Source code for generate_background

"""Generate AI-enhanced images for each image in the output directory."""
import torch
from diffusers import AutoPipelineForInpainting
from PIL import Image, ImageFilter
import os
import json
import argparse

[docs] def initalize_diffuser(device: str) -> AutoPipelineForInpainting: """Initialize the diffusion model. Args: device: cpu or cuda (if available). Returns: The diffusion model. """ if device == "cuda": pipe = AutoPipelineForInpainting.from_pretrained( "diffusers/stable-diffusion-xl-1.0-inpainting-0.1", torch_dtype=torch.float16, variant="fp16" ).to(device) else: pipe = AutoPipelineForInpainting.from_pretrained( "diffusers/stable-diffusion-xl-1.0-inpainting-0.1" ).to(device) return pipe
[docs] def generate_background(image_file: str, mask_file: str, output_path: str, prompt: str, negative_prompt: str, device: str, pipe: AutoPipelineForInpainting, object_names: tuple[str, str] = None) -> None: """ Generate and save an enhanced image using a diffusion model for a given image and mask. Args: image_file: The path to the image file. mask_file: The path to the mask file corresponding to the image. output_path: The path to save the enhanced image. prompt: The prompt to generate the enhanced image. negative_prompt: The negative prompt to avoid certain features in the generated image. device: cpu or cuda (if available). pipe: The diffusion model. object_names: Optional tuple of (object1, object2) names to add to the prompt. """ # Enhance prompt with object information if provided if object_names: object1, object2 = object_names prompt = f"{prompt}. Do not modify the main objects in the scene: {object1} and {object2}." # Open both image and mask image = Image.open(image_file).convert("RGB") mask = Image.open(mask_file).convert("L") mask = mask.filter(ImageFilter.GaussianBlur(radius=2)) mask = mask.point(lambda x: 0 if x < 200 else 255) # Use a generator generator = torch.Generator(device=device).manual_seed(42) # Perform inpainting output_image = pipe( prompt=prompt, negative_prompt=negative_prompt, image=image, mask_image=mask, guidance_scale=5, num_inference_steps=20, strength=0.9, generator=generator ).images[0] filename = os.path.basename(image_file) output_image.save(os.path.join(output_path, filename))
[docs] def main(prompt: str, negative_prompt: str, device: str) -> None: """Generate enhanced images for each image in the output directory. Args: prompt: The prompt to generate the enhanced image. negative_prompt: The negative prompt to avoid certain features in the generated image. device: cpu or cuda (if available). """ pipe = initalize_diffuser(device) # Load config information with open('config.json', 'r') as f: config = json.load(f) output_dir = config['output_image_dir'] masks_dir = config['masks_dir'] enhanced_image_dir = config['enhanced_image_dir'] scene_dir = config['output_scene_dir'] if not os.path.exists(enhanced_image_dir): os.makedirs(enhanced_image_dir) # Check if there are files directly in the output directory, if so, only process those files direct_files = [f for f in os.listdir(output_dir) if os.path.isfile(os.path.join(output_dir, f))] if direct_files: for filename in direct_files: image_file = os.path.join(output_dir, filename) relative_path = "." # Direct files have no relative path mask_file = os.path.join(masks_dir, filename) # Extract object names from scene data scene_file = os.path.join(scene_dir, filename.replace('.png', '.json')) with open(scene_file, 'r') as f: scene_data = json.load(f) ground_name = scene_data.get('ground_object', {}).get('name') figure_name = scene_data.get('figure_object', {}).get('name') enhanced_path_dir = enhanced_image_dir if not os.path.exists(enhanced_path_dir): os.makedirs(enhanced_path_dir) generate_background(image_file, mask_file, enhanced_path_dir, prompt, negative_prompt, device, pipe, (ground_name, figure_name)) else: for root, _, files in os.walk(output_dir): for filename in files: image_file = os.path.join(root, filename) relative_path = os.path.relpath(root, output_dir) mask_file = os.path.join(masks_dir, filename) # Extract object names from scene data scene_file = os.path.join(scene_dir, relative_path, filename.replace('.png', '.json')) with open(scene_file, 'r') as f: scene_data = json.load(f) ground_name = scene_data.get('ground_object', {}).get('name') figure_name = scene_data.get('figure_object', {}).get('name') enhanced_path_dir = os.path.join(enhanced_image_dir, relative_path) if not os.path.exists(enhanced_path_dir): os.makedirs(enhanced_path_dir) generate_background(image_file, mask_file, enhanced_path_dir, prompt, negative_prompt, device, pipe, (ground_name, figure_name))
if __name__ == '__main__': # Set up argument parser parser = argparse.ArgumentParser() parser.add_argument('--prompt', type=str, required=True, help='The prompt to generate the enhanced image') parser.add_argument('--negative_prompt', type=str, default="", help='The negative prompt to avoid certain features in the generated image') parser.add_argument('--device', type=str, choices=['cpu', 'cuda', 'auto'], default='auto', help='Device to use for inference (cpu, cuda, or auto)') args = parser.parse_args() # Determine device if args.device == 'auto': if torch.cuda.is_available(): device = "cuda" print("CUDA is available. Using GPU.") else: print("CUDA is not available. Using CPU.") device = "cpu" else: device = args.device # Run main function with parsed arguments main(args.prompt, args.negative_prompt, device)