- Newest
- Most votes
- Most comments
Depending on the version of the PyTorch framework container that you're using (since they go back quite a long way now), there may be some differences in what formats the default model loader expects & supports.
In general, if you're not finding that the default loader works for you, you should also be able to provide a model_fn()
implementation (usually in code/inference.py
inside your model.tar.gz
, although again the exact specifics here may vary if you're using older versions) to override the behaviour: Loading the model how you want and returning it from the function.
If you have issues getting your model_fn recognised:
I believe there was an issue affecting specific framework images 1.6.0
, 1.7.0
and 1.8.0
where a custom model_fn
may be ignored. While the major and minor versions of these framework containers correspond to PyTorch versions, the final build number is for tracking AWS DLC changes... So if you're currently pinning to these specific versions e.g. 1.6.0
and finding this issue, you may find upgrading to a later build or just specifying 1.6
would help.
Also, when deploying via PyTorchModel
you can specify a new source_dir
(containing your inference.py
) and a new model.tar.gz
will be created to add in this code... But when deploying directly from estimator.deploy()
, you may need to make sure your training job sets everything up by:
- Copying/creating the required code file(s) into
${SM_MODEL_DIR}/code
(to add the inference code to model.tar.gz) - (In some versions) making sure your original training script name (e.g.
train.py
) is also a valid entrypoint under${SM_MODEL_DIR}/code
and exposes your required functions. This is because in some cases the entrypoint seems to get carried over from training, instead of defaulting toinference.py
.
My usual (lazy) solution to this is to make my PyTorch training jobs recursively copy the whole src
folder contents from training into ${SM_MODEL_DIR}/code
(as shown here), and to from inference import *
in my training entrypoint script (as shown here). I think I've seen this approach work for at least v1.4-1.8, but it's been a while now since I used the older end of that range.
Relevant content
- asked 8 months ago
- AWS OFFICIALUpdated 7 months ago
- AWS OFFICIALUpdated 2 months ago
- AWS OFFICIALUpdated 7 months ago
- AWS OFFICIALUpdated 2 years ago
@Alex_T - thank you. one follow up questions, I am overriding the model_fn function and providing path to mode as such -> model_path = '{}/pytorch_model.bin'.format(model_dir) , based on a sample provided in a book . I guess one difference is, i'm trying this to deploy it in as a serverless endpoint. because i tried it with real time ( without serverless option ) and it worked.
@Alex_T - thanks , I wanted to provide few more details. I am trying to load a pretained pytorch model based on documentation here https://huggingface.co/transformers/v1.2.0/serialization.html. for sagemaker purposes , i have overridden model_fn function and provided path to the model as such -> model_path = '{}/pytorch_model.bin'.format(model_dir). but i'm not sure if this will work, as in the logs i can see a warning "at least one file with .pt or .pth extension must exist...