1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59
| import torch from d2l import torch as d2l from torch.nn import functional as F from torch import nn import math
def get_params(vocab_size, num_hiddens, device): input_size = output_size = vocab_size def normal(shape): return torch.randn(size=shape, device=device)*0.01
w_xh = normal((input_size, num_hiddens)) w_hh = normal((num_hiddens, num_hiddens)) b_h = normal((num_hiddens, )) w_ho = normal((num_hiddens, output_size)) b_o = normal((output_size, ))
params = [w_xh, w_hh, b_h, w_ho, b_o] for param in params: param.requires_grad_(True) return params
def init_state(batch_size, num_hiddens, device): return (torch.zeros((batch_size, num_hiddens), device=device),)
def rnn(inputs: torch.Tensor, params, state)->torch.Tensor: w_xh, w_hh, b_h, w_ho, b_o = params H, = state
outputs = [] for X in inputs: H = F.relu(torch.matmul(X, w_xh) + torch.matmul(H, w_hh) + b_h) Y = torch.matmul(H, w_ho) + b_o outputs.append(Y) return torch.stack(outputs, dim=0), H
class RNN(object): def __init__(self, vocab_size, num_hiddens, device, init_state, get_params, forward_fn): self.vocab_size, self.num_hiddens = vocab_size, num_hiddens self.state = init_state(batch_size, num_hiddens, device) self.params = get_params(vocab_size, num_hiddens, device) self.forward_fn = forward_fn def __call__(self, X:torch.tensor, state=None): if state is not None: self.state = state for s in self.state: s.detach_() inputs = F.one_hot(X.T, len(vocab)).type(torch.float32) outputs, H = self.forward_fn(inputs, self.params, self.state) self.state = (H, ) return outputs, self.state
|