Merge branch 'master' into fix/api-nodes/vidu-pricing

This commit is contained in:
Alexander Piskun
2026-02-17 07:40:16 +02:00
committed by GitHub

View File

@@ -178,10 +178,7 @@ class BaseModel(torch.nn.Module):
xc = torch.cat([xc] + [comfy.model_management.cast_to_device(c_concat, xc.device, xc.dtype)], dim=1)
context = c_crossattn
dtype = self.get_dtype()
if self.manual_cast_dtype is not None:
dtype = self.manual_cast_dtype
dtype = self.get_dtype_inference()
xc = xc.to(dtype)
device = xc.device
@@ -218,6 +215,13 @@ class BaseModel(torch.nn.Module):
def get_dtype(self):
return self.diffusion_model.dtype
def get_dtype_inference(self):
dtype = self.get_dtype()
if self.manual_cast_dtype is not None:
dtype = self.manual_cast_dtype
return dtype
def encode_adm(self, **kwargs):
return None
@@ -372,9 +376,7 @@ class BaseModel(torch.nn.Module):
input_shapes += shape
if comfy.model_management.xformers_enabled() or comfy.model_management.pytorch_attention_flash_attention():
dtype = self.get_dtype()
if self.manual_cast_dtype is not None:
dtype = self.manual_cast_dtype
dtype = self.get_dtype_inference()
#TODO: this needs to be tweaked
area = sum(map(lambda input_shape: input_shape[0] * math.prod(input_shape[2:]), input_shapes))
return (area * comfy.model_management.dtype_size(dtype) * 0.01 * self.memory_usage_factor) * (1024 * 1024)
@@ -1165,7 +1167,7 @@ class Anima(BaseModel):
t5xxl_ids = t5xxl_ids.unsqueeze(0)
if torch.is_inference_mode_enabled(): # if not we are training
cross_attn = self.diffusion_model.preprocess_text_embeds(cross_attn.to(device=device, dtype=self.get_dtype()), t5xxl_ids.to(device=device), t5xxl_weights=t5xxl_weights.to(device=device, dtype=self.get_dtype()))
cross_attn = self.diffusion_model.preprocess_text_embeds(cross_attn.to(device=device, dtype=self.get_dtype_inference()), t5xxl_ids.to(device=device), t5xxl_weights=t5xxl_weights.to(device=device, dtype=self.get_dtype_inference()))
else:
out['t5xxl_ids'] = comfy.conds.CONDRegular(t5xxl_ids)
out['t5xxl_weights'] = comfy.conds.CONDRegular(t5xxl_weights)