simonpick commited on
Commit
767ffb0
Β·
verified Β·
1 Parent(s): ca86a1c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +380 -270
app.py CHANGED
@@ -29,11 +29,11 @@ MAX_SEED = np.iinfo(np.int32).max
29
  TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp')
30
  MODES = [
31
  {"name": "Normal", "icon": "assets/app/normal.png", "render_key": "normal"},
32
- {"name": "Clay render", "icon": "assets/app/clay.png", "render_key": "clay"},
33
- {"name": "Base color", "icon": "assets/app/basecolor.png", "render_key": "base_color"},
34
- {"name": "HDRI forest", "icon": "assets/app/hdri_forest.png", "render_key": "shaded_forest"},
35
- {"name": "HDRI sunset", "icon": "assets/app/hdri_sunset.png", "render_key": "shaded_sunset"},
36
- {"name": "HDRI courtyard", "icon": "assets/app/hdri_courtyard.png", "render_key": "shaded_courtyard"},
37
  ]
38
  STEPS = 8
39
  DEFAULT_MODE = 3
@@ -41,206 +41,349 @@ DEFAULT_STEP = 3
41
 
42
 
43
  css = """
44
- /* Overwrite Gradio Default Style */
45
- .stepper-wrapper {
46
- padding: 0;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
  }
48
 
49
- .stepper-container {
50
- padding: 0;
51
- align-items: center;
 
52
  }
53
 
54
- .step-button {
55
- flex-direction: row;
 
 
 
56
  }
57
 
58
- .step-connector {
59
- transform: none;
 
 
 
 
60
  }
61
 
62
- .step-number {
63
- width: 16px;
64
- height: 16px;
 
 
 
 
 
 
65
  }
66
 
67
- .step-label {
68
- position: relative;
69
- bottom: 0;
 
 
70
  }
71
 
72
- .wrap.center.full {
73
- inset: 0;
74
- height: 100%;
 
 
 
75
  }
76
 
77
- .wrap.center.full.translucent {
78
- background: var(--block-background-fill);
 
 
 
 
 
79
  }
80
 
81
- .meta-text-center {
82
- display: block !important;
83
- position: absolute !important;
84
- top: unset !important;
85
- bottom: 0 !important;
86
- right: 0 !important;
87
- transform: unset !important;
88
  }
89
 
90
- /* Previewer */
91
- .previewer-container {
92
- position: relative;
93
- font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, Helvetica, Arial, sans-serif;
94
- width: 100%;
95
- height: 722px;
96
- margin: 0 auto;
97
- padding: 20px;
98
- display: flex;
99
- flex-direction: column;
100
- align-items: center;
101
- justify-content: center;
102
  }
103
 
104
- .previewer-container .tips-icon {
105
- position: absolute;
106
- right: 10px;
107
- top: 10px;
108
- z-index: 10;
109
- border-radius: 10px;
110
- color: #fff;
111
- background-color: var(--color-accent);
112
- padding: 3px 6px;
113
- user-select: none;
 
114
  }
115
 
116
- .previewer-container .tips-text {
117
- position: absolute;
118
- right: 10px;
119
- top: 50px;
120
- color: #fff;
121
- background-color: var(--color-accent);
122
- border-radius: 10px;
123
- padding: 6px;
124
- text-align: left;
125
- max-width: 300px;
126
- z-index: 10;
127
- transition: all 0.3s;
128
- opacity: 0%;
129
- user-select: none;
130
- }
131
-
132
- .previewer-container .tips-text p {
133
- font-size: 14px;
134
- line-height: 1.2;
135
- }
136
-
137
- .tips-icon:hover + .tips-text {
138
- display: block;
139
- opacity: 100%;
140
  }
141
 
142
- /* Row 1: Display Modes */
143
- .previewer-container .mode-row {
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
144
  width: 100%;
 
145
  display: flex;
146
- gap: 8px;
 
147
  justify-content: center;
148
- margin-bottom: 20px;
149
- flex-wrap: wrap;
150
- }
151
- .previewer-container .mode-btn {
152
- width: 24px;
153
- height: 24px;
154
- border-radius: 50%;
155
- cursor: pointer;
156
- opacity: 0.5;
157
- transition: all 0.2s;
158
- border: 2px solid #ddd;
159
- object-fit: cover;
160
- }
161
- .previewer-container .mode-btn:hover { opacity: 0.9; transform: scale(1.1); }
162
- .previewer-container .mode-btn.active {
163
- opacity: 1;
164
- border-color: var(--color-accent);
165
- transform: scale(1.1);
166
  }
167
 
168
- /* Row 2: Display Image */
169
  .previewer-container .display-row {
170
- margin-bottom: 20px;
171
- min-height: 400px;
172
  width: 100%;
173
- flex-grow: 1;
174
  display: flex;
175
  justify-content: center;
176
  align-items: center;
 
177
  }
 
178
  .previewer-container .previewer-main-image {
179
  max-width: 100%;
180
  max-height: 100%;
181
- flex-grow: 1;
182
  object-fit: contain;
183
  display: none;
 
 
184
  }
 
185
  .previewer-container .previewer-main-image.visible {
186
  display: block;
 
187
  }
188
 
189
- /* Row 3: Custom HTML Slider */
190
- .previewer-container .slider-row {
191
- width: 100%;
 
 
 
 
192
  display: flex;
193
- flex-direction: column;
194
- align-items: center;
195
  gap: 10px;
196
- padding: 0 10px;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
197
  }
198
 
199
  .previewer-container input[type=range] {
200
  -webkit-appearance: none;
201
  width: 100%;
202
- max-width: 400px;
203
  background: transparent;
 
204
  }
 
205
  .previewer-container input[type=range]::-webkit-slider-runnable-track {
206
  width: 100%;
207
- height: 8px;
208
- cursor: pointer;
209
- background: #ddd;
210
- border-radius: 5px;
211
  }
 
212
  .previewer-container input[type=range]::-webkit-slider-thumb {
213
- height: 20px;
214
- width: 20px;
215
- border-radius: 50%;
216
- background: var(--color-accent);
217
- cursor: pointer;
218
  -webkit-appearance: none;
219
- margin-top: -6px;
220
- box-shadow: 0 2px 5px rgba(0,0,0,0.2);
221
- transition: transform 0.1s;
 
 
 
 
222
  }
 
223
  .previewer-container input[type=range]::-webkit-slider-thumb:hover {
224
  transform: scale(1.2);
225
  }
226
 
227
- /* Overwrite Previewer Block Style */
228
- .gradio-container .padded:has(.previewer-container) {
229
- padding: 0 !important;
 
 
 
 
 
 
 
 
230
  }
231
 
 
 
 
 
 
 
 
232
  .gradio-container:has(.previewer-container) [data-testid="block-label"] {
233
  position: absolute;
234
  top: 0;
235
  left: 0;
236
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
237
  """
