| from pathlib import Path |
| import os |
| import json |
|
|
| def get_config(path= None): |
| if path and Path.exists(Path(path)): |
| with open(path,"r") as f: |
| config = json.load(f) |
| requires =["batch_size","num_epochs","lr","seq_len","d_model","d_ff","N","h","model_folder","model_basename","preload","tokenizer_file","experiment_name"] |
| not_includes = [] |
| for r in requires: |
| if r not in config: |
| not_includes.append(r) |
| if len(not_includes) > 0 : |
| raise ValueError(f"Field(s) missing in config file : {''.join(not_includes)}") |
| return config |
| return { |
| "batch_size":4, |
| "num_epochs":30, |
| "lr":3**-4, |
| "seq_len":360, |
| "d_model":512, |
| "N":6, |
| "h":8, |
| "d_ff":2048, |
| "lang_src":"en", |
| "lang_tgt":"it", |
| "model_folder":"weights", |
| 'datasource':"opus_books", |
| "model_basename":"tmodel_", |
| "preload":25, |
| "tokenizer_file":"tokenizer_{0}.json", |
| "experiment_name":"runs/tmodel", |
| } |
| |
| def get_weights_file_path(config,epoch:str): |
| model_folder = config['model_folder'] |
| model_basename = config['model_basename'] |
| model_filename = f"{model_basename}{epoch}.pt" |
| return str(Path(os.getcwd()) / model_folder / model_filename) |
| |
|
|