This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch.nn as nn | |
import torch.nn.functional as F | |
import torch | |
class SimpleRNN(nn.Module): | |
""" | |
Neural Network Module with an embedding layer, a RNN module and an output layer | |
Arguments: | |
input_size(int) -- length of the dictionary of embeddings |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from torch import * | |
import torch.nn.functional as F | |
import torch.nn as nn | |
class RNNModel(nn.Module): | |
""" | |
Neural Network Module with an embedding layer, a recurent module and an output linear layer | |
Arguments: | |
rnn_type(str) -- type of rnn module to use options are ['LSTM', 'GRU', 'RNN_TANH', 'RNN_RELU'] |