Pytorch: [RFC] ๋ฉ”๋ชจ๋ฆฌ ํ˜•์‹(์ผ๋ช… ๋ ˆ์ด์•„์›ƒ์ด๋ผ๊ณ ๋„ ํ•จ) ์ง€์›

์— ๋งŒ๋“  2019๋…„ 04์›” 10์ผ  ยท  68์ฝ”๋ฉ˜ํŠธ  ยท  ์ถœ์ฒ˜: pytorch/pytorch

๋ฌธ์ œ ์„ค๋ช…

CNN ์—ฐ์‚ฐ์ž๋Š” ํ…์„œ ์ฐจ์›์˜ ํ‘œ์ค€ ์ˆœ์„œ๋ฅผ ํ™œ์šฉํ•˜๊ณ  ์˜๋ฏธ๋ก ์  ์˜๋ฏธ๋ฅผ ํ• ๋‹นํ•ฉ๋‹ˆ๋‹ค. ์˜ค๋Š˜๋‚  PyTorch์˜ 2D ์‚ฌ๋ก€์˜ ๊ฒฝ์šฐ torch.nn.Conv2d์— ๋Œ€ํ•œ ์ž…๋ ฅ์€ NCHW ์ˆœ์„œ์˜ 4d ํ…์„œ์—ฌ์•ผ ํ•ฉ๋‹ˆ๋‹ค..

์„ฑ๋Šฅ์ƒ์˜ ์ด์œ ๋กœ ํŠน์ • ์ž‘์—…์—์„œ ์•ก์„ธ์Šคํ•˜๋Š” ๋ฉ”๋ชจ๋ฆฌ๊ฐ€ ์—ฐ์†์ ์œผ๋กœ ๋ฐฐ์น˜๋˜๊ณ  ์ง€์—ญ์„ฑ์ด ๋” ์ž˜ ํ™œ์šฉ๋˜๋„๋ก ์ฐจ์›์„ ๋‹ค๋ฅด๊ฒŒ ์žฌ์ •๋ ฌํ•˜๋Š” ๊ฒƒ์ด ์ข…์ข… ์œ ๋ฆฌํ•ฉ๋‹ˆ๋‹ค. ๊ฐ€์žฅ ์ผ๋ฐ˜์ ์ธ ์˜ต์…˜์€ ์น˜์ˆ˜๋ฅผ ๋์œผ๋กœ ์ด๋™ํ•˜๋Š” ๊ฒƒ์ž…๋‹ˆ๋‹ค - NHWC. ํ•œ ์ฐจ์›์„ ๋ธ”๋ก์œผ๋กœ ๋ฐ”๋‘‘ํŒ์‹์œผ๋กœ ๋ฐฐ์—ดํ•˜๋Š” ํ›จ์”ฌ ๋” ๋ณต์žกํ•œ ๋ฉ”๋ชจ๋ฆฌ ํ˜•์‹์ด ์žˆ์„ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค..

์ด๋ฅผ ํ™œ์šฉํ•˜๋Š” ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ์˜ ์˜ˆ๋Š” ๋‹ค์Œ๊ณผ ๊ฐ™์Šต๋‹ˆ๋‹ค.

  • cudnn์€ NHWC์˜ Volta์—์„œ ๋” ๋น ๋ฅธ ์„ฑ๋Šฅ์„ ์ œ๊ณตํ•ฉ๋‹ˆ๋‹ค.
  • fbgemm ๋ฐ qnnpack์€ NCHW๋ฅผ ์ง€์›ํ•˜์ง€ ์•Š์Šต๋‹ˆ๋‹ค.
  • libxsmm๋Š” NCHW๋ฅผ ์ง€์›ํ•˜์ง€๋งŒ ์„ฑ๋Šฅ ์ €ํ•˜๊ฐ€ 50%(IIRC)์™€ ๋น„์Šทํ•ฉ๋‹ˆ๋‹ค.

๋ฌธ์ œ๋Š” ์ฐจ์› ์ˆœ์„œ๋ฅผ ๋ณ€ํ™˜ํ•˜๋Š” ๊ฒƒ ์ž์ฒด๊ฐ€ ๋น„์šฉ์ด ๋งŽ์ด ๋“ค๊ธฐ ๋•Œ๋ฌธ์— ์—ฌ๋Ÿฌ CNN ์ž‘์—…์ด ์—ฐ์†์œผ๋กœ ์ˆ˜ํ–‰๋˜๋Š” ๊ฒฝ์šฐ(์˜ˆ: conv(relu(conv))) ) ๋‹ค๋ฅธ ๋ฉ”๋ชจ๋ฆฌ ํ˜•์‹์œผ๋กœ ํ•œ ๋ฒˆ ๋ณ€ํ™˜ ํ•˜๊ณ  ์ž‘์—…์„ ์ˆ˜ํ–‰ํ•˜๊ณ  ์žฌ์ •๋ ฌํ•˜๋Š” ๊ฒƒ์ด ์ข‹์Šต๋‹ˆ๋‹ค. ๋’ค.

๋”ฐ๋ผ์„œ PyTorch๊ฐ€ ๋‹ค์–‘ํ•œ ์ฐจ์› ์ˆœ์„œ๋ฅผ ์ธ์‹ํ•˜๊ณ  Eager ๋ชจ๋“œ์™€ JIT ๋ชจ๋“œ ๋ชจ๋‘์—์„œ ์ž‘์—… ๊ฐ„์— ๋‹ค๋ฅธ ๋ฉ”๋ชจ๋ฆฌ ํ˜•์‹์„ ๊ฐ€์ง„ ํ…์„œ๋ฅผ ์ „๋‹ฌํ•  ์ˆ˜ ์žˆ๋„๋ก

์šฐ๋ฆฌ๋Š” ๋‹ค์Œ์„ ํ‘œํ˜„ํ•  ์ˆ˜ ์žˆ๋Š” API๋ฅผ ๊ตฌ์ถ•ํ•˜๊ธฐ ์œ„ํ•ด ๋…ธ๋ ฅํ•ฉ๋‹ˆ๋‹ค.

  • Eager ๋ฐ JIT์˜ PyTorch์— ์žˆ๋Š” ๋‹ค๋ฅธ ๋ฉ”๋ชจ๋ฆฌ ํ˜•์‹(์ฒ˜์Œ์—๋Š” ์ฐจ์› ์ˆœ์„œ๋งŒ)์„ ๊ฐ€์ง„ ํ…์„œ. ์ฐจ๋‹จ๋œ ๋ ˆ์ด์•„์›ƒ์€ ์šฐ์„  ์ˆœ์œ„๊ฐ€ ๋‚ฎ์ง€๋งŒ ์—ฌ์ „ํžˆ ์ข‹์Šต๋‹ˆ๋‹ค.
  • ๋ฉ”๋ชจ๋ฆฌ ํ˜•์‹ ์ฟผ๋ฆฌ ๋ฐ ๋ณ€๊ฒฝ์„ ์œ„ํ•œ ์‚ฌ์šฉ์ž ๋…ธ์ถœ API
  • ๋‹ค๋ฅธ ๋ฉ”๋ชจ๋ฆฌ ํ˜•์‹์„ ๊ฐ€์ง„ ์ž…๋ ฅ ํ…์„œ๋ฅผ ์ฒ˜๋ฆฌํ•˜๊ณ  ํ•ด๋‹นํ•˜๋Š” ๋” ๋น ๋ฅธ ๊ตฌํ˜„์œผ๋กœ ๋ผ์šฐํŒ…ํ•  ์ˆ˜ ์žˆ๋Š” ํ•ต์‹ฌ CNN ์ž‘์—…
  • JIT ํŒจ์Šค์˜ ๋ฉ”๋ชจ๋ฆฌ ํ˜•์‹์„ ์ถ”๋ก ํ•˜๊ณ  ์ตœ์ ํ™”ํ•˜๋Š” ๊ธฐ๋Šฅ

์šฉ์–ด : ์œ„์˜ ๋ฌธ์ œ๋Š” ์ข…์ข… "layout"(mxnet), "data_format"(tf), "image_format"(keras), "order"(caffe2)๋ผ๊ณ  ํ•ฉ๋‹ˆ๋‹ค. ์šฐ๋ฆฌ๋Š” PyTorch์—์„œ "memory format" ๋˜๋Š” "memory_format"์ด๋ผ๋Š” ์ด๋ฆ„์„ ์‚ฌ์šฉํ•  ๊ฒƒ์„ ์ œ์•ˆํ•ฉ๋‹ˆ๋‹ค. "๋ ˆ์ด์•„์›ƒ"์ด๋ผ๋Š” ์ด๋ฆ„์€ ๋ถˆํ–‰ํžˆ๋„ PyTorch์—์„œ 'strided' ๋Œ€ 'sparse_coo' ๊ฐ’์„ ์‚ฌ์šฉํ•˜๋ฏ€๋กœ ์ด๋ฆ„ ์ง€์ • ์˜ต์…˜์„ ์‚ฌ์šฉํ•  ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค.

์˜ํ–ฅ์„ ๋ฐ›๋Š” ์šด์˜์ž

๋‹ค์Œ ์—ฐ์‚ฐ์ž๋Š” ์ตœ์†Œํ•œ ๋ฉ”๋ชจ๋ฆฌ ํ˜•์‹์„ ์ธ์‹ํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค. ์ •ํ™•ํ•œ ๊ฒฐ๊ณผ๋ฅผ ์ƒ์„ฑํ•˜๋Š” ๊ฒƒ ์™ธ์—๋„ ๊ธฐ๋ณธ ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ์—์„œ ์ตœ์ƒ์˜ ์„ฑ๋Šฅ ์„ ์ œ๊ณตํ•˜๊ณ  ๋ช…์‹œ์ ์œผ๋กœ ์ง€์ •๋œ ์‚ฌ์šฉ์ž ์˜๋„๋ฅผ ์ „ํŒŒํ•˜๊ธฐ ์œ„ํ•ด ์ถœ๋ ฅ์˜ ๋ฉ”๋ชจ๋ฆฌ ํ˜•์‹ ์„

  • ํšŒ์„ 
  • ๋‹ค์–‘ํ•œ ํ’€๋ง
  • ๋ฐฐ์น˜ ๋†ˆ, ๋ ˆ์ด์–ด ๋†ˆ, ์ธ์Šคํ„ด์Šค ๋†ˆ (์ผ๋ฐ˜์ ์œผ๋กœ, ๋†ˆ์— ์ƒ๊ด€์—†์ด)
  • ์—…์ƒ˜ํ”Œ๋ง/๋ณด๊ฐ„
  • ํ”ผ์ณ ๋“œ๋กญ์•„์›ƒ
  • softmax ์ˆ˜์ค€์€ ๋‚ฎ์Œ - ์ฐจ์›์„ ์ˆ˜๋™์œผ๋กœ ์ง€์ •ํ•  ์ˆ˜ ์žˆ์ง€๋งŒ ํšจ์œจ์ ์ธ ๊ตฌํ˜„์€ ์•”์‹œ์  nchw ๋ ˆ์ด์•„์›ƒ์— ๋Œ€ํ•ด์„œ๋งŒ ์กด์žฌํ•ฉ๋‹ˆ๋‹ค.
  • ์‹ฌ
  • ์š”์†Œ๋ณ„(๋‹จํ•ญ ๋ฐ ์ด์ง„) ์—ฐ์‚ฐ
  • ๋ฉ”๋ชจ๋ฆฌ ํ˜•์‹์„ ์ƒ์†ํ•˜๋Š” ํ…์„œ์˜ ์ƒ์„ฑ์ž(์˜ˆ: empty_like).

API ๋ฐ ๋™์ž‘ ๋ณ€๊ฒฝ ์‚ฌํ•ญ

PyTorch์—์„œ ๋ฉ”๋ชจ๋ฆฌ ํ˜•์‹์˜ ๊ฐœ๋… ์ •์˜:

  • torch.memory_format.channels_first ์™€ ๊ฐ™์€ ์ƒ์ˆ˜์ž…๋‹ˆ๋‹ค. ์œ ํ˜•์ด ์ง€์ •๋˜์ง€ ์•Š์•˜์œผ๋ฉฐ ์ž„์˜์˜ ๋น„๊ต ๊ฐ€๋Šฅํ•œ ๊ฐ์ฒด๊ฐ€ ๋  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค(enum์œผ๋กœ ์‹œ์ž‘ํ•  ๊ฐ€๋Šฅ์„ฑ์ด ๋†’์ง€๋งŒ ๋‚˜์ค‘์—๋Š” ๋ช…๋ช…๋œ ํ…์„œ์˜ ๊ฐœ๋…๊ณผ ์ƒํ˜ธ ์šด์šฉ๋˜๋Š” ๋‹ค๋ฅธ ๊ฐ์ฒด๊ฐ€ ๋  ์ˆ˜ ์žˆ์Œ)

    • ๋Œ€์•ˆ: torch.channels_first ์ง์ ‘ ์‚ฌ์šฉ

  • ๊ฐ’์€ channels_first ๋ฐ channels_last (๋” ์ ์€ ์ˆ˜์˜ ์ƒ์ˆ˜ ํ—ˆ์šฉ).
  • 1D ์ด๋ฏธ์ง€ / 3D ํ…์„œ์˜ ๊ฒฝ์šฐ ๊ฐ’์€ ํ‰๊ท  NCW, NWC, 2D ์ด๋ฏธ์ง€ / 4D ํ…์„œ - NCHW, NHWC, 3D ์ด๋ฏธ์ง€ / 5D ํ…์„œ - NCDHW, NDHWC

Tensor์— ๋‹ค์Œ ๋ฉ”์„œ๋“œ๋ฅผ ์ถ”๊ฐ€ํ•ฉ๋‹ˆ๋‹ค.

  • x.is_contiguous(torch.memory_format.channels_first)
  • x.to(memory_format=torch.memory_format.channels_first)

์ฐธ๊ณ  : ์ง€๊ธˆ์€ x.get_memory_format() ๊ธฐ๋Šฅ์ด ์—†๊ณ  ๋ช…์‹œ์  ๊ฒ€์‚ฌ๋งŒ ๊ฐ€๋Šฅํ•ฉ๋‹ˆ๋‹ค. ๊ฐ€๋Šฅํ•œ ๊ตฌํ˜„ ๋ฒ”์œ„๊ฐ€ ๋” ๋„“์Šต๋‹ˆ๋‹ค. ์šฐ๋ฆฌ๋Š” ๊ทธ๊ฒƒ์„ ์ถ”๊ฐ€ํ•˜๊ณ  ์‹ถ์„ ์ˆ˜๋„ ์žˆ์Šต๋‹ˆ๋‹ค.

ํ…์„œ ์˜๋ฏธ๋ก ์  ๋ ˆ์ด์•„์›ƒ์€ ํ•ญ์ƒ ๋™์ผํ•˜๊ฒŒ ์œ ์ง€๋ฉ๋‹ˆ๋‹ค - NCHW! x.size() ํ•ญ์ƒ (n,c,h,w) ๋ฐ˜ํ™˜ํ•ฉ๋‹ˆ๋‹ค.

์ž‘์—…์€ ๋ฉ”๋ชจ๋ฆฌ ํ˜•์‹ ๋™์ž‘์„ ์œ ์ง€ํ•ฉ๋‹ˆ๋‹ค.

  • ์ปจ๋ณผ๋ฃจ์…˜, ํ’€๋ง ๋“ฑ(์œ„ ์ฐธ์กฐ)์€ ์ž…๋ ฅ๊ณผ ๋™์ผํ•œ ๋ฉ”๋ชจ๋ฆฌ ํ˜•์‹ ์œผ๋กœ ์ถœ๋ ฅ์„ ๋ฐ˜ํ™˜ํ•˜๊ณ  ๋‚ด๋ถ€์ ์œผ๋กœ ์ตœ์ƒ์˜ ๊ตฌํ˜„์œผ๋กœ ๋””์ŠคํŒจ์น˜ํ•ฉ๋‹ˆ๋‹ค.
  • ๋‹จํ•ญ ์š”์†Œ๋ณ„ ์—ฐ์‚ฐ์€ ๋™์ผํ•œ ๋ฉ”๋ชจ๋ฆฌ ํ˜•์‹์„ ์œ ์ง€ํ•˜๊ณ  ์ธ์ ‘ํ•œ ํ…์„œ๋งŒํผ ๋น ๋ฅด๊ฒŒ ์‹คํ–‰ํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค.
  • ์ด์ง„ ์š”์†Œ๋ณ„ ์—ฐ์‚ฐ์€ ๋ฉ”๋ชจ๋ฆฌ ํ˜•์‹์„ ์œ ์ง€ํ•˜๋Š” ๋ฐ ์žˆ์–ด ํ•ฉ๋ฆฌ์ ์ธ ๋ณด์žฅ์„ ์ œ๊ณตํ•ฉ๋‹ˆ๋‹ค. ๋” ๊ด‘๋ฒ”์œ„ํ•˜๊ฒŒ ์ •์˜ํ•  ์ˆ˜ ์žˆ์ง€๋งŒ ์ตœ์†Œ๊ฐ’์€ ๋‹ค์Œ๊ณผ ๊ฐ™์Šต๋‹ˆ๋‹ค.

    • NHWC + ์Šค์นผ๋ผ โ†’ NHWC

    • NHWC + ์—ด ๋ฒกํ„ฐ โ†’ NHWC

  • ํ•ต์‹ฌ CNN ์—ฐ์‚ฐ์— ๋Œ€ํ•œ ์—ญ๋ฐฉํ–ฅ ์—ฐ์‚ฐ์€ ์ˆœ๋ฐฉํ–ฅ ๊ฒฝ๋กœ์—์„œ์™€ ๋™์ผํ•œ ๋ฉ”๋ชจ๋ฆฌ ํ˜•์‹์„ ์œ ์ง€ํ•ฉ๋‹ˆ๋‹ค. (์ถœ๋ ฅ์— ๋Œ€ํ•œ ์ˆ˜์‹  ๊ทธ๋ผ๋””์–ธํŠธ๊ฐ€ ๋‹ค๋ฅธ ๋ฉ”๋ชจ๋ฆฌ ํ˜•์‹์ผ ์ˆ˜ ์žˆ๊ธฐ ๋•Œ๋ฌธ์— ๋ช…์‹œ์ ์œผ๋กœ ์ ์šฉํ•ด์•ผ ํ•  ์ˆ˜๋„ ์žˆ์Šต๋‹ˆ๋‹ค.)

๋ฉ”๋ชจ๋ฆฌ ํ˜•์‹์€ ์ง๋ ฌํ™”/์—ญ์ง๋ ฌํ™”๋ฅผ ํ†ตํ•ด ๋ณด์กด๋˜๋Š” ํ…์„œ์˜ ์†์„ฑ์ž…๋‹ˆ๋‹ค(ํ…์„œ๊ฐ€ ๋งค๊ฐœ๋ณ€์ˆ˜์ธ ๊ฒฝ์šฐ).

์ŠคํŠธ๋ผ์ด๋“œ ๊ตฌํ˜„

์˜ค๋Š˜๋‚  PyTorch์˜ Tensor์—๋Š” ๋…ผ๋ฆฌ์  ํ…์„œ๊ฐ€ ๋ฉ”๋ชจ๋ฆฌ์— ๋ฐฐ์น˜๋˜๋Š” ๋ฐฉ์‹์„ ์ง€์ •ํ•˜๋Š” strides ๊ฐœ๋…์ด ์žˆ์Šต๋‹ˆ๋‹ค. ํŠนํžˆ ๊ฐ ํ…์„œ๋Š” sizes ์™€ ๊ฐ™์€ ๊ธธ์ด์˜ strides ๋ฒกํ„ฐ๋ฅผ ๊ฐ€์ง‘๋‹ˆ๋‹ค. (i1, i2, .., ik) ๋…ผ๋ฆฌ์  ์ธ๋ฑ์‹ฑ์—์„œ ์š”์†Œ๋ฅผ ์ธ๋ฑ์‹ฑํ•˜๋ ค๋ฉด ๋ณดํญ์œผ๋กœ ๋‚ด์ ์„ ์ˆ˜ํ–‰ํ•˜๊ณ  offset + i0*stride0 + i1*stride1 + ... * ik * stridek ์—์„œ ๋ฉ”๋ชจ๋ฆฌ๋ฅผ ์ฐพ์Šต๋‹ˆ๋‹ค. ๋”ฐ๋ผ์„œ ์ธ์ ‘ํ•œ ํ…์„œ๋Š” ํฌ๊ธฐ์˜ ๋ˆ„์  ๊ณฑ์ด ์—ญ์ „๋˜๋Š” ๋ณดํญ์„ ๊ฐ–์Šต๋‹ˆ๋‹ค. ์˜ˆ๋ฅผ ๋“ค์–ด (n,c,h,w) ํฌ๊ธฐ์˜ 4D ํ…์„œ์—๋Š” (c*h*w, h*w, w, 1) ์žˆ์Šต๋‹ˆ๋‹ค.

์ŠคํŠธ๋ผ์ด๋“œ๋Š” ๋…ผ๋ฆฌ์  ๊ธฐ๋ณธ NCHW ์ˆœ์„œ๋ฅผ ์œ ์ง€ํ•˜๋ฉด์„œ ๋ฌผ๋ฆฌ์ ์œผ๋กœ ๋‹ค๋ฅธ ๋ฉ”๋ชจ๋ฆฌ ํ˜•์‹(์ฐจ์› ์žฌ์ •๋ ฌ)์„ ๋‚˜ํƒ€๋‚ด๋Š” ๋ฐ ์‚ฌ์šฉํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ๋‹ค์Œ๊ณผ ๊ฐ™์ด ๋ฉ”๋ชจ๋ฆฌ ํ˜•์‹ ๋ณ€ํ™˜์— ๋Œ€ํ•œ ํšจ๊ณผ์ ์ธ ์ •์˜๋ฅผ ์ œ๊ณตํ•ฉ๋‹ˆ๋‹ค.

# implementation of x.to(channels_last)
def to_mem_format_nhwc(x):
    return x.permute(0,2,3,1).contiguous().permute(0,3,1,2)

# implementation of x.to(channels_first)
def to_mem_format_nchw(x):
    return x.contiguous()

NHWC ํ˜•์‹์—์„œ ๋ณดํญ ๋ฒกํ„ฐ๋Š” (c*h*w, 1, c*w, c) ์ž…๋‹ˆ๋‹ค. ๋”ฐ๋ผ์„œ ๋ฉ”๋ชจ๋ฆฌ ๋ฒ„ํผ์—์„œ ๊ฐ€์ค‘์น˜๋Š” NHWC์— ๋Œ€ํ•ด ์—ฐ์†์ ์ธ ์ˆœ์„œ์ž…๋‹ˆ๋‹ค.

Strides๋Š” ํ…Œ์ŠคํŠธ์— ์‚ฌ์šฉํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

def is_nhwc_contiguous(x):
    return x.permute(0,2,3,1).is_contiguous()

# or alteratively
def is_nhwc_contiguous(x):
    n,c,h,w = x.size() # in any case the sizes remain in NCHW order
    return x.stride() == (c*h*w, 1, c*w, c)

def is_nchw_contiguous(x):
    return x.is_contiguous()


# operator implementations can just check contiguity and carry on directly on data pointer
def my_sample_op(x):
    if x.is_contiguous(nhwc):
        float* p = x.data();
        # Do we need to go to c++ here? 
        # can we have an example in python?
        n,c,h,w = x.size()
        # operate on `p` as it's guaranteed to be (n,h,w,c) array
        y=my_nhwc_op(p)
        # Do we need to convert the layout of y?

    else:
        # Need to convert x to nhwc layout
        x = x.permute(0,2,3,1).contiguous()
        float *p = x.data();
        # Is this needed?
        y = my_nhwc_op(p)
        return y.permute(0,3,1,2).contiguous()

์ด ์ ‘๊ทผ ๋ฐฉ์‹์˜ ์žฅ์  :

  • ์ƒˆ๋กœ์šด ์ตœ์ƒ์œ„ ์•„์ด๋””์–ด๋‚˜ API ๋งค๊ฐœ๋ณ€์ˆ˜๋ฅผ ์ถ”๊ฐ€ํ•˜์ง€ ์•Š๊ณ  ๊ธฐ์กด PyTorch ๊ฐœ๋…์˜ ๋ณดํญ ํ™œ์šฉ
  • ํ‘œ์ค€ NCHW ์ˆœ์„œ๋กœ ํ…์„œ์˜ ๋…ผ๋ฆฌ์  ๋™์ž‘์„ ์œ ์ง€ํ•ฉ๋‹ˆ๋‹ค.
  • ์ž…๋ ฅ ์ฐจ์›์˜ ์ž„์˜ ์žฌ์ •๋ ฌ์— ๋Œ€ํ•ด ์ž‘๋™ํ•ฉ๋‹ˆ๋‹ค.
  • ๊ธฐ์กด ์ง๋ ฌํ™” ๋ฃจํ‹ด์€ ์ด๋ฏธ ํ…์„œ์˜ ๋ณดํญ์„ ์œ ์ง€ํ•ฉ๋‹ˆ๋‹ค.
  • ๋‹ค์–‘ํ•œ ๋ฉ”๋ชจ๋ฆฌ ๋ ˆ์ด์•„์›ƒ์—์„œ ์ž‘์—…ํ•˜๊ธฐ ์œ„ํ•ด ๋งŽ์€ ์ž‘์—…์„ ์žฌ์‚ฌ์šฉํ•˜๋Š” ๊ธฐ๋Šฅ

๋‹จ์  :

  • .contiguous() ํ˜ธ์ถœํ•˜๋Š” ๊ฒƒ์€ NCHW๋กœ ์ „ํ™˜ํ•˜๋Š” ๊ฒƒ๊ณผ ๋™์ผํ•˜๋ฉฐ ์‚ฌ์šฉ์ž ๋˜๋Š” ์ž‘์—… ์ค‘ ํ•˜๋‚˜ ๋‚ด๋ถ€์—์„œ ์šฐ์—ฐํžˆ ๋ฐœ์ƒํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

    • ์šด์˜์ž๊ฐ€ ๋ฉ”๋ชจ๋ฆฌ ํ˜•์‹์„ ๋ณด์กดํ•˜๋Š”์ง€ ํ™•์ธํ•˜๋ ค๋ฉด ์šด์˜์ž์— ๋Œ€ํ•œ ๋ช…์‹œ์  ๊ฐ์‚ฌ๊ฐ€ ํ•„์š”ํ•ฉ๋‹ˆ๋‹ค.

  • ์ฐจ๋‹จ/ํƒ€์ผ ํ˜•์‹์—์„œ๋Š” ์ž‘๋™ํ•˜์ง€ ์•Š์Šต๋‹ˆ๋‹ค. ๋‹ค๋ฅธ ์ ‘๊ทผ ๋ฐฉ์‹์ด ํ•„์š”ํ•ฉ๋‹ˆ๋‹ค.

    • PyTorch์—์„œ ์ผ๊ธ‰ ์‹œ๋ฏผ์œผ๋กœ ์ถ”๊ฐ€ํ•˜๋Š” ๊ฒƒ์„ ๊ณ ๋ คํ•  ์ˆ˜ ์žˆ์ง€๋งŒ ํ›จ์”ฌ ๋” ํฐ ๋ณ€ํ™”์ž…๋‹ˆ๋‹ค.

    • ๋Œ€์•ˆ์€ ๋ถˆํˆฌ๋ช… ํ•ธ๋“ค(์˜ˆ: MKLDNN ํ…์„œ)๋กœ ์ฒ˜๋ฆฌํ•˜๋Š” ๊ฒƒ์ž…๋‹ˆ๋‹ค.

  • ๊ธฐ๋ณธ ๊ตฌํ˜„์˜ ์„ฑ๋Šฅ ํŠน์„ฑ์ด ์ตœ์ข… ์‚ฌ์šฉ์ž์—๊ฒŒ ๋œ ๋ช…ํ™•ํ•จ

