mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-03-13 17:20:01 +00:00
[Weight-adapter/Trainer] Bypass forward mode in Weight adapter system (#11958)
* Add API of bypass forward module * bypass implementation * add bypass fwd into nodes list/trainer
This commit is contained in:
@@ -18,6 +18,7 @@ import comfy_extras.nodes_custom_sampler
|
||||
import folder_paths
|
||||
import node_helpers
|
||||
from comfy.weight_adapter import adapters, adapter_maps
|
||||
from comfy.weight_adapter.bypass import BypassInjectionManager
|
||||
from comfy_api.latest import ComfyExtension, io, ui
|
||||
from comfy.utils import ProgressBar
|
||||
|
||||
@@ -339,6 +340,11 @@ class TrainSampler(comfy.samplers.Sampler):
|
||||
self._train_step_multires_mode(model_wrap, cond, extra_args, noisegen, latent_image, dataset_size, pbar)
|
||||
|
||||
if (i + 1) % self.grad_acc == 0:
|
||||
for param_groups in self.optimizer.param_groups:
|
||||
for param in param_groups["params"]:
|
||||
if param.grad is None:
|
||||
continue
|
||||
param.grad.data = param.grad.data.to(param.data.dtype)
|
||||
self.optimizer.step()
|
||||
self.optimizer.zero_grad()
|
||||
ui_pbar.update(1)
|
||||
@@ -498,9 +504,9 @@ def _prepare_latents_and_count(latents, dtype, bucket_mode):
|
||||
num_images = sum(t.shape[0] for t in latents)
|
||||
multi_res = False # Not using multi_res path in bucket mode
|
||||
|
||||
logging.info(f"Bucket mode: {num_buckets} buckets, {num_images} total samples")
|
||||
logging.debug(f"Bucket mode: {num_buckets} buckets, {num_images} total samples")
|
||||
for i, lat in enumerate(latents):
|
||||
logging.info(f" Bucket {i}: shape {lat.shape}")
|
||||
logging.debug(f" Bucket {i}: shape {lat.shape}")
|
||||
return latents, num_images, multi_res
|
||||
|
||||
# Non-bucket mode
|
||||
@@ -509,7 +515,7 @@ def _prepare_latents_and_count(latents, dtype, bucket_mode):
|
||||
latents = [t.to(dtype) for t in latents]
|
||||
for latent in latents:
|
||||
all_shapes.add(latent.shape)
|
||||
logging.info(f"Latent shapes: {all_shapes}")
|
||||
logging.debug(f"Latent shapes: {all_shapes}")
|
||||
if len(all_shapes) > 1:
|
||||
multi_res = True
|
||||
else:
|
||||
@@ -545,7 +551,7 @@ def _validate_and_expand_conditioning(positive, num_images, bucket_mode):
|
||||
if bucket_mode:
|
||||
return positive # Skip validation in bucket mode
|
||||
|
||||
logging.info(f"Total Images: {num_images}, Total Captions: {len(positive)}")
|
||||
logging.debug(f"Total Images: {num_images}, Total Captions: {len(positive)}")
|
||||
if len(positive) == 1 and num_images > 1:
|
||||
return positive * num_images
|
||||
elif len(positive) != num_images:
|
||||
@@ -596,6 +602,8 @@ def _create_weight_adapter(
|
||||
shape = module.weight.shape
|
||||
lora_params = {}
|
||||
|
||||
logging.debug(f"Creating weight adapter for {key} with shape {shape}")
|
||||
|
||||
if len(shape) >= 2:
|
||||
alpha = float(existing_weights.get(f"{key}.alpha", 1.0))
|
||||
dora_scale = existing_weights.get(f"{key}.dora_scale", None)
|
||||
@@ -690,6 +698,61 @@ def _setup_lora_adapters(mp, existing_weights, algorithm, lora_dtype, rank):
|
||||
return lora_sd, all_weight_adapters
|
||||
|
||||
|
||||
def _setup_lora_adapters_bypass(mp, existing_weights, algorithm, lora_dtype, rank):
|
||||
"""Setup LoRA adapters in bypass mode.
|
||||
|
||||
In bypass mode:
|
||||
- Weight adapters (lora/lokr/oft) use bypass injection (forward hook)
|
||||
- Bias/norm adapters (BiasDiff) still use weight wrapper (direct modification)
|
||||
|
||||
This is useful when the base model weights are quantized and cannot be
|
||||
directly modified.
|
||||
|
||||
Args:
|
||||
mp: Model patcher
|
||||
existing_weights: Dict of existing LoRA weights
|
||||
algorithm: Algorithm name for new adapters
|
||||
lora_dtype: dtype for LoRA weights
|
||||
rank: Rank for new LoRA adapters
|
||||
|
||||
Returns:
|
||||
tuple: (lora_sd dict, all_weight_adapters list, bypass_manager)
|
||||
"""
|
||||
lora_sd = {}
|
||||
all_weight_adapters = []
|
||||
bypass_manager = BypassInjectionManager()
|
||||
|
||||
for n, m in mp.model.named_modules():
|
||||
if hasattr(m, "weight_function"):
|
||||
if m.weight is not None:
|
||||
adapter, params = _create_weight_adapter(
|
||||
m, n, existing_weights, algorithm, lora_dtype, rank
|
||||
)
|
||||
lora_sd.update(params)
|
||||
all_weight_adapters.append(adapter)
|
||||
|
||||
key = f"{n}.weight"
|
||||
# BiasDiff (for 1D weights like norm) uses weight wrapper, not bypass
|
||||
# Only use bypass for adapters that have h() method (lora/lokr/oft)
|
||||
if isinstance(adapter, BiasDiff):
|
||||
mp.add_weight_wrapper(key, adapter)
|
||||
logging.debug(f"[BypassMode] Added 1D weight adapter (weight wrapper) for {key}")
|
||||
else:
|
||||
bypass_manager.add_adapter(key, adapter, strength=1.0)
|
||||
logging.debug(f"[BypassMode] Added weight adapter (bypass) for {key}")
|
||||
|
||||
if hasattr(m, "bias") and m.bias is not None:
|
||||
# Bias adapters still use weight wrapper (bias is usually not quantized)
|
||||
bias_adapter, bias_params = _create_bias_adapter(m, n, lora_dtype)
|
||||
lora_sd.update(bias_params)
|
||||
key = f"{n}.bias"
|
||||
mp.add_weight_wrapper(key, bias_adapter)
|
||||
all_weight_adapters.append(bias_adapter)
|
||||
logging.debug(f"[BypassMode] Added bias adapter (weight wrapper) for {key}")
|
||||
|
||||
return lora_sd, all_weight_adapters, bypass_manager
|
||||
|
||||
|
||||
def _create_optimizer(optimizer_name, parameters, learning_rate):
|
||||
"""Create optimizer based on name.
|
||||
|
||||
@@ -884,11 +947,13 @@ class TrainLoraNode(io.ComfyNode):
|
||||
default=False,
|
||||
tooltip="Enable resolution bucket mode. When enabled, expects pre-bucketed latents from ResolutionBucket node.",
|
||||
),
|
||||
io.Boolean.Input(
|
||||
"bypass_mode",
|
||||
default=False,
|
||||
tooltip="Enable bypass mode for training. When enabled, adapters are applied via forward hooks instead of weight modification. Useful for quantized models where weights cannot be directly modified.",
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
io.Model.Output(
|
||||
display_name="model", tooltip="Model with LoRA applied"
|
||||
),
|
||||
io.Custom("LORA_MODEL").Output(
|
||||
display_name="lora", tooltip="LoRA weights"
|
||||
),
|
||||
@@ -919,6 +984,7 @@ class TrainLoraNode(io.ComfyNode):
|
||||
gradient_checkpointing,
|
||||
existing_lora,
|
||||
bucket_mode,
|
||||
bypass_mode,
|
||||
):
|
||||
# Extract scalars from lists (due to is_input_list=True)
|
||||
model = model[0]
|
||||
@@ -936,6 +1002,7 @@ class TrainLoraNode(io.ComfyNode):
|
||||
gradient_checkpointing = gradient_checkpointing[0]
|
||||
existing_lora = existing_lora[0]
|
||||
bucket_mode = bucket_mode[0]
|
||||
bypass_mode = bypass_mode[0]
|
||||
|
||||
# Process latents based on mode
|
||||
if bucket_mode:
|
||||
@@ -968,9 +1035,16 @@ class TrainLoraNode(io.ComfyNode):
|
||||
existing_weights, existing_steps = _load_existing_lora(existing_lora)
|
||||
|
||||
# Setup LoRA adapters
|
||||
lora_sd, all_weight_adapters = _setup_lora_adapters(
|
||||
mp, existing_weights, algorithm, lora_dtype, rank
|
||||
)
|
||||
bypass_manager = None
|
||||
if bypass_mode:
|
||||
logging.debug("Using bypass mode for training")
|
||||
lora_sd, all_weight_adapters, bypass_manager = _setup_lora_adapters_bypass(
|
||||
mp, existing_weights, algorithm, lora_dtype, rank
|
||||
)
|
||||
else:
|
||||
lora_sd, all_weight_adapters = _setup_lora_adapters(
|
||||
mp, existing_weights, algorithm, lora_dtype, rank
|
||||
)
|
||||
|
||||
# Create optimizer and loss function
|
||||
optimizer = _create_optimizer(
|
||||
@@ -1029,6 +1103,14 @@ class TrainLoraNode(io.ComfyNode):
|
||||
guider = TrainGuider(mp)
|
||||
guider.set_conds(positive)
|
||||
|
||||
# Inject bypass hooks if bypass mode is enabled
|
||||
bypass_injections = None
|
||||
if bypass_manager is not None:
|
||||
bypass_injections = bypass_manager.create_injections(mp.model)
|
||||
for injection in bypass_injections:
|
||||
injection.inject(mp)
|
||||
logging.debug(f"[BypassMode] Injected {bypass_manager.get_hook_count()} bypass hooks")
|
||||
|
||||
# Run training loop
|
||||
try:
|
||||
_run_training_loop(
|
||||
@@ -1041,6 +1123,11 @@ class TrainLoraNode(io.ComfyNode):
|
||||
multi_res,
|
||||
)
|
||||
finally:
|
||||
# Eject bypass hooks if they were injected
|
||||
if bypass_injections is not None:
|
||||
for injection in bypass_injections:
|
||||
injection.eject(mp)
|
||||
logging.debug("[BypassMode] Ejected bypass hooks")
|
||||
for m in mp.model.modules():
|
||||
unpatch(m)
|
||||
del train_sampler, optimizer
|
||||
@@ -1052,7 +1139,9 @@ class TrainLoraNode(io.ComfyNode):
|
||||
for param in lora_sd:
|
||||
lora_sd[param] = lora_sd[param].to(lora_dtype)
|
||||
|
||||
return io.NodeOutput(mp, lora_sd, loss_map, steps + existing_steps)
|
||||
# mp in train node is highly specialized for training
|
||||
# use it in inference will result in bad behavior so we don't return it
|
||||
return io.NodeOutput(lora_sd, loss_map, steps + existing_steps)
|
||||
|
||||
|
||||
class LoraModelLoader(io.ComfyNode):#
|
||||
|
||||
Reference in New Issue
Block a user