From de2b6bd65ae84c76781ca387f834fc5abb5fb08a Mon Sep 17 00:00:00 2001 From: Fabio Ferreira Date: Wed, 3 Sep 2025 14:40:13 +0100 Subject: [PATCH 01/11] feat: add optional gradient checkpointing to unet --- monai/networks/nets/unet.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/monai/networks/nets/unet.py b/monai/networks/nets/unet.py index eac0ddab39..c9758e4cdf 100644 --- a/monai/networks/nets/unet.py +++ b/monai/networks/nets/unet.py @@ -13,9 +13,11 @@ import warnings from collections.abc import Sequence +from typing import cast import torch import torch.nn as nn +from torch.utils.checkpoint import checkpoint from monai.networks.blocks.convolutions import Convolution, ResidualUnit from monai.networks.layers.factories import Act, Norm @@ -23,6 +25,15 @@ __all__ = ["UNet", "Unet"] +class _ActivationCheckpointWrapper(nn.Module): + def __init__(self, module: nn.Module) -> None: + super().__init__() + self.module = module + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.training: + return cast(torch.Tensor, checkpoint(self.module, x, use_reentrant=False)) + return cast(torch.Tensor, self.module(x)) class UNet(nn.Module): """ @@ -118,6 +129,7 @@ def __init__( dropout: float = 0.0, bias: bool = True, adn_ordering: str = "NDA", + use_checkpointing: bool = False, ) -> None: super().__init__() @@ -146,6 +158,7 @@ def __init__( self.dropout = dropout self.bias = bias self.adn_ordering = adn_ordering + self.use_checkpointing = use_checkpointing def _create_block( inc: int, outc: int, channels: Sequence[int], strides: Sequence[int], is_top: bool @@ -192,6 +205,8 @@ def _get_connection_block(self, down_path: nn.Module, up_path: nn.Module, subblo subblock: block defining the next layer in the network. Returns: block for this layer: `nn.Sequential(down_path, SkipConnection(subblock), up_path)` """ + if self.use_checkpointing: + subblock = _ActivationCheckpointWrapper(subblock) return nn.Sequential(down_path, SkipConnection(subblock), up_path) def _get_down_layer(self, in_channels: int, out_channels: int, strides: int, is_top: bool) -> nn.Module: From 66edcb508243f53c4f10af93d6ebfca9a32fe4ef Mon Sep 17 00:00:00 2001 From: Fabio Ferreira Date: Wed, 3 Sep 2025 14:44:27 +0100 Subject: [PATCH 02/11] fix: small ruff issue --- monai/networks/nets/unet.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/monai/networks/nets/unet.py b/monai/networks/nets/unet.py index c9758e4cdf..3fe20dc12f 100644 --- a/monai/networks/nets/unet.py +++ b/monai/networks/nets/unet.py @@ -25,6 +25,7 @@ __all__ = ["UNet", "Unet"] + class _ActivationCheckpointWrapper(nn.Module): def __init__(self, module: nn.Module) -> None: super().__init__() @@ -35,6 +36,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return cast(torch.Tensor, checkpoint(self.module, x, use_reentrant=False)) return cast(torch.Tensor, self.module(x)) + class UNet(nn.Module): """ Enhanced version of UNet which has residual units implemented with the ResidualUnit class. From e66e3578b48630703a0bbfc7aadfe0f68c550f95 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?F=C3=A1bio=20S=2E=20Ferreira?= Date: Thu, 4 Sep 2025 15:36:15 +0100 Subject: [PATCH 03/11] Update monai/networks/nets/unet.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com> Signed-off-by: Fábio S. Ferreira --- monai/networks/nets/unet.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/monai/networks/nets/unet.py b/monai/networks/nets/unet.py index 3fe20dc12f..cced0f950b 100644 --- a/monai/networks/nets/unet.py +++ b/monai/networks/nets/unet.py @@ -32,8 +32,12 @@ def __init__(self, module: nn.Module) -> None: self.module = module def forward(self, x: torch.Tensor) -> torch.Tensor: - if self.training: - return cast(torch.Tensor, checkpoint(self.module, x, use_reentrant=False)) + if self.training and torch.is_grad_enabled() and x.requires_grad: + try: + return cast(torch.Tensor, checkpoint(self.module, x, use_reentrant=False)) + except TypeError: + # Fallback for older PyTorch without `use_reentrant` + return cast(torch.Tensor, checkpoint(self.module, x)) return cast(torch.Tensor, self.module(x)) From feefcaa3944f56fba163475cc5ef4d0da28ceddf Mon Sep 17 00:00:00 2001 From: Fabio Ferreira Date: Thu, 4 Sep 2025 16:01:24 +0100 Subject: [PATCH 04/11] docs: update docstrings --- monai/networks/nets/unet.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/monai/networks/nets/unet.py b/monai/networks/nets/unet.py index cced0f950b..8ad48a1d12 100644 --- a/monai/networks/nets/unet.py +++ b/monai/networks/nets/unet.py @@ -27,6 +27,7 @@ class _ActivationCheckpointWrapper(nn.Module): + """Apply activation checkpointing to the wrapped module during training.""" def __init__(self, module: nn.Module) -> None: super().__init__() self.module = module @@ -86,6 +87,8 @@ class UNet(nn.Module): if a conv layer is directly followed by a batch norm layer, bias should be False. adn_ordering: a string representing the ordering of activation (A), normalization (N), and dropout (D). Defaults to "NDA". See also: :py:class:`monai.networks.blocks.ADN`. + use_checkpointing: if True, apply activation checkpointing to internal sub-blocks during training to reduce memory + at the cost of extra compute. Checkpointing is bypassed in eval and when gradients are disabled. Defaults to False. Examples:: From e11245797957206e7c8ed25637b059cdc318b4f5 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 4 Sep 2025 15:01:53 +0000 Subject: [PATCH 05/11] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- monai/networks/nets/unet.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/networks/nets/unet.py b/monai/networks/nets/unet.py index 8ad48a1d12..5f4c2222f9 100644 --- a/monai/networks/nets/unet.py +++ b/monai/networks/nets/unet.py @@ -87,7 +87,7 @@ class UNet(nn.Module): if a conv layer is directly followed by a batch norm layer, bias should be False. adn_ordering: a string representing the ordering of activation (A), normalization (N), and dropout (D). Defaults to "NDA". See also: :py:class:`monai.networks.blocks.ADN`. - use_checkpointing: if True, apply activation checkpointing to internal sub-blocks during training to reduce memory + use_checkpointing: if True, apply activation checkpointing to internal sub-blocks during training to reduce memory at the cost of extra compute. Checkpointing is bypassed in eval and when gradients are disabled. Defaults to False. Examples:: From f673ca1453020bba8d9690c3745bb2dc917a806a Mon Sep 17 00:00:00 2001 From: Fabio Ferreira Date: Thu, 4 Sep 2025 16:17:02 +0100 Subject: [PATCH 06/11] fix: avoid BatchNorm subblocks --- monai/networks/nets/unet.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/monai/networks/nets/unet.py b/monai/networks/nets/unet.py index 5f4c2222f9..f010fd4a86 100644 --- a/monai/networks/nets/unet.py +++ b/monai/networks/nets/unet.py @@ -30,10 +30,19 @@ class _ActivationCheckpointWrapper(nn.Module): """Apply activation checkpointing to the wrapped module during training.""" def __init__(self, module: nn.Module) -> None: super().__init__() + # Pre-detect BatchNorm presence for fast path + self._has_bn = any(isinstance(m, nn.modules.batchnorm._BatchNorm) for m in module.modules()) self.module = module def forward(self, x: torch.Tensor) -> torch.Tensor: if self.training and torch.is_grad_enabled() and x.requires_grad: + if self._has_bn: + warnings.warn( + "Activation checkpointing skipped for a subblock containing BatchNorm to avoid double-updating " + "running statistics during recomputation.", + RuntimeWarning, + ) + return cast(torch.Tensor, self.module(x)) try: return cast(torch.Tensor, checkpoint(self.module, x, use_reentrant=False)) except TypeError: From 69540ffe7d16fa81bb30cd0c1c09186c0b59d9da Mon Sep 17 00:00:00 2001 From: Fabio Ferreira Date: Thu, 4 Sep 2025 17:05:03 +0100 Subject: [PATCH 07/11] fix: revert batch norm changes --- monai/networks/nets/unet.py | 9 --------- 1 file changed, 9 deletions(-) diff --git a/monai/networks/nets/unet.py b/monai/networks/nets/unet.py index f010fd4a86..5f4c2222f9 100644 --- a/monai/networks/nets/unet.py +++ b/monai/networks/nets/unet.py @@ -30,19 +30,10 @@ class _ActivationCheckpointWrapper(nn.Module): """Apply activation checkpointing to the wrapped module during training.""" def __init__(self, module: nn.Module) -> None: super().__init__() - # Pre-detect BatchNorm presence for fast path - self._has_bn = any(isinstance(m, nn.modules.batchnorm._BatchNorm) for m in module.modules()) self.module = module def forward(self, x: torch.Tensor) -> torch.Tensor: if self.training and torch.is_grad_enabled() and x.requires_grad: - if self._has_bn: - warnings.warn( - "Activation checkpointing skipped for a subblock containing BatchNorm to avoid double-updating " - "running statistics during recomputation.", - RuntimeWarning, - ) - return cast(torch.Tensor, self.module(x)) try: return cast(torch.Tensor, checkpoint(self.module, x, use_reentrant=False)) except TypeError: From 42ec757a76dc9c476b7dd302fb9352eee168b9b3 Mon Sep 17 00:00:00 2001 From: Fabio Ferreira Date: Wed, 1 Oct 2025 16:56:41 +0100 Subject: [PATCH 08/11] refactor: creates a subclass of UNet and overrides the get connection block method --- monai/networks/nets/unet.py | 16 +++++----------- 1 file changed, 5 insertions(+), 11 deletions(-) diff --git a/monai/networks/nets/unet.py b/monai/networks/nets/unet.py index 5f4c2222f9..4a67a4180f 100644 --- a/monai/networks/nets/unet.py +++ b/monai/networks/nets/unet.py @@ -33,13 +33,7 @@ def __init__(self, module: nn.Module) -> None: self.module = module def forward(self, x: torch.Tensor) -> torch.Tensor: - if self.training and torch.is_grad_enabled() and x.requires_grad: - try: - return cast(torch.Tensor, checkpoint(self.module, x, use_reentrant=False)) - except TypeError: - # Fallback for older PyTorch without `use_reentrant` - return cast(torch.Tensor, checkpoint(self.module, x)) - return cast(torch.Tensor, self.module(x)) + return cast(torch.Tensor, checkpoint(self.module, x, use_reentrant=False)) class UNet(nn.Module): @@ -138,7 +132,6 @@ def __init__( dropout: float = 0.0, bias: bool = True, adn_ordering: str = "NDA", - use_checkpointing: bool = False, ) -> None: super().__init__() @@ -167,7 +160,6 @@ def __init__( self.dropout = dropout self.bias = bias self.adn_ordering = adn_ordering - self.use_checkpointing = use_checkpointing def _create_block( inc: int, outc: int, channels: Sequence[int], strides: Sequence[int], is_top: bool @@ -214,8 +206,6 @@ def _get_connection_block(self, down_path: nn.Module, up_path: nn.Module, subblo subblock: block defining the next layer in the network. Returns: block for this layer: `nn.Sequential(down_path, SkipConnection(subblock), up_path)` """ - if self.use_checkpointing: - subblock = _ActivationCheckpointWrapper(subblock) return nn.Sequential(down_path, SkipConnection(subblock), up_path) def _get_down_layer(self, in_channels: int, out_channels: int, strides: int, is_top: bool) -> nn.Module: @@ -321,5 +311,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.model(x) return x +class CheckpointUNet(UNet): + def _get_connection_block(self, down_path: nn.Module, up_path: nn.Module, subblock: nn.Module) -> nn.Module: + subblock = _ActivationCheckpointWrapper(subblock) + return super()._get_connection_block(down_path, up_path, subblock) Unet = UNet From a2e8474abf79552cb4c041c69583261ad16c7049 Mon Sep 17 00:00:00 2001 From: Fabio Ferreira Date: Wed, 1 Oct 2025 17:13:04 +0100 Subject: [PATCH 09/11] chore: remove use checkpointing from doc string --- monai/networks/nets/unet.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/monai/networks/nets/unet.py b/monai/networks/nets/unet.py index 4a67a4180f..24e56c96a4 100644 --- a/monai/networks/nets/unet.py +++ b/monai/networks/nets/unet.py @@ -81,8 +81,6 @@ class UNet(nn.Module): if a conv layer is directly followed by a batch norm layer, bias should be False. adn_ordering: a string representing the ordering of activation (A), normalization (N), and dropout (D). Defaults to "NDA". See also: :py:class:`monai.networks.blocks.ADN`. - use_checkpointing: if True, apply activation checkpointing to internal sub-blocks during training to reduce memory - at the cost of extra compute. Checkpointing is bypassed in eval and when gradients are disabled. Defaults to False. Examples:: From 4c4782e6a4d9156f3eeebf90543b2f1699ab3d72 Mon Sep 17 00:00:00 2001 From: Fabio Ferreira Date: Thu, 2 Oct 2025 13:50:55 +0100 Subject: [PATCH 10/11] fix: linting issues --- monai/networks/nets/unet.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/monai/networks/nets/unet.py b/monai/networks/nets/unet.py index 24e56c96a4..0f380a1be7 100644 --- a/monai/networks/nets/unet.py +++ b/monai/networks/nets/unet.py @@ -28,6 +28,7 @@ class _ActivationCheckpointWrapper(nn.Module): """Apply activation checkpointing to the wrapped module during training.""" + def __init__(self, module: nn.Module) -> None: super().__init__() self.module = module @@ -309,9 +310,11 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.model(x) return x + class CheckpointUNet(UNet): def _get_connection_block(self, down_path: nn.Module, up_path: nn.Module, subblock: nn.Module) -> nn.Module: subblock = _ActivationCheckpointWrapper(subblock) return super()._get_connection_block(down_path, up_path, subblock) + Unet = UNet From 515c659ee6f0587d25935ae728195266cf340422 Mon Sep 17 00:00:00 2001 From: Fabio Ferreira Date: Wed, 8 Oct 2025 09:53:08 +0100 Subject: [PATCH 11/11] feat: add activation checkpointing to down and up paths to be more efficient --- monai/networks/nets/unet.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/monai/networks/nets/unet.py b/monai/networks/nets/unet.py index 0f380a1be7..226f4630bf 100644 --- a/monai/networks/nets/unet.py +++ b/monai/networks/nets/unet.py @@ -314,6 +314,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class CheckpointUNet(UNet): def _get_connection_block(self, down_path: nn.Module, up_path: nn.Module, subblock: nn.Module) -> nn.Module: subblock = _ActivationCheckpointWrapper(subblock) + down_path = _ActivationCheckpointWrapper(down_path) + up_path = _ActivationCheckpointWrapper(up_path) return super()._get_connection_block(down_path, up_path, subblock)