Differential Revision: D90237984 fbshipit-source-id: 526fd760f303bf31be4f743bdcd77760496de0de
102 lines
3.2 KiB
Python
102 lines
3.2 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
|
|
|
|
# pyre-unsafe
|
|
|
|
import torch
|
|
|
|
|
|
def recursive_fn_factory(fn):
|
|
def recursive_fn(b):
|
|
if isinstance(b, dict):
|
|
return {k: recursive_fn(b[k]) for k in b}
|
|
if isinstance(b, list):
|
|
return [recursive_fn(t) for t in b]
|
|
if isinstance(b, tuple):
|
|
return tuple(recursive_fn(t) for t in b)
|
|
if isinstance(b, torch.Tensor):
|
|
return fn(b)
|
|
# Yes, writing out an explicit white list of
|
|
# trivial types is tedious, but so are bugs that
|
|
# come from not applying fn, when expected to have
|
|
# applied it.
|
|
if b is None:
|
|
return b
|
|
trivial_types = [bool, int]
|
|
for t in trivial_types:
|
|
if isinstance(b, t):
|
|
return b
|
|
raise TypeError(f"Unexpected type {type(b)}")
|
|
|
|
return recursive_fn
|
|
|
|
|
|
recursive_contiguous = recursive_fn_factory(lambda x: x.contiguous())
|
|
recursive_clone = recursive_fn_factory(torch.clone)
|
|
|
|
|
|
def compile_wrapper(
|
|
fn, *, mode="max-autotune", fullgraph=True, dynamic=False, name=None
|
|
):
|
|
compiled_fn = torch.compile(fn, mode=mode, fullgraph=fullgraph, dynamic=dynamic)
|
|
|
|
def compiled_fn_wrapper(*args, **kwargs):
|
|
with torch.autograd.profiler.record_function(
|
|
f"compiled {fn}" if name is None else name
|
|
):
|
|
cont_args = recursive_contiguous(args)
|
|
cont_kwargs = recursive_contiguous(kwargs)
|
|
result = compiled_fn(*cont_args, **cont_kwargs)
|
|
cloned_result = recursive_clone(result)
|
|
return cloned_result
|
|
|
|
return compiled_fn_wrapper
|
|
|
|
|
|
def shape_logging_wrapper(fn, keep_kwargs, enable_logging=False):
|
|
"""
|
|
Wraps a function and prints the shapes of all tensor inputs.
|
|
Only prints when a new combination of shapes is seen.
|
|
Thread-safe.
|
|
|
|
Args:
|
|
fn: Function to wrap
|
|
enable_logging: Boolean flag to enable/disable logging
|
|
"""
|
|
seen_shapes = set()
|
|
|
|
def get_shape(obj):
|
|
if isinstance(obj, torch.Tensor):
|
|
return obj.shape
|
|
elif isinstance(obj, (list, tuple)):
|
|
if len(obj) > 1:
|
|
return tuple(get_shape(x) for x in obj)
|
|
return get_shape(obj[0])
|
|
elif isinstance(obj, dict):
|
|
return tuple(sorted((k, get_shape(v)) for k, v in obj.items()))
|
|
else:
|
|
return type(obj).__name__
|
|
|
|
def wrapper(*args, **kwargs):
|
|
shapes = tuple(get_shape(arg) for arg in args) + tuple(
|
|
(k, get_shape(v))
|
|
for k, v in kwargs.items()
|
|
if isinstance(v, (torch.Tensor, list))
|
|
and (len(keep_kwargs) > 0 and k in keep_kwargs)
|
|
)
|
|
if shapes not in seen_shapes:
|
|
seen_shapes.add(shapes)
|
|
if enable_logging:
|
|
print(f"[ShapeLogger] New input shapes for {fn.__qualname__}: {shapes}")
|
|
return fn(*args, **kwargs)
|
|
|
|
# Allow toggling the flag at runtime
|
|
wrapper.enable_logging = enable_logging
|
|
|
|
def set_logging(enabled=False):
|
|
nonlocal enable_logging
|
|
enable_logging = enabled
|
|
wrapper.enable_logging = enable_logging
|
|
|
|
wrapper.set_logging = set_logging
|
|
return wrapper
|