| from sagemaker.huggingface import HuggingFace |
|
|
| ROLE = ? |
|
|
| |
| hyperparameters = { |
| 'epochs': 1, |
| 'per_device_train_batch_size': 32, |
| 'do_train': True, |
| 'model_name_or_path': 'distilbert-base-uncased', |
| 'output_dir': '/opt/ml/checkpoints' |
| } |
|
|
|
|
| |
| huggingface_estimator = HuggingFace( |
| entry_point='train.py', |
| source_dir='.', |
| instance_type='local', |
| instance_count=1, |
| checkpoint_s3_uri=f's3://{sess.default_bucket()}/checkpoints', |
| use_spot_instances=True, |
| max_wait=3600, |
| max_run=1000, |
| role=ROLE, |
| transformers_version='4.4', |
| pytorch_version='1.6', |
| py_version='py36', |
| hyperparameters=hyperparameters, |
| ) |
|
|
|
|
| huggingface_estimator.fit( |
| { |
| 'train': 's3://sagemaker-us-east-1-558105141721/samples/datasets/imdb/train', |
| 'test': 's3://sagemaker-us-east-1-558105141721/samples/datasets/imdb/test' |
| } |
| ) |
|
|