Source code for pymt.mappers.esmp

import numpy as np

from ..grids.esmp import EsmpUnstructuredField
from .imapper import IGridMapper
from .mapper import IncompatibleGridError

try:
    import ESMF as esmf
except ImportError:
    esmf = None
    REGRID_METHOD = None
    UNMAPPED_ACTION = None
else:
    REGRID_METHOD = esmf.RegridMethod.CONSERVE
    UNMAPPED_ACTION = esmf.UnmappedAction.ERROR


[docs]class EsmpMapper(IGridMapper): _name = None @property def name(self): return self._name
[docs] @staticmethod def test(dst_grid, src_grid): raise NotImplementedError("test")
[docs] def init_fields(self): raise NotImplementedError("init_fields")
[docs] def initialize(self, dest_grid, src_grid, **kwds): method = kwds.get("method", REGRID_METHOD) unmapped = kwds.get("unmapped", UNMAPPED_ACTION) if not EsmpMapper.test(dest_grid, src_grid): raise IncompatibleGridError(dest_grid.name, src_grid.name) self._src = EsmpUnstructuredField( src_grid.get_x(), src_grid.get_y(), src_grid.get_connectivity(), src_grid.get_offset(), ) self._dst = EsmpUnstructuredField( dest_grid.get_x(), dest_grid.get_y(), dest_grid.get_connectivity(), dest_grid.get_offset(), ) self.init_fields() self._regridder = esmf.Regrid( self.get_source_field(), self.get_dest_field(), regrid_method=method, unmapped_action=unmapped, )
[docs] def run(self, src_values, **kwds): dest_values = kwds.get("dest_values", None) src_ptr = self.get_source_data() src_ptr[:] = src_values # src_ptr.data = src_values # dst_ptr = ESMP.ESMP_FieldGetPtr(dst_field) dst_ptr = self.get_dest_data() if dest_values is not None: dst_ptr[:] = dest_values # dst_ptr.data = dest_values else: dst_ptr.fill(0.0) # ESMP.ESMP_FieldRegrid(self.get_source_field(), self.get_dest_field(), # self._routehandle) self._regridder(self.get_source_field(), self.get_dest_field()) return dest_values
[docs] def finalize(self): pass
# ESMP.ESMP_FieldRegridRelease(self._routehandle)
[docs] def get_source_field(self): return self._src.as_esmp("src")
[docs] def get_dest_field(self): return self._dst.as_esmp("dst")
[docs] def get_source_data(self): return self._src.as_esmp("src").data
# return ESMP.ESMP_FieldGetPtr(self.get_source_field())
[docs] def get_dest_data(self): return self._dst.as_esmp("dst").data
# return ESMP.ESMP_FieldGetPtr(self.get_dest_field())
[docs]class EsmpCellToCell(EsmpMapper): _name = "CellToCell"
[docs] def init_fields(self): data = np.empty(self._src.get_cell_count(), dtype=np.float64) self._src.add_field("src", data, centering="zonal") data = np.empty(self._dst.get_cell_count(), dtype=np.float64) self._dst.add_field("dst", data, centering="zonal")
[docs] @staticmethod def test(dst_grid, src_grid): return all(np.diff(dst_grid.get_offset()) > 2) and all( np.diff(src_grid.get_offset()) > 2 )
[docs]class EsmpPointToPoint(EsmpMapper): _name = "PointToPoint"
[docs] def init_fields(self): data = np.empty(self._src.get_point_count(), dtype=np.float64) self._src.add_field("src", data, centering="point") data = np.empty(self._dst.get_point_count(), dtype=np.float64) self._dst.add_field("dst", data, centering="point")
[docs] @staticmethod def test(dst_grid, src_grid): return dst_grid is not None and src_grid is not None