Source code for pymt.services.gridreader.gridreader

import warnings

import yaml

from ...printers.nc.read import field_fromfile
from .interpolate import create_interpolators
from .time_series_names import get_time_series_names


[docs]class TimeInterpolator: def __init__(self): self._shape = () self._spacing = () self._origin = () self._input_exchange_items = set() self._output_exchange_items = set() self._start_time = 0.0 self._end_time = 0.0 self._time = 0.0 self._interpolators = {}
[docs] def initialize(self, source): config = read_configuration(source) (fields, times) = field_fromfile(config["input_file"], fmt="NETCDF4") self._shape = fields[0].get_shape() self._spacing = fields[0].get_spacing() self._origin = fields[0].get_origin() self._input_exchange_items = set() self._output_exchange_items = get_time_series_names(fields.keys()) self._start_time = times[0] self._end_time = times[-1] self._time = times[0] self._interpolators = create_interpolators( times, fields, kind=config["interpolation"] )
[docs] def update_until(self, time): if time < self._start_time or time > self._endtime: raise ValueError( "time is outside of start/end time ({} not in [{}, {}])".format( time, self._start_time, self._end_time ) ) self._time = time
[docs] def finalize(self): pass
[docs] def get_start_time(self): return self._start_time
[docs] def get_end_time(self): return self._end_time
[docs] def get_current_time(self): return self._time
[docs] def get_input_var_names(self): return self._input_exchange_items
[docs] def get_output_var_names(self): return self._output_exchange_items
[docs] def get_grid_shape(self, grid_id): if grid_id in (None, ""): warnings.warn("grid identifier is ignored") return self._shape
[docs] def get_grid_spacing(self, grid_id): if grid_id in (None, ""): warnings.warn("grid identifier is ignored") return self._spacing
[docs] def get_grid_origin(self, grid_id): if grid_id in (None, ""): warnings.warn("grid identifier is ignored") return self._origin
[docs] def get_value(self, name): return self._interpolators[name](self._time)
# if name.endswith('_increment'): # name = name[:- len('_increment')] # return (self._interpolators[name](self._time) - # self._interpolators[name](self._last_time)) # else: # return self._interpolators[name](self._time)
[docs]def get_abspath_or_url(filename, prefix=""): import os from urlparse import urlparse, urlunparse parts = urlparse(filename) if parts.scheme in ["file", ""]: return os.path.join(prefix, parts.path) else: return urlunparse(parts)
[docs]def read_configuration(source): config = yaml.safe_load(source) input_file = config.get("input_file") input_dir = config.get("input_dir", "") kind = config.get("interpolation", "linear") return { "input_file": get_abspath_or_url(input_file, prefix=input_dir), "interpolation": kind, }
if __name__ == "__main__": import doctest doctest.testmod()