Source code for herbie.accessors

## Brian Blaylock
## April 23, 2021

"""Herbie's xarray accessors.

Other useful packages

- salem: mask data by a geographic region: https://github.com/fmaussion/salem
- xoak: xarray nearest neighbor https://github.com/xarray-contrib/xoak
"""

import functools
import pickle
import re
import warnings
from pathlib import Path
from typing import Literal, Optional, Union

import numpy as np
import pandas as pd
import xarray as xr
from pyproj import CRS

import herbie
from .pick_points import GridPointPicker

_level_units = dict(
    adiabaticCondensation="adiabatic condensation",
    atmosphere="atmosphere",
    atmosphereSingleLayer="atmosphere single layer",
    boundaryLayerCloudLayer="boundary layer cloud layer",
    cloudBase="cloud base",
    cloudCeiling="cloud ceiling",
    cloudTop="cloud top",
    depthBelowLand="m",
    equilibrium="equilibrium",
    heightAboveGround="m",
    heightAboveGroundLayer="m",
    highCloudLayer="high cloud layer",
    highestTroposphericFreezing="highest tropospheric freezing",
    isobaricInhPa="hPa",
    isobaricLayer="hPa",
    isothermZero="0 C",
    isothermal="K",
    level="m",
    lowCloudLayer="low cloud layer",
    meanSea="MSLP",
    middleCloudLayer="middle cloud layer",
    nominalTop="nominal top",
    pressureFromGroundLayer="Pa",
    sigma="sigma",
    sigmaLayer="sigmaLayer",
    surface="surface",
)


def add_proj_info(ds: xr.Dataset):
    """Add projection info to a Dataset."""
    raise NotImplementedError("This function `add_proj_info` is not yet implemented.")

    # TODO: remove pyproj dependency

    match = re.search(r'"source": "(.*?)"', ds.history)
    FILE = Path(match.group(1))

    # Get CF grid projection information with pygrib and pyproj because
    # this is something cfgrib doesn't do (https://github.com/ecmwf/cfgrib/issues/251)
    # NOTE: Assumes the projection is the same for all variables
    with pygrib.open(str(FILE)) as grb:
        msg = grb.message(1)
        cf_params = CRS(msg.projparams).to_cf()

    # Funny stuff with polar stereographic (https://github.com/pyproj4/pyproj/issues/856)
    # TODO: Is there a better way to handle this? What about south pole?
    if cf_params["grid_mapping_name"] == "polar_stereographic":
        cf_params["latitude_of_projection_origin"] = cf_params.get(
            "latitude_of_projection_origin", 90
        )

    # ----------------------
    # Attach CF grid mapping
    # ----------------------
    # http://cfconventions.org/Data/cf-conventions/cf-conventions-1.8/cf-conventions.html#appendix-grid-mappings
    ds["gribfile_projection"] = None
    ds["gribfile_projection"].attrs = cf_params
    ds["gribfile_projection"].attrs["long_name"] = "model grid projection"

    # Assign this grid_mapping for all variables
    for var in list(ds):
        if var == "gribfile_projection":
            continue
        ds[var].attrs["grid_mapping"] = "gribfile_projection"


[docs] @xr.register_dataset_accessor("herbie") class HerbieAccessor: """Accessor for xarray Datasets opened with Herbie."""
[docs] def __init__(self, xarray_obj): self._obj = xarray_obj self._center = None
@property def center(self) -> tuple[float, float]: """Return the geographic center point of this dataset.""" if self._center is None: lat = self._obj.latitude lon = self._obj.longitude self._center = (float(lon.mean()), float(lat.mean())) return self._center
[docs] def to_180(self) -> xr.Dataset: """Wrap longitude coordinates as range [-180,180].""" ds = self._obj ds["longitude"] = (ds["longitude"] + 180) % 360 - 180 return ds
[docs] def to_360(self) -> xr.Dataset: """Wrap longitude coordinates as range [0,360].""" ds = self._obj ds["longitude"] = (ds["longitude"] - 360) % 360 return ds
[docs] @functools.cached_property def crs(self): """ Cartopy coordinate reference system (crs) from a cfgrib Dataset. Projection information is from the grib2 message for each variable. Parameters ---------- ds : xarray.Dataset An xarray.Dataset from a GRIB2 file opened by the cfgrib engine. """ try: import metpy except ModuleNotFoundError: raise ModuleNotFoundError( "metpy is an 'extra' requirement, please use " "`pip install 'herbie-data[extras]'` for the full functionality." ) ds = self._obj # Get variables that have dimensions # (this filters out the gribfile_projection variable) variables = [i for i in list(ds) if len(ds[i].dims) > 0] ds = ds.metpy.parse_cf(varname=variables) crs = ds.metpy_crs.item().to_cartopy() return crs
[docs] @functools.cached_property def polygon(self): """Get a polygon of the domain boundary.""" try: import cartopy.crs as ccrs from shapely.geometry import Polygon except ModuleNotFoundError: raise ModuleNotFoundError( "cartopy is an 'extra' requirements, please use " "`pip install 'herbie-data[extras]'` for the full functionality." ) ds = self._obj LON = ds.longitude.data LAT = ds.latitude.data # Path of array outside border starting from the lower left corner # and going around the array counter-clockwise. outside = ( list(zip(LON[0, :], LAT[0, :])) + list(zip(LON[:, -1], LAT[:, -1])) + list(zip(LON[-1, ::-1], LAT[-1, ::-1])) + list(zip(LON[::-1, 0], LAT[::-1, 0])) ) outside = np.array(outside) ############################### # Polygon in Lat/Lon coordinates x = outside[:, 0] y = outside[:, 1] domain_polygon_latlon = Polygon(zip(x, y)) ################################### # Polygon in projection coordinates transform = self.crs.transform_points(ccrs.PlateCarree(), x, y) # Remove any points that run off the projection map (i.e., point's value is `inf`). transform = transform[~np.isinf(transform).any(axis=1)] x = transform[:, 0] y = transform[:, 1] domain_polygon = Polygon(zip(x, y)) return domain_polygon, domain_polygon_latlon
[docs] def with_wind( self, which: Literal["both", "speed", "direction"] = "both" ) -> xr.Dataset: """Return Dataset with calculated wind speed and/or direction. Consistent with the eccodes GRIB parameter database, variables names are assigned as follows: - "si10" : 10 metre wind speed (note this is not ws10 as you might expect) - "wdir10" : 10 metre wind direction - "ws" : wind speed - "wdir" : wind direction Refer to the eccodes database <https://codes.ecmwf.int/grib/param-db/>. Parameters ---------- which : {'both', 'speed', 'direction'} Specify which wind quantity to compute. """ ds = self._obj n_computed = 0 if which in ("speed", "both"): if {"u10", "v10"}.issubset(ds): ds["si10"] = np.sqrt(ds.u10**2 + ds.v10**2) ds["si10"].attrs["GRIB_paramId"] = 207 ds["si10"].attrs["long_name"] = "10 metre wind speed" ds["si10"].attrs["units"] = "m s**-1" ds["si10"].attrs["standard_name"] = "wind_speed" ds["si10"].attrs["grid_mapping"] = ds.u10.attrs.get("grid_mapping") n_computed += 1 if {"u100", "v100"}.issubset(ds): ds["si100"] = np.sqrt(ds.u100**2 + ds.v100**2) ds["si100"].attrs["GRIB_paramId"] = 228249 ds["si100"].attrs["long_name"] = "100 metre wind speed" ds["si100"].attrs["units"] = "m s**-1" ds["si100"].attrs["standard_name"] = "wind_speed" ds["si100"].attrs["grid_mapping"] = ds.u100.attrs.get("grid_mapping") n_computed += 1 if {"u80", "v80"}.issubset(ds): ds["si80"] = np.sqrt(ds.u80**2 + ds.v80**2) ds["si80"].attrs["long_name"] = "80 metre wind speed" ds["si80"].attrs["units"] = "m s**-1" ds["si80"].attrs["standard_name"] = "wind_speed" ds["si80"].attrs["grid_mapping"] = ds.u80.attrs.get("grid_mapping") n_computed += 1 if {"u", "v"}.issubset(ds): ds["ws"] = np.sqrt(ds.u**2 + ds.v**2) ds["ws"].attrs["GRIB_paramId"] = 10 ds["ws"].attrs["long_name"] = "wind speed" ds["ws"].attrs["units"] = "m s**-1" ds["ws"].attrs["standard_name"] = "wind_speed" ds["ws"].attrs["grid_mapping"] = ds.u.attrs.get("grid_mapping") n_computed += 1 if which in ("direction", "both"): if {"u10", "v10"}.issubset(ds): ds["wdir10"] = ( (270 - np.rad2deg(np.arctan2(ds.v10, ds.u10))) % 360 ).where((ds.u10 != 0) | (ds.v10 != 0)) ds["wdir10"].attrs["GRIB_paramId"] = 260260 ds["wdir10"].attrs["long_name"] = "10 metre wind direction" ds["wdir10"].attrs["units"] = "degree" ds["wdir10"].attrs["standard_name"] = "wind_from_direction" ds["wdir10"].attrs["grid_mapping"] = ds.u10.attrs.get("grid_mapping") n_computed += 1 if {"u100", "v100"}.issubset(ds): ds["wdir100"] = ( (270 - np.rad2deg(np.arctan2(ds.v100, ds.u100))) % 360 ).where((ds.u100 != 0) | (ds.v100 != 0)) ds["wdir100"].attrs["long_name"] = "100 metre wind direction" ds["wdir100"].attrs["units"] = "degree" ds["wdir100"].attrs["standard_name"] = "wind_from_direction" ds["wdir100"].attrs["grid_mapping"] = ds.u100.attrs.get("grid_mapping") n_computed += 1 if {"u80", "v80"}.issubset(ds): ds["wdir80"] = ( (270 - np.rad2deg(np.arctan2(ds.v80, ds.u80))) % 360 ).where((ds.u80 != 0) | (ds.v80 != 0)) ds["wdir80"].attrs["long_name"] = "80 metre wind direction" ds["wdir80"].attrs["units"] = "degree" ds["wdir80"].attrs["standard_name"] = "wind_from_direction" ds["wdir80"].attrs["grid_mapping"] = ds.u80.attrs.get("grid_mapping") n_computed += 1 if {"u", "v"}.issubset(ds): ds["wdir"] = ((270 - np.rad2deg(np.arctan2(ds.v, ds.u))) % 360).where( (ds.u != 0) | (ds.v != 0) ) ds["wdir"].attrs["GRIB_paramId"] = 3031 ds["wdir"].attrs["long_name"] = "wind direction" ds["wdir"].attrs["units"] = "degree" ds["wdir"].attrs["standard_name"] = "wind_from_direction" ds["wdir"].attrs["grid_mapping"] = ds.u.attrs.get("grid_mapping") n_computed += 1 if n_computed == 0: warnings.warn("`with_wind()` did not do anything.") return ds
[docs] def pick_points( self, points: pd.DataFrame, method: Literal["nearest", "weighted"] = "nearest", *, k: Optional[int] = None, max_distance: Union[int, float] = 500, use_cached_tree: Union[bool, Literal["replant"]] = True, tree_name: Optional[str] = None, verbose: bool = False, ) -> xr.Dataset: """Pick nearest neighbor grid values at selected points. Parameters ---------- points : pd.DataFrame DataFrame with 'latitude' and 'longitude' columns method : {'nearest', 'weighted'} Method for point picking: - 'nearest': Get value at nearest grid point - 'weighted': Inverse-distance weighted mean of k nearest points k : int, optional Number of nearest neighbors. Defaults to 1 for 'nearest', 4 for 'weighted' max_distance : float Maximum distance in km for valid neighbors (default: 500) use_cached_tree : bool or 'replant' Whether to cache the BallTree spatial index: - True: Use cached tree if exists, create and cache if not - False: Always create new tree, don't cache - 'replant': Force create new tree and cache it tree_name : str, optional Custom name for BallTree cache file. If None, uses ds.model verbose : bool Print verbose output (currently unused) Returns ------- xr.Dataset Dataset with values at requested points. Includes coordinates: - point_latitude, point_longitude: requested coordinates - point_grid_distance: distance to nearest grid point (km) - Any other columns from points DataFrame Examples -------- >>> H = Herbie("2024-03-28 00:00", model="hrrr") >>> ds = H.xarray("TMP:[5,6,7,8,9][0,5]0 mb") >>> points = pd.DataFrame({ ... "longitude": [-100, -105, -98.4], ... "latitude": [40, 29, 42.3], ... "stid": ["aa", "bb", "cc"], ... }) Get nearest neighbor values: >>> dsp = ds.herbie.pick_points(points, method="nearest") Get distance-weighted mean of 4 nearest points: >>> dsp = ds.herbie.pick_points(points, method="weighted") Convert to DataFrame: >>> df = dsp.to_dataframe() Index by station ID: >>> dsp = dsp.swap_dims({"point": "point_stid"}) """ cache_dir = herbie.config["default"]["save_dir"] picker = GridPointPicker( ds=self._obj, cache_dir=cache_dir, tree_name=tree_name, ) return picker.pick_points( points=points, method=method, k=k, max_distance=max_distance, use_cached_tree=use_cached_tree, )
[docs] def nearest_points(self, points, names=None, verbose=True): """Legacy point extractor.""" from herbie.nearest_points import nearest_points return nearest_points( ds=self._obj, points=points, names=names, verbose=verbose, )
[docs] def plot(self, ax=None, common_features_kw={}, vars=None, **kwargs): """Plot data on a map. Parameters ---------- vars : list List of variables to plot. Default None will plot all variables in the DataSet. """ raise NotImplementedError( "Plotting functionality is not working right now. If you have ideas, please open a PR." )