Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 65 additions & 16 deletions src/pyrecest/filters/out_of_sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,12 +130,16 @@ def latest_at_or_before(self, time):
def items_after(self, time):
"""Return all records with timestamp strictly greater than ``time``."""
query_time = _coerce_time(time)
return tuple(self._copy_item(item) for item in self._items if item.time > query_time)
return tuple(
self._copy_item(item) for item in self._items if item.time > query_time
)

def items_at_or_after(self, time):
"""Return all records with timestamp greater than or equal to ``time``."""
query_time = _coerce_time(time)
return tuple(self._copy_item(item) for item in self._items if item.time >= query_time)
return tuple(
self._copy_item(item) for item in self._items if item.time >= query_time
)

def _copy_item(self, item):
return TimestampedItem(
Expand Down Expand Up @@ -169,7 +173,9 @@ class MeasurementTimeBuffer:
"""Small helper for tracking measurement timestamps and OOSM status."""

def __init__(self, max_lag=None, maxlen=None, *, copy_values=True):
self._buffer = FixedLagBuffer(max_lag=max_lag, maxlen=maxlen, copy_values=copy_values)
self._buffer = FixedLagBuffer(
max_lag=max_lag, maxlen=maxlen, copy_values=copy_values
)

def __len__(self):
return len(self._buffer)
Expand Down Expand Up @@ -301,15 +307,19 @@ def _insert_and_apply(self, time, method_name, args=(), kwargs=None):

if out_of_sequence:
captured_result = self._replay(capture_event=event)
replayed_event_count = len([item for item in self._events if item.time >= event.time])
replayed_event_count = len(
[item for item in self._events if item.time >= event.time]
)
else:
captured_result = self._apply_event(event)
self._latest_time = max(self._latest_time, event_time)
replayed_event_count = 0

self._trim_to_lag()
diagnostics = captured_result if isinstance(captured_result, dict) else None
accepted = True if diagnostics is None else bool(diagnostics.get("accepted", True))
accepted = (
True if diagnostics is None else bool(diagnostics.get("accepted", True))
)
return OutOfSequenceResult(
time=event_time,
final_time=self._latest_time,
Expand Down Expand Up @@ -362,12 +372,18 @@ class OutOfSequenceKalmanUpdater(_EventReplayMixin):
"""Fixed-lag OOSM processor for :class:`KalmanFilter`."""

def __init__(self, kalman_filter, initial_time=0.0, max_lag=None):
filter_object = kalman_filter if isinstance(kalman_filter, KalmanFilter) else KalmanFilter(kalman_filter)
filter_object = (
kalman_filter
if isinstance(kalman_filter, KalmanFilter)
else KalmanFilter(kalman_filter)
)
self._setup_replay(filter_object, initial_time=initial_time, max_lag=max_lag)

def predict_linear(self, time, system_matrix, sys_noise_cov, sys_input=None):
"""Record/apply a timestamped linear-Gaussian prediction."""
return self._insert_and_apply(time, "predict_linear", (system_matrix, sys_noise_cov, sys_input))
return self._insert_and_apply(
time, "predict_linear", (system_matrix, sys_noise_cov, sys_input)
)

def predict_model(self, time, transition_model):
"""Record/apply a timestamped structural transition-model prediction."""
Expand All @@ -389,7 +405,11 @@ def update_linear(
time,
"update_linear",
(measurement, measurement_matrix, meas_noise),
{"return_diagnostics": return_diagnostics, "scale": scale, "action": action},
{
"return_diagnostics": return_diagnostics,
"scale": scale,
"action": action,
},
)

def update_linear_robust(
Expand Down Expand Up @@ -421,18 +441,33 @@ def update_linear_robust(
},
)

def update_model(self, time, measurement_model, measurement, *, return_diagnostics=False, scale=1.0, action="updated"):
def update_model(
self,
time,
measurement_model,
measurement,
*,
return_diagnostics=False,
scale=1.0,
action="updated",
):
"""Record/apply a timestamped structural measurement-model update."""
return self._insert_and_apply(
time,
"update_model",
(measurement_model, measurement),
{"return_diagnostics": return_diagnostics, "scale": scale, "action": action},
{
"return_diagnostics": return_diagnostics,
"scale": scale,
"action": action,
},
)

def update_model_robust(self, time, measurement_model, measurement, **kwargs):
"""Record/apply a timestamped robust structural measurement-model update."""
return self._insert_and_apply(time, "update_model_robust", (measurement_model, measurement), kwargs)
return self._insert_and_apply(
time, "update_model_robust", (measurement_model, measurement), kwargs
)


class OutOfSequenceParticleUpdater(_EventReplayMixin):
Expand Down Expand Up @@ -461,18 +496,28 @@ def predict_nonlinear(
shift_instead_of_add=None,
):
"""Record/apply a timestamped nonlinear particle prediction."""
kwargs = {"noise_distribution": noise_distribution, "function_is_vectorized": function_is_vectorized}
kwargs = {
"noise_distribution": noise_distribution,
"function_is_vectorized": function_is_vectorized,
}
if shift_instead_of_add is not None:
kwargs["shift_instead_of_add"] = shift_instead_of_add
return self._insert_and_apply(time, "predict_nonlinear", (f,), kwargs)

def update_model(self, time, measurement_model, measurement=None):
"""Record/apply a timestamped particle measurement-model update."""
return self._insert_and_apply(time, "update_model", (measurement_model,), {"measurement": measurement})
return self._insert_and_apply(
time, "update_model", (measurement_model,), {"measurement": measurement}
)

def update_nonlinear_using_likelihood(self, time, likelihood, measurement=None):
"""Record/apply a timestamped likelihood-based particle update."""
return self._insert_and_apply(time, "update_nonlinear_using_likelihood", (likelihood,), {"measurement": measurement})
return self._insert_and_apply(
time,
"update_nonlinear_using_likelihood",
(likelihood,),
{"measurement": measurement},
)

update_with_likelihood = update_nonlinear_using_likelihood

Expand Down Expand Up @@ -527,7 +572,9 @@ def retrodict_linear_gaussian(
effective_covariance = covariance
if remove_process_noise:
if sys_noise_cov is None:
raise ValueError("sys_noise_cov is required when remove_process_noise is true")
raise ValueError(
"sys_noise_cov is required when remove_process_noise is true"
)
sys_noise_cov = _as_matrix(sys_noise_cov, "sys_noise_cov")
if sys_noise_cov.shape != covariance.shape:
raise ValueError("sys_noise_cov has incompatible shape")
Expand Down Expand Up @@ -559,4 +606,6 @@ def retrodict_linear_gaussian_state(
sys_noise_cov=sys_noise_cov,
remove_process_noise=remove_process_noise,
)
return GaussianDistribution(previous_mean, previous_covariance, check_validity=False)
return GaussianDistribution(
previous_mean, previous_covariance, check_validity=False
)
Loading
Loading