Maintain patching related

1. fix several problems related to layerdiffuse not unloaded
2. fix several problems related to Fooocus inpaint
3. Slightly speed up on-the-fly LoRAs by precomputing them to computation dtype
This commit is contained in:
layerdiffusion
2024-08-30 15:14:32 -07:00
parent a8483a3f79
commit d1d0ec46aa
11 changed files with 119 additions and 101 deletions

View File

@@ -55,13 +55,14 @@ class FooocusInpaintPatcher(ControlModelPatcher):
def try_build_from_state_dict(state_dict, ckpt_path):
if 'diffusion_model.time_embed.0.weight' in state_dict:
if len(state_dict['diffusion_model.time_embed.0.weight']) == 3:
return FooocusInpaintPatcher(state_dict)
return FooocusInpaintPatcher(state_dict, ckpt_path)
return None
def __init__(self, state_dict):
def __init__(self, state_dict, filename):
super().__init__()
self.state_dict = state_dict
self.filename = filename
self.inpaint_head = InpaintHead().to(device=torch.device('cpu'), dtype=torch.float32)
self.inpaint_head.load_state_dict(load_torch_file(os.path.join(os.path.dirname(__file__), 'fooocus_inpaint_head')))
@@ -95,8 +96,7 @@ class FooocusInpaintPatcher(ControlModelPatcher):
lora_keys.update({x: x for x in unet.model.state_dict().keys()})
loaded_lora = load_fooocus_patch(self.state_dict, lora_keys)
unet.lora_loader.clear_patches() # TODO
patched = unet.lora_loader.add_patches(loaded_lora, 1.0)
patched = unet.add_patches(filename=self.filename, patches=loaded_lora)
not_patched_count = sum(1 for x in loaded_lora if x not in patched)