๊ฐ€์žฅ ํฐ ์ž ์žฌ์  ๋ฌธ์ œ๋Š” ๋ถˆ๋ถ„๋ช…ํ•œ ์‚ฌ์šฉ์ž ์˜๋„ ์ž…๋‹ˆ๋‹ค. ์‚ฌ์šฉ์ž๊ฐ€ ์ •๋ง๋กœ ๋‹ค๋ฅธ ๋ฉ”๋ชจ๋ฆฌ ํ˜•์‹์„ ์›ํ–ˆ๋Š”์ง€ ์•„๋‹ˆ๋ฉด ์ž…๋ ฅ ํ…์„œ๊ฐ€ ์šฐ์—ฐํžˆ ์ด๋Ÿฐ ์‹์œผ๋กœ ์ŠคํŠธ๋ผ์ด๋“œ(stride)๋˜์—ˆ๋Š”์ง€ ๊ตฌ๋ณ„ํ•  ๋ฐฉ๋ฒ•์ด ์—†์Šต๋‹ˆ๋‹ค. ํŠนํžˆ, ๊ธฐ์กด ์ž‘์—…์˜ ๋™์ž‘ ๋ณ€๊ฒฝ์œผ๋กœ ์ด์–ด์ง‘๋‹ˆ๋‹ค. ์˜ค๋Š˜๋‚  ์ปจ๋ณผ๋ฃจ์…˜์€ ์ž…๋ ฅ์ด ์ž„์˜์˜ ์ŠคํŠธ๋ผ์ด๋“œ์ธ ๊ฒฝ์šฐ์—๋„ NCHW ์—ฐ์† ํ…์„œ๋ฅผ ์ƒ์„ฑํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ์ƒˆ๋กœ์šด ์„ธ๊ณ„์—์„œ๋Š” ์ž…๋ ฅ์„ NHWC๋กœ ์ธ์‹ํ•˜์—ฌ NHWC๋„ ๋ฐ˜ํ™˜ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ์˜๋ฏธ ์ฒด๊ณ„๋Š” ๋ณ€๊ฒฝํ•˜์ง€ ์•Š์ง€๋งŒ ๋””๋ฒ„๊ทธํ•˜๊ธฐ ์–ด๋ ค์šด ์„ฑ๋Šฅ ๋ฌธ์ œ๋กœ ์ด์–ด์ง‘๋‹ˆ๋‹ค. ๊ฐ€๋Šฅํ•œ ํ•ด๊ฒฐ์ฑ…์€ ์‚ฌ์šฉ์ž ์ง€์ • memory_format ํ”Œ๋ž˜๊ทธ๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ํ…์„œ์— ๋ช…์‹œ์ ์œผ๋กœ ํƒœ๊ทธ๋ฅผ ์ง€์ •ํ•˜๊ณ  ์ด ์ฃผ์„(์ŠคํŠธ๋ผ์ด๋“œ ์™ธ์—)๋งŒ ๋”ฐ๋ฅด๋Š” ๊ฒƒ์ž…๋‹ˆ๋‹ค.

์œ„์˜ ๋ฌธ์ œ๋ฅผ ํ•ด๊ฒฐํ•˜๊ธฐ ์œ„ํ•ด ์ดˆ๊ธฐ ์ œ์•ˆ์€ ํ…์„œ์—์„œ ์ˆ˜ํ–‰๋œ ๋งˆ์ง€๋ง‰ to(memory_format) ํ˜ธ์ถœ์„ ๊ธฐ๋กํ•˜๋Š” "์†Œํ”„ํŠธ" ๋ฉ”๋ชจ๋ฆฌ ํ˜•์‹ ํƒœ๊ทธ๋ฅผ ํ…์„œ์— ๋„์ž…ํ•˜๋Š” ๊ฒƒ์ž…๋‹ˆ๋‹ค. ์šด์˜์ž๋Š” ์ด ์ฃผ์„์„ ์ถœ๋ ฅ์— ์ „ํŒŒํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค. ์ฃผ์„์€ "์†Œํ”„ํŠธ"์ด๋ฏ€๋กœ ๋ถˆ์ผ์น˜ ์ฃผ์„์— ๋Œ€ํ•œ ํ•˜๋“œ ์˜ค๋ฅ˜๊ฐ€ ๋ฐœ์ƒํ•˜์ง€ ์•Š๊ณ  ํ”„๋กœํŒŒ์ผ๋ง ๋ชจ๋“œ์—์„œ ๊ฒฝ๊ณ ๊ฐ€ ์ƒ์„ฑ๋ฉ๋‹ˆ๋‹ค.

์—ฐ์‚ฐ์ž ๊ตฌํ˜„

๊ธฐ์กด ์šด์˜์ž์˜ ์„œ๋ช…์€ ๋ณ€๊ฒฝ๋˜์ง€ ์•Š์Šต๋‹ˆ๋‹ค. ์šด์˜์ž๋Š” ์šด์˜์ž ๋‚ด๋ถ€์—์„œ ํ•˜๋“œ ์ฝ”๋”ฉ๋œ ๋””์ŠคํŒจ์น˜๋ฅผ โ€‹โ€‹์ˆ˜ํ–‰ํ•˜์—ฌ ๋” ๋น ๋ฅธ ๊ตฌํ˜„์œผ๋กœ ๋ผ์šฐํŒ…ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ๊ตฌํ˜„์ด ๋ถˆ๊ฐ€๋Šฅํ•œ ๊ฒฝ์šฐ ๋‹ค๋ฅธ ๋ฉ”๋ชจ๋ฆฌ ํ˜•์‹์„ ํ†ตํ•œ ์™•๋ณต์ด ๊ฐ€๋Šฅํ•ฉ๋‹ˆ๋‹ค. ๋Œ€์•ˆ์€ ์˜ค๋ฅ˜ ๋ฉ”์‹œ์ง€๋ฅผ ๋ฐœ์ƒ์‹œํ‚ค๋Š” ๊ฒƒ์ž…๋‹ˆ๋‹ค.

def maxpool(x: Tensor):
    if x.is_contiguous(torch.layout.NHWC):
        return max_pool_impl_nhwc(x)
    return max_pool_impl_default(x.contiguous())

'conv_nhwc'์™€ ๊ฐ™์€ ๋ณ„๋„์˜ ์—ฐ์‚ฐ์ž๋ฅผ ๋งŒ๋“œ๋Š” ๋Œ€์‹  'conv'์™€ ๊ฐ™์€ ๋‹จ์ผ ๊ธฐํ˜ธ๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ JIT IR์˜ ์—ฐ์‚ฐ์ž๋ฅผ ์ฐธ์กฐํ•˜๋Š” ๊ฒƒ์ด ์ข‹์Šต๋‹ˆ๋‹ค. ๊ทธ ์ด์œ ๋Š” ๋‹จ์ˆœ์„ฑ๊ณผ ์˜๋ฏธ๋ก ์  ํ‘œํ˜„ ์ˆ˜์ค€์—์„œ IR์„ ์œ ์ง€ํ•˜๊ธฐ ๋•Œ๋ฌธ์ž…๋‹ˆ๋‹ค.

์š”์†Œ๋ณ„ ์—ฐ์‚ฐ

element-wise์™€ ๊ฐ™์€ ํ•ต์‹ฌ ์ž‘์—…์ด ๋ฉ”๋ชจ๋ฆฌ ํ˜•์‹์„ ์œ ์ง€ํ•˜๊ณ  ํšจ์œจ์ ์ž„์„ ๋ณด์žฅํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค.

๋‹จํ•ญ ์—ฐ์‚ฐ์€ ์ผ๋ฐ˜์ ์œผ๋กœ ๋ฉ”๋ชจ๋ฆฌ ๋ธ”๋ก์ด "๋ฐ€๋„"์ธ์ง€ ์—ฌ๋ถ€, ์ฆ‰ ์š”์†Œ๊ฐ€ ๊ฐ„๊ฒฉ์ด ์—†๋Š” ์˜์—ญ์— ๊ฑธ์ณ ์žˆ๊ณ  ๊ฐ ๋ฉ”๋ชจ๋ฆฌ ์œ„์น˜๊ฐ€ ์ •ํ™•ํžˆ ํ•œ ๋ฒˆ ์‚ฌ์šฉ๋˜๋Š”์ง€ ์—ฌ๋ถ€๋ฅผ ํ™•์ธํ•˜์—ฌ ์ฒ˜๋ฆฌํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ๊ฐ„๋‹จํ•œ ์•Œ๊ณ ๋ฆฌ์ฆ˜์œผ๋กœ ๊ฒ€์ฆ ๊ฐ€๋Šฅ

def is_dense_format(x):
    p = 1
    for s, d in sorted(zip(x.stride(), x.size())):
        if s != p:
            return False
        p *= d
    return True

def my_unary(x):
    if is_dense_format(x):
        return contig_memory_impl(x.data(), x.numel())
    return default_strided_impl(x)

# is_dense_format can be used in implementations of e.g. empty_like too

์„ฑ๋Šฅ ๋„๊ตฌ

๋””๋ฒ„๊น… ์„ฑ๋Šฅ์„ ์œ„ํ•ด ํ”„๋กœํŒŒ์ผ๋Ÿฌ์— ๋‹ค์Œ ์ง€์›์„ ์ถ”๊ฐ€ํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค.

  • ํ”„๋กœ๊ทธ๋žจ์—์„œ ์‹ค์ œ ๋ฉ”๋ชจ๋ฆฌ ์žฌ์ •๋ ฌ์ด ๋ฐœ์ƒํ•˜๋Š” ์œ„์น˜ ํ™•์ธ - ์ฆ‰ .contiguous()์— ๋Œ€ํ•œ ํ˜ธ์ถœ ์ถ”์ 
  • ํ˜ธ์ถœ๋œ ๊ตฌํ˜„ ์ถ”์ 
  • ์˜ˆ๋ฅผ ๋“ค์–ด ๋ฐ”์ด๋„ˆ๋ฆฌ ์—ฐ์‚ฐ("์†Œํ”„ํŠธ" ์ฃผ์„์ด ์œ ์šฉํ•œ ๊ฒฝ์šฐ)์—์„œ ๋ฉ”๋ชจ๋ฆฌ ํ˜•์‹ ๋ณ€๊ฒฝ์— ๋Œ€ํ•œ ๊ฒฝ๊ณ ๋ฅผ ๋ฐœํ–‰ํ•ฉ๋‹ˆ๋‹ค.

์ด ๊ธฐ๋Šฅ์€ ์ฃผ๋ฌธํ˜• ํ”„๋กœํŒŒ์ผ๋ง ๋„๊ตฌ์— ๊ตฌ์ถ•ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

Autograd ์ฒ˜๋ฆฌ

์—ญ๋ฐฉํ–ฅ ํŒจ์Šค๊ฐ€ ์ •๋ฐฉํ–ฅ๊ณผ ๋™์ผํ•œ ๋ฉ”๋ชจ๋ฆฌ ํ˜•์‹์œผ๋กœ ์‹คํ–‰๋˜์–ด์•ผ ํ•œ๋‹ค๊ณ  ์˜ˆ์ƒํ•˜๋Š” ๊ฒƒ์ด ๋…ผ๋ฆฌ์ ์ž…๋‹ˆ๋‹ค. ๋“ค์–ด์˜ค๋Š” ๊ทธ๋ผ๋””์–ธํŠธ๊ฐ€ ์ž„์˜์ ์œผ๋กœ ์ŠคํŠธ๋ผ์ด๋“œ๋  ์ˆ˜ ์žˆ์œผ๋ฏ€๋กœ ํ•ญ์ƒ ์ž๋™์œผ๋กœ ๋ฐœ์ƒํ•˜์ง€๋Š” ์•Š์Šต๋‹ˆ๋‹ค. ๋”ฐ๋ผ์„œ ์ •๋ฐฉํ–ฅ ํŒจ์Šค๋Š” ๋ฉ”๋ชจ๋ฆฌ ํ˜•์‹์„ ๋ช…์‹œ์ ์œผ๋กœ ์ธ์‹ํ•˜๊ณ  autograd ํด๋กœ์ €์— ์ €์žฅํ•˜๊ณ  ์—ญ๋ฐฉํ–ฅ ๊ธฐ๋Šฅ ์ „์— grad ํ…์„œ์— ์ ์šฉํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค.

๊ฐ€๋Šฅํ•œ ๊ตฌํ˜„:

def conv_backward(input, weight, grad_output, grad_weight, grad_input):
  if input.is_contiguous(torch.memory_format.channels_last):
    grad_output = grad_output.to(torch.memory_format.channels_last)
    return conv_backward_nhwc(...)
  else:
    grad_output = grad_output.contiguous()
    return conv_backward_nchw(...)

JIT์—์„œ์˜ ํ‘œํ˜„

ํ˜„์žฌ ์ œ์•ˆ์€ ๋‹ค์Œ๊ณผ ๊ฐ™์Šต๋‹ˆ๋‹ค.

  • ์œ ํ˜• ์ฃผ์„์˜ ๋ฉ”๋ชจ๋ฆฌ ํ˜•์‹์— ๋Œ€ํ•œ ์ผ๊ธ‰ ์ฒ˜๋ฆฌ๋Š” ์•„์ง ์—†์Šต๋‹ˆ๋‹ค. ๋Œ€์‹  ๋ฉ”๋ชจ๋ฆฌ ํ˜•์‹์„ ์กฐ์ž‘ํ•˜๋Š” ํŒจ์Šค์— ํ•„์š”ํ•œ ๋ชจ์–‘์œผ๋กœ lookside ๋งต์„ ์œ ์ง€ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.
  • ๊ฐ’๋ณ„ ํ˜•์‹ ์ฃผ์„์„ ์ƒ์„ฑํ•˜๋Š” ์ถ”๋ก  ํŒจ์Šค(shape_inference์™€ ์œ ์‚ฌ)
  • ์ตœ์ ์˜ ์„ฑ๋Šฅ์„ ์œ„ํ•ด ํ•„์š”ํ•œ to(memory_format) ํ˜ธ์ถœ์„ ์‚ฝ์ž…ํ•ด์•ผ ํ•˜๋Š” ์œ„์น˜๋ฅผ ์ฐพ๋Š” ๋ฉ”๋ชจ๋ฆฌ ํ˜•์‹ ๋ณ€ํ™˜ ํŒจ์Šค(์ˆ˜๋™ ๋˜๋Š” ์ž๋™)

์ง‘ํ–‰ ๋ชฉ์ ์œผ๋กœ assert x.is_contiguous(channels_last) ์™€ ๊ฐ™์€ ๋ช…๋ น๋ฌธ์„ ์‚ฌ์šฉํ•  ์ˆ˜๋„ ์žˆ์Šต๋‹ˆ๋‹ค.

์ฐธ๊ณ : ํŠน์ • ์žฅ์น˜์— ์„ ํ˜ธ๋˜๋Š” ๋ฉ”๋ชจ๋ฆฌ ํ˜•์‹ ์กฐํ•ฉ์ด ์žˆ๋‹ค๋Š” ์ •๋ณด๋ฅผ ์ €์žฅํ•  ์œ„์น˜์— ๋Œ€ํ•œ ์งˆ๋ฌธ์ด ์žˆ์Šต๋‹ˆ๋‹ค(์˜ˆ: x86์˜ qconv๋Š” NHWC๋งŒ ๊ตฌํ˜„ํ•˜๋Š” fbgemm ๊ฒฝ๋กœ). ํ•œ ๊ฐ€์ง€ ์˜ต์…˜์€ op ๋“ฑ๋ก ์ˆ˜์ค€์— ๋‘๋Š” ๊ฒƒ์ด์ง€๋งŒ ๋ฉ”๋ชจ๋ฆฌ ํ˜•์‹ ์ฃผ์„์€ ๋” ๋งŽ์€ ๋ถ€๊ฐ€ ์ •๋ณด์ฒ˜๋Ÿผ ๋Š๊ปด์ง‘๋‹ˆ๋‹ค. ์„ ํ˜ธํ•˜๋Š” ๋ฉ”๋ชจ๋ฆฌ ํ˜•์‹ ๋ฐ ๊ด€๋ จ ํœด๋ฆฌ์Šคํ‹ฑ์„ ๋‚˜ํƒ€๋‚ด๋Š” ์ „์—ญ ๋งต์„ JIT ํŒจ์Šค์˜ ์–ด๋”˜๊ฐ€์— ์œ ์ง€ ๊ด€๋ฆฌํ•˜๋Š” ๊ฒƒ์œผ๋กœ ์‹œ์ž‘ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ์–ด์ˆ˜์„ ํ•˜๋ฉด ๋“ฑ๋ก ๊ธฐ๋ฐ˜ ๋ฉ”์ปค๋‹ˆ์ฆ˜์œผ๋กœ ์ „ํ™˜ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

Beyond: ์ฐจ๋‹จ๋œ ๋ ˆ์ด์•„์›ƒ

๋” ๋ณต์žกํ•œ ํ…์„œ ํŒจํ‚น์„ ์ถ”๊ฐ€ํ•˜๊ธฐ๋กœ ๊ฒฐ์ •ํ•จ์— ๋”ฐ๋ผ ๋†’์€ ๊ตฌํ˜„ ๋น„์šฉ๊ณผ ๋ณต์žก์„ฑ์œผ๋กœ ์ธํ•ด 1๊ธ‰ PyTorch ํ…์„œ๋ฅผ ์‚ฌ์šฉํ•˜๋Š” ๊ฒƒ์ด ํƒ€๋‹นํ•˜์ง€ ์•Š์„ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ๋‘ ๊ฐ€์ง€ ๋Œ€์•ˆ์ด ๊ฐ€๋Šฅํ•ฉ๋‹ˆ๋‹ค.

  • ์‚ฌ์šฉ์ž ์ •์˜ C ์œ ํ˜• ๋ฐ”์ธ๋”ฉ๊ณผ ๊ฐ™์€ ๋ถˆํˆฌ๋ช…ํ•œ ํ‘œํ˜„. ์„ฑ๋Šฅ ์ตœ์ ํ™” ์ธก๋ฉด์—์„œ ๋‹ค์–‘์„ฑ์ด ๋” ๋†’์€ ์ถ”๋ก ์—์„œ ํŒจํ‚น์„ ์œ„ํ•ด ์„ ํƒํ•˜๋Š” ์˜ต์…˜์ž…๋‹ˆ๋‹ค.
  • MKLDNNTensor์™€ ๊ฐ™์€ ์ผ๊ธ‰ ํ…์„œ ์œ ํ˜•์œผ๋กœ, ์ผ๋ถ€(์ „๋ถ€๋Š” ์•„๋‹˜) ์ž‘์—…์ด ์ด ์ƒˆ๋กœ์šด ์œ ํ˜•์— ๋ฐ”์ธ๋”ฉ๋ฉ๋‹ˆ๋‹ค.

๋˜ ๋‹ค๋ฅธ ๋Œ€์•ˆ์€ ํ•ต์‹ฌ PyTorch Tensor ํด๋ž˜์Šค์—์„œ ์ฐจ๋‹จ/ํƒ€์ผ๋ง์— ๋Œ€ํ•œ ๊ธฐ๋ณธ ์ง€์›์„ ๊ตฌํ˜„ํ•˜๋Š” ๊ฒƒ์ž…๋‹ˆ๋‹ค.

๋ช…๋ช…๋œ ํ…์„œ ๊ด€๊ณ„

NamedTensor ์— ๋Œ€ํ•œ ๊ธฐ์กด ์ œ์•ˆ์€ ํ…์„œ์— ๋Œ€ํ•œ ์œ ํ˜• ๊ฒ€์‚ฌ ๋ฉ”์ปค๋‹ˆ์ฆ˜์œผ๋กœ ๊ตฌ์„ฑ๋˜์–ด ์žˆ์Šต๋‹ˆ๋‹ค. ํ˜„์žฌ ์ฐจ์› ์ด๋ฆ„์— ์˜๋ฏธ๋ก ์  ์˜๋ฏธ๋ฅผ ํ• ๋‹นํ•˜์ง€ ์•Š์Šต๋‹ˆ๋‹ค. ๋”ฐ๋ผ์„œ ํ™œ์„ฑํ™” ํ…์„œ์˜ ์˜๋ฏธ๋ฅผ ์ถ”๋ก ํ•˜๋Š” ์œ ์ผํ•œ ๋ฐฉ๋ฒ•์€ ๋ฏธ๋ฆฌ ๊ฒฐ์ •๋œ NCHW ํ˜•์‹์„ ๊ณ„์† ์‚ฌ์šฉํ•˜๋Š” ๊ฒƒ์ž…๋‹ˆ๋‹ค. NamedTensor์™€ ํ˜„์žฌ ์ œ์•ˆ์„ ์ง๊ตํ•˜๊ฒŒ ๋งŒ๋“ญ๋‹ˆ๋‹ค.

์ผ๋ถ€ ์ด๋ฆ„์˜ ์˜๋ฏธ(์˜ˆ: "์ฑ„๋„", "๋„ˆ๋น„")๋ฅผ ๊ธฐ๊บผ์ด ์ง€์ •ํ•˜๋ ค๋Š” ๊ฒฝ์šฐ ์šด์˜์ž๋Š” ์ด ์ •๋ณด๋ฅผ ํ™œ์šฉํ•˜์—ฌ ๋” ๋น ๋ฅธ ๊ตฌํ˜„์œผ๋กœ ๋ผ์šฐํŒ…ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ์ž…๋ ฅ ํ…์„œ๊ฐ€ ๋…ผ๋ฆฌ์ ์œผ๋กœ NHWC(์˜ค๋Š˜๋‚ ๊ณผ ๊ฐ™์€ NCHW๊ฐ€ ์•„๋‹˜) ๋ฉ”๋ชจ๋ฆฌ ํ˜•์‹์„ ๊ฐ–๊ธฐ ๋•Œ๋ฌธ์— ์˜๋ฏธ๋ก ์  ๋ณ€ํ™”๊ฐ€ ๋  ๊ฒƒ์ž…๋‹ˆ๋‹ค.

์„ ํ–‰ ๊ธฐ์ˆ 

TensorFlow๋Š” data_format ๋งค๊ฐœ๋ณ€์ˆ˜๋ฅผ ํ†ตํ•ด ์šด์˜์ž ์ˆ˜์ค€์—์„œ NHWC์™€ NCHW๋ฅผ ๋ชจ๋‘ ์ง€์›ํ•ฉ๋‹ˆ๋‹ค. ํ—ˆ์šฉ๋˜๋Š” ๊ฐ’์€ 4์ฐจ์› ์ž…๋ ฅ์˜ ๊ฒฝ์šฐ ("NHWC", "NCHW"), 5์ฐจ์› ์ž…๋ ฅ์˜ ๊ฒฝ์šฐ ("NDHWC", "NCDHW") ๋˜๋Š” ์ž…๋ ฅ๊ณผ ๋ฌด๊ด€ํ•œ channels_first / channels_last ์ž…๋‹ˆ๋‹ค. ์ฐจ์›. ๋งค๊ฐœ๋ณ€์ˆ˜ ์„ค์ •์„ ์˜ฌ๋ฐ”๋ฅด๊ฒŒ ์ฒ˜๋ฆฌํ•˜๋Š” ๊ฒƒ์€ ์‚ฌ์šฉ์ž์—๊ฒŒ ๋‹ฌ๋ ค ์žˆ์Šต๋‹ˆ๋‹ค. ์ฆ‰, ํ…์„œ์— ์˜ํ•ด ์ž๋™์œผ๋กœ ์ถ”์ ๋˜์ง€ ์•Š์Šต๋‹ˆ๋‹ค.

Caffe2์ด ๋งค๊ฐœ ๋ณ€์ˆ˜๊ฐ€ ํ˜ธ์ถœ ํ˜ธ์ถœ order ๋ณด๋‹ค๋Š” data_format ,ํ•˜์ง€๋งŒ ์—ฌ์ „ํžˆ ๋ช…์‹œ ์ ์œผ๋กœ ๊ฐœ๋ณ„ ์šด์˜์ž ์ˆ˜์ค€์—์„œ ์ ์šฉํ•ฉ๋‹ˆ๋‹ค.


๋ถ€๋ก: ๊ณ ๋ ค๋˜๋Š” ๊ธฐํƒ€ ์˜ต์…˜

๋ฆฌํŠธ๋จธ์Šค ์งˆ๋ฌธ: ๋‹ค์Œ ์ฝ”๋“œ๋Š” ๋ฌด์—‡์„ ์ธ์‡„ํ•ฉ๋‹ˆ๊นŒ: tensor_in_nhwc_layout.size(1) - ์ฑ„๋„ ์ˆ˜(PyTorch์˜ ๊ธฐ๋ณธ๊ฐ’์€ NCHW์ด๊ธฐ ๋•Œ๋ฌธ์—) ๋˜๋Š” ๋†’์ด(์œ„์น˜ 1์˜ NHWC ๋ ˆ์ด์•„์›ƒ์— ์žˆ๊ธฐ ๋•Œ๋ฌธ์—).

์ด ๋‹ต๋ณ€์„ ๊ธฐ๋ฐ˜์œผ๋กœ ๋ช‡ ๊ฐ€์ง€ ์˜ต์…˜์ด ๊ฐ€๋Šฅํ•ฉ๋‹ˆ๋‹ค.

  • ์˜ต์…˜ A - ์ŠคํŠธ๋ผ์ด๋“œ(์œ„์— ์ œ์‹œ๋จ). Tensor ๋ ˆ์ด์•„์›ƒ์€ ์™„์ „ํžˆ ๋‚ด๋ถ€ ํ‘œํ˜„์ž…๋‹ˆ๋‹ค. ๊ตฌํ˜„๊ณผ ๋งˆ์ฐฌ๊ฐ€์ง€๋กœ ๋ณดํญ์œผ๋กœ ๊ฐ€์žฅ ํŽธ๋ฆฌํ•˜๊ฒŒ ์ˆ˜ํ–‰๋ฉ๋‹ˆ๋‹ค.

    • .size(1)์€ "์ฑ„๋„"์„ ๋ฐ˜ํ™˜ํ•˜์ง€๋งŒ ๋‚ด๋ถ€ ๋ฉ”๋ชจ๋ฆฌ๋Š” ๋‹ค๋ฅด๊ฒŒ ๋ฐฐ์น˜๋ฉ๋‹ˆ๋‹ค.

    • ์žฅ์ : ๋ชจ๋ธ์˜ ์ฝ”๋“œ๋ฅผ ๋ณ€๊ฒฝํ•˜์ง€ ์•Š์Šต๋‹ˆ๋‹ค. ๋‚ด ๋ชจ๋ธ์€ ์—ฌ์ „ํžˆ โ€‹โ€‹์ง์ ‘ ์ฐจ์› ์‚ฐ์ˆ ์„ ์ˆ˜ํ–‰ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ์‹ค์ œ๋กœ ๊ณต๊ฐœ API ๋ณ€๊ฒฝ ์‚ฌํ•ญ์€ ์—†์Šต๋‹ˆ๋‹ค.

    • ๋‹จ์ : strides ๊ตฌํ˜„์—์„œ ๋งŽ์€ ์—ฐ์‚ฐ์ž๊ฐ€ .contiguous()๋ฅผ ํ˜ธ์ถœํ•˜๊ณ  ์‹ค์ˆ˜๋กœ ๋ ˆ์ด์•„์›ƒ์„ ๋‹ค์‹œ ๋˜๋Œ๋ฆด ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

    • ๋‹จ์ : ์‚ฌ์šฉ์ž ๊ด€์ ์—์„œ ์ˆ˜์ต๋ฅ  ๋ณด์žฅ์ด ๋ฌด์—‡๋ณด๋‹ค ์ค‘์š”ํ•œ์ง€ ์ดํ•ดํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค. ์ด IMO๋Š” ๋ณดํญ ์ „์šฉ ์ ‘๊ทผ ๋ฐฉ์‹์„ ์ œ๊ฑฐํ•ฉ๋‹ˆ๋‹ค. ์™œ๋ƒํ•˜๋ฉด ์ž‘์—…์ด ๋ฐ˜ํ™˜๋  ํ˜•์‹์„ ์ดํ•ดํ•˜๊ธฐ๊ฐ€ ๋งค์šฐ ์–ด๋ ค์›Œ์ง€๊ณ  "๋‚ด ๋ณดํญ์„ ๋ฌด์‹œํ•˜๊ณ  ์‹ค์ œ๋กœ NCHW์— ์ธ์ ‘ํ•œ ๊ฒƒ์„ ๋ฐ˜ํ™˜ํ•˜์‹ญ์‹œ์˜ค"๋ผ๊ณ  ๋งํ•˜๋Š” API๊ฐ€ ์—†๊ธฐ ๋•Œ๋ฌธ์ž…๋‹ˆ๋‹ค. ์ด๋Š” ์œ„์˜ ์ œํ•œ ์‚ฌํ•ญ์— ์ถ”๊ฐ€๋ฉ๋‹ˆ๋‹ค.

  • ์˜ต์…˜ B - ๋ช…์‹œ์  NHWC ํ…์„œ. ์‚ฌ์šฉ์ž๊ฐ€ ์ฐจ์› ์ˆœ์„œ๊ฐ€ ๋‹ค๋ฅธ ํ…์„œ๋ฅผ ๋ช…์‹œ์ ์œผ๋กœ ์กฐ์ž‘ํ•˜์ง€๋งŒ ํ…์„œ ์ž์ฒด๋Š” ์ด์— ๋Œ€ํ•ด ์•„๋ฌด๊ฒƒ๋„ ๋ชจ๋ฆ…๋‹ˆ๋‹ค. ์‚ฌ์šฉ์ž๊ฐ€ ๊ธฐ๋Œ€ํ•˜๋Š” ๊ฒƒ์„ ํŒŒ์•…ํ•˜๋ ค๋ฉด ์šด์˜์ž ์ˆ˜์ค€์—์„œ ์ฃผ์„์ด ํ•„์š”ํ•ฉ๋‹ˆ๋‹ค.

    • .size(1)๋Š” "๋†’์ด"๋ฅผ ๋ฐ˜ํ™˜ํ•ฉ๋‹ˆ๋‹ค.

    • ์žฅ์ : ๋งˆ์ˆ ์ด ์—†๊ณ  ๋งค์šฐ ์˜ˆ์ธก ๊ฐ€๋Šฅ

    • ๋‹จ์ : ํ•œ ๋ ˆ์ด์•„์›ƒ์—์„œ ๋‹ค๋ฅธ ๋ ˆ์ด์•„์›ƒ์œผ๋กœ ๋ชจ๋ธ์„ ๋ณ€๊ฒฝํ•˜๋Š” ๊ฒƒ์€ .size() ๋ฐ .reshape()์— ๋Œ€ํ•œ ๋ชจ๋“  ์•ก์„ธ์Šค๋ฅผ ์ถ”์ ํ•ด์•ผ ํ•˜๋Š” ๋ณต์žกํ•œ ์ž‘์—…์ด ๋ฉ๋‹ˆ๋‹ค(๋˜๋Š” API์—์„œ ๋ช…์‹œ์ ์œผ๋กœ ๋งŒ๋“ค์–ด์•ผ ํ•ฉ๋‹ˆ๊นŒ?)

  • ์˜ต์…˜ B' - ๋ ˆ์ด์•„์›ƒ ํ”Œ๋ž˜๊ทธ๊ฐ€ ์žˆ๋Š” ๋ช…์‹œ์  NHWC ํ…์„œ . ์œ„์™€ ๋™์ผํ•˜์ง€๋งŒ ํ…์„œ์— ์ฃผ์„์„ ์ถ”๊ฐ€ํ•˜์—ฌ ์ž‘์—…์ด ๊ตฌํ˜„์—์„œ ์†Œ๋น„ํ•˜๋Š” ์˜๋ฏธ๋ก ์  ๋ ˆ์ด์•„์›ƒ์„ ํ‘œ์‹œํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ๊ทธ๋Ÿฌ๋ฉด ์šด์˜์ž ์ˆ˜์ค€ ์ฃผ์„์ด ํ•„์š”ํ•˜์ง€ ์•Š์Šต๋‹ˆ๋‹ค. ์šด์˜์ž๋Š” ์ž…๋ ฅ์˜ ๋ ˆ์ด์•„์›ƒ ํ”Œ๋ž˜๊ทธ๋ฅผ ๊ธฐ๋ฐ˜์œผ๋กœ ๋””์ŠคํŒจ์น˜๋ฅผ โ€‹โ€‹์ˆ˜ํ–‰ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.
  • ์˜ต์…˜ C - ๋ช…๋ช…๋œ Tensor . ( https://docs.google.com/document/d/1ynu3wA2hcjwOtEng04N904gJjEbZWcINXO_ardX6hxc/edit#heading =h.2gbe5xpga3w9)

    • .size(1)๋Š” "๋†’์ด"๋ฅผ ๋ฐ˜ํ™˜ํ•˜์ง€๋งŒ ์šฐ๋ฆฌ๋Š” ์‚ฌ๋žŒ๋“ค์—๊ฒŒ ์ด API๋ฅผ ์‚ฌ์šฉํ•˜์ง€ ๋ง๊ณ  ๋Œ€์‹  .size('channel')๋ฅผ ์‚ฌ์šฉํ•˜๋„๋ก ์š”์ฒญํ•ฉ๋‹ˆ๋‹ค.

    • ์žฅ์ : ๋งค์šฐ ๋ช…์‹œ์ ์ด๋ฉฐ ์‚ฌ์šฉ์ž๊ฐ€ ์›ํ•˜๋Š” ๊ฒƒ

    • con: ์ „ํ™˜ ๋ฌธ์ œ๊ฐ€ ํ•ด๊ฒฐ๋˜์ง€ ์•Š์œผ๋ฉด ๋ ˆ์ด์•„์›ƒ ์ธ์‹์œผ๋กœ ์ž‘์„ฑ๋œ ๋ชจ๋“  ์ฝ”๋“œ์—์„œ ๋ช…๋ช…๋œ ํ…์„œ๋ฅผ ์‚ฌ์šฉํ•˜๋„๋ก ๊ฐ•์ œํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค. ๊ทธ๋ ‡์ง€ ์•Š์€ ๊ฒฝ์šฐ - ์œ„์™€ ๋™์ผํ•œ ๋ฌธ์ œ๊ฐ€ ์ ์šฉ๋ฉ๋‹ˆ๋‹ค.

  • ์˜ต์…˜ D-๋ ˆ์ด์•„์›ƒ์€ ๋ถˆํˆฌ๋ช…ํ•œ ํ…์„œ ์œ ํ˜• ์ž…๋‹ˆ๋‹ค. MKLDNN ๋˜๋Š” SparseTensor - ๋‹ค๋ฅธ DispatchID๋ฅผ ๊ฐ€์ง„ ๋ณ„๋„์˜ ํ…์„œ ์œ ํ˜•์„ ์ทจ๊ธ‰ํ•˜๋Š” ๊ฒƒ์ฒ˜๋Ÿผ NHWC๋ฅผ ์ทจ๊ธ‰ํ•ฉ๋‹ˆ๋‹ค. ์˜ต์…˜ A์™€ ๋น„์Šทํ•˜์ง€๋งŒ ๊ธฐ๋ณธ ๋™์ž‘์— ๋Œ€ํ•ด ์„œ๋กœ ๋‹ค๋ฅธ ์ ˆ์ถฉ์•ˆ์ด ์žˆ์Šต๋‹ˆ๋‹ค. ๊ตฌํ˜„๋˜์ง€ ์•Š์€ ์ž‘์—…์€ NCHW๋กœ ๋˜๋Œ๋ฆฌ๋Š” ๋Œ€์‹  ์‹คํŒจํ•ฉ๋‹ˆ๋‹ค.

    • .size(1)์€ ์—ฌ์ „ํžˆ โ€‹โ€‹"์ฑ„๋„"์„ ๋ฐ˜ํ™˜ํ•ฉ๋‹ˆ๋‹ค.

    • ์žฅ์ : ๋งˆ์ˆ ์ด ์—†๊ณ  ๋ช…์‹œ์ ์ด๋ฉฐ ๋ณ„๋„์˜ ๋””์ŠคํŒจ์น˜๊ฐ€ ์šด์˜์ž๊ฐ€ ์›ํ•˜๋Š” ๊ฒƒ์„ ๊ฒฐ์ •ํ•  ์ˆ˜ ์žˆ๋„๋ก ํ•ฉ๋‹ˆ๋‹ค.

    • ์žฅ์ /๋‹จ์ : ํ•„์š”ํ•œ ๋ชจ๋“  ์—ฐ์‚ฐ์ž๋Š” ๋‹ค๋ฅธ ๋ ˆ์ด์•„์›ƒ์—์„œ ๊ตฌํ˜„ํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค. ์ผ๋ถ€ ์—ฐ์‚ฐ์ด ๋ˆ„๋ฝ๋œ ๊ฒฝ์šฐ ์‚ฌ์šฉ์ž๋Š” ์ง€์›๋˜์ง€ ์•Š๋Š”๋‹ค๋Š” ๋ช…์‹œ์  ์˜ค๋ฅ˜๊ฐ€ ๋ฐœ์ƒํ•ฉ๋‹ˆ๋‹ค.

    • ๋‹จ์ : ์˜ˆ์ƒ ๊ฒฐ๊ณผ๋ฅผ ์˜ˆ์ธกํ•˜๊ธฐ ์–ด๋ ต๊ธฐ ๋•Œ๋ฌธ์— ๋ทฐ์™€ ๊ฐ™์€ ๋งŽ์€ ์ž‘์—…์„ ๊ธˆ์ง€ํ•ด์•ผ ํ•  ๊ฒƒ์ž…๋‹ˆ๋‹ค.

internals mkldnn triaged

๊ฐ€์žฅ ์œ ์šฉํ•œ ๋Œ“๊ธ€

BTW ์™œ ์šฐ๋ฆฌ๋Š” layout ์ง‘์ฐฉํ•˜๋Š” ๋Œ€์‹  ์ƒˆ๋กœ์šด ๊ฐœ๋…์„ ๋งŒ๋“ค์–ด์•ผ ํ•ฉ๋‹ˆ๊นŒ? ํฌ์†Œ ํ‘œํ˜„์€ "channels_last"์™€ ๊ฐ™์€ ๋ ˆ์ด์•„์›ƒ ๊ฐœ๋…์ด ์ž˜ ์ •์˜๋˜์–ด ์žˆ์ง€ ์•Š๋‹ค๊ณ  ์ƒ๊ฐํ•˜๋ฏ€๋กœ memory_formats * layouts ์˜ ์ œํ’ˆ์„ ๋‚˜ํƒ€๋‚ผ ํ•„์š”๊ฐ€ ์—†์Šต๋‹ˆ๋‹ค( layouts ๋Š” ํ˜„์žฌ ์‚ฌ์šฉ์„ ๋‚˜ํƒ€๋ƒ…๋‹ˆ๋‹ค. ), ํ•˜์ง€๋งŒ memory_format + layouts ์‚ฌ์šฉํ•˜๋ฉด ์ด์ „๊ณผ ๋™์ผํ•œ ์ธ์ˆ˜๋ฅผ ์‚ฌ์šฉํ•˜๋Š” ๊ฒƒ์ด ์ข‹์Šต๋‹ˆ๋‹ค. ๋‚˜์—๊ฒŒ ๊ทธ๊ฒƒ์€ ๋” ์งง๊ณ  ๋” ์ข‹์œผ๋ฉฐ ํŒฉํ† ๋ฆฌ ์„œ๋ช…์„ ์ˆ˜์ฒœ ๊ฐœ์˜ ์ธ์ˆ˜๋กœ ํ™•์žฅํ•˜๋Š” ๊ฒƒ์„ ํ”ผํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

๋ชจ๋“  68 ๋Œ“๊ธ€

empty_like ์—๋Š” ํ•œ ๊ฐ€์ง€ ๋ฌธ์ œ๊ฐ€ ์žˆ์Šต๋‹ˆ๋‹ค. ํ˜„์žฌ ์ •์˜๋œ ์˜๋ฏธ๋Š” ๋ชจ๋“  ๋ณดํญ ์ •๋ณด๋ฅผ ์‚ญ์ œํ•œ๋‹ค๋Š” ๊ฒƒ์ด๋ฏ€๋กœ ๋ ˆ์ด์•„์›ƒ์„ ์œ ์ง€ํ•˜๊ณ  BC๊ฐ€ ๋  ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค.

@VitalyFedyunin ์€ .contiguous() ๋ฐ torch.memory_layout ๋น„ํŠธ๋ฅผ ๊ตฌํ˜„ํ•˜๋„๋ก ๋“ฑ๋ก๋˜์—ˆ์Šต๋‹ˆ๋‹ค.

ํ•œ ๊ฐ€์ง€ ์งˆ๋ฌธ - (n, c, h, w) ํฌ๊ธฐ์˜ 4D ํ…์„œ x (n, c, h, w)

x = torch.randn(n,c,h,w)
# x.size(): (n, c, h, w)
# x.stride(): (c*h*w, h*w, w, 1)

์šฐ๋ฆฌ๋Š” ์ด์ƒํ•œ ์ˆœ์—ด์„ ๊ฐ€์ง€๊ณ  ์žˆ์Šต๋‹ˆ๋‹ค

y = x.permute(0, 3, 1, 2)
# y.size(): (n, w, c, h)
# y.stride(): (c*h*w, 1, h*w, w)

์ด์ œ NHWC ํ˜•์‹์— ๋Œ€ํ•ด ์—ฐ์†์ ์ธ์ง€ ํ™•์ธํ•ฉ๋‹ˆ๋‹ค. ์•„๋ž˜์™€ ๊ฐ™์€ ๋…ผ๋ฆฌ์— ๋”ฐ๋ผ

def is_nhwc_contiguous(x):
    return x.permute(0,2,3,1).is_contiguous()

# or alternatively
def is_nhwc_contiguous(x):
    n,c,h,w = x.size() # in any case the sizes remain in NCHW order
    return x.stride() == (c*h*w, 1, c*w, c)

๋‘ ๊ฒฝ์šฐ ๋ชจ๋‘ is_nhwc_contiguous(y) ๋Š” True๋ฅผ ๋ฐ˜ํ™˜ํ•ฉ๋‹ˆ๊นŒ?

์ด๊ฒƒ์€ ๋งž์Šต๋‹ˆ๋‹ค. ๊ทธ๋Ÿฌ๋‚˜ ๋ณต์‚ฌ, to ๋ฐ ์œ ์‚ฌํ•œ ์ž‘์—… ์ค‘์— ์•ž๋’ค๋กœ ๋ณ€ํ™˜์„ ํ”ผํ•˜๊ธฐ ์œ„ํ•ด ๋ณดํญ์—๋งŒ ๋ฆด๋ ˆ์ดํ•  ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค.

strides์˜ ์ˆœ์„œ๊ฐ€ ๋ฉ”๋ชจ๋ฆฌ ํ˜•์‹๊ณผ ๊ฐ™์œผ๋ฉด ์–ด๋–ป๊ฒŒ ๋ ๊นŒ์š”? 4D ํ…์„œ๋ฅผ ์˜ˆ๋กœ ๋“ค์–ด๋ณด๊ฒ ์Šต๋‹ˆ๋‹ค. ํ…์„œ๋ฅผ ์„ค๋ช…ํ•˜๊ธฐ ์œ„ํ•ด sizes , strides ๋ฐ stride_indexes .

์˜ ํฌ๊ธฐ (N, C, H, w)
๋ฌผ๋ฆฌ์  ์ˆœ์„œ์— ๋”ฐ๋ฅธ ๋ณดํญ , ์ฆ‰

  • ํ˜•์‹์ด nchw์ธ ๊ฒฝ์šฐ (n, c, h, w)์˜ ๋ณดํญ
  • ํ˜•์‹์ด nhwc์ธ ๊ฒฝ์šฐ (n, h, w, c)์˜ ๋ณดํญ.

stride_indexes ๋Š”

  • (0, 1, 2, 3) ํ˜•์‹์ด nchw์ธ ๊ฒฝ์šฐ
  • (0, 2, 3, 1) ํ˜•์‹์ด nhwc์ธ ๊ฒฝ์šฐ.

nchw ํ˜•์‹์˜ ๊ฒฝ์šฐ ์ด์ „๊ณผ ๋™์ผํ•ฉ๋‹ˆ๋‹ค. nhwc์˜ ๊ฒฝ์šฐ ๋น„์Šทํ•  ๊ฒƒ์ž…๋‹ˆ๋‹ค.

def is_nhwc_contiguous(x):
     n,c,h,w = x.size()
     return x.stride() == (h*w*c, w*c, c, 1)

def is_nchw_contiguous(x):
    n,c,h,w = x.size()
    return x.stride() == (c*h*w, h*w, w, 1)

def is_nchw_format(x):
    return x.stride_index() == (0, 1, 2, 3) 

def is_nhwc_format(x):
    return x.stride_index == (0, 2, 3, 1)

def is_contiguous(x):
    if (is_nchw_format(x)):
        return is_nchw_contiguous(x)
    else if (is_nhwc_format(x)):
        return  is_nhwc_contiguous(x)
    else:
        warning_not_support()

# or, to use stride_index
def is_contiguous(x):
    return x.stride() == (x.size[x.stride_index[1]]*x.size[x.stride_index[2]]*x.size[x.stride_index[3]], x.size[x.stride_index[2]] * x.size[x.stride_index[3]], x.size[x.stride_index[3]], 1)

์ด๊ฒƒ์€ ๋˜ํ•œ ์ฐจ๋‹จ๋œ ํ˜•์‹์„ ์ง€์›ํ•˜๋„๋ก ํ™•์žฅ๋  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. nChw16c๋ฅผ ์˜ˆ๋กœ ์‚ฌ์šฉํ•ฉ๋‹ˆ๋‹ค.

sizes: (n, c, h, w)
block_sizes: (n, c/16, h, w, 16)
strides: strides of (n, c/16, h, w, 16)
stride_indexes: (0, 1, 2, 3, 1)  # assume blocked dimension is always in dense (i.e. on the right side of major dimension)

์ž์„ธํ•œ ๋‚ด์šฉ์€ ๋‚˜์ค‘์— ์ž์„ธํžˆ ์•Œ์•„๋ณผ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

nchw ์—ฐ์† ํ…์„œ๋งŒ ํ—ˆ์šฉํ•˜๋Š” OP์˜ ๊ฒฝ์šฐ ์—ฌ๊ธฐ์—์„œ ์•ฝ๊ฐ„์˜ ์ž‘์—…์ด ํ•„์š”ํ•  ๊ฒƒ์ž…๋‹ˆ๋‹ค.

๋˜๋Š” ํ”„๋กœํ† ํƒ€์ž…์„ ์•ฝ๊ฐ„ ๋ณ€๊ฒฝํ•  ์ˆ˜๋„ ์žˆ์Šต๋‹ˆ๋‹ค.

def is_contiguous(format=nchw):
    ...
def contiguous(format=nchw)
    ...

๋”ฐ๋ผ์„œ ๊ธฐ๋ณธ์ ์œผ๋กœ nchw๋งŒ ์—ฐ์†์ ์ด๋ผ๊ณ  ๊ฐ€์ •ํ•ฉ๋‹ˆ๋‹ค. ์ด๋Ÿฐ ์‹์œผ๋กœ ํ•ด๋‹น OP๋ฅผ ๋‹ค์‹œ ์ž‘์„ฑํ•  ํ•„์š”๊ฐ€ ์—†์œผ๋ฉฐ ์ž๋™์œผ๋กœ nchw๋กœ ์žฌ์ •๋ ฌ๋ฉ๋‹ˆ๋‹ค.

์šฐ๋ฆฌ๋Š” ๋‹ค์Œ์„ ํ‘œํ˜„ํ•  ์ˆ˜ ์žˆ๋Š” API๋ฅผ ๊ตฌ์ถ•ํ•˜๊ธฐ ์œ„ํ•ด ๋…ธ๋ ฅํ•ฉ๋‹ˆ๋‹ค.

  • Eager ๋ฐ JIT์˜ PyTorch์— ์žˆ๋Š” ๋‹ค๋ฅธ ๋ฉ”๋ชจ๋ฆฌ ํ˜•์‹(์ฒ˜์Œ์—๋Š” ์ฐจ์› ์ˆœ์„œ๋งŒ)์„ ๊ฐ€์ง„ ํ…์„œ. ์ฐจ๋‹จ๋œ ๋ ˆ์ด์•„์›ƒ์€ ์šฐ์„  ์ˆœ์œ„๊ฐ€ ๋‚ฎ์ง€๋งŒ ์—ฌ์ „ํžˆ ์ข‹์Šต๋‹ˆ๋‹ค.
  • ๋ฉ”๋ชจ๋ฆฌ ํ˜•์‹ ์ฟผ๋ฆฌ ๋ฐ ๋ณ€๊ฒฝ์„ ์œ„ํ•œ ์‚ฌ์šฉ์ž ๋…ธ์ถœ API
  • ๋‹ค๋ฅธ ๋ฉ”๋ชจ๋ฆฌ ํ˜•์‹์„ ๊ฐ€์ง„ ์ž…๋ ฅ ํ…์„œ๋ฅผ ์ฒ˜๋ฆฌํ•˜๊ณ  ํ•ด๋‹นํ•˜๋Š” ๋” ๋น ๋ฅธ ๊ตฌํ˜„์œผ๋กœ ๋ผ์šฐํŒ…ํ•  ์ˆ˜ ์žˆ๋Š” ํ•ต์‹ฌ CNN ์ž‘์—…
  • JIT ํŒจ์Šค์˜ ๋ฉ”๋ชจ๋ฆฌ ํ˜•์‹์„ ์ถ”๋ก ํ•˜๊ณ  ์ตœ์ ํ™”ํ•˜๋Š” ๊ธฐ๋Šฅ

์ข‹์€ ์ œ์•ˆ! ๋‚ด ์ดํ•ด๊ฐ€ ์˜ฌ๋ฐ”๋ฅธ์ง€ ํ™•์ธํ•˜๊ฒ ์Šต๋‹ˆ๋‹ค(MKL-DNN ํ˜•์‹ ์ฒ˜๋ฆฌ์— ๋Œ€ํ•œ ์ œ์•ˆ ํฌํ•จ):

์ด ์ œ์•ˆ์„ "ํ˜•์‹" ํด๋ž˜์Šค๋กœ ๊ตฌํ˜„ํ–ˆ๋‹ค๊ณ  ์ƒ๊ฐํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. API ์ฟผ๋ฆฌ ๋ฐ ๋ณ€๊ฒฝ์„ ๊ฐ€์ƒ์œผ๋กœ ์ œ๊ณตํ•˜๋Š” ํ•œ MKL-DNN ๋ณตํ•ฉ ํ˜•์‹์— ๋งž๋Š” ์ƒ์†/ํ™•์žฅ์„ ์ˆ˜ํ–‰ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ๋˜๋Š” ํ˜•์‹ ์ฒ˜๋ฆฌ๋ฅผ ์œ„ํ•œ ํ”„๋ ˆ์ž„์›Œํฌ๋ฅผ ์ œ๊ณตํ•˜๋Š” ํ•œ ๋‹ค๋ฅธ ๋ฐฉ๋ฒ•์„ ์‚ฌ์šฉํ•˜์—ฌ ์ค‘์š”ํ•œ ์„ธ๋ถ€ ์‚ฌํ•ญ์„ ์šฐ๋ฆฌ์—๊ฒŒ ์ „๊ฐ€ํ•ฉ๋‹ˆ๋‹ค.

OP ๊ตฌํ˜„์— ๋Œ€ํ•ด ๊ฐ OP๋Š” ์„ฑ๋Šฅ์„ ์ตœ๋Œ€ํ™”ํ•˜๋Š” ์„ ํ˜ธ ํ˜•์‹๊ณผ ์ž‘๋™ํ•˜๋Š” ํ˜ธํ™˜ ํ˜•์‹์„ ๊ฐ€์งˆ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ์š”์†Œ๋ณ„ ์—ฐ์‚ฐ์ž(๋˜๋Š” ๋” ์ผ๋ฐ˜์ ์œผ๋กœ ๋งํ•˜๋ฉด ๋ฉ”๋ชจ๋ฆฌ ์ œํ•œ OP)๋Š” ๊ธฐ๋ณธ ์„ค์ •์ด ์—†๋‹ค๊ณ  ๊ฐ€์ •ํ•ฉ๋‹ˆ๋‹ค. OP๋Š” "ํ˜•์‹" ๊ฐœ์ฒด๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ๊ฒฐ๊ณผ ํ…์„œ๋ฅผ ์ƒ์„ฑํ•ฉ๋‹ˆ๋‹ค. ์ด ํ˜•์‹ ๊ฐœ์ฒด๋Š” ๊ธฐ๋ณธ pytorch ๊ธฐ๋Œ€์น˜์™€ ํ˜ธํ™˜๋˜๋Š” ์ฟผ๋ฆฌ/๋ณ€๊ฒฝ ์˜๋ฏธ ์ฒด๊ณ„๋ฅผ ๋ณด์žฅํ•  ๋ฟ๋งŒ ์•„๋‹ˆ๋ผ ์ตœ์ ํ™”๋œ ํ•จ์ˆ˜์˜ ์ผ๋ จ ๋ฒˆํ˜ธ(์˜ˆ: conv2d(ReLU(conv2d)))๋กœ ํ˜ธ์ถœ๋˜๋Š” ๊ฒฝ์šฐ ํŠน์ • ํ˜•์‹์„ ์ฒ˜๋ฆฌํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ์‚ฌ๋ก€)

@uyongw ์ฒซ ๋ฒˆ์งธ ์˜ˆ์— ๋Œ€ํ•ด ์ข€ ๋” ๋ช…ํ™•ํžˆ ํ•˜๊ณ  ์‹ถ์Šต๋‹ˆ๋‹ค. "๋‚˜๋Š” NCHW ํ…์„œ๋ฅผ ๊ฐ€์ง€๊ณ  ์žˆ๋Š”๋ฐ, ๊ทธ๊ฒƒ์„ ์ด์ƒํ•œ ๋ฐฉ์‹์œผ๋กœ ์กฐ์˜ฎ๊น€ํ–ˆ์Šต๋‹ˆ๋‹ค. ๊ทธ๋ž˜์„œ ์ง€๊ธˆ์€ NWCH์ฒ˜๋Ÿผ ๋ณด์ž…๋‹ˆ๋‹ค. ์ด์ œ NHWC๊ฐ€ ์—ฐ์†์ ์ธ์ง€ ์•Œ๊ณ  ์‹ถ์Šต๋‹ˆ๋‹ค." ๊ทธ๋Ÿฌ๋‚˜ ๊ทธ๊ฒƒ์€ ์ž˜๋ชป๋œ ๊ด€์ ์ž…๋‹ˆ๋‹ค. ๋” ๋‚˜์€ ๊ณต์‹์€ "๋‚˜๋Š” NHWC ํ…์„œ๋ฅผ ๊ฐ€์ง€๊ณ  ์žˆ์œผ๋ฉฐ NCHW ํ…์„œ๋กœ ์ „์น˜ํ–ˆ์Šต๋‹ˆ๋‹ค."์ž…๋‹ˆ๋‹ค.

๋‹ค๋ฅด๊ฒŒ ๋งํ•˜๋ฉด ํ…์„œ์˜ ๋ฌผ๋ฆฌ์  ์ฐจ์›์—๋Š” ๋ณธ์งˆ์ ์ธ ์˜๋ฏธ๊ฐ€ ์—†์Šต๋‹ˆ๋‹ค(๋ณดํญ์„ ๋ฌด์‹œํ•  ๋•Œ). ๋ณดํญ๊ณผ ๊ด€๋ จํ•˜์—ฌ ์ฐธ์กฐํ•˜๋Š” ๋ฐฉ๋ฒ•์„ ๊ณ ๋ คํ•  ๋•Œ๋งŒ ์˜๋ฏธ๋ฅผ ๋ถ€์—ฌํ•ฉ๋‹ˆ๋‹ค.

ํ…์„œ๋ฅผ ์„ค๋ช…ํ•˜๊ธฐ ์œ„ํ•ด ํฌ๊ธฐ, ๋ณดํญ ๋ฐ stride_indexes๊ฐ€ ์žˆ์Šต๋‹ˆ๋‹ค.

๋‚˜๋Š” stride_indexes ๊ฐ€ ๋ฌธ์ œ์— ๋Œ€ํ•ด ์ƒ๊ฐํ•˜๋Š” ํŽธ๋ฆฌํ•œ ๋ฐฉ๋ฒ•์ด๋ผ๊ณ  ์ƒ๊ฐํ•˜์ง€๋งŒ, strides์™€ ์—„๊ฒฉํ•˜๊ฒŒ ์ค‘๋ณต๋ฉ๋‹ˆ๋‹ค. true strides.) @VitalyFedyunin ๊ณผ ์ €๋Š” strides ์ž์ฒด์—์„œ ์ •๋ณด๋ฅผ ์žฌ๊ตฌ์„ฑํ•˜๋Š” ๊ฒƒ์ด

๋”ฐ๋ผ์„œ ๊ธฐ๋ณธ์ ์œผ๋กœ nchw๋งŒ ์—ฐ์†์ ์ด๋ผ๊ณ  ๊ฐ€์ •ํ•ฉ๋‹ˆ๋‹ค.

๋„ค, ์ œ๊ฐ€ ์ฝ์€ ๊ณ„ํš์ž…๋‹ˆ๋‹ค.

@CaoZhongZ

์ด ์ œ์•ˆ์„ "ํ˜•์‹" ํด๋ž˜์Šค๋กœ ๊ตฌํ˜„ํ–ˆ๋‹ค๊ณ  ์ƒ๊ฐํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. API ์ฟผ๋ฆฌ ๋ฐ ๋ณ€๊ฒฝ์„ ๊ฐ€์ƒ์œผ๋กœ ์ œ๊ณตํ•˜๋Š” ํ•œ MKL-DNN ๋ณตํ•ฉ ํ˜•์‹์— ๋งž๋Š” ์ƒ์†/ํ™•์žฅ์„ ์ˆ˜ํ–‰ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ๋˜๋Š” ํ˜•์‹ ์ฒ˜๋ฆฌ๋ฅผ ์œ„ํ•œ ํ”„๋ ˆ์ž„์›Œํฌ๋ฅผ ์ œ๊ณตํ•˜๋Š” ํ•œ ๋‹ค๋ฅธ ๋ฐฉ๋ฒ•์„ ์‚ฌ์šฉํ•˜์—ฌ ์ค‘์š”ํ•œ ์„ธ๋ถ€ ์‚ฌํ•ญ์„ ์šฐ๋ฆฌ์—๊ฒŒ ์ „๊ฐ€ํ•ฉ๋‹ˆ๋‹ค.

๋‚˜๋Š” ๊ทธ๊ฒƒ์ด ์ œ์•ˆ์— ๋Œ€ํ•œ ์ •ํ™•ํ•œ ์„ค๋ช…์ด๋ผ๊ณ  ์ƒ๊ฐํ•˜์ง€ ์•Š์Šต๋‹ˆ๋‹ค. ์—ฌ๊ธฐ์„œ ์ œ์•ˆํ•˜๋Š” ๋ฉ”๋ชจ๋ฆฌ ๋ ˆ์ด์•„์›ƒ ์ง€์›์€ ์ŠคํŠธ๋ผ์ด๋“œ๋กœ ํ‘œํ˜„ํ•  ์ˆ˜ ์žˆ๋Š” ๋ ˆ์ด์•„์›ƒ์ผ ๋ฟ์ž…๋‹ˆ๋‹ค. ์ด ๋ฐฉ๋ฒ•์œผ๋กœ ํ‘œํ˜„ํ•  ์ˆ˜ ์—†๋Š” ๋ชจ๋“  ๊ฒƒ(์˜ˆ: ๋ธ”๋ก ๋ ˆ์ด์•„์›ƒ)์€ ์ด ๋ฐฉ๋ฒ•์œผ๋กœ ์ž‘๋™ํ•˜์ง€ ์•Š์œผ๋ฉฐ ๋” ๋ฌด๊ฑฐ์šด "๋ ˆ์ด์•„์›ƒ" ๋ฉ”์ปค๋‹ˆ์ฆ˜์— ์˜ํ•ด ์ง€์›๋˜์–ด์•ผ ํ•ฉ๋‹ˆ๋‹ค.

๋‹ค๋ฅด๊ฒŒ ๋งํ•˜๋ฉด ํ…์„œ์˜ ๋ฌผ๋ฆฌ์  ์ฐจ์›์—๋Š” ๋ณธ์งˆ์ ์ธ ์˜๋ฏธ๊ฐ€ ์—†์Šต๋‹ˆ๋‹ค(๋ณดํญ์„ ๋ฌด์‹œํ•  ๋•Œ). ๋ณดํญ๊ณผ ๊ด€๋ จํ•˜์—ฌ ์ฐธ์กฐํ•˜๋Š” ๋ฐฉ๋ฒ•์„ ๊ณ ๋ คํ•  ๋•Œ๋งŒ ์˜๋ฏธ๋ฅผ ๋ถ€์—ฌํ•ฉ๋‹ˆ๋‹ค.

๋ถ€๋ถ„์ ์œผ๋กœ ๋™์˜ํ•ฉ๋‹ˆ๋‹ค :-) ๊ทธ๋Ÿฌ๋‚˜ ์ด ํŠน์ •ํ•œ ๋ฌธ์ œ์— ๋Œ€ํ•ด์„œ๋Š” ์•„๋‹™๋‹ˆ๋‹ค. ์ด๋ฏธ nhwc ํ…์„œ๊ฐ€ ์žˆ๋‹ค๊ณ  ๊ฐ€์ •ํ•ด ๋ณด๊ฒ ์Šต๋‹ˆ๋‹ค. ๊ทธ๋Ÿฐ ๋‹ค์Œ nwhc๋กœ ๋ณ€๊ฒฝํ•ฉ๋‹ˆ๋‹ค. nhwc๋กœ ๋” ์ˆœ์—ดํ•œ ๋‹ค์Œ contiguous()๋ฅผ ์ˆ˜ํ–‰ํ•˜๊ณ  ์‹ถ์Šต๋‹ˆ๋‹ค. ๊ทธ๋Ÿฌ๋‚˜ ๋‚˜๋Š” ๊ทธ๊ฒƒ์„ ์ด๋ฏธ ์—ฐ์†์ ์œผ๋กœ ์–ป์—ˆ์Šต๋‹ˆ๋‹ค. ํ˜ผ๋ž€์Šค๋Ÿฝ์ง€ ์•Š์Šต๋‹ˆ๊นŒ?

๋‚˜๋Š” stride_indexes๊ฐ€ ๋ฌธ์ œ์— ๋Œ€ํ•ด ์ƒ๊ฐํ•˜๋Š” ํŽธ๋ฆฌํ•œ ๋ฐฉ๋ฒ•์ด๋ผ๊ณ  ์ƒ๊ฐํ•˜์ง€๋งŒ, stride์™€ ์—„๊ฒฉํ•˜๊ฒŒ ์ค‘๋ณต๋ฉ๋‹ˆ๋‹ค. ์™œ๋ƒํ•˜๋ฉด ๋‹น์‹ ์ด ๋งํ•˜๋Š” ๋ชจ๋“  ๊ฒƒ์€ "์ด (์—ญ?) ์ˆœ์—ด์„ ๋ณดํญ์— ์ ์šฉํ•˜๊ณ  ๊ทธ๊ฒƒ์„ ์ง„์ •ํ•œ ๋ณดํญ์œผ๋กœ ์ทจ๊ธ‰ํ•˜๊ธฐ ๋•Œ๋ฌธ์ž…๋‹ˆ๋‹ค.)

IMHO, nhwc(๋ฌผ๋ฆฌ์ )์— ๋ณดํญ์ด ์žˆ๋Š” ๊ฒฝ์šฐ ๋ณดํญ์ด ์ค‘๋ณต๋˜์ง€ ์•Š์Šต๋‹ˆ๋‹ค. ํฌ๊ธฐ(๋…ผ๋ฆฌ)๊ฐ€ ์žˆ๋Š” ์˜ฌ๋ฐ”๋ฅธ ๋งคํ•‘์ด ํ•„์š”ํ•˜๊ธฐ ๋•Œ๋ฌธ์ž…๋‹ˆ๋‹ค. ๊ทธ๋ ‡์ง€ ์•Š์œผ๋ฉด ์‹ค์ œ ์ˆœ์„œ๋ฅผ ๋งํ•  ๋ฐฉ๋ฒ•์ด ์—†์Šต๋‹ˆ๋‹ค.

BTW ์—ญ ๋งคํ•‘์„ ์‚ฌ์šฉํ•˜๋Š” ๋” ๊ฐ„๋‹จํ•œ ์ ‘๊ทผ ๋ฐฉ์‹์ด ์žˆ์Šต๋‹ˆ๋‹ค. nchw์˜ ๊ฒฝ์šฐ (0, 1, 2, 3)์ด๊ณ  nhwc์˜ ๊ฒฝ์šฐ (0, 2, 3, 1) ๋Œ€์‹  (0, 3, 1, 2)์ž…๋‹ˆ๋‹ค. ์ฆ‰, stride_index ์ž์ฒด๋„ ํ•ญ์ƒ NCHW์ž…๋‹ˆ๋‹ค. ๊ทธ๋Ÿฌ๋‚˜ ๋ฌธ์ œ๋Š” nChw16c ๋˜๋Š” OIhw16i16o์™€ ๊ฐ™์€ ์ฐจ๋‹จ๋œ ํ˜•์‹์œผ๋กœ ํ™•์žฅํ•  ์ˆ˜ ์—†๋‹ค๋Š” ๊ฒƒ์ž…๋‹ˆ๋‹ค.

์ฐจ๋‹จ๋œ ํ˜•์‹์—๋Š” ์™„์ „ํžˆ ๋‹ค๋ฅธ ์—ฐ์‚ฐ์ž ๊ตฌํ˜„ ์ง‘ํ•ฉ์ด ํ•„์š”ํ•ฉ๋‹ˆ๋‹ค. ์ด๋Ÿฌํ•œ ์ด์œ ๋กœ ์šฐ๋ฆฌ๋Š” ์ •์˜์ƒ ๋ชจ๋“  ๊ธฐ์กด ์—ฐ์‚ฐ์ž์™€ ์นœ์ˆ™ํ•˜๊ณ  ๋™์ผํ•˜๊ฑฐ๋‚˜ ๋” ๋‚˜์€ ์„ฑ๋Šฅ์œผ๋กœ ์ž‘๋™ํ•˜๋Š” '๋ฉ”๋ชจ๋ฆฌ ํ˜•์‹'๊ณผ ํ˜ผํ•ฉํ•˜์ง€ ์•Š๋Š” ๊ฒƒ์„ ์„ ํ˜ธํ•ฉ๋‹ˆ๋‹ค.

๋ถ€๋ถ„์ ์œผ๋กœ ๋™์˜ํ•ฉ๋‹ˆ๋‹ค :-) ๊ทธ๋Ÿฌ๋‚˜ ์ด ํŠน์ •ํ•œ ๋ฌธ์ œ์— ๋Œ€ํ•ด์„œ๋Š” ์•„๋‹™๋‹ˆ๋‹ค. ์ด๋ฏธ nhwc ํ…์„œ๊ฐ€ ์žˆ๋‹ค๊ณ  ๊ฐ€์ •ํ•ด ๋ณด๊ฒ ์Šต๋‹ˆ๋‹ค. ๊ทธ๋Ÿฐ ๋‹ค์Œ nwhc๋กœ ๋ณ€๊ฒฝํ•ฉ๋‹ˆ๋‹ค. nhwc๋กœ ๋” ์ˆœ์—ดํ•œ ๋‹ค์Œ contiguous()๋ฅผ ์ˆ˜ํ–‰ํ•˜๊ณ  ์‹ถ์Šต๋‹ˆ๋‹ค. ๊ทธ๋Ÿฌ๋‚˜ ๋‚˜๋Š” ๊ทธ๊ฒƒ์„ ์ด๋ฏธ ์—ฐ์†์ ์œผ๋กœ ์–ป์—ˆ์Šต๋‹ˆ๋‹ค. ํ˜ผ๋ž€์Šค๋Ÿฝ์ง€ ์•Š์Šต๋‹ˆ๊นŒ?

์ผ๋ถ€ ์šฉ์–ด๋ฅผ ๊ตฌ์–ด์ฒด๋กœ ์‚ฌ์šฉํ•˜๊ณ  ์ •ํ™•์„ฑ์ด ํ•„์š”ํ•˜๊ธฐ ๋•Œ๋ฌธ์— ๊ท€ํ•˜์˜ ์˜ˆ๋ฅผ ์ดํ•ดํ•˜๊ธฐ ์–ด๋ ต์Šต๋‹ˆ๋‹ค. ๋ง์”€ํ•˜์‹  ๋‚ด์šฉ์„ ์ œ๊ฐ€ ํ•ด์„ํ•˜๋Š” ๋ฐฉ๋ฒ•์€ ๋‹ค์Œ๊ณผ ๊ฐ™์Šต๋‹ˆ๋‹ค.

  • ์ด ์ œ์•ˆ์— ๋”ฐ๋ผ "nhwc" ํ…์„œ๋Š” "๋ฌผ๋ฆฌ์  ๋ ˆ์ด์•„์›ƒ์ด NHWC์ด์ง€๋งŒ ๋…ผ๋ฆฌ์  ๋ ˆ์ด์•„์›ƒ์ด NCHW๊ฐ€ ๋˜๋„๋ก ์ŠคํŠธ๋ผ์ด๋“œ๋œ ํ…์„œ"์ž…๋‹ˆ๋‹ค.
  • "(๋…ผ๋ฆฌ์  ๋ ˆ์ด์•„์›ƒ์ด NCHW์ธ ํ…์„œ) ํ…์„œ๋ฅผ (๋…ผ๋ฆฌ์  ๋ ˆ์ด์•„์›ƒ) NWHC๋กœ ์น˜ํ™˜"ํ•˜๋Š” ๊ฒƒ์€ y = x.permute(0, 2, 3, 1) ๋ฅผ ์‹คํ–‰ํ•˜๋Š” ๊ฒƒ์ž…๋‹ˆ๋‹ค. ์™œ๋ƒํ•˜๋ฉด ๋ฌผ๋ฆฌ์  ๋ ˆ์ด์•„์›ƒ์ด ์•„๋‹ˆ๋ผ ๋…ผ๋ฆฌ์  ๋ ˆ์ด์•„์›ƒ์„ ์น˜ํ™˜ํ•˜๊ณ  ์žˆ๊ธฐ ๋•Œ๋ฌธ์ž…๋‹ˆ๋‹ค. (์›๋ž˜ ๊ฒŒ์‹œ๋ฌผ์—์„œ ์ˆœ์—ด x.permute(0, 3, 1, 2) ์„ ์–ธ๊ธ‰ํ–ˆ๊ธฐ ๋•Œ๋ฌธ์— ์ด๊ฒƒ์ด ์˜๋ฏธํ•˜๋Š” ๋ฐ”๊ฐ€ ์•„๋‹ˆ๋ผ๊ณ  ์ƒ๊ฐํ•ฉ๋‹ˆ๋‹ค.
  • ๊ทธ๋Ÿฐ ๋‹ค์Œ (๋…ผ๋ฆฌ์  ๋ ˆ์ด์•„์›ƒ) NWHC ํ…์„œ๋ฅผ (๋…ผ๋ฆฌ์  ๋ ˆ์ด์•„์›ƒ) NHWC๋กœ ์ถ”๊ฐ€ ์ˆœ์—ด์€ ์ˆœ์—ด z = y.permute(0, 2, 3, 1) ์„ ์ ์šฉํ•˜๋Š” ๊ฒƒ์ž…๋‹ˆ๋‹ค. ์ด์ œ ๋…ผ๋ฆฌ์  ๋ ˆ์ด์•„์›ƒ์ด ๋ฌผ๋ฆฌ์  ๋ ˆ์ด์•„์›ƒ๊ณผ ์ผ์น˜ํ•˜๋Š” ํ…์„œ๊ฐ€ ์žˆ์Šต๋‹ˆ๋‹ค. ์ด๊ฒƒ์€ ์šฐ๋ฆฌ๊ฐ€ z.contiguous() ๋ฌป๋Š”๋‹ค๋ฉด ์šฐ๋ฆฌ๋Š” ์ฐธ์ด ๋  ๊ฒƒ์ด๋ผ๋Š” ๊ฒƒ์„ ์˜๋ฏธํ•ฉ๋‹ˆ๋‹ค(๊ทธ๋ฆฌ๊ณ  ํ˜ผ๋ž€์Šค๋Ÿฝ๊ฒŒ๋„ z.contiguous(memory_layout=NCHW) ๋„ ์ฐธ์ด ๋  ๊ฒƒ์ž…๋‹ˆ๋‹ค.) ๊ทธ๋Ÿฌ๋‚˜ NHWC ์—ฐ์†์ ์ด์ง€๋Š” ์•Š์„ ๊ฒƒ์ž…๋‹ˆ๋‹ค.

๋‚˜๋Š” ์ด๊ฒƒ์ด ๋‹น์‹ ์ด ์—ผ๋‘์— ๋‘” ์˜ˆ๋ผ๊ณ  ์ƒ๊ฐํ•˜์ง€ ์•Š์Šต๋‹ˆ๋‹ค. ์ด ๊ฒฝ์šฐ "์ˆœ์—ด"์ด ์˜๋ฏธํ•˜๋Š” ๋ฐ”์— ๋Œ€ํ•ด ๋” ์ •ํ™•ํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค.

IMHO, nhwc(๋ฌผ๋ฆฌ์ )์— ๋ณดํญ์ด ์žˆ๋Š” ๊ฒฝ์šฐ ๋ณดํญ์ด ์ค‘๋ณต๋˜์ง€ ์•Š์Šต๋‹ˆ๋‹ค. ํฌ๊ธฐ(๋…ผ๋ฆฌ)๊ฐ€ ์žˆ๋Š” ์˜ฌ๋ฐ”๋ฅธ ๋งคํ•‘์ด ํ•„์š”ํ•˜๊ธฐ ๋•Œ๋ฌธ์ž…๋‹ˆ๋‹ค. ๊ทธ๋ ‡์ง€ ์•Š์œผ๋ฉด ์‹ค์ œ ์ˆœ์„œ๋ฅผ ๋งํ•  ๋ฐฉ๋ฒ•์ด ์—†์Šต๋‹ˆ๋‹ค.

์ด ์ œ์•ˆ์˜ ํ•ต์‹ฌ์ž…๋‹ˆ๋‹ค : ๋…ผ๋ฆฌ์  ์ธ ๋ ˆ์ด์•„์›ƒ์œผ๋กœ ์šฐ๋ฆฌ ํŠน๊ถŒ NCHW, ํ•ญ์ƒ. ๋”ฐ๋ผ์„œ ๋‚ด๊ฐ€ ๋ชจ๋ฅด๋Š” 4D ํ…์„œ๊ฐ€ ์žˆ๋Š” ๊ฒฝ์šฐ ๋…ผ๋ฆฌ์  ๋ ˆ์ด์•„์›ƒ์ด NCHW๋ผ๊ณ  ๊ฐ€์ • ํ•ฉ๋‹ˆ๋‹ค. ๊ทธ๊ฒƒ์€ ๋ชจํ˜ธ์„ฑ์„ ์ œ๊ฑฐํ•ฉ๋‹ˆ๋‹ค. ๋…ผ๋ฆฌ์  ๋ ˆ์ด์•„์›ƒ์ด NCHW๊ฐ€ ์•„๋‹Œ ํ…์„œ๋ฅผ ์ฒ˜๋ฆฌํ•˜๊ณ  ์‹ถ๋‹ค๋ฉด ๋ช…์‹œ๋œ API๊ฐ€ ์‚ถ์„ ์กฐ๊ธˆ ์–ด๋ ต๊ฒŒ ๋งŒ๋“ ๋‹ค๊ณ  ์ƒ๊ฐํ•ฉ๋‹ˆ๋‹ค.

@dzhulgakov

์ž‘์—…์€ ๋ฉ”๋ชจ๋ฆฌ ํ˜•์‹ ๋™์ž‘์„ ์œ ์ง€ํ•ฉ๋‹ˆ๋‹ค.

๋ฌผ๋ฆฌ์  NHWC ํ…์„œ๊ฐ€ ์ˆœ์ „ํžˆ ์ŠคํŠธ๋ผ์ด๋“œ๋ฅผ ํ†ตํ•ด ๋ฐœ์ƒํ•  ์ˆ˜ ์žˆ๋Š” ๊ฒฝ์šฐ ๋ฉ”๋ชจ๋ฆฌ ํ˜•์‹ ํƒœ๊ทธ๊ฐ€ ์žˆ์„ ๋•Œ๋งŒ ๋ฉ”๋ชจ๋ฆฌ ํ˜•์‹์„ ๋ณด์กดํ•˜๋„๋ก ํ•˜์ง€ ์•Š๋Š” ํ•œ ์ด๊ฒƒ์€ ๊ธฐ์ˆ ์ ์œผ๋กœ BC ๊นจ๋Š” ๊ฒƒ์ž…๋‹ˆ๋‹ค. ์ œ์•ˆ์ด ํ˜„์žฌ ๋ฌด์—‡์„ ์ œ์•ˆํ•˜๊ณ  ์žˆ๋Š”์ง€ ์ž˜ ๋ชจ๋ฅด๊ฒ ์Šต๋‹ˆ๋‹ค.) ์ด๊ฒƒ์ด ์‹ค์ œ๋กœ ์‹ค์ œ๋กœ ๋ˆ„๊ตฐ๊ฐ€์˜ ์ฝ”๋“œ๋ฅผ ์†์ƒ์‹œํ‚ค๋Š”์ง€ ํ™•์‹คํ•˜์ง€ ์•Š์Šต๋‹ˆ๋‹ค.

๋ฌผ๋ฆฌ์  NHWC ํ…์„œ๊ฐ€ ์ˆœ์ „ํžˆ ์ŠคํŠธ๋ผ์ด๋“œ๋ฅผ ํ†ตํ•ด ๋ฐœ์ƒํ•  ์ˆ˜ ์žˆ๋Š” ๊ฒฝ์šฐ ๋ฉ”๋ชจ๋ฆฌ ํ˜•์‹ ํƒœ๊ทธ๊ฐ€ ์žˆ์„ ๋•Œ๋งŒ ๋ฉ”๋ชจ๋ฆฌ ํ˜•์‹์„ ๋ณด์กดํ•˜๋„๋ก ํ•˜์ง€ ์•Š๋Š” ํ•œ ์ด๊ฒƒ์€ ๊ธฐ์ˆ ์ ์œผ๋กœ BC ๊นจ๋Š” ๊ฒƒ์ž…๋‹ˆ๋‹ค. ์ œ์•ˆ์ด ํ˜„์žฌ ๋ฌด์—‡์„ ์ œ์•ˆํ•˜๊ณ  ์žˆ๋Š”์ง€ ์ž˜ ๋ชจ๋ฅด๊ฒ ์Šต๋‹ˆ๋‹ค.) ์ด๊ฒƒ์ด ์‹ค์ œ๋กœ ์‹ค์ œ๋กœ ๋ˆ„๊ตฐ๊ฐ€์˜ ์ฝ”๋“œ๋ฅผ ์†์ƒ์‹œํ‚ค๋Š”์ง€ ํ™•์‹คํ•˜์ง€ ์•Š์Šต๋‹ˆ๋‹ค.

๋ฉ”๋ชจ๋ฆฌ ํ˜•์‹์„ '๊ณ ์ •'์œผ๋กœ ๋งŒ๋“ค ์ˆ˜ ์žˆ๋‹ค๊ณ  ๊ฐ€์ •ํ•ฉ๋‹ˆ๋‹ค. ๋ฉ”๋ชจ๋ฆฌ ํ˜•์‹ ํ…์„œ์— ๋Œ€ํ•œ ์—ฐ์‚ฐ์€ ๋ฉ”๋ชจ๋ฆฌ ํ˜•์‹ ํ…์„œ๋ฅผ ์ƒ์„ฑํ•ฉ๋‹ˆ๋‹ค. ๊ทธ๋Ÿฌ๋ฉด BC ๋ฌธ์ œ๊ฐ€ ํ•ด๊ฒฐ๋ฉ๋‹ˆ๋‹ค.

๊ทธ๋Ÿฌ๋‚˜ ํ…์„œ์˜ ๋ฉ”๋ชจ๋ฆฌ ํ˜•์‹์ด ๋‹ค๋ฅธ ๊ฒฝ์šฐ ์ด์ง„(๋˜๋Š” ๋” ๋งŽ์€ ๋ฉค๋ฒ„) ์ž‘์—…์˜ ๋™์ž‘์„ ์ •์˜ํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค.

@ezyang ์˜ค ๋ฐฉ๊ธˆ ์œ„์˜ ๋‹ต๋ณ€์— ์˜คํƒ€๊ฐ€ ์žˆ์Œ์„ ๋ฐœ๊ฒฌํ–ˆ์Šต๋‹ˆ๋‹ค. (์ฃ„์†กํ•ฉ๋‹ˆ๋‹ค. ๊ทธ๋Ÿฌ๋‚˜ ์›๋ž˜์˜ ์˜ˆ๋Š” ์—ฌ์ „ํžˆ ์ •ํ™•ํ•ฉ๋‹ˆ๋‹ค.) ์•„๋ž˜์™€ ๊ฐ™์ด ๋‹ค์‹œ ์„ค๋ช…ํ•˜๊ฒ ์Šต๋‹ˆ๋‹ค.

  1. NCHW ํ…์„œ๊ฐ€ ์žˆ์Šต๋‹ˆ๋‹ค(๋ฌผ๋ฆฌ์ ์œผ๋กœ, ์—ฐ์†์ ).
  2. ๊ทธ๋Ÿฐ ๋‹ค์Œ NWHC(๋…ผ๋ฆฌ์ ์œผ๋กœ)๋กœ ์น˜ํ™˜ํ•ฉ๋‹ˆ๋‹ค.
  3. contiguous() ํ˜ธ์ถœ์ด ๋’ค๋”ฐ๋ฅด๋Š” NHWC๋กœ ๋” ์น˜ํ™˜ํ•˜๊ณ  ์‹ถ์Šต๋‹ˆ๋‹ค.
  4. NHWC(๋ฌผ๋ฆฌ์ )๋กœ ์‚ฌ์šฉํ•ฉ๋‹ˆ๋‹ค.

ํ•˜์ง€๋งŒ 2๋‹จ๊ณ„ ์ดํ›„์— ์ด๋ฏธ NHWC ์—ฐ์†์„ฑ์„ ์–ป์—ˆ์Šต๋‹ˆ๋‹ค. ๊ทธ๋Ÿฐ ๋‹ค์Œ 3๋‹จ๊ณ„๋ฅผ ๊ฑด๋„ˆ๋›ฐ๊ณ  4๋‹จ๊ณ„์—์„œ ์ง์ ‘ NHWC๋กœ ์‚ฌ์šฉํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ํ•˜์ง€๋งŒ ์ด๊ฒƒ์€ ํ…์„œ์˜ ๋ฌผ๋ฆฌ์  ์ˆœ์„œ๊ฐ€ ์ „ํ˜€ ๋ณ€๊ฒฝ๋˜์ง€ ์•Š๊ธฐ ๋•Œ๋ฌธ์— ํ™•์‹คํžˆ ์˜ณ์ง€ ์•Š์Šต๋‹ˆ๋‹ค.

์ฐจ๋‹จ๋œ ํ˜•์‹์—๋Š” ์™„์ „ํžˆ ๋‹ค๋ฅธ ์—ฐ์‚ฐ์ž ๊ตฌํ˜„ ์ง‘ํ•ฉ์ด ํ•„์š”ํ•ฉ๋‹ˆ๋‹ค. ์ด๋Ÿฌํ•œ ์ด์œ ๋กœ ์šฐ๋ฆฌ๋Š” ์ •์˜์ƒ ๋ชจ๋“  ๊ธฐ์กด ์—ฐ์‚ฐ์ž์™€ ์นœ์ˆ™ํ•˜๊ณ  ๋™์ผํ•˜๊ฑฐ๋‚˜ ๋” ๋‚˜์€ ์„ฑ๋Šฅ์œผ๋กœ ์ž‘๋™ํ•˜๋Š” '๋ฉ”๋ชจ๋ฆฌ ํ˜•์‹'๊ณผ ํ˜ผํ•ฉํ•˜์ง€ ์•Š๋Š” ๊ฒƒ์„ ์„ ํ˜ธํ•ฉ๋‹ˆ๋‹ค.

์˜ˆ, ์ฒซ ๋ฒˆ์งธ ๋‹จ๊ณ„๋กœ NHWC๋ฅผ ํ™œ์„ฑํ™”ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ๊ทธ๋Ÿฌ๋‚˜ ์‹ค์ œ๋กœ ์ฐจ๋‹จ๋œ ํ˜•์‹์ด ์™„์ „ํžˆ ๋‹ค๋ฅธ ๊ฒƒ์ด๋ผ๊ณ  ์ƒ๊ฐํ•˜์ง€ ์•Š์Šต๋‹ˆ๋‹ค. ๊ทธ๊ฒƒ์€ ์ž์—ฐ์Šค๋Ÿฝ๊ฒŒ ํ‘œํ˜„๋  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค(์•ฝ๊ฐ„์˜ ์ข‹์€ ์ถ”์ƒํ™”๋กœ). ์ผ๋ฐ˜์ ์ธ ํ˜•์‹ ์„ค๋ช…์ด ์žˆ๋Š” ๊ฒฝ์šฐ ๋‹ค๋ฅธ ์‚ฌ๋žŒ๋“ค์€ ์ž„์˜์˜ ์ฐจ๋‹จ/๋ณดํญ์œผ๋กœ ์ƒˆ ํ˜•์‹์„ ๋“ฑ๋กํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

