mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-03-13 17:09:49 +00:00
Implement many kernels from scratch
This commit is contained in:
@@ -12,7 +12,7 @@ from backend import memory_management
|
||||
from backend.state_dict import try_filter_state_dict, load_state_dict
|
||||
from backend.operations import using_forge_operations
|
||||
from backend.nn.vae import IntegratedAutoencoderKL
|
||||
from backend.nn.clip import IntegratedCLIP, CLIPTextConfig
|
||||
from backend.nn.clip import IntegratedCLIP
|
||||
from backend.nn.unet import IntegratedUNet2DConditionModel
|
||||
|
||||
from backend.diffusion_engine.sd15 import StableDiffusion
|
||||
@@ -40,17 +40,18 @@ def load_huggingface_component(guess, component_name, lib_name, cls_name, repo_p
|
||||
if cls_name in ['AutoencoderKL']:
|
||||
config = IntegratedAutoencoderKL.load_config(config_path)
|
||||
|
||||
with using_forge_operations():
|
||||
with using_forge_operations(device=memory_management.cpu, dtype=memory_management.vae_dtype()):
|
||||
model = IntegratedAutoencoderKL.from_config(config)
|
||||
|
||||
load_state_dict(model, state_dict)
|
||||
return model
|
||||
if component_name.startswith('text_encoder') and cls_name in ['CLIPTextModel', 'CLIPTextModelWithProjection']:
|
||||
from transformers import CLIPTextConfig, CLIPTextModel
|
||||
config = CLIPTextConfig.from_pretrained(config_path)
|
||||
|
||||
with modeling_utils.no_init_weights():
|
||||
with using_forge_operations():
|
||||
model = IntegratedCLIP(config)
|
||||
with using_forge_operations(device=memory_management.cpu, dtype=memory_management.text_encoder_dtype()):
|
||||
model = IntegratedCLIP(CLIPTextModel, config, add_text_projection=True)
|
||||
|
||||
load_state_dict(model, state_dict, ignore_errors=[
|
||||
'transformer.text_projection.weight',
|
||||
@@ -58,13 +59,30 @@ def load_huggingface_component(guess, component_name, lib_name, cls_name, repo_p
|
||||
'logit_scale'
|
||||
], log_name=cls_name)
|
||||
|
||||
return model
|
||||
if component_name.startswith('text_encoder') and cls_name in ['T5EncoderModel']:
|
||||
from transformers import T5EncoderModel, T5Config
|
||||
config = T5Config.from_pretrained(config_path)
|
||||
|
||||
dtype = memory_management.text_encoder_dtype()
|
||||
sd_dtype = state_dict['transformer.encoder.block.0.layer.0.SelfAttention.k.weight'].dtype
|
||||
|
||||
if sd_dtype in [torch.float8_e4m3fn, torch.float8_e5m2]:
|
||||
dtype = sd_dtype
|
||||
|
||||
with modeling_utils.no_init_weights():
|
||||
with using_forge_operations(device=memory_management.cpu, dtype=dtype):
|
||||
model = IntegratedCLIP(T5EncoderModel, config)
|
||||
|
||||
load_state_dict(model, state_dict, log_name=cls_name)
|
||||
|
||||
return model
|
||||
if cls_name == 'UNet2DConditionModel':
|
||||
unet_config = guess.unet_config.copy()
|
||||
state_dict_size = memory_management.state_dict_size(state_dict)
|
||||
unet_config['dtype'] = memory_management.unet_dtype(model_params=state_dict_size)
|
||||
|
||||
with using_forge_operations():
|
||||
with using_forge_operations(device=memory_management.cpu, dtype=unet_config['dtype']):
|
||||
model = IntegratedUNet2DConditionModel.from_config(unet_config)
|
||||
model._internal_dict = unet_config
|
||||
|
||||
@@ -77,14 +95,14 @@ def load_huggingface_component(guess, component_name, lib_name, cls_name, repo_p
|
||||
|
||||
def split_state_dict(sd):
|
||||
guess = huggingface_guess.guess(sd)
|
||||
guess.clip_target = guess.clip_target(sd)
|
||||
|
||||
state_dict = {
|
||||
'unet': try_filter_state_dict(sd, ['model.diffusion_model.']),
|
||||
'vae': try_filter_state_dict(sd, guess.vae_key_prefix)
|
||||
guess.unet_target: try_filter_state_dict(sd, guess.unet_key_prefix),
|
||||
guess.vae_target: try_filter_state_dict(sd, guess.vae_key_prefix)
|
||||
}
|
||||
|
||||
sd = guess.process_clip_state_dict(sd)
|
||||
guess.clip_target = guess.clip_target(sd)
|
||||
|
||||
for k, v in guess.clip_target.items():
|
||||
state_dict[v] = try_filter_state_dict(sd, [k + '.'])
|
||||
|
||||
Reference in New Issue
Block a user