Files
sam3_local/sam3/perflib/compile.py
generatedunixname89002005307016 7b89b8fc3f Add missing Pyre mode headers] [batch:11/N] [shard:17/N]
Differential Revision: D90237984

fbshipit-source-id: 526fd760f303bf31be4f743bdcd77760496de0de
2026-01-07 05:16:41 -08:00

102 lines
3.2 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
# pyre-unsafe
import torch
def recursive_fn_factory(fn):
def recursive_fn(b):
if isinstance(b, dict):
return {k: recursive_fn(b[k]) for k in b}
if isinstance(b, list):
return [recursive_fn(t) for t in b]
if isinstance(b, tuple):
return tuple(recursive_fn(t) for t in b)
if isinstance(b, torch.Tensor):
return fn(b)
# Yes, writing out an explicit white list of
# trivial types is tedious, but so are bugs that
# come from not applying fn, when expected to have
# applied it.
if b is None:
return b
trivial_types = [bool, int]
for t in trivial_types:
if isinstance(b, t):
return b
raise TypeError(f"Unexpected type {type(b)}")
return recursive_fn
recursive_contiguous = recursive_fn_factory(lambda x: x.contiguous())
recursive_clone = recursive_fn_factory(torch.clone)
def compile_wrapper(
fn, *, mode="max-autotune", fullgraph=True, dynamic=False, name=None
):
compiled_fn = torch.compile(fn, mode=mode, fullgraph=fullgraph, dynamic=dynamic)
def compiled_fn_wrapper(*args, **kwargs):
with torch.autograd.profiler.record_function(
f"compiled {fn}" if name is None else name
):
cont_args = recursive_contiguous(args)
cont_kwargs = recursive_contiguous(kwargs)
result = compiled_fn(*cont_args, **cont_kwargs)
cloned_result = recursive_clone(result)
return cloned_result
return compiled_fn_wrapper
def shape_logging_wrapper(fn, keep_kwargs, enable_logging=False):
"""
Wraps a function and prints the shapes of all tensor inputs.
Only prints when a new combination of shapes is seen.
Thread-safe.
Args:
fn: Function to wrap
enable_logging: Boolean flag to enable/disable logging
"""
seen_shapes = set()
def get_shape(obj):
if isinstance(obj, torch.Tensor):
return obj.shape
elif isinstance(obj, (list, tuple)):
if len(obj) > 1:
return tuple(get_shape(x) for x in obj)
return get_shape(obj[0])
elif isinstance(obj, dict):
return tuple(sorted((k, get_shape(v)) for k, v in obj.items()))
else:
return type(obj).__name__
def wrapper(*args, **kwargs):
shapes = tuple(get_shape(arg) for arg in args) + tuple(
(k, get_shape(v))
for k, v in kwargs.items()
if isinstance(v, (torch.Tensor, list))
and (len(keep_kwargs) > 0 and k in keep_kwargs)
)
if shapes not in seen_shapes:
seen_shapes.add(shapes)
if enable_logging:
print(f"[ShapeLogger] New input shapes for {fn.__qualname__}: {shapes}")
return fn(*args, **kwargs)
# Allow toggling the flag at runtime
wrapper.enable_logging = enable_logging
def set_logging(enabled=False):
nonlocal enable_logging
enable_logging = enabled
wrapper.enable_logging = enable_logging
wrapper.set_logging = set_logging
return wrapper