๋”๊ตฐ๋‹ค๋‚˜ ์ด๋ฏธ ์ง€์›์„ ์ฐจ๋‹จํ–ˆ๋‹ค๋ฉด ๊ธฐ๋ณธ์ด ๋˜๋Š” ๋ชจ๋“  ๊ฒƒ์„ ์‹คํ–‰ํ•˜๊ธฐ ์œ„ํ•ด ์ˆจ๊ฒจ์ง„ ๊ตฌ์„ฑ์„ ๋งŒ๋“œ๋Š” ๋ฐ ์‹ ๊ฒฝ ์“ฐ์ง€ ์•Š์•„๋„ ๋ฉ๋‹ˆ๋‹ค. ๋‚ด๋ถ€์— ์•”์‹œ์  ์„ธ๊ณ„๊ฐ€ ์ƒ์„ฑ๋˜๊ณ  ๋‘ ์„ธ๊ณ„ ์‚ฌ์ด์˜ ์‹œ์ž‘/๋์ด ๋ฌธ์ œ๊ฐ€ ๋  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

์–ด์จŒ๋“  ์ฐจ๋‹จ๋œ ํ˜•์‹์— ๋Œ€ํ•ด ์ƒ๊ฐํ•˜๊ธฐ์—๋Š” ๋„ˆ๋ฌด ๋ฉ€๋ฆฌ ๋–จ์–ด์ ธ ์žˆ์„ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ํ•˜์ง€๋งŒ ๊ฐ€๋Šฅํ•˜๋ฉด ๋””์ž์ธ์„ ํ™•์žฅ ๊ฐ€๋Šฅํ•˜๊ฒŒ ๋งŒ๋“œ๋Š” ๊ฒƒ์ด ๋” ๋‚ซ๋‹ค๊ณ  ์ƒ๊ฐํ•ฉ๋‹ˆ๋‹ค.

ํ•˜์ง€๋งŒ 2๋‹จ๊ณ„ ์ดํ›„์— ์ด๋ฏธ NHWC ์—ฐ์†์„ฑ์„ ์–ป์—ˆ์Šต๋‹ˆ๋‹ค. ๊ทธ๋Ÿฐ ๋‹ค์Œ 3๋‹จ๊ณ„๋ฅผ ๊ฑด๋„ˆ๋›ฐ๊ณ  4๋‹จ๊ณ„์—์„œ ์ง์ ‘ NHWC๋กœ ์‚ฌ์šฉํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ํ•˜์ง€๋งŒ ์ด๊ฒƒ์€ ํ…์„œ์˜ ๋ฌผ๋ฆฌ์  ์ˆœ์„œ๊ฐ€ ์ „ํ˜€ ๋ณ€๊ฒฝ๋˜์ง€ ์•Š๊ธฐ ๋•Œ๋ฌธ์— ํ™•์‹คํžˆ ์˜ณ์ง€ ์•Š์Šต๋‹ˆ๋‹ค.

์•Œ๊ฒ ์Šต๋‹ˆ๋‹ค. ์ด์ œ ๊ท€ํ•˜์˜ ์˜ˆ๋ฅผ ์ดํ•ดํ•ฉ๋‹ˆ๋‹ค. ์‹ค์ œ๋กœ 2๋‹จ๊ณ„์—์„œ ๋ฉˆ์ถ”๊ณ  NCHW ํ…์„œ์ธ ๊ฒƒ์ฒ˜๋Ÿผ ์‚ฌ์šฉํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ์ด ๊ฒฝ์šฐ W๋ฅผ C ๋“ฑ์œผ๋กœ ๋ถ€์ ์ ˆํ•˜๊ฒŒ ํ•ด์„ํ•˜๊ฒŒ ๋ฉ๋‹ˆ๋‹ค. ์ด๊ฒƒ์€ ํ™•์‹คํžˆ ๋ณดํญ ๊ธฐ๋ฐ˜ ๊ตฌํ˜„์˜ ๋‹จ์ ์ž…๋‹ˆ๋‹ค( @dzhulgakov , ์•„๋งˆ๋„ ์ด๊ฒƒ์„ ์ œ์•ˆ์— ์ถ”๊ฐ€ํ•ด์•ผ ํ•  ๊ฒƒ์ž…๋‹ˆ๋‹ค). ์ œ์•ˆ์„œ์—๋Š” ์ด ๊ฒฝ์šฐ์— ๋Œ€ํ•œ ๋ช‡ ๊ฐ€์ง€ ์กฐํ•ญ์ด ์žˆ์Šต๋‹ˆ๋‹ค.

์œ„์˜ ๋ฌธ์ œ๋ฅผ ํ•ด๊ฒฐํ•˜๊ธฐ ์œ„ํ•ด ์ดˆ๊ธฐ ์ œ์•ˆ์€ ํ…์„œ์—์„œ ์ˆ˜ํ–‰๋œ ๋งˆ์ง€๋ง‰ to(memory_format) ํ˜ธ์ถœ์„ ๊ธฐ๋กํ•˜๋Š” "์†Œํ”„ํŠธ" ๋ฉ”๋ชจ๋ฆฌ ํ˜•์‹ ํƒœ๊ทธ๋ฅผ ํ…์„œ์— ๋„์ž…ํ•˜๋Š” ๊ฒƒ์ž…๋‹ˆ๋‹ค. ์šด์˜์ž๋Š” ์ด ์ฃผ์„์„ ์ถœ๋ ฅ์— ์ „ํŒŒํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค. ์ฃผ์„์€ "์†Œํ”„ํŠธ"์ด๋ฏ€๋กœ ๋ถˆ์ผ์น˜ ์ฃผ์„์— ๋Œ€ํ•œ ํ•˜๋“œ ์˜ค๋ฅ˜๊ฐ€ ๋ฐœ์ƒํ•˜์ง€ ์•Š๊ณ  ํ”„๋กœํŒŒ์ผ๋ง ๋ชจ๋“œ์—์„œ ๊ฒฝ๊ณ ๊ฐ€ ์ƒ์„ฑ๋ฉ๋‹ˆ๋‹ค.

์†Œํ”„ํŠธ ๋ฉ”๋ชจ๋ฆฌ ํ˜•์‹ ํƒœ๊ทธ๋ฅผ ์‚ฌ์šฉํ•˜๋ฉด ์ˆœ์—ดํ•œ NCHW ํ…์„œ์™€ ์‹ค์ œ๋กœ ๋ฌผ๋ฆฌ์ ์œผ๋กœ NHWC์ธ ํ…์„œ๋ฅผ ๊ตฌ๋ณ„ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ๊ทธ๋Ÿฌ๋‚˜ ํ˜„์žฌ ํ˜•์‹์˜ ์†Œํ”„ํŠธ ํƒœ๊ทธ๋Š” ๊ตฌ์†๋ ฅ์ด ์—†์œผ๋ฏ€๋กœ ์‹ค์ œ๋กœ ์ด ๊ฒฝ์šฐ์— ์–ผ๋งˆ๋‚˜ ์œ ์šฉํ•œ์ง€ ์ž˜ ๋ชจ๋ฅด๊ฒ ์Šต๋‹ˆ๋‹ค.

๋ฌธ์ œ๋ฅผ ํ•ด๊ฒฐํ•˜๋Š” ๋˜ ๋‹ค๋ฅธ ๋ฐฉ๋ฒ•์€ ๋ช…๋ช…๋œ ํ…์„œ๋ฅผ ์‚ฌ์šฉํ•˜๋Š” ๊ฒƒ์ž…๋‹ˆ๋‹ค. ๋ช…๋ช…๋œ ํ…์„œ๋ฅผ ์‚ฌ์šฉํ•˜๋ฉด (๋…ผ๋ฆฌ์ ) ์ฐจ์›์˜ ์ด๋ฆ„์„ ์‚ฌ์šฉํ•˜์—ฌ ํ…์„œ๋ฅผ NCHW(๊ฐ€์ • ๊ธฐ๋ณธ๊ฐ’)๋กœ ๋ณด๊ณ  ์žˆ๋Š”์ง€ ์•„๋‹ˆ๋ฉด ๋‹ค๋ฅธ ๊ฒƒ์œผ๋กœ ๋ณด๊ณ  ์žˆ๋Š”์ง€ ํŒŒ์•…ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

๊ทธ๋Ÿฌ๋‚˜ ์‹ค์ œ๋กœ ์ฐจ๋‹จ๋œ ํ˜•์‹์ด ์™„์ „ํžˆ ๋‹ค๋ฅธ ๊ฒƒ์ด๋ผ๊ณ  ์ƒ๊ฐํ•˜์ง€ ์•Š์Šต๋‹ˆ๋‹ค. ๊ทธ๊ฒƒ์€ ์ž์—ฐ์Šค๋Ÿฝ๊ฒŒ ํ‘œํ˜„๋  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค(์•ฝ๊ฐ„์˜ ์ข‹์€ ์ถ”์ƒํ™”๋กœ). ์ผ๋ฐ˜์ ์ธ ํ˜•์‹ ์„ค๋ช…์ด ์žˆ๋Š” ๊ฒฝ์šฐ ๋‹ค๋ฅธ ์‚ฌ๋žŒ๋“ค์€ ์ž„์˜์˜ ์ฐจ๋‹จ/๋ณดํญ์œผ๋กœ ์ƒˆ ํ˜•์‹์„ ๋“ฑ๋กํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

์—ฌ๊ธฐ์— ์ฃผ์ œ์— ๋Œ€ํ•œ ๋” ๋งŽ์€ ์„ค๋ช…์ด ์žˆ์Šต๋‹ˆ๋‹ค: https://github.com/pytorch/pytorch/issues/16038#issuecomment -454490374

@ezyang ๋‹ต๋ณ€ ๊ฐ์‚ฌํ•ฉ๋‹ˆ๋‹ค. ์˜ˆ, ์†Œํ”„ํŠธ ํ˜•์‹ ํƒœ๊ทธ๊ฐ€ ๋„์›€์ด ๋  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ๋ฌธ์ œ๋Š” ์ฐจ์› ์ˆœ์„œ๊ฐ€ ์ž„์˜์ ์ผ ์ˆ˜ ์žˆ์œผ๋ฏ€๋กœ ์ถฉ๋ถ„ํžˆ ์œ ์—ฐํ•˜์ง€ ์•Š์„ ์ˆ˜ ์žˆ๋‹ค๋Š” ๊ฒƒ์ž…๋‹ˆ๋‹ค. ๋˜ํ•œ ์ž์ฒด์ ์œผ๋กœ ๊ณ„์‚ฐํ•  ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค. ๋ช…๋ช…๋œ ํ…์„œ๋Š” ๊ฐ ์ฐจ์›์— ๋Œ€ํ•œ ์˜๋ฏธ๋ก ์  ์˜๋ฏธ๋ฅผ ๊ฐ–์ง€๋งŒ ์ง€์›ํ•˜๊ธฐ ์œ„ํ•ด ๋” ๋งŽ์€ ๊ธฐ๋Šฅ์ด ํ•„์š”ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

๊ฐœ์ธ์ ์œผ๋กœ ๋‚˜๋Š” ์ด๊ฒƒ์ด ๋ณดํญ ์ˆœ์„œ(๋ฌผ๋ฆฌ์ )์—์„œ NCHW ํฌ๊ธฐ ์ˆœ์„œ(๋…ผ๋ฆฌ์ )๋กœ์˜ ๋งต์„ ๋„์ž…ํ•˜์—ฌ ํ•ด๊ฒฐํ•  ์ˆ˜ ์žˆ๋‹ค๊ณ  ์ƒ๊ฐํ•ฉ๋‹ˆ๋‹ค. ์œ„์—์„œ ์ œ์•ˆํ•œ ๊ฒƒ์ฒ˜๋Ÿผ NCHW์˜ ๊ฒฝ์šฐ ํ˜„์žฌ ๋””์ž์ธ๊ณผ ๊ฑฐ์˜ ๋™์ผํ•ฉ๋‹ˆ๋‹ค. NHWC์˜ ๊ฒฝ์šฐ sizes ๋Š” ์—ฌ์ „ํžˆ NCHW์ด๊ณ  strides ๋Š” (N, H, W, C) ์ˆœ์„œ์ž…๋‹ˆ๋‹ค. ๊ทธ๋ฆฌ๊ณ  stride_index = (0, 2, 3, 1)์„ ์‚ฌ์šฉํ•˜์—ฌ ๋ณดํญ์˜ ์ฐจ์› ์ธ๋ฑ์Šค๋ฅผ ์ง€์ •ํ•ฉ๋‹ˆ๋‹ค.

๋˜ํ•œ strides ๋ฐ stride_index ์˜ ์กฐํ•ฉ์„ ์‚ฌ์šฉํ•˜์—ฌ ๋ชจ๋“  ํ…์„œ ํ˜•์‹์„ ๋‚˜ํƒ€๋‚ผ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ์ด๊ฒƒ์€ ๋‹ค๋ฅธ ์‚ฌ๋žŒ๋“ค์ด ์ƒˆ๋กœ์šด ๋ฐ์ดํ„ฐ ํ˜•์‹์„ ๋“ฑ๋กํ•  ์ˆ˜ ์žˆ๋Š” ์œ ์—ฐ์„ฑ์„ ์ œ๊ณตํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

@ezyang

์ž‘์—…์€ ๋ฉ”๋ชจ๋ฆฌ ํ˜•์‹ ๋™์ž‘์„ ์œ ์ง€ํ•ฉ๋‹ˆ๋‹ค.

๋ฌผ๋ฆฌ์  NHWC ํ…์„œ๊ฐ€ ์ˆœ์ „ํžˆ ์ŠคํŠธ๋ผ์ด๋“œ๋ฅผ ํ†ตํ•ด ๋ฐœ์ƒํ•  ์ˆ˜ ์žˆ๋Š” ๊ฒฝ์šฐ ๋ฉ”๋ชจ๋ฆฌ ํ˜•์‹ ํƒœ๊ทธ๊ฐ€ ์žˆ์„ ๋•Œ๋งŒ ๋ฉ”๋ชจ๋ฆฌ ํ˜•์‹์„ ๋ณด์กดํ•˜๋„๋ก ํ•˜์ง€ ์•Š๋Š” ํ•œ ์ด๊ฒƒ์€ ๊ธฐ์ˆ ์ ์œผ๋กœ BC ๊นจ๋Š” ๊ฒƒ์ž…๋‹ˆ๋‹ค. ์ œ์•ˆ์ด ํ˜„์žฌ ๋ฌด์—‡์„ ์ œ์•ˆํ•˜๊ณ  ์žˆ๋Š”์ง€ ์ž˜ ๋ชจ๋ฅด๊ฒ ์Šต๋‹ˆ๋‹ค.) ์ด๊ฒƒ์ด ์‹ค์ œ๋กœ ์‹ค์ œ๋กœ ๋ˆ„๊ตฐ๊ฐ€์˜ ์ฝ”๋“œ๋ฅผ ์†์ƒ์‹œํ‚ค๋Š”์ง€ ํ™•์‹คํ•˜์ง€ ์•Š์Šต๋‹ˆ๋‹ค.

์‚ฐ์ˆ  ์—ฐ์‚ฐ๊ณผ ์ž„๊ณ„๊ฐ’์ด TensorIterator๋กœ ์ด๋™ํ–ˆ์„ ๋•Œ ์ด๋Š” ๊ธฐ์ˆ ์ ์œผ๋กœ BC ํŒŒ๊ดด์˜€์Šต๋‹ˆ๋‹ค(ํ”ผ์—ฐ์‚ฐ์ž์˜ ๋ฉ”๋ชจ๋ฆฌ ํ˜•์‹์ด ๋ณด์กด๋˜์ง€ ์•Š๊ณ  TensorIterator๊ฐ€ ์ด๋ฅผ ๋ณด์กดํ•˜๊ธฐ ๋•Œ๋ฌธ์ž…๋‹ˆ๋‹ค). ํ˜„์žฌ ์ƒํƒœ๋Š” ๋งค์šฐ ์ผ๊ด€์„ฑ์ด ์—†์Šต๋‹ˆ๋‹ค. ์ž„๊ณ„๊ฐ’์€ ๋ ˆ์ด์•„์›ƒ์„ ์œ ์ง€ํ•˜์ง€๋งŒ, ๋‹ค๋ฅธ ๋ชจ๋“  ๋‹จํ•ญ ์—ฐ์‚ฐ์€ ๊ทธ๋ ‡์ง€ ์•Š์œผ๋ฉฐ, torch.where๋Š” ๊ทธ๋ ‡์ง€ ์•Š์Šต๋‹ˆ๋‹ค. ๋‘ ํ”ผ์—ฐ์‚ฐ์ž์˜ ๋ ˆ์ด์•„์›ƒ์ด ๋™์ผํ•œ ๊ฒฝ์šฐ ์‚ฐ์ˆ  ์—ฐ์‚ฐ์€ ๋ ˆ์ด์•„์›ƒ์„ ์œ ์ง€ํ•˜์ง€๋งŒ ๊ธฐ๋ณธ๊ฐ’์€ "nchw" ๋˜๋Š” contiguous ํ…์„œ์ž…๋‹ˆ๋‹ค contiguous ๋ฏธ์Šค๋งค์นญ์ด ์žˆ์„ ๊ฒฝ์šฐ ๋ฐฉ์†ก์€ ์–ด๋–ป๊ฒŒ ๋˜๋Š”์ง€ ์ž˜ ๋ชจ๋ฅด๊ฒ ์Šต๋‹ˆ๋‹ค.
๋˜ํ•œ BC๊ฐ€ ์•„๋‹Œ ๋ ˆ์ด์•„์›ƒ์„ ์œ ์ง€ํ•˜๋Š” empty_like ๋Œ€ํ•ด ์ข‹์€ ์ง€์ ์„ ํ•˜๊ณ  ์žˆ์Šต๋‹ˆ๋‹ค. ์•„๋งˆ๋„ ์ œ์•ˆ์„œ์˜ is_contiguous์™€ ๊ฐ™์€ ๋ ˆ์ด์•„์›ƒ ์ธ์ˆ˜๊ฐ€ ํ•„์š”ํ•  ๊ฒƒ์ž…๋‹ˆ๋‹ค.

x.is_contiguous(torch.memory_format.channels_first)

@ezyang @ngimel

empty_like์—๋Š” ํ•œ ๊ฐ€์ง€ ๋ฌธ์ œ๊ฐ€ ์žˆ์Šต๋‹ˆ๋‹ค. ํ˜„์žฌ ์ •์˜๋œ ์˜๋ฏธ๋Š” ๋ชจ๋“  ๋ณดํญ ์ •๋ณด๋ฅผ ์‚ญ์ œํ•œ๋‹ค๋Š” ๊ฒƒ์ด๋ฏ€๋กœ ๋ ˆ์ด์•„์›ƒ์„ ์œ ์ง€ํ•˜๊ณ  BC๊ฐ€ ๋  ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค.

๋˜ํ•œ BC๊ฐ€ ์•„๋‹Œ ๋ ˆ์ด์•„์›ƒ์„ ์œ ์ง€ํ•˜๋Š” empty_like ๋“ฑ์— ๋Œ€ํ•ด ์ข‹์€ ์ง€์ ์„ ํ•˜๊ณ  ์žˆ์Šต๋‹ˆ๋‹ค.

์šฐ๋ฆฌ๊ฐ€ ๋ฌผ๋ฆฌ์  ์งˆ์„œ๋ฅผ ํ‘œํ˜„ํ•˜๊ธฐ ์œ„ํ•ด ๋ณดํญ์— ์˜์กดํ•˜์ง€ ์•Š๋Š”๋‹ค๋ฉด, empty_like ๋Š” BC๋ฅผ ๊นจ๋œจ๋ฆด ํ•„์š”๊ฐ€ ์—†์Šต๋‹ˆ๋‹ค. ํ…์„œ์—๋Š” 3๊ฐ€์ง€ ์ฐจ์› ์ •๋ณด๊ฐ€ ์žˆ์Šต๋‹ˆ๋‹ค.

  • ๋ชจ์–‘: ํฌ๊ธฐ
  • ๋…ผ๋ฆฌ ์ˆœ์„œ: ๋ณดํญ์œผ๋กœ ๊ธฐ๋ก๋œ ์ˆœ์„œ ์ •๋ณด(์ผ๋ฐ˜์ ์œผ๋กœ ์ „์น˜ ๋˜๋Š” ์ˆœ์—ด์„ ์ง€์›ํ•˜๋Š” ๋ฐ ์‚ฌ์šฉ๋จ)
  • ๋ฌผ๋ฆฌ์  ์ˆœ์„œ: NCHW ๋˜๋Š” NHWC(์ œ๊ฐ€ ์ œ์•ˆํ•œ ๋Œ€๋กœ stride_index๋กœ ์ฒ˜๋ฆฌํ•  ์ˆ˜ ์žˆ์Œ).

ํ˜„์žฌ ๋ฌผ๋ฆฌ์  ์ˆœ์„œ๋Š” ๋ชจ์–‘/ํฌ๊ธฐ์™€ ๋™์ผํ•ฉ๋‹ˆ๋‹ค. ๊ทธ๋ž˜์„œ ์šฐ๋ฆฌ๋Š” ๋…ผ๋ฆฌ ์ˆœ์„œ๋ฅผ ๋ณดํญ์œผ๋กœ ๋–จ์–ด๋œจ๋ฆฝ๋‹ˆ๋‹ค. ๋ชจ์–‘๊ณผ ๋ฌผ๋ฆฌ์  ์ˆœ์„œ๋ฅผ ๋ถ„๋ฆฌํ•œ๋‹ค๊ณ  ์ƒ๊ฐํ•˜๋ฉด ๋…ผ๋ฆฌ ์ˆœ์„œ๋ฅผ ์‚ญ์ œํ•  ์ˆ˜๋„ ์žˆ์ง€๋งŒ empty_like ๋Œ€ํ•œ ๋ชจ์–‘๊ณผ ๋ฌผ๋ฆฌ์  ์ˆœ์„œ๋Š” ๋ณด์กดํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ์ฆ‰, size() ๋ฐ stride_index() ๋Š” ๋ชจ๋‘ ๋ณด์กด๋˜์ง€๋งŒ stride() ๋Š” ์žฌ์„ค์ •๋ฉ๋‹ˆ๋‹ค. ํŠนํžˆ NHWC ํ…์„œ์˜ empty_like ๋Š” ๋™์ผํ•œ ๋ชจ์–‘ ์ •๋ณด๊ฐ€ ์ง€์ •๋œ NHWC ์—ฐ์† ํ…์„œ๋ฅผ ๋ฐ˜ํ™˜ํ•ฉ๋‹ˆ๋‹ค.

@uyongw empty_like ๋ณ€๊ฒฝํ•˜๋Š” ๊ฒƒ์ด ์ข‹์€ ์ƒ๊ฐ์ธ์ง€ ์ž˜ ๋ชจ๋ฅด๊ฒ ์Šต๋‹ˆ๋‹ค. ์ง€๊ธˆ ๊ทธ ์˜๋ฏธ๋Š” numpy์˜ empty_like ์™€ ์ผ์น˜ํ•ฉ๋‹ˆ๋‹ค.

ํ˜„์žฌ ์ƒํƒœ๋Š” ๋งค์šฐ ์ผ๊ด€์„ฑ์ด ์—†์Šต๋‹ˆ๋‹ค. ์ž„๊ณ„๊ฐ’์€ ๋ ˆ์ด์•„์›ƒ์„ ์œ ์ง€ํ•˜์ง€๋งŒ, ๋‹ค๋ฅธ ๋ชจ๋“  ๋‹จํ•ญ ์—ฐ์‚ฐ์€ ๊ทธ๋ ‡์ง€ ์•Š์œผ๋ฉฐ, torch.where๋Š” ๊ทธ๋ ‡์ง€ ์•Š์Šต๋‹ˆ๋‹ค. ๋‘ ํ”ผ์—ฐ์‚ฐ์ž๊ฐ€ ๋™์ผํ•œ ๋ ˆ์ด์•„์›ƒ์„ ๊ฐ–๊ณ  ์žˆ๋Š” ๊ฒฝ์šฐ ์‚ฐ์ˆ  ์—ฐ์‚ฐ์€ ๋ ˆ์ด์•„์›ƒ์„ ์œ ์ง€ํ•˜์ง€๋งŒ ๊ธฐ๋ณธ๊ฐ’์€ "nchw" ๋˜๋Š” ์ธ์ ‘ํ•˜๋Š” ํ…์„œ์ž…๋‹ˆ๋‹ค. ๋ถˆ์ผ์น˜๊ฐ€ ์žˆ๋Š” ๊ฒฝ์šฐ ํ˜„์žฌ ์ดํ•ดํ•˜๊ณ  ์žˆ์ง€๋งŒ ๋ฐฉ์†ก์€ ์–ด๋–ป๊ฒŒ ๋˜๋Š”์ง€ ์ž˜ ๋ชจ๋ฅด๊ฒ ์Šต๋‹ˆ๋‹ค.

@ngimel , ์˜ˆ, ์ง€๊ธˆ์€ ์ผ๊ด€์„ฑ์ด ์—†์Šต๋‹ˆ๋‹ค. ๋ฉ”๋ชจ๋ฆฌ ํ˜•์‹์„ ํ‘œํ˜„ํ•˜๋Š” ๋ฐฉ๋ฒ•์˜ ์ผ๋ถ€๋Š” ์—ฐ์‚ฐ์ž๋ฅผ ์ผ๊ด€๋œ ์ƒํƒœ๋กœ ๋งŒ๋“œ๋Š” ๊ฒƒ์ด๋ผ๊ณ  ์ƒ๊ฐํ•ฉ๋‹ˆ๋‹ค.

@zou3519 ๋งํฌํ•œ numpy์˜ empty_like์—๋Š” ๊ธฐ๋ณธ์ ์œผ๋กœ "ํ”„๋กœํ† ํƒ€์ž…์˜ ๋ ˆ์ด์•„์›ƒ๊ณผ ์ตœ๋Œ€ํ•œ ๊ฐ€๊น๊ฒŒ ์ผ์น˜"ํ•˜๋Š” order ์ธ์ˆ˜๊ฐ€ ์žˆ์Šต๋‹ˆ๋‹ค. ๊ทธ๊ฒƒ์€ pytorch์˜ empty_like ๊ฐ€ ํ˜„์žฌ ์ˆ˜ํ–‰ํ•˜๋Š” ์ž‘์—…์ด ์•„๋‹™๋‹ˆ๋‹ค(ํ”„๋กœํ† ํƒ€์ž…์ด ์—ฐ์†์ ์ด์ง€ ์•Š์€ ๊ฒฝ์šฐ์—๋„ "nchw"- ์—ฐ์† ํ…์„œ๋ฅผ ๋ฐ˜ํ™˜ํ•จ)

์•„, ๋„ˆ๋ฌด ๋นจ๋ฆฌ ์ฝ์—ˆ๋„ค์š”. ๊ทธ ๊ฒฝ์šฐ์— ์šฐ๋ฆฌ์˜ empty_like ์ผ์น˜ numpy๋ฅผ ๊ฐ–๋Š” ๊ฒƒ์ด ์ข‹์„ ๊ฒƒ์ด๊ณ  ์—ฌ๊ธฐ์—์„œ ๋ฉ”๋ชจ๋ฆฌ ๋ ˆ์ด์•„์›ƒ๋„ ๊ฐ–๋Š” ๊ฒƒ์ด ์ข‹์„ ๊ฒƒ์ž…๋‹ˆ๋‹ค(์•„๋งˆ๋„?)

