1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74
| import torch import torch.nn as nn import torch.nn.functional as F import copy
class Gate(nn.Module): def __init__(self, input_size, hidden_dim): super(Gate, self).__init__() self.linear1 = nn.Linear(hidden_dim, hidden_dim) self.linear2 = nn.Linear(input_size, hidden_dim)
def forward(self, x, h_pre, active_func): h_next = active_func(self.linear1(h_pre) + self.linear2(x)) return h_next
def clones(module, N): "Produce N identical layers." return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])
class LSTMCell(nn.Module):
def __init__(self, input_size, hidden_dim): super(LSTMCell, self).__init__() self.input_size = input_size self.hidden_dim = hidden_dim self.gate = clones(Gate(input_size, hidden_dim), 4)
def forward(self, x, h_pre, c_pre): """ :param x: (batch, input_size) :param h_pre: (batch, hidden_dim) :param c_pre: (batch, hidden_dim) :return: h_next(batch, hidden_dim), c_next(batch, hidden_dim) """ f_t = self.gate[0](x, h_pre, torch.sigmoid) i_t = self.gate[1](x, h_pre, torch.sigmoid) g_t = self.gate[2](x, h_pre, torch.tanh) o_t = self.gate[3](x, h_pre, torch.sigmoid) c_next = f_t * c_pre + i_t * g_t h_next = o_t * torch.tanh(c_next)
return h_next, c_next
class LSTM(nn.Module):
def __init__(self, input_size, hidden_dim): super(LSTM, self).__init__() self.input_size = input_size self.hidden_dim = hidden_dim self.lstm_cell = LSTMCell(input_size, hidden_dim)
def forward(self, x): """ :param x: (seq_len, batch,input_size) :return: output (seq_len, batch, hidden_dim) h_n (1, batch, hidden_dim) c_n (1, batch, hidden_dim) """ seq_len, batch, _ = x.shape h = torch.zeros(batch, self.hidden_dim) c = torch.zeros(batch, self.hidden_dim) output = torch.zeros(seq_len, batch, self.hidden_dim) for i in range(seq_len): inp = x[i, :, :] h, c = self.lstm_cell(inp, h, c) output[i, :, :] = h
h_n = output[-1:, :, :] return output, (h_n, c.unsqueeze(0))
|