Created
March 30, 2022 15:41
-
-
Save Dref360/dd2b678ae28fb36a075ff72cf096c4e6 to your computer and use it in GitHub Desktop.
Train a HF Pipeline on a dataset. Taken from their course.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import argparse | |
from datasets import load_dataset | |
from transformers import AutoModelForSequenceClassification, TrainingArguments, Trainer | |
from transformers import AutoTokenizer | |
from transformers import DataCollatorWithPadding | |
LABEL_COL = "label" | |
TEXT_COL = "text" | |
def parse_args(): | |
parser = argparse.ArgumentParser() | |
parser.add_argument("dataset_args", nargs='+', help="List of argument to load the " | |
"Dataset to train on (available on HF Hub)") | |
parser.add_argument("--pretrained_pipeline", default="distilbert-base-uncased", | |
type=str, help="Pretrained pipeline to download (tokenizer and model)") | |
parser.add_argument("--text_column", default=TEXT_COL) | |
parser.add_argument("--label_column", default=LABEL_COL) | |
parser.add_argument("--ckpt_path", default="./ckpt") | |
return parser.parse_args() | |
def main(args): | |
ds = load_dataset(*args.dataset_args) | |
if args.label_column not in ds["train"].column_names or args.text_column not in ds["train"].column_names: | |
raise ValueError(f"Expecting {args.label_column} and {args.text_column} in dataset" | |
f" found {ds['train'].column_names}") | |
if args.text_column != TEXT_COL: | |
ds = ds.rename_column(args.text_column, TEXT_COL) | |
if args.label_column != LABEL_COL: | |
ds = ds.rename_column(args.label_column, LABEL_COL) | |
num_classes = ds["train"].features[LABEL_COL].num_classes | |
tokenizer = AutoTokenizer.from_pretrained(args.pretrained_pipeline) | |
def preprocess_function(examples): | |
return tokenizer(examples[TEXT_COL], truncation=True) | |
tokenized_ds = ds.map(preprocess_function, batched=True) | |
data_collator = DataCollatorWithPadding(tokenizer=tokenizer) | |
model = AutoModelForSequenceClassification.from_pretrained(args.pretrained_pipeline, | |
num_labels=num_classes) | |
training_args = TrainingArguments( | |
output_dir=args.ckpt_path, | |
learning_rate=2e-5, | |
per_device_train_batch_size=16, | |
per_device_eval_batch_size=16, | |
num_train_epochs=5, | |
weight_decay=0.01, | |
) | |
trainer = Trainer( | |
model=model, | |
args=training_args, | |
train_dataset=tokenized_ds["train"], | |
eval_dataset=tokenized_ds["test"], | |
tokenizer=tokenizer, | |
data_collator=data_collator, | |
) | |
trainer.train() | |
if __name__ == '__main__': | |
main(parse_args()) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment