diff --git a/src/spikeinterface/core/basesorting.py b/src/spikeinterface/core/basesorting.py index a5cbdff18e..e6bd6158e9 100644 --- a/src/spikeinterface/core/basesorting.py +++ b/src/spikeinterface/core/basesorting.py @@ -278,9 +278,18 @@ def get_unit_spike_train_in_seconds( # Use the native spiking times if available # Some instances might implement a method themselves to access spike times directly without having to convert - # (e.g. NWB extractors) + # (e.g. NWB extractors). The native times already include the extractor's `_native_t_start`, + # so we apply only the shift (`_t_start - _native_t_start`) on top. if hasattr(segment, "get_unit_spike_train_in_seconds"): - return segment.get_unit_spike_train_in_seconds(unit_id=unit_id, start_time=start_time, end_time=end_time) + spike_times = segment.get_unit_spike_train_in_seconds( + unit_id=unit_id, start_time=start_time, end_time=end_time + ) + t_start = segment._t_start if segment._t_start is not None else 0 + native_t_start = segment._native_t_start if segment._native_t_start is not None else 0 + shift = t_start - native_t_start + if shift != 0: + spike_times = spike_times + shift + return spike_times # If no recording attached and all back to frame-based conversion # Get spike train in frames and convert to times using traditional method @@ -330,8 +339,12 @@ def register_recording(self, recording, check_spike_frames: bool = True): # Copy the recording's start times into the sorting segments. This way, # the sorting preserves the start time even if the recording is later # detached (e.g. analyzer saved and reloaded without the recording). + # Also update `_native_t_start` so any subsequent `shift_times` call measures + # its delta from the recording's start time (not the extractor's original value). for segment_index, segment in enumerate(self.segments): - segment._t_start = recording.get_start_time(segment_index=segment_index) + start_time = recording.get_start_time(segment_index=segment_index) + segment._t_start = start_time + segment._native_t_start = start_time @property def sorting_info(self): @@ -374,11 +387,38 @@ def get_start_time(self, segment_index: int | None = None) -> float: segment = self.segments[segment_index] return segment._t_start if segment._t_start is not None else 0.0 + def shift_times(self, shift: int | float, segment_index: int | None = None) -> None: + """ + Shift all times by a scalar value. + + This modifies the sorting's own time offset without touching the registered + recording. When a recording is registered, the shift is applied on top of + the recording's time basis when resolving timestamps. + + Parameters + ---------- + shift : int | float + The shift to apply. If positive, times will be increased by `shift`. + If negative, times will be decreased. + segment_index : int | None + The segment on which to shift the times. + If `None`, all segments will be shifted. + """ + if segment_index is None: + segments_to_shift = range(self.get_num_segments()) + else: + segments_to_shift = (segment_index,) + + for segment_index in segments_to_shift: + segment = self.segments[segment_index] + segment._t_start = (segment._t_start if segment._t_start is not None else 0) + shift + def get_end_time(self, segment_index: int | None = None) -> float: """Get the end time of the sorting segment. - If a recording is registered, returns the recording's end time. - Otherwise returns the time of the last spike in the segment. + If a recording is registered, returns the recording's end time (plus any + shift applied via `shift_times`). Otherwise returns the time of the last + spike in the segment. Parameters ---------- @@ -392,7 +432,10 @@ def get_end_time(self, segment_index: int | None = None) -> float: """ segment_index = self._check_segment_index(segment_index) if self.has_recording(): - return self._recording.get_end_time(segment_index=segment_index) + segment = self.segments[segment_index] + t_start = segment._t_start if segment._t_start is not None else 0 + shift = t_start - self._recording.get_start_time(segment_index=segment_index) + return self._recording.get_end_time(segment_index=segment_index) + shift else: last_spike_frame = self.get_last_spike_frame(segment_index=segment_index) return self.sample_index_to_time(last_spike_frame, segment_index=segment_index) @@ -425,11 +468,19 @@ def get_times(self, segment_index=None): * if the segment has a time_vector, then it is returned * if not, a time_vector is constructed on the fly with sampling frequency + Any shift applied via `shift_times` is added to the returned times. + If there is no registered recording it returns None """ segment_index = self._check_segment_index(segment_index) if self.has_recording(): - return self._recording.get_times(segment_index=segment_index) + times = self._recording.get_times(segment_index=segment_index) + segment = self.segments[segment_index] + t_start = segment._t_start if segment._t_start is not None else 0 + shift = t_start - self._recording.get_start_time(segment_index=segment_index) + if shift != 0: + times = times + shift + return times else: return None @@ -771,11 +822,13 @@ def time_to_sample_index(self, time, segment_index=0): """ Transform time in seconds into sample index """ + segment = self.segments[segment_index] + t_start = segment._t_start if segment._t_start is not None else 0 if self.has_recording(): - sample_index = self._recording.time_to_sample_index(time, segment_index=segment_index) + # Subtract the sorting's shift (relative to the recording's start) before delegating + shift = t_start - self._recording.get_start_time(segment_index=segment_index) + sample_index = self._recording.time_to_sample_index(time - shift, segment_index=segment_index) else: - segment = self.segments[segment_index] - t_start = segment._t_start if segment._t_start is not None else 0 sample_index = round((time - t_start) * self.get_sampling_frequency()) return sample_index @@ -787,11 +840,13 @@ def sample_index_to_time( Transform sample index into time in seconds """ segment_index = self._check_segment_index(segment_index) + segment = self.segments[segment_index] + t_start = segment._t_start if segment._t_start is not None else 0 if self.has_recording(): - return self._recording.sample_index_to_time(sample_index, segment_index=segment_index) + # Add the sorting's shift (relative to the recording's start) after delegating + shift = t_start - self._recording.get_start_time(segment_index=segment_index) + return self._recording.sample_index_to_time(sample_index, segment_index=segment_index) + shift else: - segment = self.segments[segment_index] - t_start = segment._t_start if segment._t_start is not None else 0 return (sample_index / self.get_sampling_frequency()) + t_start def precompute_spike_trains(self): @@ -1149,6 +1204,11 @@ class BaseSortingSegment(BaseSegment): def __init__(self, t_start=None): self._t_start = t_start + # Immutable reference to the start time as set by the extractor at init. + # Used to compute the user-applied shift as `_t_start - _native_t_start`, + # so `shift_times` can correctly propagate through extractors that return + # native absolute times (e.g. NWB) without double-counting the extractor's offset. + self._native_t_start = t_start BaseSegment.__init__(self) def get_unit_spike_train( diff --git a/src/spikeinterface/core/tests/test_time_handling.py b/src/spikeinterface/core/tests/test_time_handling.py index bd74ddfe02..263ed0b36e 100644 --- a/src/spikeinterface/core/tests/test_time_handling.py +++ b/src/spikeinterface/core/tests/test_time_handling.py @@ -465,6 +465,52 @@ def test_get_start_time_with_t_start(self): sorting.segments[0]._t_start = 100.0 assert sorting.get_start_time(segment_index=0) == 100.0 + def test_shift_times(self): + sorting = generate_sorting(num_units=5, durations=[10]) + unit_id = sorting.unit_ids[0] + + spike_times_before = sorting.get_unit_spike_train(unit_id, segment_index=0, return_times=True) + + sorting.shift_times(shift=5.0) + + assert sorting.get_start_time(segment_index=0) == 5.0 + spike_times_after = sorting.get_unit_spike_train(unit_id, segment_index=0, return_times=True) + assert np.allclose(spike_times_after, spike_times_before + 5.0) + + def test_shift_times_all_segments(self): + sorting = generate_sorting(num_units=5, durations=[10, 15]) + sorting.segments[0]._t_start = 1.0 + sorting.segments[1]._t_start = 2.0 + + sorting.shift_times(shift=3.0) + + assert sorting.get_start_time(segment_index=0) == 4.0 + assert sorting.get_start_time(segment_index=1) == 5.0 + + def test_shift_times_single_segment(self): + sorting = generate_sorting(num_units=5, durations=[10, 15]) + sorting.segments[0]._t_start = 1.0 + sorting.segments[1]._t_start = 2.0 + + sorting.shift_times(shift=3.0, segment_index=1) + + assert sorting.get_start_time(segment_index=0) == 1.0 + assert sorting.get_start_time(segment_index=1) == 5.0 + + def test_shift_times_with_native_spike_times(self): + """Shift must apply even when the segment provides native spike times (e.g. NWB extractors).""" + sorting = generate_sorting(num_units=5, durations=[10]) + unit_id = sorting.unit_ids[0] + segment = sorting.segments[0] + + # Simulate a segment that provides native spike times directly + original_times = sorting.get_unit_spike_train(unit_id, segment_index=0, return_times=True).copy() + segment.get_unit_spike_train_in_seconds = lambda unit_id, start_time, end_time: original_times + + sorting.shift_times(shift=5.0) + spike_times = sorting.get_unit_spike_train(unit_id, segment_index=0, return_times=True) + assert np.allclose(spike_times, original_times + 5.0) + class TestSortingTimeWithRecording: """ @@ -503,3 +549,68 @@ def test_with_recording_shifted_start(self): sorting.register_recording(recording) assert sorting.get_start_time(segment_index=0) == 50.0 + + def test_shift_times(self): + recording = generate_recording(num_channels=4, durations=[10]) + sorting = generate_sorting(num_units=5, durations=[10]) + sorting.register_recording(recording) + unit_id = sorting.unit_ids[0] + + rec_start_before = recording.get_start_time(segment_index=0) + rec_end_before = recording.get_end_time(segment_index=0) + spike_times_before = sorting.get_unit_spike_train(unit_id, segment_index=0, return_times=True) + + sorting.shift_times(shift=5.0) + + # The recording should be untouched + assert recording.get_start_time(segment_index=0) == rec_start_before + assert recording.get_end_time(segment_index=0) == rec_end_before + + # The sorting's times should be shifted + assert sorting.get_start_time(segment_index=0) == rec_start_before + 5.0 + assert sorting.get_end_time(segment_index=0) == rec_end_before + 5.0 + spike_times_after = sorting.get_unit_spike_train(unit_id, segment_index=0, return_times=True) + assert np.allclose(spike_times_after, spike_times_before + 5.0) + + def test_time_conversion_roundtrip_after_shift(self): + """sample_index_to_time and time_to_sample_index must remain inverses after a shift.""" + recording = generate_recording(num_channels=4, durations=[10]) + sorting = generate_sorting(num_units=5, durations=[10]) + sorting.register_recording(recording) + + sorting.shift_times(shift=5.0) + + # Frame 30000 is 1.0s in the recording. After a 5.0s shift, the sorting should report 6.0s. + time = sorting.sample_index_to_time(30000, segment_index=0) + assert time == recording.sample_index_to_time(30000, segment_index=0) + 5.0 + + # The inverse: 6.0s in the sorting should map back to frame 30000. + frame = sorting.time_to_sample_index(time, segment_index=0) + assert frame == 30000 + + def test_shift_times_with_time_vector(self): + """Shift on sorting composes with a recording that has an explicit time vector, + preserving the irregular spacing.""" + recording = generate_recording(num_channels=4, durations=[1.0]) + num_samples = recording.get_num_samples(segment_index=0) + # Irregular timestamps starting at 100.0 + times = ( + 100.0 + + np.cumsum(np.random.RandomState(0).uniform(0.5, 1.5, num_samples)) / recording.get_sampling_frequency() + ) + recording.set_times(times, segment_index=0, with_warning=False) + + sorting = generate_sorting(num_units=5, durations=[1.0]) + sorting.register_recording(recording) + unit_id = sorting.unit_ids[0] + + spike_times_before = sorting.get_unit_spike_train(unit_id, segment_index=0, return_times=True) + + sorting.shift_times(shift=5.0) + + spike_times_after = sorting.get_unit_spike_train(unit_id, segment_index=0, return_times=True) + # Irregular spacing preserved, everything shifted by 5.0 + assert np.allclose(spike_times_after, spike_times_before + 5.0) + + # Recording is untouched + assert np.allclose(recording.get_times(segment_index=0), times) diff --git a/src/spikeinterface/extractors/neoextractors/neobaseextractor.py b/src/spikeinterface/extractors/neoextractors/neobaseextractor.py index 3d787f9519..fa91f0f4c2 100644 --- a/src/spikeinterface/extractors/neoextractors/neobaseextractor.py +++ b/src/spikeinterface/extractors/neoextractors/neobaseextractor.py @@ -618,11 +618,10 @@ def __init__( sampling_frequency, neo_returns_frames, ): - BaseSortingSegment.__init__(self) + BaseSortingSegment.__init__(self, t_start=t_start) self.neo_reader = neo_reader self.segment_index = segment_index self.block_index = block_index - self._t_start = t_start self._sampling_frequency = sampling_frequency self.neo_returns_frames = neo_returns_frames diff --git a/src/spikeinterface/extractors/nwbextractors.py b/src/spikeinterface/extractors/nwbextractors.py index b89999d088..b223b97398 100644 --- a/src/spikeinterface/extractors/nwbextractors.py +++ b/src/spikeinterface/extractors/nwbextractors.py @@ -1313,13 +1313,12 @@ def _fetch_properties(self, columns): class NwbSortingSegment(BaseSortingSegment): def __init__(self, spike_times_data, spike_times_index_data, sampling_frequency: float, t_start: float): - BaseSortingSegment.__init__(self) + BaseSortingSegment.__init__(self, t_start=t_start) self.spike_times_data = spike_times_data self.spike_times_index_data = spike_times_index_data self.spike_times_data = spike_times_data self.spike_times_index_data = spike_times_index_data self._sampling_frequency = sampling_frequency - self._t_start = t_start def get_unit_spike_train( self,