diff --git a/GRU.lua b/GRU.lua new file mode 100644 index 00000000..dc85d63c --- /dev/null +++ b/GRU.lua @@ -0,0 +1,252 @@ +require 'torch' +require 'nn' + + +local layer, parent = torch.class('nn.GRU', 'nn.Module') + +--[[ +If we add up the sizes of all the tensors for output, gradInput, weights, +gradWeights, and temporary buffers, we get that a SequenceGRU stores this many +scalar values: + +NTD + 4NTH + 5NH + 6H^2 + 6DH + 7H + +Note that this class doesn't own input or gradOutput, so you'll +see a bit higher memory usage in practice. +--]] + +function layer:__init(input_dim, hidden_dim) + parent.__init(self) + + local D, H = input_dim, hidden_dim + self.input_dim, self.hidden_dim = D, H + + self.weight = torch.Tensor(D + H, 3 * H) + self.gradWeight = torch.Tensor(D + H, 3 * H):zero() + self.bias = torch.Tensor(3 * H) + self.gradBias = torch.Tensor(3 * H):zero() + self:reset() + + self.cell = torch.Tensor() -- This will be (N, T, H) + self.gates = torch.Tensor() -- This will be (N, T, 3H) + self.buffer1 = torch.Tensor() -- This will be (N, H) + self.buffer2 = torch.Tensor() -- This will be (N, H) + self.buffer3 = torch.Tensor() -- This will be (H,) + self.grad_a_buffer = torch.Tensor() -- This will be (N, 3H) + self.h0 = torch.Tensor() + self.remember_states = false + self.grad_h0 = torch.Tensor() + self.grad_x = torch.Tensor() + self.gradInput = {self.grad_c0, self.grad_h0, self.grad_x} +end + + +function layer:reset(std) + if not std then + std = 1.0 / math.sqrt(self.hidden_dim + self.input_dim) + end + --self.bias:zero() + self.bias:normal(0,std) --self.bias[{{self.hidden_dim + 1, 2 * self.hidden_dim}}]:fill(1) + self.weight:normal(0, std) + return self +end + + +function layer:resetStates() + self.h0 = self.h0.new() +end + + +local function check_dims(x, dims) + assert(x:dim() == #dims) + for i, d in ipairs(dims) do + assert(x:size(i) == d) + end +end + + +function layer:_unpack_input(input) + local h0, x = nil, nil + + if torch.type(input) == 'table' and #input == 2 then + h0, x = unpack(input) + elseif torch.isTensor(input) then + x = input + else + assert(false, 'invalid input') + end + return h0, x +end + + +function layer:_get_sizes(input, gradOutput) + local h0, x = self:_unpack_input(input) + local N, T = x:size(1), x:size(2) + local H, D = self.hidden_dim, self.input_dim + check_dims(x, {N, T, D}) + if h0 then + check_dims(h0, {N, H}) + end + + if gradOutput then + check_dims(gradOutput, {N, T, H}) + end + return N, T, D, H +end + + +--[[ +Input: +- h0: Initial hidden state, (N, H) +- x: Input sequence, (N, T, D) + +Output: +- h: Sequence of hidden states, (N, T, H) +--]] + + +function layer:updateOutput(input) + local h0, x = self:_unpack_input(input) + local N, T, D, H = self:_get_sizes(input) + + self._return_grad_h0 = (h0 ~= nil) + + if not h0 then + h0 = self.h0 + if h0:nElement() == 0 or not self.remember_states then + h0:resize(N, H):zero() + elseif self.remember_states then + local prev_N, prev_T = self.output:size(1), self.output:size(2) + assert(prev_N == N, 'batch sizes must be the same to remember states') + h0:copy(self.output[{{}, prev_T}]) + end + end + + local bias_expand = self.bias:view(1, 3 * H):expand(N, 3 * H) + local Wx = self.weight[{{1, D}}] + local Wh = self.weight[{{D + 1, D + H}}] + + local h = self.output + h:resize(N, T, H):zero() + local prev_h = h0 + self.gates:resize(N, T, 3 * H):zero() + for t = 1, T do + local cur_x = x[{{}, t}] + local next_h = h[{{}, t}] + local cur_gates = self.gates[{{}, t}] + cur_gates:addmm(bias_expand, cur_x, Wx) + cur_gates[{{}, {1, 2 * H}}]:addmm(prev_h, Wh[{{}, {1, 2 * H}}]) + cur_gates[{{}, {1, 2 * H}}]:sigmoid() + + local u = cur_gates[{{}, {1, H}}] --update gate : u = sig(Wx * x + Wh * prev_h + b) + local r = cur_gates[{{}, {H + 1, 2 * H}}] --reset gate : r = sig(Wx * x + Wh * prev_h + b) + next_h:cmul(r, prev_h) --temporary buffer : r . prev_h + cur_gates[{{}, {2 * H + 1, 3 * H}}]:addmm(next_h, Wh[{{}, {2 * H + 1, 3 * H}}]) -- hc += Wh * r . prev_h + local hc = cur_gates[{{}, {2 * H + 1, 3 * H}}]:tanh() --hidden candidate : hc = tanh(Wx * x + Wh * r . prev_h + b) + next_h:addcmul(prev_h,-1, u, prev_h) + next_h:addcmul(u,hc) --next_h = (1-u) . prev_h + u . hc + prev_h = next_h + end + + return self.output +end + + +function layer:backward(input, gradOutput, scale) + scale = scale or 1.0 + local h0, x = self:_unpack_input(input) + + if not h0 then h0 = self.h0 end + + local grad_h0, grad_x = self.grad_h0, self.grad_x + local h= self.output + local grad_h = gradOutput + + local N, T, D, H = self:_get_sizes(input, gradOutput) + local Wx = self.weight[{{1, D}}] + local Wh = self.weight[{{D + 1, D + H}}] + local grad_Wx = self.gradWeight[{{1, D}}] + local grad_Wh = self.gradWeight[{{D + 1, D + H}}] + local grad_b = self.gradBias + + grad_h0:resizeAs(h0):zero() + + grad_x:resizeAs(x):zero() + local grad_next_h = self.buffer1:resizeAs(h0):zero() + local temp_buffer = self.buffer2:resizeAs(h0):zero() + for t = T, 1, -1 do + local next_h= h[{{}, t}] + local prev_h= nil + if t == 1 then + prev_h = h0 + else + prev_h = h[{{}, t - 1}] + end + grad_next_h:add(grad_h[{{}, t}]) + + local u = self.gates[{{}, t, {1, H}}] + local r = self.gates[{{}, t, {H + 1, 2 * H}}] + local hc = self.gates[{{}, t, {2 * H + 1, 3 * H}}] + + + local grad_a = self.grad_a_buffer:resize(N, 3 * H):zero() + local grad_au = grad_a[{{}, {1, H}}] + local grad_ar = grad_a[{{}, {H + 1, 2 * H}}] + local grad_ahc = grad_a[{{}, {2 * H + 1, 3 * H}}] + + -- We will use grad_au as temporary buffer + -- to compute grad_ahc. + + local grad_hc = grad_au:fill(0):add(grad_next_h ):cmul(u) + grad_ahc:fill(1):addcmul(-1, hc,hc):cmul(grad_hc) + local grad_r = grad_au:fill(0):addmm(grad_ahc, Wh[{{}, {2 * H + 1, 3 * H}}]:t() ):cmul(prev_h) + grad_ar:fill(1):add(-1, r):cmul(r):cmul(grad_r) + + temp_buffer:fill(0):add(hc):add(-1, prev_h) + grad_au:fill(1):add(-1, u):cmul(u):cmul(temp_buffer):cmul(grad_next_h) + grad_x[{{}, t}]:mm(grad_a, Wx:t()) + grad_Wx:addmm(scale, x[{{}, t}]:t(), grad_a) + grad_Wh[{{}, {1, 2 * H}}]:addmm(scale, prev_h:t(), grad_a[{{}, {1, 2 * H}}]) + + local grad_a_sum = self.buffer3:resize(H):sum(grad_a, 1) + grad_b:add(scale, grad_a_sum) + temp_buffer:fill(0):add(prev_h):cmul(r) + grad_Wh[{{}, {2 * H + 1, 3 * H}}]:addmm(scale, temp_buffer:t(), grad_ahc) + grad_next_h:addcmul(-1, u, grad_next_h) + grad_next_h:addmm(grad_a[{{}, {1, 2 * H}}], Wh[{{}, {1, 2 * H}}]:t()) + temp_buffer:fill(0):addmm(grad_a[{{}, {2 * H + 1, 3 * H}}], Wh[{{}, {2 * H + 1, 3 * H}}]:t()):cmul(r) + grad_next_h:add(temp_buffer) + end + grad_h0:copy(grad_next_h) + + if self._return_grad_h0 then + self.gradInput = {self.grad_h0, self.grad_x} + else + self.gradInput = self.grad_x + end + + return self.gradInput +end + + +function layer:updateGradInput(input, gradOutput) + self:backward(input, gradOutput, 0) +end + + +function layer:accGradParameters(input, gradOutput, scale) + self:backward(input, gradOutput, scale) +end + +function layer:clearState() + self.cell:set() + self.gates:set() + self.buffer1:set() + self.buffer2:set() + self.buffer3:set() + self.grad_a_buffer:set() + + self.grad_h0:set() + self.grad_x:set() + self.output:set() +end diff --git a/LanguageModel.lua b/LanguageModel.lua index d6248184..ee5dee43 100644 --- a/LanguageModel.lua +++ b/LanguageModel.lua @@ -3,6 +3,7 @@ require 'nn' require 'VanillaRNN' require 'LSTM' +require 'GRU' local utils = require 'util.utils' @@ -42,6 +43,8 @@ function LM:__init(kwargs) rnn = nn.VanillaRNN(prev_dim, H) elseif self.model_type == 'lstm' then rnn = nn.LSTM(prev_dim, H) + elseif self.model_type == 'gru' then + rnn = nn.GRU(prev_dim, H) end rnn.remember_states = true table.insert(self.rnns, rnn) @@ -148,11 +151,9 @@ end --[[ Sample from the language model. Note that this will reset the states of the underlying RNNs. - Inputs: - init: String of length T0 - max_length: Number of characters to sample - Returns: - sampled: (1, max_length) array of integers, where the first part is init. --]] diff --git a/test/GRU_test.lua b/test/GRU_test.lua new file mode 100644 index 00000000..0f3dfd13 --- /dev/null +++ b/test/GRU_test.lua @@ -0,0 +1,199 @@ +require 'torch' +require 'nn' + +require 'GRU' +local gradcheck = require 'util.gradcheck' +local tests = {} +local tester = torch.Tester() + + +local function check_size(x, dims) + tester:assert(x:dim() == #dims) + for i, d in ipairs(dims) do + tester:assert(x:size(i) == d) + end +end + + +function tests.testForward() + local N, T, D, H = 3, 4, 5, 6 + + local h0 = torch.randn(N, H) + local x = torch.randn(N, T, D) + + local gru = nn.GRU(D, H) + local h = gru:forward{h0, x} + + -- Do a naive forward pass + local naive_h = torch.Tensor(N, T, H) + + + -- Unpack weight, bias for each gate + local Wxu = gru.weight[{{1, D}, {1, H}}] + local Wxr = gru.weight[{{1, D}, {H + 1, 2 * H}}] + local Wxhc = gru.weight[{{1, D}, {2 * H + 1, 3 * H}}] + + + local Whu = gru.weight[{{D + 1, D + H}, {1, H}}] + local Whr = gru.weight[{{D + 1, D + H}, {H + 1, 2 * H}}] + local Whhc = gru.weight[{{D + 1, D + H}, {2 * H + 1, 3 * H}}] + + + local bu = gru.bias[{{1, H}}]:view(1, H):expand(N, H) + local br = gru.bias[{{H + 1, 2 * H}}]:view(1, H):expand(N, H) + local bhc = gru.bias[{{2 * H + 1, 3 * H}}]:view(1, H):expand(N, H) + + + local prev_h = h0:clone() + for t = 1, T do + local xt = x[{{}, t}] + local u = torch.sigmoid(torch.mm(xt, Wxu) + torch.mm(prev_h, Whu) + bu) + local r = torch.sigmoid(torch.mm(xt, Wxr) + torch.mm(prev_h, Whr) + br) + local hc = torch.tanh(torch.mm(xt, Wxhc) + torch.mm(torch.cmul(prev_h,r), Whhc) + bhc) + local next_h = torch.cmul(hc, u) + prev_h - torch.cmul(prev_h, u) + + naive_h[{{}, t}] = next_h + + prev_h = next_h + end + + tester:assertTensorEq(naive_h, h, 1e-10) +end + + +function tests.gradcheck() + local N, T, D, H = 2, 3, 4, 5 + + local x = torch.randn(N, T, D) + local h0 = torch.randn(N, H) + + + local gru = nn.GRU(D, H) + local h = gru:forward{h0, x} + + local dh = torch.randn(#h) + + gru:zeroGradParameters() + local dh0, dx = unpack(gru:backward({h0, x}, dh)) + local dw = gru.gradWeight:clone() + local db = gru.gradBias:clone() + + local function fx(x) return gru:forward{h0, x} end + local function fh0(h0) return gru:forward{h0, x} end + + local function fw(w) + local old_w = gru.weight + gru.weight = w + local out = gru:forward{ h0, x} + gru.weight = old_w + return out + end + + local function fb(b) + local old_b = gru.bias + gru.bias = b + local out = gru:forward{h0, x} + gru.bias = old_b + return out + end + + local dx_num = gradcheck.numeric_gradient(fx, x, dh) + local dh0_num = gradcheck.numeric_gradient(fh0, h0, dh) + + local dw_num = gradcheck.numeric_gradient(fw, gru.weight, dh) + local db_num = gradcheck.numeric_gradient(fb, gru.bias, dh) + + local dx_error = gradcheck.relative_error(dx_num, dx) + local dh0_error = gradcheck.relative_error(dh0_num, dh0) + + local dw_error = gradcheck.relative_error(dw_num, dw) + local db_error = gradcheck.relative_error(db_num, db) + + tester:assertle(dh0_error, 1e-4) + + tester:assertle(dx_error, 1e-5) + tester:assertle(dw_error, 1e-4) + tester:assertle(db_error, 1e-5) +end + + +-- Make sure that everything works correctly when we don't pass an initial cell +-- state; in this case we do pass an initial hidden state and an input sequence +function tests.noCellTest() + local N, T, D, H = 4, 5, 6, 7 + local gru = nn.GRU(D, H) + + for t = 1, 3 do + local x = torch.randn(N, T, D) + local h0 = torch.randn(N, H) + local dout = torch.randn(N, T, H) + + local out = gru:forward{h0, x} + local din = gru:backward({h0, x}, dout) + + tester:assert(torch.type(din) == 'table') + tester:assert(#din == 2) + check_size(din[1], {N, H}) + check_size(din[2], {N, T, D}) + + -- Make sure the initial cell state got reset to zero + --tester:assertTensorEq(gru.c0, torch.zeros(N, H), 0) + end +end + + +-- Make sure that everything works when we don't pass initial hidden or initial +-- cell state; in this case we only pass input sequence of vectors +function tests.noHiddenTest() + local N, T, D, H = 4, 5, 6, 7 + local gru = nn.GRU(D, H) + + for t = 1, 3 do + local x = torch.randn(N, T, D) + local dout = torch.randn(N, T, H) + + local out = gru:forward(x) + local din = gru:backward(x, dout) + + tester:assert(torch.isTensor(din)) + check_size(din, {N, T, D}) + + -- Make sure the initial cell state and initial hidden state are zero + --tester:assertTensorEq(gru.c0, torch.zeros(N, H), 0) + tester:assertTensorEq(gru.h0, torch.zeros(N, H), 0) + end +end + + +function tests.rememberStatesTest() + local N, T, D, H = 5, 6, 7, 8 + local gru = nn.GRU(D, H) + gru.remember_states = true + + local final_h = nil + for t = 1, 4 do + local x = torch.randn(N, T, D) + local dout = torch.randn(N, T, H) + local out = gru:forward(x) + local din = gru:backward(x, dout) + + if t == 1 then + tester:assertTensorEq(gru.h0, torch.zeros(N, H), 0) + elseif t > 1 then + tester:assertTensorEq(gru.h0, final_h, 0) + end + final_h = out[{{}, T}]:clone() + end + + -- Initial states should reset to zero after we call resetStates + gru:resetStates() + local x = torch.randn(N, T, D) + local dout = torch.randn(N, T, H) + gru:forward(x) + gru:backward(x, dout) + tester:assertTensorEq(gru.h0, torch.zeros(N, H), 0) +end + + +tester:add(tests) +tester:run()