@zou3519 ๋„ค, ์ œ๊ฐ€ ๋งํ•˜๋ ค๋Š” ๊ฒƒ์€ ํ˜„์žฌ ์˜๋ฏธ๋ก ( @ezyang ๋ฐ @ngimel์ด ์–ธ๊ธ‰ํ•œ ๊ฒƒ์ฒ˜๋Ÿผ ๋…ผ๋ฆฌ์  ์ˆœ์„œ ์‚ญ์ œ)์„ ์œ ์ง€ํ•˜๊ณ  ๋™์‹œ์— numpy์˜ ๊ธฐ๋ณธ๊ฐ’๊ณผ ๊ฐ™์€ ๋ฌผ๋ฆฌ์  ๋ ˆ์ด์•„์›ƒ์„ ์œ ์ง€ํ•˜๋Š” ๊ฒƒ์ž…๋‹ˆ๋‹ค. ๋”ฐ๋ผ์„œ NCHW ํ”„๋กœํ† ํƒ€์ž…์˜ ๊ฒฝ์šฐ ๋™์ž‘์€ ์ด์ „๊ณผ ๋™์ผํ•ฉ๋‹ˆ๋‹ค. NHWC ํ”„๋กœํ† ํƒ€์ž…์˜ ๊ฒฝ์šฐ ๋™์ž‘์€ ์—ฌ์ „ํžˆ โ€‹โ€‹ํ˜ธํ™˜๋ฉ๋‹ˆ๋‹ค. ์ฆ‰, ํ˜„์žฌ ๊ตฌํ˜„์„ ๋ณ€๊ฒฝํ•˜์ง€ ์•Š์œผ๋ฉด ์ƒˆ ํ…์„œ๋Š” NCHW ์—ฐ์† ๋Œ€์‹  NHWC ์—ฐ์†์ด ๋ฉ๋‹ˆ๋‹ค.

๋‘ ๊ฐ€์ง€ ์งˆ๋ฌธ:

  • NHWC ํ…์„œ๋ฅผ NCHW ํ…์„œ์— ์ถ”๊ฐ€ํ•˜๋ฉด ์–ด๋–ป๊ฒŒ ๋ฉ๋‹ˆ๊นŒ?
  • ์ฐจ์›์˜ ๋ฌผ๋ฆฌ์  ์œ„์น˜๋ฅผ ๋‚˜ํƒ€๋‚ด๋Š” ์ •์ˆ˜ ๊ฐ’์„ ๋ฐ˜ํ™˜ํ•˜๋Š” ํ…์„œ์—์„œ t.channel_dim()๊ณผ ๊ฐ™์€ ๋ฉ”์„œ๋“œ๋ฅผ ๋งŒ๋“ค์–ด (B)์˜ ๋‹จ์ ์„ ํ•ด๊ฒฐํ•˜๋Š” ๊ฒƒ์€ ์–ด๋–ป์Šต๋‹ˆ๊นŒ? ์ด ์ ‘๊ทผ ๋ฐฉ์‹์€ ๋ธ”๋ก ํ˜•์‹๊ณผ ๊ฐ™์€ ๋‹ค๋ฅธ ํ˜•์‹์„ ๋„คํŠธ์›Œํฌ ๋ณ€๊ฒฝ ์—†์ด ์„ ํƒํ•  ์ˆ˜ ์žˆ๋„๋ก ํ•˜๋Š” ๋ฐ ํ•„์š”ํ•  ์ˆ˜๋„ ์žˆ์Šต๋‹ˆ๋‹ค.

๋งˆ์ง€๋ง‰ ๊ธ€๋จธ๋ฆฌ ๊ธฐํ˜ธ๋กœ (B)์˜ ๋‹จ์ ์„ ํ•ด๊ฒฐํ•˜๋ฉด (B)๊ฐ€ ๋‚˜์—๊ฒŒ ๋” ๋‚˜์€ ๊ฒƒ ๊ฐ™์Šต๋‹ˆ๋‹ค. ์ง๊ด€์ ์œผ๋กœ ๋ช…ํ™•ํ•˜๊ณ  ๋…ผ๋ฆฌ์  ์˜ค๋ฅ˜๋ฅผ ๊ฐ์ง€ํ•˜๊ธฐ ์‰ฌ์›Œ์•ผ ํ•ฉ๋‹ˆ๋‹ค. ๊ธฐ์กด์˜ ๋ชจ๋“  ์—ฐ์‚ฐ์€ ๋‹ค๋ฅธ ์ธ์ ‘ ํ…์„œ์ฒ˜๋Ÿผ ๋ณด์ด๊ธฐ ๋•Œ๋ฌธ์— ํ…์„œ์—์„œ๋„ ์ž‘๋™ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ์‹œ๋งจํ‹ฑ(๋ช…๋ช…๋œ ํ…์„œ ์ œ์•ˆ๊ณผ ์œ ์‚ฌ)์„ ์ดํ•ดํ•  ์ˆ˜ ์žˆ๋Š” ์ž‘์—…๋„ ์˜ˆ์ƒ๋Œ€๋กœ ์ˆ˜ํ–‰๋ฉ๋‹ˆ๋‹ค.

@zou3519 ๋งํฌํ•œ numpy์˜ empty_like์—๋Š” ๊ธฐ๋ณธ์ ์œผ๋กœ "ํ”„๋กœํ† ํƒ€์ž…์˜ ๋ ˆ์ด์•„์›ƒ๊ณผ ์ตœ๋Œ€ํ•œ ๊ฐ€๊น๊ฒŒ ์ผ์น˜"ํ•˜๋Š” order ์ธ์ˆ˜๊ฐ€ ์žˆ์Šต๋‹ˆ๋‹ค. ๊ทธ๊ฒƒ์€ pytorch์˜ empty_like ๊ฐ€ ํ˜„์žฌ ์ˆ˜ํ–‰ํ•˜๋Š” ์ž‘์—…์ด ์•„๋‹™๋‹ˆ๋‹ค(ํ”„๋กœํ† ํƒ€์ž…์ด ์—ฐ์†์ ์ด์ง€ ์•Š์€ ๊ฒฝ์šฐ์—๋„ "nchw"- ์—ฐ์† ํ…์„œ๋ฅผ ๋ฐ˜ํ™˜ํ•จ)

์ด๋Ÿฌํ•œ ๊ฒฝ์šฐ ํ˜•์‹์„ ์œ ์ง€ํ•  ๊ณ„ํš์ž…๋‹ˆ๋‹ค(๋ฉ”๋ชจ๋ฆฌ ํ˜•์‹ ํ…์„œ์˜ ๊ฒฝ์šฐ).

NHWC ํ…์„œ๋ฅผ NCHW ํ…์„œ์— ์ถ”๊ฐ€ํ•˜๋ฉด ์–ด๋–ป๊ฒŒ ๋ฉ๋‹ˆ๊นŒ?
๋ฉ”๋ชจ๋ฆฌ ํ˜•์‹ ํ…์„œ๋ฅผ ์‚ฌ์šฉํ•œ ์ž‘์—…์€ ๋ฉ”๋ชจ๋ฆฌ ํ˜•์‹ ํ…์„œ๋ฅผ ๋ฐ˜ํ™˜ํ•ฉ๋‹ˆ๋‹ค. ๋‘ ํ…์„œ๊ฐ€ ๋ชจ๋‘ ๋ฉ”๋ชจ๋ฆฌ ํ˜•์‹์ธ ๊ฒฝ์šฐ ์ถœ๋ ฅ ํ˜•์‹์€ ์ฒซ ๋ฒˆ์งธ ํ…์„œ์— ์˜ํ•ด ๊ฒฐ์ •๋ฉ๋‹ˆ๋‹ค.

๋‚ด๊ฐ€ ์ถ”๊ฐ€ํ•  ๋‘ ๊ฐ€์ง€:

์ด๋Ÿฌํ•œ ๊ฒฝ์šฐ ํ˜•์‹์„ ์œ ์ง€ํ•  ๊ณ„ํš์ž…๋‹ˆ๋‹ค(๋ฉ”๋ชจ๋ฆฌ ํ˜•์‹ ํ…์„œ์˜ ๊ฒฝ์šฐ).

์ข…์ข… ์šด์˜์ž๊ฐ€ empty_like ํ˜ธ์ถœํ•œ ๋‹ค์Œ NCHW ์—ฐ์†์ด๋ผ๊ณ  ๊ฐ€์ •ํ•˜๊ธฐ ๋•Œ๋ฌธ์— ๊ธฐ์กด ์‚ฌ์šฉ์„ ๊ฐ์‚ฌํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค. ๊ทธ๋ฆฌ๊ณ  ์ œ3์ž ์ฝ”๋“œ๋ฅผ ์–ด๋–ป๊ฒŒ ์ฒ˜๋ฆฌํ• ์ง€ ๋ชจ๋ฅด๊ฒ ์Šต๋‹ˆ๋‹ค. BC๋ฅผ ๋ณด์กดํ•˜๋ ค๋ฉด numpy์™€ ๋‹ค๋ฅธ ๊ธฐ๋ณธ๊ฐ’์ด ํ•„์š”ํ•œ ๊ฒƒ ๊ฐ™์Šต๋‹ˆ๋‹ค.

๋ฉ”๋ชจ๋ฆฌ ํ˜•์‹ ํ…์„œ๋ฅผ ์‚ฌ์šฉํ•œ ์ž‘์—…์€ ๋ฉ”๋ชจ๋ฆฌ ํ˜•์‹ ํ…์„œ๋ฅผ ๋ฐ˜ํ™˜ํ•ฉ๋‹ˆ๋‹ค. ๋‘ ํ…์„œ๊ฐ€ ๋ชจ๋‘ ๋ฉ”๋ชจ๋ฆฌ ํ˜•์‹์ธ ๊ฒฝ์šฐ ์ถœ๋ ฅ ํ˜•์‹์€ ์ฒซ ๋ฒˆ์งธ ํ…์„œ์— ์˜ํ•ด ๊ฒฐ์ •๋ฉ๋‹ˆ๋‹ค.

๋˜ํ•œ ์ถœ๋ ฅ ํ˜•์‹์ด ๋ฌด์—‡์ธ์ง€ ์ •๋ง ์ค‘์š”ํ•˜๋‹ค๋ฉด ์ถœ๋ ฅ ํ…์„œ๋ฅผ ์ „๋‹ฌํ•ฉ๋‹ˆ๋‹ค.

empty_like์— ๋™์˜ํ•ฉ๋‹ˆ๋‹ค. empty_like/zeros_like ๋“ฑ์˜ ๊ฒฐ๊ณผ๊ฐ€ nchw-contiguous๋กœ ๊ฐ„์ฃผ๋˜๋Š” ๊ฒฝ์šฐ๊ฐ€ ๊ฝค ์žˆ์Šต๋‹ˆ๋‹ค(๋ฌผ๋ฆฌ์ ์œผ๋กœ ์—ฐ์†์ ์ด๋ผ๊ณ  ๋งํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค. ๋งŽ์€ ๊ฒฝ์šฐ ์ด๋ฏธ์ง€ ์ž‘์—…์ด ์•„๋‹˜).
out kwarg๊ฐ€ ์žˆ๋Š” ํ•จ์ˆ˜๋Š” ๋ฏธ๋ถ„ํ•  ์ˆ˜ ์—†๊ธฐ ๋•Œ๋ฌธ์— ์ถœ๋ ฅ ํ…์„œ๋ฅผ ์ „๋‹ฌํ•˜๋Š” ๊ฒƒ์€ ๋Œ€๋ถ€๋ถ„์˜ ๊ฒฝ์šฐ ์˜ต์…˜์ด ์•„๋‹™๋‹ˆ๋‹ค.

์šฐ๋ฆฌ์˜ ๋งŽ์€ ๋ฌธ์ œ๋Š” ์˜ˆ์ƒ๋˜๋Š” ์ถœ๋ ฅ ๋ ˆ์ด์•„์›ƒ์˜ ๋ถˆ์ผ์น˜์—์„œ ๋น„๋กฏ๋ฉ๋‹ˆ๋‹ค. ํ•œ ๋ฒˆ์— ๋ชจ๋“  ๋ฌธ์ œ๋ฅผ ํ•ด๊ฒฐํ•  ์ˆ˜๋Š” ์—†์ง€๋งŒ ํ˜„์žฌ ์ƒํƒœ๋ฅผ ์ž ๊ทธ๊ณ (์ ์–ด๋„ ๋ณดํญ์— ๋Œ€ํ•ด์„œ๋Š”) ํ•˜๋‚˜์”ฉ ํ•ด๊ฒฐํ•ด ๋ณผ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ์—ฌ๊ธฐ ์ œ์•ˆ์ด ์žˆ์Šต๋‹ˆ๋‹ค.

ํŒŒ์ด์ฌ API

์ƒˆ๋กœ์šด torch.memory_format์„ ์†Œ๊ฐœํ•ฉ๋‹ˆ๋‹ค.

torch_memory_format.any # default value
torch_memory_format.preserve
torch.memory_format.contiguous # what most of the functions now behave as default
torch.memory_format.nchw # requires 4D tensor, contiguous memory
torch.memory_format.nhwc # requires 4D tensor, restrided/permuted memory

ํ…์„œ๋Š” ๋ช…์‹œ์  ๋ฉ”๋ชจ๋ฆฌ ํ˜•์‹ ๋ณ€ํ™˜์ด ํ•„์š”ํ•ฉ๋‹ˆ๋‹ค.

x = torch.zeros((10,3,32,32)) # NCHW
x.permute(0,2,3,1).is_contiguous(memory_format=torch.memory_format.nhwc) == False # because memory still layed out as NCHW

ํŠน์ • ํ˜•์‹์œผ๋กœ 'ํƒœ๊ทธ'ํ•˜๋ ค๋ฉด:

y = x.to(memory_format=torch.memory_format.nhwc)
y.is_contiguous(memory_format=torch.memory_format.nhwc) == True # We got new tensor with proper memory layout
y.is_contiguous() == False # Required for back compatibility
y.stride() == (3072, 3, 1, 96)

์ด์ œ empty_like ๋ฐ ์œ ์‚ฌ์— ๋Œ€ํ•ด:

z = torch.empty_like(y) 
z.is_contiguous() == True # For BC

์‹ค์ œ๋กœ ๋‹ค์Œ๊ณผ ๊ฐ™๊ธฐ ๋•Œ๋ฌธ์ž…๋‹ˆ๋‹ค.

z = torch.empty_like(y, memory_format=torch.memory_format.any ) 

ํ˜•์‹์„ ์œ ์ง€ํ•˜๋ ค๋ฉด:

z = torch.empty_like(y, memory_format=torch_memory_format.preserve) 
z.is_contiguous() == False 
z.is_contiguous(memory_format=torch.memory_format.nhwc) == True

๋น„์Šทํ•˜๊ฒŒ:

z = torch.empty_like(y, memory_format=memory_format=torch.memory_format.nhwc) 
z.is_contiguous() == False 
z.is_contiguous(memory_format=torch.memory_format.nhwc) == True

์ฆ‰, ๊ฐ ํ•จ์ˆ˜ memory_format ๊ธฐ๋ณธ๊ฐ’์„ ์„ธ๊ณ„์˜ ํ˜„์žฌ ์ƒํƒœ๋กœ ์ฒœ์ฒœํžˆ ์ •์˜ํ•˜๊ณ  ๋ถ„๋ฅ˜ํ•˜๊ณ  ๋ฏธ๋ž˜์— ๋ณ€๊ฒฝํ•  ๋ฐฉ๋ฒ•์„ ์—ผ๋‘์— ๋‘˜ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

ํ…์„œ๋ฅผ ์ง€์ •ํ•˜๋ฉด TensorOptions๊ฐ€ ํ˜„์žฌ ๋ฌด์‹œ๋ฉ๋‹ˆ๋‹ค(๊ฐ€์žฅ ์ข‹์€ ๊ฒฝ์šฐ ์˜ˆ์™ธ๊ฐ€ ๋ฐœ์ƒํ•˜๋Š” ๊ฒƒ์€ ์˜ˆ๋ฅผ ๋“ค์–ด out ํ…์„œ ์žฅ์น˜์™€ ์žฅ์น˜ ์˜ต์…˜ ๋ถˆ์ผ์น˜๋ฅผ ์ „๋‹ฌํ•˜๋Š” ๊ฒƒ์ž…๋‹ˆ๋‹ค).

๋ฉ”๋ชจ๋ฆฌ ํ˜•์‹์€ ๊ฐ€๋ฒผ์›Œ์•ผ ํ•˜๋ฏ€๋กœ ์ˆœ์—ด์ด ์†์‹ค๋  ๊ฒƒ์ž…๋‹ˆ๋‹ค.

x.zeros((10,3,32,32), memory_format=torch.memory_format.nhwc)
x = x.permute(0,1,3,2).permute(0,1,3,2)
x.is_contiguous(memory_format=torch.memory_format.nhwc) == False (even if strides are similar)

ํŒจ๋”ฉ์ด ํ™•์‹คํ•˜์ง€ ์•Š์€ ๊ฒฝ์šฐ ์—ฌ๊ธฐ์—์„œ ๋„์›€์„ ์ฃผ์‹œ๋ฉด ๊ฐ์‚ฌํ•˜๊ฒ ์Šต๋‹ˆ๋‹ค.

๊ทธ๋Ÿฌ๋‚˜ ์ ์ ˆํ•œ ํ˜•์‹์œผ๋กœ x.to(memory_format=torch.memory_format.nhwc) 'tag' ํ…์„œ๋ฅผ ๋งŒ๋“ค๊ณ  ์ž์ฒด๋ฅผ ๋ฐ˜ํ™˜ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

๋‹ค์ค‘ ์ฒ˜๋ฆฌ

๋ฉ”๋ชจ๋ฆฌ ํ˜•์‹ 'ํƒœ๊ทธ'๋ฅผ ์œ ์ง€ํ•ฉ๋‹ˆ๋‹ค.

๋ธ”๋ก ๋ฉ”๋ชจ๋ฆฌ ํ˜•์‹

์œ„์˜ API๋Š” ์ฐจ์›/๋ณดํญ/ํฌ๊ธฐ์— ์˜์กดํ•˜์ง€ ์•Š์œผ๋ฏ€๋กœ ํ–ฅํ›„ ๋™์ผํ•œ API๋ฅผ ์œ ์ง€ํ•˜๋ฉด์„œ ๊ธฐ๋Šฅ์„ ํ™•์žฅํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

๋‚ด๋ถ€ API

์—ฐ์‚ฐ์ž๋Š” ๋ฉ”๋ชจ๋ฆฌ ํ˜•์‹์— ๋”ฐ๋ผ ๋ถ„๊ธฐํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

if (self.memory_format(nhwc)) {
 // fast path
} else
{
 // classic implementation
}

memory_format์„ TensorOptions๋กœ ํ•˜๋ฉด ๋””์ŠคํŒจ์น˜ ์ˆ˜์ค€์—์„œ ๋ถ„๊ธฐํ•˜๋Š” ๊ฒƒ์„ ์ƒ๊ฐํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค(๋””๋ฐ”์ด์Šค, ๋ ˆ์ด์•„์›ƒ๊ณผ ์œ ์‚ฌ).

@VitalyFedyunin ์˜ ์ œ์•ˆ์— ๋Œ€ํ•œ ์ž‘์€ ํ”ผ๋“œ๋ฐฑ - ์—ฌ๊ธฐ์— 4D ํ…์„œ๊ฐ€ ํ•„์š”ํ•˜๋‹ค๊ณ  ์ƒ๊ฐํ•ฉ๋‹ˆ๋‹ค.

torch.memory_format.nchw # requires 4D tensor, contiguous memory
torch.memory_format.nhwc # requires 4D tensor, restrided/permuted memory

๋„ˆ๋ฌด ์ œํ•œ์ ์ด๋ฉฐ(2D ์™ธ์— 1D ๋ฐ 3D๋„ ์ฒ˜๋ฆฌํ•ด์•ผ ํ•˜๊ธฐ ๋•Œ๋ฌธ์—) ์›๋ž˜ ์ œ์•ˆ์˜ channels_first/channels_last ๊ฐ€ ์ด ๋ชฉ์ ์— ๋” ์ ํ•ฉํ–ˆ์Šต๋‹ˆ๋‹ค.

๋™์˜ํ•ฉ๋‹ˆ๋‹ค. ๋” ๋‚˜์€ ์ด๋ฆ„ ์ง€์ •์ด ํ•„์š”ํ•ฉ๋‹ˆ๋‹ค. channels_first ๋Š” ์ผ๊ด„ ์ฒ˜๋ฆฌ๊ฐ€ ๋จผ์ € ์ง„ํ–‰๋œ๋‹ค๋Š” ์ ์„ ์ œ์™ธํ•˜๊ณ  ๊ฑฐ์˜ ์˜ณ๊ฒŒ ๋“ค๋ฆฝ๋‹ˆ๋‹ค =)

๋‚˜๋Š” ๋‹น์‹ ์˜ ์ตœ์‹  ์ œ์•ˆ์„ ์ข‹์•„ํ•ฉ๋‹ˆ๋‹ค. .contiguous() ์ฒ˜๋ฆฌ๊ฐ€ ๋ณ€๊ฒฝ๋ฉ๋‹ˆ๊นŒ? .contiguous(memory_format=<...>)๊ฐ€ ํ•„์š”ํ•ฉ๋‹ˆ๊นŒ? ๊ทธ๋ ‡๋‹ค๋ฉด ๋งŽ์€ ์ž‘์—…์ด ๋‹จ์ˆœํžˆ .contiguous()๋ฅผ ํ˜ธ์ถœํ•˜์ง€๋งŒ ์—ฌ์ „ํžˆ ๋ฉ”๋ชจ๋ฆฌ๋ฅผ ๋ถ€์ ์ ˆํ•˜๊ฒŒ ํฌ๋งทํ•˜๊ณ  ์žˆ์„ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ์˜ค๋Š˜๋‚  ๋งŽ์€ ์ž‘์—…์—์„œ๋„ ๋™์ผํ•œ ํšจ๊ณผ๋ฅผ ๋‚ผ ์ˆ˜ ์žˆ๋Š” empty_like()๋กœ ์ถœ๋ ฅ์„ ํ• ๋‹นํ•ฉ๋‹ˆ๋‹ค. ์ž…๋ ฅ์˜ ๋ฉ”๋ชจ๋ฆฌ ํ˜•์‹์„ ๊ฐ์ง€ํ•˜๊ณ  ์˜ฌ๋ฐ”๋ฅธ ์—ฐ์†์ ์ด๊ณ  empty_like ํ˜ธ์ถœ์„ ์ˆ˜ํ–‰ํ•˜๋„๋ก ์—…๋ฐ์ดํŠธํ•  ๊ณ„ํš์ž…๋‹ˆ๊นŒ?

์ง€๊ธˆ ๋‹น์žฅ์€ .contiguous() ๊ฐ€ ๋ฉ”๋ชจ๋ฆฌ ์—ฐ์† ํ…์„œ๋ฅผ ๋‚ด๋ฆผ์ฐจ์ˆœ์œผ๋กœ ๋ณดํญ์œผ๋กœ ๋ฐ˜ํ™˜ํ•  ๊ฒƒ์œผ๋กœ ๊ธฐ๋Œ€ํ•˜๋Š” ์‚ฌ์šฉ์ž(๋ฐ ๋ชจ๋“  ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ)์ž…๋‹ˆ๋‹ค.

์šฐ๋ฆฌ๋Š” ์ด ๊ณ„์•ฝ์„ ๊นฐ ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค. ๊ทธ๋Ÿฌ๋‚˜ ์ข‹์€ ์†Œ์‹์€ memory_format ์˜ต์…˜์„ ์ง€์›ํ•˜๋Š” ์ฆ‰์‹œ JIT๊ฐ€ ํด๋ž˜์‹ ํ˜•์‹ ๋Œ€์‹  .contiguous(memory_format=...) ๋ฅผ ํ˜ธ์ถœํ•˜๋Š” ๊ฒƒ์ด ๋” ํšจ์œจ์ ์ธ ๋•Œ๋ฅผ ์ดํ•ดํ•  ์ˆ˜ ์žˆ๋‹ค๋Š” ๊ฒƒ์ž…๋‹ˆ๋‹ค.

@VitalyFedyunin ์•„๋ž˜์™€ ๊ฐ™์€ ์ž‘์—…์€ ํ—ˆ์šฉ๋˜์ง€ ์•Š๋Š”๋‹ค๊ณ  ๊ฐ€์ •ํ•ฉ๋‹ˆ๊นŒ?

x.zeros(10,3,32,32)
# x is in nchw (default)
# x.size() is [10,3,32,32]
# x.stride() is [3*32*32, 32*32, 32,1]
x = x.permute(0,2,3,1)
# At this point 
# x.size() is [10,32,32,3], size is not in nchw order
# x.stride() is [3*32*32, 32,1,32*32]

# How can this be supported?
y = x.to(memory_format=torch.memory_format.nhwc)

๋˜ ๋‹ค๋ฅธ ๋ณ€ํ˜•์€ ๋‹ค์Œ๊ณผ ๊ฐ™์Šต๋‹ˆ๋‹ค.

x.zeros(10,3,32,32)
# `x` is in nchw (default)
# x.size() is [10,3,32,32]
# x.stride() is [3*32*32, 32*32, 32,1]
x = x.permute(0,2,3,1)
x=x.contiguous()
# At this point 
# x.size() is [10,32,32,3], size is not in nchw order
# x.stride() is [32*32*3, 32*3,3,1]

# How can this be supported?
y = x.to(memory_format=torch.memory_format.nhwc)

@raghuramank100 - ์‚ฌ์šฉ์ž๊ฐ€ ์ฒ˜์Œ์— .permute(0,2,3,1) ๋ฅผ ํ˜ธ์ถœํ•˜๋Š” ์ด์œ ๋Š” ๋ฌด์—‡์ž…๋‹ˆ๊นŒ? ์ด ์ œ์•ˆ์˜ ๋ชจ๋“  ํ…์„œ๋Š” ์˜๋ฏธ๋ก ์  ํฌ๊ธฐ๊ฐ€ (n,c,h,w)์ด๋ฉฐ, ์ด๋Š” size(1)์ด ์ฑ„๋„์„ ๋ฐ˜ํ™˜ํ•จ์„ ์˜๋ฏธํ•ฉ๋‹ˆ๋‹ค. ๊ทธ๊ฒƒ์ด ์˜ค๋Š˜๋‚  PT์˜ ํ‘œ์ค€ ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ๊ฐ€ ๊ฐ€์ •ํ•˜๊ณ  ์ด ์ œ์•ˆ์—์„œ๋„ ๊ฐ€์ •ํ•˜๋Š” ๊ฒƒ์ž…๋‹ˆ๋‹ค. ๋”ฐ๋ผ์„œ .permute๋ฅผ ์ „ํ˜€ ํ˜ธ์ถœํ•˜์ง€ ์•Š์„ ๊ฒƒ์ž…๋‹ˆ๋‹ค.

์ปจํ…์ŠคํŠธ ๊ด€๋ฆฌ์ž๊ฐ€ ์‚ฌ์šฉ์ž๊ฐ€ ๊ด€๋ฆฌ์ž ๋ฒ”์œ„ ๋‚ด์—์„œ ํ• ๋‹น๋œ ํ…์„œ์˜ ๋ฉ”๋ชจ๋ฆฌ ํ˜•์‹์„ ํŠน์ • ํ˜•์‹์œผ๋กœ ์žฌ์ •์˜ํ•  ์ˆ˜ ์žˆ๋„๋ก ํ•˜๋Š” ๋ฐ ์œ ์šฉํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๊นŒ?

with torch.memory_format(torch.memory_format.nhwc):
    # a will be allocated with the context managed memory format   
    a = torch.randn(...)

# b will be allocated matching some assumed default format
b = torch.randn(...)

memory_format์˜ ์ œ์–ด๋ฅผ ๋Š์Šจํ•˜๊ฒŒ ํ•˜๊ธฐ ๋•Œ๋ฌธ์— ์ปจํ…์ŠคํŠธ ๊ด€๋ฆฌ์ž์˜ ์•„์ด๋””์–ด๊ฐ€ ๋งˆ์Œ์— ๋“ค์ง€ ์•Š์Šต๋‹ˆ๋‹ค.

์˜ˆ๋ฅผ ๋“ค์–ด:

with torch.memory_format(torch.channels_last):
  x = torch.randn(10,3,32,32) # this one is NHWC
  y = torch.randn(10,10) @ this one is not

๋ช…์‹œ์  memory_format์ด ๋ช…ํ™•ํ•˜๊ฒŒ ํ•˜๋Š” ๊ฒฝ์šฐ:

x = torch.randn(10,3,32,32).to(memory_format=torch.channels_last) # this one is NHWC
y = torch.randn(10,10).to(memory_format=torch.channels_last) # This is errors out as dim == 2

