Created
September 21, 2023 17:45
-
-
Save entrpn/687774b554ed29956f43f05a204adb65 to your computer and use it in GitHub Desktop.
Flax SDXL inference
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import jax | |
import jax.numpy as jnp | |
import numpy as np | |
num_devices = jax.device_count() | |
device_type = jax.devices()[0].device_kind | |
assert "TPU" in device_type, "Available device is not a TPU, please select TPU from Edit > Notebook settings > Hardware accelerator" | |
from flax.jax_utils import replicate | |
from diffusers import FlaxStableDiffusionXLPipeline | |
from flax.training.common_utils import shard | |
dtype = jnp.bfloat16 | |
model_id = "pcuenq/stable-diffusion-xl-base-1.0-flax" | |
def to_bf16(t): | |
return jax.tree_map(lambda x: x.astype(jnp.bfloat16) if x.dtype != jnp.bfloat16 else x, t) | |
pipeline, params = FlaxStableDiffusionXLPipeline.from_pretrained( | |
model_id, | |
use_safetensors=True, | |
dtype=dtype, | |
) | |
params['vae'] = to_bf16(params['vae']) | |
params['text_encoder'] = to_bf16(params['text_encoder']) | |
params['text_encoder_2'] = to_bf16(params['text_encoder_2']) | |
params['unet'] = to_bf16(params['unet']) | |
imgs_per_device = 1 | |
prompt = "a colorful photo of a castle in the middle of a forest with trees and bushes, by Ismail Inceoglu, shadows, high contrast, dynamic shading, hdr, detailed vegetation, digital painting, digital drawing, detailed painting, a detailed digital painting, gothic art, featured on deviantart" | |
prompt = [prompt] * jax.device_count() * imgs_per_device | |
prompt_ids = pipeline.prepare_inputs(prompt) | |
prompt_ids = shard(prompt_ids) | |
neg_prompt = "fog, grainy, purple" | |
neg_prompt = [neg_prompt] * jax.device_count() * imgs_per_device | |
neg_prompt_ids = pipeline.prepare_inputs(neg_prompt) | |
neg_prompt_ids = shard(neg_prompt_ids) | |
p_params = replicate(params) | |
def create_key(seed=0): | |
return jax.random.PRNGKey(seed) | |
rng = create_key(0) | |
rng = jax.random.split(rng, jax.device_count()) | |
do_jit = True | |
def generate(prompt_ids, neg_prompt_ids): | |
return pipeline( | |
prompt_ids if do_jit else prompt_ids[0], | |
p_params if do_jit else params, | |
rng if do_jit else rng[0], | |
num_inference_steps=40, | |
neg_prompt_ids=neg_prompt_ids if do_jit else neg_prompt_ids[0], | |
guidance_scale = 9., | |
jit=do_jit, | |
).images | |
import time | |
start = time.time() | |
_ = generate(prompt_ids, neg_prompt_ids) | |
print(f"Compiled in {time.time() - start}") | |
start = time.time() | |
for _ in range(5): | |
images = generate(prompt_ids, neg_prompt_ids) | |
print(f"Inference in {(time.time() - start)/5}") | |
print("images.shape:",images.shape) | |
trace_path = "/tmp/tensorboard" | |
with jax.profiler.trace(trace_path): | |
images = generate(prompt_ids, neg_prompt_ids) | |
print("images.shape:",images.shape) | |
print("images.dtype:",images.dtype) | |
images = images.reshape((images.shape[0] * images.shape[1], ) + images.shape[-3:]) | |
images = pipeline.numpy_to_pil(np.array(images)) | |
for i, image in enumerate(images): | |
image.save(f"castle_{i}.png") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment