36 lines
1.1 KiB
Python
36 lines
1.1 KiB
Python
|
#!/usr/bin/python
|
||
|
# -*- coding: utf-8 -*-
|
||
|
#####################################
|
||
|
# File name : rnn_model.py
|
||
|
# Create date : 2019-02-16 15:37
|
||
|
# Modified date : 2019-02-19 13:07
|
||
|
# Author : DARREN
|
||
|
# Describe : not set
|
||
|
# Email : lzygzh@126.com
|
||
|
#####################################
|
||
|
from __future__ import division
|
||
|
from __future__ import print_function
|
||
|
|
||
|
import torch
|
||
|
import torch.nn as nn
|
||
|
|
||
|
class RNN(nn.Module):
|
||
|
def __init__(self, input_size, hidden_size, output_size):
|
||
|
super(RNN, self).__init__()
|
||
|
|
||
|
self.hidden_size = hidden_size
|
||
|
|
||
|
self.input_to_hidden = nn.Linear(input_size + hidden_size, hidden_size)
|
||
|
self.input_to_output = nn.Linear(input_size + hidden_size, output_size)
|
||
|
self.softmax = nn.LogSoftmax(dim=1)
|
||
|
|
||
|
def forward(self, input, hidden):
|
||
|
combined = torch.cat((input, hidden), 1)
|
||
|
hidden = self.input_to_hidden(combined)
|
||
|
output = self.input_to_output(combined)
|
||
|
output = self.softmax(output)
|
||
|
return output, hidden
|
||
|
|
||
|
def init_hidden(self):
|
||
|
return torch.zeros(1, self.hidden_size)
|