Fix dtype issue with hidream o1 (#13849)

This commit is contained in:
comfyanonymous
2026-05-11 20:53:13 -07:00
committed by GitHub
parent 8e53f001a4
commit 0155ddcbe3

View File

@@ -451,9 +451,8 @@ class Qwen35VisionPatchEmbed(nn.Module):
self.proj = ops.Conv3d(self.in_channels, self.embed_dim, kernel_size=kernel_size, stride=kernel_size, bias=True, device=device, dtype=dtype)
def forward(self, x):
target_dtype = self.proj.weight.dtype
x = x.view(-1, self.in_channels, self.temporal_patch_size, self.patch_size, self.patch_size)
return self.proj(x.to(target_dtype)).view(-1, self.embed_dim)
return self.proj(x).view(-1, self.embed_dim)
class Qwen35VisionMLP(nn.Module):