# Copyright 2019 TerraPower, LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import pickle
from typing import Any, Optional, List, Set
import sys
import numpy
import six
from armi import runLog
from armi.reactor.parameters import parameterDefinitions, exceptions
from armi.reactor.parameters.parameterDefinitions import (
SINCE_LAST_DISTRIBUTE_STATE,
SINCE_BACKUP,
SINCE_ANYTHING,
NEVER,
)
from armi.utils import units
GLOBAL_SERIAL_NUM = -1
"""
The serial number for all ParameterCollections
This is a counter of the number of instances of all types. They are useful for tracking
items through the history of a database.
.. warning::
This is not MPI safe. We also have not done anything to make it thread safe,
except that the GIL exists.
"""
def _getBaseParameterDefinitions():
pDefs = parameterDefinitions.ParameterDefinitionCollection()
pDefs.add(
parameterDefinitions.Parameter(
"serialNum",
units=units.UNITLESS,
description=(
"Unique serial integer for all objects in the ARMI Composite Tree. "
"The numbers are only unique for a simulation, on an MPI rank."
),
location=None,
saveToDB=True,
default=parameterDefinitions.NoDefault,
setter=parameterDefinitions.NoDefault,
categories=set(),
)
)
return pDefs
class _ParameterCollectionType(type):
"""
Simple metaclass to make sure that expected class attributes are present.
These attributes shouldn't be shared among different subclasses, so this
makes sure that each subclass gets its own.
"""
def __new__(mcl, name, bases, attrs):
attrs["pDefs"] = attrs.get("pDefs") or None
attrs["_ArmiObject"] = None
attrs["_allFields"] = []
return type.__new__(mcl, name, bases, attrs)
[docs]class ParameterCollection(metaclass=_ParameterCollectionType):
r"""An empty class for holding state information in the ARMI data structure.
A parameter collection stores one or more formally-defined values ("parameters").
Until a given ParameterCollection subclass has been instantiated, new parameters may
be added to its parameter definitions (e.g., from plugins). Upon first
instantiation, ``applyParameters()`` will be called, binding the parameter
definitions to the Collection class as descriptors.
It is illegal to redefine a parameter with the same name in the same class, or its
subclasses, and attempting to do so should result in exceptions in
``applyParameters()``.
Attributes
----------
_backup : str
A pickle dump of the __getstate__, or None.
_hist : dict
Keys are ``(paramName, timeStep)``.
assigned : int Flag
indicates the synchronization state of the parameter collection. This is used to
reduce the amount of information that is transmitted during database, and MPI
operations as well as determine the collection's state when exiting a
``Composite.retainState``.
This attribute when used with the ``Parameter.assigned`` attribute allows us to
efficiently perform many operations.
See Also
--------
armi.reactors.parameters
"""
pDefs: parameterDefinitions.ParameterDefinitionCollection = (
_getBaseParameterDefinitions()
)
_allFields: List[str] = []
# The ArmiObject class that this ParameterCollection belongs to
_ArmiObject = None
# A set of all instance attributes that are settable on an instance. This prevents
# inadvertent setting of values that aren't proper parameters. Named _slots, as
# it is used to emulate some of the behaviors of __slots__.
_slots: Set[str] = set()
def __init__(self, _state: Optional[List[Any]] = None):
"""
Create a new ParameterCollection instance.
Parameters
----------
_state:
Optional list of parameter values, ordered by _allFields. Passed values
should come from a call to __getstate__(). This should only be used
internally to this model.
"""
if self.pDefs is None or not self.pDefs.locked:
type(self).applyParameters()
assert self.pDefs.locked, (
"It looks like parameter definitions haven't been "
"set up yet for {}; be sure that applyAllParameters() is being called "
"somewhere.".format(type(self))
)
self._backup = None
# used by the history tracker when a parameter key is a tuple (name, timestep)
self._hist = {}
# Initialize all parameter values to **something**. This is crucial to getting
# the split-key dictionary memory savings in lieu of using __slots__!
if _state is None:
for pDef in self.paramDefs:
setattr(self, pDef.fieldName, pDef.default)
else:
for key, val in zip(self._allFields, _state):
self.__dict__[key] = val
self.assigned = NEVER
global GLOBAL_SERIAL_NUM
self.serialNum = GLOBAL_SERIAL_NUM = GLOBAL_SERIAL_NUM + 1
if self.serialNum > sys.maxsize:
runLog.warning(
"Created serial number larger than an integer. Current serial: {}".format(
GLOBAL_SERIAL_NUM
)
)
[docs] @classmethod
def applyParameters(cls):
"""
Apply the definitions from a ParameterDefinitionCollection as properties.
This places the parameter definitions in the associated
ParameterDefinitionCollection onto this ParameterCollection class as class
attributes. In the process it recursively calls the same method on base classes,
and adds their parameter definitions as well. Since each instance of Parameter
implements the descriptor protocol, these are effectively behaving as
``@property``-style accessors.
This function must act on each ParameterCollection subclass before the first
instance is created. Subsequent calls will short-circuit. Before calling this
method, it is possible to add more Parameters to the associated
ParameterDefinitionCollection, ``cls.pDefs``. After calling this method, the
ParameterDefinitionCollection will be locked, preventing any further additions.
This method is called in the ``__init__()`` method, but can also be called
proactively to compile the parameter definitions earlier, if desired.
See Also
--------
armi.reactor.parameters.parameterDefinitions.ParameterDefinitionCollection
"""
if bool(cls._allFields):
# Short-circuit if this has already been done
return
# Ensure that we have at least something to start with
cls.pDefs = cls.pDefs or parameterDefinitions.ParameterDefinitionCollection()
# Collect definitions from base ParameterCollection classes. E.g.,
# HelixParameterCollection also gets parameter definitions from
# ComponentParameterCollection.
if not cls.pDefs.locked:
basePDefs = parameterDefinitions.ParameterDefinitionCollection()
for base in [
b for b in cls.__bases__ if issubclass(b, ParameterCollection)
]:
base.applyParameters()
if base.pDefs is not None:
basePDefs.extend(base.pDefs)
# Check for duplicate parameter definitions
seen = set()
duplicates = set()
for name in cls.pDefs.names:
if name in seen:
duplicates.add(name)
seen.add(name)
if duplicates:
raise exceptions.ParameterDefinitionError(
"The following parameters were multiply-defined:\n {}".format(
duplicates
)
)
overriddenParameters = set(cls.pDefs.names).intersection(
set(basePDefs.names)
)
if overriddenParameters:
raise exceptions.ParameterDefinitionError(
"The following parameters "
"have been redefined in a subclass: {}\n"
"current type: {}\n"
"bases: {}".format(overriddenParameters, cls, cls.__bases__)
)
# Bind the parameter definitions as descriptors to the collection
for pd in cls.pDefs:
pd.collectionType = cls
setattr(cls, pd.name, pd)
parameterDefinitions.ALL_DEFINITIONS.add(pd)
cls.pDefs.extend(basePDefs)
# prevent the addition of new parameter definitions. This will lead to errors
# early, rather than mysterious attribute access errors later.
cls.pDefs.lock()
cls._allFields = list(
sorted(
["_backup", "_hist", "assigned"] + [pd.fieldName for pd in cls.pDefs]
)
)
cls._slots = set(cls._allFields).union({pd.name for pd in cls.pDefs})
def __repr__(self):
return "<{} assigned:{}>".format(self.__class__.__name__, self.assigned)
def __setattr__(self, key, value):
assert key in self._slots, (
"Trying to set undefined attribute `{}` on "
"a ParameterCollection!".format(key)
)
object.__setattr__(self, key, value)
def __deepcopy__(self, memo):
"""
Returns a new instance of ParameterCollection with a new ``serialNum``.
Notes
-----
This operates under the assumption that ``__deepcopy__`` is used when needing a
new instance, which should get its own serial number. This follows from the
assumption that parameter collections are typically copied when copying an
ArmiObject to which it may belong. In this case, serialNum needs to be
incremented so that the objects are unique. serialNum is special.
"""
# Grabbing state first and passing it into __init__() as a performance
# optimization. This avoids the extra work in __init__() of defaulting all of
# the parameters, only to set them in __setstate__(). Instead we pass them in,
# so that __init__() can set them.
state = copy.deepcopy(self.__getstate__(), memo)
memo[id(self)] = newPC = self.__class__(_state=state)
return newPC
def __reduce__(self):
"""
Implement pickle __reduce__ protocol.
We need to do this because most subclasses of ParameterCollection are created
from a metaclass, and are therefore not top-level objects and not trivially
picklable. This implementation works by asking the ArmiObject itself to give an
instance of its associated ParameterCollection class, then setting its state.
"""
assert type(self)._ArmiObject is not None, (
"Cannot reduce {}, since it does not have an associated ArmiObject, and is "
"therefore not tied to the world of the living.".format(type(self))
)
return type(self)._ArmiObject.getParameterCollection, (), self.__getstate__()
def __getstate__(self):
# reduce data to one giant list, ordered by _allFields (sorted). Use NoDefault
# when a value is missing
data = [
getattr(self, fieldName, parameterDefinitions.NoDefault)
for fieldName in self._allFields
]
return data
def __setstate__(self, state):
# does the reverse of __getstate__
for key, val in zip(self._allFields, state):
setattr(self, key, val)
def __getitem__(self, name):
try:
return getattr(self, name)
except TypeError: # allows for history parameter tuples
return self._hist[name]
except AttributeError:
raise exceptions.UnknownParameterError(
"Parameter {} is not defined for {}".format(name, type(self))
)
def __setitem__(self, name, value):
try:
setattr(self, name, value)
except TypeError: # allows for history parameter tuples
if isinstance(name, tuple):
self._hist[name] = value
else:
raise
except AttributeError: # for clarity
raise exceptions.UnknownParameterError(
"Cannot locate definition for parameter {} in {}".format(
name, type(self)
)
)
def __delitem__(self, name):
if isinstance(name, six.string_types):
pd = self.paramDefs[name]
if hasattr(self, pd.fieldName):
pd.assigned = SINCE_ANYTHING
delattr(self, pd.fieldName)
else:
del self._hist[name]
def __contains__(self, name):
if isinstance(name, six.string_types):
return hasattr(self, "_p_" + name)
else:
return name in self._hist
def __eq__(self, other):
if not isinstance(other, self.__class__):
return False
for pd in self.paramDefs:
fieldName = pd.fieldName
haveValue = (hasattr(self, fieldName), hasattr(other, fieldName))
if all(haveValue):
if getattr(self, fieldName) != getattr(self, fieldName):
return False
elif any(haveValue):
return False
return True
def __iter__(self):
return (
pd.name
for pd in self.paramDefs
if pd.assigned != NEVER
and getattr(self, pd.fieldName) is not parameterDefinitions.NoDefault
)
[docs] def items(self):
keys = list(iter(self))
return zip(keys, (getattr(self, key) for key in keys))
[docs] def get(self, key, default=None):
"""Return a requested parameter value, if possible.
This functions similarly to the same method on a dict or similar. If there is a
value present for the requested parameter on this parameter collection, return
it. Otherwise, return the supplied default. The main reason for using this is
for safely attempting to access a parameter that doesn't have a default value,
and may not have been set. Other methods for accessing parameters would raise
an exception.
"""
try:
return self[key]
except exceptions.ParameterError:
return default
[docs] def keys(self):
return list(iter(self)) + list(self._hist.keys())
[docs] def values(self):
paramVals = list(
getattr(self, pd.fieldName)
for pd in self.paramDefs
if hasattr(self, pd.fieldName)
)
return paramVals + list(self._hist.values())
[docs] def update(self, someDict):
for k, val in someDict.items():
self[k] = val
@property
def paramDefs(self) -> parameterDefinitions.ParameterDefinitionCollection:
r"""
Get the :py:class:`ParameterDefinitionCollection` associated with this instance.
This serves as both an alias for the pDefs class attribute, and as a read-only
accessor for them. Most non-paramter-system related interactions with an
object's ``ParameterCollection`` should go through this. In the future, it
probably makes sense to make the ``pDefs`` that the ``applyDefinitions`` and
``ResolveParametersMeta`` things are sensitive to more hidden from outside the
parameter system.
"""
return type(self).pDefs
[docs] def getSyncData(self):
"""
Get all changed parameters SINCE_LAST_DISTRIBUTE_STATE (or ``syncMpiState``).
If this ParmaterCollection (proxy for a ``Composite``) has been modified
``SINCE_LAST_DISTRIBUTE_STATE``, this will return a dictionary of parameter name
keys and values, otherwise ``None``.
"""
if self.assigned & SINCE_LAST_DISTRIBUTE_STATE:
syncData = {
paramDef.name: getattr(self, paramDef.fieldName)
for paramDef in self.paramDefs
if paramDef.assigned & SINCE_LAST_DISTRIBUTE_STATE
and paramDef.name in self
}
return syncData
return None
[docs] def backUp(self):
"""Back up the state in a Pickle."""
try:
self._backup = pickle.dumps(self.__getstate__())
# this reads as assigned & everything_but(SINCE_BACKUP)
self.assigned &= ~SINCE_BACKUP
except:
runLog.error("Attempted to pickle {}.".format(self))
raise
[docs] def restoreBackup(self, paramsToApply):
"""Restore the backed up the state in a from a pickle.
Parameters
----------
paramsToApply : list of ParmeterDefinitions
restores the state of all parameters not in `paramsToApply`
"""
currentData = dict()
if self.assigned & SINCE_BACKUP:
compParams = (pd for pd in paramsToApply.intersection(set(self.paramDefs)))
currentData = {
pd: getattr(self, pd.fieldName)
for pd in compParams
if hasattr(self, pd.fieldName)
}
self.__setstate__(pickle.loads(self._backup))
for pd, currentValue in currentData.items():
# correct for global paramDef.assigned assumption
retainedValue = getattr(self, pd.fieldName)
if isinstance(retainedValue, numpy.ndarray) or isinstance(
currentValue, numpy.ndarray
):
if (retainedValue != currentValue).any():
setattr(self, pd.fieldName, currentValue)
pd.assigned = SINCE_ANYTHING
self.assigned = SINCE_ANYTHING
elif retainedValue != currentValue:
setattr(self, pd.fieldName, currentValue)
pd.assigned = SINCE_ANYTHING
self.assigned = SINCE_ANYTHING
[docs]def collectPluginParameters(pm):
"""Apply parameters from plugins to their respective object classes."""
for pluginParamDefnCollections in pm.hook.defineParameters():
for klass, pDefs in pluginParamDefnCollections.items():
klass.pDefs.extend(pDefs)
[docs]def applyAllParameters(klass=None):
klass = klass or ParameterCollection
klass.applyParameters()
for derived in klass.__subclasses__():
applyAllParameters(derived)