Last active
September 7, 2020 12:58
-
-
Save PhilipMay/bd250cba591b3252b8da2f3d31ee5b64 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import logging | |
from pathlib import Path | |
import torch | |
from farm.data_handler.data_silo import DataSilo, DataSiloForCrossVal | |
from farm.data_handler.processor import TextClassificationProcessor | |
from farm.modeling.optimization import initialize_optimizer | |
from farm.modeling.adaptive_model import AdaptiveModel | |
from farm.modeling.language_model import LanguageModel | |
from farm.modeling.prediction_head import TextClassificationHead | |
from farm.modeling.tokenization import Tokenizer | |
from farm.train import Trainer, EarlyStopping | |
from farm.utils import set_all_seeds, MLFlowLogger, initialize_device_settings | |
#lang_model = "./models/dbmdz-bert-base-german-uncased" | |
lang_model = "bert-base-german-dbmdz-uncased" | |
#lang_model = "./models/german-nlp-group-electra-base-german-uncased" | |
#lang_model = "german-nlp-group/electra-base-german-uncased" | |
def doc_classification_crossvalidation(): | |
logger = logging.getLogger(__name__) | |
logging.basicConfig( | |
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", | |
datefmt="%m/%d/%Y %H:%M:%S", | |
level=logging.INFO) | |
logging.getLogger('transformers').setLevel(logging.WARNING) | |
xval_folds = 5 | |
xval_stratified = True | |
metric_name = "f1_macro" | |
save_dir = Path("./saved_models/electra-bert-test") | |
set_all_seeds(seed=42) | |
device, n_gpu = initialize_device_settings(use_cuda=True) | |
n_epochs = 3 | |
batch_size = 32 | |
evaluate_every = 100 | |
use_amp = None | |
tokenizer = Tokenizer.load(pretrained_model_name_or_path=lang_model) | |
label_list = ["OTHER", "OFFENSE"] | |
processor = TextClassificationProcessor(tokenizer=tokenizer, | |
max_seq_len=64, | |
data_dir=Path("./data/germeval18"), | |
label_list=label_list, | |
metric=metric_name, | |
label_column_name="coarse_label" | |
) | |
data_silo = DataSilo(processor=processor, batch_size=batch_size) | |
silos = DataSiloForCrossVal.make(data_silo, n_splits=xval_folds) | |
def train_on_split(silo_to_use, n_fold, save_dir): | |
logger.info(f"############ Crossvalidation: Fold {n_fold} ############") | |
language_model = LanguageModel.load(lang_model) | |
prediction_head = TextClassificationHead( | |
class_weights=data_silo.calculate_class_weights(task_name="text_classification"), | |
num_labels=len(label_list)) | |
model = AdaptiveModel( | |
language_model=language_model, | |
prediction_heads=[prediction_head], | |
embeds_dropout_prob=0.2, | |
lm_output_types=["per_sequence"], | |
device=device) | |
model, optimizer, lr_schedule = initialize_optimizer( | |
model=model, | |
learning_rate=0.5e-5, | |
device=device, | |
n_batches=len(silo_to_use.loaders["train"]), | |
n_epochs=n_epochs, | |
use_amp=use_amp) | |
earlystopping = EarlyStopping( | |
metric=metric_name, | |
mode="max", | |
save_dir=save_dir, | |
patience=5, | |
) | |
trainer = Trainer( | |
model=model, | |
optimizer=optimizer, | |
data_silo=silo_to_use, | |
epochs=n_epochs, | |
n_gpu=n_gpu, | |
lr_schedule=lr_schedule, | |
evaluate_every=evaluate_every, | |
device=device, | |
early_stopping=earlystopping, | |
) | |
trainer.train() | |
es_result = earlystopping.best_so_far | |
result = trainer.test_result[0][metric_name] | |
print('result from early stopping (on dev set)', es_result) | |
print('result from test set (with best loaded trial)', result) | |
input("Please compare result from early stopping (on dev set) and result from test set (with best loaded trial)...") | |
return model | |
for num_fold, silo in enumerate(silos): | |
model = train_on_split(silo, num_fold, save_dir) | |
# emtpy cache to avoid memory leak and cuda OOM across multiple folds | |
model.cpu() | |
torch.cuda.empty_cache() | |
if __name__ == "__main__": | |
doc_classification_crossvalidation() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment