Created
June 7, 2020 18:13
-
-
Save gautham20/4536513280de7bfd6c802079c48a23d1 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class StoreItemDataset(Dataset): | |
def __init__(self, cat_columns=[], num_columns=[], embed_vector_size=None, decoder_input=True, ohe_cat_columns=False): | |
super().__init__() | |
self.sequence_data = None | |
self.cat_columns = cat_columns | |
self.num_columns = num_columns | |
self.cat_classes = {} | |
self.cat_embed_shape = [] | |
self.cat_embed_vector_size = embed_vector_size if embed_vector_size is not None else {} | |
self.pass_decoder_input=decoder_input | |
self.ohe_cat_columns = ohe_cat_columns | |
self.cat_columns_to_decoder = False | |
def get_embedding_shape(self): | |
return self.cat_embed_shape | |
def load_sequence_data(self, processed_data): | |
self.sequence_data = processed_data | |
def process_cat_columns(self, column_map=None): | |
column_map = column_map if column_map is not None else {} | |
for col in self.cat_columns: | |
self.sequence_data[col] = self.sequence_data[col].astype('category') | |
if col in column_map: | |
self.sequence_data[col] = self.sequence_data[col].cat.set_categories(column_map[col]).fillna('#NA#') | |
else: | |
self.sequence_data[col].cat.add_categories('#NA#', inplace=True) | |
self.cat_embed_shape.append((len(self.sequence_data[col].cat.categories), self.cat_embed_vector_size.get(col, 50))) | |
def __len__(self): | |
return len(self.sequence_data) | |
def __getitem__(self, idx): | |
row = self.sequence_data.iloc[[idx]] | |
x_inputs = [torch.tensor(row['x_sequence'].values[0], dtype=torch.float32)] | |
y = torch.tensor(row['y_sequence'].values[0], dtype=torch.float32) | |
if self.pass_decoder_input: | |
decoder_input = torch.tensor(row['y_sequence'].values[0][:, 1:], dtype=torch.float32) | |
if len(self.num_columns) > 0: | |
for col in self.num_columns: | |
num_tensor = torch.tensor([row[col].values[0]], dtype=torch.float32) | |
x_inputs[0] = torch.cat((x_inputs[0], num_tensor.repeat(x_inputs[0].size(0)).unsqueeze(1)), axis=1) | |
decoder_input = torch.cat((decoder_input, num_tensor.repeat(decoder_input.size(0)).unsqueeze(1)), axis=1) | |
if len(self.cat_columns) > 0: | |
if self.ohe_cat_columns: | |
for ci, (num_classes, _) in enumerate(self.cat_embed_shape): | |
col_tensor = torch.zeros(num_classes, dtype=torch.float32) | |
col_tensor[row[self.cat_columns[ci]].cat.codes.values[0]] = 1.0 | |
col_tensor_x = col_tensor.repeat(x_inputs[0].size(0), 1) | |
x_inputs[0] = torch.cat((x_inputs[0], col_tensor_x), axis=1) | |
if self.pass_decoder_input and self.cat_columns_to_decoder: | |
col_tensor_y = col_tensor.repeat(decoder_input.size(0), 1) | |
decoder_input = torch.cat((decoder_input, col_tensor_y), axis=1) | |
else: | |
cat_tensor = torch.tensor( | |
[row[col].cat.codes.values[0] for col in self.cat_columns], | |
dtype=torch.long | |
) | |
x_inputs.append(cat_tensor) | |
if self.pass_decoder_input: | |
x_inputs.append(decoder_input) | |
y = torch.tensor(row['y_sequence'].values[0][:, 0], dtype=torch.float32) | |
if len(x_inputs) > 1: | |
return tuple(x_inputs), y | |
return x_inputs[0], y |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment