ΠΡΡΡ Π»ΠΈ ΡΠΏΠΎΡΠΎΠ± ΠΏΠ΅ΡΠ΅Π΄Π°ΡΡ Π΄ΠΎΠΏΠΎΠ»Π½ΠΈΡΠ΅Π»ΡΠ½ΡΡ ΡΡΠ½ΠΊΡΠΈΡ Π²ΠΌΠ΅ΡΡΠ΅ Ρ ΡΡΡΠ΅ΡΡΠ²ΡΡΡΠΈΠΌΠΈ ΡΠΎΠΊΠ΅Π½Π°ΠΌΠΈ ΡΠ»ΠΎΠ² Π² ΠΊΠ°ΡΠ΅ΡΡΠ²Π΅ Π²Ρ ΠΎΠ΄Π½ΡΡ Π΄Π°Π½Π½ΡΡ ΠΈ ΠΏΠ΅ΡΠ΅Π΄Π°ΡΡ ΠΈΡ Π² ΠΊΠΎΠ΄ΠΈΡΠΎΠ²ΡΠΈΠΊ RNN?
ΠΠ°Π²Π°ΠΉΡΠ΅ ΡΠ°ΡΡΠΌΠΎΡΡΠΈΠΌ ΠΏΡΠΎΠ±Π»Π΅ΠΌΡ NMT, Π΄ΠΎΠΏΡΡΡΠΈΠΌ, Ρ ΠΌΠ΅Π½Ρ Π΅ΡΡΡ Π΅ΡΠ΅ 2 ΡΡΠΎΠ»Π±ΡΠ° ΡΡΠ½ΠΊΡΠΈΠΉ Π΄Π»Ρ ΡΠΎΠΎΡΠ²Π΅ΡΡΡΠ²ΡΡΡΠ΅Π³ΠΎ ΠΈΡΡ ΠΎΠ΄Π½ΠΎΠ³ΠΎ ΡΠ»ΠΎΠ²Π°ΡΡ (Feature1 Π·Π΄Π΅ΡΡ). ΠΠ°ΠΏΡΠΈΠΌΠ΅Ρ, ΡΠ°ΡΡΠΌΠΎΡΡΠΈΠΌ ΡΡΠΎ Π½ΠΈΠΆΠ΅:
Feature1 Feature2 Feature3
word1 xa
word2 yb
word3 yc
.
.
ΠΡΠ»ΠΎ Π±Ρ Π·Π΄ΠΎΡΠΎΠ²ΠΎ, Π΅ΡΠ»ΠΈ Π±Ρ ΠΊΡΠΎ-Π½ΠΈΠ±ΡΠ΄Ρ ΡΠ°ΡΡΠΊΠ°Π·Π°Π», ΠΊΠ°ΠΊ ΠΏΡΠ°ΠΊΡΠΈΡΠ΅ΡΠΊΠΈ ΡΠ΅Π°Π»ΠΈΠ·ΠΎΠ²Π°ΡΡ / ΡΠ΄Π΅Π»Π°ΡΡ ΡΡΠΎ Π² pytorch. ΠΠ°ΡΠ°Π½Π΅Π΅ ΡΠΏΠ°ΡΠΈΠ±ΠΎ.
Π‘Π°ΠΌΡΠΉ ΠΏΡΠΎΡΡΠΎΠΉ ΡΠΏΠΎΡΠΎΠ± - ΠΎΠ±ΡΠ΅Π΄ΠΈΠ½ΠΈΡΡ ΠΎΠ±ΡΠ΅ΠΊΡΡ Π² ΠΎΠ΄ΠΈΠ½ Π²Ρ
ΠΎΠ΄Π½ΠΎΠΉ Π²Π΅ΠΊΡΠΎΡ. ΠΠ΄Π½Π°ΠΊΠΎ ΡΡΠΎ ΡΠ°Π±ΠΎΡΠ°Π΅Ρ ΡΠΎΠ»ΡΠΊΠΎ Π² ΡΠΎΠΌ ΡΠ»ΡΡΠ°Π΅, Π΅ΡΠ»ΠΈ Π²Π°ΡΠ° RNN ΠΏΡΠΈΠ½ΠΈΠΌΠ°Π΅Ρ Π²Π΅ΠΊΡΠΎΡΠ½ΡΠΉ Π²Π²ΠΎΠ΄, Π° Π½Π΅ Π΄ΠΈΡΠΊΡΠ΅ΡΠ½ΡΠΉ Π²Π²ΠΎΠ΄ (LongTensor) ΡΠ΅ΡΠ΅Π· ΡΠ»ΠΎΠΉ Π²ΡΡΡΠ°ΠΈΠ²Π°Π½ΠΈΡ. Π ΡΡΠΎΠΌ ΡΠ»ΡΡΠ°Π΅ Π²Ρ Π·Π°Ρ
ΠΎΡΠΈΡΠ΅ ΠΎΠ±ΡΠ΅Π΄ΠΈΠ½ΠΈΡΡ ΡΠ²ΠΎΠΈ Π΄ΠΎΠΏΠΎΠ»Π½ΠΈΡΠ΅Π»ΡΠ½ΡΠ΅ ΡΡΠ½ΠΊΡΠΈΠΈ ΠΏΠΎΡΠ»Π΅ ΡΠΎΠ³ΠΎ, ΠΊΠ°ΠΊ Π²Π²Π΅Π΄Π΅Π½Π½ΡΠ΅ Π΄Π°Π½Π½ΡΠ΅ Π±ΡΠ΄ΡΡ Π²ΡΡΡΠΎΠ΅Π½Ρ. ΠΡΠΎΠ³ΠΎ ΠΏΡΠΎΡΠ΅ Π²ΡΠ΅Π³ΠΎ Π΄ΠΎΡΡΠΈΡΡ, ΠΈΡΠΏΠΎΠ»ΡΠ·ΡΡ Π΄ΠΎΠΏΠΎΠ»Π½ΠΈΡΠ΅Π»ΡΠ½ΡΠΉ Π²Ρ
ΠΎΠ΄Π½ΠΎΠΉ Π°ΡΠ³ΡΠΌΠ΅Π½Ρ Π² ΠΊΠΎΠ΄ΠΈΡΠΎΠ²ΡΠΈΠΊΠ΅ ΠΈ ΠΈΡΠΏΠΎΠ»ΡΠ·ΡΡ torch.cat
(ΡΠΌ. ΠΠ²ΠΎΠ΄ ΠΊΠ°ΡΠ΅Π³ΠΎΡΠΈΠΈ Π² https://github.com/spro/practical-pytorch/blob/master/conditional-char -rnn / conditional-char-rnn.ipynb Π΄Π»Ρ ΠΏΡΠΎΡΡΠΎΠ³ΠΎ ΠΏΡΠΈΠΌΠ΅ΡΠ°, Ρ
ΠΎΡΡ ΠΎΠ½ Π½Π΅ ΠΈΡΠΏΠΎΠ»ΡΠ·ΡΠ΅Ρ ΡΡΠΎΠ²Π΅Π½Ρ Π²Π½Π΅Π΄ΡΠ΅Π½ΠΈΡ).
ΠΡΠ»ΠΈ Π²Π°ΡΠΈ ΡΡΠ½ΠΊΡΠΈΠΈ ΡΠ°ΠΊΠΆΠ΅ ΡΠ²Π»ΡΡΡΡΡ Π΄ΠΈΡΠΊΡΠ΅ΡΠ½ΡΠΌΠΈ, Π²Π°ΠΌ ΠΌΠΎΠΆΠ΅Ρ ΠΏΠΎΡΡΠ΅Π±ΠΎΠ²Π°ΡΡΡΡ Π½Π΅ΡΠΊΠΎΠ»ΡΠΊΠΎ ΡΠ»ΠΎΠ΅Π² Π²ΡΡΡΠ°ΠΈΠ²Π°Π½ΠΈΡ, ΠΏΠΎ ΠΎΠ΄Π½ΠΎΠΌΡ Π΄Π»Ρ ΠΊΠ°ΠΆΠ΄ΠΎΠ³ΠΎ, ΠΈ ΠΎΠ±ΡΠ΅Π΄ΠΈΠ½ΠΈΡΡ Π²ΡΠ΅ ΡΠ΅Π·ΡΠ»ΡΡΠ°ΡΡ.
Π¦ΠΈΡΠΈΡΡΠ΅ΠΌΡΠΉ ΡΠ°Π·Π΄Π΅Π» Π±ΠΎΠ»ΡΡΠ΅ ΠΊΠ°ΡΠ°Π΅ΡΡΡ Π΄Π»ΠΈΠ½Ρ Π²Ρ ΠΎΠ΄Π½ΠΎΠΉ ΠΈ ΡΠ΅Π»Π΅Π²ΠΎΠΉ ΠΏΠΎΡΠ»Π΅Π΄ΠΎΠ²Π°ΡΠ΅Π»ΡΠ½ΠΎΡΡΠΈ, Π° Π½Π΅ ΡΠ°Π·ΠΌΠ΅ΡΠ° ΠΎΠ±ΡΠ΅ΠΊΡΠ°.
ΠΠ»Ρ ΠΈΠ½ΠΈΡΠΈΠ°Π»ΠΈΠ·Π°ΡΠΎΡΠ° Π²Π°ΠΌ Π½ΡΠΆΠ½ΠΎ Π±ΡΠ΄Π΅Ρ Π΄ΠΎΠ±Π°Π²ΠΈΡΡ Π°ΡΠ³ΡΠΌΠ΅Π½ΡΡ Π΄Π»Ρ ΡΠ°Π·ΠΌΠ΅ΡΠΎΠ² Π²Π°ΡΠΈΡ
ΠΎΠ±ΡΠ΅ΠΊΡΠΎΠ² ΠΈ ΡΠΎΠ·Π΄Π°ΡΡ Π½ΠΎΠ²ΡΠΉ ΡΠ»ΠΎΠΉ Π²Π½Π΅Π΄ΡΠ΅Π½ΠΈΡ Π΄Π»Ρ ΠΊΠ°ΠΆΠ΄ΠΎΠ³ΠΎ ΠΎΡΠ΄Π΅Π»ΡΠ½ΠΎΠ³ΠΎ ΠΎΠ±ΡΠ΅ΠΊΡΠ°. Π ΠΌΠ΅ΡΠΎΠ΄Π΅ forward () Π²Ρ Π΄ΠΎΠ»ΠΆΠ½Ρ Π΄ΠΎΠ±Π°Π²ΠΈΡΡ Π°ΡΠ³ΡΠΌΠ΅Π½ΡΡ Π΄Π»Ρ Π΄ΠΎΠΏΠΎΠ»Π½ΠΈΡΠ΅Π»ΡΠ½ΡΡ
ΡΡΠ½ΠΊΡΠΈΠΉ, ΠΏΠ΅ΡΠ΅Π½Π°ΠΏΡΠ°Π²ΠΈΡΡ ΠΈΡ
ΡΠ΅ΡΠ΅Π· ΡΠΎΠΎΡΠ²Π΅ΡΡΡΠ²ΡΡΡΠΈΠ΅ ΡΠ»ΠΎΠΈ Π²ΡΡΡΠ°ΠΈΠ²Π°Π½ΠΈΡ ΠΈ ΠΎΠ±ΡΠ΅Π΄ΠΈΠ½ΠΈΡΡ ΠΈΡ
Π²ΡΠ΅ Π² ΠΎΠ΄ΠΈΠ½ Π²Π΅ΠΊΡΠΎΡ Π΄Π»Ρ RNN. RNN ΡΠ°ΠΊΠΆΠ΅ Π±ΡΠ΄Π΅Ρ ΠΈΠΌΠ΅ΡΡ ΡΠ°Π·ΠΌΠ΅Ρ Π²Π²ΠΎΠ΄Π° 3 * hidden_size
ΠΈΠ·-Π·Π° ΡΡΠ΅Ρ
ΡΠ»ΠΎΠΆΠ΅Π½Π½ΡΡ
Π²ΠΌΠ΅ΡΡΠ΅ Π²Π»ΠΎΠΆΠ΅Π½ΠΈΠΉ (ΠΈΠ»ΠΈ Π²Ρ ΠΌΠΎΠΆΠ΅ΡΠ΅ ΡΠΎΠ·Π΄Π°ΡΡ ΡΠ°Π·Π½ΡΠ΅ ΡΠ°Π·ΠΌΠ΅ΡΡ Π²ΡΡΡΠ°ΠΈΠ²Π°Π½ΠΈΡ Π΄Π»Ρ ΠΊΠ°ΠΆΠ΄ΠΎΠΉ ΡΡΠ½ΠΊΡΠΈΠΈ).
Π ΡΠ΅Π»ΠΎΠΌ ΡΡΠΎ Π΄ΠΎΠ»ΠΆΠ½ΠΎ Π²ΡΠ³Π»ΡΠ΄Π΅ΡΡ ΠΏΡΠΈΠΌΠ΅ΡΠ½ΠΎ ΡΠ°ΠΊ:
class EncoderRNN(nn.Module):
def __init__(..., word_size, feature2_size, feature3_size, hidden_size, ...):
...
self.word_embedding = nn.Embedding(word_size, hidden_size)
self.feature2_embedding = nn.Embedding(feature2_size, hidden_size)
self.feature3_embedding = nn.Embedding(feature3_size, hidden_size)
# Note: * 3 because the above 3 embeddings will be concatenated
self.gru = nn.GRU(hidden_size * 3, hidden_size, n_layers, dropout=self.dropout, bidirectional=True)
def forward(self, word_seqs, feature2_seqs, feature3_seqs, input_lengths, hidden=None):
# Note: we run this all at once (over multiple batches of multiple sequences)
word_embedded = self.word_embedding(word_seqs)
feature2_embedded = self.feature2_embedding(feature2_seqs)
feature3_embedded = self.feature3_embedding(feature3_seqs)
combined = torch.cat((word_embedded, feature2_embedded, feature3_embedded), 2)
packed = torch.nn.utils.rnn.pack_padded_sequence(combined, input_lengths)
...
@spro ΠΠΎΠ»ΡΡΠΎΠ΅ ΡΠΏΠ°ΡΠΈΠ±ΠΎ Π·Π° Π²Π°Ρ Π²ΠΊΠ»Π°Π΄. Π― Π½Π°ΠΏΠΈΡΠ°Π» ΡΠΎΠΎΠ±ΡΠ΅Π½ΠΈΠ΅ Π² Π±Π»ΠΎΠ³Π΅ ΠΎ ΡΠΎΠΌ ΠΆΠ΅ -
https://iamsiva11.github.io/extra-features-seq2seq/. ΠΠ°Π΄Π΅ΡΡΡ, ΡΡΠΎ Π±ΡΠ΄Π΅Ρ ΠΏΠΎΠ»Π΅Π·Π½ΠΎ ΠΌΠ½ΠΎΠ³ΠΈΠΌ Π΄ΡΡΠ³ΠΈΠΌ.
Π‘Π°ΠΌΡΠΉ ΠΏΠΎΠ»Π΅Π·Π½ΡΠΉ ΠΊΠΎΠΌΠΌΠ΅Π½ΡΠ°ΡΠΈΠΉ
ΠΠ»Ρ ΠΈΠ½ΠΈΡΠΈΠ°Π»ΠΈΠ·Π°ΡΠΎΡΠ° Π²Π°ΠΌ Π½ΡΠΆΠ½ΠΎ Π±ΡΠ΄Π΅Ρ Π΄ΠΎΠ±Π°Π²ΠΈΡΡ Π°ΡΠ³ΡΠΌΠ΅Π½ΡΡ Π΄Π»Ρ ΡΠ°Π·ΠΌΠ΅ΡΠΎΠ² Π²Π°ΡΠΈΡ ΠΎΠ±ΡΠ΅ΠΊΡΠΎΠ² ΠΈ ΡΠΎΠ·Π΄Π°ΡΡ Π½ΠΎΠ²ΡΠΉ ΡΠ»ΠΎΠΉ Π²Π½Π΅Π΄ΡΠ΅Π½ΠΈΡ Π΄Π»Ρ ΠΊΠ°ΠΆΠ΄ΠΎΠ³ΠΎ ΠΎΡΠ΄Π΅Π»ΡΠ½ΠΎΠ³ΠΎ ΠΎΠ±ΡΠ΅ΠΊΡΠ°. Π ΠΌΠ΅ΡΠΎΠ΄Π΅ forward () Π²Ρ Π΄ΠΎΠ»ΠΆΠ½Ρ Π΄ΠΎΠ±Π°Π²ΠΈΡΡ Π°ΡΠ³ΡΠΌΠ΅Π½ΡΡ Π΄Π»Ρ Π΄ΠΎΠΏΠΎΠ»Π½ΠΈΡΠ΅Π»ΡΠ½ΡΡ ΡΡΠ½ΠΊΡΠΈΠΉ, ΠΏΠ΅ΡΠ΅Π½Π°ΠΏΡΠ°Π²ΠΈΡΡ ΠΈΡ ΡΠ΅ΡΠ΅Π· ΡΠΎΠΎΡΠ²Π΅ΡΡΡΠ²ΡΡΡΠΈΠ΅ ΡΠ»ΠΎΠΈ Π²ΡΡΡΠ°ΠΈΠ²Π°Π½ΠΈΡ ΠΈ ΠΎΠ±ΡΠ΅Π΄ΠΈΠ½ΠΈΡΡ ΠΈΡ Π²ΡΠ΅ Π² ΠΎΠ΄ΠΈΠ½ Π²Π΅ΠΊΡΠΎΡ Π΄Π»Ρ RNN. RNN ΡΠ°ΠΊΠΆΠ΅ Π±ΡΠ΄Π΅Ρ ΠΈΠΌΠ΅ΡΡ ΡΠ°Π·ΠΌΠ΅Ρ Π²Π²ΠΎΠ΄Π°
3 * hidden_size
ΠΈΠ·-Π·Π° ΡΡΠ΅Ρ ΡΠ»ΠΎΠΆΠ΅Π½Π½ΡΡ Π²ΠΌΠ΅ΡΡΠ΅ Π²Π»ΠΎΠΆΠ΅Π½ΠΈΠΉ (ΠΈΠ»ΠΈ Π²Ρ ΠΌΠΎΠΆΠ΅ΡΠ΅ ΡΠΎΠ·Π΄Π°ΡΡ ΡΠ°Π·Π½ΡΠ΅ ΡΠ°Π·ΠΌΠ΅ΡΡ Π²ΡΡΡΠ°ΠΈΠ²Π°Π½ΠΈΡ Π΄Π»Ρ ΠΊΠ°ΠΆΠ΄ΠΎΠΉ ΡΡΠ½ΠΊΡΠΈΠΈ).Π ΡΠ΅Π»ΠΎΠΌ ΡΡΠΎ Π΄ΠΎΠ»ΠΆΠ½ΠΎ Π²ΡΠ³Π»ΡΠ΄Π΅ΡΡ ΠΏΡΠΈΠΌΠ΅ΡΠ½ΠΎ ΡΠ°ΠΊ: