mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-04-30 03:01:15 +00:00
support all flux models
This commit is contained in:
@@ -10,7 +10,7 @@ from diffusers import DiffusionPipeline
|
||||
from transformers import modeling_utils
|
||||
|
||||
from backend import memory_management
|
||||
from backend.utils import read_arbitrary_config
|
||||
from backend.utils import read_arbitrary_config, load_torch_file
|
||||
from backend.state_dict import try_filter_state_dict, load_state_dict
|
||||
from backend.operations import using_forge_operations
|
||||
from backend.nn.vae import IntegratedAutoencoderKL
|
||||
@@ -46,6 +46,8 @@ def load_huggingface_component(guess, component_name, lib_name, cls_name, repo_p
|
||||
comp._eventual_warn_about_too_long_sequence = lambda *args, **kwargs: None
|
||||
return comp
|
||||
if cls_name in ['AutoencoderKL']:
|
||||
assert isinstance(state_dict, dict) and len(state_dict) > 16, 'You do not have VAE state dict!'
|
||||
|
||||
config = IntegratedAutoencoderKL.load_config(config_path)
|
||||
|
||||
with using_forge_operations(device=memory_management.cpu, dtype=memory_management.vae_dtype()):
|
||||
@@ -54,6 +56,8 @@ def load_huggingface_component(guess, component_name, lib_name, cls_name, repo_p
|
||||
load_state_dict(model, state_dict, ignore_start='loss.')
|
||||
return model
|
||||
if component_name.startswith('text_encoder') and cls_name in ['CLIPTextModel', 'CLIPTextModelWithProjection']:
|
||||
assert isinstance(state_dict, dict) and len(state_dict) > 16, 'You do not have CLIP state dict!'
|
||||
|
||||
from transformers import CLIPTextConfig, CLIPTextModel
|
||||
config = CLIPTextConfig.from_pretrained(config_path)
|
||||
|
||||
@@ -71,6 +75,8 @@ def load_huggingface_component(guess, component_name, lib_name, cls_name, repo_p
|
||||
|
||||
return model
|
||||
if cls_name == 'T5EncoderModel':
|
||||
assert isinstance(state_dict, dict) and len(state_dict) > 16, 'You do not have T5 state dict!'
|
||||
|
||||
from backend.nn.t5 import IntegratedT5
|
||||
config = read_arbitrary_config(config_path)
|
||||
|
||||
@@ -78,17 +84,21 @@ def load_huggingface_component(guess, component_name, lib_name, cls_name, repo_p
|
||||
sd_dtype = memory_management.state_dict_dtype(state_dict)
|
||||
|
||||
if sd_dtype in [torch.float8_e4m3fn, torch.float8_e5m2]:
|
||||
print(f'Using Detected T5 Data Type: {sd_dtype}')
|
||||
dtype = sd_dtype
|
||||
print(f'Using Detected T5 Data Type: {dtype}')
|
||||
else:
|
||||
print(f'Using Default T5 Data Type: {dtype}')
|
||||
|
||||
with modeling_utils.no_init_weights():
|
||||
with using_forge_operations(device=memory_management.cpu, dtype=dtype, manual_cast_enabled=True):
|
||||
model = IntegratedT5(config)
|
||||
|
||||
load_state_dict(model, state_dict, log_name=cls_name, ignore_errors=['transformer.encoder.embed_tokens.weight'])
|
||||
load_state_dict(model, state_dict, log_name=cls_name, ignore_errors=['transformer.encoder.embed_tokens.weight', 'logit_scale'])
|
||||
|
||||
return model
|
||||
if cls_name in ['UNet2DConditionModel', 'FluxTransformer2DModel']:
|
||||
assert isinstance(state_dict, dict) and len(state_dict) > 16, 'You do not have model state dict!'
|
||||
|
||||
model_loader = None
|
||||
if cls_name == 'UNet2DConditionModel':
|
||||
model_loader = lambda c: IntegratedUNet2DConditionModel.from_config(c)
|
||||
@@ -148,16 +158,57 @@ def load_huggingface_component(guess, component_name, lib_name, cls_name, repo_p
|
||||
return None
|
||||
|
||||
|
||||
def split_state_dict(sd, sd_vae=None):
|
||||
guess = huggingface_guess.guess(sd)
|
||||
guess.clip_target = guess.clip_target(sd)
|
||||
def replace_state_dict(sd, asd, guess):
|
||||
vae_key_prefix = guess.vae_key_prefix[0]
|
||||
text_encoder_key_prefix = guess.text_encoder_key_prefix[0]
|
||||
|
||||
if sd_vae is not None:
|
||||
print(f'Using external VAE state dict: {len(sd_vae)}')
|
||||
if "decoder.conv_in.weight" in asd:
|
||||
keys_to_delete = [k for k in sd if k.startswith(vae_key_prefix)]
|
||||
for k in keys_to_delete:
|
||||
del sd[k]
|
||||
for k, v in asd.items():
|
||||
sd[vae_key_prefix + k] = v
|
||||
|
||||
if 'text_model.encoder.layers.0.layer_norm1.weight' in asd:
|
||||
keys_to_delete = [k for k in sd if k.startswith(f"{text_encoder_key_prefix}clip_l.")]
|
||||
for k in keys_to_delete:
|
||||
del sd[k]
|
||||
for k, v in asd.items():
|
||||
sd[f"{text_encoder_key_prefix}clip_l.transformer.{k}"] = v
|
||||
|
||||
if 'encoder.block.0.layer.0.SelfAttention.k.weight' in asd:
|
||||
keys_to_delete = [k for k in sd if k.startswith(f"{text_encoder_key_prefix}t5xxl.")]
|
||||
for k in keys_to_delete:
|
||||
del sd[k]
|
||||
for k, v in asd.items():
|
||||
sd[f"{text_encoder_key_prefix}t5xxl.transformer.{k}"] = v
|
||||
|
||||
return sd
|
||||
|
||||
|
||||
def preprocess_state_dict(sd):
|
||||
if any("double_block" in k for k in sd.keys()):
|
||||
if not any(k.startswith("model.diffusion_model") for k in sd.keys()):
|
||||
sd = {f"model.diffusion_model.{k}": v for k, v in sd.items()}
|
||||
|
||||
return sd
|
||||
|
||||
|
||||
def split_state_dict(sd, additional_state_dicts: list = None):
|
||||
sd = load_torch_file(sd)
|
||||
sd = preprocess_state_dict(sd)
|
||||
guess = huggingface_guess.guess(sd)
|
||||
|
||||
if isinstance(additional_state_dicts, list):
|
||||
for asd in additional_state_dicts:
|
||||
asd = load_torch_file(asd)
|
||||
sd = replace_state_dict(sd, asd, guess)
|
||||
|
||||
guess.clip_target = guess.clip_target(sd)
|
||||
|
||||
state_dict = {
|
||||
guess.unet_target: try_filter_state_dict(sd, guess.unet_key_prefix),
|
||||
guess.vae_target: try_filter_state_dict(sd, guess.vae_key_prefix) if sd_vae is None else sd_vae
|
||||
guess.vae_target: try_filter_state_dict(sd, guess.vae_key_prefix)
|
||||
}
|
||||
|
||||
sd = guess.process_clip_state_dict(sd)
|
||||
@@ -176,9 +227,9 @@ def split_state_dict(sd, sd_vae=None):
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def forge_loader(sd, sd_vae=None):
|
||||
def forge_loader(sd, additional_state_dicts=None):
|
||||
try:
|
||||
state_dicts, estimated_config = split_state_dict(sd, sd_vae=sd_vae)
|
||||
state_dicts, estimated_config = split_state_dict(sd, additional_state_dicts=additional_state_dicts)
|
||||
except:
|
||||
raise ValueError('Failed to recognize model type!')
|
||||
|
||||
|
||||
@@ -384,6 +384,7 @@ class LoadedModel:
|
||||
if not do_not_need_cpu_swap:
|
||||
memory_in_swap = 0
|
||||
mem_counter = 0
|
||||
mem_cannot_cast = 0
|
||||
for m in self.real_model.modules():
|
||||
if hasattr(m, "parameters_manual_cast"):
|
||||
m.prev_parameters_manual_cast = m.parameters_manual_cast
|
||||
@@ -399,8 +400,12 @@ class LoadedModel:
|
||||
m._apply(lambda x: x.pin_memory())
|
||||
elif hasattr(m, "weight"):
|
||||
m.to(self.device)
|
||||
mem_counter += module_size(m)
|
||||
print(f"[Memory Management] Swap disabled for", type(m).__name__)
|
||||
module_mem = module_size(m)
|
||||
mem_counter += module_mem
|
||||
mem_cannot_cast += module_mem
|
||||
|
||||
if mem_cannot_cast > 0:
|
||||
print(f"[Memory Management] Loaded to GPU for backward capability: {mem_cannot_cast / (1024 * 1024):.2f} MB")
|
||||
|
||||
swap_flag = 'Shared' if PIN_SHARED_MEMORY else 'CPU'
|
||||
method_flag = 'asynchronous' if stream.should_use_stream() else 'blocked'
|
||||
|
||||
@@ -3,6 +3,10 @@ from backend.patcher.base import ModelPatcher
|
||||
from backend.nn.base import ModuleDict, ObjectDict
|
||||
|
||||
|
||||
class JointTextEncoder(ModuleDict):
|
||||
pass
|
||||
|
||||
|
||||
class CLIP:
|
||||
def __init__(self, model_dict={}, tokenizer_dict={}, no_init=False):
|
||||
if no_init:
|
||||
@@ -11,7 +15,7 @@ class CLIP:
|
||||
load_device = memory_management.text_encoder_device()
|
||||
offload_device = memory_management.text_encoder_offload_device()
|
||||
|
||||
self.cond_stage_model = ModuleDict(model_dict)
|
||||
self.cond_stage_model = JointTextEncoder(model_dict)
|
||||
self.tokenizer = ObjectDict(tokenizer_dict)
|
||||
self.patcher = ModelPatcher(self.cond_stage_model, load_device=load_device, offload_device=offload_device)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user