update
# Updated model.modules: # Fixed bugs in Conv1dBanks module # Optimized the debugging scheme # Implement the CBHG module
This commit is contained in:
parent
2194cdc1f3
commit
250b0dc1fd
|
@ -1,6 +1,5 @@
|
|||
import torch
|
||||
import torch.nn.functional as Func
|
||||
from torch.nn import Module
|
||||
from torch.nn import Module, GRU
|
||||
from torch.nn import Linear, Conv1d, MaxPool1d, Dropout, BatchNorm1d, ReLU, Sigmoid
|
||||
|
||||
|
||||
|
@ -129,8 +128,6 @@ class Conv1dBanks(Module):
|
|||
def __init__(self, k, in_dims, out_dims, activation):
|
||||
super(Conv1dBanks, self).__init__()
|
||||
|
||||
self.conv1_norm_outputs = []
|
||||
|
||||
self.k = k
|
||||
self.in_dims = in_dims
|
||||
self.out_dims = out_dims
|
||||
|
@ -138,22 +135,20 @@ class Conv1dBanks(Module):
|
|||
|
||||
def forward(self, inputs):
|
||||
# inputs : (N, in_dims , L_in)
|
||||
conv1_norm_outputs = []
|
||||
for k_size in range(1, 1 + self.k):
|
||||
conv1d_norm = Conv1dNorm(self.in_dims, self.out_dims, k_size, self.activation)
|
||||
# conv1d_norm_outputs : (N, out_dims, L_out)
|
||||
# L_in == L_out
|
||||
conv1d_norm_outputs = conv1d_norm(inputs)
|
||||
self.conv1_norm_outputs.append(conv1d_norm_outputs)
|
||||
conv1_norm_outputs.append(conv1d_norm_outputs)
|
||||
|
||||
# conv1d_banks : (N, k*out_dims, L_out)
|
||||
conv1d_banks = torch.cat(self.conv1_norm_outputs, 1)
|
||||
conv1d_banks = torch.cat(conv1_norm_outputs, 1)
|
||||
|
||||
return conv1d_banks
|
||||
|
||||
|
||||
# TODO : Add GRU Model
|
||||
|
||||
|
||||
class CBHG(Module):
|
||||
def __init__(self, k, num_highway_blocks, in_dims, out_dims, activation):
|
||||
super(CBHG, self).__init__()
|
||||
|
@ -172,34 +167,45 @@ class CBHG(Module):
|
|||
|
||||
self.highway = HighwayNet(in_dims=out_dims)
|
||||
|
||||
self.gru = GRU(out_dims, out_dims, batch_first=True, bidirectional=True)
|
||||
|
||||
def forward(self, inputs):
|
||||
# inputs : (N, in_dims, L_in)
|
||||
# conv1d_banks_outputs : (N, k*out_dims, L_in)
|
||||
conv1d_banks_outputs = self.conv1d_banks(inputs)
|
||||
print("conv1d_banks_outputs : " + str(conv1d_banks_outputs.shape))
|
||||
|
||||
# Cut out the rest
|
||||
max_pool1d_outputs = self.max_pool1d(conv1d_banks_outputs)
|
||||
max_pool1d_outputs = max_pool1d_outputs[:, :, :-1]
|
||||
# max_pool1d_outputs : (N, k*out_dims, L_in)
|
||||
print("max_pool1d_outputs : " + str(max_pool1d_outputs.shape))
|
||||
|
||||
assert conv1d_banks_outputs.shape == max_pool1d_outputs.shape
|
||||
|
||||
# projection1_outputs : (N, out_dims, L_in)
|
||||
projection1_outputs = self.projection1(max_pool1d_outputs)
|
||||
print("projection1_outputs : " + str(projection1_outputs.shape))
|
||||
# projection2_outputs : (N, out_dims, L_in)
|
||||
projection2_outputs = self.projection2(projection1_outputs)
|
||||
print("projection2_outputs : " + str(projection2_outputs.shape))
|
||||
|
||||
# residual_connections : (N, out_dims, L_in)
|
||||
residual_connections = projection2_outputs + inputs
|
||||
highway_data = residual_connections.transpose(2, 1)
|
||||
# highway_data : (N, L_in, out_dims)
|
||||
|
||||
# highway_data : (N, L_in, out_dims)
|
||||
for i in range(self.num_highway_blocks):
|
||||
highway_data = self.highway(highway_data)
|
||||
|
||||
# TODO : Add GRU Model
|
||||
# gru_output : (N, L_in, out_dims*2)
|
||||
gru_output, _ = self.gru(highway_data)
|
||||
|
||||
return highway_data
|
||||
debug = False
|
||||
if debug:
|
||||
print("conv1d_banks_outputs : " + str(conv1d_banks_outputs.shape))
|
||||
print("max_pool1d_outputs : " + str(max_pool1d_outputs.shape))
|
||||
print("projection1_outputs : " + str(projection1_outputs.shape))
|
||||
print("projection2_outputs : " + str(projection2_outputs.shape))
|
||||
print("highway_data_1 : " + str(highway_data.shape))
|
||||
print("highway_data_2 : " + str(highway_data.shape))
|
||||
print("gru_output : " + str(gru_output.shape))
|
||||
|
||||
return gru_output
|
||||
|
|
Loading…
Reference in New Issue