238
 
239
 
240
  head = """
 
 
 
 
241
  <script>
242
  function refreshView(mode, step) {
243
- // 1. Find current mode and step
244
  const allImgs = document.querySelectorAll('.previewer-main-image');
245
  for (let i = 0; i < allImgs.length; i++) {
246
  const img = allImgs[i];
@@ -253,21 +396,11 @@ head = """
253
  }
254
  }
255
 
256
- // 2. Hide ALL images
257
- // We select all elements with class 'previewer-main-image'
258
  allImgs.forEach(img => img.classList.remove('visible'));
259
-
260
- // 3. Construct the specific ID for the current state
261
- // Format: view-m{mode}-s{step}
262
  const targetId = 'view-m' + mode + '-s' + step;
263
  const targetImg = document.getElementById(targetId);
 
264
 
265
- // 4. Show ONLY the target
266
- if (targetImg) {
267
- targetImg.classList.add('visible');
268
- }
269
-
270
- // 5. Update Button Highlights
271
  const allBtns = document.querySelectorAll('.mode-btn');
272
  allBtns.forEach((btn, idx) => {
273
  if (idx === mode) btn.classList.add('active');
@@ -275,23 +408,22 @@ head = """
275
  });
276
  }
277
 
278
- // --- Action: Switch Mode ---
279
- function selectMode(mode) {
280
- refreshView(mode, -1);
281
- }
282
-
283
- // --- Action: Slider Change ---
284
- function onSliderChange(val) {
285
- refreshView(-1, parseInt(val));
286
- }
287
  </script>
288
  """
289
 
290
 
291
- empty_html = f"""
292
  <div class="previewer-container">
293
- <svg style=" opacity: .5; height: var(--size-5); color: var(--body-text-color);"
294
- xmlns="http://www.w3.org/2000/svg" width="100%" height="100%" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round" class="feather feather-image"><rect x="3" y="3" width="18" height="18" rx="2" ry="2"></rect><circle cx="8.5" cy="8.5" r="1.5"></circle><polyline points="21 15 16 10 5 21"></polyline></svg>
 
 
 
 
 
 
295
  </div>
296
  """
297
 
@@ -324,10 +456,7 @@ def remove_background(input: Image.Image) -> Image.Image:
324
 
325
 
326
  def preprocess_image(input: Image.Image) -> Image.Image:
327
- """
328
- Preprocess the input image.
329
- """
330
- # if has alpha channel, use it directly; otherwise, remove background
331
  has_alpha = False
332
  if input.mode == 'RGBA':
333
  alpha = np.array(input)[:, :, 3]
@@ -349,7 +478,7 @@ def preprocess_image(input: Image.Image) -> Image.Image:
349
  size = max(bbox[2] - bbox[0], bbox[3] - bbox[1])
350
  size = int(size * 1)
351
  bbox = center[0] - size // 2, center[1] - size // 2, center[0] + size // 2, center[1] + size // 2
352
- output = output.crop(bbox) # type: ignore
353
  output = np.array(output).astype(np.float32) / 255
354
  output = output[:, :, :3] * output[:, :, 3:4]
355
  output = Image.fromarray((output * 255).astype(np.uint8))
@@ -376,9 +505,6 @@ def unpack_state(state: dict) -> Tuple[SparseTensor, SparseTensor, int]:
376
 
377
 
378
  def get_seed(randomize_seed: bool, seed: int) -> int:
379
- """
380
- Get the random seed.
381
- """
382
  return np.random.randint(0, MAX_SEED) if randomize_seed else seed
383
 
384
 
@@ -402,7 +528,6 @@ def image_to_3d(
402
  req: gr.Request,
403
  progress=gr.Progress(track_tqdm=True),
404
  ) -> str:
405
- # --- Sampling ---
406
  outputs, latents = pipeline.run(
407
  image,
408
  seed=seed,
@@ -433,70 +558,32 @@ def image_to_3d(
433
  return_latent=True,
434
  )
435
  mesh = outputs[0]
436
- mesh.simplify(16777216) # nvdiffrast limit
437
  images = render_utils.render_snapshot(mesh, resolution=1024, r=2, fov=36, nviews=STEPS, envmap=envmap)
438
  state = pack_state(latents)
439
  torch.cuda.empty_cache()
440
 
441
- # --- HTML Construction ---
442
- # The Stack of 48 Images
443
  images_html = ""
444
  for m_idx, mode in enumerate(MODES):
445
  for s_idx in range(STEPS):
446
- # ID Naming Convention: view-m{mode}-s{step}
447
  unique_id = f"view-m{m_idx}-s{s_idx}"
448
-
449
- # Logic: Only Mode 0, Step 0 is visible initially
450
  is_visible = (m_idx == DEFAULT_MODE and s_idx == DEFAULT_STEP)
451
  vis_class = "visible" if is_visible else ""
452
-
453
- # Image Source
454
  img_base64 = image_to_base64(Image.fromarray(images[mode['render_key']][s_idx]))
455
-
456
- # Render the Tag
457
- images_html += f"""
458
- <img id="{unique_id}"
459
- class="previewer-main-image {vis_class}"
460
- src="{img_base64}"
461
- loading="eager">
462
- """
463
 
464
- # Button Row HTML
465
  btns_html = ""
466
  for idx, mode in enumerate(MODES):
467
  active_class = "active" if idx == DEFAULT_MODE else ""
468
- # Note: onclick calls the JS function defined in Head
469
- btns_html += f"""
470
- <img src="{mode['icon_base64']}"
471
- class="mode-btn {active_class}"
472
- onclick="selectMode({idx})"
473
- title="{mode['name']}">
474
- """
475
 
476
- # Assemble the full component
477
  full_html = f"""
478
  <div class="previewer-container">
479
- <div class="tips-wrapper">
480
- <div class="tips-icon">πŸ’‘Tips</div>
481
- <div class="tips-text">
482
- <p>● <b>Render Mode</b> - Click on the circular buttons to switch between different render modes.</p>
483
- <p>● <b>View Angle</b> - Drag the slider to change the view angle.</p>
484
- </div>
485
- </div>
486
-
487
- <!-- Row 1: Viewport containing 48 static <img> tags -->
488
- <div class="display-row">
489
- {images_html}
490
- </div>
491
-
492
- <!-- Row 2 -->
493
- <div class="mode-row" id="btn-group">
494
- {btns_html}
495
- </div>
496
-
497
- <!-- Row 3: Slider -->
498
  <div class="slider-row">
499
- <input type="range" id="custom-slider" min="0" max="{STEPS - 1}" value="{DEFAULT_STEP}" step="1" oninput="onSliderChange(this.value)">
500
  </div>
501
  </div>
502
  """
@@ -512,17 +599,6 @@ def extract_glb(
512
  req: gr.Request,
513
  progress=gr.Progress(track_tqdm=True),
514
  ) -> Tuple[str, str]:
515
- """
516
- Extract a GLB file from the 3D model.
517
-
518
- Args:
519
- state (dict): The state of the generated 3D model.
520
- decimation_target (int): The target face count for decimation.
521
- texture_size (int): The texture resolution.
522
-
523
- Returns:
524
- str: The path to the extracted GLB file.
525
- """
526
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
527
  shape_slat, tex_slat, res = unpack_state(state)
528
  mesh = pipeline.decode_latent(shape_slat, tex_slat, res)[0]
@@ -551,72 +627,104 @@ def extract_glb(
551
  return glb_path, glb_path
552
 
553
 
554
- with gr.Blocks(delete_cache=(600, 600)) as demo:
555
- gr.Markdown("""
556
- ## Image to 3D Asset with [TRELLIS.2](https://microsoft.github.io/TRELLIS.2)
557
- * Upload an image (preferably with an alpha-masked foreground object) and click Generate to create a 3D asset.
558
- * Click Extract GLB to export and download the generated GLB file if you're satisfied with the result. Otherwise, try another time.
 
 
 
 
 
 
 
559
  """)
560
 
561
- with gr.Row():
562
- with gr.Column(scale=1, min_width=360):
563
- image_prompt = gr.Image(label="Image Prompt", format="png", image_mode="RGBA", type="pil", height=400)
564
 
565
- resolution = gr.Radio(["512", "1024", "1536"], label="Resolution", value="1024")
566
- seed = gr.Slider(0, MAX_SEED, label="Seed", value=0, step=1)
567
- randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
568
- decimation_target = gr.Slider(100000, 500000, label="Decimation Target", value=300000, step=10000)
569
- texture_size = gr.Slider(1024, 4096, label="Texture Size", value=2048, step=1024)
 
 
 
 
570
 
571
- generate_btn = gr.Button("Generate")
 
 
 
 
 
 
 
572
 
573
- with gr.Accordion(label="Advanced Settings", open=False):
574
- gr.Markdown("Stage 1: Sparse Structure Generation")
575
  with gr.Row():
576
- ss_guidance_strength = gr.Slider(1.0, 10.0, label="Guidance Strength", value=7.5, step=0.1)
577
- ss_guidance_rescale = gr.Slider(0.0, 1.0, label="Guidance Rescale", value=0.7, step=0.01)
578
- ss_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
579
  ss_rescale_t = gr.Slider(1.0, 6.0, label="Rescale T", value=5.0, step=0.1)
580
- gr.Markdown("Stage 2: Shape Generation")
 
 
 
 
581
  with gr.Row():
582
- shape_slat_guidance_strength = gr.Slider(1.0, 10.0, label="Guidance Strength", value=7.5, step=0.1)
583
- shape_slat_guidance_rescale = gr.Slider(0.0, 1.0, label="Guidance Rescale", value=0.5, step=0.01)
584
- shape_slat_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1)
585
  shape_slat_rescale_t = gr.Slider(1.0, 6.0, label="Rescale T", value=3.0, step=0.1)
586
- gr.Markdown("Stage 3: Material Generation")
 
 
 
 
587
  with gr.Row():
588
- tex_slat_guidance_strength = gr.Slider(1.0, 10.0, label="Guidance Strength", value=1.0, step=0.1)
589
- tex_slat_guidance_rescale = gr.Slider(0.0, 1.0, label="Guidance Rescale", value=0.0, step=0.01)
590
- tex_slat_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1)
591
- tex_slat_rescale_t = gr.Slider(1.0, 6.0, label="Rescale T", value=3.0, step=0.1)
592
 
593
- with gr.Column(scale=10):
 
594
  with gr.Walkthrough(selected=0) as walkthrough:
595
  with gr.Step("Preview", id=0):
596
- preview_output = gr.HTML(empty_html, label="3D Asset Preview", show_label=True, container=True)
597
- extract_btn = gr.Button("Extract GLB")
598
- with gr.Step("Extract", id=1):
599
- glb_output = gr.Model3D(label="Extracted GLB", height=724, show_label=True, display_mode="solid", clear_color=(0.25, 0.25, 0.25, 1.0))
600
- download_btn = gr.DownloadButton(label="Download GLB")
601
- gr.Markdown("*We are actively working on improving the speed of GLB extraction. Currently, it may take half a minute or more and face count is limited.*")
602
 
603
- with gr.Column(scale=1, min_width=172):
604
- examples = gr.Examples(
605
- examples=[
606
- f'assets/example_image/{image}'
607
- for image in os.listdir("assets/example_image")
608
- ],
609
- inputs=[image_prompt],
610
- fn=preprocess_image,
611
- outputs=[image_prompt],
612
- run_on_click=True,
613
- examples_per_page=18,
614
- )
615
 
616
  output_buf = gr.State()
617
-
618
 
619
- # Handlers
620
  demo.load(start_session)
621
  demo.unload(end_session)
622
 
@@ -650,14 +758,16 @@ with gr.Blocks(delete_cache=(600, 600)) as demo:
650
  inputs=[output_buf, decimation_target, texture_size],
651
  outputs=[glb_output, download_btn],
652
  )
653
-
654
 
655
- # Launch the Gradio app
 
 
 
 
656
  if __name__ == "__main__":
657
  os.makedirs(TMP_DIR, exist_ok=True)
658
 
659
- # Construct ui components
660
- btn_img_base64_strs = {}
661
  for i in range(len(MODES)):
662
  icon = Image.open(MODES[i]['icon'])
663
  MODES[i]['icon_base64'] = image_to_base64(icon)
@@ -683,4 +793,4 @@ if __name__ == "__main__":
683
  )),
684
  }
685
 
686
- demo.launch(css=css, head=head)
 
29
  TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp')
30
  MODES = [
31
  {"name": "Normal", "icon": "assets/app/normal.png", "render_key": "normal"},
32
+ {"name": "Clay", "icon": "assets/app/clay.png", "render_key": "clay"},
33
+ {"name": "Color", "icon": "assets/app/basecolor.png", "render_key": "base_color"},
34
+ {"name": "Forest", "icon": "assets/app/hdri_forest.png", "render_key": "shaded_forest"},
35
+ {"name": "Sunset", "icon": "assets/app/hdri_sunset.png", "render_key": "shaded_sunset"},
36
+ {"name": "Courtyard", "icon": "assets/app/hdri_courtyard.png", "render_key": "shaded_courtyard"},
37
  ]
38
  STEPS = 8
39
  DEFAULT_MODE = 3
 
41
 
42
 
43
  css = """
44
+ /* ═══════════════════════════════════════════════════════════════
45
+ TRELLIS.2 β€” Modern Dark Theme
46
+ ═══════════════════════════════════════════════════════════════ */
47
+
48
+ :root {
49
+ --accent: #6366f1;
50
+ --accent-hover: #818cf8;
51
+ --accent-glow: rgba(99, 102, 241, 0.3);
52
+ --surface-0: #0a0a0b;
53
+ --surface-1: #111113;
54
+ --surface-2: #1a1a1d;
55
+ --surface-3: #242428;
56
+ --border: rgba(255, 255, 255, 0.06);
57
+ --text-primary: #fafafa;
58
+ --text-secondary: rgba(255, 255, 255, 0.5);
59
+ --radius: 16px;
60
+ --radius-sm: 10px;
61
  }
62
 
63
+ /* Global Overrides */
64
+ .gradio-container {
65
+ background: var(--surface-0) !important;
66
+ max-width: 1400px !important;
67
  }
68
 
69
+ .dark {
70
+ --block-background-fill: var(--surface-1) !important;
71
+ --block-border-color: var(--border) !important;
72
+ --body-background-fill: var(--surface-0) !important;
73
+ --color-accent: var(--accent) !important;
74
  }
75
 
76
+ /* Header */
77
+ .app-header {
78
+ text-align: center;
79
+ padding: 40px 20px 30px;
80
+ border-bottom: 1px solid var(--border);
81
+ margin-bottom: 24px;
82
  }
83
 
84
+ .app-header h1 {
85
+ font-family: 'SF Pro Display', -apple-system, BlinkMacSystemFont, sans-serif;
86
+ font-size: 2.5rem;
87
+ font-weight: 600;
88
+ letter-spacing: -0.03em;
89
+ background: linear-gradient(135deg, #fff 0%, rgba(255,255,255,0.7) 100%);
90
+ -webkit-background-clip: text;
91
+ -webkit-text-fill-color: transparent;
92
+ margin: 0 0 8px 0;
93
  }
94
 
95
+ .app-header p {
96
+ color: var(--text-secondary);
97
+ font-size: 1rem;
98
+ margin: 0;
99
+ font-weight: 400;
100
  }
101
 
102
+ /* Panels */
103
+ .panel {
104
+ background: var(--surface-1) !important;
105
+ border: 1px solid var(--border) !important;
106
+ border-radius: var(--radius) !important;
107
+ overflow: hidden;
108
  }
109
 
110
+ .panel-title {
111
+ font-size: 0.7rem;
112
+ text-transform: uppercase;
113
+ letter-spacing: 0.1em;
114
+ color: var(--text-secondary);
115
+ padding: 16px 20px 8px;
116
+ font-weight: 600;
117
  }
118
 
119
+ /* Upload Area */
120
+ .upload-zone {
121
+ min-height: 280px !important;
122
+ border: 2px dashed var(--border) !important;
123
+ border-radius: var(--radius) !important;
124
+ background: var(--surface-2) !important;
125
+ transition: all 0.3s ease;
126
  }
127
 
128
+ .upload-zone:hover {
129
+ border-color: var(--accent) !important;
130
+ background: rgba(99, 102, 241, 0.05) !important;
 
 
 
 
 
 
 
 
 
131
  }
132
 
133
+ /* Buttons */
134
+ .primary-btn {
135
+ background: var(--accent) !important;
136
+ border: none !important;
137
+ border-radius: var(--radius-sm) !important;
138
+ color: white !important;
139
+ font-weight: 600 !important;
140
+ padding: 14px 28px !important;
141
+ font-size: 0.95rem !important;
142
+ transition: all 0.2s ease !important;
143
+ box-shadow: 0 4px 20px var(--accent-glow) !important;
144
  }
145
 
146
+ .primary-btn:hover {
147
+ background: var(--accent-hover) !important;
148
+ transform: translateY(-1px);
149
+ box-shadow: 0 6px 30px var(--accent-glow) !important;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
150
  }
151
 
152
+ .secondary-btn {
153
+ background: var(--surface-3) !important;
154
+ border: 1px solid var(--border) !important;
155
+ border-radius: var(--radius-sm) !important;
156
+ color: var(--text-primary) !important;
157
+ font-weight: 500 !important;
158
+ transition: all 0.2s ease !important;
159
+ }
160
+
161
+ .secondary-btn:hover {
162
+ background: var(--surface-2) !important;
163
+ border-color: var(--accent) !important;
164
+ }
165
+
166
+ /* Sliders & Inputs */
167
+ input[type="range"] {
168
+ accent-color: var(--accent) !important;
169
+ }
170
+
171
+ .wrap input, .wrap textarea {
172
+ background: var(--surface-2) !important;
173
+ border: 1px solid var(--border) !important;
174
+ border-radius: var(--radius-sm) !important;
175
+ color: var(--text-primary) !important;
176
+ }
177
+
178
+ /* Radio Buttons */
179
+ .gr-radio-row {
180
+ gap: 8px !important;
181
+ }
182
+
183
+ .gr-radio-row label {
184
+ background: var(--surface-2) !important;
185
+ border: 1px solid var(--border) !important;
186
+ border-radius: var(--radius-sm) !important;
187
+ padding: 10px 18px !important;
188
+ transition: all 0.2s ease !important;
189
+ }
190
+
191
+ .gr-radio-row label:hover {
192
+ border-color: var(--accent) !important;
193
+ }
194
+
195
+ .gr-radio-row label.selected {
196
+ background: var(--accent) !important;
197
+ border-color: var(--accent) !important;
198
+ }
199
+
200
+ /* Accordion */
201
+ .gr-accordion {
202
+ border: 1px solid var(--border) !important;
203
+ border-radius: var(--radius-sm) !important;
204
+ background: var(--surface-2) !important;
205
+ }
206
+
207
+ /* Walkthrough/Stepper */
208
+ .stepper-wrapper { padding: 0; }
209
+ .stepper-container { padding: 0; align-items: center; }
210
+ .step-button { flex-direction: row; }
211
+ .step-connector { transform: none; }
212
+ .step-number { width: 16px; height: 16px; }
213
+ .step-label { position: relative; bottom: 0; }
214
+
215
+ /* Loading States */
216
+ .wrap.center.full { inset: 0; height: 100%; }
217
+ .wrap.center.full.translucent { background: var(--surface-1); }
218
+
219
+ /* ═══════════════════════════════════════════════════════════════
220
+ 3D PREVIEWER COMPONENT
221
+ ═══════════════════════════════════════════════════════════════ */
222
+
223
+ .previewer-container {
224
+ position: relative;
225
  width: 100%;
226
+ height: 650px;
227
  display: flex;
228
+ flex-direction: column;
229
+ align-items: center;
230
  justify-content: center;
231
+ padding: 20px;
232
+ background: radial-gradient(ellipse at center, var(--surface-2) 0%, var(--surface-1) 100%);
233
+ border-radius: var(--radius);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
234
  }
235
 
236
+ /* Viewport */
237
  .previewer-container .display-row {
238
+ flex: 1;
 
239
  width: 100%;
 
240
  display: flex;
241
  justify-content: center;
242
  align-items: center;
243
+ min-height: 0;
244
  }
245
+
246
  .previewer-container .previewer-main-image {
247
  max-width: 100%;
248
  max-height: 100%;
 
249
  object-fit: contain;
250
  display: none;
251
+ border-radius: var(--radius-sm);
252
+ box-shadow: 0 20px 60px rgba(0, 0, 0, 0.4);
253
  }
254
+
255
  .previewer-container .previewer-main-image.visible {
256
  display: block;
257
+ animation: fadeIn 0.3s ease;
258
  }
259
 
260
+ @keyframes fadeIn {
261
+ from { opacity: 0; transform: scale(0.98); }
262
+ to { opacity: 1; transform: scale(1); }
263
+ }
264
+
265
+ /* Mode Selector */
266
+ .previewer-container .mode-row {
267
  display: flex;
 
 
268
  gap: 10px;
269
+ margin-top: 20px;
270
+ padding: 8px;
271
+ background: var(--surface-0);
272
+ border-radius: 50px;
273
+ border: 1px solid var(--border);
274
+ }
275
+
276
+ .previewer-container .mode-btn {
277
+ width: 32px;
278
+ height: 32px;
279
+ border-radius: 50%;
280
+ cursor: pointer;
281
+ opacity: 0.4;
282
+ transition: all 0.25s cubic-bezier(0.4, 0, 0.2, 1);
283
+ border: 2px solid transparent;
284
+ object-fit: cover;
285
+ }
286
+
287
+ .previewer-container .mode-btn:hover {
288
+ opacity: 0.8;
289
+ transform: scale(1.1);
290
+ }
291
+
292
+ .previewer-container .mode-btn.active {
293
+ opacity: 1;
294
+ border-color: var(--accent);
295
+ transform: scale(1.15);
296
+ box-shadow: 0 0 20px var(--accent-glow);
297
+ }
298
+
299
+ /* Rotation Slider */
300
+ .previewer-container .slider-row {
301
+ width: 100%;
302
+ max-width: 320px;
303
+ margin-top: 16px;
304
  }
305
 
306
  .previewer-container input[type=range] {
307
  -webkit-appearance: none;
308
  width: 100%;
 
309
  background: transparent;
310
+ cursor: pointer;
311
  }
312
+
313
  .previewer-container input[type=range]::-webkit-slider-runnable-track {
314
  width: 100%;
315
+ height: 6px;
316
+ background: var(--surface-0);
317
+ border-radius: 3px;
318
+ border: 1px solid var(--border);
319
  }
320
+
321
  .previewer-container input[type=range]::-webkit-slider-thumb {
 
 
 
 
 
322
  -webkit-appearance: none;
323
+ height: 18px;
324
+ width: 18px;
325
+ border-radius: 50%;
326
+ background: var(--accent);
327
+ margin-top: -7px;
328
+ box-shadow: 0 2px 10px var(--accent-glow);
329
+ transition: transform 0.15s ease;
330
  }
331
+
332
  .previewer-container input[type=range]::-webkit-slider-thumb:hover {
333
  transform: scale(1.2);
334
  }
335
 
336
+ /* Empty State */
337
+ .empty-state {
338
+ display: flex;
339
+ flex-direction: column;
340
+ align-items: center;
341
+ gap: 16px;
342
+ color: var(--text-secondary);
343
+ }
344
+
345
+ .empty-state svg {
346
+ opacity: 0.3;
347
  }
348
 
349
+ .empty-state p {
350
+ font-size: 0.9rem;
351
+ margin: 0;
352
+ }
353
+
354
+ /* Block Label Override */
355
+ .gradio-container .padded:has(.previewer-container) { padding: 0 !important; }
356
  .gradio-container:has(.previewer-container) [data-testid="block-label"] {
357
  position: absolute;
358
  top: 0;
359
  left: 0;
360
  }
361
+
362
+ /* GLB Viewer */
363
+ .model3d-container {
364
+ background: var(--surface-2) !important;
365
+ border-radius: var(--radius) !important;
366
+ }
367
+
368
+ /* Footer Note */
369
+ .footer-note {
370
+ text-align: center;
371
+ color: var(--text-secondary);
372
+ font-size: 0.8rem;
373
+ padding: 16px;
374
+ border-top: 1px solid var(--border);
375
+ margin-top: 24px;
376
+ }
377
  """
