diffusion in fp8 landed

This commit is contained in:
lllyasviel
2024-08-06 16:47:39 -07:00
committed by GitHub
parent dd8997ee2e
commit 71c94799d1
7 changed files with 96 additions and 46 deletions

View File

@@ -426,16 +426,45 @@ def get_obj_from_str(string, reload=False):
pass
@torch.no_grad()
def load_model(checkpoint_info=None, already_loaded_state_dict=None):
checkpoint_info = checkpoint_info or select_checkpoint()
pass
def reuse_model_from_already_loaded(sd_model, checkpoint_info, timer):
pass
def reload_model_weights(sd_model=None, info=None, forced_reload=False):
pass
def unload_model_weights(sd_model=None, info=None):
pass
def apply_token_merging(sd_model, token_merging_ratio):
if token_merging_ratio <= 0:
return
print(f'token_merging_ratio = {token_merging_ratio}')
from backend.misc.tomesd import TomePatcher
sd_model.forge_objects.unet = TomePatcher().patch(
model=sd_model.forge_objects.unet,
ratio=token_merging_ratio
)
return
@torch.no_grad()
def forge_model_reload():
checkpoint_info = select_checkpoint()
timer = Timer()
if model_data.sd_model:
if model_data.sd_model.filename == checkpoint_info.filename:
return model_data.sd_model
model_data.sd_model = None
model_data.loaded_sd_models = []
memory_management.unload_all_models()
@@ -444,10 +473,7 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None):
timer.record("unload existing model")
if already_loaded_state_dict is not None:
state_dict = already_loaded_state_dict
else:
state_dict = get_checkpoint_state_dict(checkpoint_info, timer)
state_dict = get_checkpoint_state_dict(checkpoint_info, timer)
if shared.opts.sd_checkpoint_cache > 0:
# cache newly loaded model
@@ -489,31 +515,3 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None):
print(f"Model loaded in {timer.summary()}.")
return sd_model
def reuse_model_from_already_loaded(sd_model, checkpoint_info, timer):
pass
def reload_model_weights(sd_model=None, info=None, forced_reload=False):
pass
def unload_model_weights(sd_model=None, info=None):
pass
def apply_token_merging(sd_model, token_merging_ratio):
if token_merging_ratio <= 0:
return
print(f'token_merging_ratio = {token_merging_ratio}')
from backend.misc.tomesd import TomePatcher
sd_model.forge_objects.unet = TomePatcher().patch(
model=sd_model.forge_objects.unet,
ratio=token_merging_ratio
)
return