Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 73 additions & 13 deletions src/spikeinterface/core/basesorting.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,9 +278,18 @@ def get_unit_spike_train_in_seconds(

# Use the native spiking times if available
# Some instances might implement a method themselves to access spike times directly without having to convert
# (e.g. NWB extractors)
# (e.g. NWB extractors). The native times already include the extractor's `_native_t_start`,
# so we apply only the shift (`_t_start - _native_t_start`) on top.
if hasattr(segment, "get_unit_spike_train_in_seconds"):
return segment.get_unit_spike_train_in_seconds(unit_id=unit_id, start_time=start_time, end_time=end_time)
spike_times = segment.get_unit_spike_train_in_seconds(
unit_id=unit_id, start_time=start_time, end_time=end_time
)
t_start = segment._t_start if segment._t_start is not None else 0
native_t_start = segment._native_t_start if segment._native_t_start is not None else 0
shift = t_start - native_t_start
if shift != 0:
spike_times = spike_times + shift
return spike_times

# If no recording attached and all back to frame-based conversion
# Get spike train in frames and convert to times using traditional method
Expand Down Expand Up @@ -330,8 +339,12 @@ def register_recording(self, recording, check_spike_frames: bool = True):
# Copy the recording's start times into the sorting segments. This way,
# the sorting preserves the start time even if the recording is later
# detached (e.g. analyzer saved and reloaded without the recording).
# Also update `_native_t_start` so any subsequent `shift_times` call measures
# its delta from the recording's start time (not the extractor's original value).
for segment_index, segment in enumerate(self.segments):
segment._t_start = recording.get_start_time(segment_index=segment_index)
start_time = recording.get_start_time(segment_index=segment_index)
segment._t_start = start_time
segment._native_t_start = start_time

@property
def sorting_info(self):
Expand Down Expand Up @@ -374,11 +387,38 @@ def get_start_time(self, segment_index: int | None = None) -> float:
segment = self.segments[segment_index]
return segment._t_start if segment._t_start is not None else 0.0

def shift_times(self, shift: int | float, segment_index: int | None = None) -> None:
"""
Shift all times by a scalar value.

This modifies the sorting's own time offset without touching the registered
recording. When a recording is registered, the shift is applied on top of
the recording's time basis when resolving timestamps.

Parameters
----------
shift : int | float
The shift to apply. If positive, times will be increased by `shift`.
If negative, times will be decreased.
segment_index : int | None
The segment on which to shift the times.
If `None`, all segments will be shifted.
"""
if segment_index is None:
segments_to_shift = range(self.get_num_segments())
else:
segments_to_shift = (segment_index,)

for segment_index in segments_to_shift:
segment = self.segments[segment_index]
segment._t_start = (segment._t_start if segment._t_start is not None else 0) + shift

def get_end_time(self, segment_index: int | None = None) -> float:
"""Get the end time of the sorting segment.

If a recording is registered, returns the recording's end time.
Otherwise returns the time of the last spike in the segment.
If a recording is registered, returns the recording's end time (plus any
shift applied via `shift_times`). Otherwise returns the time of the last
spike in the segment.

Parameters
----------
Expand All @@ -392,7 +432,10 @@ def get_end_time(self, segment_index: int | None = None) -> float:
"""
segment_index = self._check_segment_index(segment_index)
if self.has_recording():
return self._recording.get_end_time(segment_index=segment_index)
segment = self.segments[segment_index]
t_start = segment._t_start if segment._t_start is not None else 0
shift = t_start - self._recording.get_start_time(segment_index=segment_index)
return self._recording.get_end_time(segment_index=segment_index) + shift
else:
last_spike_frame = self.get_last_spike_frame(segment_index=segment_index)
return self.sample_index_to_time(last_spike_frame, segment_index=segment_index)
Expand Down Expand Up @@ -425,11 +468,19 @@ def get_times(self, segment_index=None):
* if the segment has a time_vector, then it is returned
* if not, a time_vector is constructed on the fly with sampling frequency

Any shift applied via `shift_times` is added to the returned times.

If there is no registered recording it returns None
"""
segment_index = self._check_segment_index(segment_index)
if self.has_recording():
return self._recording.get_times(segment_index=segment_index)
times = self._recording.get_times(segment_index=segment_index)
segment = self.segments[segment_index]
t_start = segment._t_start if segment._t_start is not None else 0
shift = t_start - self._recording.get_start_time(segment_index=segment_index)
if shift != 0:
times = times + shift
return times
else:
return None

Expand Down Expand Up @@ -771,11 +822,13 @@ def time_to_sample_index(self, time, segment_index=0):
"""
Transform time in seconds into sample index
"""
segment = self.segments[segment_index]
t_start = segment._t_start if segment._t_start is not None else 0
if self.has_recording():
sample_index = self._recording.time_to_sample_index(time, segment_index=segment_index)
# Subtract the sorting's shift (relative to the recording's start) before delegating
shift = t_start - self._recording.get_start_time(segment_index=segment_index)
sample_index = self._recording.time_to_sample_index(time - shift, segment_index=segment_index)
else:
segment = self.segments[segment_index]
t_start = segment._t_start if segment._t_start is not None else 0
sample_index = round((time - t_start) * self.get_sampling_frequency())

return sample_index
Expand All @@ -787,11 +840,13 @@ def sample_index_to_time(
Transform sample index into time in seconds
"""
segment_index = self._check_segment_index(segment_index)
segment = self.segments[segment_index]
t_start = segment._t_start if segment._t_start is not None else 0
if self.has_recording():
return self._recording.sample_index_to_time(sample_index, segment_index=segment_index)
# Add the sorting's shift (relative to the recording's start) after delegating
shift = t_start - self._recording.get_start_time(segment_index=segment_index)
return self._recording.sample_index_to_time(sample_index, segment_index=segment_index) + shift
else:
segment = self.segments[segment_index]
t_start = segment._t_start if segment._t_start is not None else 0
return (sample_index / self.get_sampling_frequency()) + t_start

def precompute_spike_trains(self):
Expand Down Expand Up @@ -1149,6 +1204,11 @@ class BaseSortingSegment(BaseSegment):

def __init__(self, t_start=None):
self._t_start = t_start
# Immutable reference to the start time as set by the extractor at init.
# Used to compute the user-applied shift as `_t_start - _native_t_start`,
# so `shift_times` can correctly propagate through extractors that return
# native absolute times (e.g. NWB) without double-counting the extractor's offset.
self._native_t_start = t_start
BaseSegment.__init__(self)

def get_unit_spike_train(
Expand Down
111 changes: 111 additions & 0 deletions src/spikeinterface/core/tests/test_time_handling.py
Original file line number Diff line number Diff line change
Expand Up @@ -465,6 +465,52 @@ def test_get_start_time_with_t_start(self):
sorting.segments[0]._t_start = 100.0
assert sorting.get_start_time(segment_index=0) == 100.0

def test_shift_times(self):
sorting = generate_sorting(num_units=5, durations=[10])
unit_id = sorting.unit_ids[0]

spike_times_before = sorting.get_unit_spike_train(unit_id, segment_index=0, return_times=True)

sorting.shift_times(shift=5.0)

assert sorting.get_start_time(segment_index=0) == 5.0
spike_times_after = sorting.get_unit_spike_train(unit_id, segment_index=0, return_times=True)
assert np.allclose(spike_times_after, spike_times_before + 5.0)

def test_shift_times_all_segments(self):
sorting = generate_sorting(num_units=5, durations=[10, 15])
sorting.segments[0]._t_start = 1.0
sorting.segments[1]._t_start = 2.0

sorting.shift_times(shift=3.0)

assert sorting.get_start_time(segment_index=0) == 4.0
assert sorting.get_start_time(segment_index=1) == 5.0

def test_shift_times_single_segment(self):
sorting = generate_sorting(num_units=5, durations=[10, 15])
sorting.segments[0]._t_start = 1.0
sorting.segments[1]._t_start = 2.0

sorting.shift_times(shift=3.0, segment_index=1)

assert sorting.get_start_time(segment_index=0) == 1.0
assert sorting.get_start_time(segment_index=1) == 5.0

def test_shift_times_with_native_spike_times(self):
"""Shift must apply even when the segment provides native spike times (e.g. NWB extractors)."""
sorting = generate_sorting(num_units=5, durations=[10])
unit_id = sorting.unit_ids[0]
segment = sorting.segments[0]

# Simulate a segment that provides native spike times directly
original_times = sorting.get_unit_spike_train(unit_id, segment_index=0, return_times=True).copy()
segment.get_unit_spike_train_in_seconds = lambda unit_id, start_time, end_time: original_times

sorting.shift_times(shift=5.0)
spike_times = sorting.get_unit_spike_train(unit_id, segment_index=0, return_times=True)
assert np.allclose(spike_times, original_times + 5.0)


class TestSortingTimeWithRecording:
"""
Expand Down Expand Up @@ -503,3 +549,68 @@ def test_with_recording_shifted_start(self):
sorting.register_recording(recording)

assert sorting.get_start_time(segment_index=0) == 50.0

def test_shift_times(self):
recording = generate_recording(num_channels=4, durations=[10])
sorting = generate_sorting(num_units=5, durations=[10])
sorting.register_recording(recording)
unit_id = sorting.unit_ids[0]

rec_start_before = recording.get_start_time(segment_index=0)
rec_end_before = recording.get_end_time(segment_index=0)
spike_times_before = sorting.get_unit_spike_train(unit_id, segment_index=0, return_times=True)

sorting.shift_times(shift=5.0)

# The recording should be untouched
assert recording.get_start_time(segment_index=0) == rec_start_before
assert recording.get_end_time(segment_index=0) == rec_end_before

# The sorting's times should be shifted
assert sorting.get_start_time(segment_index=0) == rec_start_before + 5.0
assert sorting.get_end_time(segment_index=0) == rec_end_before + 5.0
spike_times_after = sorting.get_unit_spike_train(unit_id, segment_index=0, return_times=True)
assert np.allclose(spike_times_after, spike_times_before + 5.0)

def test_time_conversion_roundtrip_after_shift(self):
"""sample_index_to_time and time_to_sample_index must remain inverses after a shift."""
recording = generate_recording(num_channels=4, durations=[10])
sorting = generate_sorting(num_units=5, durations=[10])
sorting.register_recording(recording)

sorting.shift_times(shift=5.0)

# Frame 30000 is 1.0s in the recording. After a 5.0s shift, the sorting should report 6.0s.
time = sorting.sample_index_to_time(30000, segment_index=0)
assert time == recording.sample_index_to_time(30000, segment_index=0) + 5.0

# The inverse: 6.0s in the sorting should map back to frame 30000.
frame = sorting.time_to_sample_index(time, segment_index=0)
assert frame == 30000

def test_shift_times_with_time_vector(self):
"""Shift on sorting composes with a recording that has an explicit time vector,
preserving the irregular spacing."""
recording = generate_recording(num_channels=4, durations=[1.0])
num_samples = recording.get_num_samples(segment_index=0)
# Irregular timestamps starting at 100.0
times = (
100.0
+ np.cumsum(np.random.RandomState(0).uniform(0.5, 1.5, num_samples)) / recording.get_sampling_frequency()
)
recording.set_times(times, segment_index=0, with_warning=False)

sorting = generate_sorting(num_units=5, durations=[1.0])
sorting.register_recording(recording)
unit_id = sorting.unit_ids[0]

spike_times_before = sorting.get_unit_spike_train(unit_id, segment_index=0, return_times=True)

sorting.shift_times(shift=5.0)

spike_times_after = sorting.get_unit_spike_train(unit_id, segment_index=0, return_times=True)
# Irregular spacing preserved, everything shifted by 5.0
assert np.allclose(spike_times_after, spike_times_before + 5.0)

# Recording is untouched
assert np.allclose(recording.get_times(segment_index=0), times)
Original file line number Diff line number Diff line change
Expand Up @@ -618,11 +618,10 @@ def __init__(
sampling_frequency,
neo_returns_frames,
):
BaseSortingSegment.__init__(self)
BaseSortingSegment.__init__(self, t_start=t_start)
self.neo_reader = neo_reader
self.segment_index = segment_index
self.block_index = block_index
self._t_start = t_start
self._sampling_frequency = sampling_frequency
self.neo_returns_frames = neo_returns_frames

Expand Down
3 changes: 1 addition & 2 deletions src/spikeinterface/extractors/nwbextractors.py
Original file line number Diff line number Diff line change
Expand Up @@ -1313,13 +1313,12 @@ def _fetch_properties(self, columns):

class NwbSortingSegment(BaseSortingSegment):
def __init__(self, spike_times_data, spike_times_index_data, sampling_frequency: float, t_start: float):
BaseSortingSegment.__init__(self)
BaseSortingSegment.__init__(self, t_start=t_start)
self.spike_times_data = spike_times_data
self.spike_times_index_data = spike_times_index_data
self.spike_times_data = spike_times_data
self.spike_times_index_data = spike_times_index_data
self._sampling_frequency = sampling_frequency
self._t_start = t_start

def get_unit_spike_train(
self,
Expand Down
Loading