Pytorch: ๊ธฐ๋Šฅ ์š”์ฒญ: load_state_dict๋Š” ํŒŒ์ผ ์ด๋ฆ„์„ ๊ฐ€์ ธ์™€์•ผ ํ•ฉ๋‹ˆ๋‹ค.

์— ๋งŒ๋“  2017๋…„ 05์›” 31์ผ  ยท  3์ฝ”๋ฉ˜ํŠธ  ยท  ์ถœ์ฒ˜: pytorch/pytorch

๋†’์€ ๋ฉ”๋ชจ๋ฆฌ ๋ถ€์กฑ ์ƒํ™ฉ์—์„œ ๋‹ค์Œ์€ ์ผ๋ฐ˜์ ์œผ๋กœ ๋ฐœ์ƒํ•ฉ๋‹ˆ๋‹ค.

  1. ๋ชจ๋ธ ์ƒ์„ฑ
  2. ์ฒดํฌํฌ์ธํŠธ ํŒŒ์ผ์—์„œ state_dict ์ฝ๊ธฐ(GPU์— ๋กœ๋“œ)
  3. model.load_state_dict(๋“ค)

๋ฉ”๋ชจ๋ฆฌ ๋ถ€์กฑ์œผ๋กœ ์ธํ•ด ์ผ๋ฐ˜์ ์ธ ํ•ด๊ฒฐ ๋ฐฉ๋ฒ•์€ ๋จผ์ € ๋‹ค์Œ์„ ์ˆ˜ํ–‰ํ•˜๋Š” ๊ฒƒ์ž…๋‹ˆ๋‹ค.

s = torch.load('my_file.pt', map_location=lambda storage, loc: storage)

๊ทธ๋Ÿฐ ๋‹ค์Œ s ์„ model ์— ๋กœ๋“œํ•ฉ๋‹ˆ๋‹ค.

์ด๊ฒƒ์€ ์šฐ๋ฆฌ๊ฐ€ ํ”ผํ•  ์ˆ˜ ์žˆ์–ด์•ผ ํ•˜๋Š” ๋งค์šฐ ์ผ๋ฐ˜์ ์ธ ์‹œ๋‚˜๋ฆฌ์˜ค์ด๋ฉฐ ์ด ์‹œ๋‚˜๋ฆฌ์˜ค์—๋Š” ๋ช‡ ๊ฐ€์ง€ ํ•จ์ •์ด ์žˆ์„ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ๋ถ€๋ถ„ GPU ๋ถ€๋ถ„ ๋ชจ๋ธ์—์„œ ์ผ์–ด๋‚˜๋Š” ์ผ, ๋‹ค์ค‘ GPU ๋ชจ๋ธ์—์„œ ์ผ์–ด๋‚˜๋Š” ์ผ...

load_state_dict๊ฐ€ ํŒŒ์ผ ์ด๋ฆ„์„ ์ง์ ‘ ๊ฐ€์ ธ์˜ค๋ฉด ๊ธฐ์กด ๋งค๊ฐœ ๋ณ€์ˆ˜ ์ €์žฅ์†Œ๋ฅผ ์‚ญ์ œํ•˜๊ณ  ์ฆ‰์‹œ ์ƒˆ ๋งค๊ฐœ ๋ณ€์ˆ˜๋กœ ์„ค์ •ํ•  ์ˆ˜ ์žˆ์œผ๋ฏ€๋กœ ์ถ”๊ฐ€ ๋ฉ”๋ชจ๋ฆฌ๊ฐ€ ํ•„์š”ํ•˜์ง€ ์•Š์Šต๋‹ˆ๋‹ค.

feature nn triaged

๊ฐ€์žฅ ์œ ์šฉํ•œ ๋Œ“๊ธ€

