Skip to content

Instantly share code, notes, and snippets.

@zachmayer
Created March 11, 2025 20:27
Show Gist options
  • Save zachmayer/1b0afdb46820bce19651f8b10cb92af8 to your computer and use it in GitHub Desktop.
Save zachmayer/1b0afdb46820bce19651f8b10cb92af8 to your computer and use it in GitHub Desktop.
#!/usr/bin/env python
# /// script
# requires-python = ">=3.9"
# dependencies = [
# "click>=8.1.0",
# "diffusers>=0.28.0",
# "transformers>=4.36.0",
# "torch>=2.1.0",
# "pillow>=10.0.0",
# "accelerate>=0.24.0",
# "safetensors>=0.4.0",
# ]
# ///
import click
import torch
from diffusers import StableDiffusionImg2ImgPipeline, DPMSolverMultistepScheduler
from PIL import Image
import os
import time
import sys
from pathlib import Path
def get_device():
"""Determine the best available device for Mac (MPS), NVIDIA (CUDA), or CPU."""
if torch.backends.mps.is_available():
return "mps" # Apple Silicon GPU
elif torch.cuda.is_available():
return "cuda" # NVIDIA GPU
return "cpu" # Fallback to CPU
@click.command()
@click.argument("input_image", type=click.Path(exists=True))
@click.argument("prompt", type=str, required=False)
@click.option("--output", "-o", type=click.Path(), default=None,
help="Output image path (default: output_[timestamp].png)")
@click.option("--negative-prompt", type=str, default="",
help="Elements to avoid in the generated image")
@click.option("--model", "-m", type=str, default="black-forest-labs/FLUX.1-dev",
help="Model ID from HuggingFace")
@click.option("--strength", "-s", type=float, default=0.65,
help="Transformation strength (0.0-1.0, higher = more stylized)")
@click.option("--steps", type=int, default=30,
help="Number of inference steps (higher = more detail but slower)")
@click.option("--guidance-scale", "-g", type=float, default=7.5,
help="Guidance scale (how closely to follow the prompt)")
@click.option("--seed", type=int, default=None,
help="Random seed for reproducibility")
@click.option("--no-half", is_flag=True,
help="Disable half-precision (use if encountering MPS errors)")
@click.option("--count", "-c", type=int, default=1,
help="Number of images to generate")
def style_transfer(input_image, prompt, output, negative_prompt, model,
strength, steps, guidance_scale, seed, no_half, count):
"""
Apply artistic style transfer to an image using Flux AI.
INPUT_IMAGE: Path to the source image file
PROMPT: Text prompt describing the desired style (optional, defaults to professional digital drawing)
"""
# Use default professional headshot prompt if none provided
if prompt is None:
prompt = "A professional digital drawing portrait of a person with clean, elegant line work and subtle, refined shading. The image is stylized yet true to the subject's features, featuring a minimalist background and soft, modern colors, evoking the quality of a hand-drawn illustration by a top digital artist."
try:
start_time = time.time()
# Set output path based on input filename if not specified
if output is None:
input_stem = Path(input_image).stem
output = f"output_{input_stem}.png"
# Ensure output directory exists
output_dir = os.path.dirname(output) or "."
os.makedirs(output_dir, exist_ok=True)
# Create base output path for multiple images
base_path = Path(output)
output_stem = base_path.stem
output_suffix = base_path.suffix
output_dir = base_path.parent
# Display initial information
click.echo(click.style("🎨 Flux Style Transfer", fg="green", bold=True))
click.echo(f"πŸ“„ Input: {click.style(input_image, fg='cyan')}")
click.echo(f"πŸ–ŒοΈ Prompt: {click.style(prompt, fg='yellow')}")
if negative_prompt:
click.echo(f"🚫 Negative: {click.style(negative_prompt, fg='red')}")
# Determine the best device
device = get_device()
click.echo(f"πŸ–₯️ Using device: {click.style(device, fg='green', bold=True)}")
# Set or generate random seed
if seed is None:
seed = int(torch.randint(0, 2**32 - 1, (1,)).item())
click.echo(f"🌱 Seed: {seed}")
torch.manual_seed(seed)
# Determine precision based on device and flags
torch_dtype = torch.float32 if no_half or device == "cpu" else torch.float16
precision = "float32" if torch_dtype == torch.float32 else "float16"
click.echo(f"βš™οΈ Using {precision} precision")
# Load the model with progress indication
with click.progressbar(length=100, label='Loading model', show_eta=False) as bar:
bar.update(10)
try:
# Configure the scheduler for better quality
scheduler = DPMSolverMultistepScheduler.from_pretrained(
model, subfolder="scheduler", solver_order=2, prediction_type="epsilon"
)
# Load the pipeline
pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
model,
scheduler=scheduler,
torch_dtype=torch_dtype,
safety_checker=None,
)
bar.update(70)
# Move model to the appropriate device
pipe = pipe.to(device)
bar.update(20)
except Exception as e:
click.echo(click.style(f"\n❌ Error loading model: {str(e)}", fg="red", bold=True))
if "MPS" in str(e) and not no_half:
click.echo("πŸ’‘ Tip: Try again with --no-half flag")
return 1
# Apply optimizations based on device
if device in ["mps", "cuda"]:
pipe.enable_attention_slicing()
if device == "cuda" and hasattr(pipe, "enable_xformers_memory_efficient_attention"):
pipe.enable_xformers_memory_efficient_attention()
# Load the source image
try:
init_image = Image.open(input_image).convert("RGB")
click.echo(f"πŸ“ Image dimensions: {init_image.width}x{init_image.height}")
except Exception as e:
click.echo(click.style(f"❌ Error loading image: {str(e)}", fg="red", bold=True))
return 1
# Generate multiple images if requested
successful = 0
for i in range(count):
current_seed = seed + i
torch.manual_seed(current_seed)
# Set output path for this image
if count > 1:
this_output = str(output_dir / f"{output_stem}_{i+1}{output_suffix}")
else:
this_output = output
# Generate the image with progress bar
click.echo(f"\n✨ Generating image {i+1}/{count} (seed: {current_seed})")
try:
with click.progressbar(length=steps, label='Processing image',
show_eta=False) as bar:
# Define a callback to update the progress bar
def callback_fn(step, timestep, latents):
bar.update(1)
# Run the pipeline
result = pipe(
prompt=prompt,
negative_prompt=negative_prompt,
init_image=init_image,
strength=strength,
num_inference_steps=steps,
guidance_scale=guidance_scale,
callback=callback_fn,
callback_steps=1,
)
# Save the image
result.images[0].save(this_output)
click.echo(f"βœ… Saved to: {click.style(this_output, fg='green')}")
successful += 1
except RuntimeError as e:
if "MPS" in str(e) and not no_half:
click.echo(click.style(f"❌ MPS error: {str(e)}", fg="red", bold=True))
click.echo("πŸ’‘ Try again with --no-half flag for better compatibility")
else:
click.echo(click.style(f"❌ Generation error: {str(e)}", fg="red", bold=True))
continue
except Exception as e:
click.echo(click.style(f"❌ Unexpected error: {str(e)}", fg="red", bold=True))
continue
# Display summary
elapsed = time.time() - start_time
click.echo(f"\nπŸŽ‰ Generated {successful}/{count} images in {elapsed:.2f} seconds")
return 0 if successful > 0 else 1
except KeyboardInterrupt:
click.echo(click.style("\n❌ Operation cancelled by user", fg="yellow", bold=True))
return 1
except Exception as e:
click.echo(click.style(f"\n❌ Unexpected error: {str(e)}", fg="red", bold=True))
return 1
if __name__ == "__main__":
sys.exit(style_transfer())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment