TRELLIS_2_4K / app.py
simonpick's picture
Update app.py
11eb381 verified
import gradio as gr
from gradio_client import Client, handle_file
import spaces
import os
os.environ["OPENCV_IO_ENABLE_OPENEXR"] = '1'
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
os.environ["ATTN_BACKEND"] = "flash_attn_3"
os.environ["FLEX_GEMM_AUTOTUNE_CACHE_PATH"] = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'autotune_cache.json')
os.environ["FLEX_GEMM_AUTOTUNER_VERBOSE"] = '1'
from datetime import datetime
import shutil
import cv2
from typing import *
import torch
import numpy as np
from PIL import Image
import base64
import io
import tempfile
from trellis2.modules.sparse import SparseTensor
from trellis2.pipelines import Trellis2ImageTo3DPipeline
from trellis2.renderers import EnvMap
from trellis2.utils import render_utils
import o_voxel
MAX_SEED = np.iinfo(np.int32).max
TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp')
MODES = [
{"name": "Normal", "icon": "assets/app/normal.png", "render_key": "normal"},
{"name": "Clay", "icon": "assets/app/clay.png", "render_key": "clay"},
{"name": "Color", "icon": "assets/app/basecolor.png", "render_key": "base_color"},
{"name": "Forest", "icon": "assets/app/hdri_forest.png", "render_key": "shaded_forest"},
{"name": "Sunset", "icon": "assets/app/hdri_sunset.png", "render_key": "shaded_sunset"},
{"name": "Courtyard", "icon": "assets/app/hdri_courtyard.png", "render_key": "shaded_courtyard"},
]
STEPS = 8
DEFAULT_MODE = 3
DEFAULT_STEP = 3
css = """
/* ═══════════════════════════════════════════════════════════════
TRELLIS.2 β€” Modern Dark Theme
═══════════════════════════════════════════════════════════════ */
:root {
--accent: #6366f1;
--accent-hover: #818cf8;
--accent-glow: rgba(99, 102, 241, 0.3);
--surface-0: #0a0a0b;
--surface-1: #111113;
--surface-2: #1a1a1d;
--surface-3: #242428;
--border: rgba(255, 255, 255, 0.06);
--text-primary: #fafafa;
--text-secondary: rgba(255, 255, 255, 0.5);
--radius: 16px;
--radius-sm: 10px;
}
/* Global Overrides */
.gradio-container {
background: var(--surface-0) !important;
width: 100% !important;
min-width: 800px !important;
max-width: 1800px !important;
margin: 0 auto !important;
padding: 0 40px !important;
box-sizing: border-box !important;
}
.gradio-container > .main {
gap: 0 !important;
width: 100% !important;
max-width: none !important;
}
.contain {
display: flex !important;
flex-direction: column !important;
max-width: none !important;
}
.dark {
--block-background-fill: var(--surface-1) !important;
--block-border-color: var(--border) !important;
--body-background-fill: var(--surface-0) !important;
--color-accent: var(--accent) !important;
}
/* Header */
.app-header {
text-align: center;
padding: 48px 20px 36px;
border-bottom: 1px solid var(--border);
margin-bottom: 32px;
width: 100%;
}
.app-header h1 {
font-family: 'SF Pro Display', -apple-system, BlinkMacSystemFont, sans-serif;
font-size: 2.5rem;
font-weight: 600;
letter-spacing: -0.03em;
background: linear-gradient(135deg, #fff 0%, rgba(255,255,255,0.7) 100%);
-webkit-background-clip: text;
-webkit-text-fill-color: transparent;
margin: 0 0 8px 0;
}
.app-header p {
color: var(--text-secondary);
font-size: 1rem;
margin: 0;
font-weight: 400;
}
/* Panels */
.panel {
background: var(--surface-1) !important;
border: 1px solid var(--border) !important;
border-radius: var(--radius) !important;
overflow: hidden;
}
.panel-title {
font-size: 0.7rem;
text-transform: uppercase;
letter-spacing: 0.1em;
color: var(--text-secondary);
padding: 16px 20px 8px;
font-weight: 600;
}
/* Upload Area */
.upload-zone {
min-height: 280px !important;
border: 2px dashed var(--border) !important;
border-radius: var(--radius) !important;
background: var(--surface-2) !important;
transition: all 0.3s ease;
}
.upload-zone:hover {
border-color: var(--accent) !important;
background: rgba(99, 102, 241, 0.05) !important;
}
/* Buttons */
.primary-btn {
background: var(--accent) !important;
border: none !important;
border-radius: var(--radius-sm) !important;
color: white !important;
font-weight: 600 !important;
padding: 14px 28px !important;
font-size: 0.95rem !important;
transition: all 0.2s ease !important;
box-shadow: 0 4px 20px var(--accent-glow) !important;
}
.primary-btn:hover {
background: var(--accent-hover) !important;
transform: translateY(-1px);
box-shadow: 0 6px 30px var(--accent-glow) !important;
}
.secondary-btn {
background: var(--surface-3) !important;
border: 1px solid var(--border) !important;
border-radius: var(--radius-sm) !important;
color: var(--text-primary) !important;
font-weight: 500 !important;
transition: all 0.2s ease !important;
}
.secondary-btn:hover {
background: var(--surface-2) !important;
border-color: var(--accent) !important;
}
/* Sliders & Inputs */
input[type="range"] {
accent-color: var(--accent) !important;
}
.wrap input, .wrap textarea {
background: var(--surface-2) !important;
border: 1px solid var(--border) !important;
border-radius: var(--radius-sm) !important;
color: var(--text-primary) !important;
}
/* Radio Buttons */
.gr-radio-row {
gap: 8px !important;
}
.gr-radio-row label {
background: var(--surface-2) !important;
border: 1px solid var(--border) !important;
border-radius: var(--radius-sm) !important;
padding: 10px 18px !important;
transition: all 0.2s ease !important;
}
.gr-radio-row label:hover {
border-color: var(--accent) !important;
}
.gr-radio-row label.selected {
background: var(--accent) !important;
border-color: var(--accent) !important;
}
/* Accordion */
.gr-accordion {
border: 1px solid var(--border) !important;
border-radius: var(--radius-sm) !important;
background: var(--surface-2) !important;
}
/* Walkthrough/Stepper */
.stepper-wrapper { padding: 0; }
.stepper-container { padding: 0; align-items: center; }
.step-button { flex-direction: row; }
.step-connector { transform: none; }
.step-number { width: 16px; height: 16px; }
.step-label { position: relative; bottom: 0; }
/* Loading States */
.wrap.center.full { inset: 0; height: 100%; }
.wrap.center.full.translucent { background: var(--surface-1); }
/* ═══════════════════════════════════════════════════════════════
3D PREVIEWER COMPONENT
═══════════════════════════════════════════════════════════════ */
.previewer-container {
position: relative;
width: 100%;
height: 720px;
display: flex;
flex-direction: column;
align-items: center;
justify-content: center;
padding: 24px;
background: radial-gradient(ellipse at center, var(--surface-2) 0%, var(--surface-1) 100%);
border-radius: var(--radius);
}
/* Viewport */
.previewer-container .display-row {
flex: 1;
width: 100%;
display: flex;
justify-content: center;
align-items: center;
min-height: 0;
}
.previewer-container .previewer-main-image {
max-width: 100%;
max-height: 100%;
object-fit: contain;
display: none;
border-radius: var(--radius-sm);
box-shadow: 0 20px 60px rgba(0, 0, 0, 0.4);
}
.previewer-container .previewer-main-image.visible {
display: block;
animation: fadeIn 0.3s ease;
}
@keyframes fadeIn {
from { opacity: 0; transform: scale(0.98); }
to { opacity: 1; transform: scale(1); }
}
/* Mode Selector */
.previewer-container .mode-row {
display: flex;
gap: 10px;
margin-top: 20px;
padding: 8px;
background: var(--surface-0);
border-radius: 50px;
border: 1px solid var(--border);
}
.previewer-container .mode-btn {
width: 32px;
height: 32px;
border-radius: 50%;
cursor: pointer;
opacity: 0.4;
transition: all 0.25s cubic-bezier(0.4, 0, 0.2, 1);
border: 2px solid transparent;
object-fit: cover;
}
.previewer-container .mode-btn:hover {
opacity: 0.8;
transform: scale(1.1);
}
.previewer-container .mode-btn.active {
opacity: 1;
border-color: var(--accent);
transform: scale(1.15);
box-shadow: 0 0 20px var(--accent-glow);
}
/* Rotation Slider */
.previewer-container .slider-row {
width: 100%;
max-width: 320px;
margin-top: 16px;
}
.previewer-container input[type=range] {
-webkit-appearance: none;
width: 100%;
background: transparent;
cursor: pointer;
}
.previewer-container input[type=range]::-webkit-slider-runnable-track {
width: 100%;
height: 6px;
background: var(--surface-0);
border-radius: 3px;
border: 1px solid var(--border);
}
.previewer-container input[type=range]::-webkit-slider-thumb {
-webkit-appearance: none;
height: 18px;
width: 18px;
border-radius: 50%;
background: var(--accent);
margin-top: -7px;
box-shadow: 0 2px 10px var(--accent-glow);
transition: transform 0.15s ease;
}
.previewer-container input[type=range]::-webkit-slider-thumb:hover {
transform: scale(1.2);
}
/* Empty State */
.empty-state {
display: flex;
flex-direction: column;
align-items: center;
gap: 16px;
color: var(--text-secondary);
}
.empty-state svg {
opacity: 0.3;
}
.empty-state p {
font-size: 0.9rem;
margin: 0;
}
/* Block Label Override */
.gradio-container .padded:has(.previewer-container) { padding: 0 !important; }
.gradio-container:has(.previewer-container) [data-testid="block-label"] {
position: absolute;
top: 0;
left: 0;
}
/* GLB Viewer */
.model3d-container {
background: var(--surface-2) !important;
border-radius: var(--radius) !important;
}
/* Footer Note */
.footer-note {
text-align: center;
color: var(--text-secondary);
font-size: 0.8rem;
padding: 20px;
border-top: 1px solid var(--border);
margin-top: 32px;
width: 100%;
}
/* Main Layout - Force side by side */
#main-row {
width: 100% !important;
max-width: none !important;
margin: 0 !important;
display: flex !important;
flex-direction: row !important;
flex-wrap: nowrap !important;
gap: 32px !important;
align-items: flex-start !important;
}
#main-row.row {
flex-wrap: nowrap !important;
max-width: none !important;
}
#input-col {
flex: 0 0 400px !important;
width: 400px !important;
min-width: 350px !important;
max-width: 450px !important;
}
#preview-col {
flex: 1 1 auto !important;
min-width: 500px !important;
}
@media (max-width: 900px) {
#main-row {
flex-direction: column !important;
}
#input-col,
#preview-col {
flex: 1 1 auto !important;
width: 100% !important;
max-width: 100% !important;
min-width: 0 !important;
}
}
"""
head = """
<link rel="preconnect" href="https://fonts.googleapis.com">
<link rel="preconnect" href="https://fonts.gstatic.com" crossorigin>
<link href="https://fonts.googleapis.com/css2?family=Inter:wght@400;500;600;700&display=swap" rel="stylesheet">
<script>
function refreshView(mode, step) {
const allImgs = document.querySelectorAll('.previewer-main-image');
for (let i = 0; i < allImgs.length; i++) {
const img = allImgs[i];
if (img.classList.contains('visible')) {
const id = img.id;
const [_, m, s] = id.split('-');
if (mode === -1) mode = parseInt(m.slice(1));
if (step === -1) step = parseInt(s.slice(1));
break;
}
}
allImgs.forEach(img => img.classList.remove('visible'));
const targetId = 'view-m' + mode + '-s' + step;
const targetImg = document.getElementById(targetId);
if (targetImg) targetImg.classList.add('visible');
const allBtns = document.querySelectorAll('.mode-btn');
allBtns.forEach((btn, idx) => {
if (idx === mode) btn.classList.add('active');
else btn.classList.remove('active');
});
}
function selectMode(mode) { refreshView(mode, -1); }
function onSliderChange(val) { refreshView(-1, parseInt(val)); }
</script>
"""
empty_html = """
<div class="previewer-container">
<div class="empty-state">
<svg width="64" height="64" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round">
<path d="M21 16V8a2 2 0 0 0-1-1.73l-7-4a2 2 0 0 0-2 0l-7 4A2 2 0 0 0 3 8v8a2 2 0 0 0 1 1.73l7 4a2 2 0 0 0 2 0l7-4A2 2 0 0 0 21 16z"></path>
<polyline points="3.27 6.96 12 12.01 20.73 6.96"></polyline>
<line x1="12" y1="22.08" x2="12" y2="12"></line>
</svg>
<p>Upload an image to generate 3D</p>
</div>
</div>
"""
def image_to_base64(image):
buffered = io.BytesIO()
image = image.convert("RGB")
image.save(buffered, format="jpeg", quality=85)
img_str = base64.b64encode(buffered.getvalue()).decode()
return f"data:image/jpeg;base64,{img_str}"
def start_session(req: gr.Request):
user_dir = os.path.join(TMP_DIR, str(req.session_hash))
os.makedirs(user_dir, exist_ok=True)
def end_session(req: gr.Request):
user_dir = os.path.join(TMP_DIR, str(req.session_hash))
shutil.rmtree(user_dir)
def remove_background(input: Image.Image) -> Image.Image:
with tempfile.NamedTemporaryFile(suffix='.png') as f:
input = input.convert('RGB')
input.save(f.name)
output = rmbg_client.predict(handle_file(f.name), api_name="/image")[0][0]
output = Image.open(output)
return output
def preprocess_image(input: Image.Image) -> Image.Image:
"""Preprocess the input image."""
has_alpha = False
if input.mode == 'RGBA':
alpha = np.array(input)[:, :, 3]
if not np.all(alpha == 255):
has_alpha = True
max_size = max(input.size)
scale = min(1, 1024 / max_size)
if scale < 1:
input = input.resize((int(input.width * scale), int(input.height * scale)), Image.Resampling.LANCZOS)
if has_alpha:
output = input
else:
output = remove_background(input)
output_np = np.array(output)
alpha = output_np[:, :, 3]
bbox = np.argwhere(alpha > 0.8 * 255)
bbox = np.min(bbox[:, 1]), np.min(bbox[:, 0]), np.max(bbox[:, 1]), np.max(bbox[:, 0])
center = (bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2
size = max(bbox[2] - bbox[0], bbox[3] - bbox[1])
size = int(size * 1)
bbox = center[0] - size // 2, center[1] - size // 2, center[0] + size // 2, center[1] + size // 2
output = output.crop(bbox)
output = np.array(output).astype(np.float32) / 255
output = output[:, :, :3] * output[:, :, 3:4]
output = Image.fromarray((output * 255).astype(np.uint8))
return output
def pack_state(latents: Tuple[SparseTensor, SparseTensor, int]) -> dict:
shape_slat, tex_slat, res = latents
return {
'shape_slat_feats': shape_slat.feats.cpu().numpy(),
'tex_slat_feats': tex_slat.feats.cpu().numpy(),
'coords': shape_slat.coords.cpu().numpy(),
'res': res,
}
def unpack_state(state: dict) -> Tuple[SparseTensor, SparseTensor, int]:
shape_slat = SparseTensor(
feats=torch.from_numpy(state['shape_slat_feats']).cuda(),
coords=torch.from_numpy(state['coords']).cuda(),
)
tex_slat = shape_slat.replace(torch.from_numpy(state['tex_slat_feats']).cuda())
return shape_slat, tex_slat, state['res']
@spaces.GPU(duration=180)
def generate_and_extract(
image: Image.Image,
req: gr.Request,
progress=gr.Progress(track_tqdm=True),
) -> Tuple[str, str, str]:
"""
Combined function: Generate 3D from image AND extract GLB in one GPU session.
This avoids issues with chaining multiple @spaces.GPU functions.
"""
user_dir = os.path.join(TMP_DIR, str(req.session_hash))
os.makedirs(user_dir, exist_ok=True)
# Hardcoded values
seed = np.random.randint(0, MAX_SEED)
decimation_target = 300000
texture_size = 4096
# === STAGE 1: Generate 3D ===
outputs, latents = pipeline.run(
image,
seed=seed,
preprocess_image=False,
sparse_structure_sampler_params={
"steps": 12,
"guidance_strength": 7.5,
"guidance_rescale": 0.7,
"rescale_t": 5.0,
},
shape_slat_sampler_params={
"steps": 12,
"guidance_strength": 7.5,
"guidance_rescale": 0.5,
"rescale_t": 3.0,
},
tex_slat_sampler_params={
"steps": 12,
"guidance_strength": 1.0,
"guidance_rescale": 0.0,
"rescale_t": 3.0,
},
pipeline_type="1024_cascade",
return_latent=True,
)
mesh = outputs[0]
mesh.simplify(16777216)
# Render preview images
images = render_utils.render_snapshot(mesh, resolution=1024, r=2, fov=36, nviews=STEPS, envmap=envmap)
# Build preview HTML
images_html = ""
for m_idx, mode in enumerate(MODES):
for s_idx in range(STEPS):
unique_id = f"view-m{m_idx}-s{s_idx}"
is_visible = (m_idx == DEFAULT_MODE and s_idx == DEFAULT_STEP)
vis_class = "visible" if is_visible else ""
img_base64 = image_to_base64(Image.fromarray(images[mode['render_key']][s_idx]))
images_html += f'<img id="{unique_id}" class="previewer-main-image {vis_class}" src="{img_base64}" loading="eager">'
btns_html = ""
for idx, mode in enumerate(MODES):
active_class = "active" if idx == DEFAULT_MODE else ""
btns_html += f'<img src="{mode["icon_base64"]}" class="mode-btn {active_class}" onclick="selectMode({idx})" title="{mode["name"]}">'
preview_html = f"""
<div class="previewer-container">
<div class="display-row">{images_html}</div>
<div class="mode-row">{btns_html}</div>
<div class="slider-row">
<input type="range" min="0" max="{STEPS - 1}" value="{DEFAULT_STEP}" step="1" oninput="onSliderChange(this.value)">
</div>
</div>
"""
# === STAGE 2: Extract GLB ===
shape_slat, tex_slat, res = latents
mesh = pipeline.decode_latent(shape_slat, tex_slat, res)[0]
mesh.simplify(16777216)
glb = o_voxel.postprocess.to_glb(
vertices=mesh.vertices,
faces=mesh.faces,
attr_volume=mesh.attrs,
coords=mesh.coords,
attr_layout=pipeline.pbr_attr_layout,
grid_size=res,
aabb=[[-0.5, -0.5, -0.5], [0.5, 0.5, 0.5]],
decimation_target=decimation_target,
texture_size=texture_size,
remesh=True,
remesh_band=1,
remesh_project=0,
use_tqdm=True,
)
now = datetime.now()
timestamp = now.strftime("%Y-%m-%dT%H%M%S") + f".{now.microsecond // 1000:03d}"
glb_path = os.path.join(user_dir, f'sample_{timestamp}.glb')
glb.export(glb_path, extension_webp=True)
torch.cuda.empty_cache()
# Return: preview_html, glb_path (for viewer), glb_path (for download)
return preview_html, glb_path, glb_path
# ═══════════════════════════════════════════════════════════════
# GRADIO INTERFACE
# ═══════════════════════════════════════════════════════════════
with gr.Blocks(theme=gr.themes.Base(primary_hue="indigo"), delete_cache=(600, 600)) as demo:
# Header
gr.HTML("""
<div class="app-header">
<h1>TRELLIS.2</h1>
<p>Transform any image into a high-quality 3D asset</p>
</div>
""")
with gr.Row(equal_height=False, elem_id="main-row"):
# Left Panel β€” Input (span 1)
with gr.Column(scale=1, min_width=320, elem_id="input-col"):
# Image Upload
image_prompt = gr.Image(
label="Input Image",
format="png",
image_mode="RGBA",
type="pil",
height=400,
elem_classes=["upload-zone"]
)
# Generate Button
generate_btn = gr.Button("Generate 3D", variant="primary", elem_classes=["primary-btn"], size="lg")
# Right Panel β€” Preview (span 2)
with gr.Column(scale=2, elem_id="preview-col"):
with gr.Walkthrough(selected=0) as walkthrough:
with gr.Step("Preview", id=0):
preview_output = gr.HTML(empty_html, label="3D Preview", show_label=False)
with gr.Step("Export", id=1):
glb_output = gr.Model3D(
label="GLB Model",
height=640,
show_label=False,
display_mode="solid",
clear_color=(0.06, 0.06, 0.07, 0.0) # Alpha = 0 for transparent background
)
download_btn = gr.DownloadButton("Download GLB", elem_classes=["primary-btn"], size="lg")
# Footer
gr.HTML('<div class="footer-note">Generation includes automatic GLB extraction. This may take 90+ seconds total.</div>')
# Event Handlers
demo.load(start_session)
demo.unload(end_session)
image_prompt.upload(
preprocess_image,
inputs=[image_prompt],
outputs=[image_prompt],
)
# Single GPU call: Generate 3D + Extract GLB
generate_btn.click(
generate_and_extract,
inputs=[image_prompt],
outputs=[preview_output, glb_output, download_btn],
).then(
lambda: gr.Walkthrough(selected=1), outputs=walkthrough
)
# ═══════════════════════════════════════════════════════════════
# LAUNCH
# ═══════════════════════════════════════════════════════════════
if __name__ == "__main__":
os.makedirs(TMP_DIR, exist_ok=True)
# Load mode icons
for i in range(len(MODES)):
icon = Image.open(MODES[i]['icon'])
MODES[i]['icon_base64'] = image_to_base64(icon)
rmbg_client = Client("briaai/BRIA-RMBG-2.0")
pipeline = Trellis2ImageTo3DPipeline.from_pretrained('microsoft/TRELLIS.2-4B')
pipeline.rembg_model = None
pipeline.low_vram = False
pipeline.cuda()
envmap = {
'forest': EnvMap(torch.tensor(
cv2.cvtColor(cv2.imread('assets/hdri/forest.exr', cv2.IMREAD_UNCHANGED), cv2.COLOR_BGR2RGB),
dtype=torch.float32, device='cuda'
)),
'sunset': EnvMap(torch.tensor(
cv2.cvtColor(cv2.imread('assets/hdri/sunset.exr', cv2.IMREAD_UNCHANGED), cv2.COLOR_BGR2RGB),
dtype=torch.float32, device='cuda'
)),
'courtyard': EnvMap(torch.tensor(
cv2.cvtColor(cv2.imread('assets/hdri/courtyard.exr', cv2.IMREAD_UNCHANGED), cv2.COLOR_BGR2RGB),
dtype=torch.float32, device='cuda'
)),
}
demo.launch(css=css, head=head)