openghg_inversions.basis.operators#
Basis operator classes
Design goals#
Separate the partition/aggregation operator (basis functions) from any flux weighting. - The basis operator represents a linear map from a grid (lat/lon) to a reduced state space. - Flux weighting (multiplying by flux on the grid, interpolation to maps, covariance transforms)
- is handled by a separate wrapper class (planned: FluxWeightedBasis) so that:
sensitivity(fp_x_flux) does not require flux (since fp_x_flux is already precomputed),
but flux-aware operations remain available when needed.
Canonical “state” dimension. - Operators expose a single state dimension (default name: “state”). - In multisource/multisector cases with ragged per-source region counts, the state coordinate
becomes a MultiIndex over (source, region_in_source). This avoids padding with zeros.
Minimal metadata (BasisMeta). - We only need to know which dims to dot over (grid_dims) and the state_dim name. - Any special alignment hacks are implemented in concrete subclasses rather than inferred from metadata.
Serialization via xarray.DataTree. - BasisOperator.to_datatree() returns a self-describing DataTree with schema/kind/version attrs. - BasisOperator.decode_datatree(dt) dispatches to the correct registered subclass based on dt.attrs[“kind”]. - For multisource operators, the canonical serialized representation stores per-source flat basis arrays
under dt[“basis_flat”][<source>], keeping storage compact and natural.
How to use#
- Construct a basis operator:
op = BucketBasisOperator(basis_flat) # single-sector op = MultiSourceBucketBasisOperator({“a”: bf_a, “b”: bf_b}) # ragged multisource
- Compute sensitivities:
H = op.sensitivity(fp_x_flux)
where fp_x_flux is an xarray.DataArray with at least the grid dims (lat, lon), and typically time. In multisource workflows, fp_x_flux often has a separate dimension “source”. The multisource operator implements an alignment/broadcast hack so that the fp_x_flux source dimension can be matched against the MultiIndex level “source” stored on the state coordinate.
- Serialize/deserialize:
dt = op.to_datatree() op2 = BasisOperator.decode_datatree(dt)
Notes
Currently, basis operators cannot have a time dimension. If the input flat array has a time dimension with more than one coordinate value, an error is raised.
- class openghg_inversions.basis.operators.BasisMeta(grid_dims: tuple[str, ...] = ('lat', 'lon'), state_dim: str = 'state')#
Bases:
objectMetadata describing how to apply a basis operator.
The intent is to keep this minimal: we only store what is needed for the default implementations of BasisOperator.sensitivity and BasisOperator.interpolate.
- Variables:
- class openghg_inversions.basis.operators.BasisOperator#
Bases:
ABCAbstract basis operator.
Concrete subclasses must define: - meta (grid dims + state dim) - basis_matrix: one-hot/dummy matrix with dims (*grid_dims, state_dim) - to_datatree / from_datatree
The default sensitivity implementation assumes fp_x_flux has the grid dims and any extra dims (e.g. time, source) are preserved.
- abstract property basis_matrix: DataArray#
Dummy matrix mapping grid -> state.
Expected dims: (*grid_dims, state_dim) (state_dim may be a MultiIndex coordinate)
- classmethod decode_datatree(dt: DataTree) BasisOperator#
Dispatches a DataTree to the correct registered operator subclass.
- Parameters:
dt – DataTree representation of a basis operator.
- Returns:
A concrete BasisOperator instance.
- Raises:
ValueError – If the schema or schema version is unsupported.
KeyError – If the kind is not registered.
- abstractmethod classmethod from_datatree(dt: DataTree) Self#
Constructs an operator instance from an xarray.DataTree.
Concrete subclasses implement this to load whatever canonical representation they write in to_datatree().
- Parameters:
dt – DataTree created by to_datatree().
- Returns:
An instance of the operator.
- interpolate(state: DataArray, weights: DataArray | None = None) DataArray#
Interpolates/reconstructs a gridded field from a state vector.
This maps from the reduced basis space back to the grid by multiplying the basis dummy matrix by a state vector.
If weights is provided (e.g. a flux field on the grid), it is multiplied elementwise with the basis matrix before interpolation. This corresponds to using a flux-weighted interpolation operator.
- Parameters:
state – State vector with dimension meta.state_dim.
weights – Optional gridded weights with dimensions matching meta.grid_dims (and broadcastable to basis_matrix).
- Returns:
Reconstructed gridded field with dimensions including meta.grid_dims.
- Raises:
ValueError – If meta.state_dim is not a dimension of state.
- sensitivity(fp_x_flux: DataArray, fillna: bool = True) DataArray#
Computes the sensitivity matrix (“H”) by dotting over the grid.
This implements the common bucket-basis reduction:
fp_x_flux is a gridded quantity with dimensions that include meta.grid_dims (typically lat and lon) and usually a time dimension.
basis_matrix is a one-hot/dummy matrix that maps each grid cell to exactly one basis region/state.
The returned array keeps all non-grid dimensions from fp_x_flux and includes the reduced state dimension meta.state_dim.
- Parameters:
fp_x_flux – Footprint x flux array to reduce. Must contain all meta.grid_dims.
fillna – if True, fill NaNs in fp_x_flux with 0.0.
- Returns:
Sensitivity matrix with dimension meta.state_dim and any remaining non-grid dimensions (e.g. time).
- class openghg_inversions.basis.operators.BucketBasisOperator(basis_flat: DataArray, *, meta: BasisMeta | None = None, state_dim: str | None = None, region_labels: Literal['range0', 'range1', 'basis_values'] = 'range0', chunks: dict[str, int] | None = None)#
Bases:
BasisOperatorSingle flat bucket basis: basis_flat(lat, lon) with integer region labels.
Stores basis_flat and constructs basis_matrix via get_xr_dummies.
- class openghg_inversions.basis.operators.MultiSourceBucketBasisOperator(basis_flat: dict[str, DataArray], *, meta: BasisMeta | None = None, source_dim: str = 'source', region_in_source_dim: str = 'region_in_source', state_dim: str | None = None, chunks: dict[str, int] | None = None)#
Bases:
BasisOperatorMultiple flat bases keyed by source, with potentially ragged region counts.
Canonical state_dim is a gathered MultiIndex over (source, region_in_source).
- classmethod from_datatree(dt: DataTree) Self#
Deserialises a MultiSourceBucketBasisOperator from a DataTree.
- Parameters:
dt – DataTree produced by MultiSourceBucketBasisOperator.to_datatree().
- Returns:
A reconstructed MultiSourceBucketBasisOperator.
- interpolate(state: DataArray, weights: DataArray | None = None) DataArray#
Interpolate/reconstruct a gridded field from a state vector.
For MultiSourceBucketBasisOperator, weights may include a source_dim that is also a level name in the gathered MultiIndex on meta.state_dim. In that case we broadcast weights along the gathered state axis (repeating per-source weights across all regions within that source) by replacing source_dim with meta.state_dim.
The state vector itself is expected to be defined on meta.state_dim and should not include a separate coordinate named like a MultiIndex level (e.g. source_dim).
- Parameters:
state – State vector defined on meta.state_dim.
weights – Optional gridded weights (e.g. prior fluxes) on meta.grid_dims. May optionally include source_dim for per-source weights.
- Returns:
Gridded reconstructed field on meta.grid_dims.
- openghg_inversions.basis.operators.drop_singleton_time(da: DataArray, *, name: str = 'basis_flat') DataArray#
Drop a singleton
timedimension if present; otherwise raise.This is a strict helper intended for basis operators that assume a 2D basis over the grid dims. It avoids silently discarding time-varying basis information.
- Parameters:
da – Input DataArray which may or may not have a
timedimension.name – Label used in error messages to identify what is being checked.
- Returns:
dawithtimeremoved if it exists and has length 1, otherwisedaunchanged.- Raises:
ValueError – If
timeexists and has length not equal to 1.
- openghg_inversions.basis.operators.get_basis_operator_class(kind: str) type[BasisOperator]#
Looks up a registered BasisOperator subclass.
- Parameters:
kind – Registry key for the operator type (e.g. “bucket”).
- Returns:
The registered BasisOperator subclass.
- Raises:
KeyError – If kind is not registered.
- openghg_inversions.basis.operators.register_basis_operator(kind: str)#
Registers a BasisOperator subclass for DataTree deserialisation.
This decorator builds a small module-level registry mapping a stable kind string (stored in dt.attrs[“kind”]) to a concrete BasisOperator subclass.
- Parameters:
kind – Stable key identifying the operator type on disk. This is written to dt.attrs[“kind”] by BasisOperator.to_datatree() and used by BasisOperator.decode_datatree().
- Returns:
A class decorator that registers the decorated class under kind.
- Raises:
ValueError – If kind is already registered to a different class.