36 lines
1.1 KiB
Python
36 lines
1.1 KiB
Python
#!/usr/bin/python
|
|
# -*- coding: utf-8 -*-
|
|
#####################################
|
|
# File name : rnn_model.py
|
|
# Create date : 2019-02-16 15:37
|
|
# Modified date : 2019-02-19 13:07
|
|
# Author : DARREN
|
|
# Describe : not set
|
|
# Email : lzygzh@126.com
|
|
#####################################
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
class RNN(nn.Module):
|
|
def __init__(self, input_size, hidden_size, output_size):
|
|
super(RNN, self).__init__()
|
|
|
|
self.hidden_size = hidden_size
|
|
|
|
self.input_to_hidden = nn.Linear(input_size + hidden_size, hidden_size)
|
|
self.input_to_output = nn.Linear(input_size + hidden_size, output_size)
|
|
self.softmax = nn.LogSoftmax(dim=1)
|
|
|
|
def forward(self, input, hidden):
|
|
combined = torch.cat((input, hidden), 1)
|
|
hidden = self.input_to_hidden(combined)
|
|
output = self.input_to_output(combined)
|
|
output = self.softmax(output)
|
|
return output, hidden
|
|
|
|
def init_hidden(self):
|
|
return torch.zeros(1, self.hidden_size)
|