Created
February 1, 2019 00:53
-
-
Save RXminuS/aecaf4656bad55b0e3008987f322f34e to your computer and use it in GitHub Desktop.
Gensim Bug
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from typing import Dict, Iterable, List, Any | |
import json | |
import argparse | |
import logging | |
import multiprocessing | |
import os | |
import shutil | |
import sys | |
sys.path.insert(0, os.path.dirname( | |
os.path.abspath(os.path.join(__file__, os.pardir)))) | |
if True: # workaround for isort | |
import torch | |
from tensorboardX import SummaryWriter | |
from allennlp.common.util import prepare_environment, prepare_global_logging | |
from allennlp.common import Params | |
from allennlp.common.checks import ConfigurationError | |
from allennlp.data import Instance, Token | |
from allennlp.data.fields import TextField | |
from allennlp.data.dataset_readers.dataset_reader import DatasetReader | |
from allennlp.common.util import import_submodules | |
from allennlp.models.archival import CONFIG_NAME | |
from gensim.models import FastText | |
from gensim.utils import RULE_KEEP | |
from gensim.models.callbacks import CallbackAny2Vec | |
from gensim.models.fasttext import FAST_VERSION | |
import_submodules("donnanlp") | |
if os.environ.get("ALLENNLP_DEBUG"): | |
LEVEL = logging.DEBUG | |
else: | |
LEVEL = logging.INFO | |
logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', | |
level=LEVEL) | |
logger = logging.getLogger(__name__) # pylint: disable=invalid-name | |
class GensimDatasetIterator(object): | |
""" | |
Gensim requires the dataset to be multiple iterable, so we wrap the normal | |
dataset_reader in this iterator function to support that. | |
""" | |
def __init__(self, dataset_reader: DatasetReader, path: str): | |
self._dataset_reader = dataset_reader | |
self._path = path | |
def __iter__(self): | |
i: Instance | |
for i in self._dataset_reader.read(self._path): | |
tf: TextField = i.fields["text"] | |
tokens: List[Token] = tf.tokens | |
yield list([t.text for t in tokens]) | |
class TensorboardWriter: | |
""" | |
Wraps a pair of ``SummaryWriter`` instances but is a no-op if they're ``None``. | |
Allows Tensorboard logging without always checking for Nones first. | |
""" | |
def __init__(self, train_log: SummaryWriter = None) -> None: | |
self._train_log = train_log | |
@staticmethod | |
def _item(value: Any): | |
if hasattr(value, 'item'): | |
val = value.item() | |
else: | |
val = value | |
return val | |
def add_train_scalar(self, name: str, value: float, global_step: int) -> None: | |
# get the scalar | |
if self._train_log is not None: | |
self._train_log.add_scalar( | |
name, self._item(value), global_step) | |
def add_train_histogram(self, name: str, values: torch.Tensor, global_step: int) -> None: | |
if self._train_log is not None: | |
if isinstance(values, torch.Tensor): | |
values_to_write = values.cpu().data.numpy().flatten() | |
self._train_log.add_histogram( | |
name, values_to_write, global_step) | |
def add_embedding(self, values: torch.Tensor, metadata: List[str], global_step: int) -> None: | |
if self._train_log is not None: | |
if isinstance(values, torch.Tensor): | |
self._train_log.add_embedding( | |
values, metadata, global_step=global_step) | |
class GensimCallback(CallbackAny2Vec): | |
def __init__(self, serialization_dir: str, writer: TensorboardWriter) -> None: | |
self._serialization_dir = serialization_dir | |
self._writer = writer | |
self._global_step = 0 | |
def on_epoch_end(self, model: FastText): | |
# TODO: Somehow this broke in the latest version of gensim | |
model.save(os.path.join(self._serialization_dir, | |
f"_model.{self._global_step}.gensim"), separately=[]) | |
model.wv.save_word2vec_format(os.path.join( | |
self._serialization_dir, f"_model.{self._global_step}.txt")) | |
vectors = torch.FloatTensor(model.wv.vectors) | |
metadata = model.wv.index2word | |
self._writer.add_embedding(vectors, metadata, self._global_step) | |
self._writer.add_train_scalar( | |
"loss", model.running_training_loss, self._global_step) | |
self._writer.add_train_scalar( | |
"learning_rate", | |
model.alpha, self._global_step | |
) | |
self._global_step += 1 | |
def main(params: Params, serialization_dir: str, iterations: int = 5, file_friendly_logging: bool = False, force: bool = False): | |
prepare_environment(params) | |
create_serialization_dir(params, serialization_dir, force) | |
prepare_global_logging(serialization_dir, file_friendly_logging) | |
params.to_file(os.path.join(serialization_dir, CONFIG_NAME)) | |
dataset = dataset_from_params(params) | |
summary_writer = SummaryWriter( | |
os.path.join(serialization_dir, "log", "train")) | |
tensorboard_writer = TensorboardWriter(summary_writer) | |
gensim_params: Params = params.pop('gensim') | |
model_type = gensim_params.pop_choice( | |
"type", ["fasttext", "word2vec"], True) | |
skip_gram = gensim_params.pop_bool('skip_gram', True) | |
softmax = gensim_params.pop_bool('softmax', True) | |
max_vocab_size = gensim_params.pop_int('vocab_size', None) | |
dimension = gensim_params.pop_int('vector_size', 100) | |
window = gensim_params.pop_int('window_size', 5) | |
seed = gensim_params.pop_int("random_seed", 133700) | |
min_count = gensim_params.pop_int("min_count", 0) | |
sample = gensim_params.pop_float("sample", 1e-3) | |
workers = gensim_params.pop_int("workers", multiprocessing.cpu_count()) | |
negative = gensim_params.pop_int("negative_samples", 5) | |
min_n = gensim_params.pop_int("min_character_ngram_length", 3) | |
max_n = gensim_params.pop_int("max_character_ngram_length", 8) | |
sorted_vocab = gensim_params.pop_bool("sorted_vocab", True) | |
if FAST_VERSION <= 0: | |
logger.error( | |
"Using the slow version of this library. Please check your dependencies") | |
else: | |
logger.info( | |
"Using accelerated version of this library with %d workers", workers) | |
if not softmax and negative <= 0: | |
raise ConfigurationError( | |
"When using negative sampling the sampling rate must be > 0") | |
word_ngrams = 0 | |
if(model_type == "fasttext"): | |
word_ngrams = 1 | |
# TODO: Continue Training with equal config | |
model = FastText(sg=skip_gram, hs=softmax, size=dimension, window=window, min_count=min_count, max_vocab_size=max_vocab_size, | |
sample=sample, seed=seed, workers=workers, negative=negative, iter=iterations, min_n=min_n, max_n=max_n, sorted_vocab=sorted_vocab, word_ngrams=word_ngrams, callbacks=[GensimCallback(serialization_dir, tensorboard_writer)]) | |
# TODO: Make sure we obey the token indexer and create tokens for the entire vocab (not just the ones present in the dataset) | |
# model.build_vocab(([sp.IdToPiece(i)] | |
# for i in range(sp.GetPieceSize())), keep_raw_vocab=True) | |
model.build_vocab(dataset) | |
# TODO: LearningRateScheduler | |
model.train(dataset, epochs=iterations, total_examples=model.corpus_count, | |
start_alpha=0.025, end_alpha=0.0001) | |
model.callbacks = () | |
logger.info("Saving Models") | |
model.save(os.path.join(serialization_dir, "model.gensim"), separately=[]) | |
model.wv.save_word2vec_format(os.path.join(serialization_dir, "model.txt")) | |
logger.info("Logging Embeddings") | |
# TODO: Allow for external vocab (i.e. out of vocab embedding generation!) | |
vectors = torch.FloatTensor(model.wv.vectors) | |
metadata = model.wv.index2word | |
tensorboard_writer.add_embedding(vectors, metadata, iterations) | |
def create_serialization_dir( | |
params: Params, | |
serialization_dir: str, | |
force: bool) -> None: | |
""" | |
This function creates the serialization directory if it doesn't exist. If it already exists | |
and is non-empty. | |
Parameters | |
---------- | |
params: ``Params`` | |
A parameter object specifying an AllenNLP Experiment. | |
serialization_dir: ``str`` | |
The directory in which to save results and logs. | |
recover: ``bool`` | |
If ``True``, we will try to recover from an existing serialization directory, and crash if | |
the directory doesn't exist, or doesn't match the configuration we're given. | |
""" | |
if os.path.exists(serialization_dir) and force: | |
shutil.rmtree(serialization_dir) | |
if os.path.exists(serialization_dir) and os.listdir(serialization_dir): | |
raise ConfigurationError( | |
f"Serialization directory ({serialization_dir}) already exists and is not empty.") | |
else: | |
os.makedirs(serialization_dir, exist_ok=True) | |
def dataset_from_params(params: Params) -> GensimDatasetIterator: | |
""" | |
Load all the datasets specified by the config. | |
""" | |
dataset_reader = DatasetReader.from_params(params.pop('dataset_reader')) | |
train_data_path = params.pop('train_data_path') | |
logger.info("Reading training data from %s", train_data_path) | |
return GensimDatasetIterator(dataset_reader, train_data_path) | |
if __name__ == '__main__': | |
parser = argparse.ArgumentParser(description="Run Gensim Trainging") | |
parser.add_argument('param_path', | |
type=str, | |
help='path to parameter file describing the model to be trained') | |
parser.add_argument('-s', '--serialization-dir', | |
required=True, | |
type=str, | |
help='directory in which to save the model and its logs') | |
parser.add_argument('-n', '--iterations', | |
default=5, | |
type=int, | |
help='how many epochs to train for') | |
parser.add_argument('-f', '--force', | |
action='store_true', | |
required=False, | |
help='overwrite the output directory if it exists') | |
parser.add_argument('--file-friendly-logging', | |
action='store_true', | |
default=False, | |
help='outputs tqdm status on separate lines and slows tqdm refresh rate') | |
parser.add_argument('-o', '--overrides', | |
type=str, | |
default="", | |
help='a JSON structure used to override the experiment configuration') | |
args = parser.parse_args() | |
params = Params.from_file(args.param_path, args.overrides) | |
main(params, args.serialization_dir, iterations=args.iterations, | |
file_friendly_logging=args.file_friendly_logging, force=args.force) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment