By using AWS re:Post, you agree to the Terms of Use

ImportError: cannot import name 'dataclass_transform' from 'typing_extensions' (/home/ec2-user/anaconda3/envs/tensorflow2_p38/lib/python3.8/site-packages/typing_extensions.py)

0

Hi AWS, I am running the code for dalle mini to convert a text into an image. Here is the code for the same:

import jax
import jax.numpy as jnp
from huggingface_hub import hf_hub_url, cached_download, hf_hub_download
import shutil
from dalle_mini import DalleBart, DalleBartProcessor
from vqgan_jax.modeling_flax_vqgan import VQModel
from typing_extensions import dataclass_transform
from transformers import CLIPProcessor, FlaxCLIPModel
from IPython.display import display

# TF_CPP_MIN_LOG_LEVEL=0
print(jax.local_device_count())
print(jax.devices())

dalle_mini_files_list = ['config.json', 'tokenizer.json', 'tokenizer_config.json', 'merges.txt', 'vocab.json', 'special_tokens_map.json', 'enwiki-words-frequency.txt', 'flax_model.msgpack']

vqgan_files_list = ['config.json',  'flax_model.msgpack']

for each_file in dalle_mini_files_list:
   downloaded_file = hf_hub_download("dalle-mini/dalle-mini", filename=each_file)
   target_path = '/home/ec2-user/SageMaker/huggingface-sagemaker/content/dalle-mini/' + each_file
   shutil.copy(downloaded_file, target_path)

for each_file in vqgan_files_list:
   downloaded_file = hf_hub_download("dalle-mini/vqgan_imagenet_f16_16384", filename=each_file)
   target_path = '/home/ec2-user/SageMaker/huggingface-sagemaker/content/dalle-mini/vqgan/' + each_file
   shutil.copy(downloaded_file, target_path)

DALLE_MODEL_LOCATION = '/home/ec2-user/huggingface-sagemaker/dalle_mini/content/dalle-mini'
DALLE_COMMIT_ID = None
model, params = DalleBart.from_pretrained(    
      DALLE_MODEL_LOCATION, revision=DALLE_COMMIT_ID, dtype=jnp.float32, _do_init=False,
)

VQGAN_LOCAL_REPO = '/home/ec2-user/SageMaker/dalle_mini/content/dalle-mini/vqgan'
VQGAN_LCOAL_COMMIT_ID = None
vqgan, vqgan_params = VQModel.from_pretrained(
     VQGAN_LOCAL_REPO, revision=VQGAN_LCOAL_COMMIT_ID, _do_init=False
)


print(model.config)
print(vqgan.config)

DALLE_MODEL_LOCATION = '/home/ec2-user/SageMaker/dalle_mini/content/dalle-mini'
DALLE_COMMIT_ID = None
processor = DalleBartProcessor.from_pretrained(
     DALLE_MODEL_LOCATION, 
     revision=DALLE_COMMIT_ID)

print(processor)

# # Works for all available devices to replicate the module
from flax.jax_utils import replicate
import random

params = replicate(params)
vqgan_params = replicate(vqgan_params)

@partial(jax.pmap, axis_name="batch", static_broadcasted_argnums=(3, 4, 5, 6))
def p_generate(
    tokenized_prompt, key, params, top_k, top_p, temperature, condition_scale
):
  return model.generate(
      **tokenized_prompt,
      prng_key=key,
      params=params,
      top_k=top_k,
      top_p=top_p,
      temperature=temperature,
      condition_scale=condition_scale,
  )

#decode the images
@partial(jax.pmap, axis_name="batch")
def p_decode(indices, params):
    return vqgan.decode_code(indices, params=params)


# entering the prompts
prompts = [
    "sunset over a lake in the mountains",
    "the Eiffel tower landing on the moon",
]

tokenized_prompts = processor(prompts)
tokenized_prompt = replicate(tokenized_prompts)



# create a random key
seed = random.randint(0, 2**32 - 1)
key = jax.random.PRNGKey(seed)


n_predictions = 4

# We can customize generation parameters (see https://huggingface.co/blog/how-to-generate)
gen_top_k = None
gen_top_p = None
temperature = None
cond_scale = 10.0

print(f"Prompts: {prompts}\n")

images = []
for i in trange(max(n_predictions // jax.device_count(), 1)):
    # get a new key
    key, subkey = jax.random.split(key)
    # generate images
    encoded_images = p_generate(
        tokenized_prompt,
        shard_prng_key(subkey),
        params,
        gen_top_k,
        gen_top_p,
        temperature,
        cond_scale,
    )
    # remove BOS
    encoded_images = encoded_images.sequences[..., 1:]
    # decode images
    decoded_images = p_decode(encoded_images, vqgan_params)
    decoded_images = decoded_images.clip(0.0, 1.0).reshape((-1, 256, 256, 3))
    for decoded_img in decoded_images:
        img = Image.fromarray(np.asarray(decoded_img * 255, dtype=np.uint8))
        images.append(img)
        display(img)

and the error I am getting is:


ImportError Traceback (most recent call last) ~/SageMaker/huggingface-sagemaker/code/inference.py in <module> 5 #import DalleBart 6 #from dalle_mini import DalleBart, DalleBartProcessor ----> 7 from vqgan_jax.modeling_flax_vqgan import VQModel 8 from typing_extensions import dataclass_transform 9 #from transformers import CLIPProcessor, FlaxCLIPModel

~/anaconda3/envs/tensorflow2_p38/lib/python3.8/site-packages/vqgan_jax/modeling_flax_vqgan.py in <module> 8 import jax.numpy as jnp 9 import numpy as np ---> 10 import flax.linen as nn 11 from flax.core.frozen_dict import FrozenDict 12

~/anaconda3/envs/tensorflow2_p38/lib/python3.8/site-packages/flax/init.py in <module> 16 """Flax API.""" 17 ---> 18 from . import core as core 19 from . import linen as linen 20 from . import optim as optim

~/anaconda3/envs/tensorflow2_p38/lib/python3.8/site-packages/flax/core/init.py in <module> 26 ) 27 ---> 28 from .scope import ( 29 Scope as Scope, 30 Array as Array,

~/anaconda3/envs/tensorflow2_p38/lib/python3.8/site-packages/flax/core/scope.py in <module> 26 from flax import config 27 from flax import errors ---> 28 from flax import struct 29 from flax import traceback_util 30 from .frozen_dict import freeze

~/anaconda3/envs/tensorflow2_p38/lib/python3.8/site-packages/flax/struct.py in <module> 23 24 import jax ---> 25 from typing_extensions import dataclass_transform # pytype: disable=not-supported-yet 26 27

ImportError: cannot import name 'dataclass_transform' from 'typing_extensions' (/home/ec2-user/anaconda3/envs/tensorflow2_p38/lib/python3.8/site-packages/typing_extensions.py)

Please help me ASAP as I need to fix it urgently.

1 Answers
0

I am getting a different error now:

/home/ec2-user/anaconda3/envs/JupyterSystemEnv/lib/python3.7/site-packages/flax/core/lift.py:112: FutureWarning: jax.tree_flatten is deprecated, and will be removed in a future release. Use jax.tree_util.tree_flatten instead. scopes, treedef = jax.tree_flatten(scope_tree) /home/ec2-user/anaconda3/envs/JupyterSystemEnv/lib/python3.7/site-packages/flax/core/lift.py:729: FutureWarning: jax.tree_leaves is deprecated, and will be removed in a future release. Use jax.tree_util.tree_leaves instead. lengths = set(jax.tree_leaves(lengths)) /home/ec2-user/anaconda3/envs/JupyterSystemEnv/lib/python3.7/site-packages/flax/core/axes_scan.py:134: FutureWarning: jax.tree_flatten is deprecated, and will be removed in a future release. Use jax.tree_util.tree_flatten instead. in_avals, in_tree = jax.tree_flatten(input_avals) /home/ec2-user/anaconda3/envs/JupyterSystemEnv/lib/python3.7/site-packages/flax/linen/transforms.py:249: FutureWarning: jax.tree_leaves is deprecated, and will be removed in a future release. Use jax.tree_util.tree_leaves instead. jax.tree_leaves(tree))) /home/ec2-user/anaconda3/envs/JupyterSystemEnv/lib/python3.7/site-packages/flax/core/axes_scan.py:146: FutureWarning: jax.tree_unflatten is deprecated, and will be removed in a future release. Use jax.tree_util.tree_unflatten instead. broadcast_in, constants_out = jax.tree_unflatten(out_tree(), out_flat) Killed

I have imported the library in the inference.py script:

from jax.tree_util import (tree_flatten, tree_leaves, tree_unflatten)

But still I am experiencing this warning and the running of code is interrupts and the session gets killed.

Any reason. Please let me know.

Thanks

profile picture
answered 21 days ago

You are not logged in. Log in to post an answer.

A good answer clearly answers the question and provides constructive feedback and encourages professional growth in the question asker.

Guidelines for Answering Questions