mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-04-24 00:09:11 +00:00
diffusion in fp8 landed
This commit is contained in:
@@ -2,10 +2,13 @@ import os
|
||||
import torch
|
||||
import logging
|
||||
import importlib
|
||||
|
||||
import huggingface_guess
|
||||
|
||||
from diffusers import DiffusionPipeline
|
||||
from transformers import modeling_utils
|
||||
from backend import memory_management
|
||||
|
||||
from backend.state_dict import try_filter_state_dict, load_state_dict
|
||||
from backend.operations import using_forge_operations
|
||||
from backend.nn.vae import IntegratedAutoencoderKL
|
||||
@@ -57,9 +60,13 @@ def load_huggingface_component(guess, component_name, lib_name, cls_name, repo_p
|
||||
|
||||
return model
|
||||
if cls_name == 'UNet2DConditionModel':
|
||||
unet_config = guess.unet_config.copy()
|
||||
state_dict_size = memory_management.state_dict_size(state_dict)
|
||||
unet_config['dtype'] = memory_management.unet_dtype(model_params=state_dict_size)
|
||||
|
||||
with using_forge_operations():
|
||||
model = IntegratedUNet2DConditionModel.from_config(guess.unet_config)
|
||||
model._internal_dict = guess.unet_config
|
||||
model = IntegratedUNet2DConditionModel.from_config(unet_config)
|
||||
model._internal_dict = unet_config
|
||||
|
||||
load_state_dict(model, state_dict)
|
||||
return model
|
||||
|
||||
@@ -8,7 +8,7 @@ import platform
|
||||
|
||||
from enum import Enum
|
||||
from backend import stream
|
||||
from backend.args import args
|
||||
from backend.args import args, dynamic_args
|
||||
|
||||
|
||||
class VRAMState(Enum):
|
||||
@@ -289,9 +289,8 @@ if 'rtx' in torch_device_name.lower():
|
||||
current_loaded_models = []
|
||||
|
||||
|
||||
def module_size(module, exclude_device=None):
|
||||
def state_dict_size(sd, exclude_device=None):
|
||||
module_mem = 0
|
||||
sd = module.state_dict()
|
||||
for k in sd:
|
||||
t = sd[k]
|
||||
|
||||
@@ -303,6 +302,10 @@ def module_size(module, exclude_device=None):
|
||||
return module_mem
|
||||
|
||||
|
||||
def module_size(module, exclude_device=None):
|
||||
return state_dict_size(module.state_dict(), exclude_device=exclude_device)
|
||||
|
||||
|
||||
class LoadedModel:
|
||||
def __init__(self, model, memory_required):
|
||||
self.model = model
|
||||
@@ -563,20 +566,31 @@ def unet_inital_load_device(parameters, dtype):
|
||||
|
||||
|
||||
def unet_dtype(device=None, model_params=0, supported_dtypes=[torch.float16, torch.bfloat16, torch.float32]):
|
||||
unet_storage_dtype_overwrite = dynamic_args.get('forge_unet_storage_dtype')
|
||||
|
||||
if unet_storage_dtype_overwrite is not None:
|
||||
return unet_storage_dtype_overwrite
|
||||
|
||||
if args.unet_in_bf16:
|
||||
return torch.bfloat16
|
||||
|
||||
if args.unet_in_fp16:
|
||||
return torch.float16
|
||||
|
||||
if args.unet_in_fp8_e4m3fn:
|
||||
return torch.float8_e4m3fn
|
||||
|
||||
if args.unet_in_fp8_e5m2:
|
||||
return torch.float8_e5m2
|
||||
|
||||
if should_use_fp16(device=device, model_params=model_params, manual_cast=True):
|
||||
if torch.float16 in supported_dtypes:
|
||||
return torch.float16
|
||||
|
||||
if should_use_bf16(device, model_params=model_params, manual_cast=True):
|
||||
if torch.bfloat16 in supported_dtypes:
|
||||
return torch.bfloat16
|
||||
|
||||
return torch.float32
|
||||
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import torch
|
||||
|
||||
from backend import memory_management, attention
|
||||
from backend import memory_management, attention, operations
|
||||
from backend.modules.k_prediction import k_prediction_from_diffusers_scheduler
|
||||
|
||||
|
||||
@@ -11,6 +11,11 @@ class KModel(torch.nn.Module):
|
||||
self.storage_dtype = storage_dtype
|
||||
self.computation_dtype = computation_dtype
|
||||
|
||||
need_manual_cast = self.storage_dtype != self.computation_dtype
|
||||
operations.shift_manual_cast(model, enabled=need_manual_cast)
|
||||
|
||||
print(f'K-Model Created: {dict(storage_dtype=storage_dtype, computation_dtype=computation_dtype, manual_cast=need_manual_cast)}')
|
||||
|
||||
self.diffusion_model = model
|
||||
self.predictor = k_prediction_from_diffusers_scheduler(diffusers_scheduler)
|
||||
|
||||
|
||||
@@ -171,3 +171,10 @@ def using_forge_operations(parameters_manual_cast=False, operations=None):
|
||||
for op_name in op_names:
|
||||
setattr(torch.nn, op_name, backups[op_name])
|
||||
return
|
||||
|
||||
|
||||
def shift_manual_cast(model, enabled):
|
||||
for m in model.modules():
|
||||
if hasattr(m, 'parameters_manual_cast'):
|
||||
m.parameters_manual_cast = enabled
|
||||
return
|
||||
|
||||
@@ -426,16 +426,45 @@ def get_obj_from_str(string, reload=False):
|
||||
pass
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def load_model(checkpoint_info=None, already_loaded_state_dict=None):
|
||||
checkpoint_info = checkpoint_info or select_checkpoint()
|
||||
pass
|
||||
|
||||
|
||||
def reuse_model_from_already_loaded(sd_model, checkpoint_info, timer):
|
||||
pass
|
||||
|
||||
|
||||
def reload_model_weights(sd_model=None, info=None, forced_reload=False):
|
||||
pass
|
||||
|
||||
|
||||
def unload_model_weights(sd_model=None, info=None):
|
||||
pass
|
||||
|
||||
|
||||
def apply_token_merging(sd_model, token_merging_ratio):
|
||||
if token_merging_ratio <= 0:
|
||||
return
|
||||
|
||||
print(f'token_merging_ratio = {token_merging_ratio}')
|
||||
|
||||
from backend.misc.tomesd import TomePatcher
|
||||
|
||||
sd_model.forge_objects.unet = TomePatcher().patch(
|
||||
model=sd_model.forge_objects.unet,
|
||||
ratio=token_merging_ratio
|
||||
)
|
||||
|
||||
return
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def forge_model_reload():
|
||||
checkpoint_info = select_checkpoint()
|
||||
|
||||
timer = Timer()
|
||||
|
||||
if model_data.sd_model:
|
||||
if model_data.sd_model.filename == checkpoint_info.filename:
|
||||
return model_data.sd_model
|
||||
|
||||
model_data.sd_model = None
|
||||
model_data.loaded_sd_models = []
|
||||
memory_management.unload_all_models()
|
||||
@@ -444,10 +473,7 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None):
|
||||
|
||||
timer.record("unload existing model")
|
||||
|
||||
if already_loaded_state_dict is not None:
|
||||
state_dict = already_loaded_state_dict
|
||||
else:
|
||||
state_dict = get_checkpoint_state_dict(checkpoint_info, timer)
|
||||
state_dict = get_checkpoint_state_dict(checkpoint_info, timer)
|
||||
|
||||
if shared.opts.sd_checkpoint_cache > 0:
|
||||
# cache newly loaded model
|
||||
@@ -489,31 +515,3 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None):
|
||||
print(f"Model loaded in {timer.summary()}.")
|
||||
|
||||
return sd_model
|
||||
|
||||
|
||||
def reuse_model_from_already_loaded(sd_model, checkpoint_info, timer):
|
||||
pass
|
||||
|
||||
|
||||
def reload_model_weights(sd_model=None, info=None, forced_reload=False):
|
||||
pass
|
||||
|
||||
|
||||
def unload_model_weights(sd_model=None, info=None):
|
||||
pass
|
||||
|
||||
|
||||
def apply_token_merging(sd_model, token_merging_ratio):
|
||||
if token_merging_ratio <= 0:
|
||||
return
|
||||
|
||||
print(f'token_merging_ratio = {token_merging_ratio}')
|
||||
|
||||
from backend.misc.tomesd import TomePatcher
|
||||
|
||||
sd_model.forge_objects.unet = TomePatcher().patch(
|
||||
model=sd_model.forge_objects.unet,
|
||||
ratio=token_merging_ratio
|
||||
)
|
||||
|
||||
return
|
||||
|
||||
@@ -1,21 +1,31 @@
|
||||
import torch
|
||||
import gradio as gr
|
||||
|
||||
from modules import shared_items, shared, ui_common, sd_models
|
||||
from modules import sd_vae as sd_vae_module
|
||||
from modules_forge import main_thread
|
||||
from backend import args as backend_args
|
||||
|
||||
|
||||
ui_checkpoint: gr.Dropdown = None
|
||||
ui_vae: gr.Dropdown = None
|
||||
ui_clip_skip: gr.Slider = None
|
||||
|
||||
forge_unet_storage_dtype_options = {
|
||||
'None': None,
|
||||
'fp8e4m3': torch.float8_e4m3fn,
|
||||
'fp8e5m2': torch.float8_e5m2,
|
||||
}
|
||||
|
||||
def bind_to_opts(comp, k, save=False):
|
||||
|
||||
def bind_to_opts(comp, k, save=False, callback=None):
|
||||
def on_change(v):
|
||||
print(f'Setting Changed: {k} = {v}')
|
||||
shared.opts.set(k, v)
|
||||
if save:
|
||||
shared.opts.save(shared.config_filename)
|
||||
if callback is not None:
|
||||
callback()
|
||||
return
|
||||
|
||||
comp.change(on_change, inputs=[comp], show_progress=False)
|
||||
@@ -35,6 +45,7 @@ def make_checkpoint_manager_ui():
|
||||
ui_checkpoint = gr.Dropdown(
|
||||
value=shared.opts.sd_model_checkpoint,
|
||||
label="Checkpoint",
|
||||
elem_classes=['model_selection'],
|
||||
**sd_model_checkpoint_args()
|
||||
)
|
||||
ui_common.create_refresh_button(ui_checkpoint, shared_items.refresh_checkpoints, sd_model_checkpoint_args, f"forge_refresh_checkpoint")
|
||||
@@ -47,6 +58,9 @@ def make_checkpoint_manager_ui():
|
||||
)
|
||||
ui_common.create_refresh_button(ui_vae, shared_items.refresh_vae_list, sd_vae_args, f"forge_refresh_vae")
|
||||
|
||||
ui_forge_unet_storage_dtype_options = gr.Radio(label="Diffusion in FP8", value=shared.opts.forge_unet_storage_dtype, choices=list(forge_unet_storage_dtype_options.keys()))
|
||||
bind_to_opts(ui_forge_unet_storage_dtype_options, 'forge_unet_storage_dtype', save=True, callback=lambda: main_thread.async_run(model_load_entry))
|
||||
|
||||
ui_clip_skip = gr.Slider(label="Clip skip", value=shared.opts.CLIP_stop_at_last_layers, **{"minimum": 1, "maximum": 12, "step": 1})
|
||||
bind_to_opts(ui_clip_skip, 'CLIP_stop_at_last_layers', save=False)
|
||||
|
||||
@@ -54,7 +68,11 @@ def make_checkpoint_manager_ui():
|
||||
|
||||
|
||||
def model_load_entry():
|
||||
sd_models.load_model()
|
||||
backend_args.dynamic_args.update(dict(
|
||||
forge_unet_storage_dtype=forge_unet_storage_dtype_options[shared.opts.forge_unet_storage_dtype]
|
||||
))
|
||||
|
||||
sd_models.forge_model_reload()
|
||||
return
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user