ํ•„์š”ํ•œ ๊ฒฝ์šฐ ๋‹ค์Œ์„ ํ—ˆ์šฉํ•˜๋Š” ๊ตฌ๋ฌธ์„ ์ถ”๊ฐ€ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

x = torch.randn(10,3,32,32, memory_format=torch.channels_last)

@raghuramank100 ์ˆœ์—ดํ•  ํ•„์š”๊ฐ€ ์—†์Šต๋‹ˆ๋‹ค.

y = x.to(memory_format=torch.channels_last)

x์—์„œ์™€ ๊ฐ™์ด ํฌ๋ฏธํ•œ ์ˆœ์„œ๋ฅผ ์œ ์ง€ํ•˜๋ฉด์„œ ๋ชจ๋“  ๋”๋Ÿฌ์šด ์ž‘์—…์„ ์ˆ˜ํ–‰ํ•ฉ๋‹ˆ๋‹ค.

๊ทธ๋ž˜์„œ:

x = torch.randn(10, 3, 32, 32)
nhwc = x.to(memory_format=torch.channels_last)
self.assertFalse(nhwc.is_contiguous())
self.assertTrue(nhwc.is_contiguous(memory_format=torch.channels_last))
self.assertEqual(nhwc, x)

๊ทธ๋ฆฌ๊ณ  ์ด ํ˜•์‹์œผ๋กœ nhwc๋ฅผ ๊ณ„์† ์ฒ˜๋ฆฌํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

nhwc[N][C][H][W]

@VitalyFedyunin ๊ทธ๊ฒƒ์€ ์˜๋ฏธ๊ฐ€ ์žˆ์Šต๋‹ˆ๋‹ค.

์‚ฌ์šฉ์ž์˜ ๊ด€์ ์—์„œ ๋ณผ ๋•Œ ๋ฉ”์„œ๋“œ ์ด๋ฆ„(์ด๋Œ€๋กœ ์œ ์ง€๋˜๋Š” ๊ฒฝ์šฐ)์€ "to"๊ฐ€ ์ด๋ฏธ Tensor๋ฅผ ๋‹ค๋ฅธ ์žฅ์น˜๋กœ ์ „์†กํ•˜๋Š” ๋ฐ ๊ถŒ์žฅ๋˜๋Š” ๋ฐฉ๋ฒ•์ด๋ฏ€๋กœ ์˜คํ•ด์˜ ์†Œ์ง€๊ฐ€ ์žˆ๋Š” ๊ฒƒ ๊ฐ™์Šต๋‹ˆ๋‹ค.

๋˜ํ•œ C_ORDER ๋ฐ F_ORDER ๋ฐฐ์—ด์„ ๋ณ€ํ™˜ํ•˜๋Š” Numpy์˜ ๊ฒƒ๊ณผ ๊ฐ™์€ ๊ฒƒ์€ ์–ด๋–ป์Šต๋‹ˆ๊นŒ?

numpy.asfortranarray()
numpy.ascontiguousarray()

๋‹ค์Œ๊ณผ ๊ฐ™์€ ๊ฒƒ์„ ์‰ฝ๊ฒŒ ์ƒ์ƒํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

torch.randn(32, 3, 64, 64).to(device).as_nhwc()

@VitalyFedyunin : ๋‹ค๋ฅธ memory_format์œผ๋กœ ๋ณ€ํ™˜ํ•˜๋ฉด ์‚ฌ์šฉ์ž๊ฐ€ ์ˆ˜๋™์œผ๋กœ ๋ณ€๊ฒฝํ•  ํ•„์š”๊ฐ€ ์—†๋‹ค๋Š” ๊ฒƒ์„ ์ดํ•ดํ•ฉ๋‹ˆ๋‹ค. ๊ทธ๋Ÿฌ๋‚˜ ์ด ๊ธฐ๋Šฅ์„ ํ† ์น˜์—์„œ ์‚ฌ์šฉํ•  ์ˆ˜ ์žˆ๊ฒŒ ๋˜๋ฉด ์œ„์—์„œ ์„ค๋ช…ํ•œ ์ˆœ์„œ๋Œ€๋กœ ์‚ฌ์šฉ์ž๊ฐ€ ํ•จ์ˆ˜๋ฅผ ํ˜ธ์ถœํ•˜๋ฉด ์–ด๋–ป๊ฒŒ ๋ ๊นŒ์š”? ์ตœ์†Œํ•œ ๋ ˆ์ด์•„์›ƒ ๋ณ€ํ™˜์ด ์‹คํŒจํ–ˆ๋‹ค๋Š” ๊ฒฝ๊ณ /์˜ค๋ฅ˜ ๋ฉ”์‹œ์ง€๊ฐ€ ์žˆ์–ด์•ผ ํ•ฉ๋‹ˆ๋‹ค.

@VitalyFedyunin : ๋‹ค๋ฅธ memory_format์œผ๋กœ ๋ณ€ํ™˜ํ•˜๋ฉด ์‚ฌ์šฉ์ž๊ฐ€ ์ˆ˜๋™์œผ๋กœ ๋ณ€๊ฒฝํ•  ํ•„์š”๊ฐ€ ์—†๋‹ค๋Š” ๊ฒƒ์„ ์ดํ•ดํ•ฉ๋‹ˆ๋‹ค. ๊ทธ๋Ÿฌ๋‚˜ ์ด ๊ธฐ๋Šฅ์„ ํ† ์น˜์—์„œ ์‚ฌ์šฉํ•  ์ˆ˜ ์žˆ๊ฒŒ ๋˜๋ฉด ์œ„์—์„œ ์„ค๋ช…ํ•œ ์ˆœ์„œ๋Œ€๋กœ ์‚ฌ์šฉ์ž๊ฐ€ ํ•จ์ˆ˜๋ฅผ ํ˜ธ์ถœํ•˜๋ฉด ์–ด๋–ป๊ฒŒ ๋ ๊นŒ์š”? ์ตœ์†Œํ•œ ๋ ˆ์ด์•„์›ƒ ๋ณ€ํ™˜์ด ์‹คํŒจํ–ˆ๋‹ค๋Š” ๊ฒฝ๊ณ /์˜ค๋ฅ˜ ๋ฉ”์‹œ์ง€๊ฐ€ ์žˆ์–ด์•ผ ํ•ฉ๋‹ˆ๋‹ค.

์ด๊ฒƒ์€ ๋ช…๋ช…๋œ ํ…์„œ๋ฅผ ๊ตฌํ˜„ํ•  ๋•Œ๋งŒ ๊ฐ€๋Šฅํ•ฉ๋‹ˆ๋‹ค. ์ง€๊ธˆ ๋‹น์žฅ:

x.zeros(10,10,10,10)
x = x.permute(0,2,3,1)

๋‚ด๊ฐ€ ๋ฐฉ๊ธˆ nchw ๋˜๋Š” nhwc๋ฅผ ๋งŒ๋“ค์—ˆ๋Š”์ง€ ์•„๋ฌด๋„ ์•Œ ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค.

๋‚ด๊ฐ€ ์›๋ž˜ ์ œ์•ˆ์„ ์ž˜๋ชป ์ดํ•ดํ–ˆ์„ ์ˆ˜๋„ ์žˆ์ง€๋งŒ ๊ธฐ๋ก๋œ ๋ฉ”๋ชจ๋ฆฌ ํ˜•์‹ ํƒœ๊ทธ๊ฐ€ ์ด ์ƒํ™ฉ์„ ๋ช…ํ™•ํ•˜๊ฒŒ ํ•ด์•ผ ํ•˜๋Š” ๊ฒƒ ์•„๋‹Œ๊ฐ€์š”?

@VitalyFedyunin ์ด API๊ฐ€ ์•ˆ์ •ํ™”๋˜๋ฉด ์ตœ์ข… ์‚ฌ์šฉ์ž์—๊ฒŒ ์ด๋ฅผ ์ „๋‹ฌํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค.

@dzhulgakov @VitalyFedyunin #19975๋ฅผ ๊ฒ€ํ† ํ•œ ํ›„ ํ…์„œ์— ๊ธฐ๋ก๋œ ๋ฉ”๋ชจ๋ฆฌ ํ˜•์‹ ํƒœ๊ทธ์— ๋Œ€ํ•œ ๋ช‡ ๊ฐ€์ง€ ์ƒˆ๋กœ์šด ์šฐ๋ ค๊ฐ€ ์žˆ์Šต๋‹ˆ๋‹ค. ๋‚ด ๊ธฐ๋ณธ์ ์ธ ๋ฌธ์ œ๋Š” ์ž‘์—…์ด ๋ฉ”๋ชจ๋ฆฌ ํƒœ๊ทธ๋ฅผ ๋ณด์กดํ•ด์•ผ ํ•˜๋Š”์ง€ ์—ฌ๋ถ€๋ฅผ ์–ด๋–ป๊ฒŒ ๊ฒฐ์ •ํ•ด์•ผ ํ•ฉ๋‹ˆ๊นŒ? ์›๋ž˜๋Š” "๋Œ€์ฒด ๋ ˆ์ด์•„์›ƒ ์ธ์‹" ์šด์˜์ž๋งŒ ์ด๋Ÿฌํ•œ ์˜๋ฆฌํ•จ์„ ๊ฐ–์ถ”์–ด์•ผ ํ•œ๋‹ค๊ณ  ์ƒ๊ฐํ–ˆ์Šต๋‹ˆ๋‹ค. ํ•˜์ง€๋งŒ Vitaly์˜ ํŒจ์น˜๋ฅผ ๋ณด๋ฉด ์ผ๋ถ€ ํ•ต์‹ฌ ์˜คํผ๋ ˆ์ดํ„ฐ๋„ ์กฐ์ •์ด ํ•„์š”ํ•˜๋‹ค๊ณ  ์ƒ๊ฐํ•ฉ๋‹ˆ๋‹ค. ์˜ˆ๋ฅผ ๋“ค์–ด x[0] ; x๊ฐ€ ์ด์ „์— NHWC ํ…์„œ๋ผ๋ฉด ์ด ์ž‘์—…์„ ์ˆ˜ํ–‰ํ•œ ํ›„ HWC ํ…์„œ๋ฅผ ๊ฐ€์ ธ์™€์•ผ ํ•ฉ๋‹ˆ๋‹ค. ๋‚˜๋Š” Vitaly์˜ ํŒจ์น˜๊ฐ€ ์ด๊ฒƒ์„ ์˜ฌ๋ฐ”๋ฅด๊ฒŒ ์ฒ˜๋ฆฌํ•˜์ง€ ๋ชปํ•œ๋‹ค๊ณ  ํ™•์‹ ํ•˜๋ฉฐ, ์ด๋Š” ์‚ฌ์šฉ์ž์—๊ฒŒ ๋งค์šฐ ํ˜ผ๋ž€์Šค๋Ÿฌ์šธ ๊ฒƒ์ž…๋‹ˆ๋‹ค. ์•„๋งˆ๋„ ์˜ํ–ฅ์„ ๋ฐ›๋Š” ์œ ์ผํ•œ ์—ฐ์‚ฐ์ž๋Š” ๋ณดํญ์„ ์ฒ˜๋ฆฌํ•˜๋Š” ์—ฐ์‚ฐ์ž(์ด ๊ฒฝ์šฐ ๋„ˆ๋ฌด ๋งŽ์ง€ ์•Š๊ณ  ์ˆ˜๋™์œผ๋กœ ๊ฐ์‚ฌํ•  ์ˆ˜ ์žˆ์Œ)์ด์ง€๋งŒ ์šฐ๋ฆฌ๊ฐ€ ํ•ด์•ผ ํ•  ์ผ์ธ ๊ฒƒ ๊ฐ™์Šต๋‹ˆ๋‹ค. ์–ด๋–ป๊ฒŒ ์ƒ๊ฐํ•˜๋‚˜์š”?

์ž ๊น, ํ…์„œ๋Š” ์—ฌ์ „ํžˆ ๋‹ค์Œ ์ˆœ์„œ๋กœ ์ธ๋ฑ์‹ฑ๋œ ์ƒํƒœ๋ฅผ ์œ ์ง€ํ•ฉ๋‹ˆ๋‹ค. 0-dim N; 1์ฐจ์› C; 2์ฐจ์› H; 3rd-dim W. ๋”ฐ๋ผ์„œ x[0]์€ 0-dim C๋ฅผ ๊ฐ€์ง„ ํ…์„œ๋ฅผ ๋ฐ˜ํ™˜ํ•ฉ๋‹ˆ๋‹ค. 1์ฐจ์› H; 2nd-dim W. x๊ฐ€ channel_first ๋˜๋Š” channel_last ๋ฉ”๋ชจ๋ฆฌ ๋ ˆ์ด์•„์›ƒ์ธ์ง€ ์—ฌ๋ถ€์— ๊ด€๊ณ„์—†์ด.

๊ทธ๋ ‡์ง€ ์•Š์œผ๋ฉด memory_format์ด ์˜๋ฏธ๊ฐ€ ์—†์œผ๋ฉฐ ํ…์„œ๋ฅผ ์น˜ํ™˜ํ•˜๊ธฐ๋งŒ ํ•˜๋ฉด ๋ฉ๋‹ˆ๋‹ค.

๋‚ด ์š”์ ์€ ๋ฉ”๋ชจ๋ฆฌ ํ˜•์‹ ํƒœ๊ทธ๊ฐ€ ๋ณด์กด๋˜์ง€ ์•Š๋Š”๋‹ค๋Š” ๊ฒƒ์ž…๋‹ˆ๋‹ค. ์ž…๋ ฅ ํ…์„œ์— channels_last ํƒœ๊ทธ๊ฐ€ ์ง€์ •๋œ ๊ฒฝ์šฐ ์ƒˆ ํ…์„œ๋Š” any ํƒœ๊ทธ๊ฐ€ ์ง€์ •๋ฉ๋‹ˆ๋‹ค.

cc @zou3519 , ์—ฌ๊ธฐ ๋ ˆ์ด์•„์›ƒ ์ „ํŒŒ ๋…ผ๋ฆฌ๋Š” ๋ช…๋ช…๋œ ํ…์„œ ์ž‘์—…์—์„œ ๋ช…๋ช…๋œ ์ฐจ์› ์ „ํŒŒ๋ฅผ ๋งŽ์ด ์ƒ๊ฐ๋‚˜๊ฒŒ ํ•ฉ๋‹ˆ๋‹ค.

๋‚˜๋Š” ์—ฌ์ „ํžˆ ์ด ์ œ์•ˆ์„ ๋”ฐ๋ผ์žก๊ณ  ์žˆ๋‹ค. ๊ทธ๋Ÿฌ๋‚˜ @ezyang ์€ ์ฐจ์›๋ณ„ ํ”Œ๋ž˜๊ทธ(๋˜๋Š” ์ด๋ฆ„)๋ฅผ ์ „ํŒŒํ•˜์—ฌ ๋ ˆ์ด์•„์›ƒ ์ „ํŒŒ ๋…ผ๋ฆฌ๋ฅผ ์ถ”์ ํ•  ์ˆ˜ ์žˆ์œผ๋ฉฐ, ๊ทธ๋Ÿฌ๋ฉด ์ด๋ฆ„ ๊ทœ์น™์„ ์‚ฌ์šฉํ•˜์—ฌ ๋ช…๋ช…๋œ ํ…์„œ๋ฅผ ๊ฐ–๋Š” ๊ฒƒ๊ณผ ๊ฐ™์Šต๋‹ˆ๋‹ค.

๋ฉ”๋ชจ๋ฆฌ ํƒœ๊ทธ ๋กœ์ง๊ณผ ๋ช…๋ช…๋œ ํ…์„œ ๋กœ์ง์„ ์ •ํ™•ํ•˜๊ฒŒ ์ •๋ ฌํ•  ์ˆ˜ ์žˆ๋‹ค๋ฉด, ๋น„๋ก ์ฒ˜์Œ์— ๋‘ ๊ฐœ์˜ ๊ฐœ๋ณ„ ๊ตฌํ˜„ ๊ฒฝ๋กœ๊ฐ€ ์žˆ๋”๋ผ๋„ ๊น”๋”ํ•  ๊ฒƒ์ž…๋‹ˆ๋‹ค.

1๋‹จ๊ณ„

๋‘ ๊ฐœ์˜ ํ…์„œ ํ•จ์ˆ˜ .is_contiguous ๋ฐ .contiguous (python ๋ฐ C++ API ๋ชจ๋‘)์˜ ๊ธฐ๋Šฅ์„ ํ™•์žฅํ•ฉ๋‹ˆ๋‹ค.

์ฐธ๊ณ : .to(memory_format) ๊ธฐ๋Šฅ์— ๋Œ€ํ•œ ๋ช‡ ๊ฐ€์ง€ ๋ถˆ๋งŒ์ด ์žˆ์—ˆ๊ณ  ์ง€์›ํ•˜์ง€ ์•Š๊ธฐ๋กœ ๊ฒฐ์ •ํ–ˆ์Šต๋‹ˆ๋‹ค.

  1. .contiguous ์ด์ œ ์„ ํƒ์  ํ‚ค์›Œ๋“œ ์ „์šฉ ์ธ์ˆ˜์ธ memory_format ํ•ฉ๋‹ˆ๋‹ค. torch.contiguous_format ๋˜๋Š” torch.channels_last ์žˆ์Šต๋‹ˆ๋‹ค.

    • torch.contiguous_format ํ•˜๋ฉด ๊ธฐ์กด .contiguous() ๋™์ž‘์ด ์œ ์ง€๋ฉ๋‹ˆ๋‹ค.

    • x.contiguous(memory_format=torch.channels_last) ํ˜ธ์ถœํ•˜๋ฉด ๋™์ผํ•œ ์˜๋ฏธ์  ๋ ˆ์ด์•„์›ƒ(NCHW)์„ ์œ ์ง€ํ•˜์ง€๋งŒ ๋ฉ”๋ชจ๋ฆฌ ํ• ๋‹น ํŒจํ„ด์ด ๋‹ค๋ฅธ ์ƒˆ ํ…์„œ๋ฅผ ๋ฐ˜ํ™˜ํ•ฉ๋‹ˆ๋‹ค.

      x.contiguous(memory_format=torch.channels_last) ๋Š” ์ž…๋ ฅ ํ…์„œ๊ฐ€ 3d, 4d ๋˜๋Š” 5d์ผ ๊ฒƒ์œผ๋กœ ์˜ˆ์ƒํ•ฉ๋‹ˆ๋‹ค. ๊ทธ๋ ‡์ง€ ์•Š์œผ๋ฉด ์‹คํŒจํ•ฉ๋‹ˆ๋‹ค.

  2. .is_contiguous ์ด์ œ ์„ ํƒ์  ํ‚ค์›Œ๋“œ ์ „์šฉ ์ธ์ˆ˜์ธ memory_format ํ•ฉ๋‹ˆ๋‹ค. torch.contiguous_format ๋˜๋Š” torch.channels_last ์žˆ์Šต๋‹ˆ๋‹ค.

    • x.is_contiguous(memory_format=torch.contiguous_format) ๋Š” x.is_contiguous() ์™€ ๋™์ผํ•œ ๊ธฐ๋Šฅ์„ ์œ ์ง€ํ•˜๋ฉฐ ๋ณ€๊ฒฝ๋˜์ง€ ์•Š์€ ์ƒํƒœ๋กœ ์œ ์ง€๋ฉ๋‹ˆ๋‹ค.

    • x.is_contiguous(memory_format=torch.channels_last) ๋Š” A) ์ž…๋ ฅ ํ…์„œ๊ฐ€ ๋ฉ”๋ชจ๋ฆฌ์—์„œ ์—ฐ์†์ ์ด๊ณ  B) NWHC(๋˜๋Š” 3d,5d์˜ ๊ฒฝ์šฐ ์œ ์‚ฌ) ํ˜•์‹์œผ๋กœ ๋ฉ”๋ชจ๋ฆฌ์— ํ• ๋‹น๋œ ๊ฒฝ์šฐ true๋ฅผ ๋ฐ˜ํ™˜ํ•ฉ๋‹ˆ๋‹ค.

์ฐธ๊ณ : 1๋‹จ๊ณ„๊ฐ€ ๋๋‚  ๋•Œ x.is_contiguous(memory_format=torch.channels_last) ๋Š” ๋ชจ๋“  ํ˜ธ์ถœ์—์„œ Tensor์˜ ์ƒํƒœ๋ฅผ ๊ณ„์‚ฐํ•ฉ๋‹ˆ๋‹ค. ์ด ๊ธฐ๋Šฅ์€ ๋‚˜์ค‘์— ์—…๋ฐ์ดํŠธ๋  ์˜ˆ์ •์ž…๋‹ˆ๋‹ค.

2 ๋‹จ๊ณ„

ํŠน์ • ์ž‘์—…์— ๋Œ€ํ•œ ๋ฉ”๋ชจ๋ฆฌ ํ˜•์‹ ์œ ์ง€:

  1. ๋‹จํ•ญ ์š”์†Œ๋ณ„ ์—ฐ์‚ฐ์ž๋Š” channel_last ๋ฉ”๋ชจ๋ฆฌ ํ˜•์‹์„ ์œ ์ง€ํ•ฉ๋‹ˆ๋‹ค.

    a = torch.randn(N,C,H,W)
    b = a.contiguous(memory_format=torch.channels_last)
    c = b.sin()
    c.is_contiguous(memory_format=torch.channels_last) == True
    
  2. ์ด์ง„ ์š”์†Œ๋ณ„ ์—ฐ์‚ฐ์ž( add , sub , mul , div )๋Š” channels_last ๋ฉ”๋ชจ๋ฆฌ ํ˜•์‹์„ ์œ ์ง€ํ•ฉ๋‹ˆ๋‹ค.

    a = torch.randn(N,C,H,W)
    b = a.contiguous(memory_format=torch.channels_last)
    c = b * torch.randn(H,W)
    c.is_contiguous(memory_format=torch.channels_last) == True
    
  3. ํฌ๊ธฐ, ๋ณดํญ ๋ฐ ํ๋ฆผ์— ๋Œ€ํ•œ ๋ชจ๋“  ์ž‘์—…์€ ๋ฉ”๋ชจ๋ฆฌ ํ˜•์‹์„ ์žฌ์„ค์ •ํ•ฉ๋‹ˆ๋‹ค.

    a = torch.randn(N,C,H,W)
    b = a.contiguous(memory_format=torch.channels_last)
    c = b.permute(0,2,3,1).permute(0,3,1,2)
    c.is_contiguous(memory_format=torch.channels_last) == False
    

๋ฏธ์ •

  1. ์ถœ๋ ฅ์ด ์ฝ์„ ์ˆ˜ ์žˆ๋Š” 'channels_last'์ธ ๊ฒฝ์šฐ ๋ชจ์–‘ ๋ณ€๊ฒฝ(๋ฐ ์œ ์‚ฌ) ์ž‘์—…์˜ ๊ฒฐ๊ณผ

    import torch
    a = torch.randn(N,C,H,W)
    b = a.contiguous(memory_format=torch.channels_last)
    c = b.reshape(N,C,-1)
    c.is_contiguous(memory_format=torch.channels_last) # ?
    

    ์ฐธ๊ณ : ํ˜„์žฌ memory_format์ด ๋ณด์กด๋˜์ง€ ์•Š์•˜์Šต๋‹ˆ๋‹ค.

  2. NHWC + NCHW ์ž‘์—…์˜ ๊ฒฐ๊ณผ์ž…๋‹ˆ๋‹ค. NHWC์ธ๊ฐ€์š”?

    ์ฐธ๊ณ : ํ˜„์žฌ NHWC + NCHW -> NHWC ๋ฐ NCHW + NHWC -> NHWC

cat/split๊ณผ ๊ฐ™์€ ์ž‘์—…์€ ์–ด๋–ป์Šต๋‹ˆ๊นŒ? ๋ฉ”๋ชจ๋ฆฌ ํ˜•์‹์„ ์œ ์ง€ํ•˜๋Š” ๋ฐ ์œ ์šฉํ•ฉ๋‹ˆ๋‹ค.

@ezyang - ์ธ๋ฑ์‹ฑ๊ณผ ๊ด€๋ จํ•˜์—ฌ ์–ด๋”˜๊ฐ€์—์„œ ๋ฉˆ์ถฐ์•ผ ํ•œ๋‹ค๊ณ  ์ƒ๊ฐํ•ฉ๋‹ˆ๋‹ค. ๋‹ค๋ฅธ ๋ฉ”๋ชจ๋ฆฌ ๋ ˆ์ด์•„์›ƒ์€ ์™„์ „ํžˆ ํˆฌ๋ช…ํ•˜์ง€ ์•Š์œผ๋ฉฐ ์ผ๋ถ€ ์ž‘์—…์—์„œ๋Š” ์ด๋ฅผ ๋ฌด์‹œํ•˜๋„๋ก ํ—ˆ์šฉํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค. x[0] ๋Š” x[0].unsqueeze(0) ํฌํ•จํ•˜์—ฌ ํƒœ๊ทธ๋ฅผ ์ง€์šธ ์ˆ˜ ์žˆ์–ด์•ผ ํ•œ๋‹ค๊ณ  ์ฃผ์žฅํ•ฉ๋‹ˆ๋‹ค.

Raghu๊ฐ€ ์–ธ๊ธ‰ํ–ˆ๋“ฏ์ด cat/split์€ ๋งค์šฐ ์ผ๋ฐ˜์ ์ธ ์‚ฌ์šฉ๋ฒ•์ด์ง€๋งŒ ๊ฐ€๋Šฅํ•˜๋ฉด ํƒœ๊ทธ๋ฅผ ๋ณด์กดํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค. ์ผ๋ฐ˜์ ์ธ ๊ฒฝํ—˜ ๋ฒ•์น™์€ ์šด์˜์ด ์ˆœ์œ„๋ฅผ ๋ณ€๊ฒฝํ•˜๊ฑฐ๋‚˜ ์ถ•์„ ์ด์ƒํ•˜๊ฒŒ ์žฌ์ •๋ ฌํ•˜์ง€ ์•Š๋Š” ํ•œ ํƒœ๊ทธ๋ฅผ ์œ ์ง€ํ•ด์•ผ ํ•œ๋‹ค๋Š” ๊ฒƒ์ž…๋‹ˆ๋‹ค. ์ˆœ์œ„๊ฐ€ ๋ณ€๊ฒฝ๋˜๋ฉด ๋ชจ๋“  ๋ฒ ํŒ…์ด ํ•ด์ œ๋ฉ๋‹ˆ๋‹ค.

์–ด๋–ค ๊ฒฝ์šฐ์—๋Š” ํƒœ๊ทธ๋ฅผ ์žƒ๊ฒŒ ๋œ๋‹ค๋Š” ๋ฐ ๋™์˜ํ•ฉ๋‹ˆ๋‹ค. ํ•˜์ง€๋งŒ x[0] ๋Œ€ํ•ด์„œ๋Š” ๋™์˜ํ•˜์ง€ ์•Š์Šต๋‹ˆ๋‹ค. ๋‚˜์—๊ฒŒ ๊ทธ๊ฒƒ์€ NCHW ์—์„œ CHW ๋กœ ๊ฐ€๋Š” ๋งค์šฐ ์ผ๋ฐ˜์ ์ธ ๋ฐฉ๋ฒ•์ธ ๊ฒƒ ๊ฐ™์Šต๋‹ˆ๋‹ค.

ํ…์„œ๊ฐ€ channel_last 'ํƒœ๊ทธ'๋ฅผ ์ „๋‹ฌ(๋˜๋Š” ํฌํ•จํ•˜์ง€ ์•Š์Œ)ํ•˜๋Š” ๊ฒƒ์ด ์–ผ๋งˆ๋‚˜ ํ˜ผ๋ž€์Šค๋Ÿฌ์šด์ง€์— ๋Œ€ํ•ด ๋ช‡ ์ฐจ๋ก€ ๋Œ€ํ™”๋ฅผ ๋‚˜๋ˆˆ ํ›„ ์šฐ๋ฆฌ๋Š” bc-๋ธŒ๋ ˆ์ดํ‚น ๋ณ€๊ฒฝ์„ ๋„์ž…ํ•˜๊ณ  ํ…์„œ๋ฅผ channel_last ํ˜•์‹์œผ๋กœ ์ž๋™ ์Šน๊ฒฉํ•˜๋Š” ์œ„ํ—˜์„ ๊ฐ์ˆ˜ํ•˜๊ธฐ๋กœ ๊ฒฐ์ •ํ–ˆ์Šต๋‹ˆ๋‹ค.

