| from transformers import AutoTokenizer, BertForTokenClassification, TrainingArguments, Trainer |
| import torch |
| from tabulate import tabulate |
| import wandb |
| import os |
| import yaml |
| from datetime import datetime |
|
|
|
|
| def train(json_path: str): |
|
|
| |
|
|
| tokenizer = AutoTokenizer.from_pretrained("jjzha/jobbert_knowledge_extraction") |
| model = BertForTokenClassification.from_pretrained("Robzy/jobbert_knowledge_extraction") |
|
|
| with open("./config.yaml", "r") as file: |
| config = yaml.safe_load(file) |
|
|
| num_epochs = config['training']['epochs'] |
| batch_size = config['training']['batch_size'] |
| lr = config['training']['learning_rate'] |
| current_time = datetime.now() |
|
|
| run = wandb.init( |
| |
| project="in-demand", |
|
|
| |
| config={ |
| "learning_rate": lr, |
| "architecture": "BERT", |
| "epochs": num_epochs, |
| "batch_size": batch_size, |
| "notes": "Datetime: " + current_time.strftime("%m/%d/%Y, %H:%M:%S") |
| } |
| ) |
|
|
| |
|
|
| from torch.utils.data import DataLoader |
| import torch.nn as nn |
| from transformers import DataCollatorForTokenClassification |
| from typing import List, Tuple |
| from datasets import load_dataset |
|
|
| |
| dataset = load_dataset("json", data_files=json_path) |
| dataset = dataset.map( |
| lambda x: {"input_ids": torch.tensor(tokenizer.convert_tokens_to_ids(x["tokens"]))} |
| ) |
|
|
| def pad(list_of_lists, pad_value=0): |
|
|
| max_len = max(len(lst) for lst in list_of_lists) |
|
|
| |
| padded_lists = [lst + [pad_value] * (max_len - len(lst)) for lst in list_of_lists] |
| attention_masks = [[1] * len(lst) + [0] * (max_len - len(lst)) for lst in list_of_lists] |
| |
| return torch.tensor(padded_lists), torch.tensor(attention_masks) |
|
|
|
|
| def collate_fn(batch: List[List[torch.Tensor]]): |
|
|
| input_ids, attention_mask = pad(list(map(lambda x: tokenizer.convert_tokens_to_ids(x['tokens']),batch))) |
| tags_knowledge, _ = pad([list(map(lambda x: label2id[x],o)) for o in [b['tags_knowledge'] for b in batch]]) |
| return {"input_ids": input_ids, "tags_knowledge": tags_knowledge, "attention_mask": attention_mask} |
|
|
|
|
| |
| train_dataloader = DataLoader(dataset['train'], batch_size=batch_size, collate_fn=collate_fn) |
|
|
| from tqdm.auto import tqdm |
| from torch.optim import AdamW |
| from transformers import get_scheduler |
|
|
| model.train() |
| device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") |
|
|
| IGNORE_INDEX = -100 |
| criterion = nn.CrossEntropyLoss(ignore_index=IGNORE_INDEX) |
| id2label = model.config.id2label |
| label2id = model.config.label2id |
|
|
| optimizer = AdamW(model.parameters(), lr=lr) |
|
|
| num_training_steps = num_epochs * len(train_dataloader) |
| lr_scheduler = get_scheduler( |
| name="linear", optimizer=optimizer, num_warmup_steps=0, num_training_steps=num_training_steps |
| ) |
|
|
| |
|
|
| from dotenv import load_dotenv |
| import os |
| load_dotenv(".env") |
| import logging |
| logging.info("Initiating training") |
|
|
| progress_bar = tqdm(range(num_epochs), desc="Epochs") |
| for epoch in range(num_epochs): |
| logging.info(f"Epoch #{epoch}") |
| |
|
|
| batch_count = 1 |
|
|
| for batch in train_dataloader: |
|
|
| logging.info(f"Batch #{batch_count} / {len(train_dataloader)}") |
| |
|
|
| tokens = batch['input_ids'].to(device) |
| attention_mask = batch['attention_mask'].to(device) |
| tags_knowledge = batch['tags_knowledge'].to(device) |
|
|
| outputs = model(tokens, attention_mask=attention_mask) |
|
|
| |
| pred = outputs.logits.reshape(-1, model.config.num_labels) |
| label = torch.where(attention_mask==0, torch.tensor(IGNORE_INDEX).to(device), tags_knowledge).reshape(-1) |
|
|
| |
| _, predicted_labels = torch.max(pred, dim=1) |
| non_pad_elements = label != IGNORE_INDEX |
| correct_predictions = (predicted_labels[non_pad_elements] == label[non_pad_elements]).sum().item() |
| total_predictions = non_pad_elements.sum().item() |
| accuracy = correct_predictions / total_predictions if total_predictions > 0 else 0 |
|
|
| loss = criterion(pred, label) |
| loss.backward() |
| optimizer.step() |
| lr_scheduler.step() |
| optimizer.zero_grad() |
| |
| wandb.log({"epoch": epoch, "accuracy": accuracy, "loss": loss}) |
|
|
| batch_count += 1 |
|
|
| progress_bar.update(1) |
|
|
| print("Training complete") |
|
|
|
|
| |
|
|
|
|
| |
| model.push_to_hub("Robzy/jobbert_knowledge_extraction") |
|
|
| |
| artifact = wandb.Artifact(name="jobbert-knowledge-extraction", type="BERT") |
| state_dict = model.state_dict() |
| with artifact.new_file('model.pth', mode='wb') as f: |
| torch.save(state_dict, f) |
|
|
| |
| wandb.log_artifact(artifact) |
|
|
| def train_today(): |
|
|
| date = datetime.today().strftime('%d-%m-%Y') |
| |
| json_path = os.path.join(os.getcwd(),f'data/tags-{date}.jsonl') |
| print(f"Training on {json_path}") |
| train(json_path=json_path) |
|
|
| if __name__ == "__main__": |
| |
| train_today() |