Initial commit
fbshipit-source-id: da6be2f26e3a1202f4bffde8cb980e2dcb851294
This commit is contained in:
77
sam3/model/utils/misc.py
Normal file
77
sam3/model/utils/misc.py
Normal file
@@ -0,0 +1,77 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
|
||||
|
||||
from collections import defaultdict
|
||||
from dataclasses import fields, is_dataclass
|
||||
from typing import Any, Mapping, Protocol, runtime_checkable
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def _is_named_tuple(x) -> bool:
|
||||
return isinstance(x, tuple) and hasattr(x, "_asdict") and hasattr(x, "_fields")
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class _CopyableData(Protocol):
|
||||
def to(self, device: torch.device, *args: Any, **kwargs: Any):
|
||||
"""Copy data to the specified device"""
|
||||
...
|
||||
|
||||
|
||||
def copy_data_to_device(data, device: torch.device, *args: Any, **kwargs: Any):
|
||||
"""Function that recursively copies data to a torch.device.
|
||||
|
||||
Args:
|
||||
data: The data to copy to device
|
||||
device: The device to which the data should be copied
|
||||
args: positional arguments that will be passed to the `to` call
|
||||
kwargs: keyword arguments that will be passed to the `to` call
|
||||
|
||||
Returns:
|
||||
The data on the correct device
|
||||
"""
|
||||
|
||||
if _is_named_tuple(data):
|
||||
return type(data)(
|
||||
**copy_data_to_device(data._asdict(), device, *args, **kwargs)
|
||||
)
|
||||
elif isinstance(data, (list, tuple)):
|
||||
return type(data)(copy_data_to_device(e, device, *args, **kwargs) for e in data)
|
||||
elif isinstance(data, defaultdict):
|
||||
return type(data)(
|
||||
data.default_factory,
|
||||
{
|
||||
k: copy_data_to_device(v, device, *args, **kwargs)
|
||||
for k, v in data.items()
|
||||
},
|
||||
)
|
||||
elif isinstance(data, Mapping):
|
||||
return type(data)(
|
||||
{
|
||||
k: copy_data_to_device(v, device, *args, **kwargs)
|
||||
for k, v in data.items()
|
||||
}
|
||||
)
|
||||
elif is_dataclass(data) and not isinstance(data, type):
|
||||
new_data_class = type(data)(
|
||||
**{
|
||||
field.name: copy_data_to_device(
|
||||
getattr(data, field.name), device, *args, **kwargs
|
||||
)
|
||||
for field in fields(data)
|
||||
if field.init
|
||||
}
|
||||
)
|
||||
for field in fields(data):
|
||||
if not field.init:
|
||||
setattr(
|
||||
new_data_class,
|
||||
field.name,
|
||||
copy_data_to_device(
|
||||
getattr(data, field.name), device, *args, **kwargs
|
||||
),
|
||||
)
|
||||
return new_data_class
|
||||
elif isinstance(data, _CopyableData):
|
||||
return data.to(device, *args, **kwargs)
|
||||
return data
|
||||
Reference in New Issue
Block a user