diff --git a/model/modules.py b/model/modules.py index 6423a2d..5849ed9 100644 --- a/model/modules.py +++ b/model/modules.py @@ -1,6 +1,5 @@ import torch -import torch.nn.functional as Func -from torch.nn import Module +from torch.nn import Module, GRU from torch.nn import Linear, Conv1d, MaxPool1d, Dropout, BatchNorm1d, ReLU, Sigmoid @@ -129,8 +128,6 @@ class Conv1dBanks(Module): def __init__(self, k, in_dims, out_dims, activation): super(Conv1dBanks, self).__init__() - self.conv1_norm_outputs = [] - self.k = k self.in_dims = in_dims self.out_dims = out_dims @@ -138,22 +135,20 @@ class Conv1dBanks(Module): def forward(self, inputs): # inputs : (N, in_dims , L_in) + conv1_norm_outputs = [] for k_size in range(1, 1 + self.k): conv1d_norm = Conv1dNorm(self.in_dims, self.out_dims, k_size, self.activation) # conv1d_norm_outputs : (N, out_dims, L_out) # L_in == L_out conv1d_norm_outputs = conv1d_norm(inputs) - self.conv1_norm_outputs.append(conv1d_norm_outputs) + conv1_norm_outputs.append(conv1d_norm_outputs) # conv1d_banks : (N, k*out_dims, L_out) - conv1d_banks = torch.cat(self.conv1_norm_outputs, 1) + conv1d_banks = torch.cat(conv1_norm_outputs, 1) return conv1d_banks -# TODO : Add GRU Model - - class CBHG(Module): def __init__(self, k, num_highway_blocks, in_dims, out_dims, activation): super(CBHG, self).__init__() @@ -172,34 +167,45 @@ class CBHG(Module): self.highway = HighwayNet(in_dims=out_dims) + self.gru = GRU(out_dims, out_dims, batch_first=True, bidirectional=True) + def forward(self, inputs): # inputs : (N, in_dims, L_in) # conv1d_banks_outputs : (N, k*out_dims, L_in) conv1d_banks_outputs = self.conv1d_banks(inputs) - print("conv1d_banks_outputs : " + str(conv1d_banks_outputs.shape)) # Cut out the rest max_pool1d_outputs = self.max_pool1d(conv1d_banks_outputs) max_pool1d_outputs = max_pool1d_outputs[:, :, :-1] # max_pool1d_outputs : (N, k*out_dims, L_in) - print("max_pool1d_outputs : " + str(max_pool1d_outputs.shape)) assert conv1d_banks_outputs.shape == max_pool1d_outputs.shape # projection1_outputs : (N, out_dims, L_in) projection1_outputs = self.projection1(max_pool1d_outputs) - print("projection1_outputs : " + str(projection1_outputs.shape)) # projection2_outputs : (N, out_dims, L_in) projection2_outputs = self.projection2(projection1_outputs) - print("projection2_outputs : " + str(projection2_outputs.shape)) # residual_connections : (N, out_dims, L_in) residual_connections = projection2_outputs + inputs highway_data = residual_connections.transpose(2, 1) + # highway_data : (N, L_in, out_dims) + # highway_data : (N, L_in, out_dims) for i in range(self.num_highway_blocks): highway_data = self.highway(highway_data) - # TODO : Add GRU Model + # gru_output : (N, L_in, out_dims*2) + gru_output, _ = self.gru(highway_data) - return highway_data + debug = False + if debug: + print("conv1d_banks_outputs : " + str(conv1d_banks_outputs.shape)) + print("max_pool1d_outputs : " + str(max_pool1d_outputs.shape)) + print("projection1_outputs : " + str(projection1_outputs.shape)) + print("projection2_outputs : " + str(projection2_outputs.shape)) + print("highway_data_1 : " + str(highway_data.shape)) + print("highway_data_2 : " + str(highway_data.shape)) + print("gru_output : " + str(gru_output.shape)) + + return gru_output