"""Dataset wrappers for Fields."""
import os
import pickle
import inspect
import numpy as np
import pandas as pd
import torch

from import Dataset

from ..field import Field
from ..field.base_component import BaseComponent
from ..field.utils import recursive_insensitive_glob, hasnested, overflow_safe_mean, get_spatial_perf
from .utils import get_config, STATES_KEYWORD
from .transforms import ToNumpy, Normalize, Compose, RemoveBatchDimension, AddBatchDimension, \


def safe_check(comp, state, expected, default=False):
    """Check that components's state has expected value or return default if state is not defined."""
        return getattr(comp.state, state) == expected
    except AttributeError:
        return default

[docs]class FieldDataset(Dataset): # pylint: disable=too-many-instance-attributes """Baseclass for dataset of fields loaded with similar configs.""" default_sample_attrs = { 'MASKS': ['ACTNUM', 'TIME'], 'GRID': [], 'ROCK': ['PORO', 'PERMX', 'PERMY', 'PERMZ'], 'STATES': ['PRESSURE', 'RS', 'SGAS', 'SOIL', 'SWAT'], 'CONTROL': ['BHPT'], } _attrs_sampled_as_dict = ('MASKS', 'GRID', 'TABLES') def __init__(self, src, sample_attrs=None, fmt=('dat', 'data', 'hdf5'), subset_generator=None, unravel_model=None, from_samples=False, allow_change_preloaded=False): """ Parameters ---------- src: str, Field, FieldSample or list of Fields or FieldSamples Path to a directory containing fields for the dataset or preloaded Fields sample_attrs: dict Attributes to be represented in samples fmt: str or tuple Format in which fields are represented subset_generator: callable or None Function generating subsequences for sequential attrs (states, control) Should return array-like objects with timestep indices If None, full sequences will be sampled unravel_model: bool or None Either or not unravel loaded models If None, will be inferred from sample_attrs (set to False if 'neighbours' or 'distances' keys are presented) from_samples: bool If True, tries to load samples from previously dumped dataset (with FieldDataset.dump_samples). The sample_attrs will not affect the content of the loaded samples. The transforms will still be applied. """ # TODO: add possibility to make subsets of timesteps limited to constant control super().__init__() if isinstance(fmt, str): fmt = (fmt, ) files = [] self.root_dir = None self.preloaded = None if isinstance(src, str): for f in fmt: files += recursive_insensitive_glob(src, pattern='*.%s' % f, return_relative=True) self.root_dir = src else: self.preloaded = np.atleast_1d(src) self.fmt = fmt self.files = files self.transform = None self._sample_attrs = None self.sample_attrs = sample_attrs if sample_attrs is not None else self.default_sample_attrs self.from_samples = from_samples self.allow_change_preloaded = allow_change_preloaded self.config = get_config() # TODO make config dependent on the sample attrs self.subset_generator = subset_generator self.mean = None self.std = None self.min = None self.max = None self.masks_getter_map = { 'ACTNUM': self._get_actnum, 'WELL_MASK': self._get_well_mask, 'NAMED_WELL_MASK': self._get_named_well_mask, 'NEIGHBOURS': self._get_neighbours, 'INVALID_NEIGHBOURS_MASK': self._get_invalid_neighbours_mask, 'TIME': self._get_time, 'CF_MASK': self._get_connection_factors, 'PERF_MASK': self._get_perforation_mask } self.grid_getter_map = { 'DISTANCES': self._get_distances, 'XYZ': self._get_xyz } self.attrs_getter_map = { 'STATES': self._get_states, 'ROCK': self._get_rock, 'CONTROL': self._get_control } invalid_unravel_attrs = { 'MASKS': ['NEIGHBOURS'], 'GRID': ['DISTANCES'] } if unravel_model: for comp, attrs in invalid_unravel_attrs.items(): for attr in attrs: if hasnested(self.sample_attrs, comp, attr): raise ValueError('Can not unravel model and sample %s simultaneously.' % attr) if unravel_model is None: ravel = False for comp, attrs in invalid_unravel_attrs.items(): for attr in attrs: ravel = ravel or hasnested(self.sample_attrs, comp, attr) unravel_model = not ravel self.unravel_model = unravel_model def __len__(self): if self.preloaded is not None: return len(self.preloaded) return len(self.files) def __getitem__(self, idx): if self.from_samples: sample = self._load_sample(idx) else: sample = self._get_sample(idx) if self.transform: sample = self.transform(sample) return sample def _get_sample(self, idx): # pylint: disable=too-many-branches """Get sample from the dataset Parameters ---------- idx: int, torch.Tensor Index of the field Returns ------- sample: FieldSample """ # TODO time and batch dimensions are transposed if torch.is_tensor(idx): idx = idx.tolist() if self.subset_generator is not None: sequence_subset = list(self.subset_generator()) if not sequence_subset: raise ValueError('subset_generator should not generate empty subsets!') else: sequence_subset = None config = self.config.copy() config[STATES_KEYWORD] = {'attrs': self.config[STATES_KEYWORD]['attrs'], 'subset': sequence_subset} if 'Aquifers' in config: del config['Aquifers'] if self.preloaded is None: model = self._load_model(idx, config) else: if isinstance(self.preloaded[idx], FieldSample): return self.preloaded[idx] model = self._get_preloaded(idx) sample = {} getter_kwargs = dict( sequence_subset=sequence_subset, fill_invalid_neighbours=INVALID_VALUE_FILLER, neighbouring_radius=1 ) for comp, attrs in self.sample_attrs.items(): if comp == 'MASKS': sample[comp] = {} for attr, mask_getter in self.masks_getter_map.items(): # FIXME if not self.unravel_model and attr.upper() in ('CF_MASK', 'PERF_MASK'): continue sample[comp][attr] = mask_getter(model, **getter_kwargs) elif comp == 'GRID': sample[comp] = {} for attr in attrs: sample[comp][attr] = self.grid_getter_map[attr](model, **getter_kwargs) elif comp == 'TABLES': sample[comp] = {} sample[comp] = self._get_tables(model, attrs) elif comp == 'CONTROL': res = self.attrs_getter_map[comp](model, attrs, **getter_kwargs) sample[comp] = res['control'] sample['MASKS']['CONTROL_T'] = res['t'] else: sample[comp] = self.attrs_getter_map[comp](model, attrs, **getter_kwargs) sample = FieldSample(field=model, dataset=self, **sample) for key in list(sample.masks.keys()): if sample.masks[key] is None: del sample.masks[key] if not self.unravel_model: sample.as_ravel(inplace=True) return sample def _get_preloaded(self, idx): """Get a field from preloaded.""" model = self.preloaded[idx] if self.allow_change_preloaded: if model.state.spatial != self.unravel_model: if self.unravel_model: model.to_spatial() else: model.ravel() if 'CONTROL' in self.sample_attrs: if not model.wells.state.all_tracks_complete: model.wells.drop_incomplete() if not model.wells.state.has_blocks: model.wells.get_wellblocks(model.grid) if not model.wells.state.full_perforation: model.wells.apply_perforations() if not model.wells.state.all_tracks_inside: model.wells.drop_outside() if model.meta['MODEL_TYPE'] == 'ECL': model.wells.compute_events(grid=model.grid) else: assert model.state.spatial == self.unravel_model if 'CONTROL' in self.sample_attrs: assert model.wells.state.all_tracks_complete assert model.wells.state.has_blocks assert model.wells.state.full_perforation assert model.wells.state.all_tracks_inside return model def _load_model(self, idx, config=None, force_wells_calculations=False): """Loads field by index. Parameters ---------- idx: int config: dict, optional Config used while loading the model Returns ------- model: Field """ _, fmt = os.path.splitext(self.files[idx]) fmt = fmt.strip('.').lower() config = self.config if config is None else config if fmt == 'hdf5': if 'subset' not in config[STATES_KEYWORD] or config[STATES_KEYWORD]['subset'] is None: config = None else: for comp in config: config[comp]['attrs'] = None model = Field(path=os.path.join(self.root_dir, self.files[idx]), config=config, encoding='auto:10000', loglevel='ERROR') model.load(raise_errors=False) if self.unravel_model: model.to_spatial() if 'CONTROL' in self.sample_attrs: if not safe_check(model.wells, 'all_tracks_complete', True) or force_wells_calculations: model.wells.drop_incomplete() if not safe_check(model.wells, 'has_blocks', True) or force_wells_calculations: model.wells.get_wellblocks(grid=model.grid) if not safe_check(model.wells, 'full_perforation', True) or force_wells_calculations: model.wells.apply_perforations() if not safe_check(model.wells, 'all_tracks_inside', True) or force_wells_calculations: model.wells.drop_outside() if model.meta['MODEL_TYPE'] == 'ECL': model.wells.compute_events(grid=model.grid) if not self.unravel_model: model.ravel() return model def _load_sample(self, idx): sample = FieldSample(os.path.join(self.root_dir, self.files[idx])) sample.load() return sample def _get_actnum(self, model, **kwargs): """Get ACTNUM of the model""" _ = kwargs if hasattr(model.grid, 'actnum'): return getattr(model.grid, 'actnum').astype(np.bool) actnum = np.ones(model.grid.dimens, dtype=np.bool) return actnum if self.unravel_model else actnum.ravel(order='F') def _get_well_mask(self, model, **kwargs): """Get well mask of the model.""" _ = kwargs if hasnested(self.sample_attrs, 'MASKS', 'WELL_MASK') or 'CONTROL' in self.sample_attrs: return model.well_mask != '' return None def _get_named_well_mask(self, model, **kwargs): """Get well mask of the model.""" _ = kwargs if hasnested(self.sample_attrs, 'MASKS', 'NAMED_WELL_MASK') or 'CONTROL' in self.sample_attrs: well_mask = model.well_mask named_well_mask = {} for well in model.wells: named_well_mask[] = well_mask == return named_well_mask return None def _get_neighbours(self, model, fill_invalid_neighbours=INVALID_VALUE_FILLER, neighbouring_radius=-1, **kwargs): """Get connectivity matrix of cells presented in the model.""" if 'MASKS' in self.sample_attrs and 'NEIGHBOURS' in self.sample_attrs['MASKS']: neighbours = model.grid.get_neighbors_matrix( connectivity=neighbouring_radius, fill_value=fill_invalid_neighbours, ravel_index=True ) # Indices are with respect to the full vectors: with active and non-active cells # We want indices with respect to the vector of active cells full_to_active_ind = kwargs['MASKS']['ACTNUM'].copy().astype( full_to_active_ind[full_to_active_ind == 1] = np.arange(full_to_active_ind.sum()) full_to_active_ind = np.concatenate([full_to_active_ind, [-1]]) neighbours = full_to_active_ind[neighbours.ravel()].reshape(neighbours.shape) # Neighbours should include the cell itself itself_ind = np.arange(neighbours.shape[0])[:, np.newaxis] neighbours = np.concatenate([itself_ind, neighbours], axis=1) return neighbours return None def _get_invalid_neighbours_mask(self, model, fill_invalid_neighbours=INVALID_VALUE_FILLER, neighbouring_radius=-1, **kwargs): """Get mask of invalid neighbours (non-active or out of geometric bounds).""" _ = kwargs if hasnested(self.sample_attrs, 'GRID', 'DISTANCES'): neighbours = model.grid.get_neighbors_matrix( connectivity=neighbouring_radius, fill_value=fill_invalid_neighbours, ravel_index=True ) return neighbours == INVALID_VALUE_FILLER return None @staticmethod def _get_time(model, sequence_subset=None, **kwargs): """Get time in days associated with states timesteps relative to model start date.""" _ = kwargs dates = model.result_dates sec_in_day = 86400 t = (dates - model.start).total_seconds().values / sec_in_day return t if sequence_subset is None else t[sequence_subset] def _get_connection_factors(self, model, sequence_subset=None, **kwargs): # FIXME calls the field's method twice _ = kwargs if sequence_subset is not None: res_dates = model.result_dates if res_dates.size: res_dates = res_dates[sequence_subset] date_range = (res_dates[0], res_dates[-1]) else: date_range = None if hasnested(self.sample_attrs, 'MASKS', 'CF_MASK'): return model.get_spatial_connection_factors_and_perforation_ratio(date_range=date_range)[0] return None def _get_perforation_mask(self, model, sequence_subset=None, **kwargs): # FIXME calls the field's method twice _ = kwargs if sequence_subset is not None: res_dates = model.result_dates if res_dates.size: res_dates = res_dates[sequence_subset] date_range = (res_dates[0], res_dates[-1]) else: date_range = None if hasnested(self.sample_attrs, 'MASKS', 'PERF_MASK'): return model.get_spatial_connection_factors_and_perforation_ratio(date_range=date_range)[1] return None
[docs] @staticmethod def to_dates(model, t): """Restore actual dates from time deltas.""" dates = model.start + np.array([pd.Timedelta(i, unit='day') for i in t]) return pd.to_datetime(dates)
@staticmethod def _get_distances(model, fill_invalid_neighbours=INVALID_VALUE_FILLER, neighbouring_radius=-1, **kwargs): """Get matrix of distances for neighbouring cells.""" _ = kwargs return model.grid.calculate_neighbours_distances( connectivity=neighbouring_radius, fill_value=fill_invalid_neighbours ) @staticmethod def _get_xyz(model, **kwargs): _ = kwargs return def _get_states(self, model, attrs, sequence_subset=None, **kwargs): """Get stacked states sequence.""" _ = kwargs if (self.preloaded is not None) and (sequence_subset is not None): return np.stack([getattr(model.states, attr)[sequence_subset] for attr in attrs], axis=1) return np.stack([getattr(model.states, attr) for attr in attrs], axis=1) @staticmethod def _get_rock(model, attrs, **kwargs): """Get stacked rock attributes.""" _ = kwargs return np.stack([getattr(model.rock, attr) for attr in attrs], axis=0) @staticmethod def _get_tables(model, attrs): """Get sample table data""" return { attr: getattr(model.tables, attr).to_numpy() if attr in TABLES_WITHOUT_INDEX else getattr(model.tables, attr).to_numpy(include_index=True) for attr in attrs } @staticmethod def _get_control(model, attrs, sequence_subset=None, **kwargs): """Get control in a spatial form (defined for all cells, meaningful values in perforated cells, other cells are filled with zeros) with corresponding dates. """ _ = kwargs if sequence_subset is not None: res_dates = model.result_dates if res_dates.size: res_dates = res_dates[sequence_subset] date_range = (res_dates[0], res_dates[-1]) else: date_range = None filtered_attrs = attrs.copy() if 'PROD_PERF_MASK' in attrs: filtered_attrs.remove('PROD_PERF_MASK') if 'INJE_PERF_MASK' in attrs: filtered_attrs.remove('INJE_PERF_MASK') output = model.get_spatial_well_control(filtered_attrs, date_range=date_range, fill_shut=0., fill_outside=0.) if 'PROD_PERF_MASK' in attrs or 'INJE_PERF_MASK' in attrs: control = [] i = 0 for attr in attrs: if attr == 'PROD_PERF_MASK': control.append(get_spatial_perf(model, sequence_subset, mode='PROD')) elif attr == 'INJE_PERF_MASK': control.append(get_spatial_perf(model, sequence_subset, mode='INJE')) else: control.append(output['control'][:, i][:, None]) i += 1 output['control'] = np.concatenate(control, axis=1) return output
[docs] def set_transform(self, transform): """Set transforms to be applied to each sample Parameters ---------- transform: class Class of transform to apply list of Classes can be used to compose several transforms Returns ------- out: FieldDataset """ if not isinstance(transform, (list, tuple)): transform = [transform] self.transform = [] for t in transform: if inspect.isclass(t) and issubclass(t, Transform): if issubclass(t, Normalize): if self.std is None or self.mean is None: raise RuntimeError("Dataset's statistics are not calculated!") self.transform.append(t( mean=self.filtered_statistics['MEAN'], std=self.filtered_statistics['STD'], unravel_model=self.unravel_model )) else: self.transform.append(t()) else: self.transform.append(t) self.transform = Compose(self.transform) return self
[docs] def dump_samples(self, path, n_epoch=1, prefix=None, state=True, **kwargs): """Dump samples from the dataset. Parameters ---------- path: str Path to the directory for dump. n_epoch: int Number of times to pass through the dataset. prefix: str, None Prefix for dumped samples. state: bool If True, dump the state of the samples kwargs: dict Additional named arguments for sample.dump Returns ------- """ if not os.path.isdir(path): os.mkdir(path) prefix = prefix + '_' if prefix is not None else '' i = 0 for _ in range(n_epoch): for sample in self: sample.dump(os.path.join(path, prefix+str(i)+'.hdf5'), state=state, **kwargs) i += 1 return self
[docs] def convert_to_other_fmt(self, new_root_dir, new_fmt='hdf5', results_to_events=True, **kwargs): """Convert dataset to a new format. Parameters ---------- new_root_dir: str Directory to save converted dataset new_fmt: str Extension to use kwargs: dict Any additional named arguments passed to Field.dump Returns ------- FieldDataset """ if not os.path.exists(new_root_dir): os.makedirs(new_root_dir) for i, path in enumerate(self.files): path, _ = os.path.splitext(path) if os.path.split(path)[0]: os.makedirs(os.path.join(new_root_dir, os.path.split(path)[0])) path = '.'.join([path, new_fmt]) model = self._load_model(i, force_wells_calculations=True) if results_to_events: model.wells.results_to_events(grid=model.grid) config = None if new_fmt == 'hdf5' else self.config model.dump(path=os.path.join(new_root_dir, path), config=config, **kwargs) self.__init__( src=new_root_dir, sample_attrs=self.sample_attrs, fmt=(new_fmt, ), subset_generator=self.subset_generator ) return self
@property def filtered_statistics(self): """Filters out non-normalized attrs and attrs, which are not presented in `sample_attrs`, from statistics.""" filtered_stats = dict() for key, value in zip(('MEAN', 'STD', 'MIN', 'MAX'), (self.mean, self.std, self.min, self.max)): if value is None: raise RuntimeError("Dataset's statistics are not calculated!") filtered_stat = { comp: {} for comp in self.sample_attrs if comp not in NON_NORMALIZED_ATTRS and len(self.sample_attrs[comp]) > 0 } for comp in filtered_stat: if comp not in value: raise ValueError('Component "%s" is not presented in calculated statistics.' % comp) for attr in self.sample_attrs[comp]: if attr not in value[comp]: raise ValueError('Attribute "%s" of component "%s" is not presented in calculated statistics.' % (attr, comp)) filtered_stat[comp][attr] = value[comp][attr] if comp not in self._attrs_sampled_as_dict: filtered_stat[comp] = np.stack( [filtered_stat[comp][attr] for attr in self.sample_attrs[comp]] ) filtered_stats[key] = filtered_stat return filtered_stats
[docs] def calculate_statistics(self): # pylint: disable=too-many-branches """Calculate mean and std values for the attributes of the dataset.""" # Change sampling behavior for statistics' calculation. subset_generator, self.subset_generator = self.subset_generator, None unravel_model, self.unravel_model = self.unravel_model, False mean, mean_of_squares, std, minim, maxim = {}, {}, {}, {}, {} for comp in self.sample_attrs: if comp not in NON_NORMALIZED_ATTRS: mean[comp] = {attr: [] for attr in self.sample_attrs[comp]} mean_of_squares[comp] = {attr: [] for attr in self.sample_attrs[comp]} std[comp] = {attr: [] for attr in self.sample_attrs[comp]} minim[comp] = {attr: [] for attr in self.sample_attrs[comp]} maxim[comp] = {attr: [] for attr in self.sample_attrs[comp]} for i in range(len(self)): m, m_sq, mn, mx = self._get_model_statistics(i) for comp in m:#pylint:disable=consider-using-dict-items for attr in m[comp]: mean[comp][attr].append(m[comp][attr]) mean_of_squares[comp][attr].append(m_sq[comp][attr]) minim[comp][attr].append(mn[comp][attr]) maxim[comp][attr].append(mx[comp][attr]) for comp in mean:#pylint:disable=consider-using-dict-items for attr in mean[comp]: mean[comp][attr] = np.mean(mean[comp][attr], axis=0) mean_of_squares[comp][attr] = np.mean(mean_of_squares[comp][attr], axis=0) std[comp][attr] = np.sqrt(np.abs(mean_of_squares[comp][attr] - mean[comp][attr]**2)) minim[comp][attr] = np.min(minim[comp][attr], axis=0) maxim[comp][attr] = np.max(maxim[comp][attr], axis=0) # Recover old sampling behavior self.subset_generator = subset_generator self.unravel_model = unravel_model self.mean, self.std, self.min, self.max = mean, std, minim, maxim return self
def _get_model_statistics(self, idx): """Get mean and mean of squares for the attributes of the model.""" sample = self._get_sample(idx) mean, mean_of_squares, minim, maxim = {}, {}, {}, {} for comp in sample.keys(): if comp.upper() in NON_NORMALIZED_ATTRS: continue mask = sample.masks.well_mask if comp.upper() == 'CONTROL' else None mean[comp] = dict() mean_of_squares[comp] = dict() minim[comp] = dict() maxim[comp] = dict() if comp in self._attrs_sampled_as_dict: ax = 0 for attr, arr in sample[comp].items(): if attr.upper() == 'DISTANCES': arr = arr.copy().astype(np.float) arr[sample.masks.invalid_neighbours_mask] = np.nan mean[comp][attr] = np.nanmean(arr, axis=ax) mean_of_squares[comp][attr] = np.nanmean(np.power(arr, 2), axis=ax) minim[comp][attr] = np.nanmin(arr, axis=ax) maxim[comp][attr] = np.nanmax(arr, axis=ax) else: ax = 1 if comp not in SEQUENTIAL_ATTRS else (0, 2) if mask is not None: arr = sample[comp][..., mask] else: arr = sample[comp] comp_mean = overflow_safe_mean(arr, axis=ax) comp_mean_of_squares = overflow_safe_mean(np.power(arr, 2), axis=ax) comp_min = np.min(arr, axis=ax) comp_max = np.max(arr, axis=ax) for i, attr in enumerate(self.sample_attrs[comp]): mean[comp][attr] = comp_mean[i] mean_of_squares[comp][attr] = comp_mean_of_squares[i] minim[comp][attr] = comp_min[i] maxim[comp][attr] = comp_max[i] return mean, mean_of_squares, minim, maxim
[docs] def dump_statistics(self, path): """Dump mean and std values of the dataset into a file.""" if self.std is None or self.mean is None or self.min is None or self.max is None: raise RuntimeError("Dataset's statistics are not calculated!") with open(path, 'wb') as f: pickle.dump([self.mean, self.std, self.min, self.max], f)
[docs] def load_statistics(self, path): """Load mean and std values of the dataset from a file.""" with open(path, 'rb') as f: self.mean, self.std, self.min, self.max = pickle.load(f) for kind in ('mean', 'std', 'min', 'max'): stats = getattr(self, kind) upper_stats = {} for comp, value in stats.items(): if isinstance(value, dict): upper_stats[comp.upper()] = {} for attr, arr in value.items(): upper_stats[comp.upper()][attr.upper()] = arr else: upper_stats[comp.upper()] = value setattr(self, kind, upper_stats) if 'CONTROL' in self.sample_attrs and 'CONTROL' in self.mean: for kind in ('mean', 'std', 'min', 'max'): stats = getattr(self, kind) for k in stats['CONTROL']: if k in CONTROL_TO_RESULTS_KW and CONTROL_TO_RESULTS_KW[k] in self.sample_attrs['CONTROL']: stats['CONTROL'][CONTROL_TO_RESULTS_KW[k]] = stats['CONTROL'].pop(k)
@property def sample_attrs(self): """Attributes represented in the samples.""" return self._sample_attrs @sample_attrs.setter def sample_attrs(self, x): self._sample_attrs = { comp.upper(): [attr.upper() for attr in x[comp]] for comp in x }
[docs]class FieldSample(BaseComponent): """Class representing the samples from the dataset. Parameters ---------- path: str, optional Path to the file. Only HDF5 files are supported at the moment. field: Field, optional dataset: FieldDataset, optional state: dict, optional sample: dict-like, optional """ class _decorators: """Decorators for the FieldSample.""" @classmethod def without_batch_dimension(cls, method): """Decorates sample methods to be applied without the batch dimension.""" def decorated(instance, inplace=False, **kwargs): batch_dimension = instance.state.batch_dimension if hasattr(instance.state, 'batch_dimension') else False if batch_dimension: instance = instance.transformed(RemoveBatchDimension, inplace=inplace) inplace = True instance = method(instance, inplace=inplace, **kwargs) if batch_dimension: instance = instance.transformed(AddBatchDimension, inplace=inplace) return instance return decorated def __init__(self, path=None, field=None, dataset=None, state=None, **sample): super().__init__(**sample) self._path = path self._field = field self.sample_attrs = dataset.sample_attrs if dataset is not None else None self.dataset = dataset if state is not None: self.init_state(**state) def _nested_dicts_to_base_components(self, class_name, d): if isinstance(d, dict): d = BaseComponent(class_name=class_name, **d) for key, value in d.items(): value = self._nested_dicts_to_base_components(key, value) setattr(d, key, value) return d def __setattr__(self, key, value): if key[0] != '_': value = self._nested_dicts_to_base_components(key.upper(), value) super().__setattr__(key, value)
[docs] def empty_like(self): """Get an empty sample with the same state and the structure of embedded BaseComponents (if any).""" empty = super().empty_like() empty = FieldSample(field=self.field, dataset=self.dataset, state=empty.state.as_dict(), **dict(empty)) empty.sample_attrs = self.sample_attrs return empty
[docs] def copy(self): """Get a copy of the sample.""" copy = super().copy() copy.dataset = self.dataset copy.field = self.field copy.sample_attrs = self.sample_attrs return copy
[docs] def dump(self, path, **kwargs): """Dump the sample into a file. Parameters ---------- path: str Path to the file. kwargs: dict Additional named arguments passed to BaseComponent's dump method. """ fname = os.path.basename(path) fmt = os.path.splitext(fname)[1].strip('.') if fmt.upper() == 'HDF5': if hasattr(self.state, 'tensor') and self.state.tensor: out = self.transformed(ToNumpy) return out.dump(path, **kwargs) for state, value in self.state.as_dict().items(): if issubclass(value.__class__, BaseComponent): if state == 'sample_attributes': for k, v in value.items(): value[k] = np.array(v, dtype='S16') setattr(self, state, value) self.set_state(**{state: 'base_component'}) return self._dump_hdf5(path, **kwargs) raise NotImplementedError('File format {} not supported.'.format(fmt))
[docs] def load(self, **kwargs): """Load sample from a file. Parameters ---------- kwargs: dict Additional named arguments passed to the load method. Returns ------- sample: FieldSample Sample with loaded data. """ if self._path is None: raise RuntimeError('You should specify a path before loading!') fname = os.path.basename(self._path) fmt = os.path.splitext(fname)[1].strip('.') if fmt.upper() == 'HDF5': self._load_hdf5(self._path, **kwargs) else: raise NotImplementedError('File format {} not supported.'.format(fmt)) for state, value in self.state.as_dict().items(): if value == 'base_component': value = getattr(self, state) if state == 'sample_attributes': for k, v in value.items(): value[k] = list(v.astype('U')) self.set_state(**{state: value}) delattr(self, state) return self
@property def field(self): """Link to the parent field.""" return self._field @field.setter def field(self, x): if x is not None and not isinstance(x, Field): raise ValueError('Can assign only instances of the class %s!' % str(Field)) self._field = x @property def dataset(self): """Link to the parent dataset.""" return self._dataset @dataset.setter def dataset(self, x): if x is not None and not isinstance(x, FieldDataset): raise ValueError('Can assign only instances of the class %s!\nGiven %s' % (str(FieldDataset), type(x))) self._dataset = x if x is not None: self.init_state( spatial=x.unravel_model, cropped_at_mask=None if x.unravel_model else 'ACTNUM' ) try: self.init_state(dataset_statistics=self._nested_dicts_to_base_components( 'DATASET_STATISTICS', x.filtered_statistics )) except RuntimeError: pass
[docs] def transformed(self, transforms, inplace=False): """Apply a set of transforms to the sample. Parameters ---------- transforms: list, tuple, Compose, Transform Transform to apply inplace: bool Returns ------- sample: FieldSample Transformed sample. """ transforms = self._initialize_transform(transforms) return transforms(self, inplace=inplace)
[docs] def at_wells(self, inplace=False): """Crop all the spatial arrays to the perforated cells. Ravel if needed. Parameters ---------- inplace: bool Returns ------- sample: FieldSample Cropped sample. """ return self.as_ravel(inplace=inplace, crop_at_mask='WELL_MASK')
[docs] @_decorators.without_batch_dimension def as_spatial(self, inplace=False): """Transform the sample's arrays to the spatial form. Parameters ---------- inplace: bool Returns ------- sample: FieldSample """ raise NotImplementedError()
[docs] @_decorators.without_batch_dimension def as_ravel(self, inplace=False, crop_at_mask='ACTNUM'): """Ravel the sample's arrays. Parameters ---------- inplace: bool Returns ------- sample: FieldSample """ out = self if inplace else self.empty_like() if self.state.spatial: for comp in self.keys():#pylint:disable=consider-using-dict-items if comp.upper() == 'TABLES': out[comp] = self[comp] continue if comp.upper() in ('MASKS', 'GRID'): for attr in self[comp].keys(): if attr.upper() in ('TIME', 'CONTROL_T'): out[comp][attr] = self[comp][attr] continue if attr.upper() == 'NAMED_WELL_MASK': for well in self[comp][attr].keys(): new_shape = (-1,) + tuple(self[comp][attr][well].shape[3:]) out[comp][attr][well] = \ self[comp][attr].reshape(attr=well, newshape=new_shape, order='F', inplace=False) continue if attr.upper() in ('CF_MASK', 'PERF_MASK'): new_shape = tuple(self[comp][attr].shape[:-3]) + (-1,) else: new_shape = (-1,) + tuple(self[comp][attr].shape[3:]) out[comp][attr] = self[comp].reshape(attr=attr, newshape=new_shape, order='F', inplace=False) else: new_shape = tuple(self[comp].shape[:-3]) + (-1, ) out[comp] = self.reshape(attr=comp, newshape=new_shape, order='F', inplace=False) out.set_state(spatial=False) if crop_at_mask != self.state.cropped_at_mask: if self.state.cropped_at_mask is not None: out = self._uncrop_from_mask(out, self.state.cropped_at_mask) out = self._crop_at_mask(out, crop_at_mask) return out
@staticmethod def _crop_at_mask(obj, mask_name): """Crop a sample at a given binary mask. Parameters ---------- obj: FieldSample Sample to be cropped. mask_name: str Name of the mask from the sample['MASKS']. Returns ------- obj: FieldSample Cropped sample. """ assert not obj.state.spatial mask = obj.masks[mask_name] if isinstance(mask, torch.Tensor): mask = mask.bool() else: mask = mask.astype(bool) for comp in obj.keys(): if comp.upper() == 'TABLES': continue if comp.upper() in ('MASKS', 'GRID'): for attr in obj[comp].keys(): if attr.upper() in ('CF_MASK', 'PERF_MASK'): obj[comp][attr] = obj[comp][attr][..., mask] elif attr.upper() == 'NAMED_WELL_MASK': for well in obj[comp][attr].keys(): obj[comp][attr][well] = obj[comp][attr][well][..., mask] elif attr.upper() not in (mask_name.upper(), 'TIME', 'CONTROL_T') and obj[comp][attr] is not None: obj[comp][attr] = obj[comp][attr][mask] else: obj[comp] = obj[comp][..., mask] obj.set_state(cropped_at_mask=mask_name.upper()) return obj @staticmethod def _uncrop_from_mask(obj, mask_name): """Reverse operation to the crop_at_mask. Parameters ---------- obj: FieldSample mask_name: str Returns ------- obj: FieldSample """ raise NotImplementedError() def _initialize_transform(self, transforms): """Initialize transforms before application.""" if not isinstance(transforms, (list, tuple, Compose)): transforms = [transforms] initialized_transforms = [] for t in transforms: if inspect.isclass(t): if issubclass(t, Normalize): initialized_transforms.append(t( mean=self.state.dataset_statistics.mean, std=self.state.dataset_statistics.std, unravel_model=self.state.spatial )) else: initialized_transforms.append(t()) else: initialized_transforms.append(t) return Compose(initialized_transforms) @property def sample_attrs(self): """Attributes represented in the sample.""" return self.state.sample_attributes @sample_attrs.setter def sample_attrs(self, x): x = None if x is None else {comp.upper(): [attr.upper() for attr in x[comp]] for comp in x.keys()} x = self._nested_dicts_to_base_components('SAMPLE_ATTRIBUTES', x) if hasattr(self.state, 'SAMPLE_ATTRIBUTES'): self.set_state(sample_attributes=x) else: self.init_state(sample_attributes=x) @property def device(self): """Get the sample's device (if it is in Torch format) Returns ------- device: torch.device """ ref = None for _, value in self.items(): if isinstance(value, BaseComponent): for _, arr in value.items(): if isinstance(arr, BaseComponent): for _, mask in arr.items(): ref = mask break else: ref = arr break else: ref = value if ref is None: raise RuntimeError('The sample is empty!') if isinstance(ref, torch.Tensor): return ref.device raise RuntimeError('The sample should be in the PyTorch format! Found: %s' % type(ref))
[docs] def to(self, device, inplace=True): """Change the sample's device (if it is in Torch format). Parameters ---------- device: str, torch.device inplace: bool Returns ------- sample: FieldSample Sample at the new device """ if self.device == device: return self if inplace else self.copy() out = self if inplace else self.empty_like() for comp, value in self.items(): if isinstance(value, BaseComponent): for attr, arr in value.items(): if isinstance(arr, BaseComponent): for well, mask in arr.items(): out[comp][attr][well] = else: out[comp][attr] = else: out[comp] = return out
class SequenceSubset: """Baseclass for generating subsets of sequences.""" def __init__(self, size, low, high, **kwargs): """ Parameters ---------- size: int Lenght of the generated sequences low: int Minimal possible timestep (inclusive) high: int Maximal possible timestep (exclusive) kwargs: optional """ self.size = size self.low = low self.high = high _ = kwargs def __call__(self): """Generate subset of timesteps.""" raise NotImplementedError('Abstract method is not implemented.')
[docs]class UniformSequenceSubset(SequenceSubset): """Generator of timesteps sampled from a uniform distribution [low, high).""" def __call__(self): subset = np.random.choice(np.arange(self.low, self.high), size=self.size, replace=False) return np.sort(subset)
[docs]class RandomSubsequence(SequenceSubset): """Generator of timestep subsequences.""" def __call__(self): start = np.random.randint(low=self.low, high=self.high - self.size + 1) subset = np.arange(start, start + self.size) return np.sort(subset)