Initial commit
fbshipit-source-id: da6be2f26e3a1202f4bffde8cb980e2dcb851294
This commit is contained in:
114
sam3/model/act_ckpt_utils.py
Normal file
114
sam3/model/act_ckpt_utils.py
Normal file
@@ -0,0 +1,114 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
|
||||
|
||||
import inspect
|
||||
from functools import wraps
|
||||
from typing import Callable, TypeVar, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.utils.checkpoint as checkpoint
|
||||
from torch.utils._pytree import tree_map_only
|
||||
|
||||
# Type variables for better type hinting
|
||||
T = TypeVar("T")
|
||||
Module = TypeVar("Module", bound=nn.Module)
|
||||
|
||||
|
||||
def activation_ckpt_wrapper(module: Union[nn.Module, Callable]) -> Callable:
|
||||
"""
|
||||
Wraps a given module to enable or disable activation checkpointing.
|
||||
|
||||
Activation checkpointing (gradient checkpointing) trades compute for memory by
|
||||
recomputing intermediate activations during the backward pass instead of storing
|
||||
them in memory during the forward pass.
|
||||
|
||||
When activation checkpointing is enabled, the wrapper expects only keyword arguments,
|
||||
and it maps these to positional arguments based on the module's signature.
|
||||
|
||||
Args:
|
||||
module: The module or function to wrap with activation checkpointing
|
||||
|
||||
Returns:
|
||||
A wrapped callable that supports activation checkpointing
|
||||
|
||||
Usage:
|
||||
The returned wrapper function can be called with the same arguments as the
|
||||
original module, with an additional `act_ckpt_enable` keyword argument to control
|
||||
activation checkpointing and optional `use_reentrant` parameter.
|
||||
|
||||
Example:
|
||||
```python
|
||||
wrapped_module = activation_ckpt_wrapper(my_module)
|
||||
output = wrapped_module(x=input_tensor, y=another_tensor, act_ckpt_enable=True)
|
||||
```
|
||||
"""
|
||||
|
||||
@wraps(module)
|
||||
def act_ckpt_wrapper(
|
||||
*args, act_ckpt_enable: bool = True, use_reentrant: bool = False, **kwargs
|
||||
):
|
||||
if act_ckpt_enable:
|
||||
if len(args) > 0:
|
||||
raise ValueError(
|
||||
"This wrapper expects keyword arguments only when `act_ckpt_enable=True`"
|
||||
)
|
||||
# Get the signature of the target function/module
|
||||
callable_fn = module.forward if isinstance(module, nn.Module) else module
|
||||
sig = inspect.signature(callable_fn)
|
||||
# Create a mapping of parameter names to their default values
|
||||
param_defaults = {
|
||||
name: param.default for name, param in sig.parameters.items()
|
||||
}
|
||||
args = []
|
||||
for p_name in param_defaults.keys():
|
||||
if p_name in kwargs:
|
||||
args.append(kwargs.pop(p_name))
|
||||
elif param_defaults[p_name] is not inspect.Parameter.empty:
|
||||
# Set arg to default value if it's not in kwargs. Useful for primitive types or args that default to None
|
||||
args.append(param_defaults[p_name])
|
||||
elif (
|
||||
sig.parameters[p_name].kind is not inspect.Parameter.VAR_KEYWORD
|
||||
): # Skip **kwargs parameter
|
||||
raise ValueError(f"Missing positional argument: {p_name}")
|
||||
|
||||
# Scan remaining kwargs for torch.Tensor
|
||||
remaining_keys = list(kwargs.keys())
|
||||
for key in remaining_keys:
|
||||
if isinstance(kwargs[key], torch.Tensor):
|
||||
# Remove the tensor from kwargs, assuming it's not required by the module.
|
||||
# If it is required, the module's signature should be modified to accept it as a positional or keyword argument.
|
||||
kwargs[key] = "_REMOVED_BY_ACT_CKPT_WRAPPER_"
|
||||
|
||||
ret = checkpoint.checkpoint(
|
||||
module, *args, use_reentrant=use_reentrant, **kwargs
|
||||
)
|
||||
else:
|
||||
ret = module(*args, **kwargs)
|
||||
|
||||
return ret
|
||||
|
||||
return act_ckpt_wrapper
|
||||
|
||||
|
||||
def clone_output_wrapper(f: Callable[..., T]) -> Callable[..., T]:
|
||||
"""
|
||||
Clone the CUDA output tensors of a function to avoid in-place operations.
|
||||
|
||||
This wrapper is useful when working with torch.compile to prevent errors
|
||||
related to in-place operations on tensors.
|
||||
|
||||
Args:
|
||||
f: The function whose CUDA tensor outputs should be cloned
|
||||
|
||||
Returns:
|
||||
A wrapped function that clones any CUDA tensor outputs
|
||||
"""
|
||||
|
||||
@wraps(f)
|
||||
def wrapped(*args, **kwargs):
|
||||
outputs = f(*args, **kwargs)
|
||||
return tree_map_only(
|
||||
torch.Tensor, lambda t: t.clone() if t.is_cuda else t, outputs
|
||||
)
|
||||
|
||||
return wrapped
|
||||
Reference in New Issue
Block a user