# Updated model.modules:
#     Fixed bugs in Conv1dBanks module
#     Optimized the debugging scheme
#     Implement the CBHG module
This commit is contained in:
miaomiaomiao-LJY 2021-04-13 10:53:08 +08:00
parent 2194cdc1f3
commit 250b0dc1fd
1 changed files with 21 additions and 15 deletions

View File

@ -1,6 +1,5 @@
import torch
import torch.nn.functional as Func
from torch.nn import Module
from torch.nn import Module, GRU
from torch.nn import Linear, Conv1d, MaxPool1d, Dropout, BatchNorm1d, ReLU, Sigmoid
@ -129,8 +128,6 @@ class Conv1dBanks(Module):
def __init__(self, k, in_dims, out_dims, activation):
super(Conv1dBanks, self).__init__()
self.conv1_norm_outputs = []
self.k = k
self.in_dims = in_dims
self.out_dims = out_dims
@ -138,22 +135,20 @@ class Conv1dBanks(Module):
def forward(self, inputs):
# inputs : (N, in_dims , L_in)
conv1_norm_outputs = []
for k_size in range(1, 1 + self.k):
conv1d_norm = Conv1dNorm(self.in_dims, self.out_dims, k_size, self.activation)
# conv1d_norm_outputs : (N, out_dims, L_out)
# L_in == L_out
conv1d_norm_outputs = conv1d_norm(inputs)
self.conv1_norm_outputs.append(conv1d_norm_outputs)
conv1_norm_outputs.append(conv1d_norm_outputs)
# conv1d_banks : (N, k*out_dims, L_out)
conv1d_banks = torch.cat(self.conv1_norm_outputs, 1)
conv1d_banks = torch.cat(conv1_norm_outputs, 1)
return conv1d_banks
# TODO : Add GRU Model
class CBHG(Module):
def __init__(self, k, num_highway_blocks, in_dims, out_dims, activation):
super(CBHG, self).__init__()
@ -172,34 +167,45 @@ class CBHG(Module):
self.highway = HighwayNet(in_dims=out_dims)
self.gru = GRU(out_dims, out_dims, batch_first=True, bidirectional=True)
def forward(self, inputs):
# inputs : (N, in_dims, L_in)
# conv1d_banks_outputs : (N, k*out_dims, L_in)
conv1d_banks_outputs = self.conv1d_banks(inputs)
print("conv1d_banks_outputs : " + str(conv1d_banks_outputs.shape))
# Cut out the rest
max_pool1d_outputs = self.max_pool1d(conv1d_banks_outputs)
max_pool1d_outputs = max_pool1d_outputs[:, :, :-1]
# max_pool1d_outputs : (N, k*out_dims, L_in)
print("max_pool1d_outputs : " + str(max_pool1d_outputs.shape))
assert conv1d_banks_outputs.shape == max_pool1d_outputs.shape
# projection1_outputs : (N, out_dims, L_in)
projection1_outputs = self.projection1(max_pool1d_outputs)
print("projection1_outputs : " + str(projection1_outputs.shape))
# projection2_outputs : (N, out_dims, L_in)
projection2_outputs = self.projection2(projection1_outputs)
print("projection2_outputs : " + str(projection2_outputs.shape))
# residual_connections : (N, out_dims, L_in)
residual_connections = projection2_outputs + inputs
highway_data = residual_connections.transpose(2, 1)
# highway_data : (N, L_in, out_dims)
# highway_data : (N, L_in, out_dims)
for i in range(self.num_highway_blocks):
highway_data = self.highway(highway_data)
# TODO : Add GRU Model
# gru_output : (N, L_in, out_dims*2)
gru_output, _ = self.gru(highway_data)
return highway_data
debug = False
if debug:
print("conv1d_banks_outputs : " + str(conv1d_banks_outputs.shape))
print("max_pool1d_outputs : " + str(max_pool1d_outputs.shape))
print("projection1_outputs : " + str(projection1_outputs.shape))
print("projection2_outputs : " + str(projection2_outputs.shape))
print("highway_data_1 : " + str(highway_data.shape))
print("highway_data_2 : " + str(highway_data.shape))
print("gru_output : " + str(gru_output.shape))
return gru_output