Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 43 additions & 12 deletions monai/transforms/compose.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,31 @@
__all__ = ["Compose", "OneOf", "RandomOrder", "SomeOf", "execute_compose"]


def _inverse_one(
t: InvertibleTransform, data: Any, map_items: bool | int, unpack_items: bool, log_stats: bool | str
) -> Any:
"""Invert a single transform, delegating directly to nested ``Compose`` objects.

When ``t`` is a ``Compose`` instance its own ``inverse()`` is called so that
the child's ``map_items`` setting is respected. For all other invertible
transforms, ``apply_transform`` is used with ``lazy=False``.

Args:
t: The invertible transform to invert.
data: Data to be inverted.
map_items: Whether to map over list/tuple items (forwarded to
``apply_transform`` for non-``Compose`` transforms).
unpack_items: Whether to unpack data as parameters.
log_stats: Logger name or boolean for logging.

Returns:
The inverted data.
"""
if isinstance(t, Compose):
return t.inverse(data)
return apply_transform(t.inverse, data, map_items, unpack_items, lazy=False, log_stats=log_stats)


def execute_compose(
data: NdarrayOrTensor | Sequence[NdarrayOrTensor] | Mapping[Any, NdarrayOrTensor],
transforms: Sequence[Any],
Expand Down Expand Up @@ -315,20 +340,32 @@ def get_index_of_first(self, predicate):
return None

def flatten(self):
"""Return a Composition with a simple list of transforms, as opposed to any nested Compositions.
"""Return a Composition with a flattened list of transforms.

Nested ``Compose`` objects that share the same ``map_items`` setting as
the parent are inlined. Nested ``Compose`` objects with a *different*
``map_items`` value are kept as-is so their item-mapping behaviour is
preserved at runtime and during inversion.

e.g., `t1 = Compose([x, x, x, x, Compose([Compose([x, x]), x, x])]).flatten()`
will result in the equivalent of `t1 = Compose([x, x, x, x, x, x, x, x])`.

"""
new_transforms = []
for t in self.transforms:
if type(t) is Compose: # nopep8
if type(t) is Compose and t.map_items == self.map_items:
new_transforms += t.flatten().transforms
else:
new_transforms.append(t)

return Compose(new_transforms)
return Compose(
new_transforms,
map_items=self.map_items,
unpack_items=self.unpack_items,
log_stats=self.log_stats,
lazy=self._lazy,
overrides=self.overrides,
)
Comment thread
coderabbitai[bot] marked this conversation as resolved.

def __len__(self):
"""Return number of transformations."""
Expand Down Expand Up @@ -365,9 +402,7 @@ def inverse(self, data):
)
# loop backwards over transforms
for t in reversed(invertible_transforms):
data = apply_transform(
t.inverse, data, self.map_items, self.unpack_items, lazy=False, log_stats=self.log_stats
)
data = _inverse_one(t, data, self.map_items, self.unpack_items, self.log_stats)
return data

@staticmethod
Expand Down Expand Up @@ -622,9 +657,7 @@ def inverse(self, data):
# loop backwards over transforms
for o in reversed(applied_order):
if isinstance(self.transforms[o], InvertibleTransform):
data = apply_transform(
self.transforms[o].inverse, data, self.map_items, self.unpack_items, log_stats=self.log_stats
)
data = _inverse_one(self.transforms[o], data, self.map_items, self.unpack_items, self.log_stats)
return data


Expand Down Expand Up @@ -789,8 +822,6 @@ def inverse(self, data):
# loop backwards over transforms
for o in reversed(applied_order):
if isinstance(self.transforms[o], InvertibleTransform):
data = apply_transform(
self.transforms[o].inverse, data, self.map_items, self.unpack_items, log_stats=self.log_stats
)
data = _inverse_one(self.transforms[o], data, self.map_items, self.unpack_items, self.log_stats)

return data
11 changes: 7 additions & 4 deletions monai/transforms/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,10 +143,13 @@ def apply_transform(
try:
map_items_ = int(map_items) if isinstance(map_items, bool) else map_items
if isinstance(data, (list, tuple)) and map_items_ > 0 and not isinstance(transform, ReduceTrait):
return [
apply_transform(transform, item, map_items_ - 1, unpack_items, log_stats, lazy, overrides)
for item in data
]
# If the transform is a Compose with its own map_items, let it handle list/tuple
# expansion internally so that nested Compose map_items settings are respected.
if not isinstance(transform, transforms.compose.Compose):
return [
apply_transform(transform, item, map_items_ - 1, unpack_items, log_stats, lazy, overrides)
for item in data
]
return _apply_transform(transform, data, unpack_items, lazy, overrides, log_stats)
except Exception as e:
# if in debug mode, don't swallow exception so that the breakpoint
Expand Down
159 changes: 159 additions & 0 deletions tests/transforms/compose/test_compose.py
Original file line number Diff line number Diff line change
Expand Up @@ -775,6 +775,165 @@ def test_compose_execute_equivalence_with_flags(self, flags, data, pipeline):
self.assertEqual(expected, actual)


class TestNestedComposeMapItems(unittest.TestCase):
"""Tests for nested Compose respecting child map_items (issues #7932, #7565)."""

def test_child_map_items_false_receives_list(self):
"""Parent map_items=True, child map_items=False: child receives list as-is."""

def split(x):
return [x + 1, x + 2]

def sum_list(items):
return sum(items)

# The child Compose(map_items=False) should receive the list from split()
# and pass it as-is to sum_list, rather than the parent expanding the list.
pipeline = mt.Compose([split, mt.Compose([sum_list], map_items=False)])
result = pipeline(10)
self.assertEqual(result, 23) # (10+1) + (10+2) = 23

def test_inverse_respects_child_map_items(self):
"""Inverse path should delegate to child Compose.inverse directly."""
pipeline = mt.Compose([mt.Flip(0), mt.Compose([mt.Flip(1)], map_items=False)])
data = torch.randn(1, 4, 4)
result = pipeline(data)
restored = pipeline.inverse(result)
torch.testing.assert_close(data, restored)

def test_parent_no_map_child_map(self):
"""Parent map_items=False, child map_items=True: child maps over items."""

def double(x):
return x * 2

# Parent treats the list as a single value; child maps double() over each item.
pipeline = mt.Compose([mt.Compose([double], map_items=True)], map_items=False)
result = pipeline([1, 2, 3])
self.assertEqual(result, [2, 4, 6])

def test_flatten_preserves_different_map_items(self):
"""flatten() should not merge a child Compose with different map_items."""

def noop(x):
return x

parent = mt.Compose([noop, mt.Compose([noop, noop], map_items=False), noop])
flat = parent.flatten()
# The inner Compose(map_items=False) should NOT be flattened
self.assertEqual(len(flat.transforms), 3)
self.assertIsInstance(flat.transforms[1], mt.Compose)

def test_multiple_children_with_mixed_map_items(self):
"""Multiple internal Composes with different map_items should be handled correctly."""

def add_one(items):
if isinstance(items, list):
return [x + 1 for x in items]
return items + 1

def multiply_two(items):
if isinstance(items, list):
return [x * 2 for x in items]
return items * 2

# Parent with map_items=False processes the entire input as one unit
# Child 1 (map_items=True) will map over each item in what it receives
# Child 2 (map_items=False) will process the entire thing
pipeline = mt.Compose(
[mt.Compose([add_one], map_items=True), mt.Compose([multiply_two], map_items=False)], map_items=False
)

# Input [1, 2, 3]
# First child with map_items=True maps add_one over [1,2,3]: [2, 3, 4]
# Second child with map_items=False receives [2,3,4] and applies multiply_two: [4, 6, 8]
result = pipeline([1, 2, 3])
self.assertEqual(result, [4, 6, 8])

def test_flatten_with_multiple_children_preserves_both(self):
"""flatten() should preserve child with different map_items but flatten child with same."""

def noop(x):
return x

parent = mt.Compose(
[
noop,
mt.Compose([noop, noop], map_items=True), # Same as parent, will be flattened
mt.Compose([noop, noop], map_items=False), # Different, will be preserved
noop,
]
)
flat = parent.flatten()
# First nested Compose(map_items=True) will be flattened into parent
# Second nested Compose(map_items=False) will be preserved
# Result: noop + noop + noop + Compose([noop, noop]) + noop = 5 transforms
self.assertEqual(len(flat.transforms), 5)
# Check that the preserved one is at the correct position
self.assertIsInstance(flat.transforms[3], mt.Compose)
self.assertEqual(flat.transforms[3].map_items, False)

def test_three_level_nesting_respects_different_map_items(self):
"""Three-level nesting with different map_items at each level."""

def add_one(x):
return x + 1

# Level 1 (outermost): map_items=True (default)
# Level 2: map_items=False
# Level 3: map_items=True (same as level 2, so will be flattened into level 2)
innermost = mt.Compose([add_one], map_items=True)
middle = mt.Compose([add_one, innermost], map_items=False)
outer = mt.Compose([middle])

# Test with a simple value
# outer has map_items=True (default), middle has map_items=False
# So middle should be preserved and receive the input as-is
result = outer(5)
# outer(5) -> maps to middle -> middle(5) with map_items=False
# middle(5) -> add_one(5) = 6, then innermost(6) with map_items=True
# innermost(6) -> add_one(6) = 7
self.assertEqual(result, 7)

def test_inverse_with_multiple_children_different_map_items(self):
"""Inverse should work correctly with multiple children having different map_items."""
pipeline = mt.Compose(
[mt.Flip(0), mt.Compose([mt.Flip(1)], map_items=False), mt.Compose([mt.Flip(0)], map_items=True)]
)
data = torch.randn(2, 4, 4)
result = pipeline(data)
restored = pipeline.inverse(result)
torch.testing.assert_close(data, restored)

def test_flatten_with_mixed_same_and_different_map_items(self):
"""flatten() should merge children with same map_items but preserve those with different."""

def noop(x):
return x

# Parent has map_items=True (default)
# Child 1 has map_items=True (same as parent) -> should be flattened
# Child 2 has map_items=False (different from parent) -> should NOT be flattened
parent = mt.Compose(
[
noop,
mt.Compose([noop, noop], map_items=True), # Same as parent, will be flattened
mt.Compose([noop, noop], map_items=False), # Different from parent, will be preserved
noop,
]
)
flat = parent.flatten()
# After flatten:
# - noop (preserved)
# - 2 noops from first Compose (flattened because map_items=True matches parent)
# - Compose([noop, noop], map_items=False) (preserved because different)
# - noop (preserved)
# Total: 5 transforms
self.assertEqual(len(flat.transforms), 5)
self.assertIsInstance(flat.transforms[3], mt.Compose)
self.assertEqual(flat.transforms[3].map_items, False)


class TestComposeCallableInput(unittest.TestCase):

def test_value_error_when_not_sequence(self):
Expand Down
Loading