support all flux models

This commit is contained in:
lllyasviel
2024-08-13 05:42:17 -07:00
committed by GitHub
parent 3589b57ec1
commit 61f83dd610
8 changed files with 177 additions and 61 deletions

View File

@@ -10,7 +10,7 @@ from diffusers import DiffusionPipeline
from transformers import modeling_utils
from backend import memory_management
from backend.utils import read_arbitrary_config
from backend.utils import read_arbitrary_config, load_torch_file
from backend.state_dict import try_filter_state_dict, load_state_dict
from backend.operations import using_forge_operations
from backend.nn.vae import IntegratedAutoencoderKL
@@ -46,6 +46,8 @@ def load_huggingface_component(guess, component_name, lib_name, cls_name, repo_p
comp._eventual_warn_about_too_long_sequence = lambda *args, **kwargs: None
return comp
if cls_name in ['AutoencoderKL']:
assert isinstance(state_dict, dict) and len(state_dict) > 16, 'You do not have VAE state dict!'
config = IntegratedAutoencoderKL.load_config(config_path)
with using_forge_operations(device=memory_management.cpu, dtype=memory_management.vae_dtype()):
@@ -54,6 +56,8 @@ def load_huggingface_component(guess, component_name, lib_name, cls_name, repo_p
load_state_dict(model, state_dict, ignore_start='loss.')
return model
if component_name.startswith('text_encoder') and cls_name in ['CLIPTextModel', 'CLIPTextModelWithProjection']:
assert isinstance(state_dict, dict) and len(state_dict) > 16, 'You do not have CLIP state dict!'
from transformers import CLIPTextConfig, CLIPTextModel
config = CLIPTextConfig.from_pretrained(config_path)
@@ -71,6 +75,8 @@ def load_huggingface_component(guess, component_name, lib_name, cls_name, repo_p
return model
if cls_name == 'T5EncoderModel':
assert isinstance(state_dict, dict) and len(state_dict) > 16, 'You do not have T5 state dict!'
from backend.nn.t5 import IntegratedT5
config = read_arbitrary_config(config_path)
@@ -78,17 +84,21 @@ def load_huggingface_component(guess, component_name, lib_name, cls_name, repo_p
sd_dtype = memory_management.state_dict_dtype(state_dict)
if sd_dtype in [torch.float8_e4m3fn, torch.float8_e5m2]:
print(f'Using Detected T5 Data Type: {sd_dtype}')
dtype = sd_dtype
print(f'Using Detected T5 Data Type: {dtype}')
else:
print(f'Using Default T5 Data Type: {dtype}')
with modeling_utils.no_init_weights():
with using_forge_operations(device=memory_management.cpu, dtype=dtype, manual_cast_enabled=True):
model = IntegratedT5(config)
load_state_dict(model, state_dict, log_name=cls_name, ignore_errors=['transformer.encoder.embed_tokens.weight'])
load_state_dict(model, state_dict, log_name=cls_name, ignore_errors=['transformer.encoder.embed_tokens.weight', 'logit_scale'])
return model
if cls_name in ['UNet2DConditionModel', 'FluxTransformer2DModel']:
assert isinstance(state_dict, dict) and len(state_dict) > 16, 'You do not have model state dict!'
model_loader = None
if cls_name == 'UNet2DConditionModel':
model_loader = lambda c: IntegratedUNet2DConditionModel.from_config(c)
@@ -148,16 +158,57 @@ def load_huggingface_component(guess, component_name, lib_name, cls_name, repo_p
return None
def split_state_dict(sd, sd_vae=None):
guess = huggingface_guess.guess(sd)
guess.clip_target = guess.clip_target(sd)
def replace_state_dict(sd, asd, guess):
vae_key_prefix = guess.vae_key_prefix[0]
text_encoder_key_prefix = guess.text_encoder_key_prefix[0]
if sd_vae is not None:
print(f'Using external VAE state dict: {len(sd_vae)}')
if "decoder.conv_in.weight" in asd:
keys_to_delete = [k for k in sd if k.startswith(vae_key_prefix)]
for k in keys_to_delete:
del sd[k]
for k, v in asd.items():
sd[vae_key_prefix + k] = v
if 'text_model.encoder.layers.0.layer_norm1.weight' in asd:
keys_to_delete = [k for k in sd if k.startswith(f"{text_encoder_key_prefix}clip_l.")]
for k in keys_to_delete:
del sd[k]
for k, v in asd.items():
sd[f"{text_encoder_key_prefix}clip_l.transformer.{k}"] = v
if 'encoder.block.0.layer.0.SelfAttention.k.weight' in asd:
keys_to_delete = [k for k in sd if k.startswith(f"{text_encoder_key_prefix}t5xxl.")]
for k in keys_to_delete:
del sd[k]
for k, v in asd.items():
sd[f"{text_encoder_key_prefix}t5xxl.transformer.{k}"] = v
return sd
def preprocess_state_dict(sd):
if any("double_block" in k for k in sd.keys()):
if not any(k.startswith("model.diffusion_model") for k in sd.keys()):
sd = {f"model.diffusion_model.{k}": v for k, v in sd.items()}
return sd
def split_state_dict(sd, additional_state_dicts: list = None):
sd = load_torch_file(sd)
sd = preprocess_state_dict(sd)
guess = huggingface_guess.guess(sd)
if isinstance(additional_state_dicts, list):
for asd in additional_state_dicts:
asd = load_torch_file(asd)
sd = replace_state_dict(sd, asd, guess)
guess.clip_target = guess.clip_target(sd)
state_dict = {
guess.unet_target: try_filter_state_dict(sd, guess.unet_key_prefix),
guess.vae_target: try_filter_state_dict(sd, guess.vae_key_prefix) if sd_vae is None else sd_vae
guess.vae_target: try_filter_state_dict(sd, guess.vae_key_prefix)
}
sd = guess.process_clip_state_dict(sd)
@@ -176,9 +227,9 @@ def split_state_dict(sd, sd_vae=None):
@torch.no_grad()
def forge_loader(sd, sd_vae=None):
def forge_loader(sd, additional_state_dicts=None):
try:
state_dicts, estimated_config = split_state_dict(sd, sd_vae=sd_vae)
state_dicts, estimated_config = split_state_dict(sd, additional_state_dicts=additional_state_dicts)
except:
raise ValueError('Failed to recognize model type!')

View File

@@ -384,6 +384,7 @@ class LoadedModel:
if not do_not_need_cpu_swap:
memory_in_swap = 0
mem_counter = 0
mem_cannot_cast = 0
for m in self.real_model.modules():
if hasattr(m, "parameters_manual_cast"):
m.prev_parameters_manual_cast = m.parameters_manual_cast
@@ -399,8 +400,12 @@ class LoadedModel:
m._apply(lambda x: x.pin_memory())
elif hasattr(m, "weight"):
m.to(self.device)
mem_counter += module_size(m)
print(f"[Memory Management] Swap disabled for", type(m).__name__)
module_mem = module_size(m)
mem_counter += module_mem
mem_cannot_cast += module_mem
if mem_cannot_cast > 0:
print(f"[Memory Management] Loaded to GPU for backward capability: {mem_cannot_cast / (1024 * 1024):.2f} MB")
swap_flag = 'Shared' if PIN_SHARED_MEMORY else 'CPU'
method_flag = 'asynchronous' if stream.should_use_stream() else 'blocked'

View File

@@ -3,6 +3,10 @@ from backend.patcher.base import ModelPatcher
from backend.nn.base import ModuleDict, ObjectDict
class JointTextEncoder(ModuleDict):
pass
class CLIP:
def __init__(self, model_dict={}, tokenizer_dict={}, no_init=False):
if no_init:
@@ -11,7 +15,7 @@ class CLIP:
load_device = memory_management.text_encoder_device()
offload_device = memory_management.text_encoder_offload_device()
self.cond_stage_model = ModuleDict(model_dict)
self.cond_stage_model = JointTextEncoder(model_dict)
self.tokenizer = ObjectDict(tokenizer_dict)
self.patcher = ModelPatcher(self.cond_stage_model, load_device=load_device, offload_device=offload_device)