How to generate multiple images using. Stable Diffusion XL model deployed using SageMaker Jumpstart

0

Hi, I have deployed a Stable Diffusion XL 1.0 model on SageMaker using Sagemaker Jumpstart. I have also created a Lambda function to use the endpoint. While testing the Lambda function, the following payload from the sdxl-aws-documentation generates a single image for the given prompt:

{
  "cfg_scale": 7,
  "height": 1024,
  "width": 1024,
  "steps": 50,
  "seed": 42,
  "sampler": "K_DPMPP_2M",
  "text_prompts": [
    {
      "text": "A photograph of a cute puppy",
      "weight": 1
    }
  ]
} 

I have use cases requiring generation of multiple images, say 5 or 10 images for a single prompt. In the stability-ai documentation, I see that there is a parameter called 'samples' which allows to do this. But using the 'samples' parameter while querying the SDXL Sagemaker endpoint gives me the following error: "An error occurred (ModelError) when calling the InvokeEndpoint operation: Received client error (400) from primary with message \"generation_error\nsagemaker_inference.errors.BaseInferenceToolkitError: (400, 'generation_error', 'batch size 5 is too large (max 1)')

I would like to know how to generate multiple images for a single prompt using the SageMaker Endpoint. Is there any batch_size setting that I can increase to avoid the error I am facing?

2 Answers
0

Hello Sreenithi,

The error message you're encountering indicates that you've exceeded the maximum batch size limit when making predictions with the Stable Diffusion XL (SDXL) model deployed on SageMaker. By default, the batch size for many deep learning models in SageMaker is set to 1. In this case, you're trying to generate multiple images (a batch size greater than 1), which is leading to the error.

To generate multiple images for a single prompt using the SageMaker endpoint for SDXL, you'll need to make multiple requests, each with a batch size of 1. Here's how you can modify your approach:

  1. Loop Over Requests: In your Lambda function or wherever you are making predictions, loop over the number of images you want to generate (e.g., 5 or 10 times). For each iteration, create a request with a batch size of 1.

  2. Process Responses: After making each individual request, you'll receive a response containing one generated image. Store or aggregate these images as needed.

Here's a Python code snippet demonstrating this approach using the boto3 library for making requests to the SageMaker endpoint:

import boto3
import json

# Initialize the SageMaker client
sm_client = boto3.client('sagemaker-runtime')

# Define the payload for a single request
payload = {
    "cfg_scale": 7,
    "height": 1024,
    "width": 1024,
    "steps": 50,
    "seed": 42,
    "sampler": "K_DPMPP_2M",
    "text_prompts": [
        {
            "text": "A photograph of a cute puppy",
            "weight": 1
        }
    ],
    "samples": 1  # Set samples to 1 for a single image
}

# Define the number of images to generate
num_images_to_generate = 5

# Initialize a list to store the generated images
generated_images = []

# Loop over the number of images to generate
for _ in range(num_images_to_generate):
    # Make an individual prediction request
    response = sm_client.invoke_endpoint(
        EndpointName='your-sagemaker-endpoint-name',
        Body=json.dumps(payload),
        ContentType='application/json',
    )
    
    # Parse and store the generated image from the response
    result = json.loads(response['Body'].read())
    generated_image = result['image']
    generated_images.append(generated_image)

# At this point, generated_images contains the generated images

By making individual requests with a batch size of 1, you can generate multiple images for a single prompt without exceeding the batch size limit set for the model.

profile picture
answered 5 months ago
profile pictureAWS
EXPERT
reviewed 5 months ago
0

I have tried the above method but key** image **is not present in result. But as of 7-Nov-2023 following code is working for me:

import boto3
import json

client = boto3.client("sagemaker-runtime")


import json

import base64
from PIL import Image
from io import BytesIO

endpoint_name = 'sdxl-1-0-jumpstart-2023-11-07-12-45-12-328'


content_type = "application/json"


payload = {
    "cfg_scale": 7,
    "height": 1024,
    "width": 1024,
    "steps": 50,
    "seed": 42,
    "sampler": "K_DPMPP_2M",
    "text_prompts": [
        {
            "text": "A photograph of a cute puppy",
            "weight": 1
        }
    ],
    "samples": 1  # Set samples to 1 for a single image
}


#Endpoint invocation
response = client.invoke_endpoint(
EndpointName=endpoint_name,
ContentType=content_type,
Body=json.dumps(payload))

result = json.loads(response['Body'].read())


base64_data= result['artifacts'][0]['base64']
image_data = base64.b64decode(base64_data)
image = Image.open(BytesIO(image_data))
image.show()

answered 4 months ago

You are not logged in. Log in to post an answer.

A good answer clearly answers the question and provides constructive feedback and encourages professional growth in the question asker.

Guidelines for Answering Questions