์๋ ,
์ฌ๊ธฐ ์์์ ๊ฐ์ด Gumbel ๋
ธ์ด์ฆ๋ฅผ ์ถ๊ฐํ๋ ค๊ณ ์๋ํ์ง๋ง ์ฑ๊ณตํ์ง ๋ชปํ์ต๋๋ค.
์ฌ๋ฌ ๋ชจ๋์ด ์ฌ์ ํ ๋๋ฝ๋ ๊ฒ ๊ฐ์ต๋๊น(์: nn.Uniform()๊ณผ ๊ฐ์ ๊ธฐ๋ณธ ๋ณ์ ํจ์), ์๋๋ฉด ์ ๊ฐ ์๋ชป ์๊ณ ์์ต๋๊น? ์๋ฅผ ๋ค์ด ๋ค์ ํ๊ณผ ๊ฐ์ด pytorch์์ ๊ตฌํํ๋ ๋ฐฉ๋ฒ์ ๋ค์๊ณผ ๊ฐ์ต๋๋ค.
``` -- ๋
ธ์ด์ฆ ฮต ์ํ ๋ชจ๋ ์์ฑ
๋ก์ปฌ ๋
ธ์ด์ฆ ๋ชจ๋ = nn.Sequential()
no iseModule:add (nn.Uniform(0, 1)) -- U(0, 1)์ ์ํ
-- ๊ท ์ผํ ์ํ์ Gumbel ์ํ๋ก ๋ณํ
no iseModule:add (nn.AddConstant(1e-9, true)) -- ์์น์ ์์ ์ฑ ํฅ์
iseModule: ์ถ๊ฐ (nn.Log()) ์์
iseModule: ์ถ๊ฐ (nn.MulConstant(-1, true)) ์์
no iseModule:add (nn.AddConstant(1e-9, true)) -- ์์น์ ์์ ์ฑ ํฅ์
iseModule: ์ถ๊ฐ (nn.Log()) ์์
iseModule: ์ถ๊ฐ (nn.MulConstant(-1, true)) ์์
-- ์ํ๋ฌ ์์ฑ q(z) = G(z) = softmax((log(ฯ) + ฮต)/ฯ) (์ฌ๋งค๊ฐ๋ณ์ํ ํธ๋ฆญ)
๋ก์ปฌ ์ํ๋ฌ = nn.Sequential()
๋ก์ปฌ ์ํ๋ฌ ๋ด๋ถ = nn.ConcatTable()
sample rInternal:add (nn.Identity()) -- ์ ๊ทํ๋์ง ์์ ๋ก๊ทธ ํ๋ฅ log(ฯ)
sample rInternal:add (noiseModule) -- ๋
ธ์ด์ฆ ฮต ์์ฑ
์ํ๋ฌ:์ถ๊ฐ (์ํ๋ฌ ๋ด๋ถ)
์ํ๋ฌ:์ถ๊ฐ (nn.CAddTable())
self.temperature = nn.MulConstant(1 / self.tau, true) -- softmax์ ์จ๋ ฯ
์ํ๋ฌ:์ถ๊ฐ (์์ฒด ์จ๋)
sampler:add (nn.View(-1, self.k)) -- k์์ ์๋ํ๋๋ก ํฌ๊ธฐ ์กฐ์
์ํ๋ฌ:์ถ๊ฐ (nn.SoftMax())
sampler:add (nn.View(-1, self.N * self.k)) -- ๋ค์ ํฌ๊ธฐ ์กฐ์
```
์ฌ๊ธฐ ์์ต๋๋ค. ํจ์ฌ ๋ ์ฝ๊ธฐ ์ฝ๊ณ ๋ชจ๋์ด ํ์ํ์ง ์์ต๋๋ค.
import torch.nn.functional as F
from torch.autograd import Variable
def sampler(input, tau, temperature):
noise = torch.rand(input.size())
noise.add_(1e-9).log_().neg_()
noise.add_(1e-9).log_().neg_()
noise = Variable(noise)
x = (input + noise) / tau + temperature
x = F.softmax(x.view(input.size(0), -1))
return x.view_as(input)
์ฐ๋ฆฌ๋ ๋ฒ๊ทธ ๋ณด๊ณ ์ฉ์ผ๋ก๋ง GitHub๋ฅผ ์ฌ์ฉํ๊ณ ์์ต๋๋ค. ์ง๋ฌธ์ด ์๋ ๊ฒฝ์ฐ ํฌ๋ผ ์ ๊ฒ์ํด ์ฃผ์ธ์.
์ ์ํ ๋ต๋ณ์ ๊ฐ์ฌ๋๋ฆฝ๋๋ค! ๋ค์ ๋ฒ์ ํฌ๋ผ์ ๊ธ์ ์ธ ๊ฒ์ ๋๋ค.
๋ต๋ณ ํด์ฃผ์ ์ ๊ฐ์ฌํฉ๋๋ค!
๊ฐ์ฅ ์ ์ฉํ ๋๊ธ
์ฌ๊ธฐ ์์ต๋๋ค. ํจ์ฌ ๋ ์ฝ๊ธฐ ์ฝ๊ณ ๋ชจ๋์ด ํ์ํ์ง ์์ต๋๋ค.
์ฐ๋ฆฌ๋ ๋ฒ๊ทธ ๋ณด๊ณ ์ฉ์ผ๋ก๋ง GitHub๋ฅผ ์ฌ์ฉํ๊ณ ์์ต๋๋ค. ์ง๋ฌธ์ด ์๋ ๊ฒฝ์ฐ ํฌ๋ผ ์ ๊ฒ์ํด ์ฃผ์ธ์.