# Copyright 2019 TerraPower, LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Testing some utility functions."""
from collections import defaultdict
import os
import unittest
import numpy as np
from armi import utils
from armi.reactor.tests.test_reactors import loadTestReactor
from armi.settings.caseSettings import Settings
from armi.tests import mockRunLogs
from armi.utils import (
directoryChangers,
getPowerFractions,
getCycleNames,
getAvailabilityFactors,
getStepLengths,
getCycleLengths,
getBurnSteps,
getMaxBurnSteps,
getNodesPerCycle,
getCycleNodeFromCumulativeStep,
getCycleNodeFromCumulativeNode,
getPreviousTimeNode,
getCumulativeNodeNum,
hasBurnup,
codeTiming,
safeCopy,
)
[docs]class TestGeneralUtils(unittest.TestCase):
[docs] def test_mergeableDictionary(self):
mergeableDict = utils.MergeableDict()
normalDict = {"luna": "thehusky", "isbegging": "fortreats", "right": "now"}
mergeableDict.merge(
{"luna": "thehusky"}, {"isbegging": "fortreats"}, {"right": "now"}
)
self.assertEqual(mergeableDict, normalDict)
[docs] def test_capStrLen(self):
# Test with strings
str1 = utils.capStrLen("sodium", 5)
self.assertEqual("so...", str1)
str1 = utils.capStrLen("potassium", 6)
self.assertEqual("pot...", str1)
str1 = utils.capStrLen("rubidium", 7)
self.assertEqual("rubi...", str1)
with self.assertRaises(Exception):
str1 = utils.capStrLen("sodium", 2)
[docs] def test_list2str(self):
# Test with list of strings
list1 = ["One", "Two"]
list2 = ["Three", "Four"]
str1 = "OneTwo"
str2 = utils.list2str(list1, 4, None, None)
self.assertEqual(str1, str2)
str1 = "One Two "
str2 = utils.list2str(list1, None, None, 5)
self.assertEqual(str1, str2)
str1 = "OneTwoThreeFour"
str2 = utils.list2str(list2, None, list1, None)
self.assertEqual(str1, str2)
str1 = "OneTwoThreeFourT...Four"
str2 = utils.list2str(list2, 4, list1, None)
self.assertEqual(str1, str2)
str1 = "OneTwoThreeFourT...FourThreeFour "
str2 = utils.list2str(list2, None, list1, 5)
self.assertEqual(str1, str2)
str1 = "OneTwoThreeFourT...FourThreeFour T... Four "
str2 = utils.list2str(list2, 4, list1, 5)
self.assertEqual(str1, str2)
[docs] def test_slantSplit(self):
x1 = utils.slantSplit(10.0, 4.0, 4)
x2 = utils.slantSplit(10.0, 4.0, 4, order="high first")
self.assertListEqual(x1, [1.0, 2.0, 3.0, 4.0])
self.assertListEqual(x2, [4.0, 3.0, 2.0, 1.0])
[docs] def test_prependToList(self):
a = ["hello", "world"]
b = [1, 2, 3]
utils.prependToList(a, b)
self.assertListEqual(a, [1, 2, 3, "hello", "world"])
[docs] def test_plotMatrix(self):
matrix = np.zeros([2, 2], dtype=float)
matrix[0, 0] = 1
matrix[0, 1] = 2
matrix[1, 0] = 3
matrix[1, 1] = 4
xtick = ([0, 1], ["1", "2"])
ytick = ([0, 1], ["1", "2"])
fname = "test_plotMatrix_testfile"
with directoryChangers.TemporaryDirectoryChanger():
utils.plotMatrix(matrix, fname, show=True, title="plot")
utils.plotMatrix(matrix, fname, minV=0, maxV=5, figsize=[3, 4])
utils.plotMatrix(matrix, fname, xticks=xtick, yticks=ytick)
[docs] def test_classesInHierarchy(self):
"""Tests the classesInHierarchy utility."""
# load the test reactor
_o, r = loadTestReactor()
# call the `classesInHierarchy` function
classCounts = defaultdict(lambda: 0)
utils.classesInHierarchy(r, classCounts, None)
# validate the `classesInHierarchy` function
self.assertGreater(len(classCounts), 30)
self.assertEqual(classCounts[type(r)], 1)
self.assertEqual(classCounts[type(r.core)], 1)
# further validate the Reactor heirarchy is in place
self.assertGreater(len(r.core.getAssemblies()), 50)
self.assertGreater(len(r.core.getBlocks()), 200)
[docs] def test_codeTiming(self):
"""Test that codeTiming preserves function attributes when it wraps a function."""
@codeTiming.timed
def testFunc():
"""Test function docstring."""
pass
self.assertEqual(getattr(testFunc, "__doc__"), "Test function docstring.")
self.assertEqual(getattr(testFunc, "__name__"), "testFunc")
[docs] def test_safeCopy(self):
with directoryChangers.TemporaryDirectoryChanger():
os.mkdir("dir1")
os.mkdir("dir2")
file1 = "dir1/file1.txt"
with open(file1, "w") as f:
f.write("Hello")
file2 = "dir1\\file2.txt"
with open(file2, "w") as f:
f.write("Hello2")
with mockRunLogs.BufferLog() as mock:
# Test Linuxy file path
self.assertEqual("", mock.getStdout())
safeCopy(file1, "dir2")
self.assertIn("Copied", mock.getStdout())
self.assertIn("file1", mock.getStdout())
self.assertIn("->", mock.getStdout())
# Clean up for next safeCopy
mock.emptyStdout()
# Test Windowsy file path
self.assertEqual("", mock.getStdout())
safeCopy(file2, "dir2")
self.assertIn("Copied", mock.getStdout())
self.assertIn("file2", mock.getStdout())
self.assertIn("->", mock.getStdout())
[docs]class CyclesSettingsTests(unittest.TestCase):
"""
Check reading of the various cycle history settings for both the detailed
and simple input options.
"""
detailedCyclesSettings = """
metadata:
version: uncontrolled
settings:
power: 1000000000.0
nCycles: 3
cycles:
- name: dog
cumulative days: [1, 2, 3]
power fractions: [0.1, 0.2, 0.3]
availability factor: 0.1
- cycle length: 10
burn steps: 5
power fractions: [0.2, 0.2, 0.2, 0.2, 0]
availability factor: 0.5
- name: ferret
step days: [3, R4]
power fractions: [0.3, R4]
runType: Standard
"""
simpleCyclesSettings = """
metadata:
version: uncontrolled
settings:
power: 1000000000.0
nCycles: 3
availabilityFactors: [0.1, R2]
cycleLengths: [1, 2, 3]
powerFractions: [0.1, 0.2, R1]
burnSteps: 3
runType: Standard
"""
powerFractionsDetailedSolution = [
[0.1, 0.2, 0.3],
[0.2, 0.2, 0.2, 0.2, 0],
[0.3, 0.3, 0.3, 0.3, 0.3],
]
powerFractionsSimpleSolution = [[0.1, 0.1, 0.1], [0.2, 0.2, 0.2], [0.2, 0.2, 0.2]]
cycleNamesDetailedSolution = ["dog", None, "ferret"]
cycleNamesSimpleSolution = [None, None, None]
availabilityFactorsDetailedSolution = [0.1, 0.5, 1]
availabilityFactorsSimpleSolution = [0.1, 0.1, 0.1]
stepLengthsDetailedSolution = [
[1, 1, 1],
[10 / 5 * 0.5, 10 / 5 * 0.5, 10 / 5 * 0.5, 10 / 5 * 0.5, 10 / 5 * 0.5],
[3, 3, 3, 3, 3],
]
stepLengthsSimpleSolution = [
[1 * 0.1 / 3, 1 * 0.1 / 3, 1 * 0.1 / 3],
[2 * 0.1 / 3, 2 * 0.1 / 3, 2 * 0.1 / 3],
[3 * 0.1 / 3, 3 * 0.1 / 3, 3 * 0.1 / 3],
]
cycleLengthsDetailedSolution = [30, 10, 15]
cycleLengthsSimpleSolution = [1, 2, 3]
burnStepsDetailedSolution = [3, 5, 5]
burnStepsSimpleSolution = [3, 3, 3]
nodesPerCycleDetailedSolution = [4, 6, 6]
nodesPerCycleSimpleSolution = [4, 4, 4]
maxBurnStepsDetailedSolution = 5
maxBurnStepsSimpleSolution = 3
def setUp(self):
self.standaloneDetailedCS = Settings()
self.standaloneDetailedCS.loadFromString(self.detailedCyclesSettings)
self.standaloneSimpleCS = Settings()
self.standaloneSimpleCS.loadFromString(self.simpleCyclesSettings)
[docs] def test_getPowerFractions(self):
self.assertEqual(
getPowerFractions(self.standaloneDetailedCS),
self.powerFractionsDetailedSolution,
)
self.assertEqual(
getPowerFractions(self.standaloneSimpleCS),
self.powerFractionsSimpleSolution,
)
[docs] def test_getCycleNames(self):
self.assertEqual(
getCycleNames(self.standaloneDetailedCS), self.cycleNamesDetailedSolution
)
self.assertEqual(
getCycleNames(self.standaloneSimpleCS), self.cycleNamesSimpleSolution
)
[docs] def test_getAvailabilityFactors(self):
self.assertEqual(
getAvailabilityFactors(self.standaloneDetailedCS),
self.availabilityFactorsDetailedSolution,
)
self.assertEqual(
getAvailabilityFactors(self.standaloneSimpleCS),
self.availabilityFactorsSimpleSolution,
)
[docs] def test_getStepLengths(self):
self.assertEqual(
getStepLengths(self.standaloneDetailedCS),
self.stepLengthsDetailedSolution,
)
self.assertEqual(
getStepLengths(self.standaloneSimpleCS),
self.stepLengthsSimpleSolution,
)
[docs] def test_getCycleLengths(self):
self.assertEqual(
getCycleLengths(self.standaloneDetailedCS),
self.cycleLengthsDetailedSolution,
)
self.assertEqual(
getCycleLengths(self.standaloneSimpleCS), self.cycleLengthsSimpleSolution
)
[docs] def test_getBurnSteps(self):
self.assertEqual(
getBurnSteps(self.standaloneDetailedCS), self.burnStepsDetailedSolution
)
self.assertEqual(
getBurnSteps(self.standaloneSimpleCS), self.burnStepsSimpleSolution
)
[docs] def test_hasBurnup(self):
self.assertTrue(hasBurnup(self.standaloneDetailedCS))
[docs] def test_getMaxBurnSteps(self):
self.assertEqual(
getMaxBurnSteps(self.standaloneDetailedCS),
self.maxBurnStepsDetailedSolution,
)
self.assertEqual(
getMaxBurnSteps(self.standaloneSimpleCS), self.maxBurnStepsSimpleSolution
)
[docs] def test_getNodesPerCycle(self):
self.assertEqual(
getNodesPerCycle(self.standaloneDetailedCS),
self.nodesPerCycleDetailedSolution,
)
self.assertEqual(
getNodesPerCycle(self.standaloneSimpleCS), self.nodesPerCycleSimpleSolution
)
[docs] def test_getCycleNodeFromCumulativeStep(self):
self.assertEqual(
getCycleNodeFromCumulativeStep(8, self.standaloneDetailedCS), (1, 4)
)
self.assertEqual(
getCycleNodeFromCumulativeStep(12, self.standaloneDetailedCS), (2, 3)
)
self.assertEqual(
getCycleNodeFromCumulativeStep(4, self.standaloneSimpleCS), (1, 0)
)
self.assertEqual(
getCycleNodeFromCumulativeStep(8, self.standaloneSimpleCS), (2, 1)
)
[docs] def test_getCycleNodeFromCumulativeNode(self):
self.assertEqual(
getCycleNodeFromCumulativeNode(8, self.standaloneDetailedCS), (1, 4)
)
self.assertEqual(
getCycleNodeFromCumulativeNode(12, self.standaloneDetailedCS), (2, 2)
)
self.assertEqual(
getCycleNodeFromCumulativeNode(3, self.standaloneSimpleCS), (0, 3)
)
self.assertEqual(
getCycleNodeFromCumulativeNode(8, self.standaloneSimpleCS), (2, 0)
)
with self.assertRaises(ValueError):
getCycleNodeFromCumulativeNode(-1, self.standaloneSimpleCS)
[docs] def test_getPreviousTimeNode(self):
with self.assertRaises(ValueError):
getPreviousTimeNode(0, 0, "foo")
self.assertEqual(getPreviousTimeNode(1, 1, self.standaloneSimpleCS), (1, 0))
self.assertEqual(getPreviousTimeNode(1, 0, self.standaloneSimpleCS), (0, 3))
self.assertEqual(getPreviousTimeNode(1, 0, self.standaloneDetailedCS), (0, 3))
self.assertEqual(getPreviousTimeNode(2, 4, self.standaloneDetailedCS), (2, 3))
[docs] def test_getCumulativeNodeNum(self):
self.assertEqual(getCumulativeNodeNum(2, 0, self.standaloneSimpleCS), 8)
self.assertEqual(getCumulativeNodeNum(1, 2, self.standaloneSimpleCS), 6)
self.assertEqual(getCumulativeNodeNum(2, 0, self.standaloneDetailedCS), 10)
self.assertEqual(getCumulativeNodeNum(1, 0, self.standaloneDetailedCS), 4)