some extra lora support, inc. new glora (#2715)

* support new glora (via ComfyUI)
* support BFL FluxTools loras (mostly via ComfyUI)
* also support using loras (like Hyper, Turbo) with FluxTools models
This commit is contained in:
DenOfEquity
2025-03-04 00:26:43 +00:00
committed by GitHub
parent 4f825bc070
commit 5e1dcd35a8
2 changed files with 91 additions and 28 deletions

View File

@@ -129,10 +129,15 @@ def merge_lora_to_weight(patches, weight, key="online_lora", computation_dtype=t
print("WARNING SHAPE MISMATCH {} WEIGHT NOT MERGED {} != {}".format(key, w1.shape, weight.shape))
else:
weight += strength * memory_management.cast_to_device(w1, weight.device, weight.dtype)
elif patch_type == "set":
weight.copy_(v[0])
elif patch_type == "lora":
mat1 = memory_management.cast_to_device(v[0], weight.device, computation_dtype)
mat2 = memory_management.cast_to_device(v[1], weight.device, computation_dtype)
dora_scale = v[4]
if v[2] is not None:
alpha = v[2] / mat2.shape[0]
else:
@@ -142,12 +147,26 @@ def merge_lora_to_weight(patches, weight, key="online_lora", computation_dtype=t
mat3 = memory_management.cast_to_device(v[3], weight.device, computation_dtype)
final_shape = [mat2.shape[1], mat2.shape[0], mat3.shape[2], mat3.shape[3]]
mat2 = torch.mm(mat2.transpose(0, 1).flatten(start_dim=1), mat3.transpose(0, 1).flatten(start_dim=1)).reshape(final_shape).transpose(0, 1)
try:
lora_diff = torch.mm(mat1.flatten(start_dim=1), mat2.flatten(start_dim=1)).reshape(weight.shape)
lora_diff = torch.mm(mat1.flatten(start_dim=1), mat2.flatten(start_dim=1))
try:
lora_diff = lora_diff.reshape(weight.shape)
except:
if weight.shape[1] < lora_diff.shape[1]:
expand_factor = (lora_diff.shape[1] - weight.shape[1])
weight = torch.nn.functional.pad(weight, (0, expand_factor), mode='constant', value=0)
elif weight.shape[1] > lora_diff.shape[1]:
# expand factor should be 1*64 (for FluxTools Canny or Depth), or 5*64 (for FluxTools Fill)
expand_factor = (weight.shape[1] - lora_diff.shape[1])
lora_diff = torch.nn.functional.pad(lora_diff, (0, expand_factor), mode='constant', value=0)
if dora_scale is not None:
weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength, computation_dtype))
else:
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
except Exception as e:
print("ERROR {} {} {}".format(patch_type, key, e))
raise e
@@ -236,23 +255,45 @@ def merge_lora_to_weight(patches, weight, key="online_lora", computation_dtype=t
except Exception as e:
print("ERROR {} {} {}".format(patch_type, key, e))
raise e
elif patch_type == "glora":
if v[4] is not None:
alpha = v[4] / v[0].shape[0]
else:
alpha = 1.0
elif patch_type == "glora":
dora_scale = v[5]
old_glora = False
if v[3].shape[1] == v[2].shape[0] == v[0].shape[0] == v[1].shape[1]:
old_glora = True
if v[3].shape[0] == v[2].shape[1] == v[0].shape[1] == v[1].shape[0]:
if old_glora and v[1].shape[0] == weight.shape[0] and weight.shape[0] == weight.shape[1]:
pass
else:
old_glora = False
a1 = memory_management.cast_to_device(v[0].flatten(start_dim=1), weight.device, computation_dtype)
a2 = memory_management.cast_to_device(v[1].flatten(start_dim=1), weight.device, computation_dtype)
b1 = memory_management.cast_to_device(v[2].flatten(start_dim=1), weight.device, computation_dtype)
b2 = memory_management.cast_to_device(v[3].flatten(start_dim=1), weight.device, computation_dtype)
if v[4] is None:
alpha = 1.0
else:
if old_glora:
alpha = v[4] / v[0].shape[0]
else:
alpha = v[4] / v[1].shape[0]
try:
lora_diff = (torch.mm(b2, b1) + torch.mm(torch.mm(weight.flatten(start_dim=1), a2), a1)).reshape(weight.shape)
if old_glora:
lora_diff = (torch.mm(b2, b1) + torch.mm(torch.mm(weight.flatten(start_dim=1).to(dtype=computation_dtype), a2), a1)).reshape(weight.shape) #old lycoris glora
else:
if weight.dim() > 2:
lora_diff = torch.einsum("o i ..., i j -> o j ...", torch.einsum("o i ..., i j -> o j ...", weight.to(dtype=computation_dtype), a1), a2).reshape(weight.shape)
else:
lora_diff = torch.mm(torch.mm(weight.to(dtype=computation_dtype), a1), a2).reshape(weight.shape)
lora_diff += torch.mm(b1, b2).reshape(weight.shape)
if dora_scale is not None:
weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength, computation_dtype))
weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, computation_dtype, function)
else:
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
except Exception as e:
@@ -299,7 +340,7 @@ class LoraLoader:
self.loaded_hash = str([])
@torch.inference_mode()
def refresh(self, lora_patches, offload_device=torch.device('cpu'), force_refresh = False):
def refresh(self, lora_patches, offload_device=torch.device('cpu'), force_refresh=False):
hashes = str(list(lora_patches.keys()))
if hashes == self.loaded_hash and not force_refresh: