torch.contiguous
26 Jan 2022
Pytorch에서
narrow()
,view()
,expand()
,transpose()
와 같은 연산을 했을때, tensor에 대한 메모리 할당이 어떤식으로 이뤄지는지에 대한 포스트입니다.
Pytorch에서는 tensor의 content를 바꾸지 않은 채로, data의 organize 되는 방식만을 변경하는 다음과 같은 operation이 있다. narrow()
, view()
, expand()
, transpose()
예를들어, transpose()
를 사용하면 Pytorch는 새로운 layout으로 새로운 tensor를 만들지 않는다. 해당 tensor object의 meta data인 offset과 stride을 변경하여 new shape으로 만든다. transposed tensor는 original tensor와 같은 메모리 영역을 공유한다.
x = torch.randn(2,3)
y = torch.transpose(x,0,1)
x[0,0] =42
print(y[0,0])
# tensor(42.)
메모리 영역 할당할때, row-wise하게 할당하게 되는데, x는 contiguous하지만 y는 contiguous하지 않다. 아래 예시를 보면 알 수 있다.
x = torch.Tensor( [[1,2,3],[4,5,6]] )
print(x.stride())
# (3, 1)
print(x.is_contiguous())
# True
y= torch.transpose(x,0,1)
print(y.stride())
#( 1, 3)
print(y.is_contiguous())
# False
왜냐하면, y의 memory layout이 그냥 생성한 shape[3,2]의 tensor와 다르기 때문이다. bytes들이 하나의 memory block에 할당되어 있지만, element의 순서가 다르다. y와 같은 shape을 갖는 z를 만들어서 이를 확인해보면 더 쉽게 이해할 수 있다.
z = torch.tensor([[1,2],[3,4],[5,6]])
print(z.stride())
# (2, 1)
print(z.is_contiguous())
# True
contiguous()
를 호출하면, tensor의 copy를 만드는데, 해당 shape에 알맞는 element 순서를 갖게 memory에 할당받을 수 있다.
y= torch.transpose(x,0,1).contiguous()
print(y.stride())
# (2, 1)
print(y.is_contiguous())
# True
보통의 경우,이것에 대해 걱정하지 않아도 된다. PyTorch가 contiguous tensor를 필요로 할 경우, RuntimeError: input is not contiguous
가 뜨는 것을 제외하곤, 다른 모든 경우에 대해 정상 작동할것이라 가정해도 괜찮다.