Creating JumpStart SageMaker endpoint with Terraform fails with "Model needs Flash Attention"

0

I followed the Here the CDK deployment of a GenAI model through SageMaker JumpStart - https://aws.amazon.com/blogs/machine-learning/deploy-generative-ai-models-from-amazon-sagemaker-jumpstart-using-the-aws-cdk/

I want to use terraform because that will save us moving over previous IaaC tools

The strange thing that happens is the model, endpoint config both start up correctly. But then the endpoint fails to start up. In the cloudwatch logs, here is the traceback for the endpoint starting up. The error is Mistral model requires flash attention v2.

Cloudwatch Logs

2023-12-21T10:42:51.815-05:00	Traceback (most recent call last): File "/opt/conda/bin/text-generation-server", line 8, in <module> sys.exit(app()) File "/opt/conda/lib/python3.9/site-packages/typer/main.py", line 311, in __call__ return get_command(self)(*args, **kwargs) File "/opt/conda/lib/python3.9/site-packages/click/core.py", line 1157, in __call__ return self.main(*args, **kwargs) File "/opt/conda/lib/python3.9/site-packages/typer/core.py", line 778, in main return _main( File "/opt/conda/lib/python3.9/site-packages/typer/core.py", line 216, in _main rv = self.invoke(ctx) File "/opt/conda/lib/python3.9/site-packages/click/core.py", line 1688, in invoke return _process_result(sub_ctx.command.invoke(sub_ctx)) File "/opt/conda/lib/python3.9/site-packages/click/core.py", line 1434, in invoke return ctx.invoke(self.callback, **ctx.params) File "/opt/conda/lib/python3.9/site-packages/click/core.py", line 783, in invoke return __callback(*args, **kwargs) File "/opt/conda/lib/python3.9/site-packages/typer/main.py", line 683, in wrapper return callback(**use_params) # type: ignore File "/opt/conda/lib/python3.9/site-packages/text_generation_server/cli.py", line 83, in serve server.serve( File "/opt/conda/lib/python3.9/site-packages/text_generation_server/server.py", line 207, in serve asyncio.run( File "/opt/conda/lib/python3.9/asyncio/runners.py", line 44, in run return loop.run_until_complete(main) File "/opt/conda/lib/python3.9/asyncio/base_events.py", line 634, in run_until_complete self.run_forever() File "/opt/conda/lib/python3.9/asyncio/base_events.py", line 601, in run_forever self._run_once() File "/opt/conda/lib/python3.9/asyncio/base_events.py", line 1905, in _run_once handle._run() File "/opt/conda/lib/python3.9/asyncio/events.py", line 80, in _run self._context.run(self._callback, *self._args)

2023-12-21T10:42:51.815-05:00	> File "/opt/conda/lib/python3.9/site-packages/text_generation_server/server.py", line 159, in serve_inner model = get_model( File "/opt/conda/lib/python3.9/site-packages/text_generation_server/models/__init__.py", line 259, in get_model raise NotImplementedError("Mistral model requires flash attention v2")

2023-12-21T10:42:51.815-05:00	NotImplementedError: Mistral model requires flash attention v2

2023-12-21T10:42:51.815-05:00	#033[2m2023-12-21T15:42:20.669006Z#033[0m #033[31mERROR#033[0m #033[1mshard-manager#033[0m: #033[2mtext_generation_launcher#033[0m#033[2m:#033[0m Shard complete standard error output:

2023-12-21T10:42:51.815-05:00	/opt/conda/lib/python3.9/site-packages/bitsandbytes/cextension.py:34: UserWarning: The installed version of bitsandbytes was compiled without GPU support. 8-bit optimizers, 8-bit multiplication, and GPU quantization are unavailable. warn("The installed version of bitsandbytes was compiled without GPU support. "

2023-12-21T10:42:51.815-05:00	Traceback (most recent call last): File "/opt/conda/bin/text-generation-server", line 8, in <module> sys.exit(app()) File "/opt/conda/lib/python3.9/site-packages/text_generation_server/cli.py", line 83, in serve server.serve( File "/opt/conda/lib/python3.9/site-packages/text_generation_server/server.py", line 207, in serve asyncio.run( File "/opt/conda/lib/python3.9/asyncio/runners.py", line 44, in run return loop.run_until_complete(main) File "/opt/conda/lib/python3.9/asyncio/base_events.py", line 647, in run_until_complete return future.result() File "/opt/conda/lib/python3.9/site-packages/text_generation_server/server.py", line 159, in serve_inner model = get_model( File "/opt/conda/lib/python3.9/site-packages/text_generation_server/models/__init__.py", line 259, in get_model raise NotImplementedError("Mistral model requires flash attention v2")

2023-12-21T10:42:51.815-05:00	NotImplementedError: Mistral model requires flash attention v2 #033[2m#033[3mrank#033[0m#033[2m=#033[0m0#033[0m

2023-12-21T10:42:51.815-05:00	#033[2m2023-12-21T15:42:20.768126Z#033[0m #033[31mERROR#033[0m #033[2mtext_generation_launcher#033[0m#033[2m:#033[0m Shard 0 failed to start

2023-12-21T10:42:51.815-05:00	#033[2m2023-12-21T15:42:20.768153Z#033[0m #033[32m INFO#033[0m #033[2mtext_generation_launcher#033[0m#033[2m:#033[0m Shutting down shards

2023-12-21T10:42:52.066-05:00	Error: ShardCannotStart

2023-12-21T10:42:52.066-05:00	#033[2m2023-12-21T15:42:51.867970Z#033[0m #033[32m INFO#033[0m #033[2mtext_generation_launcher#033[0m#033[2m:#033[0m Args { model_id: "/opt/ml/model", revision: None, validation_workers: 2, sharded: None, num_shard: Some(1), quantize: None, dtype: None, trust_remote_code: false, max_concurrent_requests: 128, max_best_of: 2, max_stop_sequences: 4, max_top_n_tokens: 5, max_input_length: 8191, max_total_tokens: 8192, waiting_served_ratio: 1.2, max_batch_prefill_tokens: 8191, max_batch_total_tokens: None, max_waiting_tokens: 20, hostname: "container-0.local", port: 8080, shard_uds_path: "/tmp/text-generation-server", master_addr: "localhost", master_port: 29500, huggingface_hub_cache: Some("/tmp"), weights_cache_override: None, disable_custom_kernels: false, cuda_memory_fraction: 1.0, rope_scaling: None, rope_factor: None, json_output: false, otlp_endpoint: None, cors_allow_origin: [], watermark_gamma: None, watermark_delta: None, ngrok: false, ngrok_authtoken: None, ngrok_edge: None, env: false }

2023-12-21T10:42:54.824-05:00	#033[2m2023-12-21T15:42:51.868084Z#033[0m #033[32m INFO#033[0m #033[1mdownload#033[0m: #033[2mtext_generation_launcher#033[0m#033[2m:#033[0m Starting download process.

2023-12-21T10:42:55.326-05:00	#033[2m2023-12-21T15:42:54.689160Z#033[0m #033[32m INFO#033[0m #033[2mtext_generation_launcher#033[0m#033[2m:#033[0m Files are already present on the host. Skipping download.

2023-12-21T10:42:55.326-05:00	#033[2m2023-12-21T15:42:55.170901Z#033[0m #033[32m INFO#033[0m #033[1mdownload#033[0m: #033[2mtext_generation_launcher#033[0m#033[2m:#033[0m Successfully downloaded weights.

2023-12-21T10:42:58.084-05:00	#033[2m2023-12-21T15:42:55.171098Z#033[0m #033[32m INFO#033[0m #033[1mshard-manager#033[0m: #033[2mtext_generation_launcher#033[0m#033[2m:#033[0m Starting shard #033[2m#033[3mrank#033[0m#033[2m=#033[0m0#033[0m

2023-12-21T10:42:58.084-05:00	#033[2m2023-12-21T15:42:58.075307Z#033[0m #033[33m WARN#033[0m #033[2mtext_generation_launcher#033[0m#033[2m:#033[0m We're not using custom kernels.

2023-12-21T10:42:58.084-05:00	#033[2m2023-12-21T15:42:58.080243Z#033[0m #033[33m WARN#033[0m #033[2mtext_generation_launcher#033[0m#033[2m:#033[0m Could not import Flash Attention enabled models: CUDA is not available

2023-12-21T10:42:58.335-05:00	#033[2m2023-12-21T15:42:58.082521Z#033[0m #033[33m WARN#033[0m #033[2mtext_generation_launcher#033[0m#033[2m:#033[0m Could not import Mistral model: CUDA is not available

2023-12-21T10:42:58.335-05:00	#033[2m2023-12-21T15:42:58.146132Z#033[0m #033[31mERROR#033[0m #033[2mtext_generation_launcher#033[0m#033[2m:#033[0m Error when initializing model

Terraform

# README
# these values for environment and model data source were found by deploying a 
# JumpStart endpoint with sagemaker studio then copying the values on that model 
# using `aws sagemaker describe-model --model-name your_model_name`
# Without these, the endpoint will fail to deploy. You can check cloudwatch logs for the reason
# The standard way to deploy a endpoint is with boto3.sagemaker or python CDK.
# There are no resources online where to find the env and model data source info.
resource "aws_sagemaker_model" "mistral_sagemaker_model" {
  name               = "llm-mistral-7b-instruct-model"
  execution_role_arn = aws_iam_role.sagemaker_trust_role.arn

  primary_container {
    image = var.sagemaker_mistral_public_image
    mode  = "SingleModel"
    environment = {
      SAGEMAKER_MODEL_SERVER_TIMEOUT = "3600"
      ENDPOINT_SERVER_TIMEOUT        = "3600"
      HF_MODEL_ID                    = "/opt/ml/model"
      MAX_BATCH_PREFILL_TOKENS       = "8191"
      MAX_INPUT_LENGTH               = "8191"
      MAX_TOTAL_TOKENS               = "8192"
      MODEL_CACHE_ROOT               = "/opt/ml/model"
      SAGEMAKER_SUBMIT_DIRECTORY     = "/opt/ml/model/code/"
      SAGEMAKER_ENV                  = "1"
      SAGEMAKER_MODEL_SERVER_WORKERS = "1"
      SAGEMAKER_PROGRAM              = "inference.py"
      SM_NUM_GPUS                    = "1"
    }

    model_data_source {
      s3_data_source {
        s3_uri = "s3://jumpstart-cache-prod-us-east-1/huggingface-llm/huggingface-llm-mistral-7b-instruct/artifacts/inference-prepack/v1.0.0/"
        s3_data_type = "S3Prefix"
        compression_type = "None"
      }
    }
  }

  tags = {
    Application = var.app_name
  }
}

resource "aws_sagemaker_endpoint_configuration" "config" {
  name = "chat-bot-sagemaker-config"

  production_variants {
    variant_name           = "mistral-7b-variant"
    model_name             = aws_sagemaker_model.mistral_sagemaker_model.name
    initial_instance_count = 1
    instance_type          = var.sagemaker_inference_compute_size
  }

  tags = {
    Application = var.app_name
  }
}

resource "aws_sagemaker_endpoint" "endpoint" {
  name                 = "sagemaker-mistral-inference-ep"
  endpoint_config_name = aws_sagemaker_endpoint_configuration.config.name

  tags = {
    Application = var.app_name
  }
}

resource "aws_iam_role" "sagemaker_trust_role" {
  name = "sagemaker_role"

  assume_role_policy = <<-EOF
  {
    "Version": "2012-10-17",
    "Statement": [
      {
        "Action": "sts:AssumeRole",
        "Principal": {
          "Service": "sagemaker.amazonaws.com"
        },
        "Effect": "Allow",
        "Sid": ""
      }
    ]
  }
  EOF

  tags = {
    Application = var.app_name
  }
}

resource "aws_iam_role_policy_attachment" "sagemaker_full_access" {
  role       = aws_iam_role.sagemaker_trust_role.name
  policy_arn = "arn:aws:iam::aws:policy/AmazonSageMakerFullAccess"
}

resource "aws_iam_role_policy_attachment" "s3_read_write_access" {
  role       = aws_iam_role.sagemaker_trust_role.name
  policy_arn = "arn:aws:iam::aws:policy/AmazonS3FullAccess"
}

# this image was found by deploying a JumpStart endpoint with sagemaker studio
# then copying the value of that model. Without this, python boto3.sagemaker is the next best option.
variable "sagemaker_mistral_public_image" {
  description = "Jumpstart Model mistral public ECR image"
  default     = "763104351884.dkr.ecr.us-east-1.amazonaws.com/huggingface-pytorch-tgi-inference:2.0.1-tgi1.1.0-gpu-py39-cu118-ubuntu20.04"
  type        = string
}

variable "sagemaker_inference_compute_size" {
  description = "EC2 instance size"
  default     = "ml.m5.2xlarge"
  type        = string
}


1 Answer
0
Accepted Answer

I see you have sagemaker_inference_compute_size set as "ml.m5.2xlarge". Have you tested with ml.g5.2xlarge ?

AWS
Marc
answered 4 months ago
profile picture
EXPERT
reviewed 4 months ago
  • That was a whole day of work trying to figure that out and cross examine with CDK and the boto3.sagemaker examples. SMH. Thanks!

  • Flash Attention V2 is mostly supported by GPUs and specifically Nvidia Ada, Ampere, or Hopper GPUs. Trainium also supports flash attention but your mileage may vary. You are out of luck with CPU instance types.

You are not logged in. Log in to post an answer.

A good answer clearly answers the question and provides constructive feedback and encourages professional growth in the question asker.

Guidelines for Answering Questions