CNN ์ฐ์ฐ์๋ ํ
์ ์ฐจ์์ ํ์ค ์์๋ฅผ ํ์ฉํ๊ณ ์๋ฏธ๋ก ์ ์๋ฏธ๋ฅผ ํ ๋นํฉ๋๋ค. ์ค๋๋ PyTorch์ 2D ์ฌ๋ก์ ๊ฒฝ์ฐ torch.nn.Conv2d์ ๋ํ ์
๋ ฅ์ NCHW ์์์ 4d ํ
์์ฌ์ผ ํฉ๋๋ค.
์ฑ๋ฅ์์ ์ด์ ๋ก ํน์ ์์
์์ ์ก์ธ์คํ๋ ๋ฉ๋ชจ๋ฆฌ๊ฐ ์ฐ์์ ์ผ๋ก ๋ฐฐ์น๋๊ณ ์ง์ญ์ฑ์ด ๋ ์ ํ์ฉ๋๋๋ก ์ฐจ์์ ๋ค๋ฅด๊ฒ ์ฌ์ ๋ ฌํ๋ ๊ฒ์ด ์ข
์ข
์ ๋ฆฌํฉ๋๋ค. ๊ฐ์ฅ ์ผ๋ฐ์ ์ธ ์ต์
์ ์น์๋ฅผ ๋์ผ๋ก ์ด๋ํ๋ ๊ฒ์
๋๋ค - NHWC. ํ ์ฐจ์์ ๋ธ๋ก์ผ๋ก ๋ฐ๋ํ์์ผ๋ก ๋ฐฐ์ดํ๋ ํจ์ฌ ๋ ๋ณต์กํ ๋ฉ๋ชจ๋ฆฌ ํ์์ด ์์ ์ ์์ต๋๋ค.
์ด๋ฅผ ํ์ฉํ๋ ๋ผ์ด๋ธ๋ฌ๋ฆฌ์ ์๋ ๋ค์๊ณผ ๊ฐ์ต๋๋ค.
๋ฌธ์ ๋ ์ฐจ์ ์์๋ฅผ ๋ณํํ๋ ๊ฒ ์์ฒด๊ฐ ๋น์ฉ์ด ๋ง์ด ๋ค๊ธฐ ๋๋ฌธ์ ์ฌ๋ฌ CNN ์์
์ด ์ฐ์์ผ๋ก ์ํ๋๋ ๊ฒฝ์ฐ(์: conv(relu(conv)))
) ๋ค๋ฅธ ๋ฉ๋ชจ๋ฆฌ ํ์์ผ๋ก ํ ๋ฒ ๋ณํ ํ๊ณ ์์
์ ์ํํ๊ณ ์ฌ์ ๋ ฌํ๋ ๊ฒ์ด ์ข์ต๋๋ค. ๋ค.
๋ฐ๋ผ์ PyTorch๊ฐ ๋ค์ํ ์ฐจ์ ์์๋ฅผ ์ธ์ํ๊ณ Eager ๋ชจ๋์ JIT ๋ชจ๋ ๋ชจ๋์์ ์์ ๊ฐ์ ๋ค๋ฅธ ๋ฉ๋ชจ๋ฆฌ ํ์์ ๊ฐ์ง ํ ์๋ฅผ ์ ๋ฌํ ์ ์๋๋ก
์ฐ๋ฆฌ๋ ๋ค์์ ํํํ ์ ์๋ API๋ฅผ ๊ตฌ์ถํ๊ธฐ ์ํด ๋ ธ๋ ฅํฉ๋๋ค.
์ฉ์ด : ์์ ๋ฌธ์ ๋ ์ข ์ข "layout"(mxnet), "data_format"(tf), "image_format"(keras), "order"(caffe2)๋ผ๊ณ ํฉ๋๋ค. ์ฐ๋ฆฌ๋ PyTorch์์ "memory format" ๋๋ "memory_format"์ด๋ผ๋ ์ด๋ฆ์ ์ฌ์ฉํ ๊ฒ์ ์ ์ํฉ๋๋ค. "๋ ์ด์์"์ด๋ผ๋ ์ด๋ฆ์ ๋ถํํ๋ PyTorch์์ 'strided' ๋ 'sparse_coo' ๊ฐ์ ์ฌ์ฉํ๋ฏ๋ก ์ด๋ฆ ์ง์ ์ต์ ์ ์ฌ์ฉํ ์ ์์ต๋๋ค.
๋ค์ ์ฐ์ฐ์๋ ์ต์ํ ๋ฉ๋ชจ๋ฆฌ ํ์์ ์ธ์ํด์ผ ํฉ๋๋ค. ์ ํํ ๊ฒฐ๊ณผ๋ฅผ ์์ฑํ๋ ๊ฒ ์ธ์๋ ๊ธฐ๋ณธ ๋ผ์ด๋ธ๋ฌ๋ฆฌ์์ ์ต์์ ์ฑ๋ฅ ์ ์ ๊ณตํ๊ณ ๋ช ์์ ์ผ๋ก ์ง์ ๋ ์ฌ์ฉ์ ์๋๋ฅผ ์ ํํ๊ธฐ ์ํด ์ถ๋ ฅ์ ๋ฉ๋ชจ๋ฆฌ ํ์ ์
PyTorch์์ ๋ฉ๋ชจ๋ฆฌ ํ์์ ๊ฐ๋ ์ ์:
torch.memory_format.channels_first
์ ๊ฐ์ ์์์
๋๋ค. ์ ํ์ด ์ง์ ๋์ง ์์์ผ๋ฉฐ ์์์ ๋น๊ต ๊ฐ๋ฅํ ๊ฐ์ฒด๊ฐ ๋ ์ ์์ต๋๋ค(enum์ผ๋ก ์์ํ ๊ฐ๋ฅ์ฑ์ด ๋์ง๋ง ๋์ค์๋ ๋ช
๋ช
๋ ํ
์์ ๊ฐ๋
๊ณผ ์ํธ ์ด์ฉ๋๋ ๋ค๋ฅธ ๊ฐ์ฒด๊ฐ ๋ ์ ์์)torch.channels_first
์ง์ ์ฌ์ฉchannels_first
๋ฐ channels_last
(๋ ์ ์ ์์ ์์ ํ์ฉ).Tensor์ ๋ค์ ๋ฉ์๋๋ฅผ ์ถ๊ฐํฉ๋๋ค.
x.is_contiguous(torch.memory_format.channels_first)
x.to(memory_format=torch.memory_format.channels_first)
์ฐธ๊ณ : ์ง๊ธ์ x.get_memory_format()
๊ธฐ๋ฅ์ด ์๊ณ ๋ช
์์ ๊ฒ์ฌ๋ง ๊ฐ๋ฅํฉ๋๋ค. ๊ฐ๋ฅํ ๊ตฌํ ๋ฒ์๊ฐ ๋ ๋์ต๋๋ค. ์ฐ๋ฆฌ๋ ๊ทธ๊ฒ์ ์ถ๊ฐํ๊ณ ์ถ์ ์๋ ์์ต๋๋ค.
ํ
์ ์๋ฏธ๋ก ์ ๋ ์ด์์์ ํญ์ ๋์ผํ๊ฒ ์ ์ง๋ฉ๋๋ค - NCHW! x.size()
ํญ์ (n,c,h,w)
๋ฐํํฉ๋๋ค.
์์ ์ ๋ฉ๋ชจ๋ฆฌ ํ์ ๋์์ ์ ์งํฉ๋๋ค.
๋ฉ๋ชจ๋ฆฌ ํ์์ ์ง๋ ฌํ/์ญ์ง๋ ฌํ๋ฅผ ํตํด ๋ณด์กด๋๋ ํ ์์ ์์ฑ์ ๋๋ค(ํ ์๊ฐ ๋งค๊ฐ๋ณ์์ธ ๊ฒฝ์ฐ).
์ค๋๋ PyTorch์ Tensor์๋ ๋
ผ๋ฆฌ์ ํ
์๊ฐ ๋ฉ๋ชจ๋ฆฌ์ ๋ฐฐ์น๋๋ ๋ฐฉ์์ ์ง์ ํ๋ strides ๊ฐ๋
์ด ์์ต๋๋ค. ํนํ ๊ฐ ํ
์๋ sizes
์ ๊ฐ์ ๊ธธ์ด์ strides
๋ฒกํฐ๋ฅผ ๊ฐ์ง๋๋ค. (i1, i2, .., ik)
๋
ผ๋ฆฌ์ ์ธ๋ฑ์ฑ์์ ์์๋ฅผ ์ธ๋ฑ์ฑํ๋ ค๋ฉด ๋ณดํญ์ผ๋ก ๋ด์ ์ ์ํํ๊ณ offset + i0*stride0 + i1*stride1 + ... * ik * stridek
์์ ๋ฉ๋ชจ๋ฆฌ๋ฅผ ์ฐพ์ต๋๋ค. ๋ฐ๋ผ์ ์ธ์ ํ ํ
์๋ ํฌ๊ธฐ์ ๋์ ๊ณฑ์ด ์ญ์ ๋๋ ๋ณดํญ์ ๊ฐ์ต๋๋ค. ์๋ฅผ ๋ค์ด (n,c,h,w)
ํฌ๊ธฐ์ 4D ํ
์์๋ (c*h*w, h*w, w, 1)
์์ต๋๋ค.
์คํธ๋ผ์ด๋๋ ๋ ผ๋ฆฌ์ ๊ธฐ๋ณธ NCHW ์์๋ฅผ ์ ์งํ๋ฉด์ ๋ฌผ๋ฆฌ์ ์ผ๋ก ๋ค๋ฅธ ๋ฉ๋ชจ๋ฆฌ ํ์(์ฐจ์ ์ฌ์ ๋ ฌ)์ ๋ํ๋ด๋ ๋ฐ ์ฌ์ฉํ ์ ์์ต๋๋ค. ๋ค์๊ณผ ๊ฐ์ด ๋ฉ๋ชจ๋ฆฌ ํ์ ๋ณํ์ ๋ํ ํจ๊ณผ์ ์ธ ์ ์๋ฅผ ์ ๊ณตํฉ๋๋ค.
# implementation of x.to(channels_last)
def to_mem_format_nhwc(x):
return x.permute(0,2,3,1).contiguous().permute(0,3,1,2)
# implementation of x.to(channels_first)
def to_mem_format_nchw(x):
return x.contiguous()
NHWC ํ์์์ ๋ณดํญ ๋ฒกํฐ๋ (c*h*w, 1, c*w, c)
์
๋๋ค. ๋ฐ๋ผ์ ๋ฉ๋ชจ๋ฆฌ ๋ฒํผ์์ ๊ฐ์ค์น๋ NHWC์ ๋ํด ์ฐ์์ ์ธ ์์์
๋๋ค.
Strides๋ ํ ์คํธ์ ์ฌ์ฉํ ์ ์์ต๋๋ค.
def is_nhwc_contiguous(x):
return x.permute(0,2,3,1).is_contiguous()
# or alteratively
def is_nhwc_contiguous(x):
n,c,h,w = x.size() # in any case the sizes remain in NCHW order
return x.stride() == (c*h*w, 1, c*w, c)
def is_nchw_contiguous(x):
return x.is_contiguous()
# operator implementations can just check contiguity and carry on directly on data pointer
def my_sample_op(x):
if x.is_contiguous(nhwc):
float* p = x.data();
# Do we need to go to c++ here?
# can we have an example in python?
n,c,h,w = x.size()
# operate on `p` as it's guaranteed to be (n,h,w,c) array
y=my_nhwc_op(p)
# Do we need to convert the layout of y?
else:
# Need to convert x to nhwc layout
x = x.permute(0,2,3,1).contiguous()
float *p = x.data();
# Is this needed?
y = my_nhwc_op(p)
return y.permute(0,3,1,2).contiguous()
์ด ์ ๊ทผ ๋ฐฉ์์ ์ฅ์ :
๋จ์ :
.contiguous()
ํธ์ถํ๋ ๊ฒ์ NCHW๋ก ์ ํํ๋ ๊ฒ๊ณผ ๋์ผํ๋ฉฐ ์ฌ์ฉ์ ๋๋ ์์
์ค ํ๋ ๋ด๋ถ์์ ์ฐ์ฐํ ๋ฐ์ํ ์ ์์ต๋๋ค.๊ฐ์ฅ ํฐ ์ ์ฌ์ ๋ฌธ์ ๋ ๋ถ๋ถ๋ช ํ ์ฌ์ฉ์ ์๋ ์ ๋๋ค. ์ฌ์ฉ์๊ฐ ์ ๋ง๋ก ๋ค๋ฅธ ๋ฉ๋ชจ๋ฆฌ ํ์์ ์ํ๋์ง ์๋๋ฉด ์ ๋ ฅ ํ ์๊ฐ ์ฐ์ฐํ ์ด๋ฐ ์์ผ๋ก ์คํธ๋ผ์ด๋(stride)๋์๋์ง ๊ตฌ๋ณํ ๋ฐฉ๋ฒ์ด ์์ต๋๋ค. ํนํ, ๊ธฐ์กด ์์ ์ ๋์ ๋ณ๊ฒฝ์ผ๋ก ์ด์ด์ง๋๋ค. ์ค๋๋ ์ปจ๋ณผ๋ฃจ์ ์ ์ ๋ ฅ์ด ์์์ ์คํธ๋ผ์ด๋์ธ ๊ฒฝ์ฐ์๋ NCHW ์ฐ์ ํ ์๋ฅผ ์์ฑํ ์ ์์ต๋๋ค. ์๋ก์ด ์ธ๊ณ์์๋ ์ ๋ ฅ์ NHWC๋ก ์ธ์ํ์ฌ NHWC๋ ๋ฐํํ ์ ์์ต๋๋ค. ์๋ฏธ ์ฒด๊ณ๋ ๋ณ๊ฒฝํ์ง ์์ง๋ง ๋๋ฒ๊ทธํ๊ธฐ ์ด๋ ค์ด ์ฑ๋ฅ ๋ฌธ์ ๋ก ์ด์ด์ง๋๋ค. ๊ฐ๋ฅํ ํด๊ฒฐ์ฑ ์ ์ฌ์ฉ์ ์ง์ memory_format ํ๋๊ทธ๋ฅผ ์ฌ์ฉํ์ฌ ํ ์์ ๋ช ์์ ์ผ๋ก ํ๊ทธ๋ฅผ ์ง์ ํ๊ณ ์ด ์ฃผ์(์คํธ๋ผ์ด๋ ์ธ์)๋ง ๋ฐ๋ฅด๋ ๊ฒ์ ๋๋ค.
์์ ๋ฌธ์ ๋ฅผ ํด๊ฒฐํ๊ธฐ ์ํด ์ด๊ธฐ ์ ์์ ํ
์์์ ์ํ๋ ๋ง์ง๋ง to(memory_format)
ํธ์ถ์ ๊ธฐ๋กํ๋ "์ํํธ" ๋ฉ๋ชจ๋ฆฌ ํ์ ํ๊ทธ๋ฅผ ํ
์์ ๋์
ํ๋ ๊ฒ์
๋๋ค. ์ด์์๋ ์ด ์ฃผ์์ ์ถ๋ ฅ์ ์ ํํด์ผ ํฉ๋๋ค. ์ฃผ์์ "์ํํธ"์ด๋ฏ๋ก ๋ถ์ผ์น ์ฃผ์์ ๋ํ ํ๋ ์ค๋ฅ๊ฐ ๋ฐ์ํ์ง ์๊ณ ํ๋กํ์ผ๋ง ๋ชจ๋์์ ๊ฒฝ๊ณ ๊ฐ ์์ฑ๋ฉ๋๋ค.
๊ธฐ์กด ์ด์์์ ์๋ช ์ ๋ณ๊ฒฝ๋์ง ์์ต๋๋ค. ์ด์์๋ ์ด์์ ๋ด๋ถ์์ ํ๋ ์ฝ๋ฉ๋ ๋์คํจ์น๋ฅผ โโ์ํํ์ฌ ๋ ๋น ๋ฅธ ๊ตฌํ์ผ๋ก ๋ผ์ฐํ ํ ์ ์์ต๋๋ค. ๊ตฌํ์ด ๋ถ๊ฐ๋ฅํ ๊ฒฝ์ฐ ๋ค๋ฅธ ๋ฉ๋ชจ๋ฆฌ ํ์์ ํตํ ์๋ณต์ด ๊ฐ๋ฅํฉ๋๋ค. ๋์์ ์ค๋ฅ ๋ฉ์์ง๋ฅผ ๋ฐ์์ํค๋ ๊ฒ์ ๋๋ค.
def maxpool(x: Tensor):
if x.is_contiguous(torch.layout.NHWC):
return max_pool_impl_nhwc(x)
return max_pool_impl_default(x.contiguous())
'conv_nhwc'์ ๊ฐ์ ๋ณ๋์ ์ฐ์ฐ์๋ฅผ ๋ง๋๋ ๋์ 'conv'์ ๊ฐ์ ๋จ์ผ ๊ธฐํธ๋ฅผ ์ฌ์ฉํ์ฌ JIT IR์ ์ฐ์ฐ์๋ฅผ ์ฐธ์กฐํ๋ ๊ฒ์ด ์ข์ต๋๋ค. ๊ทธ ์ด์ ๋ ๋จ์์ฑ๊ณผ ์๋ฏธ๋ก ์ ํํ ์์ค์์ IR์ ์ ์งํ๊ธฐ ๋๋ฌธ์ ๋๋ค.
element-wise์ ๊ฐ์ ํต์ฌ ์์ ์ด ๋ฉ๋ชจ๋ฆฌ ํ์์ ์ ์งํ๊ณ ํจ์จ์ ์์ ๋ณด์ฅํด์ผ ํฉ๋๋ค.
๋จํญ ์ฐ์ฐ์ ์ผ๋ฐ์ ์ผ๋ก ๋ฉ๋ชจ๋ฆฌ ๋ธ๋ก์ด "๋ฐ๋"์ธ์ง ์ฌ๋ถ, ์ฆ ์์๊ฐ ๊ฐ๊ฒฉ์ด ์๋ ์์ญ์ ๊ฑธ์ณ ์๊ณ ๊ฐ ๋ฉ๋ชจ๋ฆฌ ์์น๊ฐ ์ ํํ ํ ๋ฒ ์ฌ์ฉ๋๋์ง ์ฌ๋ถ๋ฅผ ํ์ธํ์ฌ ์ฒ๋ฆฌํ ์ ์์ต๋๋ค. ๊ฐ๋จํ ์๊ณ ๋ฆฌ์ฆ์ผ๋ก ๊ฒ์ฆ ๊ฐ๋ฅ
def is_dense_format(x):
p = 1
for s, d in sorted(zip(x.stride(), x.size())):
if s != p:
return False
p *= d
return True
def my_unary(x):
if is_dense_format(x):
return contig_memory_impl(x.data(), x.numel())
return default_strided_impl(x)
# is_dense_format can be used in implementations of e.g. empty_like too
๋๋ฒ๊น ์ฑ๋ฅ์ ์ํด ํ๋กํ์ผ๋ฌ์ ๋ค์ ์ง์์ ์ถ๊ฐํด์ผ ํฉ๋๋ค.
์ด ๊ธฐ๋ฅ์ ์ฃผ๋ฌธํ ํ๋กํ์ผ๋ง ๋๊ตฌ์ ๊ตฌ์ถํ ์ ์์ต๋๋ค.
์ญ๋ฐฉํฅ ํจ์ค๊ฐ ์ ๋ฐฉํฅ๊ณผ ๋์ผํ ๋ฉ๋ชจ๋ฆฌ ํ์์ผ๋ก ์คํ๋์ด์ผ ํ๋ค๊ณ ์์ํ๋ ๊ฒ์ด ๋ ผ๋ฆฌ์ ์ ๋๋ค. ๋ค์ด์ค๋ ๊ทธ๋ผ๋์ธํธ๊ฐ ์์์ ์ผ๋ก ์คํธ๋ผ์ด๋๋ ์ ์์ผ๋ฏ๋ก ํญ์ ์๋์ผ๋ก ๋ฐ์ํ์ง๋ ์์ต๋๋ค. ๋ฐ๋ผ์ ์ ๋ฐฉํฅ ํจ์ค๋ ๋ฉ๋ชจ๋ฆฌ ํ์์ ๋ช ์์ ์ผ๋ก ์ธ์ํ๊ณ autograd ํด๋ก์ ์ ์ ์ฅํ๊ณ ์ญ๋ฐฉํฅ ๊ธฐ๋ฅ ์ ์ grad ํ ์์ ์ ์ฉํด์ผ ํฉ๋๋ค.
๊ฐ๋ฅํ ๊ตฌํ:
def conv_backward(input, weight, grad_output, grad_weight, grad_input):
if input.is_contiguous(torch.memory_format.channels_last):
grad_output = grad_output.to(torch.memory_format.channels_last)
return conv_backward_nhwc(...)
else:
grad_output = grad_output.contiguous()
return conv_backward_nchw(...)
ํ์ฌ ์ ์์ ๋ค์๊ณผ ๊ฐ์ต๋๋ค.
to(memory_format)
ํธ์ถ์ ์ฝ์
ํด์ผ ํ๋ ์์น๋ฅผ ์ฐพ๋ ๋ฉ๋ชจ๋ฆฌ ํ์ ๋ณํ ํจ์ค(์๋ ๋๋ ์๋)์งํ ๋ชฉ์ ์ผ๋ก assert x.is_contiguous(channels_last)
์ ๊ฐ์ ๋ช
๋ น๋ฌธ์ ์ฌ์ฉํ ์๋ ์์ต๋๋ค.
์ฐธ๊ณ : ํน์ ์ฅ์น์ ์ ํธ๋๋ ๋ฉ๋ชจ๋ฆฌ ํ์ ์กฐํฉ์ด ์๋ค๋ ์ ๋ณด๋ฅผ ์ ์ฅํ ์์น์ ๋ํ ์ง๋ฌธ์ด ์์ต๋๋ค(์: x86์ qconv๋ NHWC๋ง ๊ตฌํํ๋ fbgemm ๊ฒฝ๋ก). ํ ๊ฐ์ง ์ต์ ์ op ๋ฑ๋ก ์์ค์ ๋๋ ๊ฒ์ด์ง๋ง ๋ฉ๋ชจ๋ฆฌ ํ์ ์ฃผ์์ ๋ ๋ง์ ๋ถ๊ฐ ์ ๋ณด์ฒ๋ผ ๋๊ปด์ง๋๋ค. ์ ํธํ๋ ๋ฉ๋ชจ๋ฆฌ ํ์ ๋ฐ ๊ด๋ จ ํด๋ฆฌ์คํฑ์ ๋ํ๋ด๋ ์ ์ญ ๋งต์ JIT ํจ์ค์ ์ด๋๊ฐ์ ์ ์ง ๊ด๋ฆฌํ๋ ๊ฒ์ผ๋ก ์์ํ ์ ์์ต๋๋ค. ์ด์์ ํ๋ฉด ๋ฑ๋ก ๊ธฐ๋ฐ ๋ฉ์ปค๋์ฆ์ผ๋ก ์ ํํ ์ ์์ต๋๋ค.
๋ ๋ณต์กํ ํ ์ ํจํน์ ์ถ๊ฐํ๊ธฐ๋ก ๊ฒฐ์ ํจ์ ๋ฐ๋ผ ๋์ ๊ตฌํ ๋น์ฉ๊ณผ ๋ณต์ก์ฑ์ผ๋ก ์ธํด 1๊ธ PyTorch ํ ์๋ฅผ ์ฌ์ฉํ๋ ๊ฒ์ด ํ๋นํ์ง ์์ ์ ์์ต๋๋ค. ๋ ๊ฐ์ง ๋์์ด ๊ฐ๋ฅํฉ๋๋ค.
๋ ๋ค๋ฅธ ๋์์ ํต์ฌ PyTorch Tensor ํด๋์ค์์ ์ฐจ๋จ/ํ์ผ๋ง์ ๋ํ ๊ธฐ๋ณธ ์ง์์ ๊ตฌํํ๋ ๊ฒ์ ๋๋ค.
NamedTensor ์ ๋ํ ๊ธฐ์กด ์ ์์ ํ ์์ ๋ํ ์ ํ ๊ฒ์ฌ ๋ฉ์ปค๋์ฆ์ผ๋ก ๊ตฌ์ฑ๋์ด ์์ต๋๋ค. ํ์ฌ ์ฐจ์ ์ด๋ฆ์ ์๋ฏธ๋ก ์ ์๋ฏธ๋ฅผ ํ ๋นํ์ง ์์ต๋๋ค. ๋ฐ๋ผ์ ํ์ฑํ ํ ์์ ์๋ฏธ๋ฅผ ์ถ๋ก ํ๋ ์ ์ผํ ๋ฐฉ๋ฒ์ ๋ฏธ๋ฆฌ ๊ฒฐ์ ๋ NCHW ํ์์ ๊ณ์ ์ฌ์ฉํ๋ ๊ฒ์ ๋๋ค. NamedTensor์ ํ์ฌ ์ ์์ ์ง๊ตํ๊ฒ ๋ง๋ญ๋๋ค.
์ผ๋ถ ์ด๋ฆ์ ์๋ฏธ(์: "์ฑ๋", "๋๋น")๋ฅผ ๊ธฐ๊บผ์ด ์ง์ ํ๋ ค๋ ๊ฒฝ์ฐ ์ด์์๋ ์ด ์ ๋ณด๋ฅผ ํ์ฉํ์ฌ ๋ ๋น ๋ฅธ ๊ตฌํ์ผ๋ก ๋ผ์ฐํ ํ ์ ์์ต๋๋ค. ์ ๋ ฅ ํ ์๊ฐ ๋ ผ๋ฆฌ์ ์ผ๋ก NHWC(์ค๋๋ ๊ณผ ๊ฐ์ NCHW๊ฐ ์๋) ๋ฉ๋ชจ๋ฆฌ ํ์์ ๊ฐ๊ธฐ ๋๋ฌธ์ ์๋ฏธ๋ก ์ ๋ณํ๊ฐ ๋ ๊ฒ์ ๋๋ค.
TensorFlow๋ data_format
๋งค๊ฐ๋ณ์๋ฅผ ํตํด ์ด์์ ์์ค์์ NHWC์ NCHW๋ฅผ ๋ชจ๋ ์ง์ํฉ๋๋ค. ํ์ฉ๋๋ ๊ฐ์ 4์ฐจ์ ์
๋ ฅ์ ๊ฒฝ์ฐ ("NHWC", "NCHW"), 5์ฐจ์ ์
๋ ฅ์ ๊ฒฝ์ฐ ("NDHWC", "NCDHW") ๋๋ ์
๋ ฅ๊ณผ ๋ฌด๊ดํ channels_first
/ channels_last
์
๋๋ค. ์ฐจ์. ๋งค๊ฐ๋ณ์ ์ค์ ์ ์ฌ๋ฐ๋ฅด๊ฒ ์ฒ๋ฆฌํ๋ ๊ฒ์ ์ฌ์ฉ์์๊ฒ ๋ฌ๋ ค ์์ต๋๋ค. ์ฆ, ํ
์์ ์ํด ์๋์ผ๋ก ์ถ์ ๋์ง ์์ต๋๋ค.
Caffe2์ด ๋งค๊ฐ ๋ณ์๊ฐ ํธ์ถ ํธ์ถ order
๋ณด๋ค๋ data_format
,ํ์ง๋ง ์ฌ์ ํ ๋ช
์ ์ ์ผ๋ก ๊ฐ๋ณ ์ด์์ ์์ค์์ ์ ์ฉํฉ๋๋ค.
๋ฆฌํธ๋จธ์ค ์ง๋ฌธ: ๋ค์ ์ฝ๋๋ ๋ฌด์์ ์ธ์ํฉ๋๊น: tensor_in_nhwc_layout.size(1)
- ์ฑ๋ ์(PyTorch์ ๊ธฐ๋ณธ๊ฐ์ NCHW์ด๊ธฐ ๋๋ฌธ์) ๋๋ ๋์ด(์์น 1์ NHWC ๋ ์ด์์์ ์๊ธฐ ๋๋ฌธ์).
์ด ๋ต๋ณ์ ๊ธฐ๋ฐ์ผ๋ก ๋ช ๊ฐ์ง ์ต์ ์ด ๊ฐ๋ฅํฉ๋๋ค.
empty_like
์๋ ํ ๊ฐ์ง ๋ฌธ์ ๊ฐ ์์ต๋๋ค. ํ์ฌ ์ ์๋ ์๋ฏธ๋ ๋ชจ๋ ๋ณดํญ ์ ๋ณด๋ฅผ ์ญ์ ํ๋ค๋ ๊ฒ์ด๋ฏ๋ก ๋ ์ด์์์ ์ ์งํ๊ณ BC๊ฐ ๋ ์ ์์ต๋๋ค.
@VitalyFedyunin ์ .contiguous()
๋ฐ torch.memory_layout
๋นํธ๋ฅผ ๊ตฌํํ๋๋ก ๋ฑ๋ก๋์์ต๋๋ค.
ํ ๊ฐ์ง ์ง๋ฌธ - (n, c, h, w)
ํฌ๊ธฐ์ 4D ํ
์ x
(n, c, h, w)
x = torch.randn(n,c,h,w)
# x.size(): (n, c, h, w)
# x.stride(): (c*h*w, h*w, w, 1)
์ฐ๋ฆฌ๋ ์ด์ํ ์์ด์ ๊ฐ์ง๊ณ ์์ต๋๋ค
y = x.permute(0, 3, 1, 2)
# y.size(): (n, w, c, h)
# y.stride(): (c*h*w, 1, h*w, w)
์ด์ NHWC ํ์์ ๋ํด ์ฐ์์ ์ธ์ง ํ์ธํฉ๋๋ค. ์๋์ ๊ฐ์ ๋ ผ๋ฆฌ์ ๋ฐ๋ผ
def is_nhwc_contiguous(x):
return x.permute(0,2,3,1).is_contiguous()
# or alternatively
def is_nhwc_contiguous(x):
n,c,h,w = x.size() # in any case the sizes remain in NCHW order
return x.stride() == (c*h*w, 1, c*w, c)
๋ ๊ฒฝ์ฐ ๋ชจ๋ is_nhwc_contiguous(y)
๋ True๋ฅผ ๋ฐํํฉ๋๊น?
์ด๊ฒ์ ๋ง์ต๋๋ค. ๊ทธ๋ฌ๋ ๋ณต์ฌ, to ๋ฐ ์ ์ฌํ ์์ ์ค์ ์๋ค๋ก ๋ณํ์ ํผํ๊ธฐ ์ํด ๋ณดํญ์๋ง ๋ฆด๋ ์ดํ ์ ์์ต๋๋ค.
strides์ ์์๊ฐ ๋ฉ๋ชจ๋ฆฌ ํ์๊ณผ ๊ฐ์ผ๋ฉด ์ด๋ป๊ฒ ๋ ๊น์? 4D ํ
์๋ฅผ ์๋ก ๋ค์ด๋ณด๊ฒ ์ต๋๋ค. ํ
์๋ฅผ ์ค๋ช
ํ๊ธฐ ์ํด sizes
, strides
๋ฐ stride_indexes
.
์ ํฌ๊ธฐ (N, C, H, w)
๋ฌผ๋ฆฌ์ ์์์ ๋ฐ๋ฅธ ๋ณดํญ , ์ฆ
stride_indexes ๋
nchw ํ์์ ๊ฒฝ์ฐ ์ด์ ๊ณผ ๋์ผํฉ๋๋ค. nhwc์ ๊ฒฝ์ฐ ๋น์ทํ ๊ฒ์ ๋๋ค.
def is_nhwc_contiguous(x):
n,c,h,w = x.size()
return x.stride() == (h*w*c, w*c, c, 1)
def is_nchw_contiguous(x):
n,c,h,w = x.size()
return x.stride() == (c*h*w, h*w, w, 1)
def is_nchw_format(x):
return x.stride_index() == (0, 1, 2, 3)
def is_nhwc_format(x):
return x.stride_index == (0, 2, 3, 1)
def is_contiguous(x):
if (is_nchw_format(x)):
return is_nchw_contiguous(x)
else if (is_nhwc_format(x)):
return is_nhwc_contiguous(x)
else:
warning_not_support()
# or, to use stride_index
def is_contiguous(x):
return x.stride() == (x.size[x.stride_index[1]]*x.size[x.stride_index[2]]*x.size[x.stride_index[3]], x.size[x.stride_index[2]] * x.size[x.stride_index[3]], x.size[x.stride_index[3]], 1)
์ด๊ฒ์ ๋ํ ์ฐจ๋จ๋ ํ์์ ์ง์ํ๋๋ก ํ์ฅ๋ ์ ์์ต๋๋ค. nChw16c๋ฅผ ์๋ก ์ฌ์ฉํฉ๋๋ค.
sizes: (n, c, h, w)
block_sizes: (n, c/16, h, w, 16)
strides: strides of (n, c/16, h, w, 16)
stride_indexes: (0, 1, 2, 3, 1) # assume blocked dimension is always in dense (i.e. on the right side of major dimension)
์์ธํ ๋ด์ฉ์ ๋์ค์ ์์ธํ ์์๋ณผ ์ ์์ต๋๋ค.
nchw ์ฐ์ ํ ์๋ง ํ์ฉํ๋ OP์ ๊ฒฝ์ฐ ์ฌ๊ธฐ์์ ์ฝ๊ฐ์ ์์ ์ด ํ์ํ ๊ฒ์ ๋๋ค.
๋๋ ํ๋กํ ํ์ ์ ์ฝ๊ฐ ๋ณ๊ฒฝํ ์๋ ์์ต๋๋ค.
def is_contiguous(format=nchw):
...
def contiguous(format=nchw)
...
๋ฐ๋ผ์ ๊ธฐ๋ณธ์ ์ผ๋ก nchw๋ง ์ฐ์์ ์ด๋ผ๊ณ ๊ฐ์ ํฉ๋๋ค. ์ด๋ฐ ์์ผ๋ก ํด๋น OP๋ฅผ ๋ค์ ์์ฑํ ํ์๊ฐ ์์ผ๋ฉฐ ์๋์ผ๋ก nchw๋ก ์ฌ์ ๋ ฌ๋ฉ๋๋ค.
์ฐ๋ฆฌ๋ ๋ค์์ ํํํ ์ ์๋ API๋ฅผ ๊ตฌ์ถํ๊ธฐ ์ํด ๋ ธ๋ ฅํฉ๋๋ค.
- Eager ๋ฐ JIT์ PyTorch์ ์๋ ๋ค๋ฅธ ๋ฉ๋ชจ๋ฆฌ ํ์(์ฒ์์๋ ์ฐจ์ ์์๋ง)์ ๊ฐ์ง ํ ์. ์ฐจ๋จ๋ ๋ ์ด์์์ ์ฐ์ ์์๊ฐ ๋ฎ์ง๋ง ์ฌ์ ํ ์ข์ต๋๋ค.
- ๋ฉ๋ชจ๋ฆฌ ํ์ ์ฟผ๋ฆฌ ๋ฐ ๋ณ๊ฒฝ์ ์ํ ์ฌ์ฉ์ ๋ ธ์ถ API
- ๋ค๋ฅธ ๋ฉ๋ชจ๋ฆฌ ํ์์ ๊ฐ์ง ์ ๋ ฅ ํ ์๋ฅผ ์ฒ๋ฆฌํ๊ณ ํด๋นํ๋ ๋ ๋น ๋ฅธ ๊ตฌํ์ผ๋ก ๋ผ์ฐํ ํ ์ ์๋ ํต์ฌ CNN ์์
- JIT ํจ์ค์ ๋ฉ๋ชจ๋ฆฌ ํ์์ ์ถ๋ก ํ๊ณ ์ต์ ํํ๋ ๊ธฐ๋ฅ
์ข์ ์ ์! ๋ด ์ดํด๊ฐ ์ฌ๋ฐ๋ฅธ์ง ํ์ธํ๊ฒ ์ต๋๋ค(MKL-DNN ํ์ ์ฒ๋ฆฌ์ ๋ํ ์ ์ ํฌํจ):
์ด ์ ์์ "ํ์" ํด๋์ค๋ก ๊ตฌํํ๋ค๊ณ ์๊ฐํ ์ ์์ต๋๋ค. API ์ฟผ๋ฆฌ ๋ฐ ๋ณ๊ฒฝ์ ๊ฐ์์ผ๋ก ์ ๊ณตํ๋ ํ MKL-DNN ๋ณตํฉ ํ์์ ๋ง๋ ์์/ํ์ฅ์ ์ํํ ์ ์์ต๋๋ค. ๋๋ ํ์ ์ฒ๋ฆฌ๋ฅผ ์ํ ํ๋ ์์ํฌ๋ฅผ ์ ๊ณตํ๋ ํ ๋ค๋ฅธ ๋ฐฉ๋ฒ์ ์ฌ์ฉํ์ฌ ์ค์ํ ์ธ๋ถ ์ฌํญ์ ์ฐ๋ฆฌ์๊ฒ ์ ๊ฐํฉ๋๋ค.
OP ๊ตฌํ์ ๋ํด ๊ฐ OP๋ ์ฑ๋ฅ์ ์ต๋ํํ๋ ์ ํธ ํ์๊ณผ ์๋ํ๋ ํธํ ํ์์ ๊ฐ์ง ์ ์์ต๋๋ค. ์์๋ณ ์ฐ์ฐ์(๋๋ ๋ ์ผ๋ฐ์ ์ผ๋ก ๋งํ๋ฉด ๋ฉ๋ชจ๋ฆฌ ์ ํ OP)๋ ๊ธฐ๋ณธ ์ค์ ์ด ์๋ค๊ณ ๊ฐ์ ํฉ๋๋ค. OP๋ "ํ์" ๊ฐ์ฒด๋ฅผ ์ฌ์ฉํ์ฌ ๊ฒฐ๊ณผ ํ ์๋ฅผ ์์ฑํฉ๋๋ค. ์ด ํ์ ๊ฐ์ฒด๋ ๊ธฐ๋ณธ pytorch ๊ธฐ๋์น์ ํธํ๋๋ ์ฟผ๋ฆฌ/๋ณ๊ฒฝ ์๋ฏธ ์ฒด๊ณ๋ฅผ ๋ณด์ฅํ ๋ฟ๋ง ์๋๋ผ ์ต์ ํ๋ ํจ์์ ์ผ๋ จ ๋ฒํธ(์: conv2d(ReLU(conv2d)))๋ก ํธ์ถ๋๋ ๊ฒฝ์ฐ ํน์ ํ์์ ์ฒ๋ฆฌํ ์ ์์ต๋๋ค. ์ฌ๋ก)
@uyongw ์ฒซ ๋ฒ์งธ ์์ ๋ํด ์ข ๋ ๋ช ํํ ํ๊ณ ์ถ์ต๋๋ค. "๋๋ NCHW ํ ์๋ฅผ ๊ฐ์ง๊ณ ์๋๋ฐ, ๊ทธ๊ฒ์ ์ด์ํ ๋ฐฉ์์ผ๋ก ์กฐ์ฎ๊นํ์ต๋๋ค. ๊ทธ๋์ ์ง๊ธ์ NWCH์ฒ๋ผ ๋ณด์ ๋๋ค. ์ด์ NHWC๊ฐ ์ฐ์์ ์ธ์ง ์๊ณ ์ถ์ต๋๋ค." ๊ทธ๋ฌ๋ ๊ทธ๊ฒ์ ์๋ชป๋ ๊ด์ ์ ๋๋ค. ๋ ๋์ ๊ณต์์ "๋๋ NHWC ํ ์๋ฅผ ๊ฐ์ง๊ณ ์์ผ๋ฉฐ NCHW ํ ์๋ก ์ ์นํ์ต๋๋ค."์ ๋๋ค.
๋ค๋ฅด๊ฒ ๋งํ๋ฉด ํ ์์ ๋ฌผ๋ฆฌ์ ์ฐจ์์๋ ๋ณธ์ง์ ์ธ ์๋ฏธ๊ฐ ์์ต๋๋ค(๋ณดํญ์ ๋ฌด์ํ ๋). ๋ณดํญ๊ณผ ๊ด๋ จํ์ฌ ์ฐธ์กฐํ๋ ๋ฐฉ๋ฒ์ ๊ณ ๋ คํ ๋๋ง ์๋ฏธ๋ฅผ ๋ถ์ฌํฉ๋๋ค.
ํ ์๋ฅผ ์ค๋ช ํ๊ธฐ ์ํด ํฌ๊ธฐ, ๋ณดํญ ๋ฐ stride_indexes๊ฐ ์์ต๋๋ค.
๋๋ stride_indexes
๊ฐ ๋ฌธ์ ์ ๋ํด ์๊ฐํ๋ ํธ๋ฆฌํ ๋ฐฉ๋ฒ์ด๋ผ๊ณ ์๊ฐํ์ง๋ง, strides์ ์๊ฒฉํ๊ฒ ์ค๋ณต๋ฉ๋๋ค. true strides.) @VitalyFedyunin ๊ณผ ์ ๋ strides ์์ฒด์์ ์ ๋ณด๋ฅผ ์ฌ๊ตฌ์ฑํ๋ ๊ฒ์ด
๋ฐ๋ผ์ ๊ธฐ๋ณธ์ ์ผ๋ก nchw๋ง ์ฐ์์ ์ด๋ผ๊ณ ๊ฐ์ ํฉ๋๋ค.
๋ค, ์ ๊ฐ ์ฝ์ ๊ณํ์ ๋๋ค.
@CaoZhongZ
์ด ์ ์์ "ํ์" ํด๋์ค๋ก ๊ตฌํํ๋ค๊ณ ์๊ฐํ ์ ์์ต๋๋ค. API ์ฟผ๋ฆฌ ๋ฐ ๋ณ๊ฒฝ์ ๊ฐ์์ผ๋ก ์ ๊ณตํ๋ ํ MKL-DNN ๋ณตํฉ ํ์์ ๋ง๋ ์์/ํ์ฅ์ ์ํํ ์ ์์ต๋๋ค. ๋๋ ํ์ ์ฒ๋ฆฌ๋ฅผ ์ํ ํ๋ ์์ํฌ๋ฅผ ์ ๊ณตํ๋ ํ ๋ค๋ฅธ ๋ฐฉ๋ฒ์ ์ฌ์ฉํ์ฌ ์ค์ํ ์ธ๋ถ ์ฌํญ์ ์ฐ๋ฆฌ์๊ฒ ์ ๊ฐํฉ๋๋ค.
๋๋ ๊ทธ๊ฒ์ด ์ ์์ ๋ํ ์ ํํ ์ค๋ช ์ด๋ผ๊ณ ์๊ฐํ์ง ์์ต๋๋ค. ์ฌ๊ธฐ์ ์ ์ํ๋ ๋ฉ๋ชจ๋ฆฌ ๋ ์ด์์ ์ง์์ ์คํธ๋ผ์ด๋๋ก ํํํ ์ ์๋ ๋ ์ด์์์ผ ๋ฟ์ ๋๋ค. ์ด ๋ฐฉ๋ฒ์ผ๋ก ํํํ ์ ์๋ ๋ชจ๋ ๊ฒ(์: ๋ธ๋ก ๋ ์ด์์)์ ์ด ๋ฐฉ๋ฒ์ผ๋ก ์๋ํ์ง ์์ผ๋ฉฐ ๋ ๋ฌด๊ฑฐ์ด "๋ ์ด์์" ๋ฉ์ปค๋์ฆ์ ์ํด ์ง์๋์ด์ผ ํฉ๋๋ค.
๋ค๋ฅด๊ฒ ๋งํ๋ฉด ํ ์์ ๋ฌผ๋ฆฌ์ ์ฐจ์์๋ ๋ณธ์ง์ ์ธ ์๋ฏธ๊ฐ ์์ต๋๋ค(๋ณดํญ์ ๋ฌด์ํ ๋). ๋ณดํญ๊ณผ ๊ด๋ จํ์ฌ ์ฐธ์กฐํ๋ ๋ฐฉ๋ฒ์ ๊ณ ๋ คํ ๋๋ง ์๋ฏธ๋ฅผ ๋ถ์ฌํฉ๋๋ค.
๋ถ๋ถ์ ์ผ๋ก ๋์ํฉ๋๋ค :-) ๊ทธ๋ฌ๋ ์ด ํน์ ํ ๋ฌธ์ ์ ๋ํด์๋ ์๋๋๋ค. ์ด๋ฏธ nhwc ํ ์๊ฐ ์๋ค๊ณ ๊ฐ์ ํด ๋ณด๊ฒ ์ต๋๋ค. ๊ทธ๋ฐ ๋ค์ nwhc๋ก ๋ณ๊ฒฝํฉ๋๋ค. nhwc๋ก ๋ ์์ดํ ๋ค์ contiguous()๋ฅผ ์ํํ๊ณ ์ถ์ต๋๋ค. ๊ทธ๋ฌ๋ ๋๋ ๊ทธ๊ฒ์ ์ด๋ฏธ ์ฐ์์ ์ผ๋ก ์ป์์ต๋๋ค. ํผ๋์ค๋ฝ์ง ์์ต๋๊น?
๋๋ stride_indexes๊ฐ ๋ฌธ์ ์ ๋ํด ์๊ฐํ๋ ํธ๋ฆฌํ ๋ฐฉ๋ฒ์ด๋ผ๊ณ ์๊ฐํ์ง๋ง, stride์ ์๊ฒฉํ๊ฒ ์ค๋ณต๋ฉ๋๋ค. ์๋ํ๋ฉด ๋น์ ์ด ๋งํ๋ ๋ชจ๋ ๊ฒ์ "์ด (์ญ?) ์์ด์ ๋ณดํญ์ ์ ์ฉํ๊ณ ๊ทธ๊ฒ์ ์ง์ ํ ๋ณดํญ์ผ๋ก ์ทจ๊ธํ๊ธฐ ๋๋ฌธ์ ๋๋ค.)
IMHO, nhwc(๋ฌผ๋ฆฌ์ )์ ๋ณดํญ์ด ์๋ ๊ฒฝ์ฐ ๋ณดํญ์ด ์ค๋ณต๋์ง ์์ต๋๋ค. ํฌ๊ธฐ(๋ ผ๋ฆฌ)๊ฐ ์๋ ์ฌ๋ฐ๋ฅธ ๋งคํ์ด ํ์ํ๊ธฐ ๋๋ฌธ์ ๋๋ค. ๊ทธ๋ ์ง ์์ผ๋ฉด ์ค์ ์์๋ฅผ ๋งํ ๋ฐฉ๋ฒ์ด ์์ต๋๋ค.
BTW ์ญ ๋งคํ์ ์ฌ์ฉํ๋ ๋ ๊ฐ๋จํ ์ ๊ทผ ๋ฐฉ์์ด ์์ต๋๋ค. nchw์ ๊ฒฝ์ฐ (0, 1, 2, 3)์ด๊ณ nhwc์ ๊ฒฝ์ฐ (0, 2, 3, 1) ๋์ (0, 3, 1, 2)์ ๋๋ค. ์ฆ, stride_index ์์ฒด๋ ํญ์ NCHW์ ๋๋ค. ๊ทธ๋ฌ๋ ๋ฌธ์ ๋ nChw16c ๋๋ OIhw16i16o์ ๊ฐ์ ์ฐจ๋จ๋ ํ์์ผ๋ก ํ์ฅํ ์ ์๋ค๋ ๊ฒ์ ๋๋ค.
์ฐจ๋จ๋ ํ์์๋ ์์ ํ ๋ค๋ฅธ ์ฐ์ฐ์ ๊ตฌํ ์งํฉ์ด ํ์ํฉ๋๋ค. ์ด๋ฌํ ์ด์ ๋ก ์ฐ๋ฆฌ๋ ์ ์์ ๋ชจ๋ ๊ธฐ์กด ์ฐ์ฐ์์ ์น์ํ๊ณ ๋์ผํ๊ฑฐ๋ ๋ ๋์ ์ฑ๋ฅ์ผ๋ก ์๋ํ๋ '๋ฉ๋ชจ๋ฆฌ ํ์'๊ณผ ํผํฉํ์ง ์๋ ๊ฒ์ ์ ํธํฉ๋๋ค.
๋ถ๋ถ์ ์ผ๋ก ๋์ํฉ๋๋ค :-) ๊ทธ๋ฌ๋ ์ด ํน์ ํ ๋ฌธ์ ์ ๋ํด์๋ ์๋๋๋ค. ์ด๋ฏธ nhwc ํ ์๊ฐ ์๋ค๊ณ ๊ฐ์ ํด ๋ณด๊ฒ ์ต๋๋ค. ๊ทธ๋ฐ ๋ค์ nwhc๋ก ๋ณ๊ฒฝํฉ๋๋ค. nhwc๋ก ๋ ์์ดํ ๋ค์ contiguous()๋ฅผ ์ํํ๊ณ ์ถ์ต๋๋ค. ๊ทธ๋ฌ๋ ๋๋ ๊ทธ๊ฒ์ ์ด๋ฏธ ์ฐ์์ ์ผ๋ก ์ป์์ต๋๋ค. ํผ๋์ค๋ฝ์ง ์์ต๋๊น?
์ผ๋ถ ์ฉ์ด๋ฅผ ๊ตฌ์ด์ฒด๋ก ์ฌ์ฉํ๊ณ ์ ํ์ฑ์ด ํ์ํ๊ธฐ ๋๋ฌธ์ ๊ทํ์ ์๋ฅผ ์ดํดํ๊ธฐ ์ด๋ ต์ต๋๋ค. ๋ง์ํ์ ๋ด์ฉ์ ์ ๊ฐ ํด์ํ๋ ๋ฐฉ๋ฒ์ ๋ค์๊ณผ ๊ฐ์ต๋๋ค.
y = x.permute(0, 2, 3, 1)
๋ฅผ ์คํํ๋ ๊ฒ์
๋๋ค. ์๋ํ๋ฉด ๋ฌผ๋ฆฌ์ ๋ ์ด์์์ด ์๋๋ผ ๋
ผ๋ฆฌ์ ๋ ์ด์์์ ์นํํ๊ณ ์๊ธฐ ๋๋ฌธ์
๋๋ค. (์๋ ๊ฒ์๋ฌผ์์ ์์ด x.permute(0, 3, 1, 2)
์ ์ธ๊ธํ๊ธฐ ๋๋ฌธ์ ์ด๊ฒ์ด ์๋ฏธํ๋ ๋ฐ๊ฐ ์๋๋ผ๊ณ ์๊ฐํฉ๋๋ค.z = y.permute(0, 2, 3, 1)
์ ์ ์ฉํ๋ ๊ฒ์
๋๋ค. ์ด์ ๋
ผ๋ฆฌ์ ๋ ์ด์์์ด ๋ฌผ๋ฆฌ์ ๋ ์ด์์๊ณผ ์ผ์นํ๋ ํ
์๊ฐ ์์ต๋๋ค. ์ด๊ฒ์ ์ฐ๋ฆฌ๊ฐ z.contiguous()
๋ฌป๋๋ค๋ฉด ์ฐ๋ฆฌ๋ ์ฐธ์ด ๋ ๊ฒ์ด๋ผ๋ ๊ฒ์ ์๋ฏธํฉ๋๋ค(๊ทธ๋ฆฌ๊ณ ํผ๋์ค๋ฝ๊ฒ๋ z.contiguous(memory_layout=NCHW)
๋ ์ฐธ์ด ๋ ๊ฒ์
๋๋ค.) ๊ทธ๋ฌ๋ NHWC ์ฐ์์ ์ด์ง๋ ์์ ๊ฒ์
๋๋ค.๋๋ ์ด๊ฒ์ด ๋น์ ์ด ์ผ๋์ ๋ ์๋ผ๊ณ ์๊ฐํ์ง ์์ต๋๋ค. ์ด ๊ฒฝ์ฐ "์์ด"์ด ์๋ฏธํ๋ ๋ฐ์ ๋ํด ๋ ์ ํํด์ผ ํฉ๋๋ค.
IMHO, nhwc(๋ฌผ๋ฆฌ์ )์ ๋ณดํญ์ด ์๋ ๊ฒฝ์ฐ ๋ณดํญ์ด ์ค๋ณต๋์ง ์์ต๋๋ค. ํฌ๊ธฐ(๋ ผ๋ฆฌ)๊ฐ ์๋ ์ฌ๋ฐ๋ฅธ ๋งคํ์ด ํ์ํ๊ธฐ ๋๋ฌธ์ ๋๋ค. ๊ทธ๋ ์ง ์์ผ๋ฉด ์ค์ ์์๋ฅผ ๋งํ ๋ฐฉ๋ฒ์ด ์์ต๋๋ค.
์ด ์ ์์ ํต์ฌ์ ๋๋ค : ๋ ผ๋ฆฌ์ ์ธ ๋ ์ด์์์ผ๋ก ์ฐ๋ฆฌ ํน๊ถ NCHW, ํญ์. ๋ฐ๋ผ์ ๋ด๊ฐ ๋ชจ๋ฅด๋ 4D ํ ์๊ฐ ์๋ ๊ฒฝ์ฐ ๋ ผ๋ฆฌ์ ๋ ์ด์์์ด NCHW๋ผ๊ณ ๊ฐ์ ํฉ๋๋ค. ๊ทธ๊ฒ์ ๋ชจํธ์ฑ์ ์ ๊ฑฐํฉ๋๋ค. ๋ ผ๋ฆฌ์ ๋ ์ด์์์ด NCHW๊ฐ ์๋ ํ ์๋ฅผ ์ฒ๋ฆฌํ๊ณ ์ถ๋ค๋ฉด ๋ช ์๋ API๊ฐ ์ถ์ ์กฐ๊ธ ์ด๋ ต๊ฒ ๋ง๋ ๋ค๊ณ ์๊ฐํฉ๋๋ค.
@dzhulgakov
์์ ์ ๋ฉ๋ชจ๋ฆฌ ํ์ ๋์์ ์ ์งํฉ๋๋ค.
๋ฌผ๋ฆฌ์ NHWC ํ ์๊ฐ ์์ ํ ์คํธ๋ผ์ด๋๋ฅผ ํตํด ๋ฐ์ํ ์ ์๋ ๊ฒฝ์ฐ ๋ฉ๋ชจ๋ฆฌ ํ์ ํ๊ทธ๊ฐ ์์ ๋๋ง ๋ฉ๋ชจ๋ฆฌ ํ์์ ๋ณด์กดํ๋๋ก ํ์ง ์๋ ํ ์ด๊ฒ์ ๊ธฐ์ ์ ์ผ๋ก BC ๊นจ๋ ๊ฒ์ ๋๋ค. ์ ์์ด ํ์ฌ ๋ฌด์์ ์ ์ํ๊ณ ์๋์ง ์ ๋ชจ๋ฅด๊ฒ ์ต๋๋ค.) ์ด๊ฒ์ด ์ค์ ๋ก ์ค์ ๋ก ๋๊ตฐ๊ฐ์ ์ฝ๋๋ฅผ ์์์ํค๋์ง ํ์คํ์ง ์์ต๋๋ค.
๋ฌผ๋ฆฌ์ NHWC ํ ์๊ฐ ์์ ํ ์คํธ๋ผ์ด๋๋ฅผ ํตํด ๋ฐ์ํ ์ ์๋ ๊ฒฝ์ฐ ๋ฉ๋ชจ๋ฆฌ ํ์ ํ๊ทธ๊ฐ ์์ ๋๋ง ๋ฉ๋ชจ๋ฆฌ ํ์์ ๋ณด์กดํ๋๋ก ํ์ง ์๋ ํ ์ด๊ฒ์ ๊ธฐ์ ์ ์ผ๋ก BC ๊นจ๋ ๊ฒ์ ๋๋ค. ์ ์์ด ํ์ฌ ๋ฌด์์ ์ ์ํ๊ณ ์๋์ง ์ ๋ชจ๋ฅด๊ฒ ์ต๋๋ค.) ์ด๊ฒ์ด ์ค์ ๋ก ์ค์ ๋ก ๋๊ตฐ๊ฐ์ ์ฝ๋๋ฅผ ์์์ํค๋์ง ํ์คํ์ง ์์ต๋๋ค.
๋ฉ๋ชจ๋ฆฌ ํ์์ '๊ณ ์ '์ผ๋ก ๋ง๋ค ์ ์๋ค๊ณ ๊ฐ์ ํฉ๋๋ค. ๋ฉ๋ชจ๋ฆฌ ํ์ ํ ์์ ๋ํ ์ฐ์ฐ์ ๋ฉ๋ชจ๋ฆฌ ํ์ ํ ์๋ฅผ ์์ฑํฉ๋๋ค. ๊ทธ๋ฌ๋ฉด BC ๋ฌธ์ ๊ฐ ํด๊ฒฐ๋ฉ๋๋ค.
๊ทธ๋ฌ๋ ํ ์์ ๋ฉ๋ชจ๋ฆฌ ํ์์ด ๋ค๋ฅธ ๊ฒฝ์ฐ ์ด์ง(๋๋ ๋ ๋ง์ ๋ฉค๋ฒ) ์์ ์ ๋์์ ์ ์ํด์ผ ํฉ๋๋ค.
@ezyang ์ค ๋ฐฉ๊ธ ์์ ๋ต๋ณ์ ์คํ๊ฐ ์์์ ๋ฐ๊ฒฌํ์ต๋๋ค. (์ฃ์กํฉ๋๋ค. ๊ทธ๋ฌ๋ ์๋์ ์๋ ์ฌ์ ํ ์ ํํฉ๋๋ค.) ์๋์ ๊ฐ์ด ๋ค์ ์ค๋ช ํ๊ฒ ์ต๋๋ค.
ํ์ง๋ง 2๋จ๊ณ ์ดํ์ ์ด๋ฏธ NHWC ์ฐ์์ฑ์ ์ป์์ต๋๋ค. ๊ทธ๋ฐ ๋ค์ 3๋จ๊ณ๋ฅผ ๊ฑด๋๋ฐ๊ณ 4๋จ๊ณ์์ ์ง์ NHWC๋ก ์ฌ์ฉํ ์ ์์ต๋๋ค. ํ์ง๋ง ์ด๊ฒ์ ํ ์์ ๋ฌผ๋ฆฌ์ ์์๊ฐ ์ ํ ๋ณ๊ฒฝ๋์ง ์๊ธฐ ๋๋ฌธ์ ํ์คํ ์ณ์ง ์์ต๋๋ค.
์ฐจ๋จ๋ ํ์์๋ ์์ ํ ๋ค๋ฅธ ์ฐ์ฐ์ ๊ตฌํ ์งํฉ์ด ํ์ํฉ๋๋ค. ์ด๋ฌํ ์ด์ ๋ก ์ฐ๋ฆฌ๋ ์ ์์ ๋ชจ๋ ๊ธฐ์กด ์ฐ์ฐ์์ ์น์ํ๊ณ ๋์ผํ๊ฑฐ๋ ๋ ๋์ ์ฑ๋ฅ์ผ๋ก ์๋ํ๋ '๋ฉ๋ชจ๋ฆฌ ํ์'๊ณผ ํผํฉํ์ง ์๋ ๊ฒ์ ์ ํธํฉ๋๋ค.
์, ์ฒซ ๋ฒ์งธ ๋จ๊ณ๋ก NHWC๋ฅผ ํ์ฑํํ ์ ์์ต๋๋ค. ๊ทธ๋ฌ๋ ์ค์ ๋ก ์ฐจ๋จ๋ ํ์์ด ์์ ํ ๋ค๋ฅธ ๊ฒ์ด๋ผ๊ณ ์๊ฐํ์ง ์์ต๋๋ค. ๊ทธ๊ฒ์ ์์ฐ์ค๋ฝ๊ฒ ํํ๋ ์ ์์ต๋๋ค(์ฝ๊ฐ์ ์ข์ ์ถ์ํ๋ก). ์ผ๋ฐ์ ์ธ ํ์ ์ค๋ช ์ด ์๋ ๊ฒฝ์ฐ ๋ค๋ฅธ ์ฌ๋๋ค์ ์์์ ์ฐจ๋จ/๋ณดํญ์ผ๋ก ์ ํ์์ ๋ฑ๋กํ ์ ์์ต๋๋ค.
๋๊ตฐ๋ค๋ ์ด๋ฏธ ์ง์์ ์ฐจ๋จํ๋ค๋ฉด ๊ธฐ๋ณธ์ด ๋๋ ๋ชจ๋ ๊ฒ์ ์คํํ๊ธฐ ์ํด ์จ๊ฒจ์ง ๊ตฌ์ฑ์ ๋ง๋๋ ๋ฐ ์ ๊ฒฝ ์ฐ์ง ์์๋ ๋ฉ๋๋ค. ๋ด๋ถ์ ์์์ ์ธ๊ณ๊ฐ ์์ฑ๋๊ณ ๋ ์ธ๊ณ ์ฌ์ด์ ์์/๋์ด ๋ฌธ์ ๊ฐ ๋ ์ ์์ต๋๋ค.
์ด์จ๋ ์ฐจ๋จ๋ ํ์์ ๋ํด ์๊ฐํ๊ธฐ์๋ ๋๋ฌด ๋ฉ๋ฆฌ ๋จ์ด์ ธ ์์ ์ ์์ต๋๋ค. ํ์ง๋ง ๊ฐ๋ฅํ๋ฉด ๋์์ธ์ ํ์ฅ ๊ฐ๋ฅํ๊ฒ ๋ง๋๋ ๊ฒ์ด ๋ ๋ซ๋ค๊ณ ์๊ฐํฉ๋๋ค.
ํ์ง๋ง 2๋จ๊ณ ์ดํ์ ์ด๋ฏธ NHWC ์ฐ์์ฑ์ ์ป์์ต๋๋ค. ๊ทธ๋ฐ ๋ค์ 3๋จ๊ณ๋ฅผ ๊ฑด๋๋ฐ๊ณ 4๋จ๊ณ์์ ์ง์ NHWC๋ก ์ฌ์ฉํ ์ ์์ต๋๋ค. ํ์ง๋ง ์ด๊ฒ์ ํ ์์ ๋ฌผ๋ฆฌ์ ์์๊ฐ ์ ํ ๋ณ๊ฒฝ๋์ง ์๊ธฐ ๋๋ฌธ์ ํ์คํ ์ณ์ง ์์ต๋๋ค.
์๊ฒ ์ต๋๋ค. ์ด์ ๊ทํ์ ์๋ฅผ ์ดํดํฉ๋๋ค. ์ค์ ๋ก 2๋จ๊ณ์์ ๋ฉ์ถ๊ณ NCHW ํ ์์ธ ๊ฒ์ฒ๋ผ ์ฌ์ฉํ ์ ์์ต๋๋ค. ์ด ๊ฒฝ์ฐ W๋ฅผ C ๋ฑ์ผ๋ก ๋ถ์ ์ ํ๊ฒ ํด์ํ๊ฒ ๋ฉ๋๋ค. ์ด๊ฒ์ ํ์คํ ๋ณดํญ ๊ธฐ๋ฐ ๊ตฌํ์ ๋จ์ ์ ๋๋ค( @dzhulgakov , ์๋ง๋ ์ด๊ฒ์ ์ ์์ ์ถ๊ฐํด์ผ ํ ๊ฒ์ ๋๋ค). ์ ์์์๋ ์ด ๊ฒฝ์ฐ์ ๋ํ ๋ช ๊ฐ์ง ์กฐํญ์ด ์์ต๋๋ค.
์์ ๋ฌธ์ ๋ฅผ ํด๊ฒฐํ๊ธฐ ์ํด ์ด๊ธฐ ์ ์์ ํ ์์์ ์ํ๋ ๋ง์ง๋ง to(memory_format) ํธ์ถ์ ๊ธฐ๋กํ๋ "์ํํธ" ๋ฉ๋ชจ๋ฆฌ ํ์ ํ๊ทธ๋ฅผ ํ ์์ ๋์ ํ๋ ๊ฒ์ ๋๋ค. ์ด์์๋ ์ด ์ฃผ์์ ์ถ๋ ฅ์ ์ ํํด์ผ ํฉ๋๋ค. ์ฃผ์์ "์ํํธ"์ด๋ฏ๋ก ๋ถ์ผ์น ์ฃผ์์ ๋ํ ํ๋ ์ค๋ฅ๊ฐ ๋ฐ์ํ์ง ์๊ณ ํ๋กํ์ผ๋ง ๋ชจ๋์์ ๊ฒฝ๊ณ ๊ฐ ์์ฑ๋ฉ๋๋ค.
์ํํธ ๋ฉ๋ชจ๋ฆฌ ํ์ ํ๊ทธ๋ฅผ ์ฌ์ฉํ๋ฉด ์์ดํ NCHW ํ ์์ ์ค์ ๋ก ๋ฌผ๋ฆฌ์ ์ผ๋ก NHWC์ธ ํ ์๋ฅผ ๊ตฌ๋ณํ ์ ์์ต๋๋ค. ๊ทธ๋ฌ๋ ํ์ฌ ํ์์ ์ํํธ ํ๊ทธ๋ ๊ตฌ์๋ ฅ์ด ์์ผ๋ฏ๋ก ์ค์ ๋ก ์ด ๊ฒฝ์ฐ์ ์ผ๋ง๋ ์ ์ฉํ์ง ์ ๋ชจ๋ฅด๊ฒ ์ต๋๋ค.
๋ฌธ์ ๋ฅผ ํด๊ฒฐํ๋ ๋ ๋ค๋ฅธ ๋ฐฉ๋ฒ์ ๋ช ๋ช ๋ ํ ์๋ฅผ ์ฌ์ฉํ๋ ๊ฒ์ ๋๋ค. ๋ช ๋ช ๋ ํ ์๋ฅผ ์ฌ์ฉํ๋ฉด (๋ ผ๋ฆฌ์ ) ์ฐจ์์ ์ด๋ฆ์ ์ฌ์ฉํ์ฌ ํ ์๋ฅผ NCHW(๊ฐ์ ๊ธฐ๋ณธ๊ฐ)๋ก ๋ณด๊ณ ์๋์ง ์๋๋ฉด ๋ค๋ฅธ ๊ฒ์ผ๋ก ๋ณด๊ณ ์๋์ง ํ์ ํ ์ ์์ต๋๋ค.
๊ทธ๋ฌ๋ ์ค์ ๋ก ์ฐจ๋จ๋ ํ์์ด ์์ ํ ๋ค๋ฅธ ๊ฒ์ด๋ผ๊ณ ์๊ฐํ์ง ์์ต๋๋ค. ๊ทธ๊ฒ์ ์์ฐ์ค๋ฝ๊ฒ ํํ๋ ์ ์์ต๋๋ค(์ฝ๊ฐ์ ์ข์ ์ถ์ํ๋ก). ์ผ๋ฐ์ ์ธ ํ์ ์ค๋ช ์ด ์๋ ๊ฒฝ์ฐ ๋ค๋ฅธ ์ฌ๋๋ค์ ์์์ ์ฐจ๋จ/๋ณดํญ์ผ๋ก ์ ํ์์ ๋ฑ๋กํ ์ ์์ต๋๋ค.
์ฌ๊ธฐ์ ์ฃผ์ ์ ๋ํ ๋ ๋ง์ ์ค๋ช ์ด ์์ต๋๋ค: https://github.com/pytorch/pytorch/issues/16038#issuecomment -454490374
@ezyang ๋ต๋ณ ๊ฐ์ฌํฉ๋๋ค. ์, ์ํํธ ํ์ ํ๊ทธ๊ฐ ๋์์ด ๋ ์ ์์ต๋๋ค. ๋ฌธ์ ๋ ์ฐจ์ ์์๊ฐ ์์์ ์ผ ์ ์์ผ๋ฏ๋ก ์ถฉ๋ถํ ์ ์ฐํ์ง ์์ ์ ์๋ค๋ ๊ฒ์ ๋๋ค. ๋ํ ์์ฒด์ ์ผ๋ก ๊ณ์ฐํ ์ ์์ต๋๋ค. ๋ช ๋ช ๋ ํ ์๋ ๊ฐ ์ฐจ์์ ๋ํ ์๋ฏธ๋ก ์ ์๋ฏธ๋ฅผ ๊ฐ์ง๋ง ์ง์ํ๊ธฐ ์ํด ๋ ๋ง์ ๊ธฐ๋ฅ์ด ํ์ํ ์ ์์ต๋๋ค.
๊ฐ์ธ์ ์ผ๋ก ๋๋ ์ด๊ฒ์ด ๋ณดํญ ์์(๋ฌผ๋ฆฌ์ )์์ NCHW ํฌ๊ธฐ ์์(๋
ผ๋ฆฌ์ )๋ก์ ๋งต์ ๋์
ํ์ฌ ํด๊ฒฐํ ์ ์๋ค๊ณ ์๊ฐํฉ๋๋ค. ์์์ ์ ์ํ ๊ฒ์ฒ๋ผ NCHW์ ๊ฒฝ์ฐ ํ์ฌ ๋์์ธ๊ณผ ๊ฑฐ์ ๋์ผํฉ๋๋ค. NHWC์ ๊ฒฝ์ฐ sizes
๋ ์ฌ์ ํ NCHW์ด๊ณ strides
๋ (N, H, W, C) ์์์
๋๋ค. ๊ทธ๋ฆฌ๊ณ stride_index
= (0, 2, 3, 1)์ ์ฌ์ฉํ์ฌ ๋ณดํญ์ ์ฐจ์ ์ธ๋ฑ์ค๋ฅผ ์ง์ ํฉ๋๋ค.
๋ํ strides
๋ฐ stride_index
์ ์กฐํฉ์ ์ฌ์ฉํ์ฌ ๋ชจ๋ ํ
์ ํ์์ ๋ํ๋ผ ์ ์์ต๋๋ค. ์ด๊ฒ์ ๋ค๋ฅธ ์ฌ๋๋ค์ด ์๋ก์ด ๋ฐ์ดํฐ ํ์์ ๋ฑ๋กํ ์ ์๋ ์ ์ฐ์ฑ์ ์ ๊ณตํ ์ ์์ต๋๋ค.
@ezyang
์์ ์ ๋ฉ๋ชจ๋ฆฌ ํ์ ๋์์ ์ ์งํฉ๋๋ค.
๋ฌผ๋ฆฌ์ NHWC ํ ์๊ฐ ์์ ํ ์คํธ๋ผ์ด๋๋ฅผ ํตํด ๋ฐ์ํ ์ ์๋ ๊ฒฝ์ฐ ๋ฉ๋ชจ๋ฆฌ ํ์ ํ๊ทธ๊ฐ ์์ ๋๋ง ๋ฉ๋ชจ๋ฆฌ ํ์์ ๋ณด์กดํ๋๋ก ํ์ง ์๋ ํ ์ด๊ฒ์ ๊ธฐ์ ์ ์ผ๋ก BC ๊นจ๋ ๊ฒ์ ๋๋ค. ์ ์์ด ํ์ฌ ๋ฌด์์ ์ ์ํ๊ณ ์๋์ง ์ ๋ชจ๋ฅด๊ฒ ์ต๋๋ค.) ์ด๊ฒ์ด ์ค์ ๋ก ์ค์ ๋ก ๋๊ตฐ๊ฐ์ ์ฝ๋๋ฅผ ์์์ํค๋์ง ํ์คํ์ง ์์ต๋๋ค.
์ฐ์ ์ฐ์ฐ๊ณผ ์๊ณ๊ฐ์ด TensorIterator๋ก ์ด๋ํ์ ๋ ์ด๋ ๊ธฐ์ ์ ์ผ๋ก BC ํ๊ดด์์ต๋๋ค(ํผ์ฐ์ฐ์์ ๋ฉ๋ชจ๋ฆฌ ํ์์ด ๋ณด์กด๋์ง ์๊ณ TensorIterator๊ฐ ์ด๋ฅผ ๋ณด์กดํ๊ธฐ ๋๋ฌธ์
๋๋ค). ํ์ฌ ์ํ๋ ๋งค์ฐ ์ผ๊ด์ฑ์ด ์์ต๋๋ค. ์๊ณ๊ฐ์ ๋ ์ด์์์ ์ ์งํ์ง๋ง, ๋ค๋ฅธ ๋ชจ๋ ๋จํญ ์ฐ์ฐ์ ๊ทธ๋ ์ง ์์ผ๋ฉฐ, torch.where๋ ๊ทธ๋ ์ง ์์ต๋๋ค. ๋ ํผ์ฐ์ฐ์์ ๋ ์ด์์์ด ๋์ผํ ๊ฒฝ์ฐ ์ฐ์ ์ฐ์ฐ์ ๋ ์ด์์์ ์ ์งํ์ง๋ง ๊ธฐ๋ณธ๊ฐ์ "nchw" ๋๋ contiguous
ํ
์์
๋๋ค contiguous
๋ฏธ์ค๋งค์นญ์ด ์์ ๊ฒฝ์ฐ ๋ฐฉ์ก์ ์ด๋ป๊ฒ ๋๋์ง ์ ๋ชจ๋ฅด๊ฒ ์ต๋๋ค.
๋ํ BC๊ฐ ์๋ ๋ ์ด์์์ ์ ์งํ๋ empty_like
๋ํด ์ข์ ์ง์ ์ ํ๊ณ ์์ต๋๋ค. ์๋ง๋ ์ ์์์ is_contiguous์ ๊ฐ์ ๋ ์ด์์ ์ธ์๊ฐ ํ์ํ ๊ฒ์
๋๋ค.
x.is_contiguous(torch.memory_format.channels_first)
@ezyang @ngimel
empty_like์๋ ํ ๊ฐ์ง ๋ฌธ์ ๊ฐ ์์ต๋๋ค. ํ์ฌ ์ ์๋ ์๋ฏธ๋ ๋ชจ๋ ๋ณดํญ ์ ๋ณด๋ฅผ ์ญ์ ํ๋ค๋ ๊ฒ์ด๋ฏ๋ก ๋ ์ด์์์ ์ ์งํ๊ณ BC๊ฐ ๋ ์ ์์ต๋๋ค.
๋ํ BC๊ฐ ์๋ ๋ ์ด์์์ ์ ์งํ๋ empty_like ๋ฑ์ ๋ํด ์ข์ ์ง์ ์ ํ๊ณ ์์ต๋๋ค.
์ฐ๋ฆฌ๊ฐ ๋ฌผ๋ฆฌ์ ์ง์๋ฅผ ํํํ๊ธฐ ์ํด ๋ณดํญ์ ์์กดํ์ง ์๋๋ค๋ฉด, empty_like
๋ BC๋ฅผ ๊นจ๋จ๋ฆด ํ์๊ฐ ์์ต๋๋ค. ํ
์์๋ 3๊ฐ์ง ์ฐจ์ ์ ๋ณด๊ฐ ์์ต๋๋ค.
ํ์ฌ ๋ฌผ๋ฆฌ์ ์์๋ ๋ชจ์/ํฌ๊ธฐ์ ๋์ผํฉ๋๋ค. ๊ทธ๋์ ์ฐ๋ฆฌ๋ ๋
ผ๋ฆฌ ์์๋ฅผ ๋ณดํญ์ผ๋ก ๋จ์ด๋จ๋ฆฝ๋๋ค. ๋ชจ์๊ณผ ๋ฌผ๋ฆฌ์ ์์๋ฅผ ๋ถ๋ฆฌํ๋ค๊ณ ์๊ฐํ๋ฉด ๋
ผ๋ฆฌ ์์๋ฅผ ์ญ์ ํ ์๋ ์์ง๋ง empty_like
๋ํ ๋ชจ์๊ณผ ๋ฌผ๋ฆฌ์ ์์๋ ๋ณด์กดํ ์ ์์ต๋๋ค. ์ฆ, size()
๋ฐ stride_index()
๋ ๋ชจ๋ ๋ณด์กด๋์ง๋ง stride()
๋ ์ฌ์ค์ ๋ฉ๋๋ค. ํนํ NHWC ํ
์์ empty_like
๋ ๋์ผํ ๋ชจ์ ์ ๋ณด๊ฐ ์ง์ ๋ NHWC ์ฐ์ ํ
์๋ฅผ ๋ฐํํฉ๋๋ค.
@uyongw empty_like
๋ณ๊ฒฝํ๋ ๊ฒ์ด ์ข์ ์๊ฐ์ธ์ง ์ ๋ชจ๋ฅด๊ฒ ์ต๋๋ค. ์ง๊ธ ๊ทธ ์๋ฏธ๋ numpy์ empty_like
์ ์ผ์นํฉ๋๋ค.
ํ์ฌ ์ํ๋ ๋งค์ฐ ์ผ๊ด์ฑ์ด ์์ต๋๋ค. ์๊ณ๊ฐ์ ๋ ์ด์์์ ์ ์งํ์ง๋ง, ๋ค๋ฅธ ๋ชจ๋ ๋จํญ ์ฐ์ฐ์ ๊ทธ๋ ์ง ์์ผ๋ฉฐ, torch.where๋ ๊ทธ๋ ์ง ์์ต๋๋ค. ๋ ํผ์ฐ์ฐ์๊ฐ ๋์ผํ ๋ ์ด์์์ ๊ฐ๊ณ ์๋ ๊ฒฝ์ฐ ์ฐ์ ์ฐ์ฐ์ ๋ ์ด์์์ ์ ์งํ์ง๋ง ๊ธฐ๋ณธ๊ฐ์ "nchw" ๋๋ ์ธ์ ํ๋ ํ ์์ ๋๋ค. ๋ถ์ผ์น๊ฐ ์๋ ๊ฒฝ์ฐ ํ์ฌ ์ดํดํ๊ณ ์์ง๋ง ๋ฐฉ์ก์ ์ด๋ป๊ฒ ๋๋์ง ์ ๋ชจ๋ฅด๊ฒ ์ต๋๋ค.
@ngimel , ์, ์ง๊ธ์ ์ผ๊ด์ฑ์ด ์์ต๋๋ค. ๋ฉ๋ชจ๋ฆฌ ํ์์ ํํํ๋ ๋ฐฉ๋ฒ์ ์ผ๋ถ๋ ์ฐ์ฐ์๋ฅผ ์ผ๊ด๋ ์ํ๋ก ๋ง๋๋ ๊ฒ์ด๋ผ๊ณ ์๊ฐํฉ๋๋ค.
@zou3519 ๋งํฌํ numpy์ empty_like์๋ ๊ธฐ๋ณธ์ ์ผ๋ก "ํ๋กํ ํ์
์ ๋ ์ด์์๊ณผ ์ต๋ํ ๊ฐ๊น๊ฒ ์ผ์น"ํ๋ order
์ธ์๊ฐ ์์ต๋๋ค. ๊ทธ๊ฒ์ pytorch์ empty_like
๊ฐ ํ์ฌ ์ํํ๋ ์์
์ด ์๋๋๋ค(ํ๋กํ ํ์
์ด ์ฐ์์ ์ด์ง ์์ ๊ฒฝ์ฐ์๋ "nchw"- ์ฐ์ ํ
์๋ฅผ ๋ฐํํจ)
์, ๋๋ฌด ๋นจ๋ฆฌ ์ฝ์๋ค์. ๊ทธ ๊ฒฝ์ฐ์ ์ฐ๋ฆฌ์ empty_like ์ผ์น numpy๋ฅผ ๊ฐ๋ ๊ฒ์ด ์ข์ ๊ฒ์ด๊ณ ์ฌ๊ธฐ์์ ๋ฉ๋ชจ๋ฆฌ ๋ ์ด์์๋ ๊ฐ๋ ๊ฒ์ด ์ข์ ๊ฒ์ ๋๋ค(์๋ง๋?)
@zou3519 ๋ค, ์ ๊ฐ ๋งํ๋ ค๋ ๊ฒ์ ํ์ฌ ์๋ฏธ๋ก ( @ezyang ๋ฐ @ngimel์ด ์ธ๊ธํ ๊ฒ์ฒ๋ผ ๋ ผ๋ฆฌ์ ์์ ์ญ์ )์ ์ ์งํ๊ณ ๋์์ numpy์ ๊ธฐ๋ณธ๊ฐ๊ณผ ๊ฐ์ ๋ฌผ๋ฆฌ์ ๋ ์ด์์์ ์ ์งํ๋ ๊ฒ์ ๋๋ค. ๋ฐ๋ผ์ NCHW ํ๋กํ ํ์ ์ ๊ฒฝ์ฐ ๋์์ ์ด์ ๊ณผ ๋์ผํฉ๋๋ค. NHWC ํ๋กํ ํ์ ์ ๊ฒฝ์ฐ ๋์์ ์ฌ์ ํ โโํธํ๋ฉ๋๋ค. ์ฆ, ํ์ฌ ๊ตฌํ์ ๋ณ๊ฒฝํ์ง ์์ผ๋ฉด ์ ํ ์๋ NCHW ์ฐ์ ๋์ NHWC ์ฐ์์ด ๋ฉ๋๋ค.
๋ ๊ฐ์ง ์ง๋ฌธ:
๋ง์ง๋ง ๊ธ๋จธ๋ฆฌ ๊ธฐํธ๋ก (B)์ ๋จ์ ์ ํด๊ฒฐํ๋ฉด (B)๊ฐ ๋์๊ฒ ๋ ๋์ ๊ฒ ๊ฐ์ต๋๋ค. ์ง๊ด์ ์ผ๋ก ๋ช ํํ๊ณ ๋ ผ๋ฆฌ์ ์ค๋ฅ๋ฅผ ๊ฐ์งํ๊ธฐ ์ฌ์์ผ ํฉ๋๋ค. ๊ธฐ์กด์ ๋ชจ๋ ์ฐ์ฐ์ ๋ค๋ฅธ ์ธ์ ํ ์์ฒ๋ผ ๋ณด์ด๊ธฐ ๋๋ฌธ์ ํ ์์์๋ ์๋ํ ์ ์์ต๋๋ค. ์๋งจํฑ(๋ช ๋ช ๋ ํ ์ ์ ์๊ณผ ์ ์ฌ)์ ์ดํดํ ์ ์๋ ์์ ๋ ์์๋๋ก ์ํ๋ฉ๋๋ค.
@zou3519 ๋งํฌํ numpy์ empty_like์๋ ๊ธฐ๋ณธ์ ์ผ๋ก "ํ๋กํ ํ์ ์ ๋ ์ด์์๊ณผ ์ต๋ํ ๊ฐ๊น๊ฒ ์ผ์น"ํ๋
order
์ธ์๊ฐ ์์ต๋๋ค. ๊ทธ๊ฒ์ pytorch์empty_like
๊ฐ ํ์ฌ ์ํํ๋ ์์ ์ด ์๋๋๋ค(ํ๋กํ ํ์ ์ด ์ฐ์์ ์ด์ง ์์ ๊ฒฝ์ฐ์๋ "nchw"- ์ฐ์ ํ ์๋ฅผ ๋ฐํํจ)
์ด๋ฌํ ๊ฒฝ์ฐ ํ์์ ์ ์งํ ๊ณํ์ ๋๋ค(๋ฉ๋ชจ๋ฆฌ ํ์ ํ ์์ ๊ฒฝ์ฐ).
NHWC ํ ์๋ฅผ NCHW ํ ์์ ์ถ๊ฐํ๋ฉด ์ด๋ป๊ฒ ๋ฉ๋๊น?
๋ฉ๋ชจ๋ฆฌ ํ์ ํ ์๋ฅผ ์ฌ์ฉํ ์์ ์ ๋ฉ๋ชจ๋ฆฌ ํ์ ํ ์๋ฅผ ๋ฐํํฉ๋๋ค. ๋ ํ ์๊ฐ ๋ชจ๋ ๋ฉ๋ชจ๋ฆฌ ํ์์ธ ๊ฒฝ์ฐ ์ถ๋ ฅ ํ์์ ์ฒซ ๋ฒ์งธ ํ ์์ ์ํด ๊ฒฐ์ ๋ฉ๋๋ค.
๋ด๊ฐ ์ถ๊ฐํ ๋ ๊ฐ์ง:
์ด๋ฌํ ๊ฒฝ์ฐ ํ์์ ์ ์งํ ๊ณํ์ ๋๋ค(๋ฉ๋ชจ๋ฆฌ ํ์ ํ ์์ ๊ฒฝ์ฐ).
์ข
์ข
์ด์์๊ฐ empty_like
ํธ์ถํ ๋ค์ NCHW ์ฐ์์ด๋ผ๊ณ ๊ฐ์ ํ๊ธฐ ๋๋ฌธ์ ๊ธฐ์กด ์ฌ์ฉ์ ๊ฐ์ฌํด์ผ ํฉ๋๋ค. ๊ทธ๋ฆฌ๊ณ ์ 3์ ์ฝ๋๋ฅผ ์ด๋ป๊ฒ ์ฒ๋ฆฌํ ์ง ๋ชจ๋ฅด๊ฒ ์ต๋๋ค. BC๋ฅผ ๋ณด์กดํ๋ ค๋ฉด numpy์ ๋ค๋ฅธ ๊ธฐ๋ณธ๊ฐ์ด ํ์ํ ๊ฒ ๊ฐ์ต๋๋ค.
๋ฉ๋ชจ๋ฆฌ ํ์ ํ ์๋ฅผ ์ฌ์ฉํ ์์ ์ ๋ฉ๋ชจ๋ฆฌ ํ์ ํ ์๋ฅผ ๋ฐํํฉ๋๋ค. ๋ ํ ์๊ฐ ๋ชจ๋ ๋ฉ๋ชจ๋ฆฌ ํ์์ธ ๊ฒฝ์ฐ ์ถ๋ ฅ ํ์์ ์ฒซ ๋ฒ์งธ ํ ์์ ์ํด ๊ฒฐ์ ๋ฉ๋๋ค.
๋ํ ์ถ๋ ฅ ํ์์ด ๋ฌด์์ธ์ง ์ ๋ง ์ค์ํ๋ค๋ฉด ์ถ๋ ฅ ํ ์๋ฅผ ์ ๋ฌํฉ๋๋ค.
empty_like์ ๋์ํฉ๋๋ค. empty_like/zeros_like ๋ฑ์ ๊ฒฐ๊ณผ๊ฐ nchw-contiguous๋ก ๊ฐ์ฃผ๋๋ ๊ฒฝ์ฐ๊ฐ ๊ฝค ์์ต๋๋ค(๋ฌผ๋ฆฌ์ ์ผ๋ก ์ฐ์์ ์ด๋ผ๊ณ ๋งํด์ผ ํฉ๋๋ค. ๋ง์ ๊ฒฝ์ฐ ์ด๋ฏธ์ง ์์
์ด ์๋).
out
kwarg๊ฐ ์๋ ํจ์๋ ๋ฏธ๋ถํ ์ ์๊ธฐ ๋๋ฌธ์ ์ถ๋ ฅ ํ
์๋ฅผ ์ ๋ฌํ๋ ๊ฒ์ ๋๋ถ๋ถ์ ๊ฒฝ์ฐ ์ต์
์ด ์๋๋๋ค.
์ฐ๋ฆฌ์ ๋ง์ ๋ฌธ์ ๋ ์์๋๋ ์ถ๋ ฅ ๋ ์ด์์์ ๋ถ์ผ์น์์ ๋น๋กฏ๋ฉ๋๋ค. ํ ๋ฒ์ ๋ชจ๋ ๋ฌธ์ ๋ฅผ ํด๊ฒฐํ ์๋ ์์ง๋ง ํ์ฌ ์ํ๋ฅผ ์ ๊ทธ๊ณ (์ ์ด๋ ๋ณดํญ์ ๋ํด์๋) ํ๋์ฉ ํด๊ฒฐํด ๋ณผ ์ ์์ต๋๋ค. ์ฌ๊ธฐ ์ ์์ด ์์ต๋๋ค.
ํ์ด์ฌ API
์๋ก์ด torch.memory_format์ ์๊ฐํฉ๋๋ค.
torch_memory_format.any # default value
torch_memory_format.preserve
torch.memory_format.contiguous # what most of the functions now behave as default
torch.memory_format.nchw # requires 4D tensor, contiguous memory
torch.memory_format.nhwc # requires 4D tensor, restrided/permuted memory
ํ ์๋ ๋ช ์์ ๋ฉ๋ชจ๋ฆฌ ํ์ ๋ณํ์ด ํ์ํฉ๋๋ค.
x = torch.zeros((10,3,32,32)) # NCHW
x.permute(0,2,3,1).is_contiguous(memory_format=torch.memory_format.nhwc) == False # because memory still layed out as NCHW
ํน์ ํ์์ผ๋ก 'ํ๊ทธ'ํ๋ ค๋ฉด:
y = x.to(memory_format=torch.memory_format.nhwc)
y.is_contiguous(memory_format=torch.memory_format.nhwc) == True # We got new tensor with proper memory layout
y.is_contiguous() == False # Required for back compatibility
y.stride() == (3072, 3, 1, 96)
์ด์ empty_like ๋ฐ ์ ์ฌ์ ๋ํด:
z = torch.empty_like(y)
z.is_contiguous() == True # For BC
์ค์ ๋ก ๋ค์๊ณผ ๊ฐ๊ธฐ ๋๋ฌธ์ ๋๋ค.
z = torch.empty_like(y, memory_format=torch.memory_format.any )
ํ์์ ์ ์งํ๋ ค๋ฉด:
z = torch.empty_like(y, memory_format=torch_memory_format.preserve)
z.is_contiguous() == False
z.is_contiguous(memory_format=torch.memory_format.nhwc) == True
๋น์ทํ๊ฒ:
z = torch.empty_like(y, memory_format=memory_format=torch.memory_format.nhwc)
z.is_contiguous() == False
z.is_contiguous(memory_format=torch.memory_format.nhwc) == True
์ฆ, ๊ฐ ํจ์ memory_format ๊ธฐ๋ณธ๊ฐ์ ์ธ๊ณ์ ํ์ฌ ์ํ๋ก ์ฒ์ฒํ ์ ์ํ๊ณ ๋ถ๋ฅํ๊ณ ๋ฏธ๋์ ๋ณ๊ฒฝํ ๋ฐฉ๋ฒ์ ์ผ๋์ ๋ ์ ์์ต๋๋ค.
ํ
์๋ฅผ ์ง์ ํ๋ฉด TensorOptions๊ฐ ํ์ฌ ๋ฌด์๋ฉ๋๋ค(๊ฐ์ฅ ์ข์ ๊ฒฝ์ฐ ์์ธ๊ฐ ๋ฐ์ํ๋ ๊ฒ์ ์๋ฅผ ๋ค์ด out
ํ
์ ์ฅ์น์ ์ฅ์น ์ต์
๋ถ์ผ์น๋ฅผ ์ ๋ฌํ๋ ๊ฒ์
๋๋ค).
๋ฉ๋ชจ๋ฆฌ ํ์์ ๊ฐ๋ฒผ์์ผ ํ๋ฏ๋ก ์์ด์ด ์์ค๋ ๊ฒ์ ๋๋ค.
x.zeros((10,3,32,32), memory_format=torch.memory_format.nhwc)
x = x.permute(0,1,3,2).permute(0,1,3,2)
x.is_contiguous(memory_format=torch.memory_format.nhwc) == False (even if strides are similar)
ํจ๋ฉ์ด ํ์คํ์ง ์์ ๊ฒฝ์ฐ ์ฌ๊ธฐ์์ ๋์์ ์ฃผ์๋ฉด ๊ฐ์ฌํ๊ฒ ์ต๋๋ค.
๊ทธ๋ฌ๋ ์ ์ ํ ํ์์ผ๋ก x.to(memory_format=torch.memory_format.nhwc) 'tag' ํ ์๋ฅผ ๋ง๋ค๊ณ ์์ฒด๋ฅผ ๋ฐํํ ์ ์์ต๋๋ค.
๋ค์ค ์ฒ๋ฆฌ
๋ฉ๋ชจ๋ฆฌ ํ์ 'ํ๊ทธ'๋ฅผ ์ ์งํฉ๋๋ค.
๋ธ๋ก ๋ฉ๋ชจ๋ฆฌ ํ์
์์ API๋ ์ฐจ์/๋ณดํญ/ํฌ๊ธฐ์ ์์กดํ์ง ์์ผ๋ฏ๋ก ํฅํ ๋์ผํ API๋ฅผ ์ ์งํ๋ฉด์ ๊ธฐ๋ฅ์ ํ์ฅํ ์ ์์ต๋๋ค.
๋ด๋ถ API
์ฐ์ฐ์๋ ๋ฉ๋ชจ๋ฆฌ ํ์์ ๋ฐ๋ผ ๋ถ๊ธฐํ ์ ์์ต๋๋ค.
if (self.memory_format(nhwc)) {
// fast path
} else
{
// classic implementation
}
memory_format์ TensorOptions๋ก ํ๋ฉด ๋์คํจ์น ์์ค์์ ๋ถ๊ธฐํ๋ ๊ฒ์ ์๊ฐํ ์ ์์ต๋๋ค(๋๋ฐ์ด์ค, ๋ ์ด์์๊ณผ ์ ์ฌ).
@VitalyFedyunin ์ ์ ์์ ๋ํ ์์ ํผ๋๋ฐฑ - ์ฌ๊ธฐ์ 4D ํ ์๊ฐ ํ์ํ๋ค๊ณ ์๊ฐํฉ๋๋ค.
torch.memory_format.nchw # requires 4D tensor, contiguous memory
torch.memory_format.nhwc # requires 4D tensor, restrided/permuted memory
๋๋ฌด ์ ํ์ ์ด๋ฉฐ(2D ์ธ์ 1D ๋ฐ 3D๋ ์ฒ๋ฆฌํด์ผ ํ๊ธฐ ๋๋ฌธ์) ์๋ ์ ์์ channels_first/channels_last
๊ฐ ์ด ๋ชฉ์ ์ ๋ ์ ํฉํ์ต๋๋ค.
๋์ํฉ๋๋ค. ๋ ๋์ ์ด๋ฆ ์ง์ ์ด ํ์ํฉ๋๋ค. channels_first
๋ ์ผ๊ด ์ฒ๋ฆฌ๊ฐ ๋จผ์ ์งํ๋๋ค๋ ์ ์ ์ ์ธํ๊ณ ๊ฑฐ์ ์ณ๊ฒ ๋ค๋ฆฝ๋๋ค =)
๋๋ ๋น์ ์ ์ต์ ์ ์์ ์ข์ํฉ๋๋ค. .contiguous() ์ฒ๋ฆฌ๊ฐ ๋ณ๊ฒฝ๋ฉ๋๊น? .contiguous(memory_format=<...>)๊ฐ ํ์ํฉ๋๊น? ๊ทธ๋ ๋ค๋ฉด ๋ง์ ์์ ์ด ๋จ์ํ .contiguous()๋ฅผ ํธ์ถํ์ง๋ง ์ฌ์ ํ ๋ฉ๋ชจ๋ฆฌ๋ฅผ ๋ถ์ ์ ํ๊ฒ ํฌ๋งทํ๊ณ ์์ ์ ์์ต๋๋ค. ์ค๋๋ ๋ง์ ์์ ์์๋ ๋์ผํ ํจ๊ณผ๋ฅผ ๋ผ ์ ์๋ empty_like()๋ก ์ถ๋ ฅ์ ํ ๋นํฉ๋๋ค. ์ ๋ ฅ์ ๋ฉ๋ชจ๋ฆฌ ํ์์ ๊ฐ์งํ๊ณ ์ฌ๋ฐ๋ฅธ ์ฐ์์ ์ด๊ณ empty_like ํธ์ถ์ ์ํํ๋๋ก ์ ๋ฐ์ดํธํ ๊ณํ์ ๋๊น?
์ง๊ธ ๋น์ฅ์ .contiguous()
๊ฐ ๋ฉ๋ชจ๋ฆฌ ์ฐ์ ํ
์๋ฅผ ๋ด๋ฆผ์ฐจ์์ผ๋ก ๋ณดํญ์ผ๋ก ๋ฐํํ ๊ฒ์ผ๋ก ๊ธฐ๋ํ๋ ์ฌ์ฉ์(๋ฐ ๋ชจ๋ ๋ผ์ด๋ธ๋ฌ๋ฆฌ)์
๋๋ค.
์ฐ๋ฆฌ๋ ์ด ๊ณ์ฝ์ ๊นฐ ์ ์์ต๋๋ค. ๊ทธ๋ฌ๋ ์ข์ ์์์ memory_format ์ต์
์ ์ง์ํ๋ ์ฆ์ JIT๊ฐ ํด๋์ ํ์ ๋์ .contiguous(memory_format=...)
๋ฅผ ํธ์ถํ๋ ๊ฒ์ด ๋ ํจ์จ์ ์ธ ๋๋ฅผ ์ดํดํ ์ ์๋ค๋ ๊ฒ์
๋๋ค.
@VitalyFedyunin ์๋์ ๊ฐ์ ์์ ์ ํ์ฉ๋์ง ์๋๋ค๊ณ ๊ฐ์ ํฉ๋๊น?
x.zeros(10,3,32,32)
# x is in nchw (default)
# x.size() is [10,3,32,32]
# x.stride() is [3*32*32, 32*32, 32,1]
x = x.permute(0,2,3,1)
# At this point
# x.size() is [10,32,32,3], size is not in nchw order
# x.stride() is [3*32*32, 32,1,32*32]
# How can this be supported?
y = x.to(memory_format=torch.memory_format.nhwc)
๋ ๋ค๋ฅธ ๋ณํ์ ๋ค์๊ณผ ๊ฐ์ต๋๋ค.
x.zeros(10,3,32,32)
# `x` is in nchw (default)
# x.size() is [10,3,32,32]
# x.stride() is [3*32*32, 32*32, 32,1]
x = x.permute(0,2,3,1)
x=x.contiguous()
# At this point
# x.size() is [10,32,32,3], size is not in nchw order
# x.stride() is [32*32*3, 32*3,3,1]
# How can this be supported?
y = x.to(memory_format=torch.memory_format.nhwc)
@raghuramank100 - ์ฌ์ฉ์๊ฐ ์ฒ์์ .permute(0,2,3,1)
๋ฅผ ํธ์ถํ๋ ์ด์ ๋ ๋ฌด์์
๋๊น? ์ด ์ ์์ ๋ชจ๋ ํ
์๋ ์๋ฏธ๋ก ์ ํฌ๊ธฐ๊ฐ (n,c,h,w)์ด๋ฉฐ, ์ด๋ size(1)์ด ์ฑ๋์ ๋ฐํํจ์ ์๋ฏธํฉ๋๋ค. ๊ทธ๊ฒ์ด ์ค๋๋ PT์ ํ์ค ๋ผ์ด๋ธ๋ฌ๋ฆฌ๊ฐ ๊ฐ์ ํ๊ณ ์ด ์ ์์์๋ ๊ฐ์ ํ๋ ๊ฒ์
๋๋ค. ๋ฐ๋ผ์ .permute๋ฅผ ์ ํ ํธ์ถํ์ง ์์ ๊ฒ์
๋๋ค.
์ปจํ ์คํธ ๊ด๋ฆฌ์๊ฐ ์ฌ์ฉ์๊ฐ ๊ด๋ฆฌ์ ๋ฒ์ ๋ด์์ ํ ๋น๋ ํ ์์ ๋ฉ๋ชจ๋ฆฌ ํ์์ ํน์ ํ์์ผ๋ก ์ฌ์ ์ํ ์ ์๋๋ก ํ๋ ๋ฐ ์ ์ฉํ ์ ์์ต๋๊น?
with torch.memory_format(torch.memory_format.nhwc):
# a will be allocated with the context managed memory format
a = torch.randn(...)
# b will be allocated matching some assumed default format
b = torch.randn(...)
memory_format์ ์ ์ด๋ฅผ ๋์จํ๊ฒ ํ๊ธฐ ๋๋ฌธ์ ์ปจํ ์คํธ ๊ด๋ฆฌ์์ ์์ด๋์ด๊ฐ ๋ง์์ ๋ค์ง ์์ต๋๋ค.
์๋ฅผ ๋ค์ด:
with torch.memory_format(torch.channels_last):
x = torch.randn(10,3,32,32) # this one is NHWC
y = torch.randn(10,10) @ this one is not
๋ช ์์ memory_format์ด ๋ช ํํ๊ฒ ํ๋ ๊ฒฝ์ฐ:
x = torch.randn(10,3,32,32).to(memory_format=torch.channels_last) # this one is NHWC
y = torch.randn(10,10).to(memory_format=torch.channels_last) # This is errors out as dim == 2
ํ์ํ ๊ฒฝ์ฐ ๋ค์์ ํ์ฉํ๋ ๊ตฌ๋ฌธ์ ์ถ๊ฐํ ์ ์์ต๋๋ค.
x = torch.randn(10,3,32,32, memory_format=torch.channels_last)
@raghuramank100 ์์ดํ ํ์๊ฐ ์์ต๋๋ค.
y = x.to(memory_format=torch.channels_last)
x์์์ ๊ฐ์ด ํฌ๋ฏธํ ์์๋ฅผ ์ ์งํ๋ฉด์ ๋ชจ๋ ๋๋ฌ์ด ์์ ์ ์ํํฉ๋๋ค.
๊ทธ๋์:
x = torch.randn(10, 3, 32, 32)
nhwc = x.to(memory_format=torch.channels_last)
self.assertFalse(nhwc.is_contiguous())
self.assertTrue(nhwc.is_contiguous(memory_format=torch.channels_last))
self.assertEqual(nhwc, x)
๊ทธ๋ฆฌ๊ณ ์ด ํ์์ผ๋ก nhwc๋ฅผ ๊ณ์ ์ฒ๋ฆฌํ ์ ์์ต๋๋ค.
nhwc[N][C][H][W]
@VitalyFedyunin ๊ทธ๊ฒ์ ์๋ฏธ๊ฐ ์์ต๋๋ค.
์ฌ์ฉ์์ ๊ด์ ์์ ๋ณผ ๋ ๋ฉ์๋ ์ด๋ฆ(์ด๋๋ก ์ ์ง๋๋ ๊ฒฝ์ฐ)์ "to"๊ฐ ์ด๋ฏธ Tensor๋ฅผ ๋ค๋ฅธ ์ฅ์น๋ก ์ ์กํ๋ ๋ฐ ๊ถ์ฅ๋๋ ๋ฐฉ๋ฒ์ด๋ฏ๋ก ์คํด์ ์์ง๊ฐ ์๋ ๊ฒ ๊ฐ์ต๋๋ค.
๋ํ C_ORDER ๋ฐ F_ORDER ๋ฐฐ์ด์ ๋ณํํ๋ Numpy์ ๊ฒ๊ณผ ๊ฐ์ ๊ฒ์ ์ด๋ป์ต๋๊น?
numpy.asfortranarray()
numpy.ascontiguousarray()
๋ค์๊ณผ ๊ฐ์ ๊ฒ์ ์ฝ๊ฒ ์์ํ ์ ์์ต๋๋ค.
torch.randn(32, 3, 64, 64).to(device).as_nhwc()
@VitalyFedyunin : ๋ค๋ฅธ memory_format์ผ๋ก ๋ณํํ๋ฉด ์ฌ์ฉ์๊ฐ ์๋์ผ๋ก ๋ณ๊ฒฝํ ํ์๊ฐ ์๋ค๋ ๊ฒ์ ์ดํดํฉ๋๋ค. ๊ทธ๋ฌ๋ ์ด ๊ธฐ๋ฅ์ ํ ์น์์ ์ฌ์ฉํ ์ ์๊ฒ ๋๋ฉด ์์์ ์ค๋ช ํ ์์๋๋ก ์ฌ์ฉ์๊ฐ ํจ์๋ฅผ ํธ์ถํ๋ฉด ์ด๋ป๊ฒ ๋ ๊น์? ์ต์ํ ๋ ์ด์์ ๋ณํ์ด ์คํจํ๋ค๋ ๊ฒฝ๊ณ /์ค๋ฅ ๋ฉ์์ง๊ฐ ์์ด์ผ ํฉ๋๋ค.
@VitalyFedyunin : ๋ค๋ฅธ memory_format์ผ๋ก ๋ณํํ๋ฉด ์ฌ์ฉ์๊ฐ ์๋์ผ๋ก ๋ณ๊ฒฝํ ํ์๊ฐ ์๋ค๋ ๊ฒ์ ์ดํดํฉ๋๋ค. ๊ทธ๋ฌ๋ ์ด ๊ธฐ๋ฅ์ ํ ์น์์ ์ฌ์ฉํ ์ ์๊ฒ ๋๋ฉด ์์์ ์ค๋ช ํ ์์๋๋ก ์ฌ์ฉ์๊ฐ ํจ์๋ฅผ ํธ์ถํ๋ฉด ์ด๋ป๊ฒ ๋ ๊น์? ์ต์ํ ๋ ์ด์์ ๋ณํ์ด ์คํจํ๋ค๋ ๊ฒฝ๊ณ /์ค๋ฅ ๋ฉ์์ง๊ฐ ์์ด์ผ ํฉ๋๋ค.
์ด๊ฒ์ ๋ช ๋ช ๋ ํ ์๋ฅผ ๊ตฌํํ ๋๋ง ๊ฐ๋ฅํฉ๋๋ค. ์ง๊ธ ๋น์ฅ:
x.zeros(10,10,10,10)
x = x.permute(0,2,3,1)
๋ด๊ฐ ๋ฐฉ๊ธ nchw ๋๋ nhwc๋ฅผ ๋ง๋ค์๋์ง ์๋ฌด๋ ์ ์ ์์ต๋๋ค.
๋ด๊ฐ ์๋ ์ ์์ ์๋ชป ์ดํดํ์ ์๋ ์์ง๋ง ๊ธฐ๋ก๋ ๋ฉ๋ชจ๋ฆฌ ํ์ ํ๊ทธ๊ฐ ์ด ์ํฉ์ ๋ช ํํ๊ฒ ํด์ผ ํ๋ ๊ฒ ์๋๊ฐ์?
@VitalyFedyunin ์ด API๊ฐ ์์ ํ๋๋ฉด ์ต์ข ์ฌ์ฉ์์๊ฒ ์ด๋ฅผ ์ ๋ฌํด์ผ ํฉ๋๋ค.
@dzhulgakov @VitalyFedyunin #19975๋ฅผ ๊ฒํ ํ ํ ํ
์์ ๊ธฐ๋ก๋ ๋ฉ๋ชจ๋ฆฌ ํ์ ํ๊ทธ์ ๋ํ ๋ช ๊ฐ์ง ์๋ก์ด ์ฐ๋ ค๊ฐ ์์ต๋๋ค. ๋ด ๊ธฐ๋ณธ์ ์ธ ๋ฌธ์ ๋ ์์
์ด ๋ฉ๋ชจ๋ฆฌ ํ๊ทธ๋ฅผ ๋ณด์กดํด์ผ ํ๋์ง ์ฌ๋ถ๋ฅผ ์ด๋ป๊ฒ ๊ฒฐ์ ํด์ผ ํฉ๋๊น? ์๋๋ "๋์ฒด ๋ ์ด์์ ์ธ์" ์ด์์๋ง ์ด๋ฌํ ์๋ฆฌํจ์ ๊ฐ์ถ์ด์ผ ํ๋ค๊ณ ์๊ฐํ์ต๋๋ค. ํ์ง๋ง Vitaly์ ํจ์น๋ฅผ ๋ณด๋ฉด ์ผ๋ถ ํต์ฌ ์คํผ๋ ์ดํฐ๋ ์กฐ์ ์ด ํ์ํ๋ค๊ณ ์๊ฐํฉ๋๋ค. ์๋ฅผ ๋ค์ด x[0]
; x๊ฐ ์ด์ ์ NHWC ํ
์๋ผ๋ฉด ์ด ์์
์ ์ํํ ํ HWC ํ
์๋ฅผ ๊ฐ์ ธ์์ผ ํฉ๋๋ค. ๋๋ Vitaly์ ํจ์น๊ฐ ์ด๊ฒ์ ์ฌ๋ฐ๋ฅด๊ฒ ์ฒ๋ฆฌํ์ง ๋ชปํ๋ค๊ณ ํ์ ํ๋ฉฐ, ์ด๋ ์ฌ์ฉ์์๊ฒ ๋งค์ฐ ํผ๋์ค๋ฌ์ธ ๊ฒ์
๋๋ค. ์๋ง๋ ์ํฅ์ ๋ฐ๋ ์ ์ผํ ์ฐ์ฐ์๋ ๋ณดํญ์ ์ฒ๋ฆฌํ๋ ์ฐ์ฐ์(์ด ๊ฒฝ์ฐ ๋๋ฌด ๋ง์ง ์๊ณ ์๋์ผ๋ก ๊ฐ์ฌํ ์ ์์)์ด์ง๋ง ์ฐ๋ฆฌ๊ฐ ํด์ผ ํ ์ผ์ธ ๊ฒ ๊ฐ์ต๋๋ค. ์ด๋ป๊ฒ ์๊ฐํ๋์?
์ ๊น, ํ ์๋ ์ฌ์ ํ ๋ค์ ์์๋ก ์ธ๋ฑ์ฑ๋ ์ํ๋ฅผ ์ ์งํฉ๋๋ค. 0-dim N; 1์ฐจ์ C; 2์ฐจ์ H; 3rd-dim W. ๋ฐ๋ผ์ x[0]์ 0-dim C๋ฅผ ๊ฐ์ง ํ ์๋ฅผ ๋ฐํํฉ๋๋ค. 1์ฐจ์ H; 2nd-dim W. x๊ฐ channel_first ๋๋ channel_last ๋ฉ๋ชจ๋ฆฌ ๋ ์ด์์์ธ์ง ์ฌ๋ถ์ ๊ด๊ณ์์ด.
๊ทธ๋ ์ง ์์ผ๋ฉด memory_format์ด ์๋ฏธ๊ฐ ์์ผ๋ฉฐ ํ ์๋ฅผ ์นํํ๊ธฐ๋ง ํ๋ฉด ๋ฉ๋๋ค.
๋ด ์์ ์ ๋ฉ๋ชจ๋ฆฌ ํ์ ํ๊ทธ๊ฐ ๋ณด์กด๋์ง ์๋๋ค๋ ๊ฒ์
๋๋ค. ์
๋ ฅ ํ
์์ channels_last
ํ๊ทธ๊ฐ ์ง์ ๋ ๊ฒฝ์ฐ ์ ํ
์๋ any
ํ๊ทธ๊ฐ ์ง์ ๋ฉ๋๋ค.
cc @zou3519 , ์ฌ๊ธฐ ๋ ์ด์์ ์ ํ ๋ ผ๋ฆฌ๋ ๋ช ๋ช ๋ ํ ์ ์์ ์์ ๋ช ๋ช ๋ ์ฐจ์ ์ ํ๋ฅผ ๋ง์ด ์๊ฐ๋๊ฒ ํฉ๋๋ค.
๋๋ ์ฌ์ ํ ์ด ์ ์์ ๋ฐ๋ผ์ก๊ณ ์๋ค. ๊ทธ๋ฌ๋ @ezyang ์ ์ฐจ์๋ณ ํ๋๊ทธ(๋๋ ์ด๋ฆ)๋ฅผ ์ ํํ์ฌ ๋ ์ด์์ ์ ํ ๋ ผ๋ฆฌ๋ฅผ ์ถ์ ํ ์ ์์ผ๋ฉฐ, ๊ทธ๋ฌ๋ฉด ์ด๋ฆ ๊ท์น์ ์ฌ์ฉํ์ฌ ๋ช ๋ช ๋ ํ ์๋ฅผ ๊ฐ๋ ๊ฒ๊ณผ ๊ฐ์ต๋๋ค.
๋ฉ๋ชจ๋ฆฌ ํ๊ทธ ๋ก์ง๊ณผ ๋ช ๋ช ๋ ํ ์ ๋ก์ง์ ์ ํํ๊ฒ ์ ๋ ฌํ ์ ์๋ค๋ฉด, ๋น๋ก ์ฒ์์ ๋ ๊ฐ์ ๊ฐ๋ณ ๊ตฌํ ๊ฒฝ๋ก๊ฐ ์๋๋ผ๋ ๊น๋ํ ๊ฒ์ ๋๋ค.
๋ ๊ฐ์ ํ
์ ํจ์ .is_contiguous
๋ฐ .contiguous
(python ๋ฐ C++ API ๋ชจ๋)์ ๊ธฐ๋ฅ์ ํ์ฅํฉ๋๋ค.
์ฐธ๊ณ : .to(memory_format)
๊ธฐ๋ฅ์ ๋ํ ๋ช ๊ฐ์ง ๋ถ๋ง์ด ์์๊ณ ์ง์ํ์ง ์๊ธฐ๋ก ๊ฒฐ์ ํ์ต๋๋ค.
.contiguous
์ด์ ์ ํ์ ํค์๋ ์ ์ฉ ์ธ์์ธ memory_format
ํฉ๋๋ค. torch.contiguous_format
๋๋ torch.channels_last
์์ต๋๋ค.
torch.contiguous_format
ํ๋ฉด ๊ธฐ์กด .contiguous()
๋์์ด ์ ์ง๋ฉ๋๋ค.
x.contiguous(memory_format=torch.channels_last)
ํธ์ถํ๋ฉด ๋์ผํ ์๋ฏธ์ ๋ ์ด์์(NCHW)์ ์ ์งํ์ง๋ง ๋ฉ๋ชจ๋ฆฌ ํ ๋น ํจํด์ด ๋ค๋ฅธ ์ ํ
์๋ฅผ ๋ฐํํฉ๋๋ค.
x.contiguous(memory_format=torch.channels_last)
๋ ์
๋ ฅ ํ
์๊ฐ 3d, 4d ๋๋ 5d์ผ ๊ฒ์ผ๋ก ์์ํฉ๋๋ค. ๊ทธ๋ ์ง ์์ผ๋ฉด ์คํจํฉ๋๋ค.
.is_contiguous
์ด์ ์ ํ์ ํค์๋ ์ ์ฉ ์ธ์์ธ memory_format
ํฉ๋๋ค. torch.contiguous_format
๋๋ torch.channels_last
์์ต๋๋ค.
x.is_contiguous(memory_format=torch.contiguous_format)
๋ x.is_contiguous()
์ ๋์ผํ ๊ธฐ๋ฅ์ ์ ์งํ๋ฉฐ ๋ณ๊ฒฝ๋์ง ์์ ์ํ๋ก ์ ์ง๋ฉ๋๋ค.
x.is_contiguous(memory_format=torch.channels_last)
๋ A) ์
๋ ฅ ํ
์๊ฐ ๋ฉ๋ชจ๋ฆฌ์์ ์ฐ์์ ์ด๊ณ B) NWHC(๋๋ 3d,5d์ ๊ฒฝ์ฐ ์ ์ฌ) ํ์์ผ๋ก ๋ฉ๋ชจ๋ฆฌ์ ํ ๋น๋ ๊ฒฝ์ฐ true๋ฅผ ๋ฐํํฉ๋๋ค.
์ฐธ๊ณ : 1๋จ๊ณ๊ฐ ๋๋ ๋ x.is_contiguous(memory_format=torch.channels_last)
๋ ๋ชจ๋ ํธ์ถ์์ Tensor์ ์ํ๋ฅผ ๊ณ์ฐํฉ๋๋ค. ์ด ๊ธฐ๋ฅ์ ๋์ค์ ์
๋ฐ์ดํธ๋ ์์ ์
๋๋ค.
ํน์ ์์ ์ ๋ํ ๋ฉ๋ชจ๋ฆฌ ํ์ ์ ์ง:
๋จํญ ์์๋ณ ์ฐ์ฐ์๋ channel_last ๋ฉ๋ชจ๋ฆฌ ํ์์ ์ ์งํฉ๋๋ค.
a = torch.randn(N,C,H,W)
b = a.contiguous(memory_format=torch.channels_last)
c = b.sin()
c.is_contiguous(memory_format=torch.channels_last) == True
์ด์ง ์์๋ณ ์ฐ์ฐ์( add
, sub
, mul
, div
)๋ channels_last ๋ฉ๋ชจ๋ฆฌ ํ์์ ์ ์งํฉ๋๋ค.
a = torch.randn(N,C,H,W)
b = a.contiguous(memory_format=torch.channels_last)
c = b * torch.randn(H,W)
c.is_contiguous(memory_format=torch.channels_last) == True
ํฌ๊ธฐ, ๋ณดํญ ๋ฐ ํ๋ฆผ์ ๋ํ ๋ชจ๋ ์์ ์ ๋ฉ๋ชจ๋ฆฌ ํ์์ ์ฌ์ค์ ํฉ๋๋ค.
a = torch.randn(N,C,H,W)
b = a.contiguous(memory_format=torch.channels_last)
c = b.permute(0,2,3,1).permute(0,3,1,2)
c.is_contiguous(memory_format=torch.channels_last) == False
๋ฏธ์
์ถ๋ ฅ์ด ์ฝ์ ์ ์๋ 'channels_last'์ธ ๊ฒฝ์ฐ ๋ชจ์ ๋ณ๊ฒฝ(๋ฐ ์ ์ฌ) ์์ ์ ๊ฒฐ๊ณผ
import torch
a = torch.randn(N,C,H,W)
b = a.contiguous(memory_format=torch.channels_last)
c = b.reshape(N,C,-1)
c.is_contiguous(memory_format=torch.channels_last) # ?
์ฐธ๊ณ : ํ์ฌ memory_format์ด ๋ณด์กด๋์ง ์์์ต๋๋ค.
NHWC + NCHW ์์ ์ ๊ฒฐ๊ณผ์ ๋๋ค. NHWC์ธ๊ฐ์?
์ฐธ๊ณ : ํ์ฌ NHWC + NCHW -> NHWC ๋ฐ NCHW + NHWC -> NHWC
cat/split๊ณผ ๊ฐ์ ์์ ์ ์ด๋ป์ต๋๊น? ๋ฉ๋ชจ๋ฆฌ ํ์์ ์ ์งํ๋ ๋ฐ ์ ์ฉํฉ๋๋ค.
@ezyang - ์ธ๋ฑ์ฑ๊ณผ ๊ด๋ จํ์ฌ ์ด๋๊ฐ์์ ๋ฉ์ถฐ์ผ ํ๋ค๊ณ ์๊ฐํฉ๋๋ค. ๋ค๋ฅธ ๋ฉ๋ชจ๋ฆฌ ๋ ์ด์์์ ์์ ํ ํฌ๋ช
ํ์ง ์์ผ๋ฉฐ ์ผ๋ถ ์์
์์๋ ์ด๋ฅผ ๋ฌด์ํ๋๋ก ํ์ฉํด์ผ ํฉ๋๋ค. x[0]
๋ x[0].unsqueeze(0)
ํฌํจํ์ฌ ํ๊ทธ๋ฅผ ์ง์ธ ์ ์์ด์ผ ํ๋ค๊ณ ์ฃผ์ฅํฉ๋๋ค.
Raghu๊ฐ ์ธ๊ธํ๋ฏ์ด cat/split์ ๋งค์ฐ ์ผ๋ฐ์ ์ธ ์ฌ์ฉ๋ฒ์ด์ง๋ง ๊ฐ๋ฅํ๋ฉด ํ๊ทธ๋ฅผ ๋ณด์กดํด์ผ ํฉ๋๋ค. ์ผ๋ฐ์ ์ธ ๊ฒฝํ ๋ฒ์น์ ์ด์์ด ์์๋ฅผ ๋ณ๊ฒฝํ๊ฑฐ๋ ์ถ์ ์ด์ํ๊ฒ ์ฌ์ ๋ ฌํ์ง ์๋ ํ ํ๊ทธ๋ฅผ ์ ์งํด์ผ ํ๋ค๋ ๊ฒ์ ๋๋ค. ์์๊ฐ ๋ณ๊ฒฝ๋๋ฉด ๋ชจ๋ ๋ฒ ํ ์ด ํด์ ๋ฉ๋๋ค.
์ด๋ค ๊ฒฝ์ฐ์๋ ํ๊ทธ๋ฅผ ์๊ฒ ๋๋ค๋ ๋ฐ ๋์ํฉ๋๋ค. ํ์ง๋ง x[0]
๋ํด์๋ ๋์ํ์ง ์์ต๋๋ค. ๋์๊ฒ ๊ทธ๊ฒ์ NCHW
์์ CHW
๋ก ๊ฐ๋ ๋งค์ฐ ์ผ๋ฐ์ ์ธ ๋ฐฉ๋ฒ์ธ ๊ฒ ๊ฐ์ต๋๋ค.
ํ ์๊ฐ channel_last 'ํ๊ทธ'๋ฅผ ์ ๋ฌ(๋๋ ํฌํจํ์ง ์์)ํ๋ ๊ฒ์ด ์ผ๋ง๋ ํผ๋์ค๋ฌ์ด์ง์ ๋ํด ๋ช ์ฐจ๋ก ๋ํ๋ฅผ ๋๋ ํ ์ฐ๋ฆฌ๋ bc-๋ธ๋ ์ดํน ๋ณ๊ฒฝ์ ๋์ ํ๊ณ ํ ์๋ฅผ channel_last ํ์์ผ๋ก ์๋ ์น๊ฒฉํ๋ ์ํ์ ๊ฐ์ํ๊ธฐ๋ก ๊ฒฐ์ ํ์ต๋๋ค.
API์ ๋ํ ์๋ฏธ:
N,1,H,[W,[D]]์ ๊ฐ์ ์คํธ๋ผ์ด๋๊ฐ ์๋ ๋ชจ๋ 3d,4d,5d ํ ์๋ ์๋์ผ๋ก channel_last ๋ฉ๋ชจ๋ฆฌ ํ์์ ์ป์ต๋๋ค.
์๋ํ๋๋ก ํ๊ธฐ ์ํด channel_last ํ ์๋ฅผ ์ถ๋ ฅํ๋ channel_last ํ ์์ ์ฐ์ฐ์๊ฐ ์ธ์ ํ ํ ์์ ์ฐ์ฐ์์ ์ต์ํ ๋น์ทํ ์ฑ๋ฅ์ ๊ฐ๋๋ก ๋ณด์ฅํ๊ธฐ ์ํด ํน๋ณํ ์๋ฐฉ ์กฐ์น๋ฅผ ์ทจํ ๊ฒ์ ๋๋ค.
์ต์
์ ์๋๋ฆฌ์ค์ ๊ฒฝ์ฐ:
1) ์ฌ์ฉ์๋ ์ถ๋ ฅ์์ โโ.contiguous()๋ฅผ ํธ์ถํ ์ ์์ต๋๋ค.
2) ์ฐ๋ฆฌ๋ ์ด ๋์์ ๋ณ๊ฒฝํ๋ ๊ฒ์ด ๊ฑฐ์ ์ฌ์ํ ๋ฐฉ์์ผ๋ก ์๋ ์น๊ฒฉ ์ฝ๋๋ฅผ ์์ฑํ ๊ฒ์
๋๋ค.
์ด๋ฌํ ์๋ ํ๋ก๋ชจ์ ์ ๋ถ์์ฉ์ ๋ค์๊ณผ ๊ฐ์ต๋๋ค.
import torch
x = torch.randn(10,16,16,3).permute(0,3,1,2)
x.is_contiguous(memory_format=torch.channels_last) == True
๋ค๋ฅธ ํํธ์ผ๋ก (๊ฐ๋ฒผ์ด ์์ ํ) ๊ฒฝ์ฐ๋ฅผ ํด๊ฒฐํ ์ ์์ต๋๋ค.
import torch
x = torch.randn(10,3,16,16).contiguous(memory_format=torch.channels_last)
x = x[0].unsqueeze(0)
x.is_contiguous(memory_format=torch.channels_last) == True
@ezyang ์ ์์ฒญ์ ๋ฐ๋ผ slack ๋ณํ์์
๋ํ๋ฆฌ์ ๊ธฐ๋ฉ์ค์ธ [2:19 PM]
๊ทธ๋์ ๋๋ ํ๊ทธ์ ๊ฐ๋
์ด ์์ ๊ฒ์ด๋ผ๊ณ ์๊ฐํฉ๋๋ค.
import torch
#batch = 10, channels = 4, spatial dimensions = 16
x = torch.randn(10,16,16,4).permute(0,3,1,2)
x.is_contiguous(memory_format=torch.channels_last) == True
y = torch.randn(10,16,16,2).permute(0,3,1,2)
x1,x2 = x.chunk(2, dim=1) #chunk along channels dimension, no longer contiguous
x1.is_contiguous(memory_format=torch.channels_last) == False #right? So, if a tensor like this comes into e.g. convolution, what am I supposed to do with it? Did it want to be NHWC? Did it want to be nchw?
z=y+x1 #y is channels_last, x1 is something, what is the z layout?```
๋นํ๋ฆฌ ํ๋๋ [์ค์ 8์ 23๋ถ]
z๋ channel_last๊ฐ ๋ ๊ฒ์
๋๋ค.
๋นํ๋ฆฌ ํ๋๋ [์ค์ 8์ 25๋ถ]
x1์ด ์ ์๋ ๋ณํ์์ channel_last๊ฐ ์๋ ๊ฒฝ์ฐ(์ฒญํฌ ๊ธฐ๋ฅ์ ๋ณ๊ฒฝํ์ฌ ๋ทฐ๋ฅผ ๋ฐํํ์ง ์๋ ํ), ์ปจ๋ณผ๋ฃจ์
์ ์ด๋ฅผ ์ฐ์(channel_first) ํ์์ผ๋ก ๋ณํํ๊ณ ์ฐ์๋ ๋ฐํํฉ๋๋ค.
๋นํ๋ฆฌ ํ๋๋ [์ค์ 9:12]
@ngimel ํผ๋๋ฐฑ ์ฃผ์
์ ๊ฐ์ฌํฉ๋๋ค. ๋ณด๊ธฐ์ ๊ฐ์ ์์
์ด ๊ด๋ จ๋ ๋๋ถ๋ถ์ ๊ฒฝ์ฐ๋ฅผ ๋ค๋ฃจ๊ธฐ ์ํด ๋ณด๋ค ์๋ฏธ ์๋
๋ํ๋ฆฌ์ ๊ธฐ๋ฉ์ค์ธ [์ค์ 9์ 36๋ถ]
์ค๋ ๋์ ๋ต๋ณ:
๊ทธ๋์ ๋ฌธ์ ์ธ ๊ฒ ๊ฐ์ฃ ? ์ฑ๋ ์ฐจ์์์ ์ฒญํฌํ๋ ๊ฒ์ ์๋ฅผ ๋ค์ด ์์๊ณผ ๊ฐ์ ๋คํธ์ํฌ์์ ๋น๊ต์ ์ผ๋ฐ์ ์ธ ๊ฒ์
๋๋ค. ๋ฐ๋ผ์ ํ
์๊ฐ ์ฒญํฌ๋ ์ฑ๋ ์ฒซ ๋ฒ์งธ ํ
์์ธ ๊ฒฝ์ฐ ํ์ ์ถ๋ ฅ์ ์ฑ๋ ์ฐ์ (์ง๊ด์ ์ธ ๋์์ด๋ฉฐ ์ฌ์ฉ์๊ฐ ์ํ๋ ๊ฒ์ผ ์ ์์)์ด ๋๊ณ , ํ
์๊ฐ ์ฒญํฌ๋ ์ฑ๋ ๋ง์ง๋ง์ธ ๊ฒฝ์ฐ ํ์ ์ถ๋ ฅ์ ๋ค์ ํ ๋ฒ ์ฑ๋ ์ฒซ ๋ฒ์งธ๊ฐ ๋ ๊น์?
๋ํ๋ฆฌ์ ๊ธฐ๋ฉ์ค์ธ [์ค์ 9:39]
์ค๋ ๋์ ๋ต๋ณ:
๊ทธ๋ฌ๋ ๋น ๊ตํ ๋ง์
๋์๊ณผ y
๊ฐ ์ฒซ ๋ฒ์งธ ์ธ์์ด๊ณ ์ฑ๋์ด ๋ง์ง๋ง์ด๊ธฐ ๋๋ฌธ์
๋๋ค. ๋ง์ต๋๊น? x1+y
์ ๊ฒฐ๊ณผ๋ ๋ฌด์์
๋๊น? ์ด๋๊ฐ์ ์ด์ง ์ฐ์ฐ์ ๋ํ ๋ ์ด์์ ์ ํ ๊ท์น์ด ์์ต๋๊น?
๋นํ๋ฆฌ ํ๋๋ [์ค์ 10:44]
1) ๋ค, ๋์ ์ ์์ผ๋ก ํด๊ฒฐํ ๋ฌธ์ ์
๋๋ค. ๋๋ ์ง๊ธ ๋ช ๊ฐ์ง ํ
์คํธ๋ฅผ ํ๊ณ ์๊ณ ์ด๋ฒ ์ฃผ์ ๊ธฐ๋กํ ๊ฒ์
๋๋ค(ํ๋ฃจ๋ ์ดํ ํ์).
2) x1+y - ๋ํ channel_last๋ฅผ ์์ฑํด์ผ ํฉ๋๋ค. ๊ทธ๋ ์ง ์์ผ๋ฉด ํผ๋์ค๋ฝ์ต๋๋ค. ์, ๋ ์ด์์ ์ ํ ๊ท์น์ด ๊ธฐ๋ก๋ฉ๋๋ค.
์ฐ๋ฆฌ๊ฐ ์ด ๋๋ฉด์ ๋ํด ์ด์ผ๊ธฐํ ๋ @VitalyFedyunin ์๊ฒ ๊ด์ฐฐํ ๋ด์ฉ์ (ํ์ง๋ง ์ด๋์๋ ์ด๊ฒ์ ์ ์ด๋๋ ๊ฒ์ ๊ธฐ์ตํ์ง ๋ชปํ๋ค๊ณ ์๊ฐํฉ๋๋ค) ์ปจ๋ณผ๋ฃจ์ ์๋ ์ด๋ ์ ๋์ ์์ ๊ฐ ์๋ค๋ ๊ฒ์ ๋๋ค. ๋ฉ๋ชจ๋ฆฌ ๋ ์ด์์์ด ํจ์จ์ ์ผ๋ก ๊ตฌํํ๋ ๋ฐฉ๋ฒ์ ์๊ณ ์๋ ์ด๋ค ๊ฒ๊ณผ๋ ์ผ์นํ์ง ์๋ ์ธ์, ์ด๋ค ๋ ์ด์์์ ์ฐ๊ฒฐํด์ผ ํฉ๋๊น? BC ์ด์ ๋ก ์ฑ๋์ ๋จผ์ ์ฐ๊ฒฐํด์ผ ํ์ง๋ง ์ฌ๊ธฐ์ ์์์ ๊ฒฐ์ ์ ๋ด๋ ธ์ต๋๋ค. ์๋ง๋ ์ฑ๋์ ๋ง์ง๋ง์ผ๋ก ์ฐ๊ฒฐํ ์๋ ์์ ๊ฒ์ ๋๋ค. ๊ธฐ๋ณธ๊ฐ์ด ๋ฌด์์ธ์ง ์๋ ค์ฃผ๋ ์ผ์ข ์ ์ค๋ ๋ ๋ก์ปฌ ํ ๊ธ์ด ์์ด์ผ ํ ๊น์?
๊ทธ๋ฌ๋ ์ฌ๊ธฐ์๋ ๋ง์ ์ธ๋ถ ์ฌํญ์ด ์๋ ๊ฒ ๊ฐ์ผ๋ฉฐ ๊ฒฐ๊ตญ ์ ๋ ์ง ๋ชจ๋ฅด๊ฒ ์ต๋๋ค.
๋ฐ๋ผ์ ์ปจ๋ณผ๋ฃจ์ ์ ํ๋ฆฟํจ(๋ฐ ๊ธฐํ ๋ ์ด์์ ์ธ์ ์ฐ์ฐ์, ์๋ฅผ ๋ค์ด ์ต๊ทผ์ ๋ณธ ์ ์ํ๋ง์ ์ ๋ ฅ์์ .contiguous()์ ํธ์ถํ์ฌ ์์ํฉ๋๋ค. ๊ทธ๋์ ๊ทธ๊ฒ์ด ์๋ฏธํ๋ ๋ฐ๊ฐ ๋ฌด์์ ๋๊น?)์ด ์ฃผ๋ ์ด์ ์์ต๋๋ค. ํ๊ทธ๋ฅผ ์๊ฐํ๊ธฐ ์ํด iirc.
๋ค, ๊ทธ๋์ ํ๊ทธ ๋์์ธ์ ๋ค์ ์ด์ด๋ ๊ด์ฐฎ์ต๋๋ค. ํ์ง๋ง ์ฐ๋ฆฌ๋
์ด๋ฌํ ํ๊ทธ๋ฅผ ์ ํํ๋ ๋ฐฉ๋ฒ์ ๋ฌธ์ ๋ฅผ ์ฌ๊ฐํ๊ฒ ํด๊ฒฐํด์ผ ํฉ๋๋ค.
๋ ์ด์์์ ์์ด๋ (์ฒญํน์ ๊ฒฝ์ฐ์ ๊ฐ์ด
์ฑ๋). ๋๋ "ํ์ฌ ๋ ์ด์์"์ ๋ง๋๋ ๊ฒ์ ํจ์ฌ ๋ ์ข์ํฉ๋๋ค.
๋ฐ์ดํฐ ์ข
์์ฑ์ ๋ง๋๋ ๊ฒ๋ณด๋ค ์ผ์ข
์ ์ปจํ
์คํธ ๊ด๋ฆฌ์์
๋๋ค.
2019-06-19 12:43:45 -0700์ ngimel ๋ฉ์์ง์์ ๋ฐ์ท:
๋ฐ๋ผ์ ์ปจ๋ณผ๋ฃจ์ ์ ํ๋ฆฟํจ(๋ฐ ๊ธฐํ ๋ ์ด์์ ์ธ์ ์ฐ์ฐ์, ์๋ฅผ ๋ค์ด ์ต๊ทผ์ ๋ณธ ์ ์ํ๋ง์ ์ ๋ ฅ์์ .contiguous()์ ํธ์ถํ์ฌ ์์ํฉ๋๋ค. ๊ทธ๋์ ๊ทธ๊ฒ์ด ์๋ฏธํ๋ ๋ฐ๊ฐ ๋ฌด์์ ๋๊น?)์ด ์ฃผ๋ ์ด์ ์์ต๋๋ค. ํ๊ทธ๋ฅผ ์๊ฐํ๊ธฐ ์ํด iirc.
BTW ์ ์ฐ๋ฆฌ๋ layout
์ง์ฐฉํ๋ ๋์ ์๋ก์ด ๊ฐ๋
์ ๋ง๋ค์ด์ผ ํฉ๋๊น? ํฌ์ ํํ์ "channels_last"์ ๊ฐ์ ๋ ์ด์์ ๊ฐ๋
์ด ์ ์ ์๋์ด ์์ง ์๋ค๊ณ ์๊ฐํ๋ฏ๋ก memory_formats * layouts
์ ์ ํ์ ๋ํ๋ผ ํ์๊ฐ ์์ต๋๋ค( layouts
๋ ํ์ฌ ์ฌ์ฉ์ ๋ํ๋
๋๋ค. ), ํ์ง๋ง memory_format + layouts
์ฌ์ฉํ๋ฉด ์ด์ ๊ณผ ๋์ผํ ์ธ์๋ฅผ ์ฌ์ฉํ๋ ๊ฒ์ด ์ข์ต๋๋ค. ๋์๊ฒ ๊ทธ๊ฒ์ ๋ ์งง๊ณ ๋ ์ข์ผ๋ฉฐ ํฉํ ๋ฆฌ ์๋ช
์ ์์ฒ ๊ฐ์ ์ธ์๋ก ํ์ฅํ๋ ๊ฒ์ ํผํ ์ ์์ต๋๋ค.
๋ ์ด์์ ์ต์ ์ด ๊ณ ๋ ค๋์์ง๋ง(๋ถ๋ก ํ์ธ) ๋ง์ ์ฝ๋ ์ค๋ณต์ด ๋ฐ์ํ๊ณ ํ ์๋ฅผ ์ฆ์ ๋ค๋ฅธ memory_format์ผ๋ก ์๋ ๋ณํํ๋ ๊ฒ์ ํ์ฉํ์ง ์์ต๋๋ค.
๊ฒฐ๊ตญ memory_format์ stride ํ ์๋ฅผ ์ฌ์ฉํ๊ณ ์์ ํ ๋ค๋ฅธ ํด๋์ค๊ฐ ์๋๋ผ strided ํ ์์ ์์ฑ์ธ ์ต์ ํ๋ ์ปค๋๊ณผ ์ถ๋ ฅ์ ์ฝ๊ฒ ์ ํํ๋ ๋ฐฉ๋ฒ์ ๋๋ค.
์ด๋ค ์๋ฏธ์์ ํฌ์ ๋ ์ด์์์ ๋๋ถ๋ถ์ด 0์ธ ๋ฐฐ์ด์ ๋ํด ์ต์ ํ๋ ์ปค๋์ ์ฝ๊ฒ ์ ํํ๋ ๋ฐฉ๋ฒ์ด๊ธฐ๋ ํฉ๋๋ค.
์ด๊ฒ์ ์์งํ ์ง๋ฌธ์ผ ์ ์์ง๋ง, PyTorch๊ฐ ์ด API๋ฅผ ๊ณ ๋ คํ๋ ๊ฒ๊ณผ ์ฌ์ฉ ๊ฐ๋ฅํ ๊ฒฝ์ฐ ๊ธฐ๋ณธ CuDNN ์ปค๋์ ์ง์ ํธ์ถํ๋ ์์ ์์ฒด์์ NHWC๋ฅผ ์ฌ์ฉํ๋ ์ต์ ์ ๋ ธ์ถํ๋ ์ด์ ๋ ๋ฌด์์ ๋๊น?
์ผ๋ฐ์ ์ธ ์ฌ์ฉ ์ฌ๋ก(conv ๋ฐ LM ์ํคํ
์ฒ์ ํ๋ง๊ณผ ๊ฐ์ ์ด๋ฏธ์ง ์์
ํผํฉ)์ ๊ฒฝ์ฐ ์ด๊ฒ์ด ์ฌ์ด ์๋ฃจ์
์ธ ๊ฒ ๊ฐ์ต๋๋ค. ๊ฐ๋ฐ์๋ก์ ๋ด๊ฐ ์ํ๋ ๊ฒ์ Conv2d(..., nhwc=True)
. ์ด๊ฒ ๋ง์ด ์ ๋๋ ์ด์ ๊ฐ ์๋์?
@rewonc ์ฐ๋ฆฌ๋ ์ ์ฌํ ์ ๊ทผ ๋ฐฉ์(
nhwc=True
์ต์
์ด ์๋ ํ ์
๋ ฅ์ ๋ค์ ์ฌ์ ์ํด์ผ ํฉ๋๋ค(์ฐ์์ ์ผ๋ก).nhwc=True
์ต์
์ด ํ์ํฉ๋๋ค.์ถ์ . CudNN Ex
ํจ์๊ฐ ๊ฑฑ์ ๋๋ค๋ฉด cudnn_batch_norm_nhwc
๋ฐ ์ ์ฌํ ์ฐ์ฐ์๋ฅผ ๋
ธ์ถํ๋ ค๊ณ ํฉ๋๋ค.
์๋ ํ์ธ์ @VitalyFedyunin๋ , ์ฐ๋ฆฌ๋ ๋ช ๋ช ๋ ํ ์๊ฐ PyTorch 1.3์์ ์ง์๋๋ ๊ฒ์ ๋ณด์์ต๋๋ค. NHWC(๋๋ ์ฐจ๋จ๋) ํ์ ์ง์์ ๋ํ ๋ฌธ์ ๋ฅผ ํด๊ฒฐ(๋๋ ๋ถ๋ถ์ ์ผ๋ก ํด๊ฒฐํ ์ ์์)ํ ์ ์์ต๋๊น? ๋ช ๋ช ๋ ํ ์๋ฅผ ๊ธฐ๋ฐ์ผ๋ก NHWC ์ํ๋ก ๋์๊ฐ ๊ณํ์ด ์์ต๋๊น?
์ฐ๋ฆฌ๋ ์ฑ๋์ ๋ง์ง๋ง ์ง์์ ๊ณ์ ์งํํ๊ณ ์์ผ๋ฉฐ, ์ด๋ฒ ์ฃผ์ ๋ก๋๋งต์ ์ฌ๊ธฐ์ slack ์ฑ๋์ ๊ฒ์ํ ์์ ์ ๋๋ค. (๋ชจ๋ ์ฐ์ฐ์๋ฅผ ๋ค์ ์์ฑํด์ผ ํ๋ฏ๋ก) ์ฐจ๋จ๋ ํ์์ ๊ณง ์ถ๊ฐํ๋ ๊ฒ์ ๊ณ ๋ คํ์ง ์์ต๋๋ค.
๊ฐ์ฌ ํด์. ์ ๋ ๊ฑฐ์ผ!
https://github.com/pytorch/pytorch/issues/28619 ๋ด๋ถ์ ํํน ์์ ๋ฐ ์งํ ์ํฉ
๊ฐ์ฅ ์ ์ฉํ ๋๊ธ
BTW ์ ์ฐ๋ฆฌ๋
layout
์ง์ฐฉํ๋ ๋์ ์๋ก์ด ๊ฐ๋ ์ ๋ง๋ค์ด์ผ ํฉ๋๊น? ํฌ์ ํํ์ "channels_last"์ ๊ฐ์ ๋ ์ด์์ ๊ฐ๋ ์ด ์ ์ ์๋์ด ์์ง ์๋ค๊ณ ์๊ฐํ๋ฏ๋กmemory_formats * layouts
์ ์ ํ์ ๋ํ๋ผ ํ์๊ฐ ์์ต๋๋ค(layouts
๋ ํ์ฌ ์ฌ์ฉ์ ๋ํ๋ ๋๋ค. ), ํ์ง๋งmemory_format + layouts
์ฌ์ฉํ๋ฉด ์ด์ ๊ณผ ๋์ผํ ์ธ์๋ฅผ ์ฌ์ฉํ๋ ๊ฒ์ด ์ข์ต๋๋ค. ๋์๊ฒ ๊ทธ๊ฒ์ ๋ ์งง๊ณ ๋ ์ข์ผ๋ฉฐ ํฉํ ๋ฆฌ ์๋ช ์ ์์ฒ ๊ฐ์ ์ธ์๋ก ํ์ฅํ๋ ๊ฒ์ ํผํ ์ ์์ต๋๋ค.