mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-26 11:09:50 +00:00
Only enable fp16 on z image models that actually support it. (#12065)
This commit is contained in:
@@ -451,6 +451,7 @@ class NextDiT(nn.Module):
|
||||
device=None,
|
||||
dtype=None,
|
||||
operations=None,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.dtype = dtype
|
||||
|
||||
@@ -444,6 +444,10 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
dit_config["ffn_dim_multiplier"] = (8.0 / 3.0)
|
||||
dit_config["z_image_modulation"] = True
|
||||
dit_config["time_scale"] = 1000.0
|
||||
try:
|
||||
dit_config["allow_fp16"] = torch.std(state_dict['{}layers.{}.ffn_norm1.weight'.format(key_prefix, dit_config["n_layers"] - 2)], unbiased=False).item() < 0.42
|
||||
except Exception:
|
||||
pass
|
||||
if '{}cap_pad_token'.format(key_prefix) in state_dict_keys:
|
||||
dit_config["pad_tokens_multiple"] = 32
|
||||
sig_weight = state_dict.get('{}siglip_embedder.0.weight'.format(key_prefix), None)
|
||||
|
||||
@@ -1093,7 +1093,7 @@ class ZImage(Lumina2):
|
||||
|
||||
def __init__(self, unet_config):
|
||||
super().__init__(unet_config)
|
||||
if comfy.model_management.extended_fp16_support():
|
||||
if comfy.model_management.extended_fp16_support() and unet_config.get("allow_fp16", False):
|
||||
self.supported_inference_dtypes = self.supported_inference_dtypes.copy()
|
||||
self.supported_inference_dtypes.insert(1, torch.float16)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user