Skip to content
21 changes: 21 additions & 0 deletions monai/networks/nets/unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,11 @@

import warnings
from collections.abc import Sequence
from typing import cast

import torch
import torch.nn as nn
from torch.utils.checkpoint import checkpoint

from monai.networks.blocks.convolutions import Convolution, ResidualUnit
from monai.networks.layers.factories import Act, Norm
Expand All @@ -24,6 +26,17 @@
__all__ = ["UNet", "Unet"]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Add CheckpointUNet to __all__ exports.

CheckpointUNet is a public class but not exported in __all__.

Apply this diff:

-__all__ = ["UNet", "Unet"]
+__all__ = ["UNet", "Unet", "CheckpointUNet"]
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
__all__ = ["UNet", "Unet"]
__all__ = ["UNet", "Unet", "CheckpointUNet"]
🤖 Prompt for AI Agents
In monai/networks/nets/unet.py around line 26, update the module export list to
include the public class CheckpointUNet by adding "CheckpointUNet" to the
__all__ array; modify the existing __all__ = ["UNet", "Unet"] to include
"CheckpointUNet" so the final export list contains "UNet", "Unet", and
"CheckpointUNet".



class _ActivationCheckpointWrapper(nn.Module):
"""Apply activation checkpointing to the wrapped module during training."""

def __init__(self, module: nn.Module) -> None:
super().__init__()
self.module = module

def forward(self, x: torch.Tensor) -> torch.Tensor:
return cast(torch.Tensor, checkpoint(self.module, x, use_reentrant=False))
Comment on lines +29 to +37
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion | 🟠 Major

Add Google-style docstrings.

Class and forward docstrings need Args/Returns sections per guidelines. Document the wrapped module, checkpoint guard details, and returned tensor.

As per coding guidelines.

🤖 Prompt for AI Agents
In monai/networks/nets/unet.py around lines 29 to 37, the
_ActivationCheckpointWrapper class and its forward method lack Google-style
docstrings; add a class-level docstring that briefly describes purpose, document
the module parameter as "module: nn.Module — module to wrap for activation
checkpointing", mention that checkpointing is applied during training to save
memory and that use_reentrant=False is used as the checkpoint guard, and add a
forward method docstring with Args: x (torch.Tensor): input tensor to the
wrapped module and Returns: torch.Tensor: output tensor from the wrapped module
(with activations checkpointed); keep wording concise and follow Google-style
"Args/Returns" formatting.



class UNet(nn.Module):
"""
Enhanced version of UNet which has residual units implemented with the ResidualUnit class.
Expand Down Expand Up @@ -298,4 +311,12 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
return x


class CheckpointUNet(UNet):
def _get_connection_block(self, down_path: nn.Module, up_path: nn.Module, subblock: nn.Module) -> nn.Module:
subblock = _ActivationCheckpointWrapper(subblock)
down_path = _ActivationCheckpointWrapper(down_path)
up_path = _ActivationCheckpointWrapper(up_path)
return super()._get_connection_block(down_path, up_path, subblock)


Unet = UNet
Loading