Skip to content

pearsonlab/trajectory-decoding

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

2 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Trajectory Decoding

Implements the sequence autoencoder from Dai & Le (2015) — Semi-supervised Sequence Learning — applied to continuous trajectory/motion data.

Architecture: Mamba SSM encoder (pure PyTorch, no custom CUDA ops) compresses variable-length trajectories to a fixed latent vector; a symmetric Mamba decoder reconstructs the sequence. The encoder is trained unsupervised, then frozen while a linear or SVM classifier is trained on the embeddings.

Dataset: UEA BasicMotions — 4 action classes (standing, walking, running, badminton), 6-channel IMU (accel + gyro), 100 timesteps, 40 train + 40 test samples.

Device selection: CUDA is preferred when available, then MPS (Apple Silicon), then CPU. Pass --device cpu/cuda/mps to override.

Setup

uv sync

Usage

End-to-end experiment (downloads data, trains, evaluates, plots):

uv run python scripts/run_experiment.py

Train only the autoencoder:

uv run python -m src.train_ae --epochs 100 --lr 1e-3

Train classifiers (after autoencoder is trained):

uv run python -m src.train_clf --mode svm
uv run python -m src.train_clf --mode linear --epochs 50

Results

Expected SVM test accuracy on BasicMotions: ~82% after 100 epochs of autoencoder pretraining.

Updating the PyTorch version

The pyproject.toml pins torch>=2.2,<2.3 because that is the only version range with pre-built wheels for Intel macOS + Python 3.11. On other platforms you can relax or remove this constraint:

Linux / CUDA machines — remove the upper bound and optionally install the CUDA variant directly:

# edit pyproject.toml: change "torch>=2.2,<2.3" to "torch>=2.2"
uv sync
# or install a specific CUDA build:
uv pip install torch --index-url https://download.pytorch.org/whl/cu121

Apple Silicon (arm64) — newer torch versions have native arm64 wheels; remove the upper bound:

# edit pyproject.toml: change "torch>=2.2,<2.3" to "torch>=2.2"
uv sync

After editing pyproject.toml, run uv sync to resolve and install the updated dependency. You may also want to relax requires-python to >=3.11 to allow Python 3.12/3.13 on non-Intel platforms.

About

Sequence autoencoder (Dai & Le 2015) with Mamba SSM for trajectory/motion classification

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages