""" Control code used to terminate simulation based on
a set of user defined criteria
"""
import logging
import os
import signal
from mewarpx.diags_store.checkpoint_diagnostic import CheckPointDiagnostic
from mewarpx.diags_store.diag_base import WarpXDiagnostic
from mewarpx.mwxrun import mwxrun
from mewarpx.utils_store import util as mwxutil
from pywarpx import callbacks
logger = logging.getLogger(__name__)
[docs]class SimControl(WarpXDiagnostic):
"""
The main simulation driving class. It evaluates whether to continue or
terminate the simulation based on a set of user defined criteria.
"""
def __init__(self, total_steps, diag_steps, criteria=None,
checkpoint=False, dump_period=None, **kwargs):
"""
Generate and install functions to perform after a step #.
Arguments:
total_steps (int): Total number of steps to perform in the
simulation, when step == total_steps check_criteria returns
False
diag_steps (int): Steps between diagnostic output
criteria (list): list of user defined functions or list of user
defined tuples(function, kwargs) that each return a True or
False value
checkpoint (bool): Whether or not simulation checkpoints should be
made
dump_period (int): Checkpoints will be created every
diag_steps*dump_period steps. If the dump_period is not given,
checkpoints will only be created when the simulation ends (
including interrupts due to a TERMINATE signal)
"""
if total_steps < 1:
raise AssertionError("total_steps must be >= 1")
# ensure that simulation runs through last step specified
mwxrun.simulation.max_steps = int(total_steps)
self.crit_list = []
self.crit_args_list = []
self._write_func = None
self.diag_steps = diag_steps
self.checkpoint = checkpoint
self.initialize_criteria(criteria)
self.sim_done = False
super(SimControl, self).__init__(diag_steps=diag_steps)
logger.info(
f"Diagnostic time set to {self.diag_steps*mwxrun.dt:.2e} "
f"({self.diag_steps} steps)."
)
logger.info(
f"Total simulation time set to {total_steps*mwxrun.dt:.2e} "
f"({total_steps} steps)."
)
# register the TERM signal to be handled as a break signal
mwxrun.simulation.break_signals = "SIGTERM"
# Install checkpointing if required
if self.checkpoint:
if dump_period is None:
dump_period = total_steps
else:
dump_period = self.diag_steps*dump_period
self.checkpoint_diag = CheckPointDiagnostic(dump_period, **kwargs)
# register the USR1 signal to be handled as a checkpoint signal
mwxrun.simulation.checkpoint_signals = "SIGUSR1"
# install a callback for a signalled checkpointing event
callbacks.installoncheckpointsignal(self.trigger_checkpoint)
# install a callback to check whether any termination criteria is met
callbacks.installafterstep(self.check_criteria)
[docs] def add_checker(self, criterion):
"""Install a single function to check.
Arguments:
criterion (func or tuple): Either a function or a tuple of (func,
kwargs_dict) where the kwargs_dict will be passed to the func.
"""
if callable(criterion):
self.crit_list.append(criterion)
self.crit_args_list.append({})
else:
self.crit_list.append(criterion[0])
self.crit_args_list.append(criterion[1])
[docs] def initialize_criteria(self, criteria_list):
"""Install the full initial list of criteria."""
if criteria_list:
for crit in criteria_list:
self.add_checker(crit)
[docs] def check_criteria(self):
"""Sends an interrupt signal to the current process if a termination
criteria is satisfied."""
if self.check_timestep():
terminate_statement = 'SimControl: Termination from criteria: '
for i, criteria in enumerate(self.crit_list):
continue_flag = criteria(**self.crit_args_list[i])
self.sim_done = self.sim_done or not continue_flag
if not continue_flag:
add_statement = f"{criteria.__name__} "
terminate_statement += add_statement
if self.sim_done:
logger.info(terminate_statement)
os.kill(os.getpid(), signal.SIGTERM)
[docs] def trigger_checkpoint(self):
"""Helper function to trigger the checkpoint diagnostic to write a
checkpoint."""
if self.checkpoint:
self.checkpoint_diag.checkpoint_manager(force_run=True)
[docs] def write_results(self):
"""Create results.txt file, and write to it if write_func is set.
The file signifies that the simulation ran to completion."""
results_string = ""
if callable(self._write_func):
results_string = self._write_func()
if mwxrun.me == 0:
mwxutil.mkdir_p(WarpXDiagnostic.DIAG_DIR)
with open(
os.path.join(WarpXDiagnostic.DIAG_DIR, "results.txt"), 'a'
) as results_file:
results_file.write(results_string)
[docs] def set_write_func(self, func):
"""Sets a function for writing to results.txt file.
Arguments:
func (function): Returns a string that will be written to the
results.txt file
"""
if not callable(func):
raise ValueError("The write func is not callable")
self._write_func = func
[docs] def run(self):
"""Executes the WarpX loop."""
mwxrun.simulation.step()
# create fluxdiag checkpoint file if checkpointing is installed
self.trigger_checkpoint()
# check if the simulation completed the total number of steps
if mwxrun.get_it() >= mwxrun.simulation.max_steps:
self.sim_done = True
logger.info("SimControl: Total steps reached.")
if self.sim_done:
# create file to signal that simulation ran to completion
self.write_results()