Source code for pymt.framework.bmi_mapper

#! /usr/bin/env python
import warnings

import numpy as np

try:
    import ESMF as esmf
except ImportError:
    esmf = None


REGRID_METHODS = {}
UNMAPPED_ACTIONS = {}

if esmf is not None:
    REGRID_METHODS.update(
        {
            "bilinear": esmf.RegridMethod.BILINEAR,
            "nearest": esmf.RegridMethod.NEAREST_STOD,
            "conserve": esmf.RegridMethod.CONSERVE,
        }
    )
    UNMAPPED_ACTIONS.update(
        {"pass": esmf.UnmappedAction.IGNORE, "raise": esmf.UnmappedAction.ERROR}
    )


[docs]def ravel_jaggedarray(array): values_per_row = np.sum(array >= 0, axis=1) raveled_array = np.empty(values_per_row.sum(), dtype=array.dtype) offset = 0 for row, n_values in zip(array, values_per_row): raveled_array[offset : offset + n_values] = row[row >= 0] offset += n_values return raveled_array, values_per_row
[docs]def bmi_as_esmf_mesh(bmi_grid): xy_at_node = np.vstack((bmi_grid.node_x.values, bmi_grid.node_y.values)).T.copy() if "face_node_connectivity" in bmi_grid: nodes_at_patch = bmi_grid.face_node_connectivity.values nodes_per_patch = np.diff( np.concatenate(([0], bmi_grid.face_node_offset.values)) ) else: nodes_at_patch = None nodes_per_patch = None return as_esmf_mesh(xy_at_node, nodes_at_patch, nodes_per_patch)
# return as_esmf_mesh(xy_at_node, np.astype(nodes_at_patch, dtype=np.int32), # np.astype(nodes_per_patch, dtype=np.int32))
[docs]def as_esmf_mesh(xy_of_node, nodes_at_patch=None, nodes_per_patch=None): mesh = esmf.Mesh(parametric_dim=2, spatial_dim=2) n_nodes = len(xy_of_node) node_ids = np.arange(1, n_nodes + 1, dtype=np.int32) node_owner = np.zeros(n_nodes, dtype=np.int32) mesh.add_nodes(n_nodes, node_ids, xy_of_node, node_owner) if nodes_at_patch is not None: if nodes_at_patch.ndim == 2: n_faces = len(nodes_at_patch) else: n_faces = len(nodes_per_patch) face_ids = np.arange(1, n_faces + 1, dtype=np.int32) if nodes_at_patch.ndim == 2: face_conn, nodes_per_face = ravel_jaggedarray(nodes_at_patch) else: face_conn, nodes_per_face = nodes_at_patch, nodes_per_patch face_conn = face_conn.astype(dtype=np.int32, copy=False) nodes_per_face = nodes_per_face.astype(dtype=np.int32, copy=False) if np.all((nodes_per_face == 3) | (nodes_per_face == 4)): face_types = np.full(n_faces, -1, dtype=np.int32) face_types[np.where(nodes_per_face == 3)] = esmf.MeshElemType.TRI face_types[np.where(nodes_per_face == 4)] = esmf.MeshElemType.QUAD mesh.add_elements(n_faces, face_ids, face_types, face_conn) else: warnings.warn("mesh contains non-triangular or quadrilateral elements") return mesh
[docs]def as_esmf_field(mesh, field_name, data=None, at="node"): if at == "node": meshloc = esmf.MeshLoc.NODE elif at == "cell": meshloc = esmf.MeshLoc.ELEMENT else: raise ValueError("'at' location not understood (must be 'cell' or 'node')") field = esmf.Field(mesh, field_name, meshloc=meshloc) if data is not None: np.copyto(field.data, data.reshape(field.data.shape)) return field
[docs]def graph_as_esmf(graph, field_name, data=None, at="node"): mesh = as_esmf_mesh(graph.xy_of_node, graph.nodes_at_patch) field = as_esmf_field(mesh, field_name, data=data, at=at) return field
[docs]def run_regridding(srcfield, dstfield, method="nearest", unmapped="pass"): """run_regridding(source_field, destination_field, method=ESMP_REGRIDMETHOD_CONSERVE, unmapped=ESMP_UNMAPPEDACTION_ERROR) **PRECONDITIONS:** Two ESMP_Fields have been created and a regridding operation is desired from 'srcfield' to 'dstfield'. **POSTCONDITIONS:** An ESMP regridding operation has set the data on 'dstfield'. """ # method = kwds.get('method', ESMF.RegridMethod.NEAREST_STOD) # method = kwds.get('method', ESMF.RegridMethod.BILINEAR) # unmapped = kwds.get('unmapped', ESMF.UnmappedAction.IGNORE) # method = kwds.get('method', ESMF.RegridMethod.CONSERVE) # unmapped = kwds.get('unmapped', ESMF.UnmappedAction.ERROR) try: method = REGRID_METHODS[method] except KeyError: raise ValueError("regrid method not understood") try: unmapped = UNMAPPED_ACTIONS[unmapped] except KeyError: raise ValueError("unmapped action not understood") # call the regridding functions masked_values = np.array([-9999.0]) regridder = esmf.Regrid( srcfield, dstfield, regrid_method=method, unmapped_action=unmapped, src_mask_values=masked_values, dst_mask_values=masked_values, ) dstfield = regridder(srcfield, dstfield) return dstfield
[docs]class GridMapperMixIn(object): def _esmf_mesh_by_id(self, gid): try: self._esmf_mesh except AttributeError: self._esmf_mesh = dict() try: self._esmf_mesh[gid] except KeyError: self._esmf_mesh[gid] = bmi_as_esmf_mesh(self.grid[gid]) return self._esmf_mesh[gid] def _esmf_field_by_id(self, gid, name=None, at="node"): name = name or "generic" try: self._esmf_field except AttributeError: self._esmf_field = dict() _id = "{gid}.{name}@{at}".format(gid=gid, name=name, at=at) try: self._esmf_field[_id] except KeyError: self._esmf_field[_id] = as_esmf_field( self._esmf_mesh_by_id(gid), name, at=at ) return self._esmf_field[_id]
[docs] def regrid(self, name, **kwds): """Regrid values from one grid to another. Parameters ---------- name : str Name of the values to regrid. to : bmi_like, optional BMI object onto which to map values. If not provided, map values onto one of the object's own grids. to_name : str, optional Name of the value to map onto. If not provided, use *name*. Returns ------- ndarray The regridded values. """ dst = kwds.pop("to", self) dst_name = kwds.pop("to_name", name) data = self.get_value(name, **kwds) if esmf is not None: src_field = self._esmf_field_by_id(self.var[name].grid, at="node") dst_field = dst._esmf_field_by_id(dst.var[dst_name].grid, at="node") np.copyto(src_field.data, data.reshape(src_field.data.shape)) run_regridding(src_field, dst_field) return dst_field.data else: return data
[docs] def map_to(self, name, **kwds): """Map values to another grid. Parameters ---------- name : str Name of values to push. """ destination = kwds.pop("destination", self) at = kwds.pop("at", name) data = self.regrid(name, to=destination, to_name=at, **kwds) destination.set_value(at, data)
[docs] def set_value(self, name, *args, **kwds): """Set values for a variable. set_value(name, value) set_value(name, mapfrom=self, nomap=None) Parameters ---------- name : str Name of the destination values. """ if len(args) == 1: return super(GridMapperMixIn, self).set_value(name, *args) mapfrom = kwds.pop("mapfrom", self) nomap = kwds.pop("nomap", None) try: value, source = mapfrom except TypeError: value, source = name, mapfrom if nomap is not None: orig = self.get_value(name) data = source.regrid(value, to=self, to_name=name, **kwds) if nomap is not None: data[nomap] = orig[nomap] super(GridMapperMixIn, self).set_value(name, data)
[docs] def map_value(self, name, **kwds): """Map values from another grid. Parameters ---------- name : str Name of values to map to. mapfrom : tuple or bmi_like, optional BMI object from which values are mapped from. This can also be a tuple of *(name, bmi)*, where *name* is the variable of the source grid and *bmi* is the bmi-like source. If not provided, use *self*. nomap : narray of bool, optional Values in the destination grid to not map. """ mapfrom = kwds.pop("mapfrom", self) nomap = kwds.pop("nomap", None) try: value, source = mapfrom except TypeError: value, source = name, mapfrom if nomap is not None: orig = self.get_value(name) data = source.regrid(value, to=self, to_name=name, **kwds) if nomap is not None: data[nomap] = orig[nomap] self.set_value(name, data)