What I did
I am trying to create SageMaker Serverless Inference Endpoint using the following CloudFormation template.
AWSTemplateFormatVersion: "2010-09-09"
Parameters:
ModelDataUrl:
Type: "String"
Default: "s3://path/to/model/model.tar.gz"
Resources:
SageMakerExecutionRole:
Type: "AWS::IAM::Role"
Properties:
RoleName: "SageMakerExecutionRole"
AssumeRolePolicyDocument:
Version: "2012-10-17"
Statement:
- Effect: "Allow"
Principal:
Service:
- "sagemaker.amazonaws.com"
Action:
- "sts:AssumeRole"
Policies:
- PolicyName: "S3GetObjectPolicy"
PolicyDocument:
Version: "2012-10-17"
Statement:
- Effect: "Allow"
Action:
- "s3:GetObject"
Resource:
- "arn:aws:s3:::path/to/model/*"
- PolicyName: "ECRBatchGetItemPolicy"
PolicyDocument:
Version: "2012-10-17"
Statement:
- Effect: "Allow"
Action:
- "ecr:BatchCheckLayerAvailability"
- "ecr:BatchGetImage"
- "ecr:GetDownloadUrlForLayer"
Resource:
- "arn:aws:ecr:*:763104351884:repository/*"
- PolicyName: "ECRAutholization"
PolicyDocument:
Version: "2012-10-17"
Statement:
- Effect: "Allow"
Action:
- "ecr:GetAuthorizationToken"
Resource:
- "*"
- PolicyName: "Logging"
PolicyDocument:
Version: "2012-10-17"
Statement:
- Effect: "Allow"
Action:
- "logs:CreateLogGroup"
- "logs:CreateLogStream"
- "logs:GetLogEvents"
- "logs:PutLogEvents"
Resource:
- "*"
SageMakerModel:
Type: "AWS::SageMaker::Model"
Properties:
ExecutionRoleArn: !GetAtt "SageMakerExecutionRole.Arn"
Containers:
- Image: "763104351884.dkr.ecr.ap-northeast-1.amazonaws.com/pytorch-inference:2.2.0-cpu-py310-ubuntu20.04-sagemaker"
ModelDataUrl: !Ref ModelDataUrl
SageMakerEndpointConfig:
Type: "AWS::SageMaker::EndpointConfig"
Properties:
ProductionVariants:
- ModelName: !GetAtt "SageMakerModel.ModelName"
VariantName: "MediaClassificationEndpoint"
ServerlessConfig:
MaxConcurrency: 5
MemorySizeInMB: 3072
SageMakerEndpoint:
Type: "AWS::SageMaker::Endpoint"
Properties:
EndpointConfigName: !GetAtt "SageMakerEndpointConfig.EndpointConfigName"
EndpointName: "SageMakerEndpoint"
I have uploaded model.tar.gz
to s3, which have the following structure:
.
├── model.pth
└── code/
├── inference.py
└── requirements.txt
And here is the inference.py
:
import base64
import json
from logging import getLogger
from pathlib import Path
import timm
import torch
from PIL import Image
WEIGHT_KEY = 'classifier.weight'
logger = getLogger(__name__)
def model_fn(
model_dir: str,
) -> torch.nn.Module:
model_path = Path(model_dir) / 'model.pth'
assert model_path.exists(), f'{model_path} does not exist'
state_dict = torch.load(Path(model_dir) / 'model.pth')
num_classes = state_dict[WEIGHT_KEY].shape[0]
model = timm.create_model(
'tf_efficientnetv2_s.in21k',
num_classes=num_classes,
)
logger.info('Model created')
model.load_state_dict(state_dict)
logger.info(f'Model loaded from {model_path}')
model = model.eval()
return model
...
Problem
I tested the endpoint using SageMaker Studio, but I encountered the following error:
Received server error (500) from model with message "{
"code": 500,
"type": "InternalServerException",
"message": "Worker died."
}
Here is the CloudWatch logs:
2024-05-13T07:01:21.771Z 2024-05-13T07:01:21,768 [INFO ] W-9000-model_1.0-stdout MODEL_LOG - Traceback (most recent call last):
2024-05-13T07:01:21.771Z 2024-05-13T07:01:21,768 [INFO ] W-9000-model_1.0-stdout MODEL_LOG - File "/opt/conda/lib/python3.10/site-packages/ts/model_service_worker.py", line 253, in <module>
2024-05-13T07:01:21.771Z 2024-05-13T07:01:21,768 [INFO ] W-9000-model_1.0-stdout MODEL_LOG - worker.run_server()
2024-05-13T07:01:21.771Z 2024-05-13T07:01:21,768 [INFO ] W-9000-model_1.0-stdout MODEL_LOG - File "/opt/conda/lib/python3.10/site-packages/ts/model_service_worker.py", line 221, in run_server
2024-05-13T07:01:21.771Z 2024-05-13T07:01:21,769 [INFO ] W-9000-model_1.0-stdout MODEL_LOG - self.handle_connection(cl_socket)
2024-05-13T07:01:21.771Z 2024-05-13T07:01:21,769 [INFO ] W-9000-model_1.0-stdout MODEL_LOG - File "/opt/conda/lib/python3.10/site-packages/ts/model_service_worker.py", line 184, in handle_connection
2024-05-13T07:01:21.771Z 2024-05-13T07:01:21,769 [INFO ] W-9000-model_1.0-stdout MODEL_LOG - service, result, code = self.load_model(msg)
2024-05-13T07:01:21.771Z 2024-05-13T07:01:21,769 [INFO ] W-9000-model_1.0-stdout MODEL_LOG - File "/opt/conda/lib/python3.10/site-packages/ts/model_service_worker.py", line 131, in load_model
2024-05-13T07:01:21.771Z 2024-05-13T07:01:21,769 [INFO ] W-9000-model_1.0-stdout MODEL_LOG - service = model_loader.load(
2024-05-13T07:01:21.771Z 2024-05-13T07:01:21,769 [INFO ] W-9000-model_1.0-stdout MODEL_LOG - File "/opt/conda/lib/python3.10/site-packages/ts/model_loader.py", line 135, in load
2024-05-13T07:01:21.771Z 2024-05-13T07:01:21,769 [INFO ] W-9000-model_1.0-stdout MODEL_LOG - initialize_fn(service.context)
2024-05-13T07:01:21.771Z 2024-05-13T07:01:21,769 [INFO ] W-9000-model_1.0-stdout MODEL_LOG - File "/opt/conda/lib/python3.10/site-packages/sagemaker_pytorch_serving_container/handler_service.py", line 51, in initialize
2024-05-13T07:01:21.771Z 2024-05-13T07:01:21,769 [INFO ] W-9000-model_1.0-stdout MODEL_LOG - super().initialize(context)
2024-05-13T07:01:21.771Z 2024-05-13T07:01:21,769 [INFO ] W-9000-model_1.0-stdout MODEL_LOG - File "/opt/conda/lib/python3.10/site-packages/sagemaker_inference/default_handler_service.py", line 66, in initialize
2024-05-13T07:01:21.771Z 2024-05-13T07:01:21,769 [INFO ] W-9000-model_1.0-stdout MODEL_LOG - self._service.validate_and_initialize(model_dir=model_dir, context=context)
2024-05-13T07:01:21.771Z 2024-05-13T07:01:21,769 [INFO ] W-9000-model_1.0-stdout MODEL_LOG - File "/opt/conda/lib/python3.10/site-packages/sagemaker_inference/transformer.py", line 184, in validate_and_initialize
2024-05-13T07:01:21.771Z 2024-05-13T07:01:21,770 [INFO ] W-9000-model_1.0-stdout MODEL_LOG - self._model = self._run_handler_function(self._model_fn, *(model_dir,))
2024-05-13T07:01:21.771Z 2024-05-13T07:01:21,770 [INFO ] W-9000-model_1.0-stdout MODEL_LOG - File "/opt/conda/lib/python3.10/site-packages/sagemaker_inference/transformer.py", line 272, in _run_handler_function
2024-05-13T07:01:21.771Z 2024-05-13T07:01:21,770 [INFO ] W-9000-model_1.0-stdout MODEL_LOG - result = func(*argv)
2024-05-13T07:01:21.771Z 2024-05-13T07:01:21,771 [INFO ] epollEventLoopGroup-5-1 org.pytorch.serve.wlm.WorkerThread - 9000 Worker disconnected. WORKER_STARTED
2024-05-13T07:01:21.773Z 2024-05-13T07:01:21,772 [WARN ] W-9000-model_1.0 org.pytorch.serve.wlm.BatchAggregator - Load model failed: model, error: Worker died.
Could you tell me the cause of the error and how to solve the problem?
Thank you in advance.