Training¶
NNTrainer¶
-
class
NNTrainer
[source]¶ Class for neural-networks training.
-
train_model
(model, dataset, n_epoch=1, lr=0.001, parameters_to_optimize=None, loss_pattern=None, log_every=None, evaluate_every=None, test_frac=0, dump_best_parameters=None, optimizer=<class 'torch.optim.adam.Adam'>, scheduler=functools.partial(<class 'torch.optim.lr_scheduler.LambdaLR'>, lr_lambda=<function lr_lambda>), **kwargs)[source]¶ Train model on the given dataset.
- Parameters
model (BaseModel) – Model to train.
dataset (FieldDataset) – Dataset to use.
n_epoch (int) – Number of iterations through the dataset.
lr (float) – Learning rate.
parameters_to_optimize (None or tuple of model's parameters) – Will optimize loss over this set of parameters. If None, will over optimize all parameters.
loss_pattern (tuple or Any) – For more info see losses.standardize_loss_pattern.
log_every (int or None) – Print logs at each log_every iteration. If None, do not print logs.
evaluate_every (int or None) – Evaluate model on test dataset each evaluate_every iteration. If None, evaluations are not used.
test_frac (float) – Fraction of scenarios put into the test dataset.
dump_best_parameters (str or None) – Path to dump best model (on test set). Dumps can be made at each evaluation iteration. If None, do not dump model automatically.
optimizer (torch.optim.Optimizer) – Optimizer to use. Default: torch.optim.Adam
scheduler (torch.optim.lr_scheduler._LRScheduler) – Scheduler to use. Default: torch.optim.lr_scheduler.LambdaLR
kwargs (dict) – Any additional named arguments for model’s forward pass.
- Returns
model – Trained model.
train_loss_legend (list) – Training loss legend.
test_loss_legend (list) – Evaluation loss legend.
-
Factories¶
Tools for fast custom module creation.
-
sequential_factory
(n, conv_module, in_ch, ch, *args, use_norm=True, use_nonlin=True, residual_connections=(), wrappers=(<class 'deepfield.metamodelling.custom_blocks.wrappers.MultiInputSequential'>, <class 'deepfield.metamodelling.custom_blocks.wrappers.TimeInvariantWrapper'>), **kwargs)[source]¶ Helper for creation of sequential modules. Module is constructed based on the chosen conv_module.
- Parameters
n (int) – Number of conv_modules to stack.
conv_module (nn.Module) – Base module from which sequential is constructed.
in_ch (int) – Number of input channels in the resulting sequential.
ch (int, list) – Number of output channels in each of conv_module`s. If `int is passed, will use similar number of output channels across the modules. If list is passed, use it’s entries. List should have length n.
args (tuple) – Any additional args passed to the constructor of `conv_module’ Each arg should be either list or any. If arg is not a list, use similar values across modules.
use_norm (bool, nn.Module) – If bool, marks the need of using normalization between modules (before the nonlin). If nn.Module, uses this module as a normalization. Default: True
use_nonlin (bool, nn.Module) – If bool, marks the need of using nonlinearity between modules. If nn.Module, uses this module as a nonlinearity. Default: True
residual_connections (tuple) – Tuple of pairs. Each pair marks which layers should be connected. If pair (i, j) is given, connects the input of i-th layer and the output of j-th layer. Pairs should not cross each other. Default: ()
wrappers (tuple, list) – List of wrappers, which will be used after constructing the composing modules. Order matters. Default: (MultiInputSequential, TimeInvariantWrapper)
kwargs (dict) – Any additional named args passed to the constructor of `conv_module’ Each kwarg value should be either list or any. If value is not a list, use similar values across modules.
- Returns
out – Composed modules, possibly with applied norm, nonlin, wrappers.
- Return type
list, nn.Module
Utils¶
Miscellaneous utils.
-
class
LinearInterpolator
(points, values, at_bounds='linear')[source]¶ Piecewise linear interpolator.
-
find_best_match_indices
(search_for, search_in, less_or_equal=False, greater_or_equal=False)[source]¶ For each element of search_for, find the index of closest element in search_in.
- Parameters
- Returns
indices – Shape: search_for.shape If the required neighbour was not found, -1 value is used instead of true indices.
- Return type
torch.Tensor
-
get_model_device
(model)[source]¶ Get device on which model in situated.
- Parameters
model (nn.Module) –
- Returns
device
- Return type
torch.device