API์— ๋Œ€ํ•œ ์˜๋ฏธ:

N,1,H,[W,[D]]์™€ ๊ฐ™์€ ์ŠคํŠธ๋ผ์ด๋“œ๊ฐ€ ์žˆ๋Š” ๋ชจ๋“  3d,4d,5d ํ…์„œ๋Š” ์ž๋™์œผ๋กœ channel_last ๋ฉ”๋ชจ๋ฆฌ ํ˜•์‹์„ ์–ป์Šต๋‹ˆ๋‹ค.

์ž‘๋™ํ•˜๋„๋ก ํ•˜๊ธฐ ์œ„ํ•ด channel_last ํ…์„œ๋ฅผ ์ถœ๋ ฅํ•˜๋Š” channel_last ํ…์„œ์˜ ์—ฐ์‚ฐ์ž๊ฐ€ ์ธ์ ‘ํ•œ ํ…์„œ์˜ ์—ฐ์‚ฐ์ž์™€ ์ตœ์†Œํ•œ ๋น„์Šทํ•œ ์„ฑ๋Šฅ์„ ๊ฐ–๋„๋ก ๋ณด์žฅํ•˜๊ธฐ ์œ„ํ•ด ํŠน๋ณ„ํ•œ ์˜ˆ๋ฐฉ ์กฐ์น˜๋ฅผ ์ทจํ•  ๊ฒƒ์ž…๋‹ˆ๋‹ค.

์ตœ์•…์˜ ์‹œ๋‚˜๋ฆฌ์˜ค์˜ ๊ฒฝ์šฐ:
1) ์‚ฌ์šฉ์ž๋Š” ์ถœ๋ ฅ์—์„œ โ€‹โ€‹.contiguous()๋ฅผ ํ˜ธ์ถœํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.
2) ์šฐ๋ฆฌ๋Š” ์ด ๋™์ž‘์„ ๋ณ€๊ฒฝํ•˜๋Š” ๊ฒƒ์ด ๊ฑฐ์˜ ์‚ฌ์†Œํ•œ ๋ฐฉ์‹์œผ๋กœ ์ž๋™ ์Šน๊ฒฉ ์ฝ”๋“œ๋ฅผ ์ž‘์„ฑํ•  ๊ฒƒ์ž…๋‹ˆ๋‹ค.

์ด๋Ÿฌํ•œ ์ž๋™ ํ”„๋กœ๋ชจ์…˜์˜ ๋ถ€์ž‘์šฉ์€ ๋‹ค์Œ๊ณผ ๊ฐ™์Šต๋‹ˆ๋‹ค.

import torch
x = torch.randn(10,16,16,3).permute(0,3,1,2) 
x.is_contiguous(memory_format=torch.channels_last) == True

๋‹ค๋ฅธ ํ•œํŽธ์œผ๋กœ (๊ฐ€๋ฒผ์šด ์ˆ˜์ • ํ›„) ๊ฒฝ์šฐ๋ฅผ ํ•ด๊ฒฐํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

import torch
x = torch.randn(10,3,16,16).contiguous(memory_format=torch.channels_last)
x = x[0].unsqueeze(0)
x.is_contiguous(memory_format=torch.channels_last) == True

@ezyang ์˜ ์š”์ฒญ์— ๋”ฐ๋ผ slack ๋ณ€ํ™˜์—์„œ

๋‚˜ํƒˆ๋ฆฌ์•„ ๊ธฐ๋ฉœ์ƒค์ธ [2:19 PM]
๊ทธ๋ž˜์„œ ๋‚˜๋Š” ํƒœ๊ทธ์˜ ๊ฐœ๋…์ด ์—†์„ ๊ฒƒ์ด๋ผ๊ณ  ์ƒ๊ฐํ•ฉ๋‹ˆ๋‹ค.

import torch
#batch = 10, channels = 4, spatial dimensions = 16
x = torch.randn(10,16,16,4).permute(0,3,1,2)
x.is_contiguous(memory_format=torch.channels_last) == True
y = torch.randn(10,16,16,2).permute(0,3,1,2)
x1,x2 = x.chunk(2, dim=1) #chunk along channels dimension, no longer contiguous
x1.is_contiguous(memory_format=torch.channels_last) == False #right? So, if a tensor like this comes into e.g. convolution, what am I supposed to do with it? Did it want to be NHWC? Did it want to be nchw?
z=y+x1 #y is channels_last, x1 is something, what is the z layout?```

๋น„ํƒˆ๋ฆฌ ํŽ˜๋‘๋‹Œ [์˜ค์ „ 8์‹œ 23๋ถ„]
z๋Š” channel_last๊ฐ€ ๋  ๊ฒƒ์ž…๋‹ˆ๋‹ค.

๋น„ํƒˆ๋ฆฌ ํŽ˜๋‘๋‹Œ [์˜ค์ „ 8์‹œ 25๋ถ„]
x1์ด ์ œ์•ˆ๋œ ๋ณ€ํ˜•์—์„œ channel_last๊ฐ€ ์•„๋‹Œ ๊ฒฝ์šฐ(์ฒญํฌ ๊ธฐ๋Šฅ์„ ๋ณ€๊ฒฝํ•˜์—ฌ ๋ทฐ๋ฅผ ๋ฐ˜ํ™˜ํ•˜์ง€ ์•Š๋Š” ํ•œ), ์ปจ๋ณผ๋ฃจ์…˜์€ ์ด๋ฅผ ์—ฐ์†(channel_first) ํ˜•์‹์œผ๋กœ ๋ณ€ํ™˜ํ•˜๊ณ  ์—ฐ์†๋„ ๋ฐ˜ํ™˜ํ•ฉ๋‹ˆ๋‹ค.

๋น„ํƒˆ๋ฆฌ ํŽ˜๋‘๋‹Œ [์˜ค์ „ 9:12]
@ngimel ํ”ผ๋“œ๋ฐฑ ์ฃผ์…”์„œ ๊ฐ์‚ฌํ•ฉ๋‹ˆ๋‹ค. ๋ณด๊ธฐ์™€ ๊ฐ™์€ ์ž‘์—…์ด ๊ด€๋ จ๋œ ๋Œ€๋ถ€๋ถ„์˜ ๊ฒฝ์šฐ๋ฅผ ๋‹ค๋ฃจ๊ธฐ ์œ„ํ•ด ๋ณด๋‹ค ์˜๋ฏธ ์žˆ๋Š”

๋‚˜ํƒˆ๋ฆฌ์•„ ๊ธฐ๋ฉœ์ƒค์ธ [์˜ค์ „ 9์‹œ 36๋ถ„]
์Šค๋ ˆ๋“œ์— ๋‹ต๋ณ€:
๊ทธ๋ž˜์„œ ๋ฌธ์ œ์ธ ๊ฒƒ ๊ฐ™์ฃ ? ์ฑ„๋„ ์ฐจ์›์—์„œ ์ฒญํฌํ•˜๋Š” ๊ฒƒ์€ ์˜ˆ๋ฅผ ๋“ค์–ด ์‹œ์ž‘๊ณผ ๊ฐ™์€ ๋„คํŠธ์›Œํฌ์—์„œ ๋น„๊ต์  ์ผ๋ฐ˜์ ์ธ ๊ฒƒ์ž…๋‹ˆ๋‹ค. ๋”ฐ๋ผ์„œ ํ…์„œ๊ฐ€ ์ฒญํฌ๋œ ์ฑ„๋„ ์ฒซ ๋ฒˆ์งธ ํ…์„œ์ธ ๊ฒฝ์šฐ ํšŒ์„  ์ถœ๋ ฅ์€ ์ฑ„๋„ ์šฐ์„ (์ง๊ด€์ ์ธ ๋™์ž‘์ด๋ฉฐ ์‚ฌ์šฉ์ž๊ฐ€ ์›ํ•˜๋Š” ๊ฒƒ์ผ ์ˆ˜ ์žˆ์Œ)์ด ๋˜๊ณ , ํ…์„œ๊ฐ€ ์ฒญํฌ๋œ ์ฑ„๋„ ๋งˆ์ง€๋ง‰์ธ ๊ฒฝ์šฐ ํšŒ์„  ์ถœ๋ ฅ์€ ๋‹ค์‹œ ํ•œ ๋ฒˆ ์ฑ„๋„ ์ฒซ ๋ฒˆ์งธ๊ฐ€ ๋ ๊นŒ์š”?

๋‚˜ํƒˆ๋ฆฌ์•„ ๊ธฐ๋ฉœ์ƒค์ธ [์˜ค์ „ 9:39]
์Šค๋ ˆ๋“œ์— ๋‹ต๋ณ€:
๊ทธ๋Ÿฌ๋‚˜ ๋น„ ๊ตํ™˜ ๋ง์…ˆ ๋™์ž‘๊ณผ y ๊ฐ€ ์ฒซ ๋ฒˆ์งธ ์ธ์ˆ˜์ด๊ณ  ์ฑ„๋„์ด ๋งˆ์ง€๋ง‰์ด๊ธฐ ๋•Œ๋ฌธ์ž…๋‹ˆ๋‹ค. ๋งž์Šต๋‹ˆ๊นŒ? x1+y ์˜ ๊ฒฐ๊ณผ๋Š” ๋ฌด์—‡์ž…๋‹ˆ๊นŒ? ์–ด๋”˜๊ฐ€์— ์ด์ง„ ์—ฐ์‚ฐ์— ๋Œ€ํ•œ ๋ ˆ์ด์•„์›ƒ ์ „ํŒŒ ๊ทœ์น™์ด ์žˆ์Šต๋‹ˆ๊นŒ?

๋น„ํƒˆ๋ฆฌ ํŽ˜๋‘๋‹Œ [์˜ค์ „ 10:44]
1) ๋„ค, ๋Œ€์•ˆ ์ œ์•ˆ์œผ๋กœ ํ•ด๊ฒฐํ•  ๋ฌธ์ œ์ž…๋‹ˆ๋‹ค. ๋‚˜๋Š” ์ง€๊ธˆ ๋ช‡ ๊ฐ€์ง€ ํ…Œ์ŠคํŠธ๋ฅผ ํ•˜๊ณ  ์žˆ๊ณ  ์ด๋ฒˆ ์ฃผ์— ๊ธฐ๋กํ•  ๊ฒƒ์ž…๋‹ˆ๋‹ค(ํ•˜๋ฃจ๋‚˜ ์ดํ‹€ ํ›„์—).
2) x1+y - ๋˜ํ•œ channel_last๋ฅผ ์ƒ์„ฑํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค. ๊ทธ๋ ‡์ง€ ์•Š์œผ๋ฉด ํ˜ผ๋ž€์Šค๋Ÿฝ์Šต๋‹ˆ๋‹ค. ์˜ˆ, ๋ ˆ์ด์•„์›ƒ ์ „ํŒŒ ๊ทœ์น™์ด ๊ธฐ๋ก๋ฉ๋‹ˆ๋‹ค.

์šฐ๋ฆฌ๊ฐ€ ์ด ๋Œ€๋ฉด์— ๋Œ€ํ•ด ์ด์•ผ๊ธฐํ•  ๋•Œ @VitalyFedyunin ์—๊ฒŒ ๊ด€์ฐฐํ•œ ๋‚ด์šฉ์€ (ํ•˜์ง€๋งŒ ์–ด๋””์—๋„ ์ด๊ฒƒ์„ ์ ์–ด๋‘๋Š” ๊ฒƒ์„ ๊ธฐ์–ตํ•˜์ง€ ๋ชปํ–ˆ๋‹ค๊ณ  ์ƒ๊ฐํ•ฉ๋‹ˆ๋‹ค) ์ปจ๋ณผ๋ฃจ์…˜์—๋Š” ์–ด๋Š ์ •๋„์˜ ์ž์œ ๊ฐ€ ์žˆ๋‹ค๋Š” ๊ฒƒ์ž…๋‹ˆ๋‹ค. ๋ฉ”๋ชจ๋ฆฌ ๋ ˆ์ด์•„์›ƒ์ด ํšจ์œจ์ ์œผ๋กœ ๊ตฌํ˜„ํ•˜๋Š” ๋ฐฉ๋ฒ•์„ ์•Œ๊ณ  ์žˆ๋Š” ์–ด๋–ค ๊ฒƒ๊ณผ๋„ ์ผ์น˜ํ•˜์ง€ ์•Š๋Š” ์ธ์ˆ˜, ์–ด๋–ค ๋ ˆ์ด์•„์›ƒ์— ์—ฐ๊ฒฐํ•ด์•ผ ํ•ฉ๋‹ˆ๊นŒ? BC ์ด์œ ๋กœ ์ฑ„๋„์„ ๋จผ์ € ์—ฐ๊ฒฐํ•ด์•ผ ํ•˜์ง€๋งŒ ์—ฌ๊ธฐ์„œ ์ž„์˜์˜ ๊ฒฐ์ •์„ ๋‚ด๋ ธ์Šต๋‹ˆ๋‹ค. ์•„๋งˆ๋„ ์ฑ„๋„์— ๋งˆ์ง€๋ง‰์œผ๋กœ ์—ฐ๊ฒฐํ•  ์ˆ˜๋„ ์žˆ์„ ๊ฒƒ์ž…๋‹ˆ๋‹ค. ๊ธฐ๋ณธ๊ฐ’์ด ๋ฌด์—‡์ธ์ง€ ์•Œ๋ ค์ฃผ๋Š” ์ผ์ข…์˜ ์Šค๋ ˆ๋“œ ๋กœ์ปฌ ํ† ๊ธ€์ด ์žˆ์–ด์•ผ ํ• ๊นŒ์š”?

๊ทธ๋Ÿฌ๋‚˜ ์—ฌ๊ธฐ์—๋Š” ๋งŽ์€ ์„ธ๋ถ€ ์‚ฌํ•ญ์ด ์žˆ๋Š” ๊ฒƒ ๊ฐ™์œผ๋ฉฐ ๊ฒฐ๊ตญ ์ž˜ ๋ ์ง€ ๋ชจ๋ฅด๊ฒ ์Šต๋‹ˆ๋‹ค.

๋”ฐ๋ผ์„œ ์ปจ๋ณผ๋ฃจ์…˜์˜ ํ๋ฆฟํ•จ(๋ฐ ๊ธฐํƒ€ ๋ ˆ์ด์•„์›ƒ ์ธ์‹ ์—ฐ์‚ฐ์ž, ์˜ˆ๋ฅผ ๋“ค์–ด ์ตœ๊ทผ์— ๋ณธ ์—…์ƒ˜ํ”Œ๋ง์€ ์ž…๋ ฅ์—์„œ .contiguous()์„ ํ˜ธ์ถœํ•˜์—ฌ ์‹œ์ž‘ํ•ฉ๋‹ˆ๋‹ค. ๊ทธ๋ž˜์„œ ๊ทธ๊ฒƒ์ด ์˜๋ฏธํ•˜๋Š” ๋ฐ”๊ฐ€ ๋ฌด์—‡์ž…๋‹ˆ๊นŒ?)์ด ์ฃผ๋œ ์ด์œ ์˜€์Šต๋‹ˆ๋‹ค. ํƒœ๊ทธ๋ฅผ ์†Œ๊ฐœํ•˜๊ธฐ ์œ„ํ•ด iirc.

๋„ค, ๊ทธ๋ž˜์„œ ํƒœ๊ทธ ๋””์ž์ธ์„ ๋‹ค์‹œ ์—ด์–ด๋„ ๊ดœ์ฐฎ์Šต๋‹ˆ๋‹ค. ํ•˜์ง€๋งŒ ์šฐ๋ฆฌ๋Š”
์ด๋Ÿฌํ•œ ํƒœ๊ทธ๋ฅผ ์ „ํŒŒํ•˜๋Š” ๋ฐฉ๋ฒ•์˜ ๋ฌธ์ œ๋ฅผ ์‹ฌ๊ฐํ•˜๊ฒŒ ํ•ด๊ฒฐํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค.
๋ ˆ์ด์•„์›ƒ์„ ์žƒ์–ด๋„ (์ฒญํ‚น์˜ ๊ฒฝ์šฐ์™€ ๊ฐ™์ด
์ฑ„๋„). ๋‚˜๋Š” "ํ˜„์žฌ ๋ ˆ์ด์•„์›ƒ"์„ ๋งŒ๋“œ๋Š” ๊ฒƒ์„ ํ›จ์”ฌ ๋” ์ข‹์•„ํ•ฉ๋‹ˆ๋‹ค.
๋ฐ์ดํ„ฐ ์ข…์†์„ฑ์„ ๋งŒ๋“œ๋Š” ๊ฒƒ๋ณด๋‹ค ์ผ์ข…์˜ ์ปจํ…์ŠคํŠธ ๊ด€๋ฆฌ์ž์ž…๋‹ˆ๋‹ค.

2019-06-19 12:43:45 -0700์˜ ngimel ๋ฉ”์‹œ์ง€์—์„œ ๋ฐœ์ทŒ:

๋”ฐ๋ผ์„œ ์ปจ๋ณผ๋ฃจ์…˜์˜ ํ๋ฆฟํ•จ(๋ฐ ๊ธฐํƒ€ ๋ ˆ์ด์•„์›ƒ ์ธ์‹ ์—ฐ์‚ฐ์ž, ์˜ˆ๋ฅผ ๋“ค์–ด ์ตœ๊ทผ์— ๋ณธ ์—…์ƒ˜ํ”Œ๋ง์€ ์ž…๋ ฅ์—์„œ .contiguous()์„ ํ˜ธ์ถœํ•˜์—ฌ ์‹œ์ž‘ํ•ฉ๋‹ˆ๋‹ค. ๊ทธ๋ž˜์„œ ๊ทธ๊ฒƒ์ด ์˜๋ฏธํ•˜๋Š” ๋ฐ”๊ฐ€ ๋ฌด์—‡์ž…๋‹ˆ๊นŒ?)์ด ์ฃผ๋œ ์ด์œ ์˜€์Šต๋‹ˆ๋‹ค. ํƒœ๊ทธ๋ฅผ ์†Œ๊ฐœํ•˜๊ธฐ ์œ„ํ•ด iirc.

BTW ์™œ ์šฐ๋ฆฌ๋Š” layout ์ง‘์ฐฉํ•˜๋Š” ๋Œ€์‹  ์ƒˆ๋กœ์šด ๊ฐœ๋…์„ ๋งŒ๋“ค์–ด์•ผ ํ•ฉ๋‹ˆ๊นŒ? ํฌ์†Œ ํ‘œํ˜„์€ "channels_last"์™€ ๊ฐ™์€ ๋ ˆ์ด์•„์›ƒ ๊ฐœ๋…์ด ์ž˜ ์ •์˜๋˜์–ด ์žˆ์ง€ ์•Š๋‹ค๊ณ  ์ƒ๊ฐํ•˜๋ฏ€๋กœ memory_formats * layouts ์˜ ์ œํ’ˆ์„ ๋‚˜ํƒ€๋‚ผ ํ•„์š”๊ฐ€ ์—†์Šต๋‹ˆ๋‹ค( layouts ๋Š” ํ˜„์žฌ ์‚ฌ์šฉ์„ ๋‚˜ํƒ€๋ƒ…๋‹ˆ๋‹ค. ), ํ•˜์ง€๋งŒ memory_format + layouts ์‚ฌ์šฉํ•˜๋ฉด ์ด์ „๊ณผ ๋™์ผํ•œ ์ธ์ˆ˜๋ฅผ ์‚ฌ์šฉํ•˜๋Š” ๊ฒƒ์ด ์ข‹์Šต๋‹ˆ๋‹ค. ๋‚˜์—๊ฒŒ ๊ทธ๊ฒƒ์€ ๋” ์งง๊ณ  ๋” ์ข‹์œผ๋ฉฐ ํŒฉํ† ๋ฆฌ ์„œ๋ช…์„ ์ˆ˜์ฒœ ๊ฐœ์˜ ์ธ์ˆ˜๋กœ ํ™•์žฅํ•˜๋Š” ๊ฒƒ์„ ํ”ผํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

๋ ˆ์ด์•„์›ƒ ์˜ต์…˜์ด ๊ณ ๋ ค๋˜์—ˆ์ง€๋งŒ(๋ถ€๋ก ํ™•์ธ) ๋งŽ์€ ์ฝ”๋“œ ์ค‘๋ณต์ด ๋ฐœ์ƒํ•˜๊ณ  ํ…์„œ๋ฅผ ์ฆ‰์‹œ ๋‹ค๋ฅธ memory_format์œผ๋กœ ์ž๋™ ๋ณ€ํ™˜ํ•˜๋Š” ๊ฒƒ์„ ํ—ˆ์šฉํ•˜์ง€ ์•Š์Šต๋‹ˆ๋‹ค.

๊ฒฐ๊ตญ memory_format์€ stride ํ…์„œ๋ฅผ ์‚ฌ์šฉํ•˜๊ณ  ์™„์ „ํžˆ ๋‹ค๋ฅธ ํด๋ž˜์Šค๊ฐ€ ์•„๋‹ˆ๋ผ strided ํ…์„œ์˜ ์†์„ฑ์ธ ์ตœ์ ํ™”๋œ ์ปค๋„๊ณผ ์ถœ๋ ฅ์„ ์‰ฝ๊ฒŒ ์„ ํƒํ•˜๋Š” ๋ฐฉ๋ฒ•์ž…๋‹ˆ๋‹ค.

์–ด๋–ค ์˜๋ฏธ์—์„œ ํฌ์†Œ ๋ ˆ์ด์•„์›ƒ์€ ๋Œ€๋ถ€๋ถ„์ด 0์ธ ๋ฐฐ์—ด์— ๋Œ€ํ•ด ์ตœ์ ํ™”๋œ ์ปค๋„์„ ์‰ฝ๊ฒŒ ์„ ํƒํ•˜๋Š” ๋ฐฉ๋ฒ•์ด๊ธฐ๋„ ํ•ฉ๋‹ˆ๋‹ค.

์ด๊ฒƒ์€ ์ˆœ์ง„ํ•œ ์งˆ๋ฌธ์ผ ์ˆ˜ ์žˆ์ง€๋งŒ, PyTorch๊ฐ€ ์ด API๋ฅผ ๊ณ ๋ คํ•˜๋Š” ๊ฒƒ๊ณผ ์‚ฌ์šฉ ๊ฐ€๋Šฅํ•œ ๊ฒฝ์šฐ ๊ธฐ๋ณธ CuDNN ์ปค๋„์„ ์ง์ ‘ ํ˜ธ์ถœํ•˜๋Š” ์ž‘์—… ์ž์ฒด์—์„œ NHWC๋ฅผ ์‚ฌ์šฉํ•˜๋Š” ์˜ต์…˜์„ ๋…ธ์ถœํ•˜๋Š” ์ด์œ ๋Š” ๋ฌด์—‡์ž…๋‹ˆ๊นŒ?

์ผ๋ฐ˜์ ์ธ ์‚ฌ์šฉ ์‚ฌ๋ก€(conv ๋ฐ LM ์•„ํ‚คํ…์ฒ˜์™€ ํ’€๋ง๊ณผ ๊ฐ™์€ ์ด๋ฏธ์ง€ ์ž‘์—… ํ˜ผํ•ฉ)์˜ ๊ฒฝ์šฐ ์ด๊ฒƒ์ด ์‰ฌ์šด ์†”๋ฃจ์…˜์ธ ๊ฒƒ ๊ฐ™์Šต๋‹ˆ๋‹ค. ๊ฐœ๋ฐœ์ž๋กœ์„œ ๋‚ด๊ฐ€ ์›ํ•˜๋Š” ๊ฒƒ์€ Conv2d(..., nhwc=True) . ์ด๊ฒŒ ๋ง์ด ์•ˆ ๋˜๋Š” ์ด์œ ๊ฐ€ ์žˆ๋‚˜์š”?

@rewonc ์šฐ๋ฆฌ๋Š” ์œ ์‚ฌํ•œ ์ ‘๊ทผ ๋ฐฉ์‹(

  • ์ด ์ ‘๊ทผ ๋ฐฉ์‹์€ ์ปค๋„์ด NHWC ์ปค๋„์„ ์ ์šฉํ•˜๊ธฐ ์œ„ํ•ด ์—ฐ์†์ ์ธ ํ…์„œ์˜ ์žฌ์ŠคํŠธ๋ผ์ด๋”ฉ์„ ์ˆ˜ํ–‰ํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค.
  • ๋‹ค์Œ ์—ฐ์‚ฐ์ž๋Š” nhwc=True ์˜ต์…˜์ด ์—†๋Š” ํ•œ ์ž…๋ ฅ์„ ๋‹ค์‹œ ์žฌ์ •์˜ํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค(์—ฐ์†์ ์œผ๋กœ).
  • ๋„คํŠธ์›Œํฌ๋ฅผ ํ†ตํ•ด NHWC๋ฅผ ์‚ฌ์šฉํ•˜๋ ค๋ฉด ๋ชจ๋“  ๋‹จ์ผ ์šด์˜์ž์—๊ฒŒ nhwc=True ์˜ต์…˜์ด ํ•„์š”ํ•ฉ๋‹ˆ๋‹ค.

์ถ”์‹ . CudNN Ex ํ•จ์ˆ˜๊ฐ€ ๊ฑฑ์ •๋œ๋‹ค๋ฉด cudnn_batch_norm_nhwc ๋ฐ ์œ ์‚ฌํ•œ ์—ฐ์‚ฐ์ž๋ฅผ ๋…ธ์ถœํ•˜๋ ค๊ณ  ํ•ฉ๋‹ˆ๋‹ค.

์•ˆ๋…•ํ•˜์„ธ์š” @VitalyFedyunin๋‹˜ , ์šฐ๋ฆฌ๋Š” ๋ช…๋ช…๋œ ํ…์„œ๊ฐ€ PyTorch 1.3์—์„œ ์ง€์›๋˜๋Š” ๊ฒƒ์„ ๋ณด์•˜์Šต๋‹ˆ๋‹ค. NHWC(๋˜๋Š” ์ฐจ๋‹จ๋œ) ํ˜•์‹ ์ง€์›์— ๋Œ€ํ•œ ๋ฌธ์ œ๋ฅผ ํ•ด๊ฒฐ(๋˜๋Š” ๋ถ€๋ถ„์ ์œผ๋กœ ํ•ด๊ฒฐํ•  ์ˆ˜ ์žˆ์Œ)ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๊นŒ? ๋ช…๋ช…๋œ ํ…์„œ๋ฅผ ๊ธฐ๋ฐ˜์œผ๋กœ NHWC ์ƒํƒœ๋กœ ๋‚˜์•„๊ฐˆ ๊ณ„ํš์ด ์žˆ์Šต๋‹ˆ๊นŒ?

์šฐ๋ฆฌ๋Š” ์ฑ„๋„์˜ ๋งˆ์ง€๋ง‰ ์ง€์›์„ ๊ณ„์† ์ง„ํ–‰ํ•˜๊ณ  ์žˆ์œผ๋ฉฐ, ์ด๋ฒˆ ์ฃผ์— ๋กœ๋“œ๋งต์„ ์—ฌ๊ธฐ์™€ slack ์ฑ„๋„์— ๊ฒŒ์‹œํ•  ์˜ˆ์ •์ž…๋‹ˆ๋‹ค. (๋ชจ๋“  ์—ฐ์‚ฐ์ž๋ฅผ ๋‹ค์‹œ ์ž‘์„ฑํ•ด์•ผ ํ•˜๋ฏ€๋กœ) ์ฐจ๋‹จ๋œ ํ˜•์‹์„ ๊ณง ์ถ”๊ฐ€ํ•˜๋Š” ๊ฒƒ์„ ๊ณ ๋ คํ•˜์ง€ ์•Š์Šต๋‹ˆ๋‹ค.

๊ฐ์‚ฌ ํ•ด์š”. ์ž˜ ๋ ๊ฑฐ์•ผ!

https://github.com/pytorch/pytorch/issues/28619 ๋‚ด๋ถ€์˜ ํƒœํ‚น ์ž‘์—… ๋ฐ ์ง„ํ–‰ ์ƒํ™ฉ

์ด ํŽ˜์ด์ง€๊ฐ€ ๋„์›€์ด ๋˜์—ˆ๋‚˜์š”?
0 / 5 - 0 ๋“ฑ๊ธ‰