378
 
379
 
380
  head = """
381
+ <link rel="preconnect" href="https://fonts.googleapis.com">
382
+ <link rel="preconnect" href="https://fonts.gstatic.com" crossorigin>
383
+ <link href="https://fonts.googleapis.com/css2?family=Inter:wght@400;500;600;700&display=swap" rel="stylesheet">
384
+
385
  <script>
386
  function refreshView(mode, step) {
 
387
  const allImgs = document.querySelectorAll('.previewer-main-image');
388
  for (let i = 0; i < allImgs.length; i++) {
389
  const img = allImgs[i];
 
396
  }
397
  }
398
 
 
 
399
  allImgs.forEach(img => img.classList.remove('visible'));
 
 
 
400
  const targetId = 'view-m' + mode + '-s' + step;
401
  const targetImg = document.getElementById(targetId);
402
+ if (targetImg) targetImg.classList.add('visible');
403
 
 
 
 
 
 
 
404
  const allBtns = document.querySelectorAll('.mode-btn');
405
  allBtns.forEach((btn, idx) => {
406
  if (idx === mode) btn.classList.add('active');
 
408
  });
409
  }
410
 
411
+ function selectMode(mode) { refreshView(mode, -1); }
412
+ function onSliderChange(val) { refreshView(-1, parseInt(val)); }
 
 
 
 
 
 
 
413
  </script>
414
  """
