๋์ ๋ฉ๋ชจ๋ฆฌ ๋ถ์กฑ ์ํฉ์์ ๋ค์์ ์ผ๋ฐ์ ์ผ๋ก ๋ฐ์ํฉ๋๋ค.
๋ฉ๋ชจ๋ฆฌ ๋ถ์กฑ์ผ๋ก ์ธํด ์ผ๋ฐ์ ์ธ ํด๊ฒฐ ๋ฐฉ๋ฒ์ ๋จผ์ ๋ค์์ ์ํํ๋ ๊ฒ์ ๋๋ค.
s = torch.load('my_file.pt', map_location=lambda storage, loc: storage)
๊ทธ๋ฐ ๋ค์ s
์ model
์ ๋ก๋ํฉ๋๋ค.
์ด๊ฒ์ ์ฐ๋ฆฌ๊ฐ ํผํ ์ ์์ด์ผ ํ๋ ๋งค์ฐ ์ผ๋ฐ์ ์ธ ์๋๋ฆฌ์ค์ด๋ฉฐ ์ด ์๋๋ฆฌ์ค์๋ ๋ช ๊ฐ์ง ํจ์ ์ด ์์ ์ ์์ต๋๋ค. ๋ถ๋ถ GPU ๋ถ๋ถ ๋ชจ๋ธ์์ ์ผ์ด๋๋ ์ผ, ๋ค์ค GPU ๋ชจ๋ธ์์ ์ผ์ด๋๋ ์ผ...
load_state_dict๊ฐ ํ์ผ ์ด๋ฆ์ ์ง์ ๊ฐ์ ธ์ค๋ฉด ๊ธฐ์กด ๋งค๊ฐ ๋ณ์ ์ ์ฅ์๋ฅผ ์ญ์ ํ๊ณ ์ฆ์ ์ ๋งค๊ฐ ๋ณ์๋ก ์ค์ ํ ์ ์์ผ๋ฏ๋ก ์ถ๊ฐ ๋ฉ๋ชจ๋ฆฌ๊ฐ ํ์ํ์ง ์์ต๋๋ค.
์ตํฐ๋ง์ด์ state_dicts์๋ ๋์ผํ๊ฒ ์ ์ฉ๋ฉ๋๋ค. Adagrad์ ๊ฐ์ ์ผ๋ถ ์ตํฐ๋ง์ด์ ์ ๊ฒฝ์ฐ ์ฒดํฌํฌ์ธํธ๊ฐ ํฌ๋ฉฐ ๋์ผํ ๋ฉ๋ชจ๋ฆฌ ์๋ ฅ ์ํฉ์ด ๋ฐ์ํ ์ ์์ต๋๋ค. ์ตํฐ๋ง์ด์ ์๋ .cuda()
์กฐ์ฐจ ์์ผ๋ฏ๋ก ๋จผ์ state_dict๋ฅผ CPU์ ์๋์ผ๋ก ๋ก๋ํ ๋ค์ ์๋์ผ๋ก ์ผ๋ถ๋ฅผ GPU์ ๋ณต์ฌํด์ผ ํฉ๋๋ค.
์ค๋ @aszlam ์ ๋์ฐ๋ฉด์ ์ด ๋ฌธ์ ๋ฅผ ๋ง๋ฌ์ต๋๋ค.
load_state_dict
๊ฐ ํ์ผ ์ด๋ฆ์ ์ฌ์ฉํ๋ ๊ฒฝ์ฐ map_location
๋งค๊ฐ๋ณ์๋ ํ์ฉํด์ผ ํฉ๋๋ค. ๋์๊ฒ ์ผ๋ฐ์ ์ธ ์ํฉ์ ํด๋ฌ์คํฐ ์์คํ
์ ์ฒดํฌํฌ์ธํธ๋ฅผ ์ ์ฅํ ๋ค์ ๋ด ๋งฅ๋ถ์ ๋ก๋ํ๋ ๊ฒ์
๋๋ค(๋ฐ๋ผ์ CPU์ ๋งค๊ฐ๋ณ์๋ฅผ ๋ก๋ํด์ผ ํจ)
๋์ @szagoruyko ๋ ์ง๋ ฌํ๋ ๋ชจ๋ธ์ฉ HDF5 ํ์์ ํฌ์ ๋๋ค. ์ด ์ ์๊ณผ ์ ์ด์ธ๋ฆด ์ ์๋ค๋ฉด
๊ฐ์ฅ ์ ์ฉํ ๋๊ธ
load_state_dict
๊ฐ ํ์ผ ์ด๋ฆ์ ์ฌ์ฉํ๋ ๊ฒฝ์ฐmap_location
๋งค๊ฐ๋ณ์๋ ํ์ฉํด์ผ ํฉ๋๋ค. ๋์๊ฒ ์ผ๋ฐ์ ์ธ ์ํฉ์ ํด๋ฌ์คํฐ ์์คํ ์ ์ฒดํฌํฌ์ธํธ๋ฅผ ์ ์ฅํ ๋ค์ ๋ด ๋งฅ๋ถ์ ๋ก๋ํ๋ ๊ฒ์ ๋๋ค(๋ฐ๋ผ์ CPU์ ๋งค๊ฐ๋ณ์๋ฅผ ๋ก๋ํด์ผ ํจ)