Source code for graphlow.base.tensor_property
import torch
[docs]
class GraphlowTensorProperty:
DEFAULT_FLOAT_TYPE = torch.float32
def __init__(
self,
device: torch.device | int = -1,
dtype: torch.dtype | type | None = None,
):
self.device = device
self.dtype: torch.dtype = dtype
return
@property
def device(self) -> torch.device:
return self._device
@device.setter
def device(self, device: torch.device | int):
if isinstance(device, int):
if device < 0:
self._device = torch.device("cpu")
return
self._device = torch.device(device)
return
self._device = device
return
@property
def dtype(self) -> torch.dtype:
return self._dtype
@dtype.setter
def dtype(self, dtype: torch.dtype | type | None):
if dtype is None:
self._dtype = self.DEFAULT_FLOAT_TYPE
return
if isinstance(dtype, torch.dtype):
self._dtype = dtype
return
str_type = dtype.__name__
if str_type == "float":
self._dtype = torch.float
return
elif str_type == "int":
self._dtype = torch.int
return
elif str_type == "bool":
self._dtype = torch.bool
return
else:
raise ValueError(f"Unexpected dtype: {dtype}")