-
Notifications
You must be signed in to change notification settings - Fork 18
ENH: cov: expose correction and weights parameters #690
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
0c57e2b
d9701e0
72e2d61
c0a20b0
c621a74
98f216a
06b4007
8b5c471
cb717b0
e34d415
f560372
e5f5d6a
2de1394
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -281,31 +281,79 @@ def broadcast_shapes(*shapes: tuple[float | None, ...]) -> tuple[int | None, ... | |
| return tuple(out) | ||
|
|
||
|
|
||
| def cov(m: Array, /, *, xp: ModuleType) -> Array: # numpydoc ignore=PR01,RT01 | ||
| def cov( | ||
| m: Array, | ||
| /, | ||
| *, | ||
| correction: int | float = 1, | ||
| fweights: Array | None = None, | ||
| aweights: Array | None = None, | ||
| xp: ModuleType, | ||
| ) -> Array: # numpydoc ignore=PR01,RT01 | ||
| """See docstring in array_api_extra._delegation.""" | ||
| m = xp.asarray(m, copy=True) | ||
| m = xp.asarray(m) | ||
|
Comment on lines
-286
to
+294
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why do we drop this?
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It was mostly one small optimization that I made, as the new code for the covariance doesn't mutate in-place anymore. Like, if I understand correctly, we need to copy because of this line: But now, as we are doing: m_c = m - avg
m_w = m_c if w is None else m_c * w
m_cT = xp.matrix_transpose(m_c)
c = (m_w @ m_cT) / factI noticed this by accident, then I was testing the speed test, and I noticed some small regression. I think it's worth disabling copying. |
||
| dtype = ( | ||
| xp.float64 if xp.isdtype(m.dtype, "integral") else xp.result_type(m, xp.float64) | ||
| ) | ||
|
|
||
| m = atleast_nd(m, ndim=2, xp=xp) | ||
| m = xp.astype(m, dtype) | ||
|
|
||
| avg = xp.mean(m, axis=-1, keepdims=True) | ||
| # Validate weight shapes (eager metadata, lazy-safe). Native backends | ||
| # validate themselves; this covers the generic path (array-api-strict, | ||
| # sparse, and the dask+weights fallback where the native check is | ||
| # bypassed to preserve laziness). | ||
| n_obs = m.shape[-1] | ||
| for name, w_in in (("fweights", fweights), ("aweights", aweights)): | ||
| if w_in is None: | ||
| continue | ||
| if w_in.ndim != 1: | ||
| msg = f"`{name}` must be 1-D, got ndim={w_in.ndim}" | ||
| raise ValueError(msg) | ||
| if w_in.shape[0] != n_obs: | ||
| msg = ( | ||
| f"`{name}` has length {w_in.shape[0]} but `m` has {n_obs} observations" | ||
| ) | ||
| raise ValueError(msg) | ||
|
|
||
| fw = None | ||
| if fweights is not None: | ||
| fw = xp.astype(xp.asarray(fweights), dtype) | ||
| aw = None | ||
| if aweights is not None: | ||
| aw = xp.astype(xp.asarray(aweights), dtype) | ||
| if fw is None and aw is None: | ||
| w = None | ||
| elif fw is None: | ||
| w = aw | ||
| elif aw is None: | ||
| w = fw | ||
| else: | ||
| w = fw * aw | ||
|
|
||
| m_shape = eager_shape(m) | ||
| fact = m_shape[-1] - 1 | ||
|
|
||
| if fact <= 0: | ||
| warnings.warn("Degrees of freedom <= 0 for slice", RuntimeWarning, stacklevel=2) | ||
| fact = 0 | ||
|
|
||
| m -= avg | ||
| m_transpose = xp.matrix_transpose(m) | ||
| if xp.isdtype(m_transpose.dtype, "complex floating"): | ||
| m_transpose = xp.conj(m_transpose) | ||
| c = xp.matmul(m, m_transpose) | ||
| c /= fact | ||
| if w is None: | ||
| avg = xp.mean(m, axis=-1, keepdims=True) | ||
| fact = m_shape[-1] - correction | ||
| if fact <= 0: | ||
| warnings.warn( | ||
| "Degrees of freedom <= 0 for slice", RuntimeWarning, stacklevel=2 | ||
| ) | ||
| fact = 0 | ||
| else: | ||
| v1 = xp.sum(w, axis=-1) | ||
| avg = xp.sum(m * w, axis=-1, keepdims=True) / v1 | ||
| if aw is None: | ||
| fact = v1 - correction | ||
| else: | ||
| fact = v1 - correction * xp.sum(w * aw, axis=-1) / v1 | ||
|
|
||
| m_c = m - avg | ||
| m_w = m_c if w is None else m_c * w | ||
| m_cT = xp.matrix_transpose(m_c) | ||
|
lucascolley marked this conversation as resolved.
|
||
| if xp.isdtype(m_cT.dtype, "complex floating"): | ||
| m_cT = xp.conj(m_cT) | ||
| c = m_w @ m_cT / fact | ||
| axes = tuple(axis for axis, length in enumerate(c.shape) if length == 1) | ||
| return xp.squeeze(c, axis=axes) | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
just a thought, maybe some common logic can be extracted out from this and
array-api-extra/src/array_api_extra/_delegation.py
Lines 313 to 316 in af12cd5
that is for a tuple of axes though, so maybe not
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I couldn't think of something to optimize this. The only thing that i could think here is the normalize_axis_index from numpy, but then we would need another PR to introduce here in the library ;p
https://numpy.org/doc/2.1/reference/generated/numpy.lib.array_utils.normalize_axis_index.html#numpy.lib.array_utils.normalize_axis_index