PyTorch on m2 mac
· 5 min read
Apple Silicon(M1, M2, M3 등)을 사용하는 macOS 환경에서는 PyTorch의 cuda 대신 Metal Performance Shaders(MPS) 백엔드를 사용해야 합니다. 하지만 이때 주의해야 할 점 중 하나가 바로 torch_dtype 설정입니다.
✅ MPS란?
- MPS는 Apple에서 제공하는 GPU 가속 프레임워크입니다.
- PyTorch는 1.12 버전부터 MPS 백엔드를 지원합니다.
- Apple Silicon에서는
torch.cuda대신torch.mps를 사 용해야 합니다.
import torch
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
Note 실행 환경의 조건에 따라 다음과 같이 torch device를 검사 및 설정할 수 있습니다.
def get_device():
device = torch.device("cpu")
if torch.cuda.is_available():
print("cuda is available")
device = torch.device("cuda"))
else:
if torch.backends.mps.is_available():
print("mps is available")
device = torch.device("mps")
else:
print("cuda and mps are not available, so cpu will be used.")
return device
✅ MPS에서 권장되는 torch_dtype
| dtype | 지원 여부 | 설명 |
|---|---|---|
torch.float32 | ✅ 지원 | 가장 안정적이며 기본적으로 사용됨 |
torch.float16 | ⚠️ 제한적 | 일부 연산에서 오류 발생 가능 |
torch.bfloat16 | ❌ 미지원 | 현재 MPS에서는 사용 불가 |
torch.int32/int64 | ⚠️ 제한적 | 연산 종류에 따라 오류 발생 가능 |