415
 
416
 
417
+ empty_html = """
418
  <div class="previewer-container">
419
+ <div class="empty-state">
420
+ <svg width="64" height="64" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round">
421
+ <path d="M21 16V8a2 2 0 0 0-1-1.73l-7-4a2 2 0 0 0-2 0l-7 4A2 2 0 0 0 3 8v8a2 2 0 0 0 1 1.73l7 4a2 2 0 0 0 2 0l7-4A2 2 0 0 0 21 16z"></path>
422
+ <polyline points="3.27 6.96 12 12.01 20.73 6.96"></polyline>
423
+ <line x1="12" y1="22.08" x2="12" y2="12"></line>
424
+ </svg>
425
+ <p>Upload an image to generate 3D</p>
426
+ </div>
427
  </div>
428
  """
429
 
 
456
 
457
 
458
  def preprocess_image(input: Image.Image) -> Image.Image:
459
+ """Preprocess the input image."""
 
 
 
460
  has_alpha = False
461
  if input.mode == 'RGBA':
462
  alpha = np.array(input)[:, :, 3]
 
478
  size = max(bbox[2] - bbox[0], bbox[3] - bbox[1])
479
  size = int(size * 1)
480
  bbox = center[0] - size // 2, center[1] - size // 2, center[0] + size // 2, center[1] + size // 2
