diff --git a/monai/networks/nets/unet.py b/monai/networks/nets/unet.py index eac0ddab39..226f4630bf 100644 --- a/monai/networks/nets/unet.py +++ b/monai/networks/nets/unet.py @@ -13,9 +13,11 @@ import warnings from collections.abc import Sequence +from typing import cast import torch import torch.nn as nn +from torch.utils.checkpoint import checkpoint from monai.networks.blocks.convolutions import Convolution, ResidualUnit from monai.networks.layers.factories import Act, Norm @@ -24,6 +26,17 @@ __all__ = ["UNet", "Unet"] +class _ActivationCheckpointWrapper(nn.Module): + """Apply activation checkpointing to the wrapped module during training.""" + + def __init__(self, module: nn.Module) -> None: + super().__init__() + self.module = module + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return cast(torch.Tensor, checkpoint(self.module, x, use_reentrant=False)) + + class UNet(nn.Module): """ Enhanced version of UNet which has residual units implemented with the ResidualUnit class. @@ -298,4 +311,12 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return x +class CheckpointUNet(UNet): + def _get_connection_block(self, down_path: nn.Module, up_path: nn.Module, subblock: nn.Module) -> nn.Module: + subblock = _ActivationCheckpointWrapper(subblock) + down_path = _ActivationCheckpointWrapper(down_path) + up_path = _ActivationCheckpointWrapper(up_path) + return super()._get_connection_block(down_path, up_path, subblock) + + Unet = UNet