-
Notifications
You must be signed in to change notification settings - Fork 1.3k
feat: add activation checkpointing to unet #8554
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: dev
Are you sure you want to change the base?
feat: add activation checkpointing to unet #8554
Conversation
WalkthroughAdds a private _ActivationCheckpointWrapper(nn.Module) that applies torch.utils.checkpoint.checkpoint around a wrapped submodule during forward. Introduces a public CheckpointUNet(UNet) subclass that overrides _get_connection_block to return connection subblocks wrapped with _ActivationCheckpointWrapper. Adds imports for typing.cast and torch.utils.checkpoint.checkpoint. No other UNet logic or init changes in this diff. Estimated code review effort🎯 3 (Moderate) | ⏱️ ~25 minutes Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✨ Finishing touches
🧪 Generate unit tests
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🧹 Nitpick comments (4)
monai/networks/nets/unet.py (4)
29-33
: Add a brief class docstring to the wrapper.
Improves discoverability and meets docstring guidelines.Apply this diff:
class _ActivationCheckpointWrapper(nn.Module): - def __init__(self, module: nn.Module) -> None: + """Apply activation checkpointing to the wrapped module during training.""" + def __init__(self, module: nn.Module) -> None: super().__init__() self.module = module
134-135
: Document the newuse_checkpointing
arg in the class docstring and user docs.
State trade-offs (memory vs compute), that it’s training-only, incompatible withtorch.no_grad
, and preserves RNG by default.Proposed docstring snippet to add under “Args”:
use_checkpointing: if True, apply activation checkpointing to internal sub-blocks during training to reduce memory at the cost of extra compute. Checkpointing is bypassed in eval and when gradients are disabled. Defaults to False.I can open a docs patch and add a short example enabling the flag.
163-164
: Static wrapping caveat: runtime flips won’t take effect.
After init, changingself.use_checkpointing
won’t rewrap existing blocks. Either document this or add a small helper to (re)build the model if you expect runtime toggling.Do you expect users to toggle this at runtime? If yes, I can sketch a safe rewrap helper.
210-212
: Checkpointing scope is subblock-only; consider an optional broader policy.
Current placement is a good default. If more memory is needed, offer a policy to also wrapdown_path
/up_path
(with a warning about extra compute).I can add a
checkpoint_policy: Literal["subblock","all"] = "subblock"
in__init__
and wire it here on request.
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Knowledge Base: Disabled due to Reviews > Disable Knowledge Base setting
📒 Files selected for processing (1)
monai/networks/nets/unet.py
(5 hunks)
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py
⚙️ CodeRabbit configuration file
Review the Python code for quality and correctness. Ensure variable names adhere to PEP8 style guides, are sensible and informative in regards to their function, though permitting simple names for loop and comprehension variables. Ensure routine names are meaningful in regards to their function and use verbs, adjectives, and nouns in a semantically appropriate way. Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings. Examine code for logical error or inconsistencies, and suggest what may be changed to addressed these. Suggest any enhancements for code improving efficiency, maintainability, comprehensibility, and correctness. Ensure new or modified definitions will be covered by existing or new unit tests.
Files:
monai/networks/nets/unet.py
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (19)
- GitHub Check: min-dep-pytorch (2.7.1)
- GitHub Check: min-dep-pytorch (2.8.0)
- GitHub Check: min-dep-pytorch (2.5.1)
- GitHub Check: min-dep-os (macOS-latest)
- GitHub Check: min-dep-os (ubuntu-latest)
- GitHub Check: min-dep-pytorch (2.6.0)
- GitHub Check: min-dep-py3 (3.10)
- GitHub Check: min-dep-os (windows-latest)
- GitHub Check: min-dep-py3 (3.12)
- GitHub Check: min-dep-py3 (3.9)
- GitHub Check: min-dep-py3 (3.11)
- GitHub Check: flake8-py3 (mypy)
- GitHub Check: flake8-py3 (codeformat)
- GitHub Check: quick-py3 (macOS-latest)
- GitHub Check: quick-py3 (ubuntu-latest)
- GitHub Check: flake8-py3 (pytype)
- GitHub Check: build-docs
- GitHub Check: quick-py3 (windows-latest)
- GitHub Check: packaging
🔇 Additional comments (1)
monai/networks/nets/unet.py (1)
16-21
: Imports for checkpointing look good.
cast
andcheckpoint
are appropriate for the new wrapper.
Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com> Signed-off-by: Fábio S. Ferreira <ferreira.fabio80@gmail.com>
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
♻️ Duplicate comments (1)
monai/networks/nets/unet.py (1)
29-43
: Good guard + compatibility fallback.
Training/grad-enabled checks anduse_reentrant=False
withTypeError
fallback are the right call. This addresses the prior review note.
🧹 Nitpick comments (5)
monai/networks/nets/unet.py (5)
29-43
: Avoid per-iteration TypeError cost: detectuse_reentrant
support once.
Resolve support at import/init time to prevent raising an exception every forward on older torch.Apply:
@@ -class _ActivationCheckpointWrapper(nn.Module): +_SUPPORTS_USE_REENTRANT: bool | None = None + +class _ActivationCheckpointWrapper(nn.Module): @@ - def forward(self, x: torch.Tensor) -> torch.Tensor: - if self.training and torch.is_grad_enabled() and x.requires_grad: - try: - return cast(torch.Tensor, checkpoint(self.module, x, use_reentrant=False)) - except TypeError: - # Fallback for older PyTorch without `use_reentrant` - return cast(torch.Tensor, checkpoint(self.module, x)) - return cast(torch.Tensor, self.module(x)) + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.training and torch.is_grad_enabled() and x.requires_grad: + global _SUPPORTS_USE_REENTRANT + if _SUPPORTS_USE_REENTRANT is None: + try: + # probe once + checkpoint(self.module, x, use_reentrant=False) # type: ignore[arg-type] + _SUPPORTS_USE_REENTRANT = True + except TypeError: + _SUPPORTS_USE_REENTRANT = False + except Exception: + # do not change behavior on unexpected errors; fall back below + _SUPPORTS_USE_REENTRANT = False + if _SUPPORTS_USE_REENTRANT: + return cast(torch.Tensor, checkpoint(self.module, x, use_reentrant=False)) + return cast(torch.Tensor, checkpoint(self.module, x)) + return cast(torch.Tensor, self.module(x))Add outside the hunk (file header):
import inspect # if you switch to signature probing instead of try/exceptNote: PyTorch recommends passing
use_reentrant
explicitly going forward. (docs.pytorch.org)
29-43
: TorchScript: make wrapper script-safe.
try/except
and dynamic checkpoint calls won’t script. Short-circuit under scripting.Apply:
@@ - def forward(self, x: torch.Tensor) -> torch.Tensor: + def forward(self, x: torch.Tensor) -> torch.Tensor: + if torch.jit.is_scripting(): + # Avoid checkpoint in scripted graphs + return cast(torch.Tensor, self.module(x))
29-43
: Docstring completeness.
Add Google-style docstrings for the wrapper’s class/init/forward (inputs, returns, raises).Example:
@@ -class _ActivationCheckpointWrapper(nn.Module): - """Apply activation checkpointing to the wrapped module during training.""" +class _ActivationCheckpointWrapper(nn.Module): + """Wrap a module and apply activation checkpointing during training. + + Args: + module: The submodule to checkpoint. + + Returns: + torch.Tensor: Output tensor from the wrapped submodule. + + Raises: + RuntimeError: If checkpoint fails at runtime. + """
90-92
: Tighten theuse_checkpointing
docstring and add a BN caveat.
Keep it on one Args entry and note the BatchNorm limitation.- use_checkpointing: if True, apply activation checkpointing to internal sub-blocks during training to reduce memory - at the cost of extra compute. Checkpointing is bypassed in eval and when gradients are disabled. Defaults to False. + use_checkpointing: If True, apply activation checkpointing to internal sub-blocks during training to reduce + memory at the cost of extra compute. Bypassed in eval and when gradients are disabled. Note: avoid with + BatchNorm layers due to running-stat updates during recomputation. Defaults to False.Reference on RNG determinism (dropout is handled by default). (docs.pytorch.org)
217-219
: Wrap site is fine; consider BN detection here instead of inside the wrapper.
Keeps policy close to where wrapping occurs and avoids per-instance warnings.- if self.use_checkpointing: - subblock = _ActivationCheckpointWrapper(subblock) + if self.use_checkpointing: + has_bn = any(isinstance(m, nn.modules.batchnorm._BatchNorm) for m in subblock.modules()) + if has_bn: + warnings.warn( + "Skipping activation checkpointing for this subblock (contains BatchNorm).", + RuntimeWarning, + ) + else: + subblock = _ActivationCheckpointWrapper(subblock) return nn.Sequential(down_path, SkipConnection(subblock), up_path)Rationale: avoids double-updating BN stats during recomputation. (github.com)
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Knowledge Base: Disabled due to Reviews > Disable Knowledge Base setting
📒 Files selected for processing (1)
monai/networks/nets/unet.py
(6 hunks)
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py
⚙️ CodeRabbit configuration file
Review the Python code for quality and correctness. Ensure variable names adhere to PEP8 style guides, are sensible and informative in regards to their function, though permitting simple names for loop and comprehension variables. Ensure routine names are meaningful in regards to their function and use verbs, adjectives, and nouns in a semantically appropriate way. Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings. Examine code for logical error or inconsistencies, and suggest what may be changed to addressed these. Suggest any enhancements for code improving efficiency, maintainability, comprehensibility, and correctness. Ensure new or modified definitions will be covered by existing or new unit tests.
Files:
monai/networks/nets/unet.py
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (19)
- GitHub Check: min-dep-pytorch (2.6.0)
- GitHub Check: min-dep-py3 (3.12)
- GitHub Check: min-dep-py3 (3.10)
- GitHub Check: min-dep-py3 (3.11)
- GitHub Check: min-dep-py3 (3.9)
- GitHub Check: min-dep-pytorch (2.7.1)
- GitHub Check: min-dep-os (ubuntu-latest)
- GitHub Check: min-dep-pytorch (2.5.1)
- GitHub Check: min-dep-os (macOS-latest)
- GitHub Check: min-dep-pytorch (2.8.0)
- GitHub Check: min-dep-os (windows-latest)
- GitHub Check: quick-py3 (macOS-latest)
- GitHub Check: flake8-py3 (pytype)
- GitHub Check: quick-py3 (windows-latest)
- GitHub Check: flake8-py3 (mypy)
- GitHub Check: flake8-py3 (codeformat)
- GitHub Check: build-docs
- GitHub Check: packaging
- GitHub Check: quick-py3 (ubuntu-latest)
🔇 Additional comments (3)
monai/networks/nets/unet.py (3)
16-16
: LGTM: imports are correct and scoped.
cast
andcheckpoint
are needed by the wrapper; no issues.Also applies to: 20-20
141-141
: Public API addition: ensure tests and docs cover new flag.
Add unit tests for parity (on/off), eval bypass, andno_grad()
bypass; document in release notes/configs.Proposed minimal tests:
- Forward/backward equivalence within tolerance for
use_checkpointing={False,True}
on a tiny UNet.model.eval()
andtorch.no_grad()
paths skip checkpoint (nocheckpoint
mocks invoked).- Autocast path under CUDA runs without dtype mismatches.
Do you want a patch with pytest tests added under
tests/test_unet_checkpointing.py
?
170-170
: LGTM: stores flag on the instance.
No concerns.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (4)
monai/networks/nets/unet.py (4)
29-51
: Warn once for BN, and add param/return docstring per guidelines.Avoid repeated RuntimeWarnings and document the wrapper’s contract.
class _ActivationCheckpointWrapper(nn.Module): - """Apply activation checkpointing to the wrapped module during training.""" + """Apply activation checkpointing to the wrapped module during training. + + Args: + module: submodule to wrap. + Returns: + torch.Tensor: output of the wrapped module. + Warnings: + Skips checkpointing and emits a RuntimeWarning if the submodule contains + BatchNorm to avoid double-updating running stats. + """ def __init__(self, module: nn.Module) -> None: super().__init__() # Pre-detect BatchNorm presence for fast path self._has_bn = any(isinstance(m, nn.modules.batchnorm._BatchNorm) for m in module.modules()) self.module = module + self._bn_warned = False def forward(self, x: torch.Tensor) -> torch.Tensor: if self.training and torch.is_grad_enabled() and x.requires_grad: if self._has_bn: - warnings.warn( - "Activation checkpointing skipped for a subblock containing BatchNorm to avoid double-updating " - "running statistics during recomputation.", - RuntimeWarning, - ) + if not self._bn_warned: + warnings.warn( + "Activation checkpointing skipped for a subblock containing BatchNorm to avoid double-updating " + "running statistics during recomputation.", + RuntimeWarning, + ) + self._bn_warned = True return cast(torch.Tensor, self.module(x)) try: return cast(torch.Tensor, checkpoint(self.module, x, use_reentrant=False)) except TypeError: # Fallback for older PyTorch without `use_reentrant` return cast(torch.Tensor, checkpoint(self.module, x)) return cast(torch.Tensor, self.module(x))Minimal tests to add:
- Training vs eval parity (values match with/without checkpointing).
- BN subblock emits RuntimeWarning and bypasses checkpointing.
- Guard under
torch.no_grad()
and whenrequires_grad=False
.
99-101
: Clarifyuse_checkpointing
behavior in docs (BN, grad, training-only, build-time).Make expectations explicit for users.
- use_checkpointing: if True, apply activation checkpointing to internal sub-blocks during training to reduce memory - at the cost of extra compute. Checkpointing is bypassed in eval and when gradients are disabled. Defaults to False. + use_checkpointing: if True, apply activation checkpointing to internal sub-blocks during training to reduce + memory at the cost of extra compute. Checkpointing is active only when `self.training` is True, gradients + are enabled, and inputs require gradients; it is bypassed in eval and when grads are disabled. + Sub-blocks containing BatchNorm are not checkpointed to avoid double-updating running stats. + Note: this is a build-time option; changing it after initialization will not retroactively wrap existing + sub-blocks. Defaults to False.
179-180
: Flag is build-time only; consider asserting or documenting.Changing
self.use_checkpointing
post-init has no effect since wrappers are created during construction. The doc update above covers this; alternatively, convert to a read-only attribute to avoid confusion.
226-228
: Wrapping only the subblock is OK; consider optional coverage toggle.If desired later, expose an opt-in to also wrap
down
/up
blocks for additional savings.
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Knowledge Base: Disabled due to Reviews > Disable Knowledge Base setting
📒 Files selected for processing (1)
monai/networks/nets/unet.py
(6 hunks)
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py
⚙️ CodeRabbit configuration file
Review the Python code for quality and correctness. Ensure variable names adhere to PEP8 style guides, are sensible and informative in regards to their function, though permitting simple names for loop and comprehension variables. Ensure routine names are meaningful in regards to their function and use verbs, adjectives, and nouns in a semantically appropriate way. Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings. Examine code for logical error or inconsistencies, and suggest what may be changed to addressed these. Suggest any enhancements for code improving efficiency, maintainability, comprehensibility, and correctness. Ensure new or modified definitions will be covered by existing or new unit tests.
Files:
monai/networks/nets/unet.py
🔇 Additional comments (2)
monai/networks/nets/unet.py (2)
16-21
: Imports look good.Scoped import of
checkpoint
pluscast
is appropriate.
150-151
: API change is sensible and non-breaking.Parameter added at the end; default preserves behavior.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
♻️ Duplicate comments (1)
monai/networks/nets/unet.py (1)
29-43
: Skip checkpointing for subblocks with BatchNorm to avoid double-updating running stats.
Checkpoint recompute updates BN running stats twice under training. Detect BN in the wrapped module and bypass checkpointing with a warning.Apply this diff:
class _ActivationCheckpointWrapper(nn.Module): - """Apply activation checkpointing to the wrapped module during training.""" + """Apply activation checkpointing to the wrapped module during training. + Skips checkpointing for submodules containing BatchNorm to avoid double-updating + running statistics during recomputation. + """ def __init__(self, module: nn.Module) -> None: super().__init__() self.module = module + self._has_bn = any(isinstance(m, nn.modules.batchnorm._BatchNorm) for m in module.modules()) def forward(self, x: torch.Tensor) -> torch.Tensor: - if self.training and torch.is_grad_enabled() and x.requires_grad: + if self.training and torch.is_grad_enabled() and x.requires_grad: + if self._has_bn: + warnings.warn( + "Activation checkpointing skipped for a subblock containing BatchNorm " + "to avoid double-updating running statistics during recomputation.", + RuntimeWarning, + ) + return cast(torch.Tensor, self.module(x)) try: return cast(torch.Tensor, checkpoint(self.module, x, use_reentrant=False)) except TypeError: # Fallback for older PyTorch without `use_reentrant` return cast(torch.Tensor, checkpoint(self.module, x)) return cast(torch.Tensor, self.module(x))
🧹 Nitpick comments (3)
monai/networks/nets/unet.py (3)
90-92
: Clarify arg docs and surface BN caveat.
Tighten wording and document BN behavior for transparency.Apply this diff:
- use_checkpointing: if True, apply activation checkpointing to internal sub-blocks during training to reduce memory - at the cost of extra compute. Checkpointing is bypassed in eval and when gradients are disabled. Defaults to False. + use_checkpointing: If True, applies activation checkpointing to internal sub-blocks during training to reduce + memory at the cost of extra compute. Bypassed in eval mode and when gradients are disabled. + Note: sub-blocks containing BatchNorm are executed without checkpointing to avoid double-updating + running statistics. Defaults to False.
217-219
: Placement of wrapper is sensible; consider optional breadth control.
Future enhancement: expose a knob to checkpoint down/up paths too for deeper memory savings on very deep nets.
141-142
: Add tests to lock behavior.
- Parity: forward/backward equivalence (outputs/grad norms) with vs. without checkpointing.
- Modes: train vs. eval; torch.no_grad().
- Norms: with InstanceNorm and with BatchNorm (assert BN path skips with warning).
I can draft unit tests targeting UNet’s smallest config to keep runtime minimal—want me to open a follow-up?
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Knowledge Base: Disabled due to Reviews > Disable Knowledge Base setting
📒 Files selected for processing (1)
monai/networks/nets/unet.py
(6 hunks)
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py
⚙️ CodeRabbit configuration file
Review the Python code for quality and correctness. Ensure variable names adhere to PEP8 style guides, are sensible and informative in regards to their function, though permitting simple names for loop and comprehension variables. Ensure routine names are meaningful in regards to their function and use verbs, adjectives, and nouns in a semantically appropriate way. Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings. Examine code for logical error or inconsistencies, and suggest what may be changed to addressed these. Suggest any enhancements for code improving efficiency, maintainability, comprehensibility, and correctness. Ensure new or modified definitions will be covered by existing or new unit tests.
Files:
monai/networks/nets/unet.py
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (19)
- GitHub Check: build-docs
- GitHub Check: packaging
- GitHub Check: flake8-py3 (pytype)
- GitHub Check: quick-py3 (windows-latest)
- GitHub Check: quick-py3 (macOS-latest)
- GitHub Check: quick-py3 (ubuntu-latest)
- GitHub Check: flake8-py3 (mypy)
- GitHub Check: flake8-py3 (codeformat)
- GitHub Check: min-dep-os (ubuntu-latest)
- GitHub Check: min-dep-pytorch (2.8.0)
- GitHub Check: min-dep-os (macOS-latest)
- GitHub Check: min-dep-os (windows-latest)
- GitHub Check: min-dep-pytorch (2.6.0)
- GitHub Check: min-dep-py3 (3.9)
- GitHub Check: min-dep-py3 (3.12)
- GitHub Check: min-dep-pytorch (2.5.1)
- GitHub Check: min-dep-pytorch (2.7.1)
- GitHub Check: min-dep-py3 (3.11)
- GitHub Check: min-dep-py3 (3.10)
🔇 Additional comments (3)
monai/networks/nets/unet.py (3)
16-21
: LGTM: imports for cast/checkpoint are correct.
Direct import of checkpoint and use of typing.cast are appropriate.
35-42
: Validate AMP behavior under fallback (reentrant) checkpointing.
Older Torch (fallback path) may not replay autocast exactly; please verify mixed-precision parity.Minimal check: run a forward/backward with torch.autocast and compare loss/grad norms with and without checkpointing on a small UNet to ensure deltas are within numerical noise.
141-142
: API addition looks good.
Name and default match MONAI conventions.
monai/networks/nets/unet.py
Outdated
return cast(torch.Tensor, checkpoint(self.module, x, use_reentrant=False)) | ||
except TypeError: | ||
# Fallback for older PyTorch without `use_reentrant` | ||
return cast(torch.Tensor, checkpoint(self.module, x)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We don't support versions of PyTorch so old that use_reentrant
is missing so we can omit this along with the try block.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Makes sense, thanks.
Hi @ferreirafabio80 thanks for the contribution but I would suggest this isn't necessarily the way to go with adapting this class. Perhaps instead you can create a subclass of class CheckpointUNet(UNet):
def _get_connection_block(self, down_path: nn.Module, up_path: nn.Module, subblock: nn.Module) -> nn.Module:
subblock = _ActivationCheckpointWrapper(subblock)
return super()._get_connection_block(down_path, up_path, subblock) This would suffice for your own use if you just wanted such a definition. I think the I see also that |
Hi @ericspod, thank you for your comments. Yes, that also works. I've defined a subclass and overridden the method as you suggested. Regarding the I was probably being extremely careful with the checks in |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 5
♻️ Duplicate comments (2)
monai/networks/nets/unet.py (2)
35-36
: Missing training and gradient guards causes eval overhead and no_grad crashes.The forward unconditionally calls checkpoint. This will:
- Apply checkpointing during inference (eval mode) → unnecessary compute overhead.
- Fail under
torch.no_grad()
→ runtime error.Apply this diff:
def forward(self, x: torch.Tensor) -> torch.Tensor: - return cast(torch.Tensor, checkpoint(self.module, x, use_reentrant=False)) + if self.training and torch.is_grad_enabled() and x.requires_grad: + return cast(torch.Tensor, checkpoint(self.module, x, use_reentrant=False)) + return cast(torch.Tensor, self.module(x))
29-37
: BatchNorm in checkpointed subblocks will double-update running stats.Checkpoint recomputes the forward pass during backward, causing BatchNorm layers to update
running_mean
/running_var
twice per training step, skewing statistics.Consider detecting BatchNorm in
__init__
and either warning or skipping checkpoint:class _ActivationCheckpointWrapper(nn.Module): """Apply activation checkpointing to the wrapped module during training.""" def __init__(self, module: nn.Module) -> None: super().__init__() self.module = module + self._has_bn = any(isinstance(m, nn.modules.batchnorm._BatchNorm) for m in module.modules()) def forward(self, x: torch.Tensor) -> torch.Tensor: - return cast(torch.Tensor, checkpoint(self.module, x, use_reentrant=False)) + if self.training and torch.is_grad_enabled() and x.requires_grad: + if self._has_bn: + warnings.warn( + "Activation checkpointing skipped for subblock with BatchNorm to avoid double-update of running stats.", + RuntimeWarning, + ) + return cast(torch.Tensor, self.module(x)) + return cast(torch.Tensor, checkpoint(self.module, x, use_reentrant=False)) + return cast(torch.Tensor, self.module(x))
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Knowledge base: Disabled due to Reviews -> Disable Knowledge Base
setting
📒 Files selected for processing (1)
monai/networks/nets/unet.py
(4 hunks)
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py
⚙️ CodeRabbit configuration file
Review the Python code for quality and correctness. Ensure variable names adhere to PEP8 style guides, are sensible and informative in regards to their function, though permitting simple names for loop and comprehension variables. Ensure routine names are meaningful in regards to their function and use verbs, adjectives, and nouns in a semantically appropriate way. Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings. Examine code for logical error or inconsistencies, and suggest what may be changed to addressed these. Suggest any enhancements for code improving efficiency, maintainability, comprehensibility, and correctness. Ensure new or modified definitions will be covered by existing or new unit tests.
Files:
monai/networks/nets/unet.py
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (19)
- GitHub Check: flake8-py3 (pytype)
- GitHub Check: quick-py3 (macOS-latest)
- GitHub Check: build-docs
- GitHub Check: packaging
- GitHub Check: quick-py3 (ubuntu-latest)
- GitHub Check: min-dep-py3 (3.11)
- GitHub Check: min-dep-py3 (3.12)
- GitHub Check: min-dep-pytorch (2.5.1)
- GitHub Check: flake8-py3 (codeformat)
- GitHub Check: quick-py3 (windows-latest)
- GitHub Check: min-dep-os (macOS-latest)
- GitHub Check: min-dep-pytorch (2.6.0)
- GitHub Check: flake8-py3 (mypy)
- GitHub Check: min-dep-pytorch (2.8.0)
- GitHub Check: min-dep-os (ubuntu-latest)
- GitHub Check: min-dep-pytorch (2.7.1)
- GitHub Check: min-dep-py3 (3.9)
- GitHub Check: min-dep-py3 (3.10)
- GitHub Check: min-dep-os (windows-latest)
🔇 Additional comments (2)
monai/networks/nets/unet.py (2)
16-16
: LGTM on imports.Both
cast
andcheckpoint
are used in the new wrapper and are correctly imported.Also applies to: 20-20
316-316
: Clarify checkpointing scope: onlysubblock
vs. entire connection block.Only
subblock
(the recursive nested structure) is wrapped, whiledown_path
andup_path
(encoder/decoder convolutions at each level) are not checkpointed. Is this intentional?Typical UNet checkpointing strategies checkpoint entire encoder/decoder blocks for maximum memory savings. Consider whether
down_path
andup_path
should also be wrapped, or document the rationale for checkpointing only the recursive substructure.# Alternative: checkpoint all three components def _get_connection_block(self, down_path: nn.Module, up_path: nn.Module, subblock: nn.Module) -> nn.Module: down_path = _ActivationCheckpointWrapper(down_path) up_path = _ActivationCheckpointWrapper(up_path) subblock = _ActivationCheckpointWrapper(subblock) return super()._get_connection_block(down_path, up_path, subblock)
from monai.networks.layers.factories import Act, Norm | ||
from monai.networks.layers.simplelayers import SkipConnection | ||
|
||
__all__ = ["UNet", "Unet"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add CheckpointUNet
to __all__
exports.
CheckpointUNet
is a public class but not exported in __all__
.
Apply this diff:
-__all__ = ["UNet", "Unet"]
+__all__ = ["UNet", "Unet", "CheckpointUNet"]
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
__all__ = ["UNet", "Unet"] | |
__all__ = ["UNet", "Unet", "CheckpointUNet"] |
🤖 Prompt for AI Agents
In monai/networks/nets/unet.py around line 26, update the module export list to
include the public class CheckpointUNet by adding "CheckpointUNet" to the
__all__ array; modify the existing __all__ = ["UNet", "Unet"] to include
"CheckpointUNet" so the final export list contains "UNet", "Unet", and
"CheckpointUNet".
class _ActivationCheckpointWrapper(nn.Module): | ||
"""Apply activation checkpointing to the wrapped module during training.""" | ||
def __init__(self, module: nn.Module) -> None: | ||
super().__init__() | ||
self.module = module | ||
|
||
def forward(self, x: torch.Tensor) -> torch.Tensor: | ||
return cast(torch.Tensor, checkpoint(self.module, x, use_reentrant=False)) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion | 🟠 Major
Add comprehensive Google-style docstring.
Per coding guidelines, document all parameters, return values, and behavior.
As per coding guidelines.
Apply this diff:
class _ActivationCheckpointWrapper(nn.Module):
- """Apply activation checkpointing to the wrapped module during training."""
+ """
+ Wrapper applying activation checkpointing to a module during training.
+
+ During the forward pass in training mode, intermediate activations are not stored;
+ they are recomputed during the backward pass to reduce memory usage.
+
+ Args:
+ module: The module to wrap with activation checkpointing.
+ """
def __init__(self, module: nn.Module) -> None:
super().__init__()
self.module = module
def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """
+ Forward pass with optional activation checkpointing.
+
+ Args:
+ x: Input tensor.
+
+ Returns:
+ Output tensor from the wrapped module.
+ """
return cast(torch.Tensor, checkpoint(self.module, x, use_reentrant=False))
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
class _ActivationCheckpointWrapper(nn.Module): | |
"""Apply activation checkpointing to the wrapped module during training.""" | |
def __init__(self, module: nn.Module) -> None: | |
super().__init__() | |
self.module = module | |
def forward(self, x: torch.Tensor) -> torch.Tensor: | |
return cast(torch.Tensor, checkpoint(self.module, x, use_reentrant=False)) | |
class _ActivationCheckpointWrapper(nn.Module): | |
""" | |
Wrapper applying activation checkpointing to a module during training. | |
During the forward pass in training mode, intermediate activations are not stored; | |
they are recomputed during the backward pass to reduce memory usage. | |
Args: | |
module: The module to wrap with activation checkpointing. | |
""" | |
def __init__(self, module: nn.Module) -> None: | |
super().__init__() | |
self.module = module | |
def forward(self, x: torch.Tensor) -> torch.Tensor: | |
""" | |
Forward pass with optional activation checkpointing. | |
Args: | |
x: Input tensor. | |
Returns: | |
Output tensor from the wrapped module. | |
""" | |
return cast(torch.Tensor, checkpoint(self.module, x, use_reentrant=False)) |
🤖 Prompt for AI Agents
In monai/networks/nets/unet.py around lines 29 to 37, the
_ActivationCheckpointWrapper class and its forward method lack a comprehensive
Google-style docstring; add a docstring for the class describing its purpose
(applies activation checkpointing to a wrapped nn.Module during training),
document the constructor parameter module (type and role) and any stored
attributes, and add a Google-style docstring on forward documenting Args (x:
torch.Tensor — shape/expected dtype/context), Returns (torch.Tensor — shape and
that it is the wrapped module's output), and Behavior/Raises (explain that
checkpoint(...) is used with use_reentrant=False, when it will be active—only
during forward in training—and any exceptions propagated from the wrapped
module); keep wording concise and follow Google docstring sections (Args,
Returns, Raises, and a short Example/Notes if useful).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 3
♻️ Duplicate comments (2)
monai/networks/nets/unet.py (2)
26-26
: ExportCheckpointUNet
via__all__
.Public class isn’t exported; importing it from
monai.networks.nets
will fail. Add it to__all__
.-__all__ = ["UNet", "Unet"] +__all__ = ["UNet", "Unet", "CheckpointUNet"]
36-37
: Guard checkpoint under grad-enabled training only.Forward unconditionally checkpointing will crash under
torch.no_grad()
(common in eval) because checkpoint requires gradients. Gate onself.training
,torch.is_grad_enabled()
, and a grad-requiring input before calling checkpoint.def forward(self, x: torch.Tensor) -> torch.Tensor: - return cast(torch.Tensor, checkpoint(self.module, x, use_reentrant=False)) + if self.training and torch.is_grad_enabled() and x.requires_grad: + return cast(torch.Tensor, checkpoint(self.module, x, use_reentrant=False)) + return cast(torch.Tensor, self.module(x))
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Knowledge base: Disabled due to Reviews -> Disable Knowledge Base
setting
📒 Files selected for processing (1)
monai/networks/nets/unet.py
(3 hunks)
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py
⚙️ CodeRabbit configuration file
Review the Python code for quality and correctness. Ensure variable names adhere to PEP8 style guides, are sensible and informative in regards to their function, though permitting simple names for loop and comprehension variables. Ensure routine names are meaningful in regards to their function and use verbs, adjectives, and nouns in a semantically appropriate way. Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings. Examine code for logical error or inconsistencies, and suggest what may be changed to addressed these. Suggest any enhancements for code improving efficiency, maintainability, comprehensibility, and correctness. Ensure new or modified definitions will be covered by existing or new unit tests.
Files:
monai/networks/nets/unet.py
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (19)
- GitHub Check: min-dep-py3 (3.10)
- GitHub Check: min-dep-pytorch (2.7.1)
- GitHub Check: min-dep-py3 (3.11)
- GitHub Check: min-dep-pytorch (2.5.1)
- GitHub Check: min-dep-os (ubuntu-latest)
- GitHub Check: min-dep-pytorch (2.8.0)
- GitHub Check: min-dep-py3 (3.12)
- GitHub Check: min-dep-os (macOS-latest)
- GitHub Check: min-dep-py3 (3.9)
- GitHub Check: min-dep-pytorch (2.6.0)
- GitHub Check: min-dep-os (windows-latest)
- GitHub Check: quick-py3 (macOS-latest)
- GitHub Check: packaging
- GitHub Check: quick-py3 (ubuntu-latest)
- GitHub Check: quick-py3 (windows-latest)
- GitHub Check: flake8-py3 (codeformat)
- GitHub Check: flake8-py3 (mypy)
- GitHub Check: flake8-py3 (pytype)
- GitHub Check: build-docs
class _ActivationCheckpointWrapper(nn.Module): | ||
"""Apply activation checkpointing to the wrapped module during training.""" | ||
|
||
def __init__(self, module: nn.Module) -> None: | ||
super().__init__() | ||
self.module = module | ||
|
||
def forward(self, x: torch.Tensor) -> torch.Tensor: | ||
return cast(torch.Tensor, checkpoint(self.module, x, use_reentrant=False)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion | 🟠 Major
Add Google-style docstrings.
Class and forward docstrings need Args/Returns sections per guidelines. Document the wrapped module, checkpoint guard details, and returned tensor.
As per coding guidelines.
🤖 Prompt for AI Agents
In monai/networks/nets/unet.py around lines 29 to 37, the
_ActivationCheckpointWrapper class and its forward method lack Google-style
docstrings; add a class-level docstring that briefly describes purpose, document
the module parameter as "module: nn.Module — module to wrap for activation
checkpointing", mention that checkpointing is applied during training to save
memory and that use_reentrant=False is used as the checkpoint guard, and add a
forward method docstring with Args: x (torch.Tensor): input tensor to the
wrapped module and Returns: torch.Tensor: output tensor from the wrapped module
(with activations checkpointed); keep wording concise and follow Google-style
"Args/Returns" formatting.
class CheckpointUNet(UNet): | ||
def _get_connection_block(self, down_path: nn.Module, up_path: nn.Module, subblock: nn.Module) -> nn.Module: | ||
subblock = _ActivationCheckpointWrapper(subblock) | ||
return super()._get_connection_block(down_path, up_path, subblock) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add targeted tests for checkpointing behavior.
New subclass lacks coverage. Add unit tests ensuring training mode uses checkpoint (grad-enabled) and eval/no-grad bypasses it, plus parity with base UNet
.
As per coding guidelines.
🤖 Prompt for AI Agents
In monai/networks/nets/unet.py around lines 314 to 317, the new CheckpointUNet
subclass wraps subblocks with _ActivationCheckpointWrapper but has no unit
tests; add targeted tests that (1) construct a small UNet and CheckpointUNet
with identical weights and assert forward outputs match in eval mode (no_grad),
(2) in training mode run a forward+backward pass for both models and assert that
gradients are non-zero and that gradient patterns between UNet and
CheckpointUNet are equivalent, and (3) verify that when checkpointing is active
(training) the wrapped subblocks execute under autograd (e.g., by asserting
requires_grad on intermediary tensors or by checking that backward succeeds) and
when in eval/no_grad the wrapper bypasses torch.autograd.checkpoint (i.e., no
grad graph created). Use small random inputs, deterministic seeds, and parameter
copying to ensure parity; place tests under the networks/nets test module and
follow existing test conventions.
🛠️ Refactor suggestion | 🟠 Major
Document CheckpointUNet
.
Provide a Google-style class docstring noting the checkpointing behavior, inherited args, and trade-offs.
As per coding guidelines.
🤖 Prompt for AI Agents
In monai/networks/nets/unet.py around lines 314 to 317, the CheckpointUNet class
lacks a Google-style docstring; add a concise class docstring immediately under
the class declaration that (1) states this class enables activation
checkpointing by wrapping subblocks with _ActivationCheckpointWrapper, (2)
indicates it inherits the same constructor arguments and behavior as UNet
(mention "See UNet for args"), and (3) documents trade-offs: reduced memory at
the cost of extra compute/recomputation and potential differences in backward
behavior; keep wording brief and follow Google docstring format (one-line
summary, extended description, Args inherited note, and a short
“Raises/Warnings” or “Notes” line about the trade-offs).
Description
Introduces an optional
use_checkpointing
flag in theUNet
implementation. When enabled, intermediate activations in the encoder–decoder blocks are recomputed during the backward pass instead of being stored in memory._ActivationCheckpointWrapper
wrapper around sub-blocks.Types of changes
./runtests.sh -f -u --net --coverage
../runtests.sh --quick --unittests --disttests
.make html
command in thedocs/
folder.