네이버 부스트캠프 🔗/🎶추가 학습 정리

[Pytorch] tensor의 연산과 인덱싱

Dobby98 2023. 3. 15. 19:23

✅목차

  1. Tensor 연산
  2. Tensor 인덱싱

✅ Pytorch의 Tensor

 

파이토치를 다루다 보면 마주치는 자료형이 있다

바로 Tensor이다 

 

파이토치의 텐서는 Numpy의 배열 또는 행렬과 비슷한 모양을 갖고 있다 - 사실 같은 모양

그러나 유사하지만 파이토치의 특수한 자료형이다

 

물론 Numpy와 파이썬의 리스트를 이용해서 텐서를 만들 수 있다

바로 torch.tensor()함수를 이용하는 것이다

 

이것에 대한 내용 주간학습정리에서 간단하게 다루어 보았다

 

[네이버 부스트 캠프 AI Tech] Week 2 - Day 1 수업

본 글은 네이버 부스트 캠프 AI Tech 기간동안 개인적으로 배운 내용들을 주단위로 정리한 글입니다 본 글의 내용은 새롭게 알게 된 내용을 중심으로 정리하였고 복습 중요도를 선정해서 정리하

eumgill98.tistory.com

 


⭐tensor 연산

# + : torch.add(a, b) - a와 b를 더한다 
# -  : torch.sub(a, b) - a에서 b를 뺀다
# *  : torch.mul(a, b) - a와 b를 곱한다 
# /  : torch.div(a, b) - a를 b로 나눈다

⭐tensor 인덱싱

 

numpy와 비슷하게 인덱싱할 수 있다

#인덱싱
import torch
basic_tensor = torch.Tensor([[1,2],
                            [3,4]])

print(basic_tensor[1])
>>> tensor([3., 4.])

print(basic_tensor[0])
>>> tensor([1., 2.])

print(basic_tensor[1][1])
>>> tensor(4.)

print(basic_tensor[0][0])
>>> tensor(1.)

 

 

 torch.index_select를 활용한 인덱싱

x = torch.randn(3, 4)
>>> tensor([[ 0.1427,  0.0231, -0.5414, -1.0009], 
          [-0.4664,  0.2647, -0.1228, -1.1068],   
          [-1.1734, -0.6571,  0.7230, -0.6004]])  

indices = torch.tensor([0, 2]) # 인덱싱할 인덱스 지정

torch.index_select(x, 0, indices) #0번째 차원기준 0, 2 인덱싱
>>> tensor([[ 0.1427,  0.0231, -0.5414, -1.0009],
           [-1.1734, -0.6571,  0.7230, -0.6004]])

torch.index_select(x, 1, indices) #1번째 차원 기준 0, 2 인덱싱
>>> tensor([[ 0.1427, -0.5414],
           [-0.4664, -0.1228],
           [-1.1734,  0.7230]])
           
#차원이 헷갈린다면
x.shape
>>> torch.Size([3, 4]) #([0, 1])

#차원인 경우
A = torch.ones(2,3,4)
>>> tensor([[[1., 1., 1., 1.],
         	[1., 1., 1., 1.],
         	[1., 1., 1., 1.]],

        	[[1., 1., 1., 1.],
         	[1., 1., 1., 1.],
         	[1., 1., 1., 1.]]])
A.shape 
>>> torch.Size([2, 3, 4]) #([0, 1, 2])

 

torch.gather()를 활용한 대각선 인덱싱

 

작동 원리 

out[i][j][k] = input[index[i][j][k]][j][k]  # if dim == 0
out[i][j][k] = input[i][index[i][j][k]][k]  # if dim == 1
out[i][j][k] = input[i][j][index[i][j][k]]  # if dim == 2

예시 

import torch

#2D Tensor - using gather
A = torch.Tensor([[1, 2],
                  [3, 4]])
                  
print(torch.gather(A, 0, torch.tensor([[0,1]])))

>>> tensor([[1., 4.]])

조금더 복잡한 사용법 - in 3D tensor

import torch #3차원 에서 인덱싱 
A = torch.Tensor([[[1, 2],
                   [3, 4]],
                  [[5, 6],
                   [7, 8]]])

# print(A.shape)
# 2 2 2
# target = [[1,4],[5,8]]


output = torch.gather(A, 1, torch.tensor([[[0, 1]],[[0, 1]]])).squeeze(1)
#or
output = torch.gather(A, 1, torch.tensor([[[0, 1]],[[0, 1]]])).view(2,2)

>>> tensor([[1,4], [5,8]])