Commit 1aa87d03 authored by Thomas Baumann's avatar Thomas Baumann
Browse files

Cosmetic changes before restructuring the controller to allow for problem dependent adaptivity

parent de5b625b
......@@ -12,6 +12,8 @@ def filter_stats(stats, process=None, time=None, level=None, iter=None, type=Non
level (int): the requested level index
iter (int): the requested iteration count
type (str): string to describe the requested type of value
recomputed (bool): filter out intermediate values that have no impact on the solution because the associated
step was restarted if True. (Or filter the restarted if False. Use None to get both.)
Returns:
dict: dictionary containing only the entries corresponding to the filter
"""
......@@ -19,9 +21,11 @@ def filter_stats(stats, process=None, time=None, level=None, iter=None, type=Non
# check which steps have been recomputed
if recomputed is not None:
restarts = np.array(sort_stats(filter_stats(stats, process=None, time=None, iter=None, type='recomputed',
recomputed=None), sortby='time'))
# this will contain a 2d array with all times and whether they have been recomputed
restarts = np.array(get_sorted(stats, process=None, time=None, iter=None, type='recomputed',
recomputed=None, sortby='time'))
else:
# dummy values for when no filtering of restarts is desired
restarts = np.array([[None, None]])
for k, v in stats.items():
......@@ -33,7 +37,10 @@ def filter_stats(stats, process=None, time=None, level=None, iter=None, type=Non
(k.type == type or type is None):
if k.time in restarts[:, 0]:
if restarts[restarts[:, 0] == k.time][0, 1] == float(recomputed):
# we know there is only one entry for each time, so we make a mask for the time and take the first and
# only entry and then take the second entry of this, which contains whether a restart was performed at
# this time as a float and compare it to the value we specified for recomputed
if restarts[restarts[:, 0] == k.time][0][1] == float(recomputed):
result[k] = v
else:
result[k] = v
......@@ -82,5 +89,5 @@ def get_list_of_types(stats):
return type_list
def get_sorted(stats, process=None, time=None, level=None, iter=None, type=None, sortby='time'):
def get_sorted(stats, process=None, time=None, level=None, iter=None, type=None, recomputed=None, sortby='time'):
return sort_stats(filter_stats(stats, process=process, time=time, level=level, iter=iter, type=type), sortby=sortby)
......@@ -6,7 +6,7 @@ import dill
from pySDC.core.Controller import controller
from pySDC.core import Step as stepclass
from pySDC.core.Errors import ControllerError, CommunicationError, ParameterError
from pySDC.implementations.controller_classes.error_estimator import ErrorEstimator_nonMPI
from pySDC.implementations.controller_classes.error_estimator import get_ErrorEstimator_nonMPI
class controller_nonMPI(controller):
......@@ -79,7 +79,6 @@ class controller_nonMPI(controller):
self.params.use_HotRod
self.params.use_extrapolation_estimate = self.params.use_extrapolation_estimate or self.params.use_HotRod
self.store_uold = self.params.use_iteration_estimator or self.params.use_embedded_estimate
self.restart = [False] * num_procs
if self.params.use_adaptivity:
if 'e_tol' not in description['level_params'].keys():
raise ParameterError('Please supply "e_tol" in the level parameters')
......@@ -91,7 +90,7 @@ s to have a constant order in time for adaptivity. Setting restol=0')
if self.params.use_HotRod and self.params.HotRod_tol == np.inf:
self.logger.warning('Hot Rod needs a detection threshold, which is now set to infinity, such that a restart\
is never triggered!')
self.error_estimator = ErrorEstimator_nonMPI().get_estimator(self)
self.error_estimator = get_ErrorEstimator_nonMPI(self)
def check_iteration_estimator(self, MS):
"""
......@@ -183,9 +182,10 @@ s to have a constant order in time for adaptivity. Setting restol=0')
while not done:
done = self.pfasst(MS_active)
if True in self.restart: # restart part of the block
restarts = [S.status.restart for S in MS_active]
if True in restarts: # restart part of the block
# find the place after which we need to restart
restart_at = np.where(self.restart)[0][0]
restart_at = np.where(restarts)[0][0]
# store values in the current block that don't need restarting
if restart_at > 0:
......@@ -266,8 +266,6 @@ s to have a constant order in time for adaptivity. Setting restol=0')
for lvl in self.MS[p].levels:
lvl.status.time = time[p]
self.restart = [False] * len(self.MS)
@staticmethod
def recv(target, source, tag=None):
"""
......@@ -537,7 +535,7 @@ s to have a constant order in time for adaptivity. Setting restol=0')
if self.params.use_adaptivity:
self.adaptivity(local_MS_running)
self.resilence(local_MS_running)
self.resilience(local_MS_running)
for S in local_MS_running:
......@@ -762,7 +760,7 @@ s to have a constant order in time for adaptivity. Setting restol=0')
"""
raise ControllerError('Unknown stage, got %s' % local_MS_running[0].status.stage) # TODO
def resilence(self, local_MS_running):
def resilience(self, local_MS_running):
"""
Call various functions that are supposed to provide some sort of resilience from here
"""
......@@ -770,10 +768,11 @@ s to have a constant order in time for adaptivity. Setting restol=0')
if self.params.use_HotRod:
self.hotrod(local_MS_running)
# make sure controller and steps are on the same page about restarting
# a step gets restarted because it wants to or because any earlier step wants to
restart = False
for p in range(len(local_MS_running)):
local_MS_running[p].status.restart = local_MS_running[p].status.restart or any(self.restart[:p + 1])
self.restart[p] = local_MS_running[p].status.restart
restart = restart or local_MS_running[p].status.restart
local_MS_running[p].status.restart = restart
def hotrod(self, local_MS_running):
"""
......@@ -792,7 +791,7 @@ s to have a constant order in time for adaptivity. Setting restol=0')
if None not in [l.status.error_extrapolation_estimate, l.status.error_embedded_estimate]:
diff = l.status.error_extrapolation_estimate - l.status.error_embedded_estimate
if diff > self.params.HotRod_tol:
self.restart[i] = True
S.status.restart = True
def adaptivity(self, MS):
"""
......@@ -817,14 +816,11 @@ s to have a constant order in time for adaptivity. Setting restol=0')
order = S.status.iter # embedded error estimate is same order as time marching
assert L.status.error_embedded_estimate is not None, 'Make sure to estimate the embedded error before call\
ing adaptivity!'
h_opt = L.params.dt * 0.9 * (L.params.e_tol / L.status.error_embedded_estimate)**(1. / order)
# distribute step sizes
L.status.dt_new = h_opt
L.status.dt_new = L.params.dt * 0.9 * (L.params.e_tol / L.status.error_embedded_estimate)**(1. / order)
# check whether to move on or restart
if L.status.error_embedded_estimate >= L.params.e_tol:
self.restart[i] = True
S.status.restart = True
def adaptivity_update_step_sizes(self, active_slots):
......@@ -832,10 +828,11 @@ ing adaptivity!'
Update the step sizes computed in adaptivity here, since this can get arbitrarily elaborate
"""
# figure out where the block is restarted
if True in self.restart:
restart_at = np.where(self.restart)[0][0]
restarts = [self.MS[p].status.restart for p in active_slots]
if True in restarts:
restart_at = np.where(restarts)[0][0]
else:
restart_at = len(self.restart) - 1
restart_at = len(restarts) - 1
# record the step sizes to restart with
new_steps = [None] * len(self.MS[restart_at].levels)
......
......@@ -263,8 +263,9 @@ class _ErrorEstimator_nonMPI_BlockGS(_ErrorEstimatorBase):
S = MS[i]
for j in range(len(S.levels)):
L = S.levels[j]
semi_global_errors[i][j] = max([abs(L.uold[-1] - L.u[-1]), np.finfo(float).eps])
L.status.error_embedded_estimate = abs(semi_global_errors[i][j] - semi_global_errors[i - 1][j])
semi_global_errors[i][j] = abs(L.uold[-1] - L.u[-1])
L.status.error_embedded_estimate = max([abs(semi_global_errors[i][j] - semi_global_errors[i - 1][j]),
np.finfo(float).eps])
class _ErrorEstimator_nonMPI_no_memory_overhead_BlockGS(_ErrorEstimator_nonMPI_BlockGS):
......@@ -340,19 +341,15 @@ ate!')
self.embedded_estimate_local_error(MS[:i + 1])
class ErrorEstimator_nonMPI:
def get_ErrorEstimator_nonMPI(controller):
"""
This class should be imported from the controller and return the correct version of the error estimator based on
This function should be called from the controller and return the correct version of the error estimator based on
the chosen parameters.
"""
def __init__(self):
pass
def get_estimator(self, controller):
if len(controller.MS) >= (controller.MS[0].params.maxiter + 4) // 2:
return _ErrorEstimator_nonMPI_no_memory_overhead_BlockGS(controller)
else:
return _ErrorEstimator_nonMPI_BlockGS(controller)
if len(controller.MS) >= (controller.MS[0].params.maxiter + 4) // 2:
return _ErrorEstimator_nonMPI_no_memory_overhead_BlockGS(controller)
else:
return _ErrorEstimator_nonMPI_BlockGS(controller)
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment