Spaces:
Running
on
Zero
Running
on
Zero
Disable torch.compile for timestep_embedding function
Browse filesThe timesteps.device attribute causes ConstantVariable errors with
torch.compile. Disabling compilation for this function prevents the
error while keeping main UNet compilation benefits intact.
🤖 Generated with [Claude Code](https://claude.com/claude-code)
Co-Authored-By: Claude Sonnet 4.5 <[email protected]>
comfy/ldm/modules/diffusionmodules/util.py
CHANGED
|
@@ -256,6 +256,7 @@ class CheckpointFunction(torch.autograd.Function):
|
|
| 256 |
return (None, None) + input_grads
|
| 257 |
|
| 258 |
|
|
|
|
| 259 |
def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
|
| 260 |
"""
|
| 261 |
Create sinusoidal timestep embeddings.
|
|
@@ -267,15 +268,18 @@ def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
|
|
| 267 |
"""
|
| 268 |
if not repeat_only:
|
| 269 |
half = dim // 2
|
|
|
|
| 270 |
freqs = torch.exp(
|
| 271 |
-math.log(max_period)
|
| 272 |
* torch.arange(start=0, end=half, dtype=torch.float32)
|
| 273 |
/ half
|
| 274 |
-
).to(timesteps
|
| 275 |
args = timesteps[:, None].float() * freqs[None]
|
| 276 |
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
| 277 |
if dim % 2:
|
| 278 |
-
embedding = torch.cat(
|
|
|
|
|
|
|
| 279 |
else:
|
| 280 |
embedding = repeat(timesteps, 'b -> b d', d=dim)
|
| 281 |
return embedding
|
|
|
|
| 256 |
return (None, None) + input_grads
|
| 257 |
|
| 258 |
|
| 259 |
+
@torch.compiler.disable
|
| 260 |
def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
|
| 261 |
"""
|
| 262 |
Create sinusoidal timestep embeddings.
|
|
|
|
| 268 |
"""
|
| 269 |
if not repeat_only:
|
| 270 |
half = dim // 2
|
| 271 |
+
# Create on CPU then move to same device as timesteps (torch.compile compatible)
|
| 272 |
freqs = torch.exp(
|
| 273 |
-math.log(max_period)
|
| 274 |
* torch.arange(start=0, end=half, dtype=torch.float32)
|
| 275 |
/ half
|
| 276 |
+
).to(timesteps)
|
| 277 |
args = timesteps[:, None].float() * freqs[None]
|
| 278 |
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
| 279 |
if dim % 2:
|
| 280 |
+
embedding = torch.cat(
|
| 281 |
+
[embedding, torch.zeros_like(embedding[:, :1])], dim=-1
|
| 282 |
+
)
|
| 283 |
else:
|
| 284 |
embedding = repeat(timesteps, 'b -> b d', d=dim)
|
| 285 |
return embedding
|