Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
129 changes: 100 additions & 29 deletions src/pyrecest/filters/association_hypotheses.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,9 @@ def missed_detection_hypothesis(
)


def hypothesis_cost(hypothesis: AssociationHypothesis, *, missing_cost: float = np.inf) -> float:
def hypothesis_cost(
hypothesis: AssociationHypothesis, *, missing_cost: float = np.inf
) -> float:
"""Return a scalar minimization cost for a hypothesis."""
if hypothesis.cost is not None:
return float(hypothesis.cost)
Expand All @@ -110,7 +112,13 @@ def hypothesis_cost(hypothesis: AssociationHypothesis, *, missing_cost: float =
class NISGate:
"""Gate association hypotheses by normalized innovation squared."""

def __init__(self, threshold: float | None = None, *, measurement_dim: int | None = None, confidence: float | None = None):
def __init__(
self,
threshold: float | None = None,
*,
measurement_dim: int | None = None,
confidence: float | None = None,
):
if threshold is None:
if measurement_dim is None or confidence is None:
raise ValueError(
Expand Down Expand Up @@ -143,7 +151,10 @@ def __init__(self, threshold: float, *, missing_cost: float = np.inf):
def accepts(self, hypothesis: AssociationHypothesis) -> bool:
if hypothesis.is_missed_detection:
return True
return hypothesis_cost(hypothesis, missing_cost=self.missing_cost) <= self.threshold
return (
hypothesis_cost(hypothesis, missing_cost=self.missing_cost)
<= self.threshold
)

def __call__(self, hypothesis: AssociationHypothesis) -> bool:
return self.accepts(hypothesis)
Expand Down Expand Up @@ -176,7 +187,9 @@ def __call__(self, hypothesis: AssociationHypothesis) -> bool:
class TopKGate:
"""Keep the best ``k`` hypotheses per track or per measurement."""

def __init__(self, k: int, *, mode: GateMode = "track", missing_cost: float = np.inf):
def __init__(
self, k: int, *, mode: GateMode = "track", missing_cost: float = np.inf
):
self.k = int(k)
if self.k <= 0:
raise ValueError("k must be positive")
Expand All @@ -185,7 +198,9 @@ def __init__(self, k: int, *, mode: GateMode = "track", missing_cost: float = np
self.mode = mode
self.missing_cost = float(missing_cost)

def filter(self, hypotheses: Sequence[AssociationHypothesis]) -> list[AssociationHypothesis]:
def filter(
self, hypotheses: Sequence[AssociationHypothesis]
) -> list[AssociationHypothesis]:
"""Return hypotheses accepted by the top-k rule."""
accepted_keys = set()
grouped: dict[int, list[AssociationHypothesis]] = defaultdict(list)
Expand All @@ -194,23 +209,34 @@ def filter(self, hypotheses: Sequence[AssociationHypothesis]) -> list[Associatio
if hypothesis.is_missed_detection:
missed.append(hypothesis)
continue
key = hypothesis.track_index if self.mode == "track" else _measurement_index(hypothesis)
key = (
hypothesis.track_index
if self.mode == "track"
else _measurement_index(hypothesis)
)
grouped[key].append(hypothesis)

for group in grouped.values():
sorted_group = sorted(
group,
key=lambda hypothesis: hypothesis_cost(hypothesis, missing_cost=self.missing_cost),
key=lambda hypothesis: hypothesis_cost(
hypothesis, missing_cost=self.missing_cost
),
)
for hypothesis in sorted_group[: self.k]:
accepted_keys.add((hypothesis.track_index, hypothesis.measurement_index))
accepted_keys.add(
(hypothesis.track_index, hypothesis.measurement_index)
)

result = []
for hypothesis in hypotheses:
if hypothesis.is_missed_detection:
result.append(hypothesis)
continue
accepted = (hypothesis.track_index, hypothesis.measurement_index) in accepted_keys
accepted = (
hypothesis.track_index,
hypothesis.measurement_index,
) in accepted_keys
result.append(
hypothesis.with_acceptance(
accepted,
Expand All @@ -222,7 +248,9 @@ def filter(self, hypotheses: Sequence[AssociationHypothesis]) -> list[Associatio
def accepts(self, hypothesis: AssociationHypothesis) -> bool:
raise TypeError("TopKGate operates on a collection; use filter(...)")

def __call__(self, hypotheses: Sequence[AssociationHypothesis]) -> list[AssociationHypothesis]:
def __call__(
self, hypotheses: Sequence[AssociationHypothesis]
) -> list[AssociationHypothesis]:
return self.filter(hypotheses)


Expand All @@ -242,7 +270,9 @@ def gate_hypotheses(

result = []
for hypothesis in hypotheses:
accepted = bool(gate(hypothesis) if callable(gate) else gate.accepts(hypothesis))
accepted = bool(
gate(hypothesis) if callable(gate) else gate.accepts(hypothesis)
)
reason = None if accepted else reject_reason or type(gate).__name__
result.append(hypothesis.with_acceptance(accepted, reason))
return result
Expand All @@ -257,7 +287,11 @@ def filter_hypotheses(
"""Apply one or more gates and optionally drop rejected hypotheses."""
result = list(hypotheses)
if gates is None:
return [hypothesis for hypothesis in result if hypothesis.accepted] if accepted_only else result
return (
[hypothesis for hypothesis in result if hypothesis.accepted]
if accepted_only
else result
)
if not isinstance(gates, (list, tuple)):
gates = [gates]
for gate in gates:
Expand All @@ -279,9 +313,13 @@ def linear_gaussian_association_hypotheses(
metadata_builder: Callable[..., dict[str, Any] | None] | None = None,
) -> list[AssociationHypothesis]:
"""Build Gaussian innovation hypotheses for tracks and measurements."""
measurement_vectors = _coerce_measurements(measurements, measurement_axis=measurement_axis)
measurement_vectors = _coerce_measurements(
measurements, measurement_axis=measurement_axis
)
measurement_matrix = np.asarray(measurement_matrix, dtype=float)
covariance_stack = _coerce_measurement_covariances(meas_noise, len(measurement_vectors))
covariance_stack = _coerce_measurement_covariances(
meas_noise, len(measurement_vectors)
)

hypotheses = []
for track_index, track in enumerate(tracks):
Expand All @@ -295,7 +333,9 @@ def linear_gaussian_association_hypotheses(
measurement_matrix,
covariance,
)
nis = float(normalized_innovation_squared(innovation, innovation_covariance))
nis = float(
normalized_innovation_squared(innovation, innovation_covariance)
)
measurement_dim = int(measurement_matrix.shape[0])
log_likelihood = _linear_gaussian_log_likelihood(
innovation,
Expand Down Expand Up @@ -329,7 +369,9 @@ def linear_gaussian_association_hypotheses(
)

if gates is not None:
hypotheses = filter_hypotheses(hypotheses, gates, accepted_only=not include_rejected)
hypotheses = filter_hypotheses(
hypotheses, gates, accepted_only=not include_rejected
)
elif not include_rejected:
hypotheses = [hypothesis for hypothesis in hypotheses if hypothesis.accepted]
return hypotheses
Expand Down Expand Up @@ -435,7 +477,9 @@ def association_result_from_hypotheses(
unassigned_measurement_cost: float | Sequence[float] | None = None,
):
"""Solve GNN assignment from hypotheses and return ``AssociationResult``."""
from .track_manager import solve_global_nearest_neighbor # pylint: disable=import-outside-toplevel
from .track_manager import ( # pylint: disable=import-outside-toplevel
solve_global_nearest_neighbor,
)

cost_matrix = hypotheses_to_cost_matrix(
hypotheses,
Expand Down Expand Up @@ -474,10 +518,19 @@ def associator(tracks, measurements, **kwargs):
return association_result_from_hypotheses(
hypotheses,
num_tracks=len(tracks),
num_measurements=len(_coerce_measurements(measurements, measurement_axis=kwargs.get("measurement_axis", measurement_axis))),
num_measurements=len(
_coerce_measurements(
measurements,
measurement_axis=kwargs.get("measurement_axis", measurement_axis),
)
),
missing_cost=kwargs.get("missing_cost", missing_cost),
unassigned_track_cost=kwargs.get("unassigned_track_cost", unassigned_track_cost),
unassigned_measurement_cost=kwargs.get("unassigned_measurement_cost", unassigned_measurement_cost),
unassigned_track_cost=kwargs.get(
"unassigned_track_cost", unassigned_track_cost
),
unassigned_measurement_cost=kwargs.get(
"unassigned_measurement_cost", unassigned_measurement_cost
),
)

return associator
Expand Down Expand Up @@ -510,24 +563,36 @@ def _track_filter_state(track):
return track.filter_state
if hasattr(track, "single_target_filter"):
return track.single_target_filter.filter_state
raise TypeError("track must expose filter_state or single_target_filter.filter_state")
raise TypeError(
"track must expose filter_state or single_target_filter.filter_state"
)


def _coerce_measurements(measurements, *, measurement_axis: MeasurementAxis = "auto") -> list[np.ndarray]:
def _coerce_measurements(
measurements, *, measurement_axis: MeasurementAxis = "auto"
) -> list[np.ndarray]:
if isinstance(measurements, np.ndarray):
return _coerce_measurement_array(measurements, measurement_axis=measurement_axis)
return _coerce_measurement_array(
measurements, measurement_axis=measurement_axis
)
try:
from pyrecest.backend import to_numpy # pylint: disable=import-outside-toplevel

maybe_array = to_numpy(measurements)
if isinstance(maybe_array, np.ndarray):
return _coerce_measurement_array(maybe_array, measurement_axis=measurement_axis)
return _coerce_measurement_array(
maybe_array, measurement_axis=measurement_axis
)
except (ImportError, AttributeError, TypeError):
pass
return [np.asarray(measurement, dtype=float).reshape(-1) for measurement in measurements]
return [
np.asarray(measurement, dtype=float).reshape(-1) for measurement in measurements
]


def _coerce_measurement_array(array, *, measurement_axis: MeasurementAxis) -> list[np.ndarray]:
def _coerce_measurement_array(
array, *, measurement_axis: MeasurementAxis
) -> list[np.ndarray]:
array = np.asarray(array, dtype=float)
if array.ndim == 1:
return [array.reshape(-1)]
Expand All @@ -540,13 +605,17 @@ def _coerce_measurement_array(array, *, measurement_axis: MeasurementAxis) -> li
if measurement_axis == "sequence":
return [array[index].reshape(-1) for index in range(array.shape[0])]
if measurement_axis != "auto":
raise ValueError("measurement_axis must be 'auto', 'columns', 'rows', or 'sequence'")
raise ValueError(
"measurement_axis must be 'auto', 'columns', 'rows', or 'sequence'"
)
if array.shape[0] <= array.shape[1]:
return [array[:, index].reshape(-1) for index in range(array.shape[1])]
return [array[index, :].reshape(-1) for index in range(array.shape[0])]


def _coerce_measurement_covariances(meas_noise, num_measurements: int) -> list[np.ndarray]:
def _coerce_measurement_covariances(
meas_noise, num_measurements: int
) -> list[np.ndarray]:
covariance = np.asarray(meas_noise, dtype=float)
if covariance.ndim == 2:
return [covariance for _ in range(num_measurements)]
Expand All @@ -558,7 +627,9 @@ def _coerce_measurement_covariances(meas_noise, num_measurements: int) -> list[n
raise ValueError("meas_noise must have shape (m, m), (m, m, n), or (n, m, m)")


def _linear_gaussian_log_likelihood(innovation, innovation_covariance, measurement_dim: int) -> float:
def _linear_gaussian_log_likelihood(
innovation, innovation_covariance, measurement_dim: int
) -> float:
innovation = np.asarray(innovation, dtype=float).reshape(-1)
innovation_covariance = np.asarray(innovation_covariance, dtype=float)
sign, logdet = np.linalg.slogdet(innovation_covariance)
Expand Down
Loading
Loading