对比RNN,GRU隐藏层的计算方式有所变化,GRU的更新公式如下: \[ \begin{align*} z_t &= \sigma(W_{zh} \cdot h_{t-1} + W_{zx} \cdot x_t + b_z) \quad \text{(更新门)} \\ r_t &= \sigma(W_{rh} \cdot h_{t-1} + W_{rx} \cdot x_t + b_r) \quad \text{(重置门)} \\ \tilde{h}_t &= \text{tanh}(W_{hh} \cdot (r_t \odot h_{t-1}) + W_{hx} \cdot x_t + b_h) \quad \text{(候选隐藏状态)} \\ h_t &= (1 - z_t) \odot h_{t-1} + z_t \odot \tilde{h}_t \quad \text{(最终隐藏状态)} \end{align*} \]

Pytorch的nn.GRU

版本为2.1,pytorch的gru实现中\(\sigma\)选择的sigmoid参考官方文档 \[ sigmoid(x) = \frac{1}{1 + e^{-x}} \]
为什么是sigmoid? 1

训练模型

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import torch
from torch import nn
from d2l import torch as d2l
from torch.nn import functional as F
import matplotlib.pyplot as plt
import math

batch_size = 32
vocab_size = 35
epochs, lr = 2000, 1

train_iter, vocab = d2l.load_data_time_machine(batch_size, vocab_size)
test_iter,_ = d2l.load_data_time_machine(1, vocab_size)

class GRU(nn.Module):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.rnn = nn.GRU(input_size=28, hidden_size=64, num_layers=1)
self.l = nn.Linear(64,28)
def forward(self, x, state):
x = F.one_hot(x.T,28).type(torch.float32)
h,state = self.rnn(x,state)
o = self.l(h.reshape(-1,64))
return o, state
net = GRU()
loss_fn = nn.CrossEntropyLoss()
updater = torch.optim.SGD(params=net.parameters(),lr=lr)
state = torch.zeros((1,32,64), dtype=torch.float32)

if __name__=='__main__':
device = torch.device('cuda')
net.to(device)
state = state.to(device)
for e in range(epochs):
for x,y in train_iter:
x,y = x.to(device), y.to(device)
state = state.detach()
y_hat,state = net(x,state)
# 计算loss展平
y = y.T.reshape(-1)
updater.zero_grad()
loss = loss_fn(y_hat, y.long())
loss.backward()
updater.step()
if e%100 == 0:
print('Perplexity: ', math.exp(loss.item()))
1
2
3
4
5
6
7
8
9
10
11
12
state = torch.zeros((1,1,64), dtype=torch.float32)
state = state.to(device)
perplexity = 0
n = 0
for x,y in test_iter:
x,y = x.to(device), y.to(device)
y = y.reshape(-1)
y_hat, state = net(x, state)
loss = loss_fn(y_hat, y)
perplexity += math.exp(loss.item())
n += 1
print(perplexity/n)

输出困惑度为2.93而之前RNN实现的困惑都输出为4.21。