load_state_dict ๊ฐ€ ํŒŒ์ผ ์ด๋ฆ„์„ ์‚ฌ์šฉํ•˜๋Š” ๊ฒฝ์šฐ map_location ๋งค๊ฐœ๋ณ€์ˆ˜๋„ ํ—ˆ์šฉํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค. ๋‚˜์—๊ฒŒ ์ผ๋ฐ˜์ ์ธ ์ƒํ™ฉ์€ ํด๋Ÿฌ์Šคํ„ฐ ์‹œ์Šคํ…œ์— ์ฒดํฌํฌ์ธํŠธ๋ฅผ ์ €์žฅํ•œ ๋‹ค์Œ ๋‚ด ๋งฅ๋ถ์— ๋กœ๋“œํ•˜๋Š” ๊ฒƒ์ž…๋‹ˆ๋‹ค(๋”ฐ๋ผ์„œ CPU์— ๋งค๊ฐœ๋ณ€์ˆ˜๋ฅผ ๋กœ๋“œํ•ด์•ผ ํ•จ)

๋ชจ๋“  3 ๋Œ“๊ธ€

์˜ตํ‹ฐ๋งˆ์ด์ € state_dicts์—๋„ ๋™์ผํ•˜๊ฒŒ ์ ์šฉ๋ฉ๋‹ˆ๋‹ค. Adagrad์™€ ๊ฐ™์€ ์ผ๋ถ€ ์˜ตํ‹ฐ๋งˆ์ด์ €์˜ ๊ฒฝ์šฐ ์ฒดํฌํฌ์ธํŠธ๊ฐ€ ํฌ๋ฉฐ ๋™์ผํ•œ ๋ฉ”๋ชจ๋ฆฌ ์••๋ ฅ ์ƒํ™ฉ์ด ๋ฐœ์ƒํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ์˜ตํ‹ฐ๋งˆ์ด์ €์—๋Š” .cuda() ์กฐ์ฐจ ์—†์œผ๋ฏ€๋กœ ๋จผ์ € state_dict๋ฅผ CPU์— ์ˆ˜๋™์œผ๋กœ ๋กœ๋“œํ•œ ๋‹ค์Œ ์ˆ˜๋™์œผ๋กœ ์ผ๋ถ€๋ฅผ GPU์— ๋ณต์‚ฌํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค.

์˜ค๋Š˜ @aszlam ์„ ๋„์šฐ๋ฉด์„œ ์ด ๋ฌธ์ œ๋ฅผ ๋งŒ๋‚ฌ์Šต๋‹ˆ๋‹ค.

load_state_dict ๊ฐ€ ํŒŒ์ผ ์ด๋ฆ„์„ ์‚ฌ์šฉํ•˜๋Š” ๊ฒฝ์šฐ map_location ๋งค๊ฐœ๋ณ€์ˆ˜๋„ ํ—ˆ์šฉํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค. ๋‚˜์—๊ฒŒ ์ผ๋ฐ˜์ ์ธ ์ƒํ™ฉ์€ ํด๋Ÿฌ์Šคํ„ฐ ์‹œ์Šคํ…œ์— ์ฒดํฌํฌ์ธํŠธ๋ฅผ ์ €์žฅํ•œ ๋‹ค์Œ ๋‚ด ๋งฅ๋ถ์— ๋กœ๋“œํ•˜๋Š” ๊ฒƒ์ž…๋‹ˆ๋‹ค(๋”ฐ๋ผ์„œ CPU์— ๋งค๊ฐœ๋ณ€์ˆ˜๋ฅผ ๋กœ๋“œํ•ด์•ผ ํ•จ)

๋‚˜์™€ @szagoruyko ๋Š” ์ง๋ ฌํ™”๋œ ๋ชจ๋ธ์šฉ HDF5 ํ˜•์‹์˜ ํŒฌ์ž…๋‹ˆ๋‹ค. ์ด ์ œ์•ˆ๊ณผ ์ž˜ ์–ด์šธ๋ฆด ์ˆ˜ ์žˆ๋‹ค๋ฉด

์ด ํŽ˜์ด์ง€๊ฐ€ ๋„์›€์ด ๋˜์—ˆ๋‚˜์š”?
0 / 5 - 0 ๋“ฑ๊ธ‰