mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-26 19:19:53 +00:00
Only enable fp16 on z image models that actually support it. (#12065)
This commit is contained in:
@@ -451,6 +451,7 @@ class NextDiT(nn.Module):
|
|||||||
device=None,
|
device=None,
|
||||||
dtype=None,
|
dtype=None,
|
||||||
operations=None,
|
operations=None,
|
||||||
|
**kwargs,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.dtype = dtype
|
self.dtype = dtype
|
||||||
|
|||||||
@@ -444,6 +444,10 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
|||||||
dit_config["ffn_dim_multiplier"] = (8.0 / 3.0)
|
dit_config["ffn_dim_multiplier"] = (8.0 / 3.0)
|
||||||
dit_config["z_image_modulation"] = True
|
dit_config["z_image_modulation"] = True
|
||||||
dit_config["time_scale"] = 1000.0
|
dit_config["time_scale"] = 1000.0
|
||||||
|
try:
|
||||||
|
dit_config["allow_fp16"] = torch.std(state_dict['{}layers.{}.ffn_norm1.weight'.format(key_prefix, dit_config["n_layers"] - 2)], unbiased=False).item() < 0.42
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
if '{}cap_pad_token'.format(key_prefix) in state_dict_keys:
|
if '{}cap_pad_token'.format(key_prefix) in state_dict_keys:
|
||||||
dit_config["pad_tokens_multiple"] = 32
|
dit_config["pad_tokens_multiple"] = 32
|
||||||
sig_weight = state_dict.get('{}siglip_embedder.0.weight'.format(key_prefix), None)
|
sig_weight = state_dict.get('{}siglip_embedder.0.weight'.format(key_prefix), None)
|
||||||
|
|||||||
@@ -1093,7 +1093,7 @@ class ZImage(Lumina2):
|
|||||||
|
|
||||||
def __init__(self, unet_config):
|
def __init__(self, unet_config):
|
||||||
super().__init__(unet_config)
|
super().__init__(unet_config)
|
||||||
if comfy.model_management.extended_fp16_support():
|
if comfy.model_management.extended_fp16_support() and unet_config.get("allow_fp16", False):
|
||||||
self.supported_inference_dtypes = self.supported_inference_dtypes.copy()
|
self.supported_inference_dtypes = self.supported_inference_dtypes.copy()
|
||||||
self.supported_inference_dtypes.insert(1, torch.float16)
|
self.supported_inference_dtypes.insert(1, torch.float16)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user