Oysiyl Claude Sonnet 4.5 commited on
Commit
6cbd0a8
·
1 Parent(s): ecce6e7

Disable torch.compile for timestep_embedding function

Browse files

The timesteps.device attribute causes ConstantVariable errors with
torch.compile. Disabling compilation for this function prevents the
error while keeping main UNet compilation benefits intact.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Sonnet 4.5 <[email protected]>

comfy/ldm/modules/diffusionmodules/util.py CHANGED
@@ -256,6 +256,7 @@ class CheckpointFunction(torch.autograd.Function):
256
  return (None, None) + input_grads
257
 
258
 
 
259
  def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
260
  """
261
  Create sinusoidal timestep embeddings.
@@ -267,15 +268,18 @@ def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
267
  """
268
  if not repeat_only:
269
  half = dim // 2
 
270
  freqs = torch.exp(
271
  -math.log(max_period)
272
  * torch.arange(start=0, end=half, dtype=torch.float32)
273
  / half
274
- ).to(timesteps.device)
275
  args = timesteps[:, None].float() * freqs[None]
276
  embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
277
  if dim % 2:
278
- embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
 
 
279
  else:
280
  embedding = repeat(timesteps, 'b -> b d', d=dim)
281
  return embedding
 
256
  return (None, None) + input_grads
257
 
258
 
259
+ @torch.compiler.disable
260
  def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
261
  """
262
  Create sinusoidal timestep embeddings.
 
268
  """
269
  if not repeat_only:
270
  half = dim // 2
271
+ # Create on CPU then move to same device as timesteps (torch.compile compatible)
272
  freqs = torch.exp(
273
  -math.log(max_period)
274
  * torch.arange(start=0, end=half, dtype=torch.float32)
275
  / half
276
+ ).to(timesteps)
277
  args = timesteps[:, None].float() * freqs[None]
278
  embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
279
  if dim % 2:
280
+ embedding = torch.cat(
281
+ [embedding, torch.zeros_like(embedding[:, :1])], dim=-1
282
+ )
283
  else:
284
  embedding = repeat(timesteps, 'b -> b d', d=dim)
285
  return embedding