Skip to content

Lrp for review#900

Open
Nimanui wants to merge 14 commits intosunlabuiuc:masterfrom
Nimanui:lrp-for-review
Open

Lrp for review#900
Nimanui wants to merge 14 commits intosunlabuiuc:masterfrom
Nimanui:lrp-for-review

Conversation

@Nimanui
Copy link
Copy Markdown
Contributor

@Nimanui Nimanui commented Mar 20, 2026

Looking to add some alpha, beta, and epsilon LRP saliency options run across a few different models (both the MIMIC IV stagenet and the chest x-ray CNN)
Original implementation is based off of https://arxiv.org/abs/1604.00825 and is very vibey so I'm expecting some comments
I may prune down the example options a bit as this ended up a bit bloated, and I'll need to add some more documentation

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is just some refactoring to pull out the visualization components for use with other saliency methods, I don't know how we may want to formalize visualization going forward or otherwise evolve that.

Comment on lines +405 to +431
if mode == "binary":
# Binary logits have shape [B] or [B, 1].
# target_class_idx=1 (positive) → use logit as-is.
# target_class_idx=0 (negative) → negate so LRP attributes for the
# "not-positive" direction (relevance conservation still holds for -logit).
if target_class_idx is not None:
idx = (
target_class_idx
if isinstance(target_class_idx, int)
else int(target_class_idx.item())
)
if idx == 0:
return -logits
return logits
if mode in ("multiclass", "multilabel"):
if target_class_idx is None:
target_class_idx = torch.argmax(logits, dim=-1)
elif not isinstance(target_class_idx, torch.Tensor):
target_class_idx = torch.tensor(
target_class_idx, device=logits.device
)
batch_size = logits.size(0)
output_relevance = torch.zeros_like(logits)
output_relevance[range(batch_size), target_class_idx] = logits[
range(batch_size), target_class_idx
]
return output_relevance
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The target_class_index hanlding has been unified in #926. With the following semantics

  1. binary will always have targret_class_index 0
  2. the attribution is always towards positive, and up to the evaluator to flip the sign for negative class contribution
  3. if not specified, both multi-class and multi-label uses argmax

Comment on lines +26 to +34
ReLULRPHandler,
Conv2dLRPHandler,
MaxPool2dLRPHandler,
AvgPool2dLRPHandler,
AdaptiveAvgPool2dLRPHandler,
BatchNorm2dLRPHandler,
FlattenLRPHandler,
DropoutLRPHandler,
RNNLRPHandler,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need special handling for Conv1d, Sigmoid, Tanh, Softmax in StageNet?

Comment on lines +446 to +452
def _propagate_through_residual_block(
self,
block_name: str,
block_layer_names_rev: list,
r_in: torch.Tensor,
ctx: _LRPHookContext,
) -> torch.Tensor:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this seems to ignore rule args for LRP and hard code to epsilon rule regardless if it is epsilon rule or alpha beta?

Comment thread pyhealth/interpret/methods/lrp_base.py Outdated
Comment on lines +395 to +396
c_neg = crop_spatial(F.conv_transpose2d(s, W_neg, None, **trans_kw), x.shape)
return x * (alpha * c_pos + beta * c_neg)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The linear uses x_p * (alpha * c_pos - beta * c_neg). Is it expected for them to be different?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants