Skip to content

Instantly share code, notes, and snippets.

@RXminuS
Created February 1, 2019 00:53
Show Gist options
  • Save RXminuS/aecaf4656bad55b0e3008987f322f34e to your computer and use it in GitHub Desktop.
Save RXminuS/aecaf4656bad55b0e3008987f322f34e to your computer and use it in GitHub Desktop.
Gensim Bug
from typing import Dict, Iterable, List, Any
import json
import argparse
import logging
import multiprocessing
import os
import shutil
import sys
sys.path.insert(0, os.path.dirname(
os.path.abspath(os.path.join(__file__, os.pardir))))
if True: # workaround for isort
import torch
from tensorboardX import SummaryWriter
from allennlp.common.util import prepare_environment, prepare_global_logging
from allennlp.common import Params
from allennlp.common.checks import ConfigurationError
from allennlp.data import Instance, Token
from allennlp.data.fields import TextField
from allennlp.data.dataset_readers.dataset_reader import DatasetReader
from allennlp.common.util import import_submodules
from allennlp.models.archival import CONFIG_NAME
from gensim.models import FastText
from gensim.utils import RULE_KEEP
from gensim.models.callbacks import CallbackAny2Vec
from gensim.models.fasttext import FAST_VERSION
import_submodules("donnanlp")
if os.environ.get("ALLENNLP_DEBUG"):
LEVEL = logging.DEBUG
else:
LEVEL = logging.INFO
logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
level=LEVEL)
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
class GensimDatasetIterator(object):
"""
Gensim requires the dataset to be multiple iterable, so we wrap the normal
dataset_reader in this iterator function to support that.
"""
def __init__(self, dataset_reader: DatasetReader, path: str):
self._dataset_reader = dataset_reader
self._path = path
def __iter__(self):
i: Instance
for i in self._dataset_reader.read(self._path):
tf: TextField = i.fields["text"]
tokens: List[Token] = tf.tokens
yield list([t.text for t in tokens])
class TensorboardWriter:
"""
Wraps a pair of ``SummaryWriter`` instances but is a no-op if they're ``None``.
Allows Tensorboard logging without always checking for Nones first.
"""
def __init__(self, train_log: SummaryWriter = None) -> None:
self._train_log = train_log
@staticmethod
def _item(value: Any):
if hasattr(value, 'item'):
val = value.item()
else:
val = value
return val
def add_train_scalar(self, name: str, value: float, global_step: int) -> None:
# get the scalar
if self._train_log is not None:
self._train_log.add_scalar(
name, self._item(value), global_step)
def add_train_histogram(self, name: str, values: torch.Tensor, global_step: int) -> None:
if self._train_log is not None:
if isinstance(values, torch.Tensor):
values_to_write = values.cpu().data.numpy().flatten()
self._train_log.add_histogram(
name, values_to_write, global_step)
def add_embedding(self, values: torch.Tensor, metadata: List[str], global_step: int) -> None:
if self._train_log is not None:
if isinstance(values, torch.Tensor):
self._train_log.add_embedding(
values, metadata, global_step=global_step)
class GensimCallback(CallbackAny2Vec):
def __init__(self, serialization_dir: str, writer: TensorboardWriter) -> None:
self._serialization_dir = serialization_dir
self._writer = writer
self._global_step = 0
def on_epoch_end(self, model: FastText):
# TODO: Somehow this broke in the latest version of gensim
model.save(os.path.join(self._serialization_dir,
f"_model.{self._global_step}.gensim"), separately=[])
model.wv.save_word2vec_format(os.path.join(
self._serialization_dir, f"_model.{self._global_step}.txt"))
vectors = torch.FloatTensor(model.wv.vectors)
metadata = model.wv.index2word
self._writer.add_embedding(vectors, metadata, self._global_step)
self._writer.add_train_scalar(
"loss", model.running_training_loss, self._global_step)
self._writer.add_train_scalar(
"learning_rate",
model.alpha, self._global_step
)
self._global_step += 1
def main(params: Params, serialization_dir: str, iterations: int = 5, file_friendly_logging: bool = False, force: bool = False):
prepare_environment(params)
create_serialization_dir(params, serialization_dir, force)
prepare_global_logging(serialization_dir, file_friendly_logging)
params.to_file(os.path.join(serialization_dir, CONFIG_NAME))
dataset = dataset_from_params(params)
summary_writer = SummaryWriter(
os.path.join(serialization_dir, "log", "train"))
tensorboard_writer = TensorboardWriter(summary_writer)
gensim_params: Params = params.pop('gensim')
model_type = gensim_params.pop_choice(
"type", ["fasttext", "word2vec"], True)
skip_gram = gensim_params.pop_bool('skip_gram', True)
softmax = gensim_params.pop_bool('softmax', True)
max_vocab_size = gensim_params.pop_int('vocab_size', None)
dimension = gensim_params.pop_int('vector_size', 100)
window = gensim_params.pop_int('window_size', 5)
seed = gensim_params.pop_int("random_seed", 133700)
min_count = gensim_params.pop_int("min_count", 0)
sample = gensim_params.pop_float("sample", 1e-3)
workers = gensim_params.pop_int("workers", multiprocessing.cpu_count())
negative = gensim_params.pop_int("negative_samples", 5)
min_n = gensim_params.pop_int("min_character_ngram_length", 3)
max_n = gensim_params.pop_int("max_character_ngram_length", 8)
sorted_vocab = gensim_params.pop_bool("sorted_vocab", True)
if FAST_VERSION <= 0:
logger.error(
"Using the slow version of this library. Please check your dependencies")
else:
logger.info(
"Using accelerated version of this library with %d workers", workers)
if not softmax and negative <= 0:
raise ConfigurationError(
"When using negative sampling the sampling rate must be > 0")
word_ngrams = 0
if(model_type == "fasttext"):
word_ngrams = 1
# TODO: Continue Training with equal config
model = FastText(sg=skip_gram, hs=softmax, size=dimension, window=window, min_count=min_count, max_vocab_size=max_vocab_size,
sample=sample, seed=seed, workers=workers, negative=negative, iter=iterations, min_n=min_n, max_n=max_n, sorted_vocab=sorted_vocab, word_ngrams=word_ngrams, callbacks=[GensimCallback(serialization_dir, tensorboard_writer)])
# TODO: Make sure we obey the token indexer and create tokens for the entire vocab (not just the ones present in the dataset)
# model.build_vocab(([sp.IdToPiece(i)]
# for i in range(sp.GetPieceSize())), keep_raw_vocab=True)
model.build_vocab(dataset)
# TODO: LearningRateScheduler
model.train(dataset, epochs=iterations, total_examples=model.corpus_count,
start_alpha=0.025, end_alpha=0.0001)
model.callbacks = ()
logger.info("Saving Models")
model.save(os.path.join(serialization_dir, "model.gensim"), separately=[])
model.wv.save_word2vec_format(os.path.join(serialization_dir, "model.txt"))
logger.info("Logging Embeddings")
# TODO: Allow for external vocab (i.e. out of vocab embedding generation!)
vectors = torch.FloatTensor(model.wv.vectors)
metadata = model.wv.index2word
tensorboard_writer.add_embedding(vectors, metadata, iterations)
def create_serialization_dir(
params: Params,
serialization_dir: str,
force: bool) -> None:
"""
This function creates the serialization directory if it doesn't exist. If it already exists
and is non-empty.
Parameters
----------
params: ``Params``
A parameter object specifying an AllenNLP Experiment.
serialization_dir: ``str``
The directory in which to save results and logs.
recover: ``bool``
If ``True``, we will try to recover from an existing serialization directory, and crash if
the directory doesn't exist, or doesn't match the configuration we're given.
"""
if os.path.exists(serialization_dir) and force:
shutil.rmtree(serialization_dir)
if os.path.exists(serialization_dir) and os.listdir(serialization_dir):
raise ConfigurationError(
f"Serialization directory ({serialization_dir}) already exists and is not empty.")
else:
os.makedirs(serialization_dir, exist_ok=True)
def dataset_from_params(params: Params) -> GensimDatasetIterator:
"""
Load all the datasets specified by the config.
"""
dataset_reader = DatasetReader.from_params(params.pop('dataset_reader'))
train_data_path = params.pop('train_data_path')
logger.info("Reading training data from %s", train_data_path)
return GensimDatasetIterator(dataset_reader, train_data_path)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description="Run Gensim Trainging")
parser.add_argument('param_path',
type=str,
help='path to parameter file describing the model to be trained')
parser.add_argument('-s', '--serialization-dir',
required=True,
type=str,
help='directory in which to save the model and its logs')
parser.add_argument('-n', '--iterations',
default=5,
type=int,
help='how many epochs to train for')
parser.add_argument('-f', '--force',
action='store_true',
required=False,
help='overwrite the output directory if it exists')
parser.add_argument('--file-friendly-logging',
action='store_true',
default=False,
help='outputs tqdm status on separate lines and slows tqdm refresh rate')
parser.add_argument('-o', '--overrides',
type=str,
default="",
help='a JSON structure used to override the experiment configuration')
args = parser.parse_args()
params = Params.from_file(args.param_path, args.overrides)
main(params, args.serialization_dir, iterations=args.iterations,
file_friendly_logging=args.file_friendly_logging, force=args.force)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment