import os print("\n".join(f"{k}={v}" for k, v in os.environ.items())) os.system("pip freeze") import glob import gradio as gr import shutil import spaces import subprocess import sys import tempfile from PIL import Image from gradio_patches.radio import Radio REPO_URL = "https://github.com/prs-eth/stereospace.git" COMMIT_SHA = "d7bbae6" REPO_DIR = "stereospace" DEVICE = "cuda" def clone_repository(): if os.path.exists(REPO_DIR) and os.path.isdir(os.path.join(REPO_DIR, ".git")): print(f"Repository {REPO_DIR} already exists, checking out commit...") subprocess.run(["git", "fetch"], cwd=REPO_DIR, check=True, capture_output=True) subprocess.run(["git", "checkout", COMMIT_SHA], cwd=REPO_DIR, check=True) else: print(f"Cloning repository {REPO_URL} at commit {COMMIT_SHA}...") subprocess.run(["git", "clone", REPO_URL, REPO_DIR], check=True) subprocess.run(["git", "checkout", COMMIT_SHA], cwd=REPO_DIR, check=True) print(f"Repository ready at {REPO_DIR}") clone_repository() sys.path.insert(0, REPO_DIR) from inference import generate_novel_view from src import StereoSpace def create_placeholder_image(): placeholder = Image.new('RGB', (1, 1), color='black') with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as tmp: placeholder.save(tmp.name) return tmp.name def find_output_file(output_dir, base_name, output_mode): if output_mode == "Anaglyph": pattern = os.path.join(output_dir, f"{base_name}_anaglyph.png") elif output_mode == "Side-by-side": pattern = os.path.join(output_dir, f"{base_name}_sbs.png") elif output_mode == "Generated view": pattern = os.path.join(output_dir, f"{base_name}_generated_*.png") matches = glob.glob(pattern) if matches: pattern = matches[0] else: raise FileNotFoundError(f"No generated file found matching {pattern}") elif output_mode == "Input view": pattern = os.path.join(output_dir, f"{base_name}_source.png") else: raise ValueError(f"Unknown output mode: {output_mode}") if not os.path.exists(pattern): raise FileNotFoundError(f"Output file not found: {pattern}") return pattern def find_all_output_files(output_dir, base_name): outputs = {} modes = ["Anaglyph", "Side-by-side", "Generated view", "Input view"] for mode in modes: try: output_file = find_output_file(output_dir, base_name, mode) outputs[mode] = output_file except FileNotFoundError: pass return outputs @spaces.GPU def process_all_modes(input_image): if input_image is None: raise gr.Error("Please upload an image or select an example.") with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp_input: input_path = tmp_input.name if isinstance(input_image, str): shutil.copy(input_image, input_path) else: Image.open(input_image).convert("RGB").save(input_path) try: with tempfile.TemporaryDirectory() as tmp_output: base_name = os.path.splitext(os.path.basename(input_path))[0] inference_script = "inference.py" cmd = [ "python", inference_script, "--input", input_path, "--output", tmp_output ] print(f"Running command: {' '.join(cmd)}") print(f"Working directory: {REPO_DIR}") result = subprocess.run( cmd, cwd=REPO_DIR, capture_output=True, text=True ) if result.returncode != 0: error_msg = f"Inference failed:\nSTDOUT:\n{result.stdout}\nSTDERR:\n{result.stderr}" print(error_msg) raise gr.Error(error_msg) try: output_files = find_all_output_files(tmp_output, base_name) except Exception as e: all_files = os.listdir(tmp_output) error_msg = f"Output files not found. Available files: {all_files}\nError: {str(e)}" print(error_msg) raise gr.Error(error_msg) outputs = {} for mode, output_file in output_files.items(): output_image = Image.open(output_file).convert("RGB") output_image.load() outputs[mode] = output_image return outputs finally: if os.path.exists(input_path): os.unlink(input_path) def get_example_images(): example_dir = os.path.join(REPO_DIR, "example_images") if not os.path.exists(example_dir): return [] image_extensions = ["*.png", "*.jpg", "*.jpeg", "*.PNG", "*.JPG", "*.JPEG"] examples = [] for ext in image_extensions: examples.extend(glob.glob(os.path.join(example_dir, ext))) return sorted(examples) with gr.Blocks( title="StereoSpace Demo", ) as demo: gr.Markdown( """
Upload a photo or pick an example below to create stereo space, wait for the result, then watch it in anaglyph, side-by-side, or generated view. If a quota limit appears, duplicate the space to continue.
""" ) with gr.Row(): output_mode = Radio( choices=["Anaglyph", "Side-by-side", "Input view", "Generated view"], value="Anaglyph", label=None, container=False, scale=1, elem_classes="horizontal-radio" ) with gr.Row(): image = gr.Image( type="filepath", label="Input/Output Image", elem_classes="result-image", height=480, ) outputs_gallery = gr.Gallery( visible=False, label="Computed Outputs", show_label=False, value=[], ) def update_image_from_gallery(gallery_data, current_mode, current_image=None): if gallery_data is None or not gallery_data: return current_image if current_image is not None else None result = None for item in gallery_data: if isinstance(item, (list, tuple)) and len(item) >= 2: img, label = item[0], item[1] if label == current_mode: result = img break elif isinstance(item, str): continue if result is None and gallery_data: first_item = gallery_data[0] if isinstance(first_item, (list, tuple)) and len(first_item) >= 1: result = first_item[0] elif isinstance(first_item, str): result = first_item if result is None: return current_image if current_image is not None else None if isinstance(result, Image.Image): with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp: result.save(tmp.name) return tmp.name return result def process_new_image(img, current_mode): if img is None: placeholder = create_placeholder_image() return [], placeholder, gr.update() all_outputs = process_all_modes(img) gallery_data = [] available_modes = [] for mode_name, mode_image in all_outputs.items(): gallery_data.append([mode_image, mode_name]) available_modes.append(mode_name) placeholder = create_placeholder_image() radio_value = current_mode if current_mode in available_modes else (available_modes[0] if available_modes else "Anaglyph") return gallery_data, placeholder, gr.update(choices=available_modes, value=radio_value) def process_example_simple(img): if img is None: placeholder = create_placeholder_image() return [], placeholder, gr.update() all_outputs = process_all_modes(img) gallery_data = [] available_modes = [] for mode_name, mode_image in all_outputs.items(): gallery_data.append([mode_image, mode_name]) available_modes.append(mode_name) placeholder = create_placeholder_image() radio_value = available_modes[0] if available_modes else "Anaglyph" return gallery_data, placeholder, gr.update(choices=available_modes, value=radio_value) def clear_image(): return None, [] examples_list = get_example_images() if examples_list: def process_example_wrapper(img): gallery_data, placeholder_image, radio_update = process_example_simple(img) return gallery_data, placeholder_image, radio_update examples_component = gr.Examples( examples=examples_list, inputs=[image], outputs=[outputs_gallery, image, output_mode], fn=process_example_wrapper, cache_examples=True, cache_mode="lazy", label="Example Images", elem_id="example-images-gallery", ) def process_upload_wrapper(img, current_mode): gallery_data, blocked_image, radio_update = process_new_image(img, current_mode) return gallery_data, blocked_image, radio_update upload_event = image.upload( fn=process_upload_wrapper, inputs=[image, output_mode], outputs=[outputs_gallery, image, output_mode] ) def on_gallery_change(gallery_data, current_mode, current_image): if not gallery_data or len(gallery_data) == 0: return current_image, gr.update(interactive=True) updated_image = update_image_from_gallery(gallery_data, current_mode, current_image) return updated_image, gr.update(interactive=True) gallery_change_event = outputs_gallery.change( fn=on_gallery_change, inputs=[outputs_gallery, output_mode, image], outputs=[image, image] ) def switch_mode_handler(current_mode, gallery_data, current_image): updated_image = update_image_from_gallery(gallery_data, current_mode, current_image) return updated_image output_mode.change( fn=switch_mode_handler, inputs=[output_mode, outputs_gallery, image], outputs=image ) image.clear( fn=clear_image, outputs=[image, outputs_gallery] ) if __name__ == "__main__": demo.queue().launch( server_name="0.0.0.0", ssr_mode=False, css=""" #example-images-gallery button[class*="gallery-item"][class*="svelte-"] { min-width: max(96px, calc(100vw / 8)); min-height: max(96px, calc(100vw / 8)); width: max(96px, calc(100vw / 8)); height: max(96px, calc(100vw / 8)); } #example-images-gallery button[class*="gallery-item"] div[class*="container"] { min-width: max(96px, calc(100vw / 8)); min-height: max(96px, calc(100vw / 8)); width: max(96px, calc(100vw / 8)); height: max(96px, calc(100vw / 8)); } #example-images-gallery button[class*="gallery-item"] img { min-width: max(96px, calc(100vw / 8)); min-height: max(96px, calc(100vw / 8)); width: max(96px, calc(100vw / 8)); height: max(96px, calc(100vw / 8)); object-fit: cover; } """, head=""" """ )