mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-04-29 02:41:27 +00:00
Fix
This commit is contained in:
@@ -19,6 +19,12 @@ def count_blocks(state_dict_keys, prefix_string):
|
|||||||
count += 1
|
count += 1
|
||||||
return count
|
return count
|
||||||
|
|
||||||
|
def any_suffix_in(keys, prefix, main, suffix_list=[]):
|
||||||
|
for x in suffix_list:
|
||||||
|
if "{}{}{}".format(prefix, main, x) in keys:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
def calculate_transformer_depth(prefix, state_dict_keys, state_dict):
|
def calculate_transformer_depth(prefix, state_dict_keys, state_dict):
|
||||||
context_dim = None
|
context_dim = None
|
||||||
use_linear_in_transformer = False
|
use_linear_in_transformer = False
|
||||||
@@ -186,7 +192,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
|||||||
dit_config["meanflow_sum"] = False
|
dit_config["meanflow_sum"] = False
|
||||||
return dit_config
|
return dit_config
|
||||||
|
|
||||||
if '{}double_blocks.0.img_attn.norm.key_norm.scale'.format(key_prefix) in state_dict_keys and ('{}img_in.weight'.format(key_prefix) in state_dict_keys or f"{key_prefix}distilled_guidance_layer.norms.0.scale" in state_dict_keys): #Flux, Chroma or Chroma Radiance (has no img_in.weight)
|
if any_suffix_in(state_dict_keys, key_prefix, 'double_blocks.0.img_attn.norm.key_norm.', ["weight", "scale"]) and ('{}img_in.weight'.format(key_prefix) in state_dict_keys or any_suffix_in(state_dict_keys, key_prefix, 'distilled_guidance_layer.norms.0.', ["weight", "scale"])): #Flux, Chroma or Chroma Radiance (has no img_in.weight)
|
||||||
dit_config = {}
|
dit_config = {}
|
||||||
if '{}double_stream_modulation_img.lin.weight'.format(key_prefix) in state_dict_keys:
|
if '{}double_stream_modulation_img.lin.weight'.format(key_prefix) in state_dict_keys:
|
||||||
dit_config["image_model"] = "flux2"
|
dit_config["image_model"] = "flux2"
|
||||||
@@ -241,7 +247,8 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
|||||||
|
|
||||||
dit_config["depth"] = count_blocks(state_dict_keys, '{}double_blocks.'.format(key_prefix) + '{}.')
|
dit_config["depth"] = count_blocks(state_dict_keys, '{}double_blocks.'.format(key_prefix) + '{}.')
|
||||||
dit_config["depth_single_blocks"] = count_blocks(state_dict_keys, '{}single_blocks.'.format(key_prefix) + '{}.')
|
dit_config["depth_single_blocks"] = count_blocks(state_dict_keys, '{}single_blocks.'.format(key_prefix) + '{}.')
|
||||||
if '{}distilled_guidance_layer.0.norms.0.scale'.format(key_prefix) in state_dict_keys or '{}distilled_guidance_layer.norms.0.scale'.format(key_prefix) in state_dict_keys: #Chroma
|
|
||||||
|
if any_suffix_in(state_dict_keys, key_prefix, 'distilled_guidance_layer.0.norms.0.', ["weight", "scale"]) or any_suffix_in(state_dict_keys, key_prefix, 'distilled_guidance_layer.norms.0.', ["weight", "scale"]): #Chroma
|
||||||
dit_config["image_model"] = "chroma"
|
dit_config["image_model"] = "chroma"
|
||||||
dit_config["in_channels"] = 64
|
dit_config["in_channels"] = 64
|
||||||
dit_config["out_channels"] = 64
|
dit_config["out_channels"] = 64
|
||||||
@@ -249,7 +256,8 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
|||||||
dit_config["out_dim"] = 3072
|
dit_config["out_dim"] = 3072
|
||||||
dit_config["hidden_dim"] = 5120
|
dit_config["hidden_dim"] = 5120
|
||||||
dit_config["n_layers"] = 5
|
dit_config["n_layers"] = 5
|
||||||
if f"{key_prefix}nerf_blocks.0.norm.scale" in state_dict_keys: #Chroma Radiance
|
|
||||||
|
if any_suffix_in(state_dict_keys, key_prefix, 'nerf_blocks.0.norm.', ["weight", "scale"]): #Chroma Radiance
|
||||||
dit_config["image_model"] = "chroma_radiance"
|
dit_config["image_model"] = "chroma_radiance"
|
||||||
dit_config["in_channels"] = 3
|
dit_config["in_channels"] = 3
|
||||||
dit_config["out_channels"] = 3
|
dit_config["out_channels"] = 3
|
||||||
@@ -259,7 +267,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
|||||||
dit_config["nerf_depth"] = 4
|
dit_config["nerf_depth"] = 4
|
||||||
dit_config["nerf_max_freqs"] = 8
|
dit_config["nerf_max_freqs"] = 8
|
||||||
dit_config["nerf_tile_size"] = 512
|
dit_config["nerf_tile_size"] = 512
|
||||||
dit_config["nerf_final_head_type"] = "conv" if f"{key_prefix}nerf_final_layer_conv.norm.scale" in state_dict_keys else "linear"
|
dit_config["nerf_final_head_type"] = "conv" if any_suffix_in(state_dict_keys, key_prefix, 'nerf_final_layer_conv.norm.', ["weight", "scale"]) else "linear"
|
||||||
dit_config["nerf_embedder_dtype"] = torch.float32
|
dit_config["nerf_embedder_dtype"] = torch.float32
|
||||||
if "{}__x0__".format(key_prefix) in state_dict_keys: # x0 pred
|
if "{}__x0__".format(key_prefix) in state_dict_keys: # x0 pred
|
||||||
dit_config["use_x0"] = True
|
dit_config["use_x0"] = True
|
||||||
@@ -268,7 +276,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
|||||||
else:
|
else:
|
||||||
dit_config["guidance_embed"] = "{}guidance_in.in_layer.weight".format(key_prefix) in state_dict_keys
|
dit_config["guidance_embed"] = "{}guidance_in.in_layer.weight".format(key_prefix) in state_dict_keys
|
||||||
dit_config["yak_mlp"] = '{}double_blocks.0.img_mlp.gate_proj.weight'.format(key_prefix) in state_dict_keys
|
dit_config["yak_mlp"] = '{}double_blocks.0.img_mlp.gate_proj.weight'.format(key_prefix) in state_dict_keys
|
||||||
dit_config["txt_norm"] = "{}txt_norm.scale".format(key_prefix) in state_dict_keys
|
dit_config["txt_norm"] = any_suffix_in(state_dict_keys, key_prefix, 'txt_norm.', ["weight", "scale"])
|
||||||
if dit_config["yak_mlp"] and dit_config["txt_norm"]: # Ovis model
|
if dit_config["yak_mlp"] and dit_config["txt_norm"]: # Ovis model
|
||||||
dit_config["txt_ids_dims"] = [1, 2]
|
dit_config["txt_ids_dims"] = [1, 2]
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user