This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import jax | |
import jax.numpy as jnp | |
import numpy as np | |
num_devices = jax.device_count() | |
device_type = jax.devices()[0].device_kind | |
assert "TPU" in device_type, "Available device is not a TPU, please select TPU from Edit > Notebook settings > Hardware accelerator" | |
from flax.jax_utils import replicate | |
from diffusers import FlaxStableDiffusionXLPipeline |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import gradio as gr | |
import jax | |
import jax.numpy as jnp | |
import numpy as np | |
from flax.jax_utils import replicate | |
from diffusers import FlaxStableDiffusionXLPipeline | |
from flax.training.common_utils import shard |