Skip to content

GA2M Model for In-Hospital Mortality Prediction (Hegselmann et al., MLHC 2020)#1124

Open
smritiSrinivasan109 wants to merge 6 commits intosunlabuiuc:masterfrom
smritiSrinivasan109:master
Open

GA2M Model for In-Hospital Mortality Prediction (Hegselmann et al., MLHC 2020)#1124
smritiSrinivasan109 wants to merge 6 commits intosunlabuiuc:masterfrom
smritiSrinivasan109:master

Conversation

@smritiSrinivasan109
Copy link
Copy Markdown

@smritiSrinivasan109 smritiSrinivasan109 commented Apr 23, 2026

Contributors:

Type of contribution: Model (Option 2)

Original paper: Hegselmann, Stefan et al. "An Evaluation of the Doctor-Interpretability of Generalized Additive Models with Interactions." Machine Learning in Health Care (2020). https://proceedings.mlr.press/v126/hegselmann20a/hegselmann20a.pdf

Description:
Implements a PyTorch-native GA2M (Generalized Additive Model with Interactions) for in-hospital mortality prediction using MIMIC-IV ICU data. The model uses learned per-bin risk embeddings over feature bins, representationally equivalent to the original gradient-boosted shape functions from the paper, though optimised via SGD for compatibility with PyHealth's trainer. Training follows the paper's two-stage process: main effects first, then top-K interaction pairs selected by variance of learned risk scores. The original paper uses gradient-boosted trees via the mltk toolkit; this implementation uses SGD over embeddings which is architecturally the same but optimised differently.

File Guide:

File Description
pyhealth/models/ga2m.py GA2M model implementation
pyhealth/datasets/mimic4_icu_mortality.py MIMIC-IV dataset loader (34 features, 48h ICU window)
examples/mimic4_mortality_ga2m.py Full pipeline + ablations (full GA2M vs main effects only vs mean features only vs logistic regression baseline)
tests/test_ga2m.py 34 tests covering instantiation, bin fitting, forward pass, gradient computation, interpretability helpers
docs/api/models/pyhealth.models.GA2M.rst API documentation

Environment:
pip install -e . from the repo root, then pip install pytest scikit-learn.

To run:

python examples/mimic4_mortality_ga2m.py \
    --data_root path/to/mimic-iv-demo/mimic-iv-clinical-database-demo-2.2
Argument Description Default
--n_bins Number of bins per feature (paper uses 256) 32
--top_k Interaction pairs to retain (paper uses 34) 10
--epochs Training epochs per stage 20
--lr Learning rate 0.01

To run tests:

pytest tests/test_ga2m.py -v

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants