# pylint: disable=too-many-lines
"""Dataset wrappers for Fields."""
import os
import pickle
import inspect
import numpy as np
import pandas as pd
import torch
from torch.utils.data import Dataset
from ..field import Field
from ..field.base_component import BaseComponent
from ..field.utils import recursive_insensitive_glob, hasnested, overflow_safe_mean, get_spatial_perf
from .utils import get_config, STATES_KEYWORD
from .transforms import ToNumpy, Normalize, Compose, RemoveBatchDimension, AddBatchDimension, \
Transform, NON_NORMALIZED_ATTRS
SEQUENTIAL_ATTRS = ['STATES', 'CONTROL']
TABLES_WITHOUT_INDEX = ['DENSITY']
INVALID_VALUE_FILLER = -1
CONTROL_TO_RESULTS_KW = {'WBHP': 'BHPT'}
def safe_check(comp, state, expected, default=False):
"""Check that components's state has expected value or return default if state is not defined."""
try:
return getattr(comp.state, state) == expected
except AttributeError:
return default
[docs]class FieldDataset(Dataset): # pylint: disable=too-many-instance-attributes
"""Baseclass for dataset of fields loaded with similar configs."""
default_sample_attrs = {
'MASKS': ['ACTNUM', 'TIME'],
'GRID': [],
'ROCK': ['PORO', 'PERMX', 'PERMY', 'PERMZ'],
'STATES': ['PRESSURE', 'RS', 'SGAS', 'SOIL', 'SWAT'],
'CONTROL': ['BHPT'],
}
_attrs_sampled_as_dict = ('MASKS', 'GRID', 'TABLES')
def __init__(self, src, sample_attrs=None, fmt=('dat', 'data', 'hdf5'), subset_generator=None,
unravel_model=None, from_samples=False, allow_change_preloaded=False):
"""
Parameters
----------
src: str, Field, FieldSample or list of Fields or FieldSamples
Path to a directory containing fields for the dataset or preloaded Fields
sample_attrs: dict
Attributes to be represented in samples
fmt: str or tuple
Format in which fields are represented
subset_generator: callable or None
Function generating subsequences for sequential attrs (states, control)
Should return array-like objects with timestep indices
If None, full sequences will be sampled
unravel_model: bool or None
Either or not unravel loaded models
If None, will be inferred from sample_attrs (set to False if 'neighbours' or 'distances' keys are presented)
from_samples: bool
If True, tries to load samples from previously dumped dataset (with FieldDataset.dump_samples).
The sample_attrs will not affect the content of the loaded samples.
The transforms will still be applied.
"""
# TODO: add possibility to make subsets of timesteps limited to constant control
super().__init__()
if isinstance(fmt, str):
fmt = (fmt, )
files = []
self.root_dir = None
self.preloaded = None
if isinstance(src, str):
for f in fmt:
files += recursive_insensitive_glob(src, pattern='*.%s' % f, return_relative=True)
self.root_dir = src
else:
self.preloaded = np.atleast_1d(src)
self.fmt = fmt
self.files = files
self.transform = None
self._sample_attrs = None
self.sample_attrs = sample_attrs if sample_attrs is not None else self.default_sample_attrs
self.from_samples = from_samples
self.allow_change_preloaded = allow_change_preloaded
self.config = get_config()
# TODO make config dependent on the sample attrs
self.subset_generator = subset_generator
self.mean = None
self.std = None
self.min = None
self.max = None
self.masks_getter_map = {
'ACTNUM': self._get_actnum,
'WELL_MASK': self._get_well_mask,
'NAMED_WELL_MASK': self._get_named_well_mask,
'NEIGHBOURS': self._get_neighbours,
'INVALID_NEIGHBOURS_MASK': self._get_invalid_neighbours_mask,
'TIME': self._get_time,
'CF_MASK': self._get_connection_factors,
'PERF_MASK': self._get_perforation_mask
}
self.grid_getter_map = {
'DISTANCES': self._get_distances,
'XYZ': self._get_xyz
}
self.attrs_getter_map = {
'STATES': self._get_states,
'ROCK': self._get_rock,
'CONTROL': self._get_control
}
invalid_unravel_attrs = {
'MASKS': ['NEIGHBOURS'],
'GRID': ['DISTANCES']
}
if unravel_model:
for comp, attrs in invalid_unravel_attrs.items():
for attr in attrs:
if hasnested(self.sample_attrs, comp, attr):
raise ValueError('Can not unravel model and sample %s simultaneously.' % attr)
if unravel_model is None:
ravel = False
for comp, attrs in invalid_unravel_attrs.items():
for attr in attrs:
ravel = ravel or hasnested(self.sample_attrs, comp, attr)
unravel_model = not ravel
self.unravel_model = unravel_model
def __len__(self):
if self.preloaded is not None:
return len(self.preloaded)
return len(self.files)
def __getitem__(self, idx):
if self.from_samples:
sample = self._load_sample(idx)
else:
sample = self._get_sample(idx)
if self.transform:
sample = self.transform(sample)
return sample
def _get_sample(self, idx): # pylint: disable=too-many-branches
"""Get sample from the dataset
Parameters
----------
idx: int, torch.Tensor
Index of the field
Returns
-------
sample: FieldSample
"""
# TODO time and batch dimensions are transposed
if torch.is_tensor(idx):
idx = idx.tolist()
if self.subset_generator is not None:
sequence_subset = list(self.subset_generator())
if not sequence_subset:
raise ValueError('subset_generator should not generate empty subsets!')
else:
sequence_subset = None
config = self.config.copy()
config[STATES_KEYWORD] = {'attrs': self.config[STATES_KEYWORD]['attrs'], 'subset': sequence_subset}
if 'Aquifers' in config:
del config['Aquifers']
if self.preloaded is None:
model = self._load_model(idx, config)
else:
if isinstance(self.preloaded[idx], FieldSample):
return self.preloaded[idx]
model = self._get_preloaded(idx)
sample = {}
getter_kwargs = dict(
sequence_subset=sequence_subset, fill_invalid_neighbours=INVALID_VALUE_FILLER, neighbouring_radius=1
)
for comp, attrs in self.sample_attrs.items():
if comp == 'MASKS':
sample[comp] = {}
for attr, mask_getter in self.masks_getter_map.items():
# FIXME
if not self.unravel_model and attr.upper() in ('CF_MASK', 'PERF_MASK'):
continue
sample[comp][attr] = mask_getter(model, **getter_kwargs)
elif comp == 'GRID':
sample[comp] = {}
for attr in attrs:
sample[comp][attr] = self.grid_getter_map[attr](model, **getter_kwargs)
elif comp == 'TABLES':
sample[comp] = {}
sample[comp] = self._get_tables(model, attrs)
elif comp == 'CONTROL':
res = self.attrs_getter_map[comp](model, attrs, **getter_kwargs)
sample[comp] = res['control']
sample['MASKS']['CONTROL_T'] = res['t']
else:
sample[comp] = self.attrs_getter_map[comp](model, attrs, **getter_kwargs)
sample = FieldSample(field=model, dataset=self, **sample)
for key in list(sample.masks.keys()):
if sample.masks[key] is None:
del sample.masks[key]
if not self.unravel_model:
sample.as_ravel(inplace=True)
return sample
def _get_preloaded(self, idx):
"""Get a field from preloaded."""
model = self.preloaded[idx]
if self.allow_change_preloaded:
if model.state.spatial != self.unravel_model:
if self.unravel_model:
model.to_spatial()
else:
model.ravel()
if 'CONTROL' in self.sample_attrs:
if not model.wells.state.all_tracks_complete:
model.wells.drop_incomplete()
if not model.wells.state.has_blocks:
model.wells.get_wellblocks(model.grid)
if not model.wells.state.full_perforation:
model.wells.apply_perforations()
if not model.wells.state.all_tracks_inside:
model.wells.drop_outside()
if model.meta['MODEL_TYPE'] == 'ECL':
model.wells.compute_events(grid=model.grid)
else:
assert model.state.spatial == self.unravel_model
if 'CONTROL' in self.sample_attrs:
assert model.wells.state.all_tracks_complete
assert model.wells.state.has_blocks
assert model.wells.state.full_perforation
assert model.wells.state.all_tracks_inside
return model
def _load_model(self, idx, config=None, force_wells_calculations=False):
"""Loads field by index.
Parameters
----------
idx: int
config: dict, optional
Config used while loading the model
Returns
-------
model: Field
"""
_, fmt = os.path.splitext(self.files[idx])
fmt = fmt.strip('.').lower()
config = self.config if config is None else config
if fmt == 'hdf5':
if 'subset' not in config[STATES_KEYWORD] or config[STATES_KEYWORD]['subset'] is None:
config = None
else:
for comp in config:
config[comp]['attrs'] = None
model = Field(path=os.path.join(self.root_dir, self.files[idx]), config=config,
encoding='auto:10000', loglevel='ERROR')
model.load(raise_errors=False)
if self.unravel_model:
model.to_spatial()
if 'CONTROL' in self.sample_attrs:
if not safe_check(model.wells, 'all_tracks_complete', True) or force_wells_calculations:
model.wells.drop_incomplete()
if not safe_check(model.wells, 'has_blocks', True) or force_wells_calculations:
model.wells.get_wellblocks(grid=model.grid)
if not safe_check(model.wells, 'full_perforation', True) or force_wells_calculations:
model.wells.apply_perforations()
if not safe_check(model.wells, 'all_tracks_inside', True) or force_wells_calculations:
model.wells.drop_outside()
if model.meta['MODEL_TYPE'] == 'ECL':
model.wells.compute_events(grid=model.grid)
if not self.unravel_model:
model.ravel()
return model
def _load_sample(self, idx):
sample = FieldSample(os.path.join(self.root_dir, self.files[idx]))
sample.load()
return sample
def _get_actnum(self, model, **kwargs):
"""Get ACTNUM of the model"""
_ = kwargs
if hasattr(model.grid, 'actnum'):
return getattr(model.grid, 'actnum').astype(np.bool)
actnum = np.ones(model.grid.dimens, dtype=np.bool)
return actnum if self.unravel_model else actnum.ravel(order='F')
def _get_well_mask(self, model, **kwargs):
"""Get well mask of the model."""
_ = kwargs
if hasnested(self.sample_attrs, 'MASKS', 'WELL_MASK') or 'CONTROL' in self.sample_attrs:
return model.well_mask != ''
return None
def _get_named_well_mask(self, model, **kwargs):
"""Get well mask of the model."""
_ = kwargs
if hasnested(self.sample_attrs, 'MASKS', 'NAMED_WELL_MASK') or 'CONTROL' in self.sample_attrs:
well_mask = model.well_mask
named_well_mask = {}
for well in model.wells:
named_well_mask[well.name] = well_mask == well.name
return named_well_mask
return None
def _get_neighbours(self, model, fill_invalid_neighbours=INVALID_VALUE_FILLER, neighbouring_radius=-1, **kwargs):
"""Get connectivity matrix of cells presented in the model."""
if 'MASKS' in self.sample_attrs and 'NEIGHBOURS' in self.sample_attrs['MASKS']:
neighbours = model.grid.get_neighbors_matrix(
connectivity=neighbouring_radius,
fill_value=fill_invalid_neighbours,
ravel_index=True
)
# Indices are with respect to the full vectors: with active and non-active cells
# We want indices with respect to the vector of active cells
full_to_active_ind = kwargs['MASKS']['ACTNUM'].copy().astype(np.int)
full_to_active_ind[full_to_active_ind == 1] = np.arange(full_to_active_ind.sum())
full_to_active_ind = np.concatenate([full_to_active_ind, [-1]])
neighbours = full_to_active_ind[neighbours.ravel()].reshape(neighbours.shape)
# Neighbours should include the cell itself
itself_ind = np.arange(neighbours.shape[0])[:, np.newaxis]
neighbours = np.concatenate([itself_ind, neighbours], axis=1)
return neighbours
return None
def _get_invalid_neighbours_mask(self, model, fill_invalid_neighbours=INVALID_VALUE_FILLER,
neighbouring_radius=-1, **kwargs):
"""Get mask of invalid neighbours (non-active or out of geometric bounds)."""
_ = kwargs
if hasnested(self.sample_attrs, 'GRID', 'DISTANCES'):
neighbours = model.grid.get_neighbors_matrix(
connectivity=neighbouring_radius,
fill_value=fill_invalid_neighbours,
ravel_index=True
)
return neighbours == INVALID_VALUE_FILLER
return None
@staticmethod
def _get_time(model, sequence_subset=None, **kwargs):
"""Get time in days associated with states timesteps relative to model start date."""
_ = kwargs
dates = model.result_dates
sec_in_day = 86400
t = (dates - model.start).total_seconds().values / sec_in_day
return t if sequence_subset is None else t[sequence_subset]
def _get_connection_factors(self, model, sequence_subset=None, **kwargs):
# FIXME calls the field's method twice
_ = kwargs
if sequence_subset is not None:
res_dates = model.result_dates
if res_dates.size:
res_dates = res_dates[sequence_subset]
date_range = (res_dates[0], res_dates[-1])
else:
date_range = None
if hasnested(self.sample_attrs, 'MASKS', 'CF_MASK'):
return model.get_spatial_connection_factors_and_perforation_ratio(date_range=date_range)[0]
return None
def _get_perforation_mask(self, model, sequence_subset=None, **kwargs):
# FIXME calls the field's method twice
_ = kwargs
if sequence_subset is not None:
res_dates = model.result_dates
if res_dates.size:
res_dates = res_dates[sequence_subset]
date_range = (res_dates[0], res_dates[-1])
else:
date_range = None
if hasnested(self.sample_attrs, 'MASKS', 'PERF_MASK'):
return model.get_spatial_connection_factors_and_perforation_ratio(date_range=date_range)[1]
return None
[docs] @staticmethod
def to_dates(model, t):
"""Restore actual dates from time deltas."""
dates = model.start + np.array([pd.Timedelta(i, unit='day') for i in t])
return pd.to_datetime(dates)
@staticmethod
def _get_distances(model, fill_invalid_neighbours=INVALID_VALUE_FILLER, neighbouring_radius=-1, **kwargs):
"""Get matrix of distances for neighbouring cells."""
_ = kwargs
return model.grid.calculate_neighbours_distances(
connectivity=neighbouring_radius,
fill_value=fill_invalid_neighbours
)
@staticmethod
def _get_xyz(model, **kwargs):
_ = kwargs
return model.grid.xyz
def _get_states(self, model, attrs, sequence_subset=None, **kwargs):
"""Get stacked states sequence."""
_ = kwargs
if (self.preloaded is not None) and (sequence_subset is not None):
return np.stack([getattr(model.states, attr)[sequence_subset] for attr in attrs], axis=1)
return np.stack([getattr(model.states, attr) for attr in attrs], axis=1)
@staticmethod
def _get_rock(model, attrs, **kwargs):
"""Get stacked rock attributes."""
_ = kwargs
return np.stack([getattr(model.rock, attr) for attr in attrs], axis=0)
@staticmethod
def _get_tables(model, attrs):
"""Get sample table data"""
return {
attr: getattr(model.tables, attr).to_numpy() if attr in TABLES_WITHOUT_INDEX
else getattr(model.tables, attr).to_numpy(include_index=True) for attr in attrs
}
@staticmethod
def _get_control(model, attrs, sequence_subset=None, **kwargs):
"""Get control in a spatial form (defined for all cells, meaningful values in
perforated cells, other cells are filled with zeros) with corresponding dates.
"""
_ = kwargs
if sequence_subset is not None:
res_dates = model.result_dates
if res_dates.size:
res_dates = res_dates[sequence_subset]
date_range = (res_dates[0], res_dates[-1])
else:
date_range = None
filtered_attrs = attrs.copy()
if 'PROD_PERF_MASK' in attrs:
filtered_attrs.remove('PROD_PERF_MASK')
if 'INJE_PERF_MASK' in attrs:
filtered_attrs.remove('INJE_PERF_MASK')
output = model.get_spatial_well_control(filtered_attrs, date_range=date_range, fill_shut=0., fill_outside=0.)
if 'PROD_PERF_MASK' in attrs or 'INJE_PERF_MASK' in attrs:
control = []
i = 0
for attr in attrs:
if attr == 'PROD_PERF_MASK':
control.append(get_spatial_perf(model, sequence_subset, mode='PROD'))
elif attr == 'INJE_PERF_MASK':
control.append(get_spatial_perf(model, sequence_subset, mode='INJE'))
else:
control.append(output['control'][:, i][:, None])
i += 1
output['control'] = np.concatenate(control, axis=1)
return output
[docs] def dump_samples(self, path, n_epoch=1, prefix=None, state=True, **kwargs):
"""Dump samples from the dataset.
Parameters
----------
path: str
Path to the directory for dump.
n_epoch: int
Number of times to pass through the dataset.
prefix: str, None
Prefix for dumped samples.
state: bool
If True, dump the state of the samples
kwargs: dict
Additional named arguments for sample.dump
Returns
-------
"""
if not os.path.isdir(path):
os.mkdir(path)
prefix = prefix + '_' if prefix is not None else ''
i = 0
for _ in range(n_epoch):
for sample in self:
sample.dump(os.path.join(path, prefix+str(i)+'.hdf5'), state=state, **kwargs)
i += 1
return self
[docs] def convert_to_other_fmt(self, new_root_dir, new_fmt='hdf5', results_to_events=True, **kwargs):
"""Convert dataset to a new format.
Parameters
----------
new_root_dir: str
Directory to save converted dataset
new_fmt: str
Extension to use
kwargs: dict
Any additional named arguments passed to Field.dump
Returns
-------
FieldDataset
"""
if not os.path.exists(new_root_dir):
os.makedirs(new_root_dir)
for i, path in enumerate(self.files):
path, _ = os.path.splitext(path)
if os.path.split(path)[0]:
os.makedirs(os.path.join(new_root_dir, os.path.split(path)[0]))
path = '.'.join([path, new_fmt])
model = self._load_model(i, force_wells_calculations=True)
if results_to_events:
model.wells.results_to_events(grid=model.grid)
config = None if new_fmt == 'hdf5' else self.config
model.dump(path=os.path.join(new_root_dir, path), config=config, **kwargs)
self.__init__(
src=new_root_dir,
sample_attrs=self.sample_attrs,
fmt=(new_fmt, ),
subset_generator=self.subset_generator
)
return self
@property
def filtered_statistics(self):
"""Filters out non-normalized attrs and attrs, which are not presented in `sample_attrs`, from statistics."""
filtered_stats = dict()
for key, value in zip(('MEAN', 'STD', 'MIN', 'MAX'), (self.mean, self.std, self.min, self.max)):
if value is None:
raise RuntimeError("Dataset's statistics are not calculated!")
filtered_stat = {
comp: {} for comp in self.sample_attrs
if comp not in NON_NORMALIZED_ATTRS and len(self.sample_attrs[comp]) > 0
}
for comp in filtered_stat:
if comp not in value:
raise ValueError('Component "%s" is not presented in calculated statistics.' % comp)
for attr in self.sample_attrs[comp]:
if attr not in value[comp]:
raise ValueError('Attribute "%s" of component "%s" is not presented in calculated statistics.'
% (attr, comp))
filtered_stat[comp][attr] = value[comp][attr]
if comp not in self._attrs_sampled_as_dict:
filtered_stat[comp] = np.stack(
[filtered_stat[comp][attr] for attr in self.sample_attrs[comp]]
)
filtered_stats[key] = filtered_stat
return filtered_stats
[docs] def calculate_statistics(self): # pylint: disable=too-many-branches
"""Calculate mean and std values for the attributes of the dataset."""
# Change sampling behavior for statistics' calculation.
subset_generator, self.subset_generator = self.subset_generator, None
unravel_model, self.unravel_model = self.unravel_model, False
mean, mean_of_squares, std, minim, maxim = {}, {}, {}, {}, {}
for comp in self.sample_attrs:
if comp not in NON_NORMALIZED_ATTRS:
mean[comp] = {attr: [] for attr in self.sample_attrs[comp]}
mean_of_squares[comp] = {attr: [] for attr in self.sample_attrs[comp]}
std[comp] = {attr: [] for attr in self.sample_attrs[comp]}
minim[comp] = {attr: [] for attr in self.sample_attrs[comp]}
maxim[comp] = {attr: [] for attr in self.sample_attrs[comp]}
for i in range(len(self)):
m, m_sq, mn, mx = self._get_model_statistics(i)
for comp in m:#pylint:disable=consider-using-dict-items
for attr in m[comp]:
mean[comp][attr].append(m[comp][attr])
mean_of_squares[comp][attr].append(m_sq[comp][attr])
minim[comp][attr].append(mn[comp][attr])
maxim[comp][attr].append(mx[comp][attr])
for comp in mean:#pylint:disable=consider-using-dict-items
for attr in mean[comp]:
mean[comp][attr] = np.mean(mean[comp][attr], axis=0)
mean_of_squares[comp][attr] = np.mean(mean_of_squares[comp][attr], axis=0)
std[comp][attr] = np.sqrt(np.abs(mean_of_squares[comp][attr] - mean[comp][attr]**2))
minim[comp][attr] = np.min(minim[comp][attr], axis=0)
maxim[comp][attr] = np.max(maxim[comp][attr], axis=0)
# Recover old sampling behavior
self.subset_generator = subset_generator
self.unravel_model = unravel_model
self.mean, self.std, self.min, self.max = mean, std, minim, maxim
return self
def _get_model_statistics(self, idx):
"""Get mean and mean of squares for the attributes of the model."""
sample = self._get_sample(idx)
mean, mean_of_squares, minim, maxim = {}, {}, {}, {}
for comp in sample.keys():
if comp.upper() in NON_NORMALIZED_ATTRS:
continue
mask = sample.masks.well_mask if comp.upper() == 'CONTROL' else None
mean[comp] = dict()
mean_of_squares[comp] = dict()
minim[comp] = dict()
maxim[comp] = dict()
if comp in self._attrs_sampled_as_dict:
ax = 0
for attr, arr in sample[comp].items():
if attr.upper() == 'DISTANCES':
arr = arr.copy().astype(np.float)
arr[sample.masks.invalid_neighbours_mask] = np.nan
mean[comp][attr] = np.nanmean(arr, axis=ax)
mean_of_squares[comp][attr] = np.nanmean(np.power(arr, 2), axis=ax)
minim[comp][attr] = np.nanmin(arr, axis=ax)
maxim[comp][attr] = np.nanmax(arr, axis=ax)
else:
ax = 1 if comp not in SEQUENTIAL_ATTRS else (0, 2)
if mask is not None:
arr = sample[comp][..., mask]
else:
arr = sample[comp]
comp_mean = overflow_safe_mean(arr, axis=ax)
comp_mean_of_squares = overflow_safe_mean(np.power(arr, 2), axis=ax)
comp_min = np.min(arr, axis=ax)
comp_max = np.max(arr, axis=ax)
for i, attr in enumerate(self.sample_attrs[comp]):
mean[comp][attr] = comp_mean[i]
mean_of_squares[comp][attr] = comp_mean_of_squares[i]
minim[comp][attr] = comp_min[i]
maxim[comp][attr] = comp_max[i]
return mean, mean_of_squares, minim, maxim
[docs] def dump_statistics(self, path):
"""Dump mean and std values of the dataset into a file."""
if self.std is None or self.mean is None or self.min is None or self.max is None:
raise RuntimeError("Dataset's statistics are not calculated!")
with open(path, 'wb') as f:
pickle.dump([self.mean, self.std, self.min, self.max], f)
[docs] def load_statistics(self, path):
"""Load mean and std values of the dataset from a file."""
with open(path, 'rb') as f:
self.mean, self.std, self.min, self.max = pickle.load(f)
for kind in ('mean', 'std', 'min', 'max'):
stats = getattr(self, kind)
upper_stats = {}
for comp, value in stats.items():
if isinstance(value, dict):
upper_stats[comp.upper()] = {}
for attr, arr in value.items():
upper_stats[comp.upper()][attr.upper()] = arr
else:
upper_stats[comp.upper()] = value
setattr(self, kind, upper_stats)
if 'CONTROL' in self.sample_attrs and 'CONTROL' in self.mean:
for kind in ('mean', 'std', 'min', 'max'):
stats = getattr(self, kind)
for k in stats['CONTROL']:
if k in CONTROL_TO_RESULTS_KW and CONTROL_TO_RESULTS_KW[k] in self.sample_attrs['CONTROL']:
stats['CONTROL'][CONTROL_TO_RESULTS_KW[k]] = stats['CONTROL'].pop(k)
@property
def sample_attrs(self):
"""Attributes represented in the samples."""
return self._sample_attrs
@sample_attrs.setter
def sample_attrs(self, x):
self._sample_attrs = {
comp.upper(): [attr.upper() for attr in x[comp]] for comp in x
}
[docs]class FieldSample(BaseComponent):
"""Class representing the samples from the dataset.
Parameters
----------
path: str, optional
Path to the file. Only HDF5 files are supported at the moment.
field: Field, optional
dataset: FieldDataset, optional
state: dict, optional
sample: dict-like, optional
"""
class _decorators:
"""Decorators for the FieldSample."""
@classmethod
def without_batch_dimension(cls, method):
"""Decorates sample methods to be applied without the batch dimension."""
def decorated(instance, inplace=False, **kwargs):
batch_dimension = instance.state.batch_dimension if hasattr(instance.state,
'batch_dimension') else False
if batch_dimension:
instance = instance.transformed(RemoveBatchDimension, inplace=inplace)
inplace = True
instance = method(instance, inplace=inplace, **kwargs)
if batch_dimension:
instance = instance.transformed(AddBatchDimension, inplace=inplace)
return instance
return decorated
def __init__(self, path=None, field=None, dataset=None, state=None, **sample):
super().__init__(**sample)
self._path = path
self._field = field
self.sample_attrs = dataset.sample_attrs if dataset is not None else None
self.dataset = dataset
if state is not None:
self.init_state(**state)
def _nested_dicts_to_base_components(self, class_name, d):
if isinstance(d, dict):
d = BaseComponent(class_name=class_name, **d)
for key, value in d.items():
value = self._nested_dicts_to_base_components(key, value)
setattr(d, key, value)
return d
def __setattr__(self, key, value):
if key[0] != '_':
value = self._nested_dicts_to_base_components(key.upper(), value)
super().__setattr__(key, value)
[docs] def empty_like(self):
"""Get an empty sample with the same state and the structure of embedded BaseComponents (if any)."""
empty = super().empty_like()
empty = FieldSample(field=self.field, dataset=self.dataset, state=empty.state.as_dict(), **dict(empty))
empty.sample_attrs = self.sample_attrs
return empty
[docs] def copy(self):
"""Get a copy of the sample."""
copy = super().copy()
copy.dataset = self.dataset
copy.field = self.field
copy.sample_attrs = self.sample_attrs
return copy
[docs] def dump(self, path, **kwargs):
"""Dump the sample into a file.
Parameters
----------
path: str
Path to the file.
kwargs: dict
Additional named arguments passed to BaseComponent's dump method.
"""
fname = os.path.basename(path)
fmt = os.path.splitext(fname)[1].strip('.')
if fmt.upper() == 'HDF5':
if hasattr(self.state, 'tensor') and self.state.tensor:
out = self.transformed(ToNumpy)
return out.dump(path, **kwargs)
for state, value in self.state.as_dict().items():
if issubclass(value.__class__, BaseComponent):
if state == 'sample_attributes':
for k, v in value.items():
value[k] = np.array(v, dtype='S16')
setattr(self, state, value)
self.set_state(**{state: 'base_component'})
return self._dump_hdf5(path, **kwargs)
raise NotImplementedError('File format {} not supported.'.format(fmt))
[docs] def load(self, **kwargs):
"""Load sample from a file.
Parameters
----------
kwargs: dict
Additional named arguments passed to the load method.
Returns
-------
sample: FieldSample
Sample with loaded data.
"""
if self._path is None:
raise RuntimeError('You should specify a path before loading!')
fname = os.path.basename(self._path)
fmt = os.path.splitext(fname)[1].strip('.')
if fmt.upper() == 'HDF5':
self._load_hdf5(self._path, **kwargs)
else:
raise NotImplementedError('File format {} not supported.'.format(fmt))
for state, value in self.state.as_dict().items():
if value == 'base_component':
value = getattr(self, state)
if state == 'sample_attributes':
for k, v in value.items():
value[k] = list(v.astype('U'))
self.set_state(**{state: value})
delattr(self, state)
return self
@property
def field(self):
"""Link to the parent field."""
return self._field
@field.setter
def field(self, x):
if x is not None and not isinstance(x, Field):
raise ValueError('Can assign only instances of the class %s!' % str(Field))
self._field = x
@property
def dataset(self):
"""Link to the parent dataset."""
return self._dataset
@dataset.setter
def dataset(self, x):
if x is not None and not isinstance(x, FieldDataset):
raise ValueError('Can assign only instances of the class %s!\nGiven %s' % (str(FieldDataset), type(x)))
self._dataset = x
if x is not None:
self.init_state(
spatial=x.unravel_model,
cropped_at_mask=None if x.unravel_model else 'ACTNUM'
)
try:
self.init_state(dataset_statistics=self._nested_dicts_to_base_components(
'DATASET_STATISTICS', x.filtered_statistics
))
except RuntimeError:
pass
[docs] def at_wells(self, inplace=False):
"""Crop all the spatial arrays to the perforated cells. Ravel if needed.
Parameters
----------
inplace: bool
Returns
-------
sample: FieldSample
Cropped sample.
"""
return self.as_ravel(inplace=inplace, crop_at_mask='WELL_MASK')
[docs] @_decorators.without_batch_dimension
def as_spatial(self, inplace=False):
"""Transform the sample's arrays to the spatial form.
Parameters
----------
inplace: bool
Returns
-------
sample: FieldSample
"""
raise NotImplementedError()
# pylint: disable=too-many-nested-blocks
[docs] @_decorators.without_batch_dimension
def as_ravel(self, inplace=False, crop_at_mask='ACTNUM'):
"""Ravel the sample's arrays.
Parameters
----------
inplace: bool
Returns
-------
sample: FieldSample
"""
out = self if inplace else self.empty_like()
if self.state.spatial:
for comp in self.keys():#pylint:disable=consider-using-dict-items
if comp.upper() == 'TABLES':
out[comp] = self[comp]
continue
if comp.upper() in ('MASKS', 'GRID'):
for attr in self[comp].keys():
if attr.upper() in ('TIME', 'CONTROL_T'):
out[comp][attr] = self[comp][attr]
continue
if attr.upper() == 'NAMED_WELL_MASK':
for well in self[comp][attr].keys():
new_shape = (-1,) + tuple(self[comp][attr][well].shape[3:])
out[comp][attr][well] = \
self[comp][attr].reshape(attr=well, newshape=new_shape, order='F', inplace=False)
continue
if attr.upper() in ('CF_MASK', 'PERF_MASK'):
new_shape = tuple(self[comp][attr].shape[:-3]) + (-1,)
else:
new_shape = (-1,) + tuple(self[comp][attr].shape[3:])
out[comp][attr] = self[comp].reshape(attr=attr, newshape=new_shape, order='F', inplace=False)
else:
new_shape = tuple(self[comp].shape[:-3]) + (-1, )
out[comp] = self.reshape(attr=comp, newshape=new_shape, order='F', inplace=False)
out.set_state(spatial=False)
if crop_at_mask != self.state.cropped_at_mask:
if self.state.cropped_at_mask is not None:
out = self._uncrop_from_mask(out, self.state.cropped_at_mask)
out = self._crop_at_mask(out, crop_at_mask)
return out
@staticmethod
def _crop_at_mask(obj, mask_name):
"""Crop a sample at a given binary mask.
Parameters
----------
obj: FieldSample
Sample to be cropped.
mask_name: str
Name of the mask from the sample['MASKS'].
Returns
-------
obj: FieldSample
Cropped sample.
"""
assert not obj.state.spatial
mask = obj.masks[mask_name]
if isinstance(mask, torch.Tensor):
mask = mask.bool()
else:
mask = mask.astype(bool)
for comp in obj.keys():
if comp.upper() == 'TABLES':
continue
if comp.upper() in ('MASKS', 'GRID'):
for attr in obj[comp].keys():
if attr.upper() in ('CF_MASK', 'PERF_MASK'):
obj[comp][attr] = obj[comp][attr][..., mask]
elif attr.upper() == 'NAMED_WELL_MASK':
for well in obj[comp][attr].keys():
obj[comp][attr][well] = obj[comp][attr][well][..., mask]
elif attr.upper() not in (mask_name.upper(), 'TIME', 'CONTROL_T') and obj[comp][attr] is not None:
obj[comp][attr] = obj[comp][attr][mask]
else:
obj[comp] = obj[comp][..., mask]
obj.set_state(cropped_at_mask=mask_name.upper())
return obj
@staticmethod
def _uncrop_from_mask(obj, mask_name):
"""Reverse operation to the crop_at_mask.
Parameters
----------
obj: FieldSample
mask_name: str
Returns
-------
obj: FieldSample
"""
raise NotImplementedError()
def _initialize_transform(self, transforms):
"""Initialize transforms before application."""
if not isinstance(transforms, (list, tuple, Compose)):
transforms = [transforms]
initialized_transforms = []
for t in transforms:
if inspect.isclass(t):
if issubclass(t, Normalize):
initialized_transforms.append(t(
mean=self.state.dataset_statistics.mean,
std=self.state.dataset_statistics.std,
unravel_model=self.state.spatial
))
else:
initialized_transforms.append(t())
else:
initialized_transforms.append(t)
return Compose(initialized_transforms)
@property
def sample_attrs(self):
"""Attributes represented in the sample."""
return self.state.sample_attributes
@sample_attrs.setter
def sample_attrs(self, x):
x = None if x is None else {comp.upper(): [attr.upper() for attr in x[comp]] for comp in x.keys()}
x = self._nested_dicts_to_base_components('SAMPLE_ATTRIBUTES', x)
if hasattr(self.state, 'SAMPLE_ATTRIBUTES'):
self.set_state(sample_attributes=x)
else:
self.init_state(sample_attributes=x)
@property
def device(self):
"""Get the sample's device (if it is in Torch format)
Returns
-------
device: torch.device
"""
ref = None
for _, value in self.items():
if isinstance(value, BaseComponent):
for _, arr in value.items():
if isinstance(arr, BaseComponent):
for _, mask in arr.items():
ref = mask
break
else:
ref = arr
break
else:
ref = value
if ref is None:
raise RuntimeError('The sample is empty!')
if isinstance(ref, torch.Tensor):
return ref.device
raise RuntimeError('The sample should be in the PyTorch format! Found: %s' % type(ref))
[docs] def to(self, device, inplace=True):
"""Change the sample's device (if it is in Torch format).
Parameters
----------
device: str, torch.device
inplace: bool
Returns
-------
sample: FieldSample
Sample at the new device
"""
if self.device == device:
return self if inplace else self.copy()
out = self if inplace else self.empty_like()
for comp, value in self.items():
if isinstance(value, BaseComponent):
for attr, arr in value.items():
if isinstance(arr, BaseComponent):
for well, mask in arr.items():
out[comp][attr][well] = mask.to(device)
else:
out[comp][attr] = arr.to(device)
else:
out[comp] = value.to(device)
return out
class SequenceSubset:
"""Baseclass for generating subsets of sequences."""
def __init__(self, size, low, high, **kwargs):
"""
Parameters
----------
size: int
Lenght of the generated sequences
low: int
Minimal possible timestep (inclusive)
high: int
Maximal possible timestep (exclusive)
kwargs: optional
"""
self.size = size
self.low = low
self.high = high
_ = kwargs
def __call__(self):
"""Generate subset of timesteps."""
raise NotImplementedError('Abstract method is not implemented.')
[docs]class RandomSubsequence(SequenceSubset):
"""Generator of timestep subsequences."""
def __call__(self):
start = np.random.randint(low=self.low, high=self.high - self.size + 1)
subset = np.arange(start, start + self.size)
return np.sort(subset)