Last active
March 8, 2021 23:05
-
-
Save qmeeus/a1761836eecf382eb8b36ae218fa8dc8 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import torch | |
import torch.nn as nn | |
from sklearn import metrics | |
from skorch import NeuralNet | |
from skorch.callbacks import EarlyStopping, EpochScoring, LRScheduler | |
from torch.utils.data import DataLoader, Dataset | |
class AttentiveRecurrentDecoder(nn.Module): | |
def __init__(self, | |
input_dim, | |
output_dim, | |
d_model, | |
rnn_type="lstm", | |
num_layers=1, | |
rnn_dropout=0.1, | |
num_heads=8, | |
attn_dropout=0.1, | |
attn_bias=True): | |
super(AttentiveRecurrentDecoder, self).__init__() | |
RNNClass = getattr(nn, rnn_type.upper()) | |
self.encoder = RNNClass( | |
input_dim, | |
d_model, | |
num_layers, | |
dropout=rnn_dropout, | |
bidirectional=True, | |
batch_first=True | |
) | |
self.attention = nn.MultiheadAttention( | |
embed_dim=d_model * 2, | |
num_heads=num_heads, | |
dropout=attn_dropout, | |
bias=attn_bias | |
) | |
self.output_layer = nn.Linear(d_model * 2, output_dim) | |
def forward(self, inputs, input_lengths, labels=None, return_attention=False): | |
packed_inputs = nn.utils.rnn.pack_padded_sequence( | |
inputs, input_lengths, batch_first=True, enforce_sorted=False) | |
packed_outputs, hidden = self.encoder(packed_inputs) | |
outputs, _ = nn.utils.rnn.pad_packed_sequence(packed_outputs, batch_first=True) | |
if isinstance(hidden, tuple): # LSTM | |
hidden = hidden[1] # take the cell state | |
hidden = torch.cat([hidden[-1], hidden[-2]], dim=1) | |
query = hidden.unsqueeze(1).transpose(0, 1).contiguous() | |
key = outputs.transpose(0, 1).contiguous() | |
linear_combination, energy = self.attention(query, key, key) | |
linear_combination = linear_combination.squeeze(0) | |
logits = self.output_layer(linear_combination) | |
if return_attention: | |
return logits, energy | |
return logits | |
class SequenceDataset(Dataset): | |
def __init__(self, features, labels): | |
self.features = features | |
self.feature_lengths = [len(feats) for feats in self.features] | |
self.labels = labels | |
self.input_dim = self.features[0].shape[-1] | |
self.output_dim = self.labels[0].shape[-1] | |
def __getitem__(self, index): | |
inputs = torch.tensor(self.features[index]) | |
input_lengths = torch.tensor(self.feature_lengths[index]) | |
labels = torch.tensor(self.labels[index]) | |
return inputs, input_lengths, labels | |
def __len__(self): | |
return len(self.features) | |
@staticmethod | |
def data_collator(batch): | |
""" | |
batch should be a list of (sequence, target, length) tuples... | |
Returns a padded tensor of sequences sorted from longest to shortest, | |
""" | |
features, lengths, labels = map(list, zip(*batch)) | |
features = nn.utils.rnn.pad_sequence(features, batch_first=True, padding_value=0.) | |
lengths = torch.stack(lengths, 0) | |
labels = torch.stack(labels, 0) | |
# See https://skorch.readthedocs.io/en/latest/user/neuralnet.html#multiple-input-arguments | |
return {"inputs": features.float(), "input_lengths": lengths}, labels | |
def error_rate(y_true, y_pred): | |
assert y_true is not None | |
y_pred = y_pred[:, 1, :] > .5 | |
return 1 - (y_true == y_pred).all(-1).mean() | |
def accuracy(y_true, y_pred): | |
assert y_true is not None | |
y_pred = y_pred[:, 1, :] > .5 | |
return (y_true == y_pred).mean() | |
def generate_random_data(): | |
M, s, S, D, K = 1000, 10, 50, 128, 32 | |
input_lengths = np.random.randint(s, S, M) | |
features = [np.random.randn(l, D) for l in input_lengths] | |
target = np.zeros((M, K)) | |
for i, ks in enumerate([np.random.choice(np.arange(0, K), 3, replace=False) for _ in range(M)]): | |
target[i, ks] = 1 | |
return features, target | |
def main(): | |
features, target = generate_random_data() | |
dataset = SequenceDataset(features, target) | |
net = NeuralNet( | |
module=AttentiveRecurrentDecoder, | |
module__input_dim=dataset.input_dim, | |
module__output_dim=dataset.output_dim, | |
module__d_model=64, | |
criterion=nn.BCEWithLogitsLoss, | |
iterator_train__collate_fn=dataset.data_collator, | |
iterator_valid__collate_fn=dataset.data_collator, | |
batch_size=64, | |
max_epochs=50, | |
lr=0.2, | |
callbacks=[ | |
EpochScoring(scoring=metrics.make_scorer(error_rate), lower_is_better=True), | |
EpochScoring(scoring=metrics.make_scorer(accuracy), lower_is_better=False), | |
EarlyStopping(monitor="valid_loss", patience=5), | |
LRScheduler(policy="ReduceLROnPlateau", patience=3) | |
], | |
device="cuda" | |
) | |
net.fit(dataset) | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment