Fine-tune a private huggingface model with private dataset

0

I'm migrating a train.py to SageMaker. At the moment it fine-tunes a classifier that is hosted on HF Hub (privately) and the dataset I'm using is also privately hosted. All the examples seem to copy the dataset to s3 and use a public foundation model.

I figured it should work if I just use my HF Access token, i.e. in the train.py script used by the sagemaker.huggingface.HuggingFace Estimator I simple added my token (it's just hard coded for now)

dataset_path = "my-repo/my-dataset"
dataset_revision = "---"
access_token = '---'
dataset = datasets.load_dataset(dataset_path, revision=dataset_revision, token=access_token)

If I run this code in a notebook cell it works, but it fails when run by the estimator

Traceback (most recent call last):
  File "/opt/ml/code/train.py", line 65, in <module>
dataset = datasets.load_dataset(args.dataset_path, revision=args.dataset_revision, token=access_token)
  File "/opt/conda/lib/python3.10/site-packages/datasets/load.py", line 1773, in load_dataset
builder_instance = load_dataset_builder(
  File "/opt/conda/lib/python3.10/site-packages/datasets/load.py", line 1502, in load_dataset_builder
dataset_module = dataset_module_factory(
  File "/opt/conda/lib/python3.10/site-packages/datasets/load.py", line 1215, in dataset_module_factory
raise FileNotFoundError(
FileNotFoundError: Couldn't find a dataset script at /opt/ml/code/--- or any data file in the same directory. Couldn't find my-repo/my-dataset' on the Hugging Face Hub either: FileNotFoundError: Dataset 'my-repo/my-dataset' doesn't exist on the Hub at revision '---'. If the repo is private or gated, make sure to log in with `huggingface-cli login`.

Is it possible to train a model using sagemaker.huggingface.HuggingFace that downloads assets from HFHub? Or do they have to be copied to s3 first?

Dave
質問済み 8ヶ月前346ビュー
1回答
0
承認された回答

I figured it out myself, I used the login method of the huggingface_hub package in the train.py script.

from huggingface_hub import login

token = ...
login(access_token)

I added the token as a hyperparameter so I can pass it from a notebook and it works fine.

Dave
回答済み 8ヶ月前
profile picture
エキスパート
レビュー済み 8ヶ月前
profile pictureAWS
エキスパート
レビュー済み 8ヶ月前

ログインしていません。 ログイン 回答を投稿する。

優れた回答とは、質問に明確に答え、建設的なフィードバックを提供し、質問者の専門分野におけるスキルの向上を促すものです。

質問に答えるためのガイドライン

関連するコンテンツ