@ezyang ์ ์๋ก์ด ์ค๋ช :
์์ ์ https://github.com/Roger-luo/pytorch-complex ์์ ์งํ ์ค์ ๋๋ค.
๋ค์์ ์์ ์ ์ธ ์ํ์์ ์ํฌํ๋ก์ ๋ชจ์ต์ ๋๋ค.
PyTorch๋ ๊ธฐ๋ณธ์ ์ผ๋ก ๋ณต์กํ dtype์ ์ฐธ์กฐํ๊ธฐ ์ํ API๋ฅผ ํฌํจํ์ง๋ง ๊ธฐ๋ณธ์ ์ผ๋ก ์๋ฌด ์์ ๋ ์ํํ์ง ์์ต๋๋ค. PyTorch๋ ๋ณต์กํ ํ ์๋ฅผ ์ฐธ์กฐํ๋ torch.complex64 ๋ฐ torch.complex128์ ์ ์ํฉ๋๋ค. ๊ทธ๋ฌ๋ ์ด๋ฌํ ๋ฐฉ์์ผ๋ก ํ ์๋ฅผ ๊ตฌ์ฑํ๋ ค๊ณ ํ๋ฉด ๊ธฐ๋ณธ์ ์ผ๋ก PyTorch์์ ์ค๋ฅ๊ฐ ๋ฐ์ํฉ๋๋ค.
>>> torch.zeros({2,2}, dtype=torch.complex64)
RuntimeError: complex64 not supported by PyTorch
@ezyang ์ ์ด๋ฌํ dtypes๋ฅผ PyTorch์ ์ถ๊ฐํ๋ ํจ์น๋ฅผ ์ ๊ณตํ์ต๋๋ค. https://github.com/pytorch/pytorch/pull/11173
์ค๊ฐ์ PyTorch์์ ๊ธฐ๋ณธ์ ์ผ๋ก ์ง์ํ๋ ๊ธฐ๋ณธ ๊ธฐ๋ฅ(์: 0์ ํ ์ ํ ๋น)์ ๋ํ ์ง์์ ๋ณํฉํ ๊ฒ์ ๋๋ค. ์ด๋ค ์ง์์ด "๊ธฐ๋ณธ"์ธ์ง์ ๋ํ ํฉ๋ฆฌ์ ์ธ ํ๋ก์๋ CPU ํํ ํ ์(๊ทน๋๋ก ๋น๊ณคํ)์ ๋ํ PyTorch์ ๊ธฐ๋ณธ ์ง์์ ๋๋ค.
PyTorch๋ ๋ณต์กํ ํ ์ ๊ตฌํ์ ๋ฑ๋กํ๊ธฐ ์ํ ์ธํฐํ์ด์ค๋ฅผ ๊ฒ์ํฉ๋๋ค. ๊ตฌํ์ TypeDefault ํด๋์ค(https://github.com/pytorch/pytorch/pull/11013)์์ ์์๋๋ฉฐ ์ด ํด๋์ค์ ๋ฉ์๋๋ฅผ ์ฌ์ ์ํ์ฌ ๋ณต์กํ ๊ตฌํ์ด ์๋ ํจ์์ ๊ตฌํ์ ์ ์ํฉ๋๋ค. ๋ค์๊ณผ ๊ฐ์ด ๋ณด์ผ ๊ฒ์ ๋๋ค.
struct CPUComplexFloatType final : public TypeDefault {
virtual Tensor add(const Tensor & self, const Tensor & other, Scalar alpha=1) const override {
// Your implementation of add for complex tensors
}
// ...
}
์ด ํด๋์ค๋ ๋ณตํฉ์ ๋ํด ์ง์๋๋ ์ ํ์ ์ ํํ ์ฌ์ ์ํฉ๋๋ค. ๋ค๋ฅธ ๋ชจ๋ ๊ตฌํ์ TypeDefault์์ ์ ๊ณต๋๋ฉฐ ๊ธฐ๋ณธ์ ์ผ๋ก ์ค๋ฅ๊ฐ ๋ฐ์ํฉ๋๋ค.
PyTorch ์์ค ๋ฆฌํฌ์งํ ๋ฆฌ์ ์ฒดํฌ์ธ๋๋ ์๋ ์์ฑ ํ์ผ๋ก Type(์ ์ฒด ์ธํฐํ์ด์ค)์์ ์ง์๋๋ ๋ฉ์๋์ ์ ์ ๋ชฉ๋ก์ด ์์ต๋๋ค. ์ด ํ์ผ์ diff๋ฅผ ์ฌ์ฉํ์ฌ API ๋ณ๊ฒฝ ์ฌํญ์ ์ ๋ฌํฉ๋๋ค. ์ผ๋ฐ์ ์ผ๋ก ๋ฉ์๋๋ PyTorch ํ๋ฐํธ์๋์์ ํด๋น ์ด๋ฆ๊ณผ ์ผ๋์ผ ๋์ํฉ๋๋ค.
์ผ๋ฐ์ ์ผ๋ก ์์ง ๊ตฌํํ์ง ์์ ์์ ์ ์ฌ์ฉํ๋ ๊ฒฝ์ฐ
๊ฒฝ๊ณ : ์๋ก์ด ์์ ์ ๊ณต๊ฐ ๋ฑ๋ก๋ ์ง์ํ๋ ์ ์์คํ ์ผ๋ก Type์ ๋ฆฌํฉํ ๋งํ ์์ ์ ๋๋ค(์ง์ํ๋ ค๋ ๋ชจ๋ ๋ฉ์๋๋ฅผ ์ ์ํ๋ ๋จ์ผ ์ํผํด๋์ค๊ฐ ์๋ ๊ฒฝ์ฐ์๋ ๋ถ๋ช ํ ์๋ํ์ง ์์ต๋๋ค). ๋ฐ๋ผ์ Type์ ํ์ ํด๋์ค๋ก ์์ฑํ๋ ํน์ ๊ตฌํ ์ ๋ต์ ๋๋ฌด ์ฝ๋งค์ด์ง ๋ง์ญ์์ค.
์๋กญ๊ณ ๋ณต์กํ ์ ์ฉ ์์ ์ ๊ฒ์ํ๋ ค๋ฉด C++ ํ์ฅ API๋ฅผ ์ฌ์ฉํฉ๋๋ค. C++ ํ์ฅ API๋ https://pytorch.org/tutorials/advanced/cpp_extension.html ์ ์ค๋ช ๋์ด ์์ต๋๋ค. ๊ธฐ๋ณธ์ ์ผ๋ก ๋ค์๊ณผ ๊ฐ์ C++ ํจ์๋ฅผ ์์ฑํ ์ ์์ต๋๋ค.
at::Tensor imag(at::Tensor z) {
...
}
๊ทธ๋ฐ ๋ค์ C++ ํ์ฅ API๋ Python ๋ฐ์ธ๋ฉ์ ์์ฑํ๋ฏ๋ก Python์์ ์ด ํจ์๋ฅผ ํธ์ถํ ์ ์์ต๋๋ค.
์ผ๋ถ ์์ ์ ํ์ฌ ์กด์ฌํ๋ PyTorch์ "์ฝ๊ฒ" ํตํฉ๋ ๊ฒ์ ๋๋ค. ์๋ฅผ ๋ค์ด, ์ด์ง ์ฐ์ฐ์ ๊ตฌํํ๋ ค๋ฉด BinaryOpsKernel.cpp์์ add_kernel์ ํ์ฅํ์ฌ ๋ณตํฉ ์ ํ์ ๋ํด ๋์คํจ์นํ๋๋ก ํ๋ ๊ฒ์ด ๋ ํฉ๋ฆฌ์ ์ผ ์ ์์ต๋๋ค. ์ด๋ฌํ ํจ์น๊ฐ ์๊ณ ๋ ๋ฆฝ์ ์ธ ๊ฒฝ์ฐ ์ ์์ ๋ณํฉํ ๊ฒ์ ์ฝ์ํฉ๋๋ค.
๊ธฐ์กด ์ธํ๋ผ๋ฅผ ์ฌ์ฉํ๋ ๋์ Type์ ์ฌ์ ์๋ฅผ ์์ฑํ๊ณ ์์ ๋กญ๊ฒ ๋ณต์ฌํ์ฌ ๋ถ์ฌ๋ฃ๊ธฐ๋ฅผ ์ํํ์ฌ ํญ์ ์ฐจ๋จ์ ํด์ ํ ์ ์์ด์ผ ํฉ๋๋ค. ํ์ง๋ง ์ฌ์ธ ๋ ํผํ์!
์คํ ๊ทธ๋ผ๋. ์ด๋ฏธ ํ์ ๊ณต์์ด ์ ์๋์ด ์๋ ์์ ์ ๋ํด ์์ ํ๋ ํ Derivatives.yaml์์ ์ญ๋ฐฉํฅ ๊ตฌํ์์ ํธ์ถ๋๋ ๋ชจ๋ ๊ตฌ์ฑ ๊ธฐ๋ฅ์ ๋ํ ๋ณต์กํ ์ง์์ ๊ตฌํํ๋ ํ autograd ์ง์์ "์๋์ผ๋ก" ๋ฐ๊ฒ ๋ฉ๋๋ค. .
์ด๋ค ๊ฒฝ์ฐ์๋ ๋ณต์์์ ๋ํด ์๋ํ๋๋ก autograd ๊ณต์์ ์กฐ์ ํด์ผ ํ ์๋ ์์ต๋๋ค. ์๋ฅผ ๋ค์ด, 'abs'์ ๊ธฐ์ธ๊ธฐ๋ 'grad'๊ฐ ์๋๋๋ค. self.sign()'. ์ด๋ฌํ ๊ฒฝ์ฐ ์ฐ๋ฆฌ๊ฐ ํด์ผ ํ ์ผ์ 'abs'์ autograd ๊ณต์์ ์ฌ์ ์ํ ์ ์๋ ํจ์์ธ 'abs_backward'๋ก ๋ณ๊ฒฝํ๋ ์ ์คํธ๋ฆผ ์์ ์ ๋๋ค.
์ผ๋ฐ์ ์ธ ๋ณต์์ ์ญ์ ํ์ ๊ฒฝ์ฐ ๋ค์๊ณผ ๊ฐ์ ์ฐธ์กฐ๊ฐ ์์ต๋๋ค.
์ผ๋ฐ์ ์ผ๋ก ๋๋ถ๋ถ์ ๊ฒฝ์ฐ ์ค์ ํจ์(์์ค)์ ๋ํจ์๋ง ๊ณ์ฐํ๋ฏ๋ก autograd๋ฅผ ์์ ํ ํ์๊ฐ ์์ต๋๋ค.
์ค๋๋ ํ์ํ ๋ง์ ๋ถ๋ถ์ด ์ ์๋ฆฌ์ ์์ง๋ง ์ข ๋จ ๊ฐ ๋ฐฉ์์ผ๋ก ์กฐ๋ฆฝ๋์ง๋ ์์ต๋๋ค. ๋ค์์ ์ํํด์ผ ํ ์์ ์ ๋๋ค.
๋จ๊ธฐ ํตํฉ ๊ณํ. ์ด๋ฌํ ์์ ์ "์ฝ๊ฒ" ๊ตฌํํ ์ ์์ผ๋ฏ๋ก ๊ฐ๋ฅํ ํ ๋นจ๋ฆฌ PyTorch์์ ๋ฉ์ธ๋ผ์ธ์ ๊ตฌ์ฑํด์ผ ํฉ๋๋ค.
์ปค๋ ๊ตฌํ:
TODO: https://github.com/Roger-luo/TH/blob/master/ChangeLog.md ๋ฅผ ๊ธฐ๋ฐ์ผ๋ก ๋ชฉ๋ก ์์ฑ
๊ธฐํ ๋ณต์กํ ๊ด๋ จ ์์ :
@PhilippPelz ์ ์๋ ๋๊ธ
๋ณต์กํ ํ
์๋ฅผ pytorch์ ํตํฉํ๋ ๋ฐ ๊ด์ฌ์ด ์๋์ง ๊ถ๊ธํฉ๋๋ค.
CPU ์ง์์ ์ํด ztorch๊ฐ ์์ผ๋ฉฐ ์ผ๋ง ์ ์ z-cutorch( https://github.com/PhilippPelz/z-cutorch )๋ฅผ ์์ฑํ์ต๋๋ค. CudaHalfTensor์ ๋ํ ๋ฆฌํฉํ ๋ง ์ด์ ์ ํฌํฌ ์คํ ์ปท์
๋๋ค(์์ง ํ๋์จ์ด๊ฐ ์์).
์ผ์ด ๋ง์ง ์๋ค๋ฉด ์ฒ์ฒํ pytorch์ ํตํฉํ๊ณ ์ถ์ต๋๋ค. ์ ๋ fb.ptyhon์ ํตํด ํ๋กํ
ํ๊ธฐ ์ํด matplotlib๋ฅผ ์ฌ์ฉํ๊ณ ์์ผ๋ฉฐ ์์คํ
์ ๋ค์ ์ค์นํ ๋๋ง๋ค(๋ชจ๋ ์ข
์์ฑ ์ปดํ์ผ) ์์ฒญ๋ ๊ณ ํต์ด ๋ฐ๋ฅด๋ฉฐ ๋ด ์คํ์ฉ PC ์ค ํ๋๊ฐ ์คํ๋๋ Windows์์ pytorch๊ฐ ๊ณง ์๋ํ ๊ฒ ๊ฐ์ต๋๋ค.
๋ณต์กํ ๊ทธ๋ผ๋์ธํธ๋ ํ์ํ๋ฏ๋ก ์กฐ๋ง๊ฐ autograd๋ ๋ง์ง ๊ฒ์
๋๋ค.
tf๋ ์์ฒด์ ์ผ๋ก ๋ณต์กํ ํ
์๋ฅผ ์ง์ํ์ง๋ง ๋ง์ ์์
์์ ์์ง ์ง์ํ์ง ์๋ ๊ฒ ๊ฐ์ต๋๋ค(https://github.com/tensorflow/tensorflow/issues/2255). ๊ฒ๋ค๊ฐ ์ ๋ชฉ์ ์๋ ์ฝ๊ฐ ๋ฌด๊ฑฐ์ด ๊ฒ ๊ฐ์ต๋๋ค.
ํ์ํ ๋งํ ์์ด๋์ด๋ผ๋ฉด ๋๊ตฐ๊ฐ๊ฐ ์ด๊ฒ์ ์์ํ๋ ๋ฐฉ๋ฒ๊ณผ ์์น์ ๋ํด ๋ช ๋ง๋ ๋งํ ์ ์์ต๋๋ค.
๋ณต์กํ ํ
์์ ๋ํ ์ ํ์ ์ง์์ ์ถ๊ฐํ๋ ๋ฐ ๊ด์ฌ์ด ์๋ค๊ณ ์๊ฐํฉ๋๋ค. ๊ฐ์ฅ ์ข์ ๋ฐฉ๋ฒ์ torch/lib
์ C ๋ผ์ด๋ธ๋ฌ๋ฆฌ๋ฅผ ํฌํฌํ๊ณ ์์
ํ๋ ๊ฒ์
๋๋ค. ์ด๊ฒ์ ๋ง์คํฐ์ ์ถฉ๋์ด ์์ด์ผ ํ๋ฏ๋ก ์ค๋ซ๋์ ์ด ์์
์ ์ํํ ์ ์์ต๋๋ค. ๋ผ์ด๋ธ๋ฌ๋ฆฌ๋ฅผ ์ฌ์ฉ ๊ฐ๋ฅํ ์ํ๋ก ๋ง๋ค๋ฉด ๋ฐ์ธ๋ฉ ์์ฑ์ ์์ํ ์ ์์ผ๋ฉฐ ์ฌ๊ธฐ์์ ์ถฉ๋์ ํผํ๋ ๋ฐฉ๋ฒ์ ๋ํ ์ง์นจ์ ์ ๊ณตํ ์ ์์ต๋๋ค.
๋ณต์กํ ์ ํ์ ์ปดํ์ผํ๋ TH๊ฐ ์์ต๋๋ค. Python ํตํฉ์ ์ํด ๋ฌด์์ ์ถ๊ฐํด์ผ ํฉ๋๊น?
@PhilippPelz ๋ https://github.com/facebook/ztorch/tree/master/lib/THZ ๋ฅผ ์๋ฏธํฉ๋๊น? ์๋๋ฉด ๋ณต์กํ ์ ํ์ ๊ฐ๋ฅํ๊ฒ ํ๋ ์์ฒด TH ํฌํฌ๋ฅผ ๊ตฌ์ถํ์ต๋๊น?
@killeent ๋ TH๊ฐ Python์ ๋ฐ์ธ๋ฉ๋๋ ๋ฐฉ๋ฒ์ ๋ํ ๋ช ๊ฐ์ง ๋ฉ๋ชจ๋ฅผ ๊ฐ์ง๊ณ ์์ผ๋ฉฐ ๊ณต์ ํ ์ ์์ต๋๋ค.
์ผ๋ฐ์ ์ผ๋ก ๋ณต์กํ Tensor๋ฅผ ๊ฐ์ ธ์ค๋ ค๋ฉด ํ ์คํธ ๋ฑ์ด ์์ผ๋ฏ๋ก THZ๋ฅผ ์ ํธํฉ๋๋ค.
๋ณต์กํ Tensor๋ฅผ ์ํ CUDA ๋ฐฑ์๋๋ฅผ ๊ตฌ์ถํ๋ ๊ฒ์ ์๋นํ ํ๋ ์ผ์ด์ง๋ง ์์ง ์์๋ ํ์ง ์์์ต๋๋ค.
๋๋ ์ผ๋ง ์ ์ z-cutorch( https://github.com/PhilippPelz/z-cutorch )๋ฅผ ์์ฑํ์ต๋๋ค. CudaHalfTensor์ ๋ํ ๋ฆฌํฉํ ๋ง ์ด์ ์ ํฌํฌ ์คํ ์ปท์ ๋๋ค(์์ง ํ๋์จ์ด๊ฐ ์์).
์ด๊ฒ์ ํ๋ฅญํฉ๋๋ค. ๊ทธ๋ฐ ๋ฐฉํฅ์ผ๋ก ์ด๋ฏธ ๋ง์ ๋ ธ๋ ฅ์ ํ๋ค๊ณ ์๊ฐํฉ๋๋ค :)
@soumith ๋๋ ๋ณต์กํ ์ ํ์ผ๋ก TH๋ฅผ ํฌํฌํ์ต๋๋ค. ๊ธฐ๋ณธ์ ์ผ๋ก THGenerateComplexTypes.h + ์ถ๊ฐ๋ BLAS + LAPACK ๋ฃจํด ๋๋จธ์ง๋ ๊ฑฐ์ ๋ฌด๋ฃ์์ต๋๋ค. THZ์ ์ด๋ค ๋ถ๋ถ์ด ํธํ๋๋์ง ํ์ธํ ๋ค์ ๋ณต์ฌํ์ฌ ๋ถ์ฌ๋ฃ๋ ๊ฒ๋ณด๋ค ํจ์ฌ ๋ ์์ํด ๋ณด์์ต๋๋ค.
์ ๋ ์ง๊ธ THPP๋ฅผ ์ปดํ์ผํ๋ ๋ฐ ์ด๋ ค์์ ๊ฒช๊ณ ์์ผ๋ฉฐ ๋ค์๊ณผ ๊ฐ์ ์ปดํ์ผ๋ฌ ๋ฉ์์ง๋ฅผ ํ์ ํ๊ณ ์์ต๋๋ค.
/home/philipp/projects/pytorch/torch/lib/tmp_install/include/TH/generic/THBlas.h:6:40: ์ค๋ฅ: '*' ํ ํฐ ์์ ',' ๋๋ '...'๊ฐ ์์ด์ผ ํฉ๋๋ค.
TH_API ๋ฌดํจ THBlas_(swap)(long n, real *, long incx, real *, long incy);
์กฐ๊ธ ๊น๋ค๋กญ์ต๋๋ค.
ํ์ด์ฌ ํตํฉ์ ํ์ฑํํ๋ ๋ฐฉ๋ฒ์ ๋ํ ๋์์ ์ฃผ์๋ฉด ๊ฐ์ฌํ๊ฒ ์ต๋๋ค. CUDA ๋ฐฑ์๋๋ ๋๋ถ๋ถ z-cutorch์์ ๋ณต์ฌ ๋ถ์ฌ๋ฃ๊ธฐ์ฌ์ผ ํฉ๋๋ค.
@PhilippPelz ๋ PyTorch ๋ฉ TH์ ๋ํ ๋ช ๊ฐ์ง ์ฐธ๊ณ ์ฌํญ์ ๋๋ค. https://gist.github.com/killeent/4675635b40b61a45cac2f95a285ce3c0
@killent ๊ฐ์ฌํฉ๋๋ค, ๋งค์ฐ ์ ์ฉํด ๋ณด์ ๋๋ค. lib/build_all.sh๊ฐ ์ด์ ์ปดํ์ผ ์ค์ ๋๋ค. csrc ๋๋ ํ ๋ฆฌ๋ฅผ ๋ณผ ์ ์์ ๊ฒ ๊ฐ์ต๋๋ค.
์ด์ ๋ค์์ด ์คํ๋ฉ๋๋ค.
ํ ์น๋ฅผ th๋ก ๊ฐ์ ธ์ค๊ธฐ
numpy๋ฅผ np๋ก ๊ฐ์ ธ์ค๊ธฐ
a = np.array([1+1j,2+2j])
b = np.array([3+3j,4+4j])
ath = th.from_numpy(a)
bth = th.from_numpy(b)
ath_cuda = ath.cuda()
ath_cuda += bth.cuda()
ath = ath_cuda.cpu()
์ธ์(ath.numpy())
์์: [ 4.+4.j 6.+6.j]
๋๋ถ๋ถ์ ์ํ ํจ์์ ํจ๊ป
๋๋ ๋ค์ ์ฃผ์ ํธ์ ๊ธฐ๋ฅ๊ณผ fft๋ฅผ ์ถ๊ฐํ ๊ฒ์
๋๋ค. ์ด๊ฒ์ ๋ณํฉํ๊ธฐ ์ ์ ๋ชจ๋ ๊ฒ์ ๋ํ ํ
์คํธ๊ฐ ํ์ํ๋ค๊ณ ์๊ฐํฉ๋๋ค. ๋ณต์กํ ํ
์์ ๊ด์ฌ์ด ์๊ณ ํ
์คํธ ์์ฑ์ ๊ธฐ๊บผ์ด ๊ธฐ์ฌํ ์ ์๋ ์ฌ๋์ ์๊ณ ์๋ค๋ฉด ์ ๋ง ์ข์ ๊ฒ์
๋๋ค. ์ด ๋
ผ๋ฌธ์ด ๋ ์ค๋ฆ
๋๋ค: Deep Complex Networks , ์๋ง๋ ๊ทธ ์ฌ๋๋ค์ด ๊ด์ฌ์ ๊ฐ์ง ๊ฒ์
๋๋ค.
ํผ์์ ๋ชจ๋ ํ
์คํธ๋ฅผ ์์ฑํ ์๊ฐ์ด ์์ต๋๋ค.
@PhilippPelz ๊ทํ์ ์๊ฒฌ์ ๊ฐ์ฌ๋๋ฆฝ๋๋ค. ๊ทํ์ ๊ตฌํ์ ํ์ธํ๊ณ ์์ต๋๋ค. ๊ทธ๋ฆฌ๊ณ ๋จผ์ ger
๊ตฌํ์ ๋ํด ์ ๋ชจ๋ฅด๊ฒ ์ต๋๋ค. ์์ฑ ํค๋์์ GER๋ฅผ zger_
๋ฐ cger_
๋ก ์ ์ํ ๊ฒ์ฒ๋ผ ์ผ๋ถ ๋ณต์กํ blas ํจ์๋ THBlas.c์ ํฌํจ๋์ด ์์ง ์์ง๋ง ์ผ๋ฐ/THBlas.c์๋ cger_๊ฐ ์๋ blas ํจ์๊ฐ ์์ต๋๋ค. . ํ์ง๋ง gemv ๋ฐ ๊ธฐํ ๊ธฐ๋ฅ์ ์ฌ์ฉํ ์ ์์ต๋๋ค. ๊ทธ๋ฆฌ๊ณ IMO๋ .gitignore์ .gch๋ฅผ ์ถ๊ฐํด์ผ ํ ๊น์? ๋ชจ๋ ํ์ฅ์ ํฌํฌ๋ก ํธ์ํ์ต๋๊น? ๋จผ์ ๊ตฌํ์ ๋ฐ๋ผ ๋ง์คํฐ์๊ฒ ํ ์์ฒญ์ ํ ์ ์์ต๋๋ค.
๊ทธ๋ฆฌ๊ณ DOT
์ ๊ฒฝ์ฐ ๋ณต์กํ ๋ฒกํฐ์ ๊ฒฝ์ฐ ์ ์ ๋ํ dotc
๋ฃจํด์ด ๋ ์ผ๋ฐ์ ์ด๋ผ๊ณ ์๊ฐํฉ๋๋ค.
๊ทธ๋ฆฌ๊ณ ์, real
๋ฅผ ์ฌ์ฉํ๋ ๊ฒ์ด ๊ตฌํ์ ๋ ์ฌ์ธ ๊ฒ์ด๋ผ๋ฉด real
๊ฐ ์ค์ ๋ก ๋ณต์กํ ๋ ์ด์ํ ๋๋์ด ๋ค์์ต๋๋ค.
๊ทธ๋ฆฌ๊ณ ํ ์คํธ์ ๊ฒฝ์ฐ TH์ ๋ํ ์ด์ ํ ์คํธ๋ฅผ ๋ณด์ง ๋ชปํ์ต๋๋ค. ๊ทธ ํ ์คํธ๋ ์ด๋์ ์์ฑํด์ผ ํฉ๋๊น? ๋๋ ์ฐ๋ฆฌ๋ ์ฝ๊ฐ์ ํ์ด์ฌ ํ ์คํธ๋ฅผ ์์ฑํฉ๋๋ค.
์, ์ฃ์กํฉ๋๋ค. ํ์ํ ๋ชจ๋ ๊ฒ์ ํธ์ํ์ง ๋ชปํ ๊ฒ ๊ฐ์ต๋๋ค. ์์์ผ์ ๋ค์ ํ์ธํ๊ฒ ์ต๋๋ค. ์ผ๋ถ ์ ์ธ์ด ๋๋ฝ๋์์ต๋๋ค. zger์ cger
DOT์ ๊ฒฝ์ฐ cdotc ๋ฐ zdotc๋ฅผ ์ฌ์ฉํ๊ณ ์๋๋ฐ ๋๋ฝ๋ ๊ฒ ๊ฐ์ต๋๋ค. ๋ค์ ์ฃผ์ ์ ๋ฐ์ดํธํ๊ฒ ์ต๋๋ค.
์ค์ ๋ก ์ ํธํ๋ ์ด๋ฆ์ pytorch ์ ์ง ๊ด๋ฆฌ์์๊ฒ ํ์ธํ์ญ์์ค. ๋๋ ๋น์ ์ ๋ฒ์ ์ ๋ ์ข์ํ์ง๋ง ์์ง ๋ ธ๋ ฅ์ ๊ธฐ์ธ์ด์ง ์์์ต๋๋ค.
์, ํ์ด์ฌ์ ์ํ ๋ฌธ์ ๋ฅผ ํ ์คํธํฉ๋๋ค. compelx ๋ฒํธ ๊ฒ์ฌ๋ ํฌํจํ๋๋ก ๋๋ถ๋ถ์ ๊ธฐ๋ฅ์ ๋ํด ์ฝ๊ฒ ๋ณ๊ฒฝ๋์ด์ผ ํฉ๋๋ค.
๋ฉ์ง ๋น์ ๋ ์ด๊ฒ์ ๋ํด ์ฐพ๊ณ ์์ต๋๋ค!
์๊ฒ ์ต๋๋ค. ๋ช ๊ฐ์ง ๋ณ๊ฒฝ ์ฌํญ์ ํธ์ํ์ต๋๋ค. TH blas ๋ฃจํด์ ์ด์ ๋ณต์กํ
@PhilippPelz ๋ฐฉ๊ธ ๊ทํ์ ๋ฆฌํฌ์งํ ๋ฆฌ์ pull ์์ฒญ์ ํ์ต๋๋ค. ๋ณต์กํ ์ ํ ๋ ์ด์ด ๋ฐ ๊ธฐํ ์ฐ์ฐ์์ ๊ฒฝ์ฐ. ๋ง์ hermitian ์ฐ์ฐ์ด ์์ ์ ์์ต๋๋ค(๋ณต์กํ ์ ํ ๋ ์ด์ด์ ๊ฒฝ์ฐ bp์ ๊ฐ์). ํ ์์ ๊ธฐ๋ฅ์ ์ถ๊ฐํ ๊น์? THNN ๋ถ๋ถ์ ํ์ธํ์ จ๋์?
์, hermitian์ด ์ ์ฉํฉ๋๋ค. cuda fft๊ฐ ์ง๊ธ ์๋ ์ค์ ๋๋ค. cpu fft๋ numpy์์ ๋ํ๋ ์ ์์ต๋๋ค. ๋๋ ์์ง THNN์ด๋ THCUNN์ ๋ง์ง์ง ์์๋ค.
@PhilippPelz PR์ ๊ฐ๋จํ ์๋์๋ฅผ ์ถ๊ฐํ์ต๋๋ค. ๊ทธ๋ฆฌ๊ณ ๊ทธ๊ฒ์ ๊ฒํ ํ ์ ์์ต๋๋ค. ๋ฐ๋ผ์ ์ด๋ฌํ ๋ณ๊ฒฝ ์ฌํญ์ด ์ ํฉํ์ง ํ์ธํ๊ณ ๋ค์ ๋จ๊ณ๋ก ์ด๋ํ ์ ์์ต๋๋ค. ๊ฐ์ฌ ํด์! ์ถ์ . ์ผ๋ถ ํค๋๋ฅผ ๋์น ๊ฒ ๊ฐ์ต๋๋ค. ํด๋น ํค๋์ ๊ธฐํ ๊ฒฝ๊ณ ๋ ์์ ํฉ๋๋ค. ์ค์ ์ถ๋ ฅ์ด ์๋ ๋ณต์กํ ํจ์์ ๊ฒฝ์ฐ ๋ณต์์ ํ ์๊ฐ ์๋ ์ค์ ํ ์๋ฅผ ๋ฐํํด์ผ ํฉ๋๊น? ๋๋ ๋ณตํฉํ๊ณผ ์ค์ ํ ์ฌ์ด์ ๋ณต์ฌ ๋ฐฉ๋ฒ์ ๊ตฌํํ์ผ๋ ๊ฐ๋ฅํ๋ค.
๊ฒํ ํ ๋ชจ๋ ์ปค๋ฐ์ ๋ฆฌ๋ฒ ์ด์คํ๊ฒ ์ต๋๋ค.
@PhilippPelz ์๋
ํ์ธ์, ๊ตฌํํ THPP
๋ถ๋ถ์ ๋ํด ์๋นํ ํผ๋์ค๋ฝ์ต๋๋ค. Traits.hpp
์ ์ถ๋ ฅ์ ์์กดํ๋ ์ด์ ๋ ๋ฌด์์
๋๊น? cuda ์์ด ์ปดํ์ผํ ๋ ์ค๋ฅ๊ฐ ๋ฐ์ํฉ๋๋ค. ์ฒ๋ผ๋ง ์ฌ์ฉ์ด ๊ฐ๋ฅํ๊ฐ์?Traits.hpp
? ๋๋ ๊ทธ๊ฒ์ ์์๋ด์ง ๋ชปํ๋ค. ๋ช ๊ฐ์ง ๋จ์๋ฅผ ์ ๊ณตํ ์ ์์ต๋๊น?
@Roger-luo ์, ๋ค๋ฅธ ๊ณณ์์๋ ๋ฌธ์ ๊ฐ ์์ต๋๋ค. ์ฐ๋ฆฌ๊ฐ ์ฌ์ฉํ๋ ๋ณตํฉ ์ ํ์ complex.h ๋๋ std::complex ์ค ํ๋์ฌ์ผ ํฉ๋๋ค. THPP๋ C++ ๋ํผ์ด๋ฏ๋ก std::complex๊ฐ ๋ ์ ์ ํ ์ ์์ต๋๋ค. ๋ฐ๊ฟ์ฃผ์ค ์ ์๋์?
Thrust๋ cffi ํ์ฅ์ ๊ตฌ์ถํ ๋์ ๋๊ฐ์ ์ด์ ๋ก ๋ฌธ์ ๋ฅผ ์ผ์ผํค๊ธฐ๋ ํฉ๋๋ค. ์ง๊ธ์ ํด๊ฒฐ ๋ฐฉ๋ฒ์ ์ํํ๊ณ ์์ง๋ง ์ ์ ํ ๋ฐฉ๋ฒ์ THC์์ ๋ณตํฉ ์ ํ์ cuFloatComplex/cuDoubleComplex๋ก ๋ณ๊ฒฝํ๋ ๊ฒ์ ๋๋ค. cffi ์ปดํ์ผ๋ฌ๊ฐ ๋ถํํ์ง ์๋๋ก. ์ง๊ธ ๋ฐ๋ก ์ฐ๊ตฌ๋ฅผ ์งํํ๊ณ ์ถ์๋ฐ ์๊ฐ์ด ๋๋ฌด ๋ง์ด ๊ฑธ๋ฆฌ๋ค์ :( . ์๊ฐ์ด ๋์๋ฉด ํด์ฃผ์ธ์.
๋ํ ์ฌ์ฉ์ ์ ์ ์ปค๋ ํธ์ถ๋ก cffi ํ์ฅ์ ๊ตฌ์ถํ๋ ๊ฒ์ ๋งค์ฐ ๋ฒ๊ฑฐ๋ก์ด ์ผ์ ๋๋ค. ์๋ํ๋ฉด ํญ์ nvcc๋ก ์ปดํ์ผ๋ ์ถ๊ฐ ๋ผ์ด๋ธ๋ฌ๋ฆฌ๋ฅผ ์์ฑํด์ผ ํ๊ธฐ ๋๋ฌธ์ ๋๋ค. ์ด ๋ผ์ด๋ธ๋ฌ๋ฆฌ๋ cffi ๋ํผ์ ์ฐ๊ฒฐ๋ฉ๋๋ค. ๋ค๋ฅธ ๋ฐฉ๋ฒ์ ์๋ ๊ฒ ๊ฐ์์. ABI ๋ชจ๋์์ cffi๋ฅผ ์ฌ์ฉํ ์ ์์ง๋ง ์น ์ฌ์ดํธ์๋ "๋์ API ๋ชจ๋๋ ๋์ ํจ์๋ฅผ ์ง์ ํธ์ถํ๋ CPython C ๋ํผ๋ฅผ ์ปดํ์ผํฉ๋๋ค. ์๋์ ์ผ๋ก ํจ์ฌ ๋น ๋ฅด๋ฉฐ libffi๋ณด๋ค ๋ ์ ์๋ํฉ๋๋ค."๋ผ๊ณ ๋งํฉ๋๋ค.
@PhilippPelz ์๋ง๋ reinterpret_cast
์ด ํด๊ฒฐ์ฑ
์ด ๋ ์ ์์ต๋๊น? cuComplex
๋ก ๋ณ๊ฒฝํ๊ณ THPP์์ reinterpret_cast
๋ฅผ ์ฌ์ฉํด์ผ ํ๋ค๊ณ ์๊ฐํฉ๋๋ค. ๋ด๊ฐ ๋จผ์ ํด๋ณผ๊ฒ...
์, cuda๊ฐ ์ค์น๋์ง ์์ ์ํ์์๋ THPP๋ฅผ ๋น๋ํ๋ ค๋ฉด reinterpret_cast ์ธ์๋ ๋ค๋ฅธ ๋ฐฉ๋ฒ์ด ์๋ค๊ณ ์๊ฐํฉ๋๋ค.
@PhilippPelz ๋์๋๋ฆฌ๊ณ ์ถ์ต๋๋ค. ํ ์ผ ๋ชฉ๋ก์ด ์ด๋์ ์์ต๋๊น?
๋ณตํฉ ์ ํ์ ๋ํด THNN ๋ฐ THCUNN์ ํ์ฑํํด์ผ ํฉ๋๋ค. @roger-luo์ ํ๋ ฅํ ์ ์์ต๋๊น? ๋ํ ๋ง์คํฐ์์ ํตํฉ์ ๋ชฉํ๋ก ํ๋ค๋ฉด ๋ชจ๋ ๋ณต์กํ ๋ฉ์๋์ ๋ํด ๋จ์ ํ ์คํธ๋ฅผ ์์ฑํด์ผ ํฉ๋๋ค.
@elbamos THNN
์ ๋๋ถ๋ถ์ ์์
์ ์กด์ฌํ๋ ๊ฐ ๋ ์ด์ด์ ๋ํด ์๋กญ๊ณ ๋ณต์กํ ์ญ์ ํ ๋ฐฉ๋ฒ์ ๊ตฌํํ๋ ๊ฒ์
๋๋ค. Philipp์ ํฌํฌ์ WIP PR์ด ์์ต๋๋ค. ๋๋ ๋ช ๊ฐ์ง ์ฐธ์กฐ๋ฅผ ๋์ดํ์ต๋๋ค.
@apaszke @sumith @PhilippPelz ๊ทธ๋ฆฌ๊ณ ๋ ๊ฐ์ง ์ง๋ฌธ์ด ์์ต๋๋ค.
THS
์ ๋ค๋ฅธ GenerateXXXTypes.h
ํ์ผ์ด ์๋ ์ด์ ๋ฅผ ์๋ ์ฌ๋์ด ์์ต๋๊น? TH
์ ์๋ ๊ฒ๊ณผ ๋์ผํ๊ฒ ๋ณด์
๋๋ค.
byte_order.cpp
์ ๋ค์ ์ฝ๋๋ ๋ฌด์์
๋๊น?
void THP_decodeFloatBuffer(float* dst, const uint8_t* src, THPByteOrder order, size_t len)
{
for (size_t i = 0; i < len; i++) {
union { uint32_t x; float f; };
x = (order == THP_BIG_ENDIAN ? decodeUInt32BE(src) : decodeUInt32LE(src));
dst[i] = f;
src += sizeof(float);
}
}
void THP_decodeDoubleBuffer(double* dst, const uint8_t* src, THPByteOrder order, size_t len)
{
for (size_t i = 0; i < len; i++) {
union { uint64_t x; double d; };
x = (order == THP_BIG_ENDIAN ? decodeUInt64BE(src) : decodeUInt64LE(src));
dst[i] = d;
src += sizeof(double);
}
}
๊ด๋ จ ๋ณตํฉ ๋ฒ์ ๊ตฌํ์ ๋ํ ์ ์ ์ฌํญ์ด ์์ต๋๊น? ๋ค์ ๊ตฌํ์ด ์ฌ๋ฐ๋ฅธ์ง ํ์คํ์ง ์์ต๋๋ค...
void THP_decodeZFloatBuffer(std::complex<float>* dst, const uint8_t* src, THPByteOrder order, size_t len)
{
for (size_t i = 0; i < len; i++) {
union { uint64_t x; std::complex<float> cf;};
x = (order == THP_BIG_ENDIAN ? decodeUInt64BE(src) : decodeUInt64LE(src));
dst[i] = cf;
src += sizeof(std::complex<float>);
}
}
void THP_decodeDoubleBuffer(std::complex<double>* dst, const uint8_t* src, THPByteOrder order, size_t len)
{
for (size_t i = 0; i < len; i++) {
union { uint128_t x; std::complex<double> df;};
x = (order == THP_BIG_ENDIAN ? decodeUInt128BE(src) : decodeUInt128LE(src));
dst[i] = df;
src += sizeof(std::complex<double>);
}
}
์ด์ decodeUInt128XE
๋ ๋ค์๊ณผ ๊ฐ์ด ์ ์ธ๋ฉ๋๋ค.
static inline uint128_t decodeUInt128LE(const uint8_t *data) {
return (((uint128_t)data[ 0])<< 0) | (((uint128_t)data[ 1])<< 8)|
(((uint128_t)data[ 2])<< 16) | (((uint128_t)data[ 3])<< 24)|
(((uint128_t)data[ 4])<< 32) | (((uint128_t)data[ 5])<< 40)|
(((uint128_t)data[ 6])<< 48) | (((uint128_t)data[ 7])<< 56)|
(((uint128_t)data[ 8])<< 64) | (((uint128_t)data[ 9])<< 72)|
(((uint128_t)data[10])<< 80) | (((uint128_t)data[11])<< 88)|
(((uint128_t)data[12])<< 96) | (((uint128_t)data[13])<<104)|
(((uint128_t)data[14])<<112) | (((uint128_t)data[15])<<120);
}
static inline uint128_t decodeUInt128BE(const uint8_t *data) {
return (((uint128_t)data[15])<< 0) | (((uint128_t)data[14])<< 8)|
(((uint128_t)data[13])<< 16) | (((uint128_t)data[12])<< 24)|
(((uint128_t)data[11])<< 32) | (((uint128_t)data[10])<< 40)|
(((uint128_t)data[ 9])<< 48) | (((uint128_t)data[ 8])<< 56)|
(((uint128_t)data[ 7])<< 64) | (((uint128_t)data[ 6])<< 72)|
(((uint128_t)data[ 5])<< 80) | (((uint128_t)data[ 4])<< 88)|
(((uint128_t)data[ 3])<< 96) | (((uint128_t)data[ 2])<<104)|
(((uint128_t)data[ 1])<<112) | (((uint128_t)data[ 0])<<120);
}
์ ๋ ํ์ฌ THPP
T _Complex
std::complex<T>
๋ฅผ ์ฌ์ฉํ๊ณ ์์ต๋๋ค. ์ด๊ฒ์ด ์์ง Python์์ ์ฌ์ฉํ ์ ์๋์ง ํ์คํ์ง ์์ต๋๋ค. ๋๋ C ์ ํ T _Complex
python์ ์ฌ์ฉํ ์ ์์ต๋๋ค. ๋ฐ๋ผ์ ์ฌ๊ธฐ์ dst ์ ํ์ std::complex<T>
์
๋๋ค.
๊ทธ๋ฆฌ๊ณ ์ด ๊ตฌํ์ด ๋ง๋ค๋ฉด ์๋ง๋ https://github.com/calccrypto/uint128_t ์ ๊ฐ์ uint128_t
๊ตฌํ์ด ํ์ํ ๊ฒ์
๋๋ค. ๋ชจ๋ ์ปดํ์ผ๋ฌ๊ฐ 128๋นํธ ์ ์๋ฅผ ์ง์ํ๋ ๊ฒ์ ์๋ ๊ฒ ๊ฐ๊ธฐ ๋๋ฌธ์(gcc์๋ int128_t ๋ฐ uint128_t๊ฐ ์์ต๋๋ค).
@PhilippPelz ๊ทํ์ ํฌํฌ์ ๋ฌธ์ ๊ฐ ํ์ฑํ๋์ด ์์ง ์์ ๊ฒ์ผ๋ก ๋ํ๋ฌ์ต๋๋ค. ๊ทํ์ ํ๋ก์ ํธ ์ํ๋ ์ด๋ป์ต๋๊น? ๋ณต์กํ ํ ์๊ฐ pytorch์ ๋ก๋๋งต์ ์๋ค๋ ๊ฒ์ด ์กฐ๊ธ ์์ฝ์ต๋๋ค.
@el3ment ๋๋ CPU์ ๋ณต์กํ ๋ฐฑ์๋๋ฅผ ์ถ๊ฐํ์ต๋๋ค https://github.com/pytorch/pytorch/pull/4899 ํ์ง๋ง ์์ง ๊ฒํ ๋์ง ์์์ต๋๋ค... ๊ทธ๋ฆฌ๊ณ ๋ด PR์ ๋ํ ์๊ฒฌ์ ๋ฐ์ง ๋ชปํ๊ธฐ ๋๋ฌธ์ ์ฌ์ฉํ๊ธฐ ์์ํ์ต๋๋ค. Julia ํ๋ก๊ทธ๋๋ฐ ์ธ์ด๋ ์ต๊ทผ...
์ง๋๋ฒ์ @PhilippPelz ์๊ฒ ์ด๋ฉ์ผ์ ๋ณด๋์ต๋๋ค. ๊ทธ์ repo๋ ์์ง v0.1 ๋ฏธ๋ง์ด๊ณ ๊ทธ๋ 9์๊น์ง ๋ ผ๋ฌธ์ ์ํด ๋ฐ์ฉ๋๊น? ๊ทธ๋ฆฌ๊ณ v0.3์ ์๋ก์ด CUDA ๋ฐฑ์๋ ์์ ์ ํ๊ณ ์์์ง๋ง ์ด ๋ชจ๋ ๋ฐ์ธ๋ฉ์ ํผ์ ์๋ฃํ ์๊ฐ์ด ์์ต๋๋ค. map/reduce ํจ์๋ ์ผ๋ถ ์ต์ ํ๊ฐ ์๋ v0.1๊ณผ ๋ค๋ฅด์ง๋ง ๋ณต์์๋ฅผ ์ง์ํ๋๋ก ๊ฐ๋จํ๊ฒ ๋ณํํ ์ ์์ต๋๋ค. ๋์์ฃผ์ค ๋ถ์ด ๊ณ์๋ค๋ฉด ๊ธฐ์ ํ ๋ฐ...
๊ธฐ๊บผ์ด ๋์๋๋ฆฌ๊ฒ ์ต๋๋ค.
2018๋ 4์ 10์ผ ์คํ 10์ 52๋ถ์ Rogerluo [email protected] ์ด ๋ค์๊ณผ ๊ฐ์ด ์ผ์ต๋๋ค.
@el3ment CPU #4899์ ๋ณต์กํ ๋ฐฑ์๋๋ฅผ ์ถ๊ฐํ์ต๋๋ค.
์ง๋๋ฒ์ @PhilippPelz ์๊ฒ ์ด๋ฉ์ผ์ ๋ณด๋์ต๋๋ค. ๊ทธ์ repo๋ ์์ง v0.1 ๋ฏธ๋ง์ด๊ณ ๊ทธ๋ 9์๊น์ง ๋ ผ๋ฌธ์ ์ํด ๋ฐ์ฉ๋๊น? ๊ทธ๋ฆฌ๊ณ v0.3์ ์๋ก์ด CUDA ๋ฐฑ์๋ ์์ ์ ํ๊ณ ์์์ง๋ง ์ด ๋ชจ๋ ๋ฐ์ธ๋ฉ์ ํผ์ ์๋ฃํ ์๊ฐ์ด ์์ต๋๋ค. map/reduce ํจ์๋ ์ผ๋ถ ์ต์ ํ๊ฐ ์๋ v0.1๊ณผ ๋ค๋ฅด์ง๋ง ๋ณต์์๋ฅผ ์ง์ํ๋๋ก ๊ฐ๋จํ๊ฒ ๋ณํํ ์ ์์ต๋๋ค. ๋์์ฃผ์ค ๋ถ์ด ๊ณ์๋ค๋ฉด ๊ธฐ์ ํ ๋ฐ...
โ
๋น์ ์ด ์ธ๊ธ๋์๊ธฐ ๋๋ฌธ์ ์ด๊ฒ์ ๋ฐ๋ ๊ฒ์ ๋๋ค.
์ด ์ด๋ฉ์ผ์ ์ง์ ๋ต์ฅํ๊ฑฐ๋ GitHub์์ ๋ณด๊ฑฐ๋ ์ค๋ ๋๋ฅผ ์์๊ฑฐํ์ธ์.
@elbamos ์ฟจ, pytorch ํ์ ๋ณ๋์ ๊ตฌํ์ ์ ํธํ๋ ๊ฒ ๊ฐ์ต๋๋ค. ๋์ค์ ๋ค๋ฅธ ๋ถ๋ถ์ ๋ํด ํฌํฌ๋ฅผ ์ ๋ฐ์ดํธํ๊ฒ ์ต๋๋ค. ๊ทธ๋ฌ๋ ๋๋ ์ ๋ง๋ก ์ด๊ฒ์ ๋ํ ์๊ฐ์ด ์๊ณ , ์ด๊ฒ์ด pytorch์ ํฐ ํ์ฅ์ด ๋ ๊ฒ์ด๊ธฐ ๋๋ฌธ์ pytorch ํ์์ ๊ณํ์ด ์์ ๋ ์์ ์ ์์ํด์ผ ํ๋ค๊ณ ์๊ฐํฉ๋๋ค.
์๋ ํ์ธ์, ์ ์ฝ๋๋ v0.2 ์ดํ์ ์ปค๋ฐ๋์์ต๋๋ค.
๋๋ ๋ชจ๋ ํ ์ ์ฝ๋๋ฅผ Aten์ผ๋ก ์ฎ๊ธฐ๋ ๊ฝค ํฐ ๋ฆฌํฉํฐ๊ฐ ์๋ ๊ฒ์ ๋ณด์์ต๋๋ค. ์ด๊ฒ์ ๋ด ํฌํฌ๋ฅผ ํ์ฌ ๋ฒ์ ์ ์ฝ๊ฒ ๋ณํฉํ ์ ์์ผ๋ฉฐ ๋ ๋ง์ ์์ ์ด ํฌํจ๋ ์ ์์์ ์๋ฏธํฉ๋๋ค.
์์ง ๋ฐ์ฌ ๊ณผ์ ์ ์์ฑ ์ค์ด์ง๋ง Variable๊ณผ Tensor์ ๋ณํฉ์ด ๋ฆด๋ฆฌ์ค๋ ๋๊น์ง ์ด์จ๋ 0.4๋ฅผ ๊ธฐ๋ค๋ฆด ๊ณํ์ด์์ต๋๋ค. ๋ ์ผ์ฐ ๋ฆฌํฉํ ๋ง์ ํ๋ฉด ๋ฐ๋ผ์ก๊ธฐ ์ํด ๋๋ฌด ๋ง์ ๋ฆฌํฉํ ๋ง์ด ์งํ๋์ง ์์๊น ๊ฑฑ์ ๋ฉ๋๋ค.
@elbamos ์ํ๋ ๊ฒฝ์ฐ ๋ด ํฌํฌ์ ํญ๋ชฉ์ ์ถ๊ฐํ ์ ์์ต๋๋ค. ๋ณํฉํ๊ฒ ์ต๋๋ค. ๋์ ์ํ ์ค์ธ ํ๋ก์ ํธ์ ํ์ํ ๊ฒ์ ๊ตฌํํ๊ฒ ์ต๋๋ค. TH(CU)NN์ ๊ฝค ํฐ ์ธํฐํ์ด์ค์ด๋ฉฐ ์์ฒญ๋ ์ํฌ๋ก๋๊ฐ ๋ ๊ฒ์ ๋๋ค.
@el3ment ๋ค๋ฅธ ์ฌ๋์ ๋ฌธ์ ๋ฅผ ์ฒ๋ฆฌํ ์๊ฐ์ด ์์ต๋๋ค. ๊ทธ๋ฌ๋ ๊ฑฐ๊ธฐ์ ์๋ ๊ฒ์ ๊ตฌํํด์ผ ํ๋ ๊ฒฝ์ฐ ํญ๋ชฉ์ ๋ณํฉํ ๊ฒ์ ๋๋ค.
๊ธฐ๋ณธ์ ์ผ๋ก ๋ณต์์๋ก ์๋ํ๋ ๊ฒ์ ์ํ๋ค๋ฉด tensorflow๋ฅผ ์ ๊ทน ๊ถ์ฅํฉ๋๋ค.
์ปดํ์ผ ๋ฌธ์ ๊ฐ ์๋ ๊ฒฝ์ฐ์๋ ๋์์ ๋๋ฆฌ๊ฒ ์ต๋๋ค.
ํฌ์คํธ๋ฅ์ ๊ณ์ํ๋ฉด ์ด ๋ชจ๋ ๊ฒ์ ์ด๋ ์์ ์์ ํ์ฌ ๋ฒ์ ์ผ๋ก ์ด์ํ ๊ฒ์ ๋๋ค. ํ์ด์ค๋ถ์ด ์ด๋ฅผ ์ง์ํ์ง ์๋๋ค๋ ์ ์ ์ ๋ง ์ํ๊น์ต๋๋ค. :((
@PhilippPelz ๋์ํฉ๋๋ค. ์ ๋ง ์ฌํ๊ณ ์ค์ ๋ก tensorflow๊ฐ ์์ ๋ฌผ๋ฆฌํ์ ๋ชจ๋ ์ฐ์ฐ์๋ฅผ ์ง์ํ์ง ์์ต๋๋ค... Julia๋ฅผ ์ฌ์ฉํ๊ธฐ ์์ํ๊ณ ํ์ด์ฌ์ ํฌ๊ธฐํ์ต๋๋ค.
@Roger-luo ํฅ๋ฏธ๋กญ์ต๋๋ค. ํน์ julia ํจํค์ง๋ฅผ ์ฌ์ฉํ๊ณ ์์ต๋๊น ์๋๋ฉด ๋ชจ๋ ์์ฒด ์์ฑ ์ฝ๋์ ๋๊น?
@PhilippPelz ์ ๋ Julia์์ ์์ ๋ค์ฒด ํดํท์ ๊ฐ๋ฐ ์ค์ ๋๋ค(PyTorch PR ์ดํ). ์ฌ๊ธฐ์๋ ๋ณต์กํ ์ ๊ฒฝ๋ง์ ๋ํ ์ด์ ๋ ผ๋ฌธ์ ๊ธฐ๋ฐ์ผ๋ก ํ ๋ณต์กํ/์ค์ ์ ๊ฒฝ๋ง ๊ตฌํ์ด ํฌํจ๋์ด ์์ผ๋ฉฐ Julia์ ๋ฉํํ๋ก๊ทธ๋๋ฐ. ๋๋ ํ์ฌ ๊ทธ๊ฒ์ QMTK.jl ์ ๋ฃ์๊ณ , ์ฌ์ ํ ์งํ ์ค์ด๋ฉฐ ๋ด๊ฐ ์ํ๋ ๋ชจ๋ ๊ฒ์ ๋๋ด์ง ๋ชปํ์ต๋๋ค. PyTorch๋ ์ ๋ง ๋ง์ ์๊ฐ์ ์ฃผ์ง๋ง ๋ณต์กํ ์ง์์ ๋ํด ์ ๋ง ์ฃ์กํฉ๋๋ค...
ํ์ง๋ง ์์ผ๋ก ์ ๊ฒฝ๋ง ๋จ์ผ ํจํค์ง๋ก ๋ถ๋ฆฌํ ๊ณํ์ด ์์ต๋๋ค(์ง๊ธ์ ์ฌ๋ฌ ์ ์ฅ์๋ฅผ ์ ์งํ๊ณ ์ถ์ง ์์ต๋๋ค). ๊ทธ๋ฆฌ๊ณ CAS(Institute of Physics, CAS)์ ๊ฐ๋ฐ์ ๋ ๋ง์ ์ฌ๋๋ค์ด ์ฐธ์ฌํ ๊ฒ์ ๋๋ค. ์ฒซ ๋ฒ์งธ ํ๊ทธ๊ฐ ์ง์ ๋ ๋ฒ์ (๋ช ์ฃผ ๋ด) ์ดํ์ PR์ ์๋ฝํ๊ฒ ์ต๋๋ค.
๊ฐ๋ฐ์ ๊ด์ฌ์ด ์๋ค๋ฉด ๋ณผ ์ ์์ต๋๋ค.
PyTorch ํ์ด ๋ฏธ๋์ ๋ณต์กํ ์ง์์ ๋ํ ๊ณํ์ ๊ฐ์ง๊ณ ์๋ค๋ฉด ๊ธฐ๊บผ์ด ๋์ธ ๊ฒ์ ๋๋ค.
์ฟจ, ์ง์ผ๋ณผ๊ฒ!
์๋ ํ์ธ์ ์ฌ๋ฌ๋ถ, ์ด ๋ฌธ์ ๊ฐ ์์๋ ์ดํ๋ก ์ด์ ๋ํด ๋ต๋ณ์ ๋๋ฆฌ์ง ๋ชปํด ์ฃ์กํฉ๋๋ค.
๋ค์์ ๋ ๊ฐ์ง ์ฌ์ค์ ๋๋ค.
์ด ๋ฌธ์ ๊ฐ 2017๋ ์ ์ด๋ฆฐ ์ดํ๋ก ๋ณต์กํ ์ง์ ๊ตฌํ์ ์ข ๋ ๊ฐ๋จํ๊ฒ ๋ง๋ค ์ ์๋ ๋ช ๊ฐ์ง ์ค์ํ ์ฌํญ์ด ๋ณ๊ฒฝ๋์์ต๋๋ค. ์ฒซ ๋ฒ์งธ๋ ์ด์ ํ ์๋ฅผ ์กฐ์ํ๊ธฐ ์ํ ์ธ์ฒด๊ณตํ์ C++ ๋ผ์ด๋ธ๋ฌ๋ฆฌ์ธ ATen์ด ์๋ค๋ ๊ฒ์ ๋๋ค. ์ด๊ฒ์ TH/THC ์ฝ๋์ ๊ฑฐ๋ํ ๋ถ๋ถ์ ๋ณต์ฌํ์ฌ ๋ถ์ฌ๋ฃ์ ํ์ ๊ฐ ์๊ณ ๋ชจ๋ ์๋ refcounting์ ์ฌ๋ฐ๋ฅด๊ฒ ํ์ผ๋ฉด ํ๋ ๊ฒ์ ์๋ฏธํฉ๋๋ค. C++ ์ฝ๋๋ฅผ ๋ง์น Python์ธ ๊ฒ์ฒ๋ผ ์์ฑํ ์ ์๊ณ ๋น ๋ฅด๊ฒ ์คํํ ์ ์์ต๋๋ค. ๋ ๋ฒ์งธ๋ ์ฐ๋ฆฌ๊ฐ C10์ด๋ผ๊ณ ํ๋ ์ ๋ฒ์ ์ Aten์ ๊ฐ๋ฐ ์ค์ด๋ผ๋ ๊ฒ์ ๋๋ค. ์ด ๋ฒ์ ์ Aten(๋ซํ ๊ฒ)๋ณด๋ค ๊ฐ๋ฐฉํ ๋ฐฑ์๋๋ฅผ ๊ฐ๋ ๊ฒ์ ๋ํด ํจ์ฌ ๋ ์ง์งํฉ๋๋ค. t๋ ์ค์ ๋ก PyTorch๋ฅผ ๋ถ๊ธฐํ์ฌ ์ฝ๋์ ์ ๋๋ ํ ๋ฆฌ๋ฅผ ์ถ๊ฐํ๋ ๊ฒ์ ์๋ฐํฉ๋๋ค.
๋ฐ๋ผ์ @Roger-luo ๋ฐ @PhilippPelz , ์ฐ๋ฆฌ๋ ๋ณต์กํ ๋ฐฑ์๋๋ฅผ ํ์ค๋ก ๋ง๋๋ ๋ฐ ์ฌ๋ฌ๋ถ์ ๋์์ ๋ฐ๊ณ ์ถ์ง๋ง ๋ฏธ๋์๋ ์ด๋ฅผ ์ง์ ๊ฐ๋ฅํ๊ฒ ์ ์งํ๋ ๋ฐ ๋์์ด ๋๋ ๋ฐฉ๋ฒ์ ์ฐพ๊ณ ์ถ์ต๋๋ค. ๋น์ ์ ์๊ฐ์ ์๋ ค์ฃผ์ธ์.
@ezyang ์ธ๋ ฅ์ด ๋ถ์กฑํ๋ฉด ์์ผ๋ก ๋ณต์กํ ํ ์ ๋ถ๋ถ์ ์ ์งํ๋ ค๊ณ ๋ ธ๋ ฅํ ์ ์์ต๋๋ค. ๋๋ ๋ฐฉ๊ธ ๋ฐ์ฌ ํ์๋ฅผ ์์ํ์ต๋๋ค. ์ ์ด๋ ์ต๊ทผ ๋ช ๋ . ํ์ง๋ง pytorch ํ์ ํผ๋๋ฐฑ ์์ด๋ ๊ณ์ ๊ธฐ์ฌํ ์ ์์ต๋๋ค. ์ด ํฐ ํ์ฅ์ ๋ํ ๋ก๋๋งต ์ด ์์ด์ผ ํ๋ค๊ณ ์๊ฐํฉ๋๋ค. ๊ทธ๋ฆฌ๊ณ ๋ณต์กํ ์ง์์ ์ํํ๊ฒ ์ถ๊ฐํ ์ ์์ผ๋ฏ๋ก ๊ทํ์ ์ฌ๋๋ค์ด ๋๊ท๋ชจ PR์ ๊ฒํ ํ ํ์๊ฐ ์๊ณ ๋ง์คํฐ ๋ธ๋์น๋ฅผ ์ถ์ ํ๋ ๊ฐ๋ฐ์์ ๋ ธ๋ ฅ์ ์ฝ๊ฒ ํ ์ ์์ต๋๋ค.
์ฒซ์งธ, ๋ณต์กํ ์ง์์ ๋ํ ์ฃผ์ ๋ฌธ์ ๋ CUDA ๋ถ๋ถ์ด ๋ ๊ฒ์ด๋ผ๊ณ ์๊ฐํฉ๋๋ค. Aten์ด๋ ๋ค๋ฅธ ๋ผ์ด๋ธ๋ฌ๋ฆฌ๋ก CPU ๋ถ๋ถ์ ์ง์ํ๋ ๊ฒ์ ๋งค์ฐ ์ฝ์ต๋๋ค. ํผ๋๋ฐฑ์ด ์์ผ๋ฉด ๋ฉฐ์น ๋ง์ CPU ๋ถ๋ถ์ ๋ค์ ์์ฑํ ์ ์์ต๋๋ค. CUDA ๋ถ๋ถ์ ๋ํด ์ฐ๋ คํ ์ ์๋ ๋ช ๊ฐ์ง ๋ฌธ์ ๊ฐ ์์ผ๋ฉฐ ์ด๊ฒ์ด ๋ ๊ฐ์ง ๋ค๋ฅธ ์ ๊ทผ ๋ฐฉ์์ผ๋ก ์ด์ด์ง ์ ์๋ค๊ณ ์๊ฐํฉ๋๋ค.
float2
๋ฑ์ ์ฌ์ฉํ์ฌ CUDA ๋ถ๋ถ์์ cuComplex
์ ๊ฐ์ ๋จ์ผ ๋ณตํฉ ๊ฐ์ ์๋ฎฌ๋ ์ด์
ํฉ๋๋ค.FloatTensor
๋ฐ DoubleTensor
๋ฅผ ์ฌ์ฉํ์ฌ Aten์ C++ ๋ถ๋ถ์์ ๋ณต์กํ ํ
์๋ฅผ ์๋ฎฌ๋ ์ด์
ํฉ๋๋ค.๋ ๋ฒ์งธ ์ ๊ทผ ๋ฐฉ์์ ์ด์ ๋ THC
์์ pytorch๊ฐ ๋ช ๊ฐ์ง ํธ๋ฆญ์ ์ฌ์ฉํ์ฌ ๋งต/๋ฆฌ๋์ค ์์
์ ๊ฐ์ํํ๊ณ cuComplex
๊ฐ ์ค์ ๋ก float2
์ด๊ธฐ ๋๋ฌธ์ ์ฌ์ํ cuComplex
์ ์ ํฉํ์ง ์๊ธฐ ๋๋ฌธ์
๋๋ค. float2
ํ์ง๋ง __shfl_xxx
ํจ์๋ ๊ธฐ๋ณธ์ ์ผ๋ก float2
๋ฅผ ์ง์ํ์ง ์์ต๋๋ค. ํ์ฌ๋ก์๋ float2
์ ๋ํด ์ด๋ฌํ ๊ธฐ๋ฅ์ ํจ์จ์ ์ผ๋ก ์๋ฎฌ๋ ์ด์
ํ๋ ๋ฐฉ๋ฒ์ ๋ชจ๋ฅด๊ฒ ์ต๋๋ค.
๋ ๋ฒ์งธ ์ ๊ทผ ๋ฐฉ์์ ์ด์ ํ๋์จ์ด์ ๋ํด ์ ๊ฒฝ ์ธ ํ์๊ฐ ์๊ธฐ ๋๋ฌธ์ ๋ ์ฌ์ธ ๊ฒ์ด๋ฉฐ ์ด์ ์ฅ์น์์ ์๋ก์ด ๋ณต์กํ ํ์ฅ ์์ ์ ํจ์ฌ ์ฝ๊ฒ ๋ง๋ค ์ ์์ต๋๋ค. ๊ทธ๋ฌ๋ ์ธ์ ๋ฉ๋ชจ๋ฆฌ ์ฃผ์๋ก ์ธํด ์ฝ๊ฐ์ ์ค๋ฒํค๋๊ฐ ๋ฐ์ํ ์ ์์ต๋๋ค.
๊ฒ๋ค๊ฐ, ๋๋ ๋ณต์์๋ฅผ Aten์ ํตํฉํ๋ ค๋ฉด ํ๋์จ์ด์์ ์ค์ ๋ก ๋์ผํ ๋ค ๊ฐ์ง ์ ํ์ ์ฒ๋ฆฌํด์ผ ํ ์๋ ์๋ค๋ ๊ฒ์ ๋ฐ๊ฒฌํ์ต๋๋ค. std::complex
, thrust::complex
, cuComplex
, float2
๋๋๋ก ์ํํ ์ ์์ต๋๋ค. (์ฌ์ค ์ ๋ ์๋
์ ์ด ๋ฌธ์ ๋ฅผ ๋ง๋ฌ๊ณ reinterpreter_cast
๊ฐ ํด๊ฒฐ์ฑ
์ด์์ต๋๋ค).
๋๋ ๊ฐ์ธ์ ์ผ๋ก ๋ชจ๋ ๊ฒ์ ๋ ๋ค์ดํฐ๋ธ๋ก ์ฐ๋ ๊ฒ์ ์ ํธํฉ๋๋ค.
๊ทธ๋ฆฌ๊ณ ์ฐ๋ฆฌ๋ ์๋ง๋ ์ผ์ ์ด๋ ๋ก๋๋งต์ด ํ์ํ ๊ฒ์ด๋ผ๊ณ ์๊ฐํฉ๋๋ค. ๊ทธ๋ฆฌ๊ณ ์ฐ๋ฆฌ๋ ๊ฐ๊ฐ์ ์์ ๋ถ๋ถ์ ์ ํํ๊ณ ํจ๊ป ์์ ํ ์ ์์ผ๋ฏ๋ก ์์ ํ ๋ถ๊ฐ๋ฅํ ๋ง์คํฐ๋ฅผ ์ง์ ์ถ์ ํ ํ์๊ฐ ์์ต๋๋ค...
CPU ๋ฐฑ์๋๋ฅผ ๊ตฌํํ๋ ค๊ณ ํ ๋ ChangeLog ๊ฐ ์์๊ณ ๋ก๊ทธ์ ๋ณต์์์ ๋ํด ํจ์๋ฅผ ์์ ํด์ผ ํ๋ค๊ณ ๋ถ๋ฅํ์ต๋๋ค. ์ด ๋ก๊ทธ๋ฅผ ๊ธฐ๋ฐ์ผ๋ก ๋ก๋๋งต์ ์์ฑํ ์ ์์ต๋๋ค.
๊ฒ๋ค๊ฐ (ํธ์ฃผ์์) ๋ด ๋น์๊ฐ ๋ฐฉ๊ธ ๊ฑฐ๋ถ๋์์ผ๋ฏ๋ก ๊ฐญ ์ด์ด๋ฅผ ์์ํด์ผํฉ๋๋ค. ๊ณ์ ์ผํ ์ฌ๋์ด ํ์ํ๋ฉด ์ธํด์ฝ์ ์ ์ฒญํ ์ ์์ต๋๋ค.
๋๋ ์ง๋ ํ๋ฃจ ๋์ ์ด๊ฒ์ ๋ํด ๋ง์ ์๊ฐ์ ํ๋ค. ๋ก์ ์ ๋ ธ๋ ฅ์ ๊ทธ๋๋ก ๋ด์๋ด์ง ๋ชปํด์ ์กฐ๊ธ ์์ฝ์ง๋ง ์์ผ๋ก ์๊ฐํ๋ค
"๋ฎ์ ์ ์ง ๊ด๋ฆฌ ์ค๋ฒํค๋๋ฅผ ์ ์งํ๋ฉด์ ์ด๋ป๊ฒ ๋ณต์กํ Tensor ์ง์์ ๊ตฌ์ถํ ์ ์์ต๋๊น?"
์ด๊ฒ์ ์์ ๋ชฉํ์์ ํจ๊ณผ์ ์ธ ๊ณํ์ผ๋ก ์ ์ํ๋ ๊ฒ์ ๋๋ค.
sparse
ํ
์์ ๊ฐ์ ๊ทผ๋ณธ์ ์ด๊ณ ์๋ก์ด ํ
์ ์ ํ์ด ๋์ด์๋ ์ ๋ฉ๋๋ค. ๊ธฐ๋ณธ ์ ํ์ ์ถ๊ฐํ๋ฉด ๋ง์ ์ ์ง ๊ด๋ฆฌ ์ค๋ฒํค๋์ ๊ต์ฐจ ๋ณ๊ฒฝ์ด ๋ฐ์ํฉ๋๋ค. ์ ์ง ๊ด๋ฆฌ ์ค๋ฒํค๋๋ "๋ณต์กํ ๋นํธ๋ฅผ ๋๊ฐ ์ ์ง ๊ด๋ฆฌํฉ๋๊น?"์ ๊ดํ ๊ฒ์ด ์๋๋ผ "์ด์ ๋ชจ๋ ํต์ฌ ๊ฐ๋ฐ์๋ ๊ทผ๋ณธ์ ์ธ ๋ณ๊ฒฝ, Aten ๋ณ๊ฒฝ ๋ฑ์ ์ํํ ๋ ์ด ๋ณต์กํ ์ ํ์ ์ธ์ํด์ผ ํฉ๋๋ค."torch.stack([real1 * real2 - imag1 * imag2, real1 * imag2 + imag1 * real2], dim = -1)
๋ก ๊ตฌํ๋์ด์ผ ํฉ๋๋ค. ์ฌ๊ธฐ์ real1 = input1[:, :, :, ..., 0]
BLAS, cublas ๋ฐ MAGMA๋ ๋ชจ๋ float2์ ๋ฐ์ดํธ ํธํ๋๋ ์์ฒด ๋ณตํฉ ์ ํ์ ๊ธฐ๋ํ๊ธฐ ๋๋ฌธ์ [Tensor Shape x 2]์ฌ์ผ ํฉ๋๋ค. ๋ํ blas, cublas ๋ฐ magma ํธ์ถ์ Python ์์ค์์ ์ฒ๋ฆฌํ ์ ์์ต๋๋ค.
๋ณต์กํ ๊ณฑ์
์ ๊ฒฝ์ฐ 20% ๋ฐ์ ๋์ง ์์ ๊ฒ์ด๋ผ๊ณ ์๊ฐํฉ๋๋ค. ์ค์ ๋ถ๋ถ๊ณผ ์ด๋ฏธ์ง ๋ถ๋ถ์ ๋ํ ๊ณ์ฐ ์์ 4๊ฐ์ ์ ์ฒด ๋ณต์ฌ ์์
์ด ์์ง ์์ต๋๊น?
์ด์จ๋ ๊ณ์ํด์ ๋ง์คํฐ์ ๋ณ๊ฒฝ ์ฌํญ์ ๋ณํฉํ์ง ์์๋ ๋๋ค๋ฉด ๋คํ์
๋๋ค.
@PhilippPelz ์ ๋์ํฉ๋๋ค. BLAS, cublas ๋ฐ MAGMA์ ๋ณต์กํ ์ง์์ ์๊ฒ ๋๋ฏ๋ก ๋ง์ ์ฑ๋ฅ์ ์์ ์ ์์ต๋๋ค. ํ์ง๋ง ์ ๋ชจ๋ฅด๊ฒ ์ต๋๋ค. ๊ทธ๋ฌ๋ ๋ถ๋ช
ํ ๋ณต์กํ Tensor๋ ํฌ์ ํ
์์ ์์ ํ ๋ค๋ฆ
๋๋ค scipy.sparse
์ ๊ฐ์ ๋๋ถ๋ถ์ ๋ผ์ด๋ธ๋ฌ๋ฆฌ์ Julia์ SparseArrays
๋ ํฌ์ ๋ฐฐ์ด์ ๊ธฐ๋ณธ์ ์ธ ๋ค์ฐจ์ ๋ฐฐ์ด์ ๊ตฌ์ฑ์ผ๋ก ์ทจ๊ธํฉ๋๋ค. ๊ทธ๋ฌ๋ ์๋ฌด๋ ๋ ๊ฐ์ ์ค์ ๋ฐฐ์ด์ ํฉ์ฑํ์ฌ ๋ณตํฉ ์ ํ์ ๋ค์ฐจ์ ๋ฐฐ์ด์ ์ทจ๊ธํ์ง ์์ต๋๋ค... (์ฌ๊ธฐ์ ์๋ฌด๋ tensorflow, arrayfire, numpy ๋ฐ Julia๋ฅผ ์๋ฏธํ์ง ์์ต๋๋ค). MXNet์์ FFT๋ ์ค์ ๋ก ๋ ๊ฐ์ ์ค์ ํ
์์ ๊ตฌ์ฑ์ผ๋ก ์ด๋ฃจ์ด์ง์ง๋ง ๋ณต์กํ ๊ฒ์ ์ง์ํ์ง ์์ต๋๋ค... tensorflow๋ complex64
๋ฐ complex128
๋ฅผ ํฌํจํ ๋ค๋ฅธ ๋ท ์ ํ์ ๋ํ ๋ํผ๋ก DataType์ ๊ตฌํํ ๊ฒ ๊ฐ์ต๋๋ค. types.proto ์ฐธ์กฐ
์ฒซ์งธ, ์์๋ณ ํจ์(ํจ์๊ฐ map/reduce๋ฅผ ํธ์ถํจ)๋ ์ฑ๋ฅ ์์ค์ด ํฌ์ง ์์ต๋๋ค(์ ์ด๋ ์ด๋ฌํ ์์
์ ์ํ ๋ฉ๋ชจ๋ฆฌ๋ ์ฐ์์ ์). ํ์ง๋ง ๋จผ์ ๋ช ๊ฐ์ง BLAS ๊ธฐ๋ฅ์ ๋ฒค์น๋งํนํ์ฌ FloatTensor
์ ๊ตฌ์ฑ์ด GPU์์ Complex64Tensor
์ ์ ์ฌํ ์ฑ๋ฅ์ ๋ณด์ด๋์ง, ๋ค์๊ณผ ๊ฐ์ ์ด์ ๊ตฌํ:
gemm
gemv
๋ณตํฉ ๋ณตํฉ ํ
์๋ ๋ค์๊ณผ ๊ฐ์ต๋๋ค(๋๋ shared_ptr
์ฌ์ฉ).
class ComplexTensor {
FloatTensor *real;
FloatTensor *imag;
};
๊ทธ๋ฌ๋ ์ฒซ ๋ฒ์งธ ์ ๊ทผ ๋ฐฉ์์ ๋จ์ ์์ ์ธ๊ธํ๋ฏ์ด __shfl_xxx
์ ๊ฐ์ ํจ์๋ ์ด ์์
์ ๋ ๊ธฐ๋ณธ์ ์ผ๋ก ์ํํ๋ ค๋ ๊ฒฝ์ฐ ์ฅ์ ๋ฌผ์ฒ๋ผ ๋ณด์
๋๋ค.
ํ์ฌ torch.fft
๋ [dim1, ..., dimN, 2]
๋ชจ์์ ๋จ์ผ float ํ
์๋ฅผ ๋ฐํํฉ๋๋ค.
@ezyang C10 ์ถ์ ์ผ์ ์ ์ด๋ป๊ฒ ๋๋์? ๋ง์คํฐ ๋ธ๋์น์์ ์ปดํ๋ ์ค ์ง์์ ์์ํ๊ธฐ์ ๋งค์ฐ ํฉ๋ฆฌ์ ์ธ ์์ ์ธ ๊ฒ ๊ฐ์ต๋๋ค.
@PhilippPelz ํ์คํ 0.4์ฉ์ด ์๋๋๋ค. ์ฐ๋ฆฌ๋ ๋ด๋ถ์ ์ผ๋ก 6์์ ๋ชฉํ๋ก ํ๊ณ ์์ต๋๋ค. ๋๋ฌด ์ค๋ ๊ธฐ๋ค๋ฆฌ์ง ์๊ธฐ๋ฅผ ๋ฐ๋๋๋ค.
6์์ ์ธ๊ธํ @ezyang , PyTorch์ ๋ณต์์ ์ง์์ ์ถ๊ฐํ ์ ์์๋์?
๋๋ ๊ทธ๊ฐ ๋ณต์กํ ์ง์์ด ์๋๋ผ C10์ ์๋ฏธํ๋ค๊ณ ์๊ฐํฉ๋๋ค. C10์ ๋ณต์กํ ์ถ๊ฐ๋ฅผ ๋ ์ฝ๊ฒ ๋ง๋ค ๊ฒ์ ๋๋ค. ๊ทธ๋ ๊ฒ ์ดํดํ์ต๋๋ค.
์, C10์ Tensor ์ ํ๊ณผ ๊ธฐ๋ฅ์ ๋ชจ๋ ๊ณต๊ฐ ๋ฑ๋กํฉ๋๋ค. ๋ฐ๋ผ์ ๋ณต์กํ ์ ํ์ ๋ณ๋์ ํจํค์ง๋ก ์ถ๊ฐํ๋ ๊ฒ์ด ํจ์ฌ ์ฌ์ธ ๊ฒ์ ๋๋ค.
๋ณต์์์ ETA๊ฐ ์์ต๋๊น? "ํจ์ฌ ๋ ์ฝ๋ค"๋ ๊ฒ์ "์๋ง๋ ๋นจ๋ฆฌ ๋๋ ๊ฒ์ด๋ค"๋ฅผ ์๋ฏธํฉ๋๊น?
@themightyoarfish ํจ์ฌ ์ฝ๊ฒ, ์ฐ๋ฆฌ๊ฐ pytorch ๋ง์คํฐ์ ํธ์ํ ์ ์๋ ๊ฒ์ ๋ํด ์ฐจ๋จ๋์ง ์์ ๊ฒ์์ ์๋ฏธํฉ๋๋ค. ETA๋ฅผ ์ค์ ํ์ง ์์์ต๋๋ค. PyTorch์ ๊ณต๊ฐ ๋ฑ๋ก๋๋ฉด ์์ ๋ฒ์๋ฅผ ์ง์ ํ๊ฒ ์ต๋๋ค.
@sumith ์ฌ์ ํ ์ด ์์ ์ ์ํํ ์ฌ๋์ด ํ์ํฉ๋๊น(๋ณต์์)? PyTorch ํ์ด ๋ณต์์๋ฅผ ์ง์ํฉ๋๊น? QuCumber ๋ฅผ ์ ์ง ๊ด๋ฆฌํ ๊ฒ์ด๊ธฐ ๋๋ฌธ์ 9์์ ์ํ๋ ๊ฒฝ์ฐ ์ด ์์ ์ ์๊ฐ์ ํ ์ ํ ์ ์์ต๋๋ค(๋ณต์์๋ฅผ ๋ง์ด ์ฌ์ฉํจ)
@Roger-luo ๋ค. PyTorch ๋ฐฑ์๋์์ ๊ณต๊ฐ ๋ฑ๋ก์ด ๊ฐ๋ฅํด์ง๋ฉด ์ฐ๋ฝ์ ๋๋ฆฌ๊ณ ์ธ๋ถ ์ฌํญ์ ํด๊ฒฐํ ์ ์์ต๋๋ค.
@ezyang 9์๊น์ง ์คํํ ๋ฑ๋ก ๋๋์?
@sumith Cool, ๊ทํ์ ์๋น์ค์.
์ฐ๋ฆฌ๋ ๊ทธ๊ฒ์ ์คํํ ์ ์์ต๋๋ค. (์ฐ๋ฆฌ๋ "์์ ํ" ์๋ก์ด ์์คํ ์ ๊ฐ์ถ์ง๋ ์๊ฒ ์ง๋ง, ๋ฆฌํฉํ ๋ง์ด ๊ฐ๋ฅํ๋๋ก ์ค์ ํ๋ ํ ์๋ก์ด ๊ฐ๋ฐ์ด ๋ฐ์ํ๋ฉด ๊ณ์ ์งํํ ์ ์์ต๋๋ค. ์ด๋ ์๋ก์ด ์คํ์ ๋ํ ์ข์ ํ ์คํธ ์ฌ๋ก๊ฐ ๋ ๊ฒ์ ๋๋ค. ๋ฑ๋กํ ์ ์์ต๋๋ค.
@ezyang ์ง๊ธ๊น์ง ๋ฉ๋ชจ๊ฐ ์์ต๋๊น? ์์ ํ๊ธฐ ์ ์ ์ฝ์ ์ ์์์ต๋๋ค. ์ง๋๋ฒ๊ณผ ๋ง์ด ๋ฌ๋ผ์ง ๊ฒ ๊ฐ์ต๋๋ค.
@Roger-luo @PhilippPelz ๋ํ ๋ณต์กํ ํ ์์ ๊ตฌํ์ ๋์๋๋ฆฌ๊ณ ์ถ์ต๋๋ค. ๋ฐ์ฌ๊ณผ์ ์ฐ๊ตฌ์๋ ํ์ํฉ๋๋ค..
@alexgomezalanis ์ฌ์ ์๊ฐ์ ๋ํด ํ ๋ก ํ ์ ์๋ ์ฑ๋์ด ์์ ์ ์์ต๋๋ค. ๋ฐฉ๊ธ #complex-numbers
์ฑ๋ ํธ์ถ์ ๋ง๋ค์์ต๋๋ค. ํ์ง๋ง 9์๊น์ง๋ ์์
์ ์์ํ์ง ์์ ๊ฒ์
๋๋ค(์ฌ์ ํ Julia ์ฝ๋ ์ค ์ผ๋ถ๋ฅผ ์์
ํด์ผ ํฉ๋๋ค...)
BTW, ์ง๋๋ฒ์ ๋นํด ๋ง์ด ๋ณํ ๊ฒ ๊ฐ์ต๋๋ค. ์์ ๋ฟ๊ธฐ ์ ์ ์๊ฐ์ ๋ด์ ๋ฐ๋ผ์ก๊ฒ ์ต๋๋ค.
@alexgomezalanis ํ ์ ์์ต๋๋ค. ๋จผ์ slack์์ pytorch์ ์์ ๊ณต๊ฐ์ ๊ฐ์ ํด์ผ ํฉ๋๋ค. ๋๋ ๋๋ฅผ ์ฐพ์ ์ ์์ด. ์ด๋์ฅ์ ๋ฐ์ผ๋ ค๋ฉด [email protected] ๋ก ์ด๋ฉ์ผ์ ๋ณด๋ด์ฃผ์ญ์์ค.
@Roger-luo @alexgomezalanis ๋ณต์กํ ํ ์ ๋ฌธ์ ์ ๋ํ ์ถ์ ๋ค์ ๋ณด๊ฒ ๋์ด ๊ธฐ์ฉ๋๋ค. ์ ๋ ์ฐธ์ฌํ๊ฒ ๋ค๊ณ ์ ์ํ ์ ์์ง๋ง ํ์ค์ ์ผ๋ก 9์ ๋ง/10์ ์ด๊น์ง๋ ์ด๋ฐ ์ผ์ด ์ผ์ด๋์ง ์์ ๊ฒ์ ๋๋ค. ์ด ๋ฌธ์ ์ ๋ํ ๊ฝค ๋ง์ ๋ ผํ์๋ค์ ๋ณต์กํ ํ ์ ์ง์์ด ์ ๋ฐ์ฌ ํ๋ก์ ํธ์ ๋งค์ฐ ๋์์ด ๋ ๊ฒ์ ๋๋ค.
๋๋ ๋ํ ์๋ ์ ๋ด ์ฐ๊ตฌ๋ฅผ ์ ์ฅํ๋ ค๊ณ ๋ ธ๋ ฅํ์ต๋๋ค ๐ ... ๊ทธ๋ฌ๋ ์ด์ ๋ ์ด์ 1w + loc ์ฝ๋๋ฅผ ๋ค์ ๋๋๋ฆฌ๊ณ ์ถ์ต๋๋ค. ๐คฃ ์ฌ๋์์ ์ฑํ ํ์!
:) ์, ์ฌ๋์์ ์ฑํ ํฉ์๋ค. ๋ฉ์ผ ํด๋์์ ์ด๋์ฅ์ ์ฐพ์์ต๋๋ค.
์์ ์งํ ์ค์ธ ํ๋ฌ๊ทธ์ธ(๋จ๊ธฐ CPU ์ ์ฉ)์ https://github.com/Roger-luo/pytorch-complex ์ ๋๋ค.
์ ์๊ฒ ์ด์์ PR์ ์ฃผ์๊ธฐ ๋ฐ๋๋๋ค.
์ด ๋ฌธ์ ์ ๋งจ ์์ ๋ณต์กํ ๊ตฌํ์ด ์ํ๋๋ ๋ฐฉ๋ฒ์ ๋ํ ๋ฉ๋ชจ๋ฅผ ๊ฒ์ํ์ต๋๋ค.
์ ๋ ์ต๊ทผ์ PyTorch๋ฅผ ์ฌ์ฉํ๊ธฐ ์์ํ๊ณ ์ ๋ง ์ข์ํฉ๋๋ค. TensorFlow๋ณด๋ค ์ฌ์ฉํ๊ธฐ๊ฐ ํจ์ฌ ์ข์ต๋๋ค. ๊ทธ๋ฌ๋ ๋ณต์กํ ํ ์ ์ง์์ ๋ด ์ฐ๊ตฌ(๊ด ์ ๊ฒฝ๋ง)์ ๋งค์ฐ ์ค์ํฉ๋๋ค. ์์ง ํ๋ฐํ ์งํ๋๊ณ ์๋์? ๊ทธ๋ ๋ค๋ฉด ๋ณต์กํ ํ ์ ์ง์์ ๋ํ (๋์จํ) ๊ธฐ๊ฐ์ ์๋ ์ฌ๋์ด ์์ต๋๊น?
์ ๊ฐ ํ ์ ์๋ ๊ณณ์์ ์ด ์์ ์ ๋๊ฒ ๋์ด ๊ธฐ์ฉ๋๋ค. ํ์ง๋ง ์ ๋ PyTorch๋ฅผ ์ฒ์ ์ ํ๊ธฐ ๋๋ฌธ์ ์ด ๊ธฐ๋ฅ์ด ์ผ๋ง๋ ํฐ ์ผ์ธ์ง ์์ง ์ ๋ชจ๋ฆ ๋๋ค. ๋ด ์ฐ๊ตฌ์ค ๋๋ฃ ์ค ์ผ๋ถ๋ ๋ณต์กํ ํ ์ ์ง์์ ๊น์ ๊ด์ฌ์ ํ๋ช ํ์ผ๋ฉฐ(๋ฌผ๋ฆฌํ์์ ์ด๋ฅผ ์ถ๊ฐํ๋ฉด Torch๊ฐ ๊ฑฐ์ NumPy์ ๋ํ GPU ๊ฐ์ ๋์ฒดํ์ด ๋ ์ ์์) ๊ฐ๊น์ด ๋ฏธ๋.
์๋ ํ์ธ์ @bencbartlett
๋๋ ์ฌ์ ํ ์ฒ์ฒํ ์์ ํ๋ ค๊ณ ๋ ธ๋ ฅํ๊ณ ์์ต๋๋ค.... ํ์ง๋ง ์ ๋ ํ์ฌ ํ์์ด๊ธฐ๋ ํฉ๋๋ค(๋งค์ฐ ๋ถ์์ ํ ์ํฉ์ ์์). ์ฆ, ์ด ํํ์ ์์ ์ ํ ์ ์๊ณ ์ฌ๊ฐ ์๊ฐ์๋ง ์์ ํ ์ ์์ต๋๋ค. (์ ๋ ์๋ ๋ถํฐ Julia์์ ์ฐ๊ตฌ ๊ด๋ จ ์ฝ๋๋ฅผ ๊ตฌํํ์ต๋๋ค. ์ด๋ ๊ธฐ์กด ํจํค์ง์๋ง ํ ์น์ ๋ ๋์ ๋ณต์์ ์ง์์ด ํ์ํ๋ค๋ ๊ฒ์ ์๋ฏธํฉ๋๋ค.)
๋ณต์์๊ฐ ์ค์ํ๊ณ ํ๋ถ์ ์ผ๋ ๊ฒ์ด ๊ธด๊ธํ ๊ฒฝ์ฐ ๋ค์์ ์๋ํ๋ ๊ฒ์ด ์ข์ต๋๋ค.
https://github.com/PIQuIL/QuCumber/blob/master/qucumber/utils/cplx.py
๊ทธ๊ฒ์ ๋งค์ฐ ๋๋ฆฌ์ง๋ง ... ์ ์ด๋ ์๋ํฉ๋๋ค. ๋๋ ์ด์ TH ์คํ์ผ์ C ๋ฒ์ ์ด ์์์ต๋๋ค.
์ด๊ฒ์ ๋ฉฐ์น ๋ง์ ๋๋ผ ์ ์๋ ์์ ํ๋ก์ ํธ๊ฐ ์๋๋๋ค. ๋ฐ๋ผ์ CPU ๋๋ CUDA์ ๋ํ ๋ณต์กํ ๊ฐ์ผ๋ก ์์ ํ ๊ธฐ๋ฅ ์ง์์ ์ํ ํน์ ์๊ฐ ํ๋ ์์ ๋ณด์ฅํ ์ ์์ต๋๋ค.
๊ทธ๋ฌ๋ ์ด ๋ฌธ์ ์ ๋ํด ์ ์ ํจ๊ป ์์ ํ๋ ๋ฐ ๋์์ด ๋์์ผ๋ฉด ํฉ๋๋ค. ํ์ฅ ๋ฆฌํฌ์งํ ๋ฆฌ์ ๊ฒ์ํ ๋ฌธ์ ๋ฅผ ํด๊ฒฐํ๋ ๊ฒ์ผ๋ก ์์ํ๋ ๊ฒ์ด ์ข์ต๋๋ค. ๊ทธ๋ฆฌ๊ณ ์ง๋ฌธ์ด ์๋ ๊ฒฝ์ฐ slack์ด๋ ์ด๋ฉ์ผ ๋๋ ์ด์๋ฅผ ํตํด ์์ ๋กญ๊ฒ ์ง๋ฌธํ์ธ์(์์ง ๋ฌธ์๊ฐ ๋ง์ง ์๊ธฐ ๋๋ฌธ์).
๋ถํํ๋ ์์ง PyTorch Slack์ ์ก์ธ์คํ ์ ์์ต๋๋ค. (์ด๋๋ฅผ ์์ฒญํ๋ ์ด๋ฉ์ผ์ ๋ ๋ฒ์ด๋ ๋ณด๋์ง๋ง ๋ต์ฅ์ ๋ฐ์ง ๋ชปํ์ต๋๋ค.) ๋๊ตฐ๊ฐ ์ ๋ฅผ ์ด๋ํ ์ ์์ต๋๊น? ([email protected])
@Roger-luo ํ์คํ ๊ทํ์ ํฌํฌ๋ฅผ ์ดํด๋ณด๊ฒ ์ง๋ง ๋ง์ ๋์์ด ๋ ๊ฒ์ด๋ผ๊ณ ์ฅ๋ดํ ์๋ ์์ต๋๋ค. ์ C++๊ฐ ๋ น์ฌ๊ณ ์ง์ ํ์ ๋๋ก ํ์. QuCumber ์ ํธ๋ฆฌํฐ๋ ํ๋ฅญํ์ง๋ง ๋ถํํ๋ ๋์๊ฒ๋ ๋ณ๋ก ๋์์ด ๋์ง ์์ต๋๋ค. ๋ณต์กํ ํ ์๊ฐ GPU ์ง์๋๊ฑฐ๋ autograd ๋ฐ torch.nn์์ ์ง์๋ ๋๊น์ง๋ NumPy๊ฐ ์ ๊ณตํ ์ ์๋ ๊ฒ๋ณด๋ค ๋ง์ ์ ํธ๋ฆฌํฐ๋ฅผ ์ ๊ณตํ์ง ์์ต๋๋ค.
@soumith @ezyang PyTorch ํ์์ ์ด์ ๋ํด ๋ ๋ง์ ๊ด์ฌ์ ๊ฐ์ ธ์ฃผ์๋ฉด ์ข๊ฒ ์ต๋๋ค! ๋ณตํฉ ์ง์์ ์ผ๋ฐ ํ ์ ๋ผ์ด๋ธ๋ฌ๋ฆฌ๊ฐ ๊ฐ์ถ์ด์ผ ํ ์ค์ํ ๊ธฐ๋ฅ์ธ ๊ฒ ๊ฐ์ต๋๋ค. ์ด๋ ๋ฌผ๋ฆฌํ์์ ์ฌ์ค์ ํ์์ ์ด๋ฉฐ ํนํ ์ง๋ ๋ช ๋ ๋์ ML ๋ด์์ ๋ณต์์ ๊ฐ ๋ชจ๋ธ์ ๋ํ ๊ด์ฌ์ด ๊ธ๊ฒฉํ ์ฆ๊ฐํ์ต๋๋ค.
@bencbartlett QuCumber์ ์ ๊ทผ ๋ฐฉ์์ AD๊ฐ ์๋ GPU์์ ์ฌ์ฉํ ์ ์์ต๋๋ค... ๊ทธ๊ฒ์ ๋จ์ง ๋งค์ฐ ๋๋ฆฝ๋๋ค... ์ ๋ง์ ๋น์ ์ด ๊ทธ AD๋ฅผ ์ํ๋ฉด ์ฌ์ฉํ ์ ์์ ๊ฒ์ ๋๋ค.
์, ์์งํ ๋งํด์ https://github.com/FluxML/Flux.jl ์ ์ฝ๊ฐ ์์ ๋ ๋ฒ์ ๊ณผ ์ฐ๊ตฌ๋ฅผ ์ํด Julia์์ ์์ฒด ํจํค์ง๋ฅผ ์ฌ์ฉํ๊ณ ์์ต๋๋ค(์ด๋ค ์ํฉ์์๋ ํ ์๊ฐ ์๋ GPU์ ๋ณต์กํ AD๋ ํ์ํฉ๋๋ค. ). source2source AD โโํจํค์ง Zygote.jl์ ๋ณต์กํ ํ ์์์ AD๋ฅผ ์ํํ ์ ์์ง๋ง ์ธ๊ทธ๋จผํธ ์ค๋ฅ๊ฐ ์์ ์ ์๋ ์ด๊ธฐ ๋จ๊ณ์ ๋๋ค. ์ํ๊ณ๋ ์์ง ํ ์น์ ๋นํด ์์ ์ ์ด์ง ์๊ณ ์์ฒด ์ฌ์ฉ์ ์ํด ์ด๋ฌํ ๊ตฌํ์ ์ฝ๊ฐ ํดํนํด์ผ ํ๋ ๊ฒฝ์ฐ๊ฐ ์์ต๋๋ค. ํ์ง๋ง ๊ธฐ๋ณธ์ ์ผ๋ก๋ ์์ ๋ฌผ๋ฆฌํ ์ฐ๊ตฌ์ ํ์ํ ๊ฒ์ ๋๋ค. GPU์์๋ ๋ณต์กํ ํ ์๋ฅผ ๊ฐ์ง ์ ์์ต๋๋ค.
torch.nn
์ ๋ํ ๋ณต์กํ ๊ฐ ์ง์์ด ํ์ํ์ง ์๋ค๊ณ ์๊ฐํฉ๋๋ค. ๋ณต์กํ ํ
์๊ฐ ์๋ํ๋ฉด autograd
์ ๋ํ ๋ช ๊ฐ์ง ์ ์๋ฅผ ์ถ๊ฐํด์ผ ํ ์๋ ์์ต๋๋ค. ์ ํ ๋ ์ด์ด์ ๊ฐ์ ๊ฒ์ด ๋์ผํ๊ฒ ์ ์ง๋ ์ ์๊ธฐ ๋๋ฌธ์
๋๋ค. . ๊ทธ๋ฆฌ๊ณ ์ผ๋ถ ํ์ฑํ ๊ธฐ๋ฅ์ Hilbert ๊ณต๊ฐ์์ ํ์ค ํ์ฅ์ด ์์ ์ ์์ต๋๋ค... (๋ด ํ๋ ฅ์ @GiggleLiu ์ ๋ธ๋ก๊ทธ ๊ฒ์๋ฌผ์ ํ์ธํ ์ ์์ต๋๋ค)
pytorch-complex ํ์ฅ์ ๊ฒฝ์ฐ GPU์์ AD์ ๋ํ ์์ ํ ์ง์์ ์ธ์ ๋ฐ์ ์ ์๋์ง ์ ๋ชจ๋ฅด๊ฒ ์ต๋๋ค... ์ด๊ฒ์ ์ฌ์ ํ โโ๋์๊ฒ ๊ฝค ๋จผ ๊ฒ ๊ฐ์ต๋๋ค. CPU ๊ตฌํ์ ๋ฉ์ธ ํธ๋ฆฌ์์ ํจ์น๊ฐ ํ์ํ ์ผ์ ๊ธฐ๊ฐ(์: ์ ํ ์น๊ฒฉ, simd ์ง์ ๋ฑ)์ ๊ฑฐ์ณ์ผ ํ๋ค๊ณ ๋งํ๊ณ ์ถ์ต๋๋ค. ์ด๊ฒ์ C++์ ๋ค๊ฐ์ค๋ Aten ๊ตฌํ๊ณผ ๊ด๋ จ์ด ์๊ณ TH ๋ฑ์ ์ ๊ฑฐํ๋ ๊ฒ๊ณผ๋ ๊ด๋ จ์ด ์์ ์ ์์ต๋๋ค. ๊ทธ๋ฌ๋ฉด ๋ณต์กํ ํ ์์ ๋ํ ์ฐ์ฐ์๋ฅผ ๋ ๋น ๋ฅด๊ฒ ์ถ๊ฐํ ์ ์์ต๋๋ค.
๋๋ ๋ด์ ์ธํด์ฝ์ ์ ์ฒญํ ์ ์์ต๋๋ค(๋ฐฉ๊ธ @ezyang ์๊ฒ ๋ฌผ์ด๋ดค์ต๋๋ค). ๊ทธ๋์ ๋ฐ์ฌ ๊ณผ์ ์ ์์ํ๊ธฐ ๋ช ๋ฌ ์ ์ ํํ์์ผ๋ก ์ด ์ผ์ ํ ์ ์์์ง๋ ๋ชจ๋ฆ ๋๋ค. ๋ด ์๋ค.
๊ทธ ๋์ ๋๋ ๋ณต์กํ ๊ณฑ์ ์ ๋ด ์์ ์ ๋ฒ์ ์ ๊ตฌํํ์ต๋๋ค. ๊ทธ๋ฌ๋ ํ๋กํ์ผ๋งํ ๋ ์๋นํ ์๊ฐ์ด ์์๋ฉ๋๋ค. torch._C_._cuda_isDriverSufficient
์ด์ ๊ฐ ๋ญ์ง ์์ธ์? ๋ณต์กํ ๊ณฑ์ ์ ๋ ๋์ ๊ตฌํ์ ์๊ณ ์๋ค๋ฉด ์๋ ค์ฃผ์ญ์์ค. ์ฌํํผ, ๋ด ๋ฒ์ (๊ณฑ์ ์์ ์ต์ ํ๋์ด ์์: 4 ๋์ 3)์ ์๋์ ์ผ๋ก ๋๋ฆฐ ๊ฒ ๊ฐ์ต๋๋ค. ์๋ฅผ ๋ค์ด out ํ ์์ irfft๋ ์์๋ณ ๊ณฑ์ ๋ณด๋ค 10๋ฐฐ ๋น ๋ฆ ๋๋ค. PyTorch์ C++ ์์ค์์ ๋ณต์์ ๊ณฑ์ ์ด ์ง์๋ฉ๋๊น?
def complex_mul(x, y, out):
uavc = x[..., 0] * (y[..., 0] + y[..., 1])
out[..., 0] = uavc - (x[..., 0] + x[..., 1]) * y[..., 1]
out[..., 1] = (x[..., 1] - x[..., 0]) * y[..., 0] + uavc
def test_complex_mul_out_tensor(self):
N, C, H, W, I = 128, 3, 32, 32, 2
K = 16 # number of filter banks
repetitions = 1000
dtype = torch.float
if torch.cuda.is_available():
device = torch.device("cuda")
else:
device = torch.device("cpu")
x = torch.randn(N, 1, C, H, W, I, dtype=dtype, device=device)
y = torch.randn(K, C, H, W, I, dtype=dtype, device=device)
start_mul_time = time.time()
out = torch.empty(N, K, C, H, W, I, dtype=dtype, device=device)
for _ in range(repetitions):
complex_mul(x, y, out)
print("multiplication time: ", time.time() - start_mul_time)
์ฐ๋ฆฌ๋ C++์์ ๊ทธ๊ฒ์ ์ง์ํ๋ ค๊ณ ๋ ธ๋ ฅํ๊ณ ์์ต๋๋ค. ์๋จ์ ๊ฒ์๋ฌผ์ ์ฐธ์กฐํ์ญ์์ค. ํ์ฅ์ ์ปดํ์ผํ ์ ์๋ค๋ฉด ์ ์ด๋ ํ์ฌ๋ก์๋ ์ค์นผ๋ผ ๊ณฑ์ ์์ ์๋ํด์ผ ํฉ๋๋ค....
๊ตฌํ์ QuCumber์ ์๋ ๊ฒ๊ณผ ์ ์ฌํฉ๋๋ค. ๋ณต์์์ ๋ํด ์ฌ๋ฐ๋ฅธ cuda ์ปค๋์ ํธ์ถํ์ง ์์ผ๋ฉด ๋ ๋ง์ GPU ์ค๋ ๋๋ฅผ ํธ์ถํ ์ ์์ต๋๋ค. Python์์ ์ง์ํ๋ C++ ๋ฐฑ์๋๊ฐ ์์ผ๋ฉด SIMD๋ฅผ ์์ ์ ์์ต๋๋ค.
์์ธํ ๋ด์ฉ์ ๋ณด๋ ค๋ฉด nvprof
๋ฅผ ์คํํ๋ ๊ฒ์ด ์ข์ต๋๋ค.
@Roger-luo @apaszke @soumith ์ด ์ค๋ ๋์ ๊ฐ์ฌ๋๋ฆฝ๋๋ค. btw. ๋๋ ํ ์น.ํ ์๋ฅผ ์๋ธํด๋์ฑํ์ฌ ํจ๊ป ํดํน๋ ๊ธฐ๋ณธ ๋ณตํฉ ํ ์๋ฅผ ๊ตฌํํ์ต๋๋ค.
๋๋ ์ ๋ฐ๋ถ๋ฅผ ์ค์ ๋ก, ํ๋ฐ๋ถ๋ฅผ ํ์์ฌ๋ก ์ทจ๊ธํ๊ณ ๋ด ์์ ์ ๊ธฐ๋ณธ ์ฐ์ ์ฐ์ฐ๊ณผ ๋ด ์ฐ๊ตฌ์ ํ์ํ ์ผ๋ถ ๋ค๋ฅธ ์ฐ์ฐ์ ๊ตฌํํฉ๋๋ค.
Tensorflow ๋ฐ numpy์ ๋ํด ํ์ธํ์ต๋๋ค. ๋ด๊ฐ ๊ตฌํํ ๊ทธ๋ผ๋์ธํธ์ ๋ชจ๋ ์์ ์ ์ถ๋ ฅ๊ณผ ์ผ์นํฉ๋๋ค!
PT๊ฐ ๋ณต์กํ ํ ์๋ฅผ ์์ ํ ์ง์ํ ๋๊น์ง ๋ณด๋ฅ๋ก ์ฌ์ฉํ๊ธฐ ์ํ ๊ฒ์ ๋๋ค.
ํน์ง:
pip install pytorch-complex-tensor
@williamFalcon ๊ฐ์ฌํฉ๋๋ค!
์ ๋ฐ์ดํธ ํ๋? ๋ณต์กํ ์ ํ ์ง์์ pytorch์ ํตํฉํ ๊ณํ์ด ์๋์ง ๊ถ๊ธํฉ๋๋ค.
์๋ ํ์ธ์, @whmrtm
@ezyang ์ https://github.com/Roger-luo/pytorch-complex/issues/4 ์์ ์์ ์ค์ ๋๋ค. ๋๋ ์ด์ ๊ด์ฌ์ด ์๋ ์ฌ๋์ ๋๊ตฌ๋ ์คํํ ์ ์๋๋ก ๋์์ค ์ ์์ต๋๋ค. ์ด ๋ฌธ์ ๋ ์ผ๋ถ ๊ธฐ๋ณธ์ ์ธ ๋ฐฉ์ก ๋ฌธ์ ๋ฅผ ํด๊ฒฐํฉ๋๋ค(์ด ๋ฌธ์ ๊ฐ ํด๊ฒฐ๋ ํ ๋ง์ ๊ธฐ๋ฅ์ ์ฌ์ฉํ ์ ์์). ์์ ๋กญ๊ฒ PR์ ํ๊ฑฐ๋ ์ ์๊ฒ ๋น์ ์ ํ๋ ฅ์๋ก ์ถ๊ฐํด๋ฌ๋ผ๊ณ ์์ฒญํ์ญ์์ค.
๋๋ ์ฌ๋ฆ๊น์ง ์๋ฌด๊ฒ๋ ํ ์ ์์ ๊ฒ์ด๊ณ , ์ฐ๋ฆฌ ์์ ์ ํจํค์ง๋ฅผ ์ํ ์๋ก์ด ๋ฆด๋ฆฌ์ค๋ฅผ ๋๋ด์ผ ํฉ๋๋ค.
์๋ ํ์ธ์, @whmrtm
@ezyang ์ Roger-luo/pytorch-complex#4 ์ ๋ํด ์์ ์ค์ ๋๋ค. ๋๋ ์ด์ ๊ด์ฌ์ด ์๋ ์ฌ๋์ ๋๊ตฌ๋ ์คํํ ์ ์๋๋ก ๋์์ค ์ ์์ต๋๋ค. ์ด ๋ฌธ์ ๋ ์ผ๋ถ ๊ธฐ๋ณธ์ ์ธ ๋ฐฉ์ก ๋ฌธ์ ๋ฅผ ํด๊ฒฐํฉ๋๋ค(์ด ๋ฌธ์ ๊ฐ ํด๊ฒฐ๋ ํ ๋ง์ ๊ธฐ๋ฅ์ ์ฌ์ฉํ ์ ์์). ์์ ๋กญ๊ฒ PR์ ํ๊ฑฐ๋ ์ ์๊ฒ ๋น์ ์ ํ๋ ฅ์๋ก ์ถ๊ฐํด๋ฌ๋ผ๊ณ ์์ฒญํ์ญ์์ค.
๋๋ ์ฌ๋ฆ๊น์ง ์๋ฌด๊ฒ๋ ํ ์ ์์ ๊ฒ์ด๊ณ , ์ฐ๋ฆฌ ์์ ์ ํจํค์ง๋ฅผ ์ํ ์๋ก์ด ๋ฆด๋ฆฌ์ค๋ฅผ ๋๋ด์ผ ํฉ๋๋ค.
์ ๋ฐ์ดํธ ๊ฐ์ฌํฉ๋๋ค. ์ ๊ฐ ํ ์ ์๋ ์ผ์ ์์๋ณด๊ฒ ์ต๋๋ค.
์๋ ํ์ธ์ @Roger-luo
๋ณต์กํ ํ ์ ์ง์ ์ฃผ์ ([email protected])์ ๊ด๋ จ๋ slack ์ฑ๋์ ์ก์ธ์คํ ์ ์๋์? ์ด๋์ฅ์ ์ด๋ฉ์ผ๋ก ๋ณด๋์ง๋ง ์์ง ์๋ฌด ์ผ๋ ์ผ์ด๋์ง ์์์ต๋๋ค. ์ง๊ธ ์ ๋ ์ด ๋ฌธ์ ์ ๊ธฐ์ฌํ๊ธฐ ์์ํ ์ง์ ์ ํ์ ํ๋ ค๊ณ ๋ ธ๋ ฅํ๊ณ ์์ต๋๋ค. https://github.com/Roger-luo/pytorch-complex/issues/4 ๊ฐ ํ์ฌ ์ง์ ์ ์ธ ๊ฒ ๊ฐ์๋ฐ์?
@beconstant ์, ๊ทธ๊ฒ์ด ์์์ ์ ๋๋ค. ์ด๊ฒ์ ์ผ๋ถ ๋ธ๋ก๋์บ์คํธ ๊ธฐ๋ฅ์ ์๋ํ๊ฒ ํด์ผ ํ์ง๋ง cuda์์ ์ ํ ์น๊ฒฉ ์ค๋ฅ๊ฐ ๋ฐ์ํ๋ ์ด์ ๋ฅผ ๋ชจ๋ฅด๊ฒ ์ต๋๋ค. CPU์์ ์๋ํ๊ณ ์์์ต๋๋ค. (์ฒ์๋ถํฐ cuda๋ฅผ ์ง์ํ ์๊ฐ์ ์์ง๋ง ๋น๋ ์คํจ์ ์์ธ์ด ๋ฉ๋๋ค.)
์ด๋ ์ด๋ฉ์ผ์ ๋ณด๋ผ ์ ์์ต๋๋ค(์ก์ธ์ค ๊ถํ์ด ์์ต๋๋ค). slack์ ๊ฐ์ ํ๋ ค๋ฉด pytorch ๊ณต์ ๊ฐ์ด๋๋ฅผ ๋ฐ๋ผ์ผ ํ๋ค๊ณ ์๊ฐํฉ๋๋ค. ๊ทธ๋ฌ๋ ์ฐ๋ฆฌ๋ ํญ์ ์ด์/PR์์ ๋ ผ์ํ ์ ์์ต๋๋ค.
@Roger-luo ์๊ฒ ์ต๋๋ค. :)
๋์์ด ํ์ํ๋ฉด ์๋ ค์ฃผ์ธ์. ์ง์ ๋ pytorch ๋ฒ์ ์ ๋น๋ํ๋ ๊ฒ์ผ๋ก ์์ํ๊ฒ ์ต๋๋ค. pytorch-complex/issues/4 ์ ๋ํ ์งํ ์ํฉ์ด ์์ต๋๊น?
๋์์ด ํ์ํ๋ฉด ์๋ ค์ฃผ์ธ์. ์ง์ ๋ pytorch ๋ฒ์ ์ ๋น๋ํ๋ ๊ฒ์ผ๋ก ์์ํ๊ฒ ์ต๋๋ค. pytorch-complex/issues/4 ์ ๋ํ ์งํ ์ํฉ์ด ์์ต๋๊น?
@dylanbespalko ์๋
ํ์ธ์, Complex-valued ๋ฒ์ ์ผ๋ก ๊ตฌํ๋ pytorch๊ฐ ์๊ธํ ํ์ํฉ๋๋ค.
๊ทํ์ ๊ธฐ์ฌ์ ์ง์ฌ์ผ๋ก ๊ฐ์ฌ๋๋ฆฝ๋๋ค.
์น์ ํ๋,
์ ค๋ผ๋ฅด209
์๋ ํ์ธ์ @Zellar209 ,
@ezyang ์ด ๋ ํฐ ๋ฌธ์ ( pytorch-complex/issues/4 ) ์ค ํ๋์์ ์ด์ฌํ ์ผํ๊ณ ์๋ค๋ ๋๋์ ๋ฐ๊ณ ์์ต๋๋ค. ํ์ฌ AMD ์์คํ ์ด ์๊ณ 3์ฃผ ํ์ GPU ์ง์์ ๊ฐํํ๋ ๋ฐ ์ฌ์ฉํ ์ ์๋ Nvidia ์์คํ ์ด ์์ต๋๋ค.
๋ฌธ์ ๋ ์๋ ์ ํ ์น๊ฒฉ ๋ณ๊ฒฝ์ด CUDA๋ฅผ ์ค๋จํ๋ ๊ฒ๋ฟ์ธ ๊ฒ ๊ฐ์ต๋๋ค. ํด๋น PR์ด ํด๊ฒฐ๋๋ ํ ์ต์ํ ์ผ๋ถ ์ด์์๊ฐ CPU์์ ์๋ํ๋๋ก ํ๊ณ ์์ง CUDA๋ฅผ ์ง์ํ์ง ์์ต๋๋ค...
IMHO ๋๋ ์ฐ๋ฆฌ๊ฐ CPU์ ์ง์คํ๊ณ ์ผ์ ๋จผ์ ๋ง๋ค๊ณ GPU๋ฅผ ๋์ค์ ๊ณ ๋ คํด์ผ ํ๋ค๊ณ ์๊ฐํฉ๋๋ค.
CPU๋ง ์ง์ํ๋ฉด ๊ด์ฐฎ์ต๋๋ค. ์ด ์ ํ ์น๊ฒฉ ๋ฌธ์ ( pytorch-complex/issues/4 ๊ฐ fb์์ ๋ด๋ถ์ ์ผ๋ก ์ฒ๋ฆฌ๋ฉ๋๊น? ์ธ๋ถ์์ ์์ ํด๋ ๊ด์ฐฎ์ต๋๊น?
์๋ ํ์ธ์ @dylanbespalko; ๋๋ @Roger-luo์๊ฒ ๋ด๊ฐ ๊ทธ๊ฒ์ ์กฐ์ฌํ ๊ฒ์ด๋ผ๊ณ ๋งํ์ง๋ง(์๋ํ๋ฉด ๋ด๊ฐ ์๋ง๋ ๋ฌธ์ ๊ฐ ๋ฌด์์ธ์ง ํ์ ํ๋ ๊ฒ์ด ๊ฐ์ฅ ์ข์ ์์น์ ์๊ธฐ ๋๋ฌธ์ ๋๋ค), ์์ง ๊ทธ๊ฒ์ ๋ณผ ์๊ฐ์ด ์์์ต๋๋ค. ๋ฌธ์ ๋ฅผ ํด๊ฒฐํ๋ ๋ฐฉ๋ฒ์ ์ฐพ์๋ณด๊ณ ์ถ๋ค๋ฉด ๊ธฐ๊บผ์ด ์กฐ์ธํด ๋๋ฆฌ๊ฒ ์ต๋๋ค.
์๋ ํ์ธ์ @Zellar209 ,
@ezyang ์ด ๋ ํฐ ๋ฌธ์ ( pytorch-complex/issues/4 ) ์ค ํ๋์์ ์ด์ฌํ ์ผํ๊ณ ์๋ค๋ ๋๋์ ๋ฐ๊ณ ์์ต๋๋ค. ํ์ฌ AMD ์์คํ ์ด ์๊ณ 3์ฃผ ํ์ GPU ์ง์์ ๊ฐํํ๋ ๋ฐ ์ฌ์ฉํ ์ ์๋ Nvidia ์์คํ ์ด ์์ต๋๋ค.
์, ์ง๊ธ์ GPU๊ฐ ํ์ํ์ง ์์ต๋๋ค. ์ ๋ MAC ์์คํ ์ ์ฌ์ฉํ๊ณ ์์ต๋๋ค. ํ์ง๋ง ์ด ํ๋ก์ ํธ๋ฅผ ๋น๋ํ ๋ ๋ช ๊ฐ์ง ์ค๋ฅ๊ฐ ์์์ต๋๋ค.
์๋ ํ์ธ์ @Zellar209๋ , pytorch-complex์ ๋ฌธ์ ์์ ์ป์ ๋ด์ฉ์ ๊ฒ์ํ ์ ์์ต๋๊น? Mac์ ์๋ก์ด Xcode์ ๋ฌธ์ ๊ฐ ์๋ ๊ฒ ๊ฐ์์ ๋น๋ํ๊ธฐ๊ฐ ์ด๋ ต์ต๋๋ค. ๊ทธ๋ฌ๋ ์ฌ๋๋ค์ ๊ทธ ์ด์ ๋ฅผ ์์๋ด๊ธฐ ์ํด ๋ ๋ง์ ์ค๋ฅ ๋ฉ์์ง๊ฐ ํ์ํ ๊ฒ์ ๋๋ค.
OS๋ ์๋ฌ๋ฉ์ธ์ง ๋ฌผ์ด๋ดค๋๋ฐ ๋ต์ด์๋ค...
์๋ ํ์ธ์ @dylanbespalko; ๋๋ @Roger-luo์๊ฒ ๋ด๊ฐ ๊ทธ๊ฒ์ ์กฐ์ฌํ ๊ฒ์ด๋ผ๊ณ ๋งํ์ง๋ง(์๋ํ๋ฉด ๋ด๊ฐ ์๋ง๋ ๋ฌธ์ ๊ฐ ๋ฌด์์ธ์ง ํ์ ํ๋ ๊ฒ์ด ๊ฐ์ฅ ์ข์ ์์น์ ์๊ธฐ ๋๋ฌธ์ ๋๋ค), ์์ง ๊ทธ๊ฒ์ ๋ณผ ์๊ฐ์ด ์์์ต๋๋ค. ๋ฌธ์ ๋ฅผ ํด๊ฒฐํ๋ ๋ฐฉ๋ฒ์ ์ฐพ์๋ณด๊ณ ์ถ๋ค๋ฉด ๊ธฐ๊บผ์ด ์กฐ์ธํด ๋๋ฆฌ๊ฒ ์ต๋๋ค.
๋น ๋ฅธ ๋ต๋ณ ๊ฐ์ฌํฉ๋๋ค.
๊ฑด๋ฌผ 'torch_complex.cpp' ํ์ฅ
gcc -Wno-unused-result -Wsign-compare -Wunreachable-code -DNDEBUG -g -fwrapv -O3 -Wall -Wstrict-prototypes -I/anaconda3/include -arch x86_64 -I/anaconda3/include -arch x86_64 -I/ anaconda3/lib/python3.6/site-packages/torch/include -I/anaconda3/lib/python3.6/site-packages/torch/include/torch/csrc/api/include -I/anaconda3/lib/python3. 6/site-packages/torch/include/TH -I/anaconda3/lib/python3.6/site-packages/torch/include/THC -I/anaconda3/include/python3.6m -c src/module.cpp -o ๋น๋/temp.macosx-10.7-x86_64-3.6/src/module.o -g -stdlib=libc++ -std=c++11 -DTORCH_API_INCLUDE_EXTENSION_H -DTORCH_EXTENSION_NAME=cpp
gcc: ์ค๋ฅ: ์ธ์ํ ์ ์๋ ๋ช
๋ น์ค ์ต์
'-stdlib=libc++'
src/module์ ํฌํจ๋ ํ์ผ์ ์์ต๋๋ค. cpp:2 :
src/CPUComplexType.h:60์ ํฌํจ๋ ํ์ผ:
src/CPUComplexTypeImpl.h:102:105: ๊ฒฝ๊ณ : 'IntList'๋ ๋ ์ด์ ์ฌ์ฉ๋์ง ์์ต๋๋ค. [-Wdeprecated-declarations]
Tensor & CPUComplexType::set_(Tensor & self, Storage ์์ค, int64_t storage_offset, IntList ํฌ๊ธฐ, IntList strides) const {
^^
/anaconda3/lib/python3.6/site-packages/torch/include/c10/util/ArrayRef.h:273:7: ์ฐธ๊ณ : 'IntList'๋ ์ฌ๊ธฐ์ ๋ ์ด์ ์ฌ์ฉ๋์ง ์๋ ๊ฒ์ผ๋ก ๋ช
์์ ์ผ๋ก ํ์๋์์ต๋๋ค.
IntList ์ฌ์ฉ C10_DEPRECATED_USING = ArrayRef
^^
src/module์ ํฌํจ๋ ํ์ผ์ ์์ต๋๋ค. cpp:2 :
src/CPUComplexType.h:60์ ํฌํจ๋ ํ์ผ:
src/CPUComplexTypeImpl.h:105:76: ์ค๋ฅ: 'at' ๋ค์์คํ์ด์ค์ 'scalarTypeToDataType'์ด๋ผ๋ ๋ฉค๋ฒ๊ฐ ์์ต๋๋ค.
์๋ ์์ค_ = checked_storage(์์ค,"์์ค",2, DeviceType::CPU, at::scalarTypeToDataType(CPUComplexTypeInfo::scalar_type));
~~~~^
7๊ฐ์ ๊ฒฝ๊ณ ์ 2๊ฐ์ ์ค๋ฅ๊ฐ ์์ฑ๋์์ต๋๋ค.
๋๋ ๊ทธ๊ฒ์ ๊ณ ์น ์ ์์ต๋๋ค. ๋๋ ๋น์ ์ด ๋๋ฅผ ๋์ธ ์ ์๊ธฐ๋ฅผ ์ ๋ง๋ก ๋ฐ๋๋๋ค!
์๋ค ์,
์๊ฒฌ์ ๋ณด๋ด์ฃผ์ ์ ๊ฐ์ฌํฉ๋๋ค. ๋๋ ์ด๊ฒ์ ์กฐ์ฌํ๋ ๋ฐ ์ผ์ฃผ์ผ์ ๋ณด๋ผ ์ ์๋ค๊ณ ์๊ฐํฉ๋๋ค. ์ง๊ธ๊น์ง ๋ค์๊ณผ ๊ฐ์ด @Roger-luo์ pytorch-complex๋ฅผ ์ปดํ์ผํ์ต๋๋ค.
@Zellar209 : macOS 10.13์์ ์คํ๋๋ ํ๊ฒฝ ๋ณ์๋ฅผ ์ฒจ๋ถํ์ต๋๋ค.
๋ค์๊ณผ ๊ฐ์ด ๊ธฐ์กด pytorch ๋ฐฐํฌ๋ฅผ ์ญ์ ํฉ๋๋ค.
์ฝ๋ค ์ ๊ฑฐ ํ์ดํ ์น
ํ ์ ๊ฑฐ ํ ์น
pip uninstall torch # ์ด ๋ช
๋ น์ ๋ ๋ฒ ์คํ
ํ์ด์ฌ setup.py ์ฒญ์
python site-packages ํด๋๊ฐ ์๋ ๊ฒฝ์ฐ ํ ์น ํด๋๋ฅผ ์ญ์ ํฉ๋๋ค.
์ด์ pytorch ์์ค ํด๋์ ์ด๋ฆ์ ๋ณ๊ฒฝ(๋๋ ์ญ์ )ํฉ๋๋ค.
PyTorch ๊ฐ์ ํ 6cb593b88cb0c411690b4957850058329526d87b๋ฅผ ์ค์นํฉ๋๋ค.
git clone [email protected]:pytorch/pytorch.git
git checkout 6cb593b88cb0c411690b4957850058329526d87b
git submodule update --init โrecursive
export CMAKE_PREFIX_PATH=${CONDA_PREFIX:-"$(dirname $(which conda))/../โ}
MACOSX_DEPLOYMENT_TARGET=10.13 CC=clang CXX=clang++ python setup.py develop
python
>>> import torch
python setup.py install
python setup.py build
python setup.py test
# ERROR: test (unittest.loader._FailedTest)
# ERROR: test_scalar_binary_op (tests.test_tensor.TestComplexTensor)
from torch_complex import torch
a = torch.ones(3, dtype=torch.complex128)
a*a
RuntimeError: promoteTypes with complex numbers is not handled yet; figure out what the correct rules should be
@ezyang , @Roger-luo:
ํ
์ ์์
์ ์ํ ์ ํ ์น๊ฒฉ์ ์ํ ๋ชจ๋ ๊ฒ์ c10/core/ScalarType.h ์์ ์ํ๋๋ ๊ฒ ๊ฐ์ต๋๋ค.
AT_ERROR("promoteTypes with complex numbers is not handled yet; figure out what the correct rules should beโ);
์ค๋ฅ๋ฅผ ์ฐพ์์ต๋๋ค.
์ด ํ
์ด๋ธ ์์ c8 ๋ฐ c16์ ๋ํ ํญ๋ชฉ์ ์ถ๊ฐํด์ผ ํ๋ ๊ฒ ๊ฐ์ต๋๋ค.
์ด๊ฒ์ด 9515 ์ ๊ด๋ จ์ด ์์ต๋๊น? ๋๋ ์ด๊ฒ์ด numpy ํจ์๋ฅผ ํธ์ถํ๊ธฐ์ํ ๊ฒ์ด๋ผ๊ณ ์๊ฐํฉ๋๋ค.
์์ํ๊ธฐ์ ์ข์ ๊ณณ์ธ๊ฐ์?
9515๋ ๊ด๋ จ์ด ์์ต๋๋ค. ๊ทธ๋ฌ๋ ScalarType.h์์ ์ด ์ฝ๋ ๊ฒฝ๋ก๋ฅผ ์์ ํ๋ ๊ฒ๋ถํฐ ์์ํ๋ ๊ฒ์ด ์ข์ต๋๋ค.
ScalarType.h์์ ์ฝ๋ ๊ฒฝ๋ก๋ฅผ ์์ ํ์ต๋๋ค.
BinaryOps(add, sub, mul, div)๊ฐ ์๋ํ์ง๋ง ๋ ์ธ์๊ฐ ๋ชจ๋ Tensor์ธ ๊ฒฝ์ฐ์๋ง ๊ฐ๋ฅํฉ๋๋ค.
๋ค๋ฅธ ์ด์ํ ๋ฌธ์ ๋ ์์ง๋ง ์ข ๋ ์ดํด๋ด์ผ ํฉ๋๋ค.
@dylanbespalko ์ฌ๊ธฐ์ ์ ํ ํ๋ก๋ชจ์ ์ ์ถ๊ฐํ์ต๋๋ค: https://github.com/pytorch/pytorch/pull/11641
๊ทธ๋ฅ ๋ณต์ฌํ ์๋ ์์ง๋ง ๋ฌธ์ ๋ ์ด๊ฒ์ด ์ด๋ป๊ฒ๋ CUDA๋ฅผ ์์์ํจ๋ค๋ ๊ฒ์ ๋๋ค.
IIRC, gcc ๋ฒ์ ์ผ๋ก ์ธํด ์์ด์ด ๋ฒ๊ทธ๊ฐ ์์์ต๋๋ค. ๊ฑฐ๊ธฐ์ ๋ช ๊ฐ์ง ํด๊ฒฐ ๋ฐฉ๋ฒ์ด ์์์ต๋๋ค.
์, ๊ฐ์ฌํฉ๋๋ค @Roger-luo. #11641 ์ ๋๊ธ์ ๋ณด๊ณ ์์์ต๋๋ค. ๋ด์ผ ์ฝ๋ ๋ณต์ฌ๋ฅผ ๋ ์ํ๊ฒ ์ต๋๋ค.
CUDA ์ฅ์น๊ฐ ์์ ๋ CUDA๊ฐ ๊ณ ์ฅ๋ฌ๋์ง ์ด๋ป๊ฒ ์ ์ ์์ต๋๊น? CI๊ฐ ์๋ ค์ค ๊ฑฐ๋ผ๊ณ ์๊ฐํฉ๋๊น?
์, PR์ ์ ์ถํ๋ฉด ์ด๋ ๊ฒ์ด ๊ณ ์ฅ๋ฌ๋์ง ์๋ ค์ค๋๋ค. ๋ชจ๋ ๊ฒ์ด ํต๊ณผํ๋ฉด ์ด๋ฅผ ๋ณํฉํ๊ณ ์์ ์ ์ํํ ์ ์์ต๋๋ค.
์๊ฒ ์ต๋๋ค. ๊ทธ๋ฌ๋ฉด PR ์ ์ถ์ ์์ํ์ฌ ์ธ์ ๋ฐ์ํ๋์ง ์ ์ ์์ต๋๋ค.
@dylanbespalko ์๋
ํ์ธ์, ์ฌ์ ํ ํ๊ฒฝ์ ์ค๋ฅ๊ฐ ์๋ ๊ฒ ๊ฐ์ต๋๊น?
๋น์ ์ด ๊ทธ๊ฒ์ ๊ณ ์น๋ค๋ฉด, ์ฐ๋ฆฌ์ ๊ณต์ ํ์ญ์์ค. ์ ๋ง ๊ฐ์ฌํฉ๋๋ค.
์๋ค ์,
@Roger-luo์ ์ปค๋ฐ ๋ช ๊ฐ๋ฅผ ๋ณต์ฌํ ํ ์ฌ๋ฌ PR์ ์๋ํ์ต๋๋ค. ๋ถํํ๋ ์ง๊ธ์ CUDA GPU๊ฐ ์๊ณ CUDA๊ฐ ์๋ CI ์์คํ ์ด ์ด๊ธฐํ๋์ง ์์ต๋๋ค. ์ง๊ธ์ CUDA ํ ์คํธ ์คํจ๋ฅผ ์ฌํํ ์ ์์ผ๋ฏ๋ก ๋ช ์ฃผ ํ์ ํด๋น GPU์์ ๋ก์ปฌ๋ก ์คํํ ์ ์๊ฒ ๋๋ฉด ๋ค์ ์ธ๊ธํ๊ฒ ์ต๋๋ค. ์ ์ด๋ ์ ๋งํด ๋ณด์ ๋๋ค.
@ezyang , @Roger-luo
๋๋ Roger์ PR #11641 ์ ์ดํด๋ณด์๋ค.
๋ํ ์ต๊ทผ PyTorch ๊ฐ๋ฐ ์ค ์ผ๋ถ๋ฅผ ์ดํด๋ณด์์ต๋๋ค.
์๋ก์ด "out-of-tree" ํ์ฅ ๊ธฐ๋ฅ์ด ๊ฐ๋ฐ๋๊ณ ์์ด ๋๋จธ์ง pytorch๋ฅผ ์์์ํค์ง ์๊ณ ๋ณต์์ ์ง์์ ์กฐ์ฌํ ์ ์์ต๋๋ค. ๋ด ๋ชฉํ๋ ๋ค์๊ณผ ๊ฐ์ต๋๋ค.
@ezyang
๋น์ ์ด ์ ์ํ ์ด out-of-tree device/layout/dtype ํ์ฅ์ ๋ํ ์์ ํ์๋ผ์ธ์ ์ ๊ณตํ ์ ์์ต๋๊น? ์์ผ๋ก 3๊ฐ์ ์ด๋ด์ ์ด ๊ธฐ๋ฅ์ ๊ธฐ๋ํ ์ ์์ต๋๊น?
@ezyang
AVX/SSE ์ง์ ์์ด CPU์์ ๋ณต์์ ์ง์์ ๋ณํฉํ ์ ์์ต๋๊น? ๋ณ๋์ ๋ณํฉ ์์ฒญ์ผ๋ก ๋ค์์ ์ ์ถํ ๊ณํ์ ๋๋ค.
์์ผ๋ก ๋ฉฐ์น ์์ Intel/arm CPU ์ ์ฒด์์ ์ด๊ฒ์ ํ ์คํธํ ๊ณํ์ ๋๋ค.
@ezyang ,
์ ๋ fft()
๋ฐ var()
$ ์ ๊ฐ์ ์ฐ์ฐ์ ์ฐพ๊ณ ์์ต๋๋ค. ์ฌ๊ธฐ์ ๋ณต์์ ๊ตฌํ์ ํ
์ ๋ฐ์ดํฐ๋ฅผ (complex_shape, 2)
๋ชจ์์ ์ด์ค ํ
์๋ก ๋ณํํด์ผ ํฉ๋๋ค. ์ด๊ฒ์ ๊ธฐ์กด ํ
์ ๋ฉ์๋์์๋ ์๋ํ์ง ์์ต๋๋ค.
๋ถ๋ช ํ ๋๋ โโ๋ค์๊ณผ ๊ฐ์ด ๋นํจ์จ์ ์ธ ์ผ์ ํ ์ ์์ต๋๋ค.
def to_float(tensor):
return th.stack((tensor.real().type(th.float64), tensor.imag().type(th.float64)), -1)
def to_complex(tensor):
tensor = tensor.type(th.complex128)
return tensor[..., 0] + 1j*tensor[..., 1]
๋ถ๋ช
ํ ๊ทธ๊ฒ์ ๋ณต์ฌ๋ณธ์ ๋ง๋๋ ๊ฒ์
๋๋ค. ํ์ํ ๊ฒ์ static_cast<double>
์ด๊ณ ํ
์์ ๋ชจ์์ (old_shape, 2)
๋ก ๋ณ๊ฒฝํ๋ ๊ฒ์
๋๋ค. ์ด ์์
์ ์ํํ๋ ๋ฐฉ๋ฒ์ ๋ํ ์ ์ ์ฌํญ์ด ์์ต๋๊น?
๋ํ numpy์๋ ๋ค์๊ณผ ๊ฐ์ ์์ ์ ์ํํ ์ ์๋ ํดํน์ด ์์ต๋๋ค.
a = np.array([1 + 1j], dtype=np.complex128)
a.dtype = np.float64 ## This works
a = torch.tensor([1 + 1j], dtype=torch.complex128)
a.dtype = torch.float64 ## This does not work
dtype์ ์ค์ ํ๋ ๊ธฐ๋ฅ์ ์ด ์ํฉ์์ ์ค์ ๋ก ์๋ํ์ง๋ง ์์ธกํ ์ ์์ต๋๋ค.
๋ณต์์๋ฅผ ์ค์์ ๊ธธ์ด 2 ๋ฐฐ์ด๋ก ํด์ํ๋ ๊ฒ๊ณผ ๊ด๋ จ๋ ๋ช ๊ฐ์ง ์ถ๊ฐ ์ ๋ณด์ ๋๋ค. ๋ค์์ C++11์์ ์ ํจํฉ๋๋ค.
๋ณต์์ p์ ๋ฐฐ์ด ์์์ ๋ํ ํฌ์ธํฐ์ ์ ํจํ ๋ฐฐ์ด ์ธ๋ฑ์ค i์ ๊ฒฝ์ฐ reinterpret_cast
(p)[2 i]๋ ๋ณต์์ p[i]์ ์ค์๋ถ์ด๊ณ reinterpret_cast (p)[2 i + 1]์ ๋ณต์์ p[i]์ ํ์๋ถ์ ๋๋ค. (C++11๋ถํฐ)
์ด๊ฒ์ complex_tensor๋ฅผ ๋ชจ์(complex_shape, 2)์ ๊ฐ์ง real_tensor๋ก ๋ณํํ ๋ค์ ์ ๋ฉ๋ชจ๋ฆฌ๋ฅผ ํ ๋นํ real()
๋ฐ imag()
๋ฅผ ํธ์ถํ์ง ์๊ณ ์์
์ ์ํํ ์ ์์์ ์๋ฏธํ๋ค๊ณ ์๊ฐํฉ๋๋ค.
@dylanbespalko ๋๋ ๋น์ ์ด ์ด๊ฒ์ ๋ํด ๋ฌผ์ ๋๋ฅผ ๋๋ ค์ํ์ต๋๋ค :) std::complex
๋ณด์ฆ์ ๋ฐ์ดํฐ ํฌ์ธํฐ std::complex<float>*
๊ฐ ์์ผ๋ฉด float*
๋ก ์์ ํ๊ฒ ์บ์คํธํ ์ ์์์ ์๋ฏธํฉ๋๋ค (์๊ฒฉํ ์จ๋ฆฌ์ด์ฑ ์ค์ผ๊ฑฐ๋ฆผ) ๊ทธ๋ฐ ๋ค์ ์ฌ์ฉ ์ค์ธ fft ํญ๋ชฉ์ ์ ๋ฌํฉ๋๋ค. ์ด ๋ฎ์ ์์ค์ ๋ด๋น์๋ฅผ ์ ๋ฌํ ์ ์๋ fft/var๋ง ๊ตฌํํด์ผ ํ๋ ๊ฒฝ์ฐ ๊ฐ์ฅ ์ฝ์ต๋๋ค.
๊ทธ๋ฌ๋ ๋ณต์กํ ํ ์๋ฅผ float ํ ์๋ก ๋ฌธ์ ๊ทธ๋๋ก ์ฌ๊ฒํ ํด์ผ ํ๋ ๊ฒฝ์ฐ ์ค๋๋ PyTorch์ ์ด์ ๋ํ ์ ๋ก๊ฐ ์๊ธฐ ๋๋ฌธ์ ์ฝ๊ฐ์ ๋๊ด์ ๋ด์ฐฉํ๊ฒ ๋ฉ๋๋ค. Storage dtype์ ํญ์ Tensor dtype๊ณผ ๋์ํ์ต๋๋ค. ๊ทธ๋์ ๋ณต์กํ ์ฐฝ๊ณ ๋ฅผ ๋ง๋ ๋ค๋ฉด ํ๋กํธ ์ฐฝ๊ณ ๋ก ๊ฒํ ํ ๋ฐฉ๋ฒ์ด ์์ต๋๋ค.
๋ด๊ฐ ๊ฐ์ง ํ ๊ฐ์ง ์๊ฐ์ ์ฐ๋ฆฌ๊ฐ ์ด ๋ถ๋ณ๋์ ์ํํด์ผ ํ๋ค๋ ๊ฒ์ ๋๋ค. ์์ด๋์ด๋ ๋ค์๊ณผ ๊ฐ์ต๋๋ค.
ํ์ง๋ง ์ด ๋ถ๋ณ์ฑ์ ๋ฐ์์ํค๋ ค๋ฉด ์ผ๋ง๋ ๋ง์ ์ฝ๋๋ฅผ ๋ณ๊ฒฝํด์ผ ํ๋์ง ์ ๋ชจ๋ฅด๊ฒ ์ต๋๋ค.
@ezyang ,
๊ทธ๋ ์ด๊ฑด ์ด์ฉ ์ ์์๋ค...
์ด ๋ฎ์ ์์ค์ ๋ด๋น์๋ฅผ ์ ๋ฌํ ์ ์๋ fft/var๋ง ๊ตฌํํด์ผ ํ๋ ๊ฒฝ์ฐ ๊ฐ์ฅ ์ฝ์ต๋๋ค.
์, ๋ง์ ๊ฒฝ์ฐ์ ๊ฐ๋ฅํฉ๋๋ค. ํ ์ ๋ฐ์ดํฐ๋ฅผ std::vector๋ก ํด์ํ๋ ๋ฐฉ๋ฒ์ ๋ํ ์ฝ๋ ์ค๋ํซ์ ์ ๊ณตํ ์ ์์ต๋๊น?
๊ทธ๋ฌ๋ ๋ณต์กํ ํ ์๋ฅผ ๋ง ๊ทธ๋๋ก float ํ ์๋ก ๋ค์ ๋ด์ผ ํ๋ค๋ฉด....
๋ค๋ฅธ dtype์ ์ฌ์ฉํ์ฌ ํ
์๋ฅผ ๋ณด๋ ๊ฒ์ ๋๋ญ
๋๋ค. Tensor
set_dtype()
๋ฉ์๋๋ฅผ ๊ตฌํํ์ง๋ง ๋ช ๊ฐ์ง ์ค๋ฅ๊ฐ ๋ฐ์ํ์ต๋๋ค. ๋ํ ๋ชจ์์ ๋ณํ๋ฅผ ๋ฐ์ํ๊ธฐ ์ํด ๋ณดํญ์ ์
๋ฐ์ดํธํ์ง ์์์ต๋๋ค. dtype ์ค์ ์ด numpy์์ ์๋ํ๋ ์ด์ ๋ ํ์คํ์ง ์์ง๋ง(์ฐ์ฐ์ธ๊ฐ์?) ๋ฐ์ดํฐ๋ฅผ DAC(๋์งํธ-์๋ ๋ก๊ทธ ๋ณํ๊ธฐ)์ ์
๋ก๋ํ ๋ ์ข
์ข
์ค์ /ํ์ ๋ฐ์ดํฐ๊ฐ ์ธํฐ๋ฆฌ๋ธ๋ ๊ฒ์ผ๋ก ์์ํฉ๋๋ค. ์๋ง๋ ๊ทธ๊ฒ์ ๋น์ ์ด ์ ์ํ ๊ฒ์ฒ๋ผ ์คํ ๋ฆฌ์ง dtype์์ ํ
์ dtype์ ๋ถ๋ฆฌํด์ผ ํ ํ์์ฑ์ ๋๊ธฐ๋ฅผ ๋ถ์ฌํ ๊ฒ์
๋๋ค.
๋น๋ถ๊ฐ์ ์ด ์ผ์ ํผํ๊ฒ ์ต๋๋ค. ๋์๊ฒ ๋ค๋ฅธ ์ฑ๋ฅ ๋ณ๋ชฉ ํ์์ด ์๋ค๊ณ ํ์ ํฉ๋๋ค.
์, ๋ง์ ๊ฒฝ์ฐ์ ๊ฐ๋ฅํฉ๋๋ค. ํ ์ ๋ฐ์ดํฐ๋ฅผ std::vector๋ก ํด์ํ๋ ๋ฐฉ๋ฒ์ ๋ํ ์ฝ๋ ์ค๋ํซ์ ์ ๊ณตํ ์ ์์ต๋๊น?
์ ํํ std::vector๋ ์๋์ง๋ง ๋ค์๊ณผ ๊ฐ์ด ์์ํ๊ณ ์์ต๋๋ค.
Tensor complex_tensor;
assert(complex_tensor.is_contiguous());
std::complex<float>* cp = complex_tensor.data_ptr<std::complex<float>>();
float* fp = reinterpret_cast<float*>(cp);
auto num_floats = complex_tensor.numel() * 2;
Tensor์ฉ set_dtype() ๋ฉ์๋๋ฅผ ๊ตฌํํ์ง๋ง ๋ช ๊ฐ์ง ์ค๋ฅ๊ฐ ๋ฐ์ํ์ต๋๋ค. ๋ํ ๋ชจ์์ ๋ณํ๋ฅผ ๋ฐ์ํ๊ธฐ ์ํด ๋ณดํญ์ ์ ๋ฐ์ดํธํ์ง ์์์ต๋๋ค.
์, ๋ณดํญ๋ ์์ ํ์ง ์์ผ๋ฉด ์ด๊ฒ์ ์๋ง๋ ๋์ ์๊ฐ์ผ ๊ฒ์ ๋๋ค. ๋ํ ์ ๋ ๋ค๋ฅธ dtype์ผ๋ก ๋ณํํ๋ ํ ์์ ์ด๋ ฌํ ํฌ์ด ์๋๋๋ค. ๋ชจ๋ ๊ฒ์ ์ ์๋ฆฌ์์ ์ํํ๋ ๊ฒ์ด ์ข์ต๋๋ค. :)
๊ทธ๋ฌ๋ DAC(๋์งํธ-์๋ ๋ก๊ทธ ๋ณํ๊ธฐ)์ ๋ฐ์ดํฐ๋ฅผ ์ ๋ก๋ํ ๋ ์ข ์ข ์ค์ /ํ์ ๋ฐ์ดํฐ๊ฐ ์ธํฐ๋ฆฌ๋ธ๋ ๊ฒ์ผ๋ก ์์ํฉ๋๋ค. ์๋ง๋ ๊ทธ๊ฒ์ ๋น์ ์ด ์ ์ํ ๊ฒ์ฒ๋ผ ์คํ ๋ฆฌ์ง dtype์์ ํ ์ dtype์ ๋ถ๋ฆฌํด์ผ ํ ํ์์ฑ์ ๋๊ธฐ๋ฅผ ๋ถ์ฌํ ๊ฒ์ ๋๋ค.
์, ๊ถ๊ทน์ ์ผ๋ก ์ด๊ฒ์ด ์ฌ๋ฐ๋ฅธ ์ผ์ด์ง๋ง ์ง๊ธ์ ํ์ง ์๋ ๊ฒ์ด ๋ ์ฝ๋ค๋ ๋ฐ ๋์ํฉ๋๋ค.
@ezyang ,
๋ณต์์ CUDA ์ง์์ ์๋ง์ผ๋ก ๋ง๋ค๊ธฐ ์์ํ์ต๋๋ค.
๋ ๊ฐ์ง ๋ฐ์ด๋๋ฆฌ ํธํ ์ต์ ์ด ์์ต๋๋ค.
std::complex
์ ๋ํ ๊ต์ฒด๊ฐ ์ค๋จ๋์์ต๋๋ค.์ถ์ง๋ ฅ::๋ณต์กํ ์ปจํ
์ด๋๊ฐ ๊ฐ์ผ ํ ๊ธธ์ธ ๊ฒ ๊ฐ์ต๋๋ค. Thrust::Complex API ๋ thrust::complex<T>
์ปจํ
์ด๋๊ฐ ํธ์คํธ ๋ฐ ์ฅ์น ๋ฉ๋ชจ๋ฆฌ์ ํ ๋น๋ ์ ์๋ ๋ฐ๋ฉด std::complex<T>
๋ ํธ์คํธ ๋ฉ๋ชจ๋ฆฌ์๋ง ํ ๋น๋ ์ ์๋ค๊ณ ์ ์ํฉ๋๋ค.
__host__ __device__ thrust::complex< T >::complex (const complex< T > &z) //thrust container
__host__ thrust::complex< T >::complex (const std::complex< T > &z) //stl container.
์ด๊ฒ์ AT_DISPATCH_COMPLEX_TYPES๊ฐ using scalar_t = std::complex<double>
using scalar_t = thrust::complex<double>
๋ฅผ ์ค์ ํด์ผ ํจ์ ์๋ฏธํฉ๋๊น?
Pytorch๋ ์ค์ ๋ฐ์ดํฐ ์ ํ์ ๋ํด std::log
์ ํด๋นํ๋ CUDA๋ฅผ ์ด๋ป๊ฒ ์๋์ผ๋ก ํธ์ถํฉ๋๊น? ์ํ ์ปค๋์ ํด๋นํ๋ CUDA๊ฐ ์๋์ง ์ด๋ป๊ฒ ์ ์ ์์ต๋๊น?
thrust::complex<double>
๋ฅผ ์ฌ์ฉํ๋ ๊ฒ์ ์ด๋ ค์์ CPU ์ ์ฉ ๋น๋๋ฅผ ์ํํ๋ ๊ฒฝ์ฐ ์ค์ ๋ก ์ถ๋ ฅ์ ๋ํด ๋น๋ํ์ง ์๋๋ค๋ ๊ฒ์
๋๋ค. ๋๋ ๋ง์ ์ต์
์ด ์๋ค๊ณ ์๊ฐํฉ๋๋ค. ์ฐ๋ฆฌ๋ ์ฐ๋ฆฌ ์์ ์ ๋ณต์กํ ์ ํ์ ๊ตด๋ฆด ์ ์์ต๋๋ค(์ฐ๋ฆฌ๊ฐ ์ฐ๋ฆฌ ์์ ์ ํํ ์ ํ์ ๊ตด๋ฆฌ๋ ๋ฐฉ๋ฒ๊ณผ ์ ์ฌ). ๋๋ std::complex<>
๊ฐ ํน์ ๋ฐ์ด๋๋ฆฌ ๋ ์ด์์์ ๊ฐ๋๋ก ์ ์๋์ด ์๊ธฐ ๋๋ฌธ์ ์บ์คํธ๋ฅผ ์ฌํด์ํ์ฌ ์น๋ฆฌํ ์ ์์ต๋๋ค. ๊ทธ๊ฒ์ ๋น์ ์๊ฒ ๋ฌ๋ ค ์์ง๋ง, ์ง๊ธ์ ์ ํ ๊ฐ์ ์บ์คํ
์ ์ฌํด์ํ๋ ๊ฒ์ด ๋ ์ฌ์ ๋ณด์
๋๋ค.@iotamudelta ๋ #29547์์ C++11 ๊ท์ ์ค์ ๋ฌธ์ ๋ฅผ ์ ๊ธฐํ์ต๋๋ค.
std::real์ C++14์ constexpr์ผ ๋ฟ์ ๋๋ค.
๋ด๊ฐ ์ฌ๋ฐ๋ฅด๊ฒ ์ดํดํ๋ค๋ฉด std::real()
๋ constexpr
์ด์ด์ผ hcc ์ปดํ์ผ๋ฌ๊ฐ __device__
์ ๋ํ ๋ช
๋ น์ ์ปดํ์ผํ ์ ์์ต๋๋ค.
๊ฐ๋ฅํ ํด๊ฒฐ์ฑ :
complex<double>
๋ฅผ double
๋ก ๋ณํํ๋ ๋ค๋ฅธ ๋ฐฉ๋ฒ์ด๋ ํจ์๋ฅผ ์ฐพ์ผ์ญ์์ค.abs
๋ ์๋ํ์ง ์์ต๋๋ค. ํ
ํ๋ฆฟํจ์๋ฅผ ๋ํํ๋ ๋ฐฉ๋ฒ์ ์ฐพ์ผ์ญ์์ค.
std::real ์ ๋ํ ๋๋ถ๋ถ์ ํธ์ถ์ aten/src/ATen/native/cpu/zmath.h
์์ ์ด๋ฃจ์ด์ง๋๋ค. ์: inline
constexpr
:
inline VALUE_TYPE real_impl (SCALAR_TYPE z)
->
constexpr VALUE_TYPE real_impl (SCALAR_TYPE z)
inline std::complex<float> real_impl <std::complex<float>> (std::complex<float> z)
-> constexpr std::complex<float> real_impl <std::complex<float>> (std::complex<float> z)
inline std::complex<float> real_impl <std::complex<double>> (std::complex<float> z)
-> constexpr std::complex<float> real_impl <std::complex<double>> (std::complex<float> z)
constexpr
๊ฐ ์๋ std::real()
์ ๋ํ ์ค์ฒฉ ํธ์ถ์ด ์ฌ์ ํ ์๊ธฐ ๋๋ฌธ์ ์ปดํ์ผ๋์ง ์์ต๋๋ค.
3. std::complex๋ฅผ ์ฌ์ฉํ๋ ๊ฒฝ์ฐ
std::complex<double>
๋ฅผ double
๋ก ๋ณํํ ์ ์๋ ๋ค๋ฅธ ๋ฐฉ๋ฒ์ด ์์ต๋๊น?@iotamudelta , @bddppq , @ezyang ,
CUDA thrust::complex API์ ๋ณต์กํ UnaryOps ๋ฐ BinaryOps์ ๋ํ ์ง์์ ์ถ๊ฐํ์ง๋ง ์ ์ถํ๊ธฐ ์ ์ ๋ช ๊ฐ์ง ์ง๋ฌธ์ ํด์ผ ํฉ๋๋ค.
๋๋ ๋ณต์์๋ฅผ ๋ค๋ฃฐ ๋ thrust::complex ๋ฐ์ดํฐ ์ ํ์ ์ฌ์ฉํ ์ ์๋๋ก ํ๋ ํ
ํ๋ฆฟ ํจ์๋ฅผ ์ ์ํ์ต๋๋ค.
aten/src/ATen/native/cuda/zmath.cuh
#pragma once
#include <complex>
#include <thrust/complex.h>
namespace at { namespace native {
namespace {
template <typename TYPE>
struct ztype_cuda {
using value_t = TYPE; // Complex template type
using thrust_t = TYPE; // Equivalent thrust type
};
template <>
struct ztype_cuda<std::complex<float>> {
using value_t = float;
using thrust_t = thrust::complex<float>;
};
template <>
struct ztype_cuda<std::complex<double>> {
using value_t = double;
using thrust_t = thrust::complex<double>;
};
} // end namespace
}} //end at::native
๊ทธ๋ฐ ๋ค์ aten/src/ATen/native/cuda/BinaryOpsKernel.cu
์์
๋ฐ๊พธ๋ค:
void add_kernel_cuda(TensorIterator& iter, Scalar alpha_scalar) {
AT_DISPATCH_ALL_TYPES_AND2(kHalf, kBool, iter.common_dtype(), "add_cuda/sub_cuda", [&]() {
auto alpha = alpha_scalar.to<scalar_t>();
gpu_kernel_with_scalars(iter, [alpha]GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t {
return a + alpha * b;
});
});
}
์ ํจ๊ป:
void add_kernel_cuda(TensorIterator& iter, Scalar alpha_scalar) {
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(kHalf, kBool, iter.dtype(), "add_cuda/sub_cuda", [&]() {
using thrust_t = typename ztype_cuda<scalar_t>::thrust_t;
auto alpha = thrust_t(alpha_scalar.to<scalar_t>());
gpu_kernel_with_scalars(iter, [alpha]GPU_LAMBDA(thrust_t a, thrust_t b) -> thrust_t {
return a + alpha * b;
});
});
}
thrust_t
๋ฅผ scalar_t_c
์ ๊ฐ์ด ๋ณต์์๊ฐ ์๋ ์ซ์์ ๋ ์น์ํ ์ด๋ฆ์ผ๋ก ๋ฐ๊ฟ ์ ์์ต๋๊น?thrust::complex
cuComplex
๋ฅผ ์ฌ์ฉํด์ผ ํ๋ ์ด์ ๊ฐ ์์ต๋๊น?hip_complex
๋ฅผ ์ฌ์ฉํด์ผ ํฉ๋๊น?cuComplex
๋ณด๋ค ๋ ๋ง์ ๊ธฐ๋ฅ์ ์ง์ํ๋ ๊ฒ ๊ฐ์ต๋๋ค.์ด๋ป๊ฒ ์๊ฐํ๋์ง ์๋ ค์ฃผ์ธ์.
@iotamudelta
std::real()์ ๋ํ ํ ๋ก ์ ์
๋ฐ์ดํธํ์ต๋๋ค. std::complex๋ฅผ ํ์ธํ ์ ์์ต๋๊น?
์๋ ํ์ธ์ @dylanbespalko ,
@iotamudelta ๊ฐ ๋ถํํ๋ ๊ฒ์ ๋ณต์กํ ์ ํ์ ๋ํ cast_and_store
์ C10_HOST_DEVICE
์ด ๋๋ฝ๋์ด ํด๋น ์ฝ๋ ๊ฒฝ๋ก๊ฐ GPU์์ ์คํ๋๋ ๊ฒฝ์ฐ UB๊ฐ ๋ ๊ฒ์
๋๋ค.
ํ์ฌ ์ด ๋์ ์บ์คํ
์ ํธ๋ฆฌํฐ๋ GPU TensorIterator์์๋ง ์ฌ์ฉ๋๋ฉฐ ์ ํ ์น๊ฒฉ์ด ์๋ ๊ฒฝ์ฐ์๋ง ์ฌ์ฉ๋ฉ๋๋ค. ํ์ฌ GPU์์ complex๊ฐ ์ง์๋์ง ์์๊ธฐ ๋๋ฌธ์ ํ์ฌ complex type์ ๋ํ cast_and_store
์๋ C10_HOST_DEVICE
ํ์ ์๊ฐ ์์ผ๋ฉฐ ํธ์คํธ์ ๋ํด ์์ ํ ๊ด์ฐฎ์ std::real
๋ฅผ ์ฌ์ฉํฉ๋๋ค. ์ ์ผํ ๊ธฐ๋ฅ. ์ฌ๊ธฐ์ UB๋ ์ฌ์ฉ๋์ง ์๊ณ ๊ฑฑ์ ํ ํ์๊ฐ ์๊ธฐ ๋๋ฌธ์ ์์ต๋๋ค.
ํ์ง๋ง ๋ณต์กํ ์ง์์ GPU์ ์ถ๊ฐํ๊ณ ์ถ๊ธฐ ๋๋ฌธ์ https://github.com/pytorch/pytorch/blob/master/c10/core/ScalarType.h#L398 ์์ ๋ณผ ์ ์๋ฏ์ด ๋ณตํฉ๋ฌผ์ ์ ํ ์น๊ฒฉ์ ์ํด ์ง์๋ฉ๋๋ค. L420, ์ด ์ฝ๋ ๊ฒฝ๋ก์ ๋ํด ๋งค์ฐ ์ฃผ์ํด์ผ ํ๋ฉฐ ์๋ํ๋ ค๋ฉด ๋ช ๊ฐ์ง ์์ ์ด ํ์ํ ์ ์์ต๋๋ค.
๋ฌผ๋ก https://github.com/pytorch/pytorch/pull/29547 ์์ @iotamudelta ๊ฐ ํ๋ ๋๋ก C10_HOST_DEVICE
๋ฅผ ์ถ๊ฐํด์ผ ํ์ง๋ง ๋จ์ํ C10_HOST_DEVICE
๋ฅผ ์ถ๊ฐํ๊ธฐ ๋๋ฌธ์ ์ถฉ๋ถํ์ง ์์ต๋๋ค. @iotamudelta ๊ฐ ์ธ๊ธํ ๊ฒ์ฒ๋ผ ๋ค๋ฅธ ๋ณ๊ฒฝ ์ฌํญ์ด ์์ผ๋ฉด ์ฌ์ ํ C++11์ UB์
๋๋ค. ์ข์ ์๋ฃจ์
์ ๊ทํ๊ฐ ์ธ๊ธํ ๊ฒ์ผ ์ ์์ต๋๋ค. std::real
std::complex::real()
๋ฅผ ์ฌ์ฉํ์ญ์์ค.
ํ์ง๋ง ๊ทธ ์ธ์๋ https://github.com/pytorch/pytorch/blob/master/c10/util/TypeCast.h ํ์ผ์ ๋ณด๋ฉด fetch_and_cast
๋ด๋ถ์ ๋ค์๊ณผ ๊ฐ์ ๋ด์ฉ์ด ์์ต๋๋ค.
#ifndef C10_HOST_DEVICE
AT_FORALL_COMPLEX_TYPES(FETCH_AND_CAST_COMPLEX_CASE)
#endif
์ด ์ฝ๋ ๊ฒฝ๋ก๋ GPU์์ ๋นํ์ฑํ๋์ด ์์ต๋๋ค. ํ์ฑํํ๊ณ ์๋ํ๋๋ก ํด์ผ ํฉ๋๋ค.
๋ํ fetch_and_cast
๋ฐ cast_and_store
$ ๋ด์์ $ complex<float>
์ complex<double>
์ฌ์ด์ ๋ณํ์ ๋ณด์ง ๋ชปํ์ต๋๋ค. ์ด๋ฅผ ์ํด ๋ณํ์ ์ถ๊ฐํด์ผ ํ ์๋ ์์ต๋๋ค. ๋ชจ๋ dtypes์ ์ด๋ฌํ ๊ธฐ๋ฅ ๋ฒ์๋ฅผ ์ฒ ์ ํ ํ
์คํธํ์ญ์์ค.
์ฐธ์กฐ: @ezyang ๋ฐ @bddppq
๋ํ @dylanbespalko , PR์์ TypeCast.h
์ ๋ณ๊ฒฝํ๋ ๊ฒฝ์ฐ ์ ๋ฅผ ์ฐธ์กฐํ์ญ์์ค.
์๊ฒ ์ต๋๋ค. ARM์์ torch.real()
๋ฐ torch.imag()
๋ก ์์ ํด์ผ ํ ๋ช ๊ฐ์ง ์์ ์ฌํญ์ด ์์ผ๋ฏ๋ก TypeCast.h
๋ฐ ๊ธฐํ ๋ช ๊ฐ์ง๋ฅผ ์์ ํ๋ ๋์ ์์ ํ๊ฒ ์ต๋๋ค. PR์์ ์ฐธ์กฐํ๊ฒ ์ต๋๋ค.
๋๊ธ๋ก ๋๋ผ์ด๋ธ: @smessmer ๋ ์ฐ๋ฆฌ๋ฅผ C++14๋ก ์ฎ๊ธฐ๊ณ ์์ผ๋ฉฐ, ๊ทธ ์์ ์์๋ UB๊ฐ ์๋๋๋ค. ์ด๊ฒ์ด ๊ณง ๋์ฌ ๊ฒ์ด๊ธฐ ๋๋ฌธ์ UB๊ฐ ์ค์ ๋ฌธ์ ๋ฅผ ์ผ์ผํค์ง ์๋๋ค๋ฉด ๋๋ ๊ทธ๊ฒ์ ๋ํด ๋๋ฌด ๊ฑฑ์ ํ์ง ์์ ๊ฒ์ ๋๋ค.
@ezyang : ๋ฐ๊ฐ์ต๋๋ค. Eigen๊ณผ ๊ฐ์ ํ์ฌ ์ ํ์ ๋๋ถ๋ถ์ ์ฌ์ ํ std::real()
๋งค์ฐ ์์ ๋กญ๊ฒ ํธ์ถํฉ๋๋ค.
๋ณต์์๊ฐ ์๋ ์ซ์์ ๊ฒฝ์ฐ scalar_t ๋ฐ thrust_t๋ ๋์ผํ ์ ํ์ ๋๋ค. ์๋ง๋ ๋ด๊ฐ ๋ณ์ ์ด๋ฆ thrust_t๋ฅผ scalar_t_c์ ๊ฐ์ ๋น๋ณต์์์ ๋ ์น์ํ ๊ฒ์ผ๋ก ๋ฐ๊ฟ ์ ์์ต๋๊น?
ํ์คํ์ง ์์ง๋ง scalar_t_c
๋ thrust_t
๋ณด๋ค ์ฝ๊ฐ ๋ ๋ช
ํํด ๋ณด์
๋๋ค( c
๋ ์ด์จ๋ ๋ฌด์์ ์๋ฏธํฉ๋๊น?) ์ฌ๊ธฐ์ ๋ฌธ์ ์ ์ ํ์ ์๋นํ ๊ตฌ์ฒด์ ์ผ๋ก ๋ณด์
๋๋ค ๊ทธ๋์ ์๋๋ฅผ ์ง์ ์ ์ผ๋ก ๋งํ๋ ์ด๋ฆ์ ์ฌ์ฉํ๋ ๊ฒ์ด ์ข์ ๊ฒ ๊ฐ์ต๋๋ค.
์๊ฒ ์ต๋๋ค. thrust_t
๋ฅผ ๊ณ์ ์ฌ์ฉํ๊ฒ ์ต๋๋ค. ๋๊ตฌ๋ ์ง ztype_cuda<>()
์ ๋ฐ์ด๋ค๋ฉด scalar_t
๊ฐ ๋ณต์กํ์ง ์์ ์ ํ์ ๊ฒฝ์ฐ thrust_t
์์ ์ฆ์ ์์์ฐจ๋ ค์ผ ํฉ๋๋ค.
https://github.com/pytorch/pytorch/pull/29612 ๋ฅผ ์ฐธ์กฐํ์ธ์.
์๋ ํ์ธ์ ์ฌ๋ฌ๋ถ! pytorch์ ๋ณต์กํ ์ง์์ ์ถ๊ฐํ๋ ๋ฐฉํฅ์ผ๋ก ์ข์ ์ง์ ์ด ์๋ ๊ฒ ๊ฐ์ต๋๋ค! ์ด์ ๋ํ ์ฃผ๋๊ถ์ ์ก๊ณ CUDA ์ง์๋ ์ถ๊ฐํด ์ฃผ์ @dylanbespalko ์๊ฒ ๊ฐ์ฌ๋๋ฆฝ๋๋ค! ๋์ ์์ค์์ ๋ณต์กํ ์ง์์ ํ์ฌ ์งํ ์ํฉ์ ์๊ณ ์ถ์ต๋๋ค. ์ ๋ ์ฃผ๋ก ๋ณต์กํ ํ ์(์ด์ง ์ฐ์ฐ)๋ฅผ ์ถ๊ฐํ๊ณ ๊ณฑํ๊ธฐ ์ํ CUDA ์ง์์ ์ํ ๋๋ต์ ์ธ ํ์๋ผ์ธ์ ๊ด์ฌ์ด ์์ต๋๋ค. ๊ฐ์ฌํฉ๋๋ค!
์๋ ํ์ธ์ @sunilkpai ์ ๋๋ค .
CUDA: #30295์์ ์ด์ง ๋ฐ ๋จํญ ์ฐ์ฐ์ ์ง์ํด์ผ ํ๋ ๊ณต๊ฐ PR์ด ์์ต๋๋ค.
๋ ๋ค๋ฅธ ๋ฌธ์ ๋ ์ญ์ ํ์ ๊ด๋ จ๋ ๊ฒ์
๋๋ค. ๋ณต์์ abs()
์ ๋ํจ์๋ ์ค์์ ๋ค๋ฅด๊ฒ ์ ์๋๋ค๊ณ ์๊ฐํฉ๋๋ค. ๊ทธ๊ฒ์ ๋ํด ๋ฌด์์ํด์ผํ ์ง ๋ชจ๋ฅด๊ฒ ์ง๋ง ํ์ ์ํ์ tools/autograd/derivatives.yaml
์ ์ ์๋์ด ์์ต๋๋ค.
๋๋ ๋ณต์์ /dz abs(z) = z/abs(z)
์ ๋ํด ์๊ฐํฉ๋๋ค. ์ด๊ฒ์ ์ค์์๋ ์ฌ์ฉํ ์ ์์ง๋ง sgn(z)
๋ณด๋ค ๋๋ฆด ๊ฒ์
๋๋ค.
@dylanbespalko ๋ด ๋ณด๊ณ ์ https://arxiv.org/pdf/1701.00392.pdf ์ ํ 4.1, 4.2 ๋ฐ 4.3์ด ํ์ ์ํ์ ์ ์ํ๋ ๋ฐ ๋์์ด ๋ ์ ์์ต๋๋ค.
๋ณต์กํ ๋ํจ์(wirtinger calculus)์ ๊ฒฝ์ฐ ๋ ๊ฐ์ง ์ต์
์ด ์์ต๋๋ค.
๋ฏธ๋ถ wrt z ๋๋ z ์ผค๋ ๊ณ์ฐ.
์ ๋ ๊ฐ์ธ์ ์ผ๋ก wrt z conjugate ๋ํจ์๋ฅผ ๋ ์ข์ํฉ๋๋ค.
ํ๋ ฌ ์ฐ์ฐ์ ๋ ์์ฐ์ค๋ฝ๊ฒ ๋๊ปด์ง๊ณ ๊ทธ๋ผ๋์ธํธ ์
๋ฐ์ดํธ์ ์ผค๋ ๊ฐ ํ์ํ์ง ์์ต๋๋ค.
์ด๋ค์ ์ ์๋ ๋ค์๊ณผ ๊ฐ์ต๋๋ค.
z = x + jy
#$์ ๋ํ ํ์ ์ํ z
$: dJ/dz = dJ/dx -j dJ/dy
z.conj
z = x + jy
: dJ/dz.conj = dJ/dx + j dJ/dy
๊ทํ์ ์๊ฒฌ์์ ๋ด ๊ฐ์ ์ ํ์ฌ z
ํ์ ์ํ์ ๊ณ์ฐํ๋ค๋ ๊ฒ์
๋๋ค.
์ด ๊ฒฝ์ฐ ํ์ ์ํ์ d abs(z) / d z = z.conj / abs(z)
์
๋๋ค. ๋ค๋ฅธ ์ ์๋ฅผ ์ฌ์ฉํ๋ฉด @Randl ์ ์์ ๋ฐ๋ฅผ ์ ์์ต๋๋ค.
๋ ์ค๋ช ํด์ผ ํ๋ ๊ฒฝ์ฐ ์๋ ค์ฃผ์ธ์. ๋ณต์กํ ํ์ ์ํ์ ๋ํ ๋ช ๊ฐ์ง numpy ๊ตฌํ๋ ์์ต๋๋ค.
์ ์ฉํ ๋ ๋ค๋ฅธ ์ฐ์ฐ(ํนํ ๋ณต์์ ์ง์์ด ํ์ํ ๋ฌผ๋ฆฌ ๊ณต๊ฐ์ ํ๋ก์ ํธ์ ๋ํด)์ exp()
์ฐ์ฐ์์ ๋ํ ์ฒ๋ฆฌ๊ธฐ์
๋๋ค. tensorflow์๋ tf.exp(x + iy) = tf.exp(x) * (tf.cos(y) + 1j * tf.sin(y))
๊ฐ ์์ต๋๋ค. pytorch์์๋ ๊ตฌํํ๊ธฐ๊ฐ ๊ฐ๋จํฉ๋๊น?
@sunilkpai , @boeddeker , @Randl ,
๋ณต์กํ ํ์ ์ํ์ ๋ํ ๋ณด๊ณ ์์ ๊ฐ์ฌ๋๋ฆฝ๋๋ค. ๋๋ ๊ทธ๊ฒ์ ๋ฐ๋ฅด๋ ค๊ณ ๋ ธ๋ ฅํ ๊ฒ์ด๊ณ ๋๋ ๋ค์ ์ฃผ์ ์ด๊ฒ์ ๋ค์ ํ ๊ฒ์ด๋ค. ์ฌ๊ธฐ์ ๋ช ๊ฐ์ง ๋งํฌ๋ฅผ ์ถ๊ฐํ๊ณ ํ๋ก์ ํธ ์ํ๋ฅผ ์ค๋ช ํด์ผ ํ๋ค๊ณ ์๊ฐํ์ต๋๋ค.
๋ณต์์์ ์ํ๋ ๋น๊ณต์์ ์ผ๋ก ์ง์๋๋ฉฐ PyTorch ํ์ฅ์ ํตํด ์ถ๊ฐํด์ผ ํฉ๋๋ค.
๊ฐ ํ์ฅ์๋ ๋ค์ ๋ ๊ฐ์ง๊ฐ ํฌํจ๋ฉ๋๋ค.
.cpp
.test/
ํด๋.๋ณต์กํ ํ ์๋ฅผ ์ฝ์์ ์ธ์ํ ์ ์๋ ์ด์ ๋ ๋ฌด์์ ๋๊น?
tensor.py
์ ๋ด์ฉ์ ์์ ํ์ฌ ์ธ์ ํ์์ ๋ฌด์ํ ์ ์์ต๋๋ค.ํ์ฌ ํ๋ก์ ํธ ์ํ:
</li>
<li>Complex number specific code is under 'aten/src/ATen/native/cpu/zmath.h
์๋ PyTorch ๋ด๋ถ์์ ๊ตฌํ๋ฉ๋๋ค.</li>
<li>Complex number specific code is under 'aten/src/ATen/native/cuda/zmath.cuh
์๋ PyTorch ๋ด๋ถ์์ ๊ตฌํ๋ฉ๋๋ค.thrust::complex<T>
๋ฐ์ดํฐ ์ ํ์ด ์ฌ์ฉ๋๋ฉฐ ์ฌ๊ธฐ์๋ ์ต์ ํ๋ ์ปค๋์ด ํฌํจ๋ฉ๋๋ค.ํ์ฌ ๊ฐ๋ฐ:
--
์ฐธ๊ณ ๋ก ๋ณต์กํ ํ์ ์ํ๊ณผ ๊ด๋ จํ์ฌ Julia์์ ๊ธด ๋
ผ์๊ฐ ์์๊ณ ์ด์ ChainRules ( http://www.juliadiff.org/ChainRules.jl/dev/api.html#ChainRulesCore.Wirtinger ์ฐธ์กฐ) ๋ฐ Zygote ์์ ๊ตฌํ์ด ์๋ฃ๋์์ต๋๋ค. . ์ผ๋ฐ์ ์ผ๋ก ์ฌ๋๋ค์๊ฒ ํ์ํ ๊ฒ์
\partial L/\partial adjoint(z)
๊ทธ๋ผ๋์ธํธ(์ ์์ ๊ฐ์ฅ ๋น ๋ฅธ ๊ฐ์ ๋ฐฉํฅ)๋ก ์ฌ์ฉ๋์ง๋ง ๋ํจ์๋ \partial L/\partial z
์ ๋ค๋ฅด๋ฉฐ ๋ณต์์ AD๋ฅผ ์๋ฒฝํ๊ฒ ์ง์ํ๋ ค๋ฉด ์ถ๊ฐ ์ธํฐํ์ด์ค๋ฅผ ์ถ๊ฐํด์ผ ํฉ๋๋ค. . ์์ธํ ๊ท์น์ ChainRules
๋๋ Zygote/lib
์ ๊ตฌํ๋ ๊ฒ์ ํ์ธํ ์ ์์ต๋๋ค(์ผ๋ฐ ๊ท์น๋ง ์๊ธฐ ๋๋ฌธ์ ๋๋ถ๋ถ์ ์ฐ์ฐ์์ ๋ํ ๋ณต์์์ ๋ํ ๋ณ๋์ ๊ท์น์ด ์์ต๋๋ค. matmul
๋ ์ผ๋ฐ ์ ์๋ก ์์ฑ๋ฉ๋๋ค(์: adjoint(A) * B
).
๋ณต์กํ ํ ์๋ฅผ ์ฝ์์ ์ธ์ํ ์ ์๋ ์ด์ ๋ ๋ฌด์์ ๋๊น?
Tensor python ๊ฐ์ฒด์๋ ์ง์๋์ง ์๋ ์ผ๋ถ ๊ธฐ๋ฅ์ ํธ์ถํ๋ ์์ ์ธ์ ํ์์ด ์์ต๋๋ค.
tensor.py์ ๋ด์ฉ์ ์์ ํ์ฌ ์ธ์ ํ์์ ๋ฌด์ํ ์ ์์ต๋๋ค.
๋๋ ๊ฐ๋จํ Pytorch ํ ์๋ฅผ Numpy ๋ฐฐ์ด๋ก ๋ณํํ ๋ค์ ์ธ์ํ ์ ์์ต๋๋ค.
๋๋ ๋๋ฒ๊น ๋ฑ์ ์ํด https://github.com/Roger-luo/pytorch-complex ์์ ์ธ์์ ์ ์ด๋ ์ผ๋ถ๋ฅผ ์์ ํ๋ค๊ณ ์๊ฐํฉ๋๋ค. ์ฒ์์๋ ๋ง์คํฐ๊ฐ ๊ณผ๊ฑฐ์ ๋ง์ด ๋ณ๊ฒฝ๋์๊ธฐ ๋๋ฌธ์ ์ด๊ฒ์ด ๋์์ด ๋ ์ง ํ์คํ์ง ์์ต๋๋ค. ๋ ๋. ๋์์ด ๋๋ ๊ฒฝ์ฐ ๊ฐ์ ธ๊ฐ ์ ์์ต๋๋ค. ๋ ์ด์ ์ด ์์ ์ ์ํํ์ง ์๊ฒ ์ต๋๋ค.
@dylanbespalko ์ ๋ ๋ฐฐ์ฐ๊ธฐ ์์ํ์ง๋ง pytorch ๋ด๋ถ์ ์๋์ ์ผ๋ก ๊ฒฝํ์ด ์์ต๋๋ค! aten/src/ATen/cpu/vec256/*
์์ ๋ณธ ๊ฒ์ ๊ธฐ๋ฐ์ผ๋ก ํ๋ฉด ์ด ๋ณ๊ฒฝ์ ์๋ํ ์๋ ์์ง๋ง std::exp(std::complex) ์ ๊ธฐ๋ณธ ๋์์ด ๋ด๊ฐ ์ธ๊ธํ ๊ฒ๊ณผ ์ ํํ ์ผ์นํ๋ค๋ ์ ์ ๊ฐ์ํ ๋ ์ด๊ฒ์ด ํ์ํ์ง ํ์คํ์ง ์์ต๋๋ค. ๋ด ์ด์ ์๊ฒฌ์์ : https://en.cppreference.com/w/cpp/numeric/complex/exp ์ ๋ฉ๋ชจ๋ฅผ ์ฐธ์กฐํ์ญ์์ค. ๋ํ ์ด๊ฒ์ด CUDA์์ ์ด๋ฌํ ์์
์ ๊ตฌํํ๋ ๊ฒ์ผ๋ก ์ด๋ป๊ฒ ํด์๋๋์ง ์ ๋ชจ๋ฅด๊ฒ ์ต๋๋ค(ํ์ฌ ์ค์ , imag, conj ๋ฐ ๊ฐ๋๋ก ์ ํ๋์ด ์๋ ๊ฒ์ฒ๋ผ ๋ณด์
๋๊น?).
@sunilkpai ,
์ ๊ณต๋ ๋ฐฉ์ ์์ ์ฌ์ฉํ์ฌ exp()
์ ๋ํ AVX ์ง์์ ์ถ๊ฐํ์ต๋๋ค.
๋ํ PyTorch์ ์ต๊ทผ ๋ณ๊ฒฝ ์ฌํญ์ผ๋ก ์ธํด ๋ช ๊ฐ์ง ๋ฌธ์ ๊ฐ ๋ฐ์ํ์์ ์์์ต๋๋ค. #30871์์ ์์ ํ์ต๋๋ค.
@dylanbespalko
TH์์ ATen์ผ๋ก ์ด์ํ๋ ์ผ์ ์ด ์์ต๋๊น?
๋ด๊ฐ pytorch์ ๋ด๋ถ ์๋์ ์ ํตํ์ง ์๋ค๋ ์ฌ์ค์ ๊ฐ์ํ ๋ ๊ธฐ์ฌํ ์ ์๋ ๋ฐฉ๋ฒ์ด ์์ต๋๊น?
๋๋ arxiv ์์ ๋ณต์กํ svd์ ์ญ์ ํ๋ฅผ ์ํ ๊ณต์์ ์ฐพ์๊ณ ๊ทธ๊ฒ์ ๊ตฌํํ ์ ์์ต๋๋ค.
์์ ํด์ฃผ์ ์ ๊ฐ์ฌํฉ๋๋ค!
@์ผ์ฝฅ-์ธํ๋ฆฌ๋
https://github.com/pytorch/pytorch/wiki/TH-to-ATen-porting-guide
TH ์ปค๋์ C๋ก ๊ตฌํ๋๋ฉฐ ๊ณ ์ ํ ์ฐธ์กฐ ์นด์ดํ
๋ฌธ์ ๋ก ์ธํด ๋ณต์กํ ์ง์์ ์ถ๊ฐํ๋ ๋ฐ ๊ฑฐ์ ๊ด์ฌ์ด ์์ต๋๋ค. ๊ฐ ์ปค๋์ด ๋ฑ๋ก๋ aten/src/ATen/native/native_functions.yaml
์์ ์งํ ์ํฉ์ ์ถ์ ํ ์ ์์ต๋๋ค.
legacy::cpu::_th
๋ฅผ ๊ฒ์ํ๊ณ ํด๋น ์ซ์๋ฅผ 3์ผ๋ก ๋๋์ด ์ด์ TH ์ปค๋ ์๋ก ๋๋๋๋ค.
legacy::cpu::_thnn
๋ฅผ ๊ฒ์ํ๊ณ ์ด์ TH ์ ๊ฒฝ๋ง ์ปค๋์ ์๋ก ํด๋น ์ซ์๋ฅผ 3์ผ๋ก ๋๋๋๋ค.
๊ฐ ์ปค๋์ ์ผ๋ฐ์ ์ผ๋ก 3๊ฐ์ง ๋ค๋ฅธ ๋ฐฉ๋ฒ์ผ๋ก ๋ฑ๋ก๋ฉ๋๋ค.
1. ์ผ๋ฐ ์ปค๋ y = add(a, b)
2. ์ธํ๋ ์ด์ค ์ปค๋ a = add_(a, b)
3. ์ถ๋ ฅ ์ปค๋ add_out(a, b, out=y)
์ค์ ๊ตฌํ์ ํญ์ ์ถ๋ ฅ ์ปค๋์ ์๊ณ ๋ค๋ฅธ 2๊ฐ๋ ํด๋น ํจ์๋ฅผ ํธ์ถํฉ๋๋ค.
nn ์ปค๋์ ์ข ์ ์ปค๋์ด ๋ ์ ๊ธฐ ๋๋ฌธ์ ์ด์ํ๊ธฐ ์ฌ์ด ๊ฒฝํฅ์ด ์์ต๋๋ค. ๋ฐ๋ผ์ ์ปค๋์ ๊ตฌํํ ๋ฐฉ๋ฒ์ ์ญ์์ผ๋ก ์ปค๋์ ์ด์ํ ์ ์๋ค๋ฉด ์ ์ฒด ์์ ์ ๋ ์ํํ๊ฒ ๋ฉ๋๋ค.
ํฌํ ์ถ์ ๋ฌธ์ ํ์ธ https://github.com/pytorch/pytorch/issues/24507 , @VitalyFedyunin ๋ ์ฐธ์กฐ
๋ค์์ #32437์์ ์์ฒญํ ๋ณต์์ ์ง์์ ๋ํ ์ํ ์ ๋ฐ์ดํธ์ ๋๋ค. ์ ๋ ์ค๋ CPU ๊ด๋ จ ์ง์์ ์ํด ๋ค์ ์ผํ๊ณ ์์ต๋๋ค.
angle()
, real()
, imag()
, conj()
๋ชจ๋ ๊ตฌํ๋์์ต๋๋ค.abs()
๋ ๋ณต์์์ ๋ํด ๋ณ๋์ ๊ตฌํ์ด ํ์ํฉ๋๋ค. (์์ @boeddeker ๋ฐ @Randl ์ ๋ฉ๋ชจ ์ฐธ์กฐ)๋ณต์์ ์ง์์ ํ์ฌ ํธ๋ฆฌ ์ธ๋ถ์์ ๊ตฌํ๋ฉ๋๋ค. ์ด๊ฒ์ด ์๋ฏธํ๋ ๋ฐ๋ ๋ค์๊ณผ ๊ฐ์ต๋๋ค.
์ด ๋ฌธ์ ์ ๋ํ ์ถ๊ฐ ์ ๋ฐ์ดํธ: https://github.com/pytorch/pytorch/issues/33152
์ด๊ฒ์ ๋ณ๋์ ๋ฌธ์ ๊ฐ ๋ ์๋ ์๊ณ ์๋ ์๋ ์์ง๋ง, ํ์ฌ ๋ฌธ์์ 'ํ์ฌ pytorch๊ฐ ๋ณต์์์ ํจ๊ป ์๋ํ๋ ๋ฐฉ์'์ ์ค๋ช ํ๋ ๋ด์ฉ์ ํฌํจํ๋ ๊ฒ์ด ์ค์ง์ ์ผ๋ก ๋ ์ค์ํ๋ค๊ณ ์๊ฐํฉ๋๋ค. ์ผ๋ช ๋ง์ , ๊ณฑ์ , ์ผ์ข ์ ๊ท๋ฒ์ ํ ์ ์๊ณ ๋ณต์กํ ๊ฐ์ค์น๋ฅผ ๊ฐ์ง ์ ์์ต๋๋ค. ์ด ๋ชจ๋ ๊ฒ์ ์์ ์์ค์์ ์๋ํ ํ์ฌ ๋์์ ์ค๋ช ํ๋ ๋ช ์ค์ ๋ฌธ์๋ก ์์ฝ๋ ์ ์์ต๋๋ค.
์ด๊ฒ์ ๋ณ๋์ ๋ฌธ์ ๊ฐ ๋ ์๋ ์๊ณ ์๋ ์๋ ์์ง๋ง, ํ์ฌ ๋ฌธ์์ 'ํ์ฌ pytorch๊ฐ ๋ณต์์์ ํจ๊ป ์๋ํ๋ ๋ฐฉ์'์ ์ค๋ช ํ๋ ๋ด์ฉ์ ํฌํจํ๋ ๊ฒ์ด ์ค์ง์ ์ผ๋ก ๋ ์ค์ํ๋ค๊ณ ์๊ฐํฉ๋๋ค. ์ผ๋ช ๋ง์ , ๊ณฑ์ , ์ผ์ข ์ ๊ท๋ฒ์ ํ ์ ์๊ณ ๋ณต์กํ ๊ฐ์ค์น๋ฅผ ๊ฐ์ง ์ ์์ต๋๋ค. ์ด ๋ชจ๋ ๊ฒ์ ์์ ์์ค์์ ์๋ํ ํ์ฌ ๋์์ ์ค๋ช ํ๋ ๋ช ์ค์ ๋ฌธ์๋ก ์์ฝ๋ ์ ์์ต๋๋ค.
์๋
ํ์ธ์ @redwrasse ํผ๋๋ฐฑ ์ฃผ์
์ ๊ฐ์ฌํฉ๋๋ค! ํ์ฌ ๋ง์คํฐ์ ๋ณต์กํ ํ
์์ ๋ํด ์ง์๋๋ ์ผ๋ถ ํ ์น ๊ธฐ๋ณธ ๋ฐ ๋ณต์กํ ๊ธฐ๋ฅ์ ๋ํด ์ด์ผ๊ธฐํ๋ ๋ณต์์์ ๋ํ ๋ฉ๋ชจ๊ฐ ์์ต๋๋ค.
(๋๋ถ๋ถ์ 1.6 ๋ฆด๋ฆฌ์ค์ ํฌํจ๋จ) https://pytorch.org/docs/master/complex_numbers.html?highlight=complex. ๋ค๋ฅธ ์ด๋ค ๊ธฐ๋ฅ์ ๊ด์ฌ์ด ์๋์ง ๊ณต์ ํ ์ ์์ต๋๊น? ํ์ฌ ์ง์ ๋ฐ ํฅํ ๋ฆด๋ฆฌ์ค์ ๋ํ ๊ณํ์ ๋ํด ๋ ์ด์ผ๊ธฐํ๊ฒ ๋์ด ๊ธฐ์ฉ๋๋ค.
์ด๊ฒ์ ๋ณ๋์ ๋ฌธ์ ๊ฐ ๋ ์๋ ์๊ณ ์๋ ์๋ ์์ง๋ง, ํ์ฌ ๋ฌธ์์ 'ํ์ฌ pytorch๊ฐ ๋ณต์์์ ํจ๊ป ์๋ํ๋ ๋ฐฉ์'์ ์ค๋ช ํ๋ ๋ด์ฉ์ ํฌํจํ๋ ๊ฒ์ด ์ค์ง์ ์ผ๋ก ๋ ์ค์ํ๋ค๊ณ ์๊ฐํฉ๋๋ค. ์ผ๋ช ๋ง์ , ๊ณฑ์ , ์ผ์ข ์ ๊ท๋ฒ์ ํ ์ ์๊ณ ๋ณต์กํ ๊ฐ์ค์น๋ฅผ ๊ฐ์ง ์ ์์ต๋๋ค. ์ด ๋ชจ๋ ๊ฒ์ ์์ ์์ค์์ ์๋ํ ํ์ฌ ๋์์ ์ค๋ช ํ๋ ๋ช ์ค์ ๋ฌธ์๋ก ์์ฝ๋ ์ ์์ต๋๋ค.
์๋ ํ์ธ์ @redwrasse ํผ๋๋ฐฑ ์ฃผ์ ์ ๊ฐ์ฌํฉ๋๋ค! ํ์ฌ ๋ง์คํฐ์ ๋ณต์กํ ํ ์์ ๋ํด ์ง์๋๋ ์ผ๋ถ ํ ์น ๊ธฐ๋ณธ ๋ฐ ๋ณต์กํ ๊ธฐ๋ฅ์ ๋ํด ์ด์ผ๊ธฐํ๋ ๋ณต์์์ ๋ํ ๋ฉ๋ชจ๊ฐ ์์ต๋๋ค.
(๋๋ถ๋ถ์ 1.6 ๋ฆด๋ฆฌ์ค์ ํฌํจ๋จ) https://pytorch.org/docs/master/complex_numbers.html?highlight=complex. ๋ค๋ฅธ ์ด๋ค ๊ธฐ๋ฅ์ ๊ด์ฌ์ด ์๋์ง ๊ณต์ ํ ์ ์์ต๋๊น? ํ์ฌ ์ง์ ๋ฐ ํฅํ ๋ฆด๋ฆฌ์ค์ ๋ํ ๊ณํ์ ๋ํด ๋ ์ด์ผ๊ธฐํ๊ฒ ๋์ด ๊ธฐ์ฉ๋๋ค.
@anjali411 ๊ฐ์ฌํฉ๋๋ค. ์ด ๋ฌธ์๋ฅผ ๋ณด๋ ๋ฐ๊ฐ์ต๋๋ค. ์ด์ ์๋ ๋ชฐ๋์ต๋๋ค. ํ์ํ ๊ฒ์ '๋ณต์กํ ์ ๊ฒฝ๋ง์ ๋ํ ํ์ฌ ์ง์ ์ํ'์ ๋ช ์ค ์๊ณผ ์ค์์ ์๋ ๊ฒ ๊ฐ์ง๋ง ...
๋ณต์กํ autograd์ ๊ด์ฌ์ด ์๋ ์ฌ๋๋ค์ https://github.com/pytorch/pytorch/issues/41857 ์์ PyTorch๊ฐ ๋ฐ๋ฅผ ๊ท์น(JAX ๋๋ TF)์ ๋ํด ์ค๋ช ํฉ๋๋ค.
๊ฐ์ฅ ์ ์ฉํ ๋๊ธ
@sunilkpai , @boeddeker , @Randl ,
๋ณต์กํ ํ์ ์ํ์ ๋ํ ๋ณด๊ณ ์์ ๊ฐ์ฌ๋๋ฆฝ๋๋ค. ๋๋ ๊ทธ๊ฒ์ ๋ฐ๋ฅด๋ ค๊ณ ๋ ธ๋ ฅํ ๊ฒ์ด๊ณ ๋๋ ๋ค์ ์ฃผ์ ์ด๊ฒ์ ๋ค์ ํ ๊ฒ์ด๋ค. ์ฌ๊ธฐ์ ๋ช ๊ฐ์ง ๋งํฌ๋ฅผ ์ถ๊ฐํ๊ณ ํ๋ก์ ํธ ์ํ๋ฅผ ์ค๋ช ํด์ผ ํ๋ค๊ณ ์๊ฐํ์ต๋๋ค.
๋ณต์์์ ์ํ๋ ๋น๊ณต์์ ์ผ๋ก ์ง์๋๋ฉฐ PyTorch ํ์ฅ์ ํตํด ์ถ๊ฐํด์ผ ํฉ๋๋ค.
๊ฐ ํ์ฅ์๋ ๋ค์ ๋ ๊ฐ์ง๊ฐ ํฌํจ๋ฉ๋๋ค.
.cpp
.test/
ํด๋.ํ ์คํธ ์คํฌ๋ฆฝํธ์์ ์ง์๋๋ ์ปค๋๊ณผ ์ง์ํ์ง ์๋ ์ปค๋์ ํ์ธํ์ญ์์ค.
๋ณต์กํ ํ ์๋ฅผ ์ฝ์์ ์ธ์ํ ์ ์๋ ์ด์ ๋ ๋ฌด์์ ๋๊น?
tensor.py
์ ๋ด์ฉ์ ์์ ํ์ฌ ์ธ์ ํ์์ ๋ฌด์ํ ์ ์์ต๋๋ค.ํ์ฌ ํ๋ก์ ํธ ์ํ:
</li> <li>Complex number specific code is under 'aten/src/ATen/native/cpu/zmath.h
์๋ PyTorch ๋ด๋ถ์์ ๊ตฌํ๋ฉ๋๋ค.</li> <li>Complex number specific code is under 'aten/src/ATen/native/cuda/zmath.cuh
์๋ PyTorch ๋ด๋ถ์์ ๊ตฌํ๋ฉ๋๋ค.thrust::complex<T>
๋ฐ์ดํฐ ์ ํ์ด ์ฌ์ฉ๋๋ฉฐ ์ฌ๊ธฐ์๋ ์ต์ ํ๋ ์ปค๋์ด ํฌํจ๋ฉ๋๋ค.ํ์ฌ ๊ฐ๋ฐ:
--