From 12bd24f4af3ac33676fdc2b41e121a39eb6c9452 Mon Sep 17 00:00:00 2001 From: Aishwarya0811 Date: Wed, 8 Oct 2025 00:13:15 +0500 Subject: [PATCH 1/2] Fix MPS compatibility in get_1d_sincos_pos_embed_from_grid #12432 --- src/diffusers/models/embeddings.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index b51f5d7aec25..e2b5f2063aac 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -319,13 +319,18 @@ def get_2d_sincos_pos_embed_from_grid(embed_dim, grid, output_type="np"): return emb -def get_1d_sincos_pos_embed_from_grid(embed_dim, pos, output_type="np", flip_sin_to_cos=False): +def get_1d_sincos_pos_embed_from_grid(embed_dim, pos, output_type="np", flip_sin_to_cos=False, dtype=None): """ This function generates 1D positional embeddings from a grid. Args: embed_dim (`int`): The embedding dimension `D` pos (`torch.Tensor`): 1D tensor of positions with shape `(M,)` + output_type (`str`, *optional*, defaults to `"np"`): Output type. Use `"pt"` for PyTorch tensors. + flip_sin_to_cos (`bool`, *optional*, defaults to `False`): Whether to flip sine and cosine embeddings. + dtype (`torch.dtype`, *optional*): Data type for frequency calculations. If `None`, defaults to + `torch.float32` on MPS devices (which don't support `torch.float64`) and `torch.float64` + on other devices. Returns: `torch.Tensor`: Sinusoidal positional embeddings of shape `(M, D)`. @@ -341,7 +346,11 @@ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos, output_type="np", flip_sin if embed_dim % 2 != 0: raise ValueError("embed_dim must be divisible by 2") - omega = torch.arange(embed_dim // 2, device=pos.device, dtype=torch.float64) + # Auto-detect appropriate dtype if not specified + if dtype is None: + dtype = torch.float32 if pos.device.type == "mps" else torch.float64 + + omega = torch.arange(embed_dim // 2, device=pos.device, dtype=dtype) omega /= embed_dim / 2.0 omega = 1.0 / 10000**omega # (D/2,) From 25dd2def229b93d30a78037b03c34fd4f2c40c85 Mon Sep 17 00:00:00 2001 From: Aishwarya0811 Date: Wed, 8 Oct 2025 15:04:28 +0500 Subject: [PATCH 2/2] Fix trailing whitespace in docstring --- src/diffusers/models/embeddings.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index e2b5f2063aac..52740f6dc64b 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -328,8 +328,8 @@ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos, output_type="np", flip_sin pos (`torch.Tensor`): 1D tensor of positions with shape `(M,)` output_type (`str`, *optional*, defaults to `"np"`): Output type. Use `"pt"` for PyTorch tensors. flip_sin_to_cos (`bool`, *optional*, defaults to `False`): Whether to flip sine and cosine embeddings. - dtype (`torch.dtype`, *optional*): Data type for frequency calculations. If `None`, defaults to - `torch.float32` on MPS devices (which don't support `torch.float64`) and `torch.float64` + dtype (`torch.dtype`, *optional*): Data type for frequency calculations. If `None`, defaults to + `torch.float32` on MPS devices (which don't support `torch.float64`) and `torch.float64` on other devices. Returns: