"""Neural Ordinary Differential Equations dynamics."""
import torch
from torch import nn
from torchdiffeq import odeint
from torchdiffeq import odeint_adjoint

from .._base_nn_model import BaseModel
from ..custom_blocks.wrappers import CheckpointWrapper
from ..utils import get_model_device, find_best_match_indices

# pylint: disable=not-callable

# pylint: disable=invalid-name
class checkpointed_cls:
    """Decorates a module with `CheckpointWrapper`."""
    def __init__(self, cls):
        self.cls = cls

    def __call__(self, *args, **kwargs):
        return CheckpointWrapper(self.cls(*args, **kwargs))

[docs]class NeuralDifferentialEquation(BaseModel): """Module for modelling dynamics in the form of Neural Ordinary Differential Equation.""" _attrs = ('states', 'rock', 'control', ('masks', 'control_t'), ('masks', 'time')) def __init__(self, time_derivative_module, time_scale=30.5): """ Parameters ---------- time_derivative_module: nn.Module Module, which computes the derivative of states over time, given params and instantaneous control. time_scale: float, optional Scales time axis. Default: 30.5 """ super().__init__() self._integrator = _Integrator(time_derivative_module, time_scale) self._checkpointed_integrator = _CheckpointedIntegrator(time_derivative_module, time_scale) self.last_n_integration_iter = 0
[docs] def forward(self, init_state, params, control, control_t, t, *der_args, max_integration_timestep=100, solver='rk4', rtol=1e-4, atol=1e-4, use_adjoint=False, max_seq_len=None, use_outer_checkpointing=False, use_inner_checkpointing=True, **der_kwargs): """ Parameters ---------- init_state: torch.Tensor Initial state. Shape: (B, C_s, *dims) params: torch.Tensor Params tensor (rock mostly). Shape: (B, C_p, *dims) control: torch.Tensor Spatial control tensor. Shape: (B, T, C_c, *dims) control_t: torch.Tensor Tensor holding a sequence of time points for which control is given. Each time must be strictly larger than the previous time. Shape: (B, T_c) t: torch.Tensor Tensor holding a sequence of time points for which to solve for states. The initial time point (i.e. time point of the init_state) should be the first element of this sequence, and each time must be strictly larger than the previous time. Shape: (B, T) der_args: tuple Any additional args passed to the dynamics module. max_integration_timestep: int Maximal timestep can be made internally by the solver. solver: str Name of the solver to use. From the list: [ 'explicit_adams', 'fixed_adams', 'adams', 'tsit5', 'dopri5', 'euler', 'midpoint', 'rk4', 'adaptive_heun', ] rtol: float Relative tolerance. atol: float Absolute tolerance. use_adjoint: bool If True, use Adjoint method at the backward pass. max_seq_len: int Maximal lenght of the sequence, that can be passed into single forward pass. If required solution has a greater length, several forward passes will be done. use_outer_checkpointing: bool If True, decorates the Integrator class with CheckpointWrapper. Default: False. use_inner_checkpointing: bool If True, wraps individual forward passes of length `max_seq_len` with CheckpointWrapper. Default: True. der_kwargs: dict Any additional named arguments passed to the `time_derivative_module`. Returns ------- solution: torch.Tensor Predicted states. Shape: (B, T, C_s, *dims) """ kwargs = dict(solver=solver, rtol=rtol, atol=atol, max_integration_timestep=max_integration_timestep, use_adjoint=use_adjoint, derivative_args=der_args, derivative_kwargs=der_kwargs) use_checkpointing = der_kwargs.get('use_checkpointing', False) if use_checkpointing: kwargs['use_inner_checkpointing'] = use_inner_checkpointing # TODO will work only if timesteps across batch are equal assert torch.all(t - t[:1] == 0), "Solution timesteps should be equal for batch items." t = t[0] assert torch.all(control_t - control_t[:1] == 0), "Control timesteps should be equal for batch items." control_t = control_t[0] if use_checkpointing: self._checkpointed_integrator.wrapped_module.set_kwargs(**kwargs, max_seq_len=max_seq_len) states = self._checkpointed_integrator( init_state, params, control, control_t, t, use_checkpointing=use_outer_checkpointing ) self.last_n_integration_iter = self._checkpointed_integrator.wrapped_module.last_n_integration_iter else: self._integrator.set_kwargs(**kwargs) states = self._integrator(init_state, params, control, control_t, t) self.last_n_integration_iter = self._integrator.last_n_integration_iter return states
[docs] def make_training_iter(self, sample, loss_pattern, loss_between_differences=False, tbptt_step=None, tbptt_tail=None, tbptt_loss_between_steps=True, **kwargs): if (tbptt_step is not None and tbptt_tail is None) or (tbptt_step is None and tbptt_tail is not None): raise ValueError('If TBPTT is assumed to be used, both `k1` and `k2` should be specified.') if tbptt_step is None and tbptt_tail is None: return self._make_training_iter(sample, loss_pattern, loss_between_differences, **kwargs) return self._make_tbptt_training_iter( sample, loss_pattern, loss_between_differences, tbptt_step, tbptt_tail, tbptt_loss_between_steps, **kwargs )
def _make_training_iter(self, sample, loss_pattern, loss_between_differences, **kwargs): """Make one training iteration. Runs backward pass. Return loss of the prediction. Parameters ---------- sample: dict Sample from a dataset. Should contain 'states', 'control', 'rock', 'time' and all the required masks. loss_pattern: tuple Loss pattern between ground truth and predicted states. loss_between_differences: bool If True, loss is computed not between states, but rather between "gradients" (differences) of states. kwargs: dict Any additional named arguments passed to the `time_derivative_module`. Returns ------- loss Detached loss. """ inp = self._get_attrs_from_sample(sample, *self._attrs, leave_attrs_device=('control', )) masks = self._get_masks_from_sample(sample, *[pattern['mask'] for pattern in loss_pattern]) out = self(inp[0][:, 0], *inp[1:], **kwargs) ref, pred = self._get_reference_and_prediction(inp[0], out, loss_between_differences) loss = self._compute_loss(ref=ref, pred=pred, loss_pattern=loss_pattern, masks=masks) loss.backward() return loss.detach() def _make_tbptt_training_iter(self, sample, loss_pattern, loss_between_differences, tbptt_step, tbptt_tail, tbptt_loss_between_steps, **kwargs): if loss_between_differences: # TODO raise NotImplementedError('Loss between differences is not implemented for TBPTT.') args = self._get_attrs_from_sample(sample, *self._attrs, leave_attrs_device=self._attrs) states, rock, control, time = args[:4] #pylint: disable=unbalanced-tuple-unpacking args = args[4:] masks = self._get_masks_from_sample(sample, *[pattern['mask'] for pattern in loss_pattern]) device = get_model_device(self) s_i = states[:, 0].to(device) buffer = [(None, s_i)] outputs = [] targets = [] loss = 0. n_steps_total = control.shape[1] if not tbptt_loss_between_steps: n_steps_backward = n_steps_total // tbptt_step else: n_steps_backward = n_steps_total for i in range(n_steps_total): s_i = buffer[-1][1].detach() s_i.requires_grad = True c_i = control[:, i].unsqueeze(1).to(device) t_i = time[:, i:i+2].to(device) next_s = self(s_i, rock, c_i, t_i, *args, **kwargs)[:, 1] outputs.append(next_s.unsqueeze(1)) targets.append(states[:, i+1].unsqueeze(1).to(device)) while len(outputs) > tbptt_step: # Delete stuff that is too old del outputs[0] del targets[0] buffer.append((s_i, next_s)) while len(buffer) > tbptt_tail: # Delete stuff that is too old del buffer[0] if (i + 1) % tbptt_step == 0: for j in range(tbptt_step): ref, pred = self._get_reference_and_prediction( targets[-j - 1], outputs[-j - 1], loss_between_differences ) loss_i = self._compute_loss(ref=ref, pred=pred, loss_pattern=loss_pattern, masks=masks) loss_i /= n_steps_backward loss += loss_i.detach() retain_graph = (j != 0) or (tbptt_tail > tbptt_step) loss_i.backward(retain_graph=retain_graph) if not tbptt_loss_between_steps: break for j in range(tbptt_tail - 1): # if we get all the way back to the initial state, stop if buffer[-j - 2][0] is None: break curr_grad = buffer[-j - 1][0].grad retain_graph = j < tbptt_tail - tbptt_step - 1 buffer[-j - 2][1].backward(curr_grad, retain_graph=retain_graph) return loss @staticmethod def _get_reference_and_prediction(inp, out, loss_between_differences): if loss_between_differences: return inp[:, 1:] - inp[:, :-1], out[:, 1:] - out[:, :-1] return inp, out
@checkpointed_cls class _CheckpointedIntegrator(nn.Module): """Class which wraps an `_Integrator` forward passes with CheckpointWrapper. Also, it provides capability to split long sequences into separate smaller forward passes.""" def __init__(self, derivative_module, time_scale): """ Parameters ---------- derivative_module: nn.Module Module, which computes the derivative of states over time, given params and instantaneous control. time_scale: float Scales time axis. """ super().__init__() self._integrator = CheckpointWrapper( _Integrator(derivative_module, time_scale) ) self.max_seq_len = None self.last_n_integration_iter = 0 def forward(self, init_state, params, control, control_t, t): """ Parameters ---------- init_state: torch.Tensor Initial state. Shape: (B, C_s, *dims). params: torch.Tensor Params tensor (rock mostly). Shape: (B, C_p, *dims). control: torch.Tensor Spatial control tensor. Shape: (B, T_c, C_c, *dims). control_t: torch.Tensor Tensor holding a sequence of time points for which control is given. Each time must be strictly larger than the previous time. Shape: (T_c, ) t: torch.Tensor Tensor holding a sequence of time points for which to solve for states. The initial time point (i.e. time point of the init_state) should be the first element of this sequence, and each time must be strictly larger than the previous time. Shape: (T, ) Returns ------- solution: torch.Tensor Predicted states. Shape: (B, T, C_s, *dims). """ states = [init_state.unsqueeze(1)] s_prev = init_state t_pieces = self._get_t_pieces(t) for t_piece in t_pieces: s_piece = self._integrator(s_prev, params, control, control_t, t_piece, use_checkpointing=True) if t_piece[-1] in t: states.append(s_piece[:, 1:]) else: states.append(s_piece[:, 1:-1]) s_prev = s_piece[:, -1] states =, dim=1) self.last_n_integration_iter = self._integrator.wrapped_module.last_n_integration_iter return states def _get_t_pieces(self, t): """Split time sequence into pieces.""" if self.max_seq_len: return self._get_t_pieces_from_max_seq_len(t) return self._get_t_pieces_from_solution_t(t) def _get_t_pieces_from_max_seq_len(self, t): """Split time sequence into pieces by the given maximal length.""" pieces = [] integration_t = self._integrator.wrapped_module.get_integration_t(t) edge_ind = torch.arange( start=0, end=integration_t.shape[0] + self.max_seq_len, step=self.max_seq_len ) for left, right in zip(edge_ind[:-1], edge_ind[1:]): piece = integration_t[left: right + 1] if piece.shape[0] == 1: continue inner_piece = piece[1:-1] if inner_piece.shape[0]: in_t = torch.any(t.view(-1, 1) == inner_piece.view(1, -1), dim=0) inner_piece = inner_piece[in_t] if inner_piece.shape[0]: piece =[ piece[0].unsqueeze(0), inner_piece, piece[-1].unsqueeze(0) ]) else: piece = torch.stack([piece[0], piece[-1]]) pieces.append(piece) return pieces @staticmethod def _get_t_pieces_from_solution_t(t): """Split time sequence into pieces by the points where solution is mandatory.""" pieces = [] for left, right in zip(t[:-1], t[1:]): piece = torch.stack([left, right]) pieces.append(piece) return pieces def set_kwargs(self, **kwargs): """Set kwargs for the wrapped `_Integrator`. instance.""" self.max_seq_len = kwargs.pop('max_seq_len', None) use_inner_checkpointing = kwargs.pop('use_inner_checkpointing', True) if 'use_checkpointing' in kwargs['derivative_kwargs']: kwargs['derivative_kwargs']['use_checkpointing'] = use_inner_checkpointing self._integrator.wrapped_module.set_kwargs(**kwargs) class _Integrator(nn.Module): """Integrator for a Neural Ordinary Differential Equation.""" kwarg_keys = ( 'max_integration_timestep', 'solver', 'rtol', 'atol', 'use_adjoint', 'derivative_args', 'derivative_kwargs' ) def __init__(self, derivative_module, time_scale): """ Parameters ---------- derivative_module: nn.Module Module, which computes the derivative of states over time, given params and instantaneous control. time_scale: float Scales time axis. """ super().__init__() self.derivative_module = _TimeDerivativeModule( derivative_module ) self.time_scale = time_scale self.last_n_integration_iter = 0 for key in self.kwarg_keys: setattr(self, key, None) def set_kwargs(self, **kwargs): """Assigns kwargs to self for following implicit use.""" for key, value in kwargs.items(): if key not in self.kwarg_keys: raise ValueError('Unexpected key found: %s' % key) setattr(self, key, value) def forward(self, init_state, params, control, control_t, t): """ Parameters ---------- init_state: torch.Tensor Initial state. Shape: (B, C_s, *dims) params: torch.Tensor Params tensor (rock mostly). Shape: (B, C_p, *dims) control: torch.Tensor Spatial control tensor. Shape: (B, T, C_c, *dims) control: torch.Tensor Spatial control tensor. Shape: (B, T_c, C_c, *dims) control_t: torch.Tensor Tensor holding a sequence of time points for which control is given. Each time must be strictly larger than the previous time. Shape: (T_c, ) t: torch.Tensor Tensor holding a sequence of time points for which to solve for states. The initial time point (i.e. time point of the init_state) should be the first element of this sequence, and each time must be strictly larger than the previous time. Shape: (T, ) Returns ------- solution: torch.Tensor Predicted states. Shape: (B, T, C_s, *dims) """ # TODO time for solution and time for control should be different self.derivative_module.set_params(params) self.derivative_module.set_control(control, control_t / self.time_scale) self.derivative_module.set_additional_forward_args(self.derivative_args, self.derivative_kwargs) integration_t = self.get_integration_t(t) integrator = odeint if not self.use_adjoint else odeint_adjoint states = integrator( self.derivative_module, init_state, integration_t / self.time_scale, method=self.solver, rtol=self.rtol, atol=self.atol ).transpose(1, 0) self.last_n_integration_iter = self.derivative_module.finish_integration() solution_indices = find_best_match_indices(search_for=t, search_in=integration_t) states = states[:, solution_indices] return states def get_integration_t(self, solution_t): """Get the set of timesteps, passed to the integrator. It can be longer than the `solution_t` due to the bounded length of solver's step.""" t = [solution_t[0].unsqueeze(0)] for prev, cur in zip(solution_t[:-1], solution_t[1:]): prev, cur = float(prev), float(cur) piece = torch.arange(prev, cur, step=self.max_integration_timestep) if prev == cur or piece[-1] != cur: piece =[piece, torch.tensor([cur])]) t.append(piece[1:].to(solution_t.device)) return class _TimeDerivativeModule(nn.Module): """Module which computes derivative of states over time. Services as a transition to `odeint` function from `torchdiffeq`. Adds a control and parameters to the forward pass implicitly (since `odeint` assume no control or additional parameters used).""" def __init__(self, layers): """ Parameters ---------- layers: nn.Module Module, which computes the derivative of states over time, given params and instantaneous control. """ super().__init__() self.layers = layers self._n_integration_iter = 0 self._params = None self._control = None self._additional_forward_args = None def set_params(self, params): """Set params for following forward passes.""" self._params = params def set_control(self, control, t): """Set control for following forward passes.""" self._control = {'t': t, 'values': control} def set_additional_forward_args(self, args, kwargs): """Set additional args and kwargs for following forward passes.""" self._additional_forward_args = {'args': args, 'kwargs': kwargs} def _get_control_at_t(self, t): """Get control corresponding to the given timestep from the sequence of controls.""" i = find_best_match_indices(search_for=t, search_in=self._control['t'], less_or_equal=True).squeeze() if i >= self._control['values'].shape[1]: i = self._control['values'].shape[1] - 1 return self._control['values'][:, i].to(self._control['t'].device) def finish_integration(self): """Clear the cash after the integration is finished. Return number of forward passes made during the integration.""" n_integration_iter = self._n_integration_iter self._n_integration_iter = 0 return n_integration_iter def forward(self, t, state): """Compute time-derivative of states at given state and timestep, and implicitly set control and parameters.""" self._n_integration_iter += 1 return self._inner_loop_forward( state, self._params, self._get_control_at_t(t), *self._additional_forward_args['args'], **self._additional_forward_args['kwargs'] ) def _inner_loop_forward(self, state, params, control, *args, **kwargs): """Compute time-derivative ds/dt.""" out =[state, params, control], dim=1) out = self.layers(out, *args, **kwargs) return out