import torch import torch.nn as nn from functools import cached_property class DeviceManager(object): def __init__(self): self.device = self.__get_device() @staticmethod def __get_device(): if torch.cuda.is_available(): return torch.device('cuda:0') if torch.backends.mps.is_available(): return torch.device('mps') return torch.device('cpu')