Differential Revision: D90237984 fbshipit-source-id: 526fd760f303bf31be4f743bdcd77760496de0de
80 lines
2.5 KiB
Python
80 lines
2.5 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
|
|
|
|
# pyre-unsafe
|
|
|
|
from collections import defaultdict
|
|
from dataclasses import fields, is_dataclass
|
|
from typing import Any, Mapping, Protocol, runtime_checkable
|
|
|
|
import torch
|
|
|
|
|
|
def _is_named_tuple(x) -> bool:
|
|
return isinstance(x, tuple) and hasattr(x, "_asdict") and hasattr(x, "_fields")
|
|
|
|
|
|
@runtime_checkable
|
|
class _CopyableData(Protocol):
|
|
def to(self, device: torch.device, *args: Any, **kwargs: Any):
|
|
"""Copy data to the specified device"""
|
|
...
|
|
|
|
|
|
def copy_data_to_device(data, device: torch.device, *args: Any, **kwargs: Any):
|
|
"""Function that recursively copies data to a torch.device.
|
|
|
|
Args:
|
|
data: The data to copy to device
|
|
device: The device to which the data should be copied
|
|
args: positional arguments that will be passed to the `to` call
|
|
kwargs: keyword arguments that will be passed to the `to` call
|
|
|
|
Returns:
|
|
The data on the correct device
|
|
"""
|
|
|
|
if _is_named_tuple(data):
|
|
return type(data)(
|
|
**copy_data_to_device(data._asdict(), device, *args, **kwargs)
|
|
)
|
|
elif isinstance(data, (list, tuple)):
|
|
return type(data)(copy_data_to_device(e, device, *args, **kwargs) for e in data)
|
|
elif isinstance(data, defaultdict):
|
|
return type(data)(
|
|
data.default_factory,
|
|
{
|
|
k: copy_data_to_device(v, device, *args, **kwargs)
|
|
for k, v in data.items()
|
|
},
|
|
)
|
|
elif isinstance(data, Mapping):
|
|
return type(data)(
|
|
{
|
|
k: copy_data_to_device(v, device, *args, **kwargs)
|
|
for k, v in data.items()
|
|
}
|
|
)
|
|
elif is_dataclass(data) and not isinstance(data, type):
|
|
new_data_class = type(data)(
|
|
**{
|
|
field.name: copy_data_to_device(
|
|
getattr(data, field.name), device, *args, **kwargs
|
|
)
|
|
for field in fields(data)
|
|
if field.init
|
|
}
|
|
)
|
|
for field in fields(data):
|
|
if not field.init:
|
|
setattr(
|
|
new_data_class,
|
|
field.name,
|
|
copy_data_to_device(
|
|
getattr(data, field.name), device, *args, **kwargs
|
|
),
|
|
)
|
|
return new_data_class
|
|
elif isinstance(data, _CopyableData):
|
|
return data.to(device, *args, **kwargs)
|
|
return data
|