Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 73 additions & 0 deletions test/core/test_vector_calculus.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,79 @@ def test_divergence_basic(self, gridpath, datasetpath):
assert np.isfinite(div_field.values).any()


class TestScalarDotGradientMPASOcean:

def test_scalardotgradient_uses_known_gradient_components(
self, gridpath, datasetpath, monkeypatch
):
"""Test scalar dot gradient against independently supplied gradients."""
uxds = ux.open_dataset(
gridpath("mpas", "QU", "480", "grid.nc"),
datasetpath("mpas", "QU", "480", "data.nc"),
)

n_face = uxds.uxgrid.n_face
dims = ["n_face"]
scalar = ux.UxDataArray(
np.zeros(n_face), dims=dims, uxgrid=uxds.uxgrid, name="scalar"
)
u_component = ux.UxDataArray(
np.full(n_face, 2.0), dims=dims, uxgrid=uxds.uxgrid, name="u"
)
v_component = ux.UxDataArray(
np.full(n_face, -0.5), dims=dims, uxgrid=uxds.uxgrid, name="v"
)

def mock_gradient(self):
return ux.UxDataset(
{
"zonal_gradient": ux.UxDataArray(
np.full(n_face, 3.0), dims=dims, uxgrid=self.uxgrid
),
"meridional_gradient": ux.UxDataArray(
np.full(n_face, -4.0), dims=dims, uxgrid=self.uxgrid
),
},
uxgrid=self.uxgrid,
)

monkeypatch.setattr(ux.UxDataArray, "gradient", mock_gradient)

result = u_component.scalardotgradient(v_component, scalar)

expected = np.full(n_face, 8.0)
nt.assert_allclose(result.values, expected, rtol=0.0, atol=0.0)

assert isinstance(result, ux.UxDataArray)
assert result.name == "scalar_dot_gradient"
assert result.attrs["long_name"] == "scalar dot gradient"
assert result.sizes == u_component.sizes

def test_scalardotgradient_rejects_misaligned_indexes(self, gridpath, datasetpath):
"""Test scalar dot gradient fails instead of silently realigning faces."""
uxds = ux.open_dataset(
gridpath("mpas", "QU", "480", "grid.nc"),
datasetpath("mpas", "QU", "480", "data.nc"),
)

scalar = uxds["bottomDepth"]
u_component = ux.UxDataArray(
np.ones(scalar.size),
dims=["n_face"],
coords={"n_face": np.arange(scalar.size)},
uxgrid=uxds.uxgrid,
)
v_component = ux.UxDataArray(
np.ones(scalar.size),
dims=["n_face"],
coords={"n_face": np.arange(scalar.size) + 1},
uxgrid=uxds.uxgrid,
)

with pytest.raises(ValueError):
u_component.scalardotgradient(v_component, scalar)


class TestDivergenceDyamondSubset:

def test_divergence_constant_field(self, gridpath, datasetpath):
Expand Down
60 changes: 60 additions & 0 deletions uxarray/core/dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -1614,6 +1614,66 @@ def divergence(self, other: "UxDataArray", **kwargs) -> "UxDataArray":

return divergence_da

def scalardotgradient(self, v: "UxDataArray", q: "UxDataArray") -> "UxDataArray":
"""
Compute the dot product between a vector field and the gradient of a scalar field.

Parameters
----------
v : UxDataArray
The meridional component of the vector field. ``self`` is treated as
the zonal component.
q : UxDataArray
Scalar field whose gradient is dotted with the vector field.

Returns
-------
scalar_dot_gradient : UxDataArray
Dot product ``self * dq/dx + v * dq/dy``.
"""
if not isinstance(v, UxDataArray):
raise TypeError("v must be a UxDataArray")

if not isinstance(q, UxDataArray):
raise TypeError("q must be a UxDataArray")

if self.uxgrid != v.uxgrid or self.uxgrid != q.uxgrid:
raise ValueError("All UxDataArrays must have the same grid")

if self.dims != v.dims or self.dims != q.dims:
raise ValueError("All UxDataArrays must have the same dimensions")

if self.ndim > 1:
raise ValueError(
"Scalar dot gradient currently requires 1D face-centered data. "
"Consider selecting a single slice before computing."
)

if not (self._face_centered() and v._face_centered() and q._face_centered()):
raise ValueError(
"Computing the scalar dot gradient is only supported for face-centered data variables."
)

u = self

q_gradient = q.gradient()
q_zonal = q_gradient["zonal_gradient"]
q_meridional = q_gradient["meridional_gradient"]

u_aligned, v_aligned, q_zonal, q_meridional = xr.align(
u, v, q_zonal, q_meridional, join="exact", copy=False
)
scalar_dot_gradient = (u_aligned * q_zonal) + (v_aligned * q_meridional)
scalar_dot_gradient.name = "scalar_dot_gradient"
scalar_dot_gradient.attrs.update(
{
"long_name": "scalar dot gradient",
"description": "Dot product u * (dq/dx) + v * (dq/dy).",
}
)

return UxDataArray(scalar_dot_gradient, uxgrid=self.uxgrid)

def difference(self, destination: str | None = "edge"):
"""Computes the absolute difference of a data variable.

Expand Down
Loading