diff --git a/comfy/ldm/lumina/model.py b/comfy/ldm/lumina/model.py index b114d9e31..77d1abc97 100644 --- a/comfy/ldm/lumina/model.py +++ b/comfy/ldm/lumina/model.py @@ -451,6 +451,7 @@ class NextDiT(nn.Module): device=None, dtype=None, operations=None, + **kwargs, ) -> None: super().__init__() self.dtype = dtype diff --git a/comfy/model_detection.py b/comfy/model_detection.py index b29a033cc..8cea16e50 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -444,6 +444,10 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["ffn_dim_multiplier"] = (8.0 / 3.0) dit_config["z_image_modulation"] = True dit_config["time_scale"] = 1000.0 + try: + dit_config["allow_fp16"] = torch.std(state_dict['{}layers.{}.ffn_norm1.weight'.format(key_prefix, dit_config["n_layers"] - 2)], unbiased=False).item() < 0.42 + except Exception: + pass if '{}cap_pad_token'.format(key_prefix) in state_dict_keys: dit_config["pad_tokens_multiple"] = 32 sig_weight = state_dict.get('{}siglip_embedder.0.weight'.format(key_prefix), None) diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 45d913fa6..d25271d6e 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -1093,7 +1093,7 @@ class ZImage(Lumina2): def __init__(self, unet_config): super().__init__(unet_config) - if comfy.model_management.extended_fp16_support(): + if comfy.model_management.extended_fp16_support() and unet_config.get("allow_fp16", False): self.supported_inference_dtypes = self.supported_inference_dtypes.copy() self.supported_inference_dtypes.insert(1, torch.float16)