Make built in lora training work on anima. (#12402)

This commit is contained in:
comfyanonymous
2026-02-10 19:04:32 -08:00
committed by GitHub
parent cdcf4119b3
commit 76a7fa96db
2 changed files with 22 additions and 6 deletions

View File

@@ -1160,12 +1160,16 @@ class Anima(BaseModel):
device = kwargs["device"]
if cross_attn is not None:
if t5xxl_ids is not None:
cross_attn = self.diffusion_model.preprocess_text_embeds(cross_attn.to(device=device, dtype=self.get_dtype()), t5xxl_ids.unsqueeze(0).to(device=device))
if t5xxl_weights is not None:
cross_attn *= t5xxl_weights.unsqueeze(0).unsqueeze(-1).to(cross_attn)
t5xxl_weights = t5xxl_weights.unsqueeze(0).unsqueeze(-1).to(cross_attn)
t5xxl_ids = t5xxl_ids.unsqueeze(0)
if torch.is_inference_mode_enabled(): # if not we are training
cross_attn = self.diffusion_model.preprocess_text_embeds(cross_attn.to(device=device, dtype=self.get_dtype()), t5xxl_ids.to(device=device), t5xxl_weights=t5xxl_weights.to(device=device, dtype=self.get_dtype()))
else:
out['t5xxl_ids'] = comfy.conds.CONDRegular(t5xxl_ids)
out['t5xxl_weights'] = comfy.conds.CONDRegular(t5xxl_weights)
if cross_attn.shape[1] < 512:
cross_attn = torch.nn.functional.pad(cross_attn, (0, 0, 0, 512 - cross_attn.shape[1]))
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
return out