481
+ output = output.crop(bbox)
482
  output = np.array(output).astype(np.float32) / 255
483
  output = output[:, :, :3] * output[:, :, 3:4]
484
  output = Image.fromarray((output * 255).astype(np.uint8))
 
505
 
506
 
507
  def get_seed(randomize_seed: bool, seed: int) -> int:
 
 
 
508
  return np.random.randint(0, MAX_SEED) if randomize_seed else seed
509
 
510
 
 
528
  req: gr.Request,
529
  progress=gr.Progress(track_tqdm=True),
530
  ) -> str:
 
531
  outputs, latents = pipeline.run(
532
  image,
533
  seed=seed,
 
558
  return_latent=True,
559
  )
560
  mesh = outputs[0]
561
+ mesh.simplify(16777216)
562
  images = render_utils.render_snapshot(mesh, resolution=1024, r=2, fov=36, nviews=STEPS, envmap=envmap)
563
  state = pack_state(latents)
564
  torch.cuda.empty_cache()
565
 
566
+ # Build HTML
 
567
  images_html = ""
568
  for m_idx, mode in enumerate(MODES):
569
  for s_idx in range(STEPS):
 
570
  unique_id = f"view-m{m_idx}-s{s_idx}"
 
 
571
  is_visible = (m_idx == DEFAULT_MODE and s_idx == DEFAULT_STEP)
572
  vis_class = "visible" if is_visible else ""
 
 
573
  img_base64 = image_to_base64(Image.fromarray(images[mode['render_key']][s_idx]))
574
+ images_html += f'<img id="{unique_id}" class="previewer-main-image {vis_class}" src="{img_base64}" loading="eager">'
 
 
 
 
 
 
 
575
 
 
576
  btns_html = ""
577
  for idx, mode in enumerate(MODES):
578
  active_class = "active" if idx == DEFAULT_MODE else ""
579
+ btns_html += f'<img src="{mode["icon_base64"]}" class="mode-btn {active_class}" onclick="selectMode({idx})" title="{mode["name"]}">'
 
 
 
 
 
 
580
 
 
581
  full_html = f"""
582
  <div class="previewer-container">
583
+ <div class="display-row">{images_html}</div>
584
+ <div class="mode-row">{btns_html}</div>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
585
  <div class="slider-row">
586
+ <input type="range" min="0" max="{STEPS - 1}" value="{DEFAULT_STEP}" step="1" oninput="onSliderChange(this.value)">
587
  </div>
588
  </div>
589
  """
 
599
  req: gr.Request,
600
  progress=gr.Progress(track_tqdm=True),
601
  ) -> Tuple[str, str]:
 
 
 
 
 
 
 
 
 
 
 
602
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
603
  shape_slat, tex_slat, res = unpack_state(state)
604
  mesh = pipeline.decode_latent(shape_slat, tex_slat, res)[0]
 
627
  return glb_path, glb_path
628
 
629
 
630
+ # ═══════════════════════════════════════════════════════════════
631
+ # GRADIO INTERFACE
632
+ # ═══════════════════════════════════════════════════════════════
633
+
634
+ with gr.Blocks(css=css, head=head, theme=gr.themes.Base(primary_hue="indigo"), delete_cache=(600, 600)) as demo:
635
+
636
+ # Header
637
+ gr.HTML("""
638
+ <div class="app-header">
639
+ <h1>TRELLIS.2</h1>
640
+ <p>Transform any image into a high-quality 3D asset</p>
641
+ </div>
642
  """)
