akaSonny
Pytorch 기초부터 시작! 본문
5월부터 pytorch를 기초부터 배우게 되었다.
강의를 듣고 후에 복습하고 정리해야 나중에 남게 될 것 같아서 블로그에 정리하기로!
1. torch.ones(), torch.zeros() : 기본으로 실수형으로 나온다
import torch
import numpy as np
x = torch.ones(3, 2)
print(x)
# tensor([[1., 1.],
# [1., 1.],
# [1., 1.]])
x = torch.zeros(3, 2)
print(x)
# tensor([[0, 0.],
# [0., 0.],
# [0., 0.]])
2. seed 고정 : random 변수를 사용할 때 seed를 고정하면 항상 똑같은 변수가 나오게 된다. 딥러닝 모델 훈련 후 reconstruct할 때 중요!
torch.manual_seed(2)
x = torch.rand(3, 2) # 0 ~ 1
y = troch.randn(3,3) # -1 ~ 1
3. torch.view(shape) : 텐서의 shape 변경하는 함수.
x = torch.tensor([[1, 2],
[3, 4],
[5, 6]])
y = x.view(2, 3)
print(y)
# tensor([[1, 2, 3],
# [4, 5, 6]])
y = x.view(6, -1)
print(y)
# tensor([[1],
# [2],
# [3],
# [4],
# [5],
# [6]])
y = x.view(2, -1)
print(y)
# tensor([[1, 2, 3],
# [4, 5, 6]])
* 이 때 shape에 (6, -1), (2, -1) 과 같이 -1이 입력된 경우, 그 부분은 알아서 계산돼서 나오게 된다
위 예제에서는 (6, -1) --> (6, 1), (2, -1) --> (2, 3) 으로 계산되어 나오는 것을 확인할 수 있다.
주의할 점은 계산이 되지 않을 경우에는 에러가 발생한다. ex) (5, -1) 같은 경우
4. numpy array to tensor : 모델 훈련에는 tensor만 계산이 되므로 tensor로 바꿔주는 것이 필요.
a = np.random.randn(5)
a_pt = torch.from_numpy(a)
print(type(a))
# <class 'numpy.ndarray'>
print(type(a_pt))
# <class 'torch.Tensor'>
5. gpu 확인 및 사용
print(torch.cuda.device_count()) # the number of gpu
# 1
print(torch.cuda.get_device_name(0)) # name of gpu
# Tesla K80
device = torch.device('cuda:0')
a = torch.ones(3,2, device=device)
* 이 때 주의할 점! a는 cuda위에 올라왔기 때문에 numpy로 변환될 수 없다. numpy는 cpu를 이용해서 계산하기 때문.
따라서, numpy array로 바꾸고 싶으면 cpu로 내리고 변환해야 함.
a = a.numpy()
# TypeError: can't convert cuda:0 device type tensor to numpy.
# Use Tensor.cpu() to copy the tensor to host memory first.
6. grad 계산하기
x = torch.ones([3,2], requires_grad=True) # save all of computation
y = x + 5
z = y * y + 1
t = torch.sum(z)
# back propagation
t.backward()
print(x.grad)
# tensor([[12., 12.],
# [12., 12.],
# [12., 12.]])
이 부분이 이해가 좀 안 갔는데..
일단 마지막 x.grad는 dz/dx를 구하는 것이고, chain rule에 의해 dz/dx = dz/dy * dy/dx 로 구할 수 있다.
dz/dy = 2y , dy/dx = 1 이므로 dz/dx= 2y 이고 y=6 이었으므로 최종적으로 12가 나오게 된다.
이렇게 torch tensor 기본은 끝!
'Study (Programming) > Pytorch' 카테고리의 다른 글
[Pytorch] checkpoint 불러와서 모델 이어서 학습하기 (0) | 2024.02.19 |
---|---|
완전 간단한 Linear model 만들기 (0) | 2022.05.20 |