Source code for armi.bookkeeping.db.tests.test_database3

# Copyright 2019 TerraPower, LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for the Database3 class."""
from distutils.spawn import find_executable
import subprocess
import unittest

import h5py
import numpy

from armi.bookkeeping.db import _getH5File
from armi.bookkeeping.db import database3
from armi.bookkeeping.db.databaseInterface import DatabaseInterface
from armi.reactor import parameters
from armi.reactor.tests.test_reactors import loadTestReactor, reduceTestReactorRings
from armi.settings.fwSettings.globalSettings import CONF_SORT_REACTOR
from armi.tests import TEST_ROOT
from armi.utils import getPreviousTimeNode
from armi.utils.directoryChangers import TemporaryDirectoryChanger

# determine if this is a parallel run, and git is installed
GIT_EXE = None
if find_executable("git") is not None:
    GIT_EXE = "git"
elif find_executable("git.exe") is not None:
    GIT_EXE = "git.exe"


[docs]class TestDatabase3(unittest.TestCase): """Tests for the Database3 class.""" def setUp(self): self.td = TemporaryDirectoryChanger() self.td.__enter__() self.o, self.r = loadTestReactor( TEST_ROOT, customSettings={"reloadDBName": "reloadingDB.h5"} ) reduceTestReactorRings(self.r, self.o.cs, maxNumRings=3) self.dbi = DatabaseInterface(self.r, self.o.cs) self.dbi.initDB(fName=self._testMethodName + ".h5") self.db: database3.Database3 = self.dbi.database self.stateRetainer = self.r.retainState().__enter__() # used to test location-based history. see details below self.centralAssemSerialNums = [] self.centralTopBlockSerialNums = [] def tearDown(self): self.db.close() self.stateRetainer.__exit__() self.td.__exit__(None, None, None)
[docs] def test_writeToDB(self): """Test writing to the database. .. test:: Write a single time step of data to the database. :id: T_ARMI_DB_TIME :tests: R_ARMI_DB_TIME """ self.r.p.cycle = 0 self.r.p.timeNode = 0 self.r.p.cycleLength = 0 # Adding some nonsense in, to test NoDefault params self.r.p.availabilityFactor = parameters.NoDefault # validate that the H5 file gets bigger after the write self.assertEqual(list(self.db.h5db.keys()), ["inputs"]) self.db.writeToDB(self.r) self.assertEqual(sorted(self.db.h5db.keys()), ["c00n00", "inputs"]) # check the keys for a single time step keys = [ "Circle", "Core", "DerivedShape", "Helix", "HexAssembly", "HexBlock", "Hexagon", "Reactor", "SpentFuelPool", "layout", ] self.assertEqual(sorted(self.db.h5db["c00n00"].keys()), sorted(keys)) # validate availabilityFactor did not make it into the H5 file rKeys = [ "maxAssemNum", "cycle", "cycleLength", "flags", "serialNum", "timeNode", ] self.assertEqual( sorted(self.db.h5db["c00n00"]["Reactor"].keys()), sorted(rKeys) )
[docs] def test_getH5File(self): """ Get the h5 file for the database, because that file format is language-agnostic. .. test:: Show the database is H5-formatted. :id: T_ARMI_DB_H5 :tests: R_ARMI_DB_H5 """ with self.assertRaises(TypeError): _getH5File(None) h5 = _getH5File(self.db) self.assertEqual(type(h5), h5py.File)
[docs] def makeHistory(self): """Walk the reactor through a few time steps and write them to the db.""" for cycle, node in ((cycle, node) for cycle in range(2) for node in range(2)): self.r.p.cycle = cycle self.r.p.timeNode = node # something that splitDatabase won't change, so that we can make sure that # the right data went to the right new groups/cycles self.r.p.cycleLength = cycle self.db.writeToDB(self.r)
[docs] def makeShuffleHistory(self): """Walk the reactor through a few time steps with some shuffling.""" # Serial numbers *are not stable* (i.e., they can be different between test runs # due to parallelism and test run order). However, they are the simplest way to # check correctness of location-based history tracking. So we stash the serial # numbers at the location of interest so we can use them later to check our work. self.centralAssemSerialNums = [] self.centralTopBlockSerialNums = [] grid = self.r.core.spatialGrid for cycle in range(2): a1 = self.r.core.childrenByLocator[grid[cycle, 0, 0]] a2 = self.r.core.childrenByLocator[grid[0, 0, 0]] olda1Loc = a1.spatialLocator a1.moveTo(a2.spatialLocator) a2.moveTo(olda1Loc) c = self.r.core.childrenByLocator[grid[0, 0, 0]] self.centralAssemSerialNums.append(c.p.serialNum) self.centralTopBlockSerialNums.append(c[-1].p.serialNum) for node in range(2): self.r.p.cycle = cycle self.r.p.timeNode = node # something that splitDatabase won't change, so that we can make sure # that the right data went to the right new groups/cycles self.r.p.cycleLength = cycle self.db.writeToDB(self.r) # add some more data that isnt written to the database to test the # DatabaseInterface API self.r.p.cycle = 2 self.r.p.timeNode = 0 self.r.p.cycleLength = cycle self.r.core[0].p.chargeTime = 2 # add some fake missing parameter data to test allowMissing self.db.h5db["c00n00/Reactor/missingParam"] = "i don't exist"
def _compareArrays(self, ref, src): """ Compare two numpy arrays. Comparing numpy arrays that may have unsavory data (NaNs, Nones, jagged data, etc.) is really difficult. For now, convert to a list and compare element-by-element. """ self.assertEqual(type(ref), type(src)) if isinstance(ref, numpy.ndarray): ref = ref.tolist() src = src.tolist() for v1, v2 in zip(ref, src): # Entries may be None if isinstance(v1, numpy.ndarray): v1 = v1.tolist() if isinstance(v2, numpy.ndarray): v2 = v2.tolist() self.assertEqual(v1, v2) def _compareRoundTrip(self, data): """Make sure that data is unchanged by packing/unpacking.""" packed, attrs = database3.packSpecialData(data, "testing") roundTrip = database3.unpackSpecialData(packed, attrs, "testing") self._compareArrays(data, roundTrip)
[docs] def test_prepRestartRun(self): """ This test is based on the armiRun.yaml case that is loaded during the `setUp` above. In that cs, `reloadDBName` is set to 'reloadingDB.h5', `startCycle` = 1, and `startNode` = 2. The nonexistent 'reloadingDB.h5' must first be created here for this test. .. test:: Runs can be restarted from a snapshot. :id: T_ARMI_SNAPSHOT_RESTART :tests: R_ARMI_SNAPSHOT_RESTART """ # first successfully call to prepRestartRun o, r = loadTestReactor( TEST_ROOT, customSettings={"reloadDBName": "reloadingDB.h5"} ) cs = o.cs reduceTestReactorRings(r, cs, maxNumRings=3) ratedPower = cs["power"] startCycle = cs["startCycle"] startNode = cs["startNode"] cyclesSetting = [ {"step days": [1000, 1000], "power fractions": [1, 1]}, {"step days": [1000, 1000], "power fractions": [1, 1]}, {"step days": [1000, 1000], "power fractions": [1, 1]}, ] cycleP, nodeP = getPreviousTimeNode(startCycle, startNode, cs) cyclesSetting[cycleP]["power fractions"][nodeP] = 0.5 numCycles = 2 numNodes = 2 cs = cs.modified( newSettings={ "nCycles": numCycles, "cycles": cyclesSetting, "reloadDBName": "something_fake.h5", } ) # create a db based on the cs dbi = DatabaseInterface(r, cs) dbi.initDB(fName="reloadingDB.h5") db = dbi.database # populate the db with some things for cycle, node in ( (cycle, node) for cycle in range(numCycles) for node in range(numNodes) ): r.p.cycle = cycle r.p.timeNode = node r.p.cycleLength = sum(cyclesSetting[cycle]["step days"]) r.core.p.power = ratedPower * cyclesSetting[cycle]["power fractions"][node] db.writeToDB(r) db.close() self.dbi.prepRestartRun() # prove that the reloaded reactor has the correct power self.assertEqual(self.o.r.p.cycle, cycleP) self.assertEqual(self.o.r.p.timeNode, nodeP) self.assertEqual(cyclesSetting[cycleP]["power fractions"][nodeP], 0.5) self.assertEqual( self.o.r.core.p.power, ratedPower * cyclesSetting[cycleP]["power fractions"][nodeP], ) # now make the cycle histories clash and confirm that an error is thrown cs = cs.modified( newSettings={ "cycles": [ {"step days": [666, 666], "power fractions": [1, 1]}, {"step days": [666, 666], "power fractions": [1, 1]}, {"step days": [666, 666], "power fractions": [1, 1]}, ], } ) # create a db based on the cs dbi = DatabaseInterface(r, cs) dbi.initDB(fName="reloadingDB.h5") db = dbi.database # populate the db with something for cycle, node in ( (cycle, node) for cycle in range(numCycles) for node in range(numNodes) ): r.p.cycle = cycle r.p.timeNode = node r.p.cycleLength = 2000 db.writeToDB(r) db.close() with self.assertRaises(ValueError): self.dbi.prepRestartRun()
[docs] def test_computeParents(self): # The below arrays represent a tree structure like this: # 71 -----------------------. # | \ # 12--.-----.------. 72 # / | \ \ \ # 22 30 4---. 6 18-. # / | | | \ \ / | \ # 8 17 2 32 52 62 1 9 10 # # This should cover a handful of corner cases numChildren = [2, 5, 2, 0, 0, 1, 0, 3, 0, 0, 0, 0, 3, 0, 0, 0, 0] serialNums = [71, 12, 22, 8, 17, 30, 2, 4, 32, 53, 62, 6, 18, 1, 9, 10, 72] expected_1 = [None, 71, 12, 22, 22, 12, 30, 12, 4, 4, 4, 12, 12, 18, 18, 18, 71] expected_2 = [ None, None, 71, 12, 12, 71, 12, 71, 12, 12, 12, 71, 71, 12, 12, 12, None, ] expected_3 = [ None, None, None, 71, 71, None, 71, None, 71, 71, 71, None, None, 71, 71, 71, None, ] self.assertEqual( database3.Layout.computeAncestors(serialNums, numChildren), expected_1 ) self.assertEqual( database3.Layout.computeAncestors(serialNums, numChildren, 2), expected_2 ) self.assertEqual( database3.Layout.computeAncestors(serialNums, numChildren, 3), expected_3 )
[docs] def test_load(self): """Load a reactor at different time steps, from the database. .. test:: Load the reactor from the database. :id: T_ARMI_DB_R_LOAD :tests: R_ARMI_DB_R_LOAD """ self.makeShuffleHistory() with self.assertRaises(KeyError): _r = self.db.load(0, 0) # default load, should pass without error _r = self.db.load(0, 0, allowMissing=True) # show that we can use negative indices to load r = self.db.load(0, -2, allowMissing=True) self.assertEqual(r.p.timeNode, 1) with self.assertRaises(ValueError): # makeShuffleHistory only populates 2 nodes, but the case settings # defines 3, so we must check -4 before getting an error self.db.load(0, -4, allowMissing=True) del self.db.h5db["c00n00/Reactor/missingParam"] _r = self.db.load(0, 0, allowMissing=False) # we shouldn't be able to set the fileName if a file is open with self.assertRaises(RuntimeError): self.db.fileName = "whatever.h5"
[docs] def test_loadSortSetting(self): self.makeShuffleHistory() # default load, should pass without error r0 = self.db.load(0, 0, allowMissing=True) # test that the reactor loads differently, dependent on the setting cs = self.db.loadCS() cs = cs.modified(newSettings={CONF_SORT_REACTOR: False}) r1 = self.db.load(0, 0, cs=cs, allowMissing=True) # the reactor / core should be the same size self.assertEqual(len(r0), len(r1)) self.assertEqual(len(r0.core), len(r1.core))
[docs] def test_history(self): self.makeShuffleHistory() grid = self.r.core.spatialGrid testAssem = self.r.core.childrenByLocator[grid[0, 0, 0]] testBlock = testAssem[-1] # Test assem hist = self.db.getHistoryByLocation( testAssem, params=["chargeTime", "serialNum"] ) expectedSn = { (c, n): self.centralAssemSerialNums[c] for c in range(2) for n in range(2) } self.assertEqual(expectedSn, hist["serialNum"]) # test block hists = self.db.getHistoriesByLocation( [testBlock], params=["serialNum"], timeSteps=[(0, 0), (1, 0)] ) expectedSn = {(c, 0): self.centralTopBlockSerialNums[c] for c in range(2)} self.assertEqual(expectedSn, hists[testBlock]["serialNum"]) # cant mix blocks and assems, since they are different distance from core with self.assertRaises(ValueError): self.db.getHistoriesByLocation([testAssem, testBlock], params=["serialNum"]) # if requested time step isnt written, return no content hist = self.dbi.getHistory( self.r.core[0], params=["chargeTime", "serialNum"], byLocation=True ) self.assertIn((2, 0), hist["chargeTime"].keys()) self.assertEqual(hist["chargeTime"][(2, 0)], 2)
[docs] def test_auxData(self): path = self.db.getAuxiliaryDataPath((2, 0), "test_stuff") self.assertEqual(path, "c02n00/test_stuff") with self.assertRaises(KeyError): self.db.genAuxiliaryData((-1, -1))
# TODO: This should be expanded.
[docs] def test_replaceNones(self): """Super basic test that we handle Nones correctly in database read/writes.""" data3 = numpy.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) data1 = numpy.array([1, 2, 3, 4, 5, 6, 7, 8]) data1iNones = numpy.array([1, 2, None, 5, 6]) data1fNones = numpy.array([None, 2.0, None, 5.0, 6.0]) data2fNones = numpy.array( [None, [[1.0, 2.0, 6.0], [2.0, 3.0, 4.0]]], dtype=object ) dataJag = numpy.array( [[[1, 2], [3, 4]], [[1, 2, 3], [4, 5, 6], [7, 8, 9]]], dtype=object ) dataJagNones = numpy.array( [[[1, 2], [3, 4]], [[1], [1]], [[1, 2, 3], [4, 5, 6], [7, 8, 9]]], dtype=object, ) dataDict = numpy.array( [{"bar": 2, "baz": 3}, {"foo": 4, "baz": 6}, {"foo": 7, "bar": 8}] ) self._compareRoundTrip(data3) self._compareRoundTrip(data1) self._compareRoundTrip(data1iNones) self._compareRoundTrip(data1fNones) self._compareRoundTrip(data2fNones) self._compareRoundTrip(dataJag) self._compareRoundTrip(dataJagNones) self._compareRoundTrip(dataDict)
[docs] def test_mergeHistory(self): self.makeHistory() # put some big data in an HDF5 attribute. This will exercise the code that pulls # such attributes into a formal dataset and a reference. self.r.p.cycle = 1 self.r.p.timeNode = 0 tnGroup = self.db.getH5Group(self.r) database3.Database3._writeAttrs( tnGroup["layout/serialNum"], tnGroup, { "fakeBigData": numpy.eye(6400), "someString": "this isn't a reference to another dataset", }, ) db_path = "restartDB.h5" db2 = database3.Database3(db_path, "w") with db2: db2.mergeHistory(self.db, 2, 2) self.r.p.cycle = 1 self.r.p.timeNode = 0 tnGroup = db2.getH5Group(self.r) # this test is a little bit implementation-specific, but nice to be explicit self.assertEqual( tnGroup["layout/serialNum"].attrs["fakeBigData"], "@/c01n00/attrs/0_fakeBigData", ) # exercise the _resolveAttrs function attrs = database3.Database3._resolveAttrs( tnGroup["layout/serialNum"].attrs, tnGroup ) self.assertTrue(numpy.array_equal(attrs["fakeBigData"], numpy.eye(6400))) keys = sorted(db2.keys()) self.assertEqual(len(keys), 4) self.assertEqual(keys[:3], ["/c00n00", "/c00n01", "/c01n00"])
[docs] def test_splitDatabase(self): self.makeHistory() self.db.splitDatabase( [(c, n) for c in (0, 1) for n in range(2)], "-all-iterations" ) # Closing to copy back from fast path self.db.close() with h5py.File("test_splitDatabase.h5", "r") as newDb: self.assertEqual(newDb["c00n00/Reactor/cycle"][()], 0) self.assertEqual(newDb["c00n00/Reactor/cycleLength"][()][0], 0) self.assertNotIn("c03n00", newDb) self.assertEqual(newDb.attrs["databaseVersion"], database3.DB_VERSION) # validate that the min set of meta data keys exists meta_data_keys = [ "appName", "armiLocation", "databaseVersion", "hostname", "localCommitHash", "machines", "platform", "platformArch", "platformRelease", "platformVersion", "pluginPaths", "python", "startTime", "successfulCompletion", "user", "version", ] for meta_key in meta_data_keys: self.assertIn(meta_key, newDb.attrs) self.assertIsNotNone(newDb.attrs[meta_key]) # test an edge case - no DB to split with self.assertRaises(ValueError): self.db.h5db = None self.db.splitDatabase( [(c, n) for c in (0, 1) for n in range(2)], "-all-iterations" )
[docs] @unittest.skipIf(GIT_EXE is None, "This test needs Git.") def test_grabLocalCommitHash(self): """Test of static method to grab a local commit hash with ARMI version.""" # 1. test outside a Git repo localHash = database3.Database3.grabLocalCommitHash() self.assertEqual(localHash, "unknown") # 2. test inside an empty git repo try: code = subprocess.run( ["git", "init", "."], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, ).returncode except FileNotFoundError: print("Skipping this test because it is being run outside a git repo.") return self.assertEqual(code, 0) localHash = database3.Database3.grabLocalCommitHash() self.assertEqual(localHash, "unknown") # 3. test inside a git repo with one tag # commit the empty repo code = subprocess.run( ["git", "commit", "--allow-empty", "-m", '"init"', "--author", '"sam <>"'], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, ).returncode if code == 128: # GitHub Actions blocks certain kinds of Git commands return # create a tag off our new commit code = subprocess.run( ["git", "tag", "thanks", "-m", '"you_rock"'], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, ).returncode self.assertEqual(code, 0) # test that we recover the correct commit hash localHash = database3.Database3.grabLocalCommitHash() self.assertEqual(localHash, "thanks") # delete the .git directory code = subprocess.run( ["git", "clean", "-f"], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL ).returncode self.assertEqual(code, 0) code = subprocess.run( ["git", "clean", "-f", "-d"], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, ).returncode self.assertEqual(code, 0)
[docs] def test_fileName(self): # test the file name getter self.assertEqual(str(self.db.fileName), "test_fileName.h5") # test the file name setter self.db.close() self.db.fileName = "thing.h5" self.assertEqual(str(self.db.fileName), "thing.h5")
[docs] def test_readInputsFromDB(self): """Test that we can read inputs from the database. .. test:: Save and retrieve settings from the database. :id: T_ARMI_DB_CS :tests: R_ARMI_DB_CS .. test:: Save and retrieve blueprints from the database. :id: T_ARMI_DB_BP :tests: R_ARMI_DB_BP """ inputs = self.db.readInputsFromDB() self.assertEqual(len(inputs), 3) # settings self.assertGreater(len(inputs[0]), 100) self.assertIn("settings:", inputs[0]) self.assertEqual(len(inputs[1]), 0) # blueprints self.assertGreater(len(inputs[2]), 100) self.assertIn("custom isotopics:", inputs[2]) self.assertIn("blocks:", inputs[2])
[docs] def test_deleting(self): self.assertEqual(type(self.db), database3.Database3) del self.db self.assertFalse(hasattr(self, "db")) self.db = self.dbi.database
[docs] def test_open(self): with self.assertRaises(ValueError): self.db.open()
[docs] def test_loadCS(self): cs = self.db.loadCS() self.assertEqual(cs["numProcessors"], 1) self.assertEqual(cs["nCycles"], 6)
[docs] def test_loadBlueprints(self): bp = self.db.loadBlueprints() self.assertIsNone(bp.nuclideFlags) self.assertEqual(len(bp.assemblies), 0)