Lrp for review#900
Conversation
There was a problem hiding this comment.
This is just some refactoring to pull out the visualization components for use with other saliency methods, I don't know how we may want to formalize visualization going forward or otherwise evolve that.
| if mode == "binary": | ||
| # Binary logits have shape [B] or [B, 1]. | ||
| # target_class_idx=1 (positive) → use logit as-is. | ||
| # target_class_idx=0 (negative) → negate so LRP attributes for the | ||
| # "not-positive" direction (relevance conservation still holds for -logit). | ||
| if target_class_idx is not None: | ||
| idx = ( | ||
| target_class_idx | ||
| if isinstance(target_class_idx, int) | ||
| else int(target_class_idx.item()) | ||
| ) | ||
| if idx == 0: | ||
| return -logits | ||
| return logits | ||
| if mode in ("multiclass", "multilabel"): | ||
| if target_class_idx is None: | ||
| target_class_idx = torch.argmax(logits, dim=-1) | ||
| elif not isinstance(target_class_idx, torch.Tensor): | ||
| target_class_idx = torch.tensor( | ||
| target_class_idx, device=logits.device | ||
| ) | ||
| batch_size = logits.size(0) | ||
| output_relevance = torch.zeros_like(logits) | ||
| output_relevance[range(batch_size), target_class_idx] = logits[ | ||
| range(batch_size), target_class_idx | ||
| ] | ||
| return output_relevance |
There was a problem hiding this comment.
The target_class_index hanlding has been unified in #926. With the following semantics
- binary will always have targret_class_index 0
- the attribution is always towards positive, and up to the evaluator to flip the sign for negative class contribution
- if not specified, both multi-class and multi-label uses argmax
| ReLULRPHandler, | ||
| Conv2dLRPHandler, | ||
| MaxPool2dLRPHandler, | ||
| AvgPool2dLRPHandler, | ||
| AdaptiveAvgPool2dLRPHandler, | ||
| BatchNorm2dLRPHandler, | ||
| FlattenLRPHandler, | ||
| DropoutLRPHandler, | ||
| RNNLRPHandler, |
There was a problem hiding this comment.
Do we need special handling for Conv1d, Sigmoid, Tanh, Softmax in StageNet?
| def _propagate_through_residual_block( | ||
| self, | ||
| block_name: str, | ||
| block_layer_names_rev: list, | ||
| r_in: torch.Tensor, | ||
| ctx: _LRPHookContext, | ||
| ) -> torch.Tensor: |
There was a problem hiding this comment.
this seems to ignore rule args for LRP and hard code to epsilon rule regardless if it is epsilon rule or alpha beta?
| c_neg = crop_spatial(F.conv_transpose2d(s, W_neg, None, **trans_kw), x.shape) | ||
| return x * (alpha * c_pos + beta * c_neg) |
There was a problem hiding this comment.
The linear uses x_p * (alpha * c_pos - beta * c_neg). Is it expected for them to be different?
Looking to add some alpha, beta, and epsilon LRP saliency options run across a few different models (both the MIMIC IV stagenet and the chest x-ray CNN)
Original implementation is based off of https://arxiv.org/abs/1604.00825 and is very vibey so I'm expecting some comments
I may prune down the example options a bit as this ended up a bit bloated, and I'll need to add some more documentation