MPDynamic: force load flux img_in weight (Fixes flux1 canny+depth lora crash) (#12446)

* lora: add weight shape calculations.

This lets the loader know if a lora will change the shape of a weight
so it can take appropriate action.

* MPDynamic: force load flux img_in weight

This weight is a bit special, in that the lora changes its geometry.
This is rather unique, not handled by existing estimate and doesn't
work for either offloading or dynamic_vram.

Fix for dynamic_vram as a special case. Ideally we can fully precalculate
these lora geometry changes at load time, but just get these models
working first.
This commit is contained in:
rattus
2026-02-15 17:30:09 -08:00
committed by GitHub
parent ecd2a19661
commit c0370044cd
4 changed files with 65 additions and 8 deletions

View File

@@ -374,6 +374,31 @@ def pad_tensor_to_shape(tensor: torch.Tensor, new_shape: list[int]) -> torch.Ten
return padded_tensor
def calculate_shape(patches, weight, key, original_weights=None):
current_shape = weight.shape
for p in patches:
v = p[1]
offset = p[3]
# Offsets restore the old shape; lists force a diff without metadata
if offset is not None or isinstance(v, list):
continue
if isinstance(v, weight_adapter.WeightAdapterBase):
adapter_shape = v.calculate_shape(key)
if adapter_shape is not None:
current_shape = adapter_shape
continue
# Standard diff logic with padding
if len(v) == 2:
patch_type, patch_data = v[0], v[1]
if patch_type == "diff" and len(patch_data) > 1 and patch_data[1]['pad_weight']:
current_shape = patch_data[0].shape
return current_shape
def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32, original_weights=None):
for p in patches:
strength = p[0]

View File

@@ -1514,8 +1514,10 @@ class ModelPatcherDynamic(ModelPatcher):
weight, _, _ = get_key_weight(self.model, key)
if weight is None:
return 0
return (False, 0)
if key in self.patches:
if comfy.lora.calculate_shape(self.patches[key], weight, key) != weight.shape:
return (True, 0)
setattr(m, param_key + "_lowvram_function", LowVramPatch(key, self.patches))
num_patches += 1
else:
@@ -1529,7 +1531,13 @@ class ModelPatcherDynamic(ModelPatcher):
model_dtype = getattr(m, param_key + "_comfy_model_dtype", None) or weight.dtype
weight._model_dtype = model_dtype
geometry = comfy.memory_management.TensorGeometry(shape=weight.shape, dtype=model_dtype)
return comfy.memory_management.vram_aligned_size(geometry)
return (False, comfy.memory_management.vram_aligned_size(geometry))
def force_load_param(self, param_key, device_to):
key = key_param_name_to_key(n, param_key)
if key in self.backup:
comfy.utils.set_attr_param(self.model, key, self.backup[key].weight)
self.patch_weight_to_device(key, device_to=device_to)
if hasattr(m, "comfy_cast_weights"):
m.comfy_cast_weights = True
@@ -1537,13 +1545,19 @@ class ModelPatcherDynamic(ModelPatcher):
m.seed_key = n
set_dirty(m, dirty)
v_weight_size = 0
v_weight_size += setup_param(self, m, n, "weight")
v_weight_size += setup_param(self, m, n, "bias")
force_load, v_weight_size = setup_param(self, m, n, "weight")
force_load_bias, v_weight_bias = setup_param(self, m, n, "bias")
force_load = force_load or force_load_bias
v_weight_size += v_weight_bias
if vbar is not None and not hasattr(m, "_v"):
m._v = vbar.alloc(v_weight_size)
allocated_size += v_weight_size
if force_load:
logging.info(f"Module {n} has resizing Lora - force loading")
force_load_param(self, "weight", device_to)
force_load_param(self, "bias", device_to)
else:
if vbar is not None and not hasattr(m, "_v"):
m._v = vbar.alloc(v_weight_size)
allocated_size += v_weight_size
else:
for param in params:
@@ -1606,6 +1620,11 @@ class ModelPatcherDynamic(ModelPatcher):
for m in self.model.modules():
move_weight_functions(m, device_to)
keys = list(self.backup.keys())
for k in keys:
bk = self.backup[k]
comfy.utils.set_attr_param(self.model, k, bk.weight)
def partially_load(self, device_to, extra_memory=0, force_patch_weights=False):
assert not force_patch_weights #See above
with self.use_ejected(skip_and_inject_on_exit_only=True):

View File

@@ -49,6 +49,12 @@ class WeightAdapterBase:
"""
raise NotImplementedError
def calculate_shape(
self,
key
):
return None
def calculate_weight(
self,
weight,

View File

@@ -214,6 +214,13 @@ class LoRAAdapter(WeightAdapterBase):
else:
return None
def calculate_shape(
self,
key
):
reshape = self.weights[5]
return tuple(reshape) if reshape is not None else None
def calculate_weight(
self,
weight,