Fine-tune a private huggingface model with private dataset

0

I'm migrating a train.py to SageMaker. At the moment it fine-tunes a classifier that is hosted on HF Hub (privately) and the dataset I'm using is also privately hosted. All the examples seem to copy the dataset to s3 and use a public foundation model.

I figured it should work if I just use my HF Access token, i.e. in the train.py script used by the sagemaker.huggingface.HuggingFace Estimator I simple added my token (it's just hard coded for now)

dataset_path = "my-repo/my-dataset"
dataset_revision = "---"
access_token = '---'
dataset = datasets.load_dataset(dataset_path, revision=dataset_revision, token=access_token)

If I run this code in a notebook cell it works, but it fails when run by the estimator

Traceback (most recent call last):
  File "/opt/ml/code/train.py", line 65, in <module>
dataset = datasets.load_dataset(args.dataset_path, revision=args.dataset_revision, token=access_token)
  File "/opt/conda/lib/python3.10/site-packages/datasets/load.py", line 1773, in load_dataset
builder_instance = load_dataset_builder(
  File "/opt/conda/lib/python3.10/site-packages/datasets/load.py", line 1502, in load_dataset_builder
dataset_module = dataset_module_factory(
  File "/opt/conda/lib/python3.10/site-packages/datasets/load.py", line 1215, in dataset_module_factory
raise FileNotFoundError(
FileNotFoundError: Couldn't find a dataset script at /opt/ml/code/--- or any data file in the same directory. Couldn't find my-repo/my-dataset' on the Hugging Face Hub either: FileNotFoundError: Dataset 'my-repo/my-dataset' doesn't exist on the Hub at revision '---'. If the repo is private or gated, make sure to log in with `huggingface-cli login`.

Is it possible to train a model using sagemaker.huggingface.HuggingFace that downloads assets from HFHub? Or do they have to be copied to s3 first?

Dave
已提问 8 个月前344 查看次数
1 回答
0
已接受的回答

I figured it out myself, I used the login method of the huggingface_hub package in the train.py script.

from huggingface_hub import login

token = ...
login(access_token)

I added the token as a hyperparameter so I can pass it from a notebook and it works fine.

Dave
已回答 8 个月前
profile picture
专家
已审核 8 个月前
profile pictureAWS
专家
已审核 8 个月前

您未登录。 登录 发布回答。

一个好的回答可以清楚地解答问题和提供建设性反馈,并能促进提问者的职业发展。

回答问题的准则

相关内容