643
 
644
+ with gr.Row(equal_height=True):
645
+ # Left Panel β€” Controls
646
+ with gr.Column(scale=1, min_width=320):
647
 
648
+ # Image Upload
649
+ image_prompt = gr.Image(
650
+ label="Input Image",
651
+ format="png",
652
+ image_mode="RGBA",
653
+ type="pil",
654
+ height=280,
655
+ elem_classes=["upload-zone"]
656
+ )
657
 
658
+ # Main Controls
659
+ with gr.Group():
660
+ resolution = gr.Radio(
661
+ ["512", "1024", "1536"],
662
+ label="Resolution",
663
+ value="1024",
664
+ interactive=True
665
+ )
666
 
 
 
667
  with gr.Row():
668
+ seed = gr.Number(label="Seed", value=0, precision=0, minimum=0, maximum=MAX_SEED)
669
+ randomize_seed = gr.Checkbox(label="Random", value=True)
670
+
671
+ # Generate Button
672
+ generate_btn = gr.Button("Generate 3D", variant="primary", elem_classes=["primary-btn"], size="lg")
673
+
674
+ # Export Settings
675
+ with gr.Accordion("Export Settings", open=False):
676
+ decimation_target = gr.Slider(100000, 500000, label="Face Count", value=300000, step=10000)
677
+ texture_size = gr.Slider(1024, 4096, label="Texture Size", value=2048, step=512)
678
+
679
+ # Advanced Settings
680
+ with gr.Accordion("Advanced Settings", open=False):
681
+ gr.Markdown("**Stage 1: Sparse Structure**")
682
+ with gr.Row():
683
+ ss_guidance_strength = gr.Slider(1.0, 10.0, label="Guidance", value=7.5, step=0.1)
684
+ ss_sampling_steps = gr.Slider(1, 50, label="Steps", value=12, step=1)
685
+ with gr.Row():
686
+ ss_guidance_rescale = gr.Slider(0.0, 1.0, label="Rescale", value=0.7, step=0.01)
687
  ss_rescale_t = gr.Slider(1.0, 6.0, label="Rescale T", value=5.0, step=0.1)
688
+
689
+ gr.Markdown("**Stage 2: Shape**")
690
+ with gr.Row():
691
+ shape_slat_guidance_strength = gr.Slider(1.0, 10.0, label="Guidance", value=7.5, step=0.1)
692
+ shape_slat_sampling_steps = gr.Slider(1, 50, label="Steps", value=12, step=1)
693
  with gr.Row():
694
+ shape_slat_guidance_rescale = gr.Slider(0.0, 1.0, label="Rescale", value=0.5, step=0.01)
 
 
695
  shape_slat_rescale_t = gr.Slider(1.0, 6.0, label="Rescale T", value=3.0, step=0.1)
696
+
697
+ gr.Markdown("**Stage 3: Material**")
698
+ with gr.Row():
699
+ tex_slat_guidance_strength = gr.Slider(1.0, 10.0, label="Guidance", value=1.0, step=0.1)
700
+ tex_slat_sampling_steps = gr.Slider(1, 50, label="Steps", value=12, step=1)
701
  with gr.Row():
702
+ tex_slat_guidance_rescale = gr.Slider(0.0, 1.0, label="Rescale", value=0.0, step=0.01)
703
+ tex_slat_rescale_t = gr.Slider(1.0, 6.0, label="Rescale T", value=3.0, step=0.1)
 
 
704
 
705
+ # Right Panel β€” Preview
706
+ with gr.Column(scale=2):
707
  with gr.Walkthrough(selected=0) as walkthrough:
708
  with gr.Step("Preview", id=0):
709
+ preview_output = gr.HTML(empty_html, label="3D Preview", show_label=False)
710
+ extract_btn = gr.Button("Extract GLB", elem_classes=["secondary-btn"], size="lg")
 
 
 
 
711
 
712
+ with gr.Step("Export", id=1):
713
+ glb_output = gr.Model3D(
714
+ label="GLB Model",
715
+ height=580,
716
+ show_label=False,
717
+ display_mode="solid",
718
+ clear_color=(0.06, 0.06, 0.07, 1.0)
719
+ )
720
+ download_btn = gr.DownloadButton("Download GLB", elem_classes=["primary-btn"], size="lg")
721
+
722
+ # Footer
723
+ gr.HTML('<div class="footer-note">GLB extraction may take 30+ seconds. Face count is limited for performance.</div>')
724
 
725
  output_buf = gr.State()
 
726
 
727
+ # Event Handlers
728
  demo.load(start_session)
729
  demo.unload(end_session)
730
 
 
758
  inputs=[output_buf, decimation_target, texture_size],
759
  outputs=[glb_output, download_btn],
760
  )
 
761
 
762
+
763
+ # ═══════════════════════════════════════════════════════════════
764
+ # LAUNCH
765
+ # ═══════════════════════════════════════════════════════════════
766
+
767
  if __name__ == "__main__":
768
  os.makedirs(TMP_DIR, exist_ok=True)
769
 
770
+ # Load mode icons
 
771
  for i in range(len(MODES)):
772
  icon = Image.open(MODES[i]['icon'])
773
  MODES[i]['icon_base64'] = image_to_base64(icon)
 
793
  )),
794
  }
795
 
796
+ demo.launch()