From c39653163d77161b2df2d57419129a4d6d081aa1 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Mon, 16 Feb 2026 21:29:20 -0800 Subject: [PATCH] Fix anima preprocess text embeds not using right inference dtype. (#12501) --- comfy/model_base.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/comfy/model_base.py b/comfy/model_base.py index 4a74cb1ce..9dcef8741 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -178,10 +178,7 @@ class BaseModel(torch.nn.Module): xc = torch.cat([xc] + [comfy.model_management.cast_to_device(c_concat, xc.device, xc.dtype)], dim=1) context = c_crossattn - dtype = self.get_dtype() - - if self.manual_cast_dtype is not None: - dtype = self.manual_cast_dtype + dtype = self.get_dtype_inference() xc = xc.to(dtype) device = xc.device @@ -218,6 +215,13 @@ class BaseModel(torch.nn.Module): def get_dtype(self): return self.diffusion_model.dtype + def get_dtype_inference(self): + dtype = self.get_dtype() + + if self.manual_cast_dtype is not None: + dtype = self.manual_cast_dtype + return dtype + def encode_adm(self, **kwargs): return None @@ -372,9 +376,7 @@ class BaseModel(torch.nn.Module): input_shapes += shape if comfy.model_management.xformers_enabled() or comfy.model_management.pytorch_attention_flash_attention(): - dtype = self.get_dtype() - if self.manual_cast_dtype is not None: - dtype = self.manual_cast_dtype + dtype = self.get_dtype_inference() #TODO: this needs to be tweaked area = sum(map(lambda input_shape: input_shape[0] * math.prod(input_shape[2:]), input_shapes)) return (area * comfy.model_management.dtype_size(dtype) * 0.01 * self.memory_usage_factor) * (1024 * 1024) @@ -1165,7 +1167,7 @@ class Anima(BaseModel): t5xxl_ids = t5xxl_ids.unsqueeze(0) if torch.is_inference_mode_enabled(): # if not we are training - cross_attn = self.diffusion_model.preprocess_text_embeds(cross_attn.to(device=device, dtype=self.get_dtype()), t5xxl_ids.to(device=device), t5xxl_weights=t5xxl_weights.to(device=device, dtype=self.get_dtype())) + cross_attn = self.diffusion_model.preprocess_text_embeds(cross_attn.to(device=device, dtype=self.get_dtype_inference()), t5xxl_ids.to(device=device), t5xxl_weights=t5xxl_weights.to(device=device, dtype=self.get_dtype_inference())) else: out['t5xxl_ids'] = comfy.conds.CONDRegular(t5xxl_ids) out['t5xxl_weights'] = comfy.conds.CONDRegular(t5xxl_weights)