18 lines
414 B
Python
18 lines
414 B
Python
import torch
|
|
import torch.nn as nn
|
|
from functools import cached_property
|
|
|
|
|
|
class DeviceManager(object):
|
|
|
|
def __init__(self):
|
|
self.device = self.__get_device()
|
|
|
|
@staticmethod
|
|
def __get_device():
|
|
if torch.cuda.is_available():
|
|
return torch.device('cuda:0')
|
|
if torch.backends.mps.is_available():
|
|
return torch.device('mps')
|
|
return torch.device('cpu')
|