Only enable fp16 on z image models that actually support it. (#12065)

This commit is contained in:
comfyanonymous
2026-01-24 19:32:28 -08:00
committed by GitHub
parent ed6002cb60
commit 635406e283
3 changed files with 6 additions and 1 deletions

View File

@@ -451,6 +451,7 @@ class NextDiT(nn.Module):
device=None, device=None,
dtype=None, dtype=None,
operations=None, operations=None,
**kwargs,
) -> None: ) -> None:
super().__init__() super().__init__()
self.dtype = dtype self.dtype = dtype

View File

@@ -444,6 +444,10 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["ffn_dim_multiplier"] = (8.0 / 3.0) dit_config["ffn_dim_multiplier"] = (8.0 / 3.0)
dit_config["z_image_modulation"] = True dit_config["z_image_modulation"] = True
dit_config["time_scale"] = 1000.0 dit_config["time_scale"] = 1000.0
try:
dit_config["allow_fp16"] = torch.std(state_dict['{}layers.{}.ffn_norm1.weight'.format(key_prefix, dit_config["n_layers"] - 2)], unbiased=False).item() < 0.42
except Exception:
pass
if '{}cap_pad_token'.format(key_prefix) in state_dict_keys: if '{}cap_pad_token'.format(key_prefix) in state_dict_keys:
dit_config["pad_tokens_multiple"] = 32 dit_config["pad_tokens_multiple"] = 32
sig_weight = state_dict.get('{}siglip_embedder.0.weight'.format(key_prefix), None) sig_weight = state_dict.get('{}siglip_embedder.0.weight'.format(key_prefix), None)

View File

@@ -1093,7 +1093,7 @@ class ZImage(Lumina2):
def __init__(self, unet_config): def __init__(self, unet_config):
super().__init__(unet_config) super().__init__(unet_config)
if comfy.model_management.extended_fp16_support(): if comfy.model_management.extended_fp16_support() and unet_config.get("allow_fp16", False):
self.supported_inference_dtypes = self.supported_inference_dtypes.copy() self.supported_inference_dtypes = self.supported_inference_dtypes.copy()
self.supported_inference_dtypes.insert(1, torch.float16) self.supported_inference_dtypes.insert(1, torch.float16)