Source code for mewarpx.diags_store.checkpoint_diagnostic

"""Class for installing a checkpoint diagnostic"""
import logging
import os

from mewarpx.diags_store.diag_base import WarpXDiagnostic
from mewarpx.mwxrun import mwxrun
from mewarpx.utils_store import init_restart_util
from pywarpx import callbacks, picmi

logger = logging.getLogger(__name__)


[docs]class CheckPointDiagnostic(WarpXDiagnostic): def __init__(self, diag_steps, name=init_restart_util.DEFAULT_CHECKPOINT_NAME, clear_old_checkpoints=True, num_to_keep=2, **kwargs): """ This class is a wrapper for creating checkpoints from which a simulation can be restarted. Adding flux diagnostic data to checkpoints are supported, but since this class has to be initialized before the simulation and flux diagnostics are initialized after the simulation, the user is responsible for adding the ``flux_diag`` attribute to this object in a simulation input file. Arguments: diag_steps (int): Run the diagnostic with this period. Also plot on this period if enabled. name (str): The name of the diagnostic to be passed into the picmi checkpoint diagnostic. clear_old_checkpoints (bool): If True old checkpoints will be deleted after new ones are created. num_to_keep (int): Number of checkpoints to keep. Default 1. kwargs: For a list of valid keyword arguments see :class:`mewarpx.diags_store.diag_base.WarpXDiagnostic` """ self.checkpoint_steps = diag_steps self.name = name self.clear_old_checkpoints = clear_old_checkpoints self.num_to_keep = num_to_keep self.flux_diag = None super(CheckPointDiagnostic, self).__init__( diag_steps=diag_steps, **kwargs) self.write_dir = self.DIAG_DIR self.add_checkpoint() # if checkpoints will only be created with an interrupt signal or # the end of the simulation, we don't need to install the callback if self.checkpoint_steps != mwxrun.simulation.max_steps: callbacks.installafterdiagnostics(self.checkpoint_manager)
[docs] def add_checkpoint(self): diagnostic = picmi.Checkpoint( period=self.checkpoint_steps, name=self.name, write_dir=self.write_dir ) mwxrun.simulation.add_diagnostic(diagnostic)
[docs] def checkpoint_manager(self, force_run=False): """Function executed on checkpoint steps to perform various tasks related to checkpoint management. These include copying the flux diagnostic data needed for a restart as well as deleting old checkpoints. """ if not force_run and not self.check_timestep(): return # Save a copy of flux diagnostics, if present, to load when restarting. if self.flux_diag is not None: # If the timeseries were not updated on timestep, do so now. if self.flux_diag.last_run_step != mwxrun.get_it(): self.flux_diag.update_ts_dict() self.flux_diag.last_run_step = mwxrun.get_it() if mwxrun.me == 0: self.flux_diag.update_fullhist_dict() if mwxrun.me == 0: # We use the unorthodox file extension .ckpt (for checkpoint) # so that we can continue to blindly move all .dpkl files from # EFS to S3 when running on AWS dst = os.path.join( self.write_dir, f"{self.name}{self.flux_diag.last_run_step:06d}", "fluxdata.ckpt" ) self.flux_diag.save(filepath=dst) if self.clear_old_checkpoints and mwxrun.me == 0: init_restart_util.clean_old_checkpoints( checkpoint_prefix=self.name, num_to_keep=self.num_to_keep )