Created
March 11, 2025 20:27
-
-
Save zachmayer/1b0afdb46820bce19651f8b10cb92af8 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
# /// script | |
# requires-python = ">=3.9" | |
# dependencies = [ | |
# "click>=8.1.0", | |
# "diffusers>=0.28.0", | |
# "transformers>=4.36.0", | |
# "torch>=2.1.0", | |
# "pillow>=10.0.0", | |
# "accelerate>=0.24.0", | |
# "safetensors>=0.4.0", | |
# ] | |
# /// | |
import click | |
import torch | |
from diffusers import StableDiffusionImg2ImgPipeline, DPMSolverMultistepScheduler | |
from PIL import Image | |
import os | |
import time | |
import sys | |
from pathlib import Path | |
def get_device(): | |
"""Determine the best available device for Mac (MPS), NVIDIA (CUDA), or CPU.""" | |
if torch.backends.mps.is_available(): | |
return "mps" # Apple Silicon GPU | |
elif torch.cuda.is_available(): | |
return "cuda" # NVIDIA GPU | |
return "cpu" # Fallback to CPU | |
@click.command() | |
@click.argument("input_image", type=click.Path(exists=True)) | |
@click.argument("prompt", type=str, required=False) | |
@click.option("--output", "-o", type=click.Path(), default=None, | |
help="Output image path (default: output_[timestamp].png)") | |
@click.option("--negative-prompt", type=str, default="", | |
help="Elements to avoid in the generated image") | |
@click.option("--model", "-m", type=str, default="black-forest-labs/FLUX.1-dev", | |
help="Model ID from HuggingFace") | |
@click.option("--strength", "-s", type=float, default=0.65, | |
help="Transformation strength (0.0-1.0, higher = more stylized)") | |
@click.option("--steps", type=int, default=30, | |
help="Number of inference steps (higher = more detail but slower)") | |
@click.option("--guidance-scale", "-g", type=float, default=7.5, | |
help="Guidance scale (how closely to follow the prompt)") | |
@click.option("--seed", type=int, default=None, | |
help="Random seed for reproducibility") | |
@click.option("--no-half", is_flag=True, | |
help="Disable half-precision (use if encountering MPS errors)") | |
@click.option("--count", "-c", type=int, default=1, | |
help="Number of images to generate") | |
def style_transfer(input_image, prompt, output, negative_prompt, model, | |
strength, steps, guidance_scale, seed, no_half, count): | |
""" | |
Apply artistic style transfer to an image using Flux AI. | |
INPUT_IMAGE: Path to the source image file | |
PROMPT: Text prompt describing the desired style (optional, defaults to professional digital drawing) | |
""" | |
# Use default professional headshot prompt if none provided | |
if prompt is None: | |
prompt = "A professional digital drawing portrait of a person with clean, elegant line work and subtle, refined shading. The image is stylized yet true to the subject's features, featuring a minimalist background and soft, modern colors, evoking the quality of a hand-drawn illustration by a top digital artist." | |
try: | |
start_time = time.time() | |
# Set output path based on input filename if not specified | |
if output is None: | |
input_stem = Path(input_image).stem | |
output = f"output_{input_stem}.png" | |
# Ensure output directory exists | |
output_dir = os.path.dirname(output) or "." | |
os.makedirs(output_dir, exist_ok=True) | |
# Create base output path for multiple images | |
base_path = Path(output) | |
output_stem = base_path.stem | |
output_suffix = base_path.suffix | |
output_dir = base_path.parent | |
# Display initial information | |
click.echo(click.style("π¨ Flux Style Transfer", fg="green", bold=True)) | |
click.echo(f"π Input: {click.style(input_image, fg='cyan')}") | |
click.echo(f"ποΈ Prompt: {click.style(prompt, fg='yellow')}") | |
if negative_prompt: | |
click.echo(f"π« Negative: {click.style(negative_prompt, fg='red')}") | |
# Determine the best device | |
device = get_device() | |
click.echo(f"π₯οΈ Using device: {click.style(device, fg='green', bold=True)}") | |
# Set or generate random seed | |
if seed is None: | |
seed = int(torch.randint(0, 2**32 - 1, (1,)).item()) | |
click.echo(f"π± Seed: {seed}") | |
torch.manual_seed(seed) | |
# Determine precision based on device and flags | |
torch_dtype = torch.float32 if no_half or device == "cpu" else torch.float16 | |
precision = "float32" if torch_dtype == torch.float32 else "float16" | |
click.echo(f"βοΈ Using {precision} precision") | |
# Load the model with progress indication | |
with click.progressbar(length=100, label='Loading model', show_eta=False) as bar: | |
bar.update(10) | |
try: | |
# Configure the scheduler for better quality | |
scheduler = DPMSolverMultistepScheduler.from_pretrained( | |
model, subfolder="scheduler", solver_order=2, prediction_type="epsilon" | |
) | |
# Load the pipeline | |
pipe = StableDiffusionImg2ImgPipeline.from_pretrained( | |
model, | |
scheduler=scheduler, | |
torch_dtype=torch_dtype, | |
safety_checker=None, | |
) | |
bar.update(70) | |
# Move model to the appropriate device | |
pipe = pipe.to(device) | |
bar.update(20) | |
except Exception as e: | |
click.echo(click.style(f"\nβ Error loading model: {str(e)}", fg="red", bold=True)) | |
if "MPS" in str(e) and not no_half: | |
click.echo("π‘ Tip: Try again with --no-half flag") | |
return 1 | |
# Apply optimizations based on device | |
if device in ["mps", "cuda"]: | |
pipe.enable_attention_slicing() | |
if device == "cuda" and hasattr(pipe, "enable_xformers_memory_efficient_attention"): | |
pipe.enable_xformers_memory_efficient_attention() | |
# Load the source image | |
try: | |
init_image = Image.open(input_image).convert("RGB") | |
click.echo(f"π Image dimensions: {init_image.width}x{init_image.height}") | |
except Exception as e: | |
click.echo(click.style(f"β Error loading image: {str(e)}", fg="red", bold=True)) | |
return 1 | |
# Generate multiple images if requested | |
successful = 0 | |
for i in range(count): | |
current_seed = seed + i | |
torch.manual_seed(current_seed) | |
# Set output path for this image | |
if count > 1: | |
this_output = str(output_dir / f"{output_stem}_{i+1}{output_suffix}") | |
else: | |
this_output = output | |
# Generate the image with progress bar | |
click.echo(f"\n⨠Generating image {i+1}/{count} (seed: {current_seed})") | |
try: | |
with click.progressbar(length=steps, label='Processing image', | |
show_eta=False) as bar: | |
# Define a callback to update the progress bar | |
def callback_fn(step, timestep, latents): | |
bar.update(1) | |
# Run the pipeline | |
result = pipe( | |
prompt=prompt, | |
negative_prompt=negative_prompt, | |
init_image=init_image, | |
strength=strength, | |
num_inference_steps=steps, | |
guidance_scale=guidance_scale, | |
callback=callback_fn, | |
callback_steps=1, | |
) | |
# Save the image | |
result.images[0].save(this_output) | |
click.echo(f"β Saved to: {click.style(this_output, fg='green')}") | |
successful += 1 | |
except RuntimeError as e: | |
if "MPS" in str(e) and not no_half: | |
click.echo(click.style(f"β MPS error: {str(e)}", fg="red", bold=True)) | |
click.echo("π‘ Try again with --no-half flag for better compatibility") | |
else: | |
click.echo(click.style(f"β Generation error: {str(e)}", fg="red", bold=True)) | |
continue | |
except Exception as e: | |
click.echo(click.style(f"β Unexpected error: {str(e)}", fg="red", bold=True)) | |
continue | |
# Display summary | |
elapsed = time.time() - start_time | |
click.echo(f"\nπ Generated {successful}/{count} images in {elapsed:.2f} seconds") | |
return 0 if successful > 0 else 1 | |
except KeyboardInterrupt: | |
click.echo(click.style("\nβ Operation cancelled by user", fg="yellow", bold=True)) | |
return 1 | |
except Exception as e: | |
click.echo(click.style(f"\nβ Unexpected error: {str(e)}", fg="red", bold=True)) | |
return 1 | |
if __name__ == "__main__": | |
sys.exit(style_transfer()) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment