load non comfy weights

This commit is contained in:
Rattus
2026-02-09 12:40:59 +10:00
parent 6198f7562e
commit 959b3014bb
3 changed files with 13 additions and 9 deletions

View File

@@ -1206,8 +1206,6 @@ def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, str
if dtype is None:
dtype = weight._model_dtype
r = torch.empty_like(weight, dtype=dtype, device=device)
signature = comfy_aimdo.model_vbar.vbar_fault(weight._v)
if signature is not None:
raw_tensor = comfy_aimdo.torch.aimdo_to_tensor(weight._v, device)
@@ -1216,11 +1214,9 @@ def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, str
weight._v_signature = signature
#Send it over
v_tensor.copy_(weight, non_blocking=non_blocking)
#always take a deep copy even if _v is good, as we have no reasonable point to unpin
#a non comfy weight
r.copy_(v_tensor)
comfy_aimdo.model_vbar.vbar_unpin(weight._v)
return r
return v_tensor.to(dtype=dtype)
r = torch.empty_like(weight, dtype=dtype, device=device)
if weight.dtype != r.dtype and weight.dtype != weight._model_dtype:
#Offloaded casting could skip this, however it would make the quantizations

View File

@@ -1492,7 +1492,9 @@ class ModelPatcherDynamic(ModelPatcher):
if vbar is not None:
vbar.prioritize()
#We have way more tools for acceleration on comfy weight offloading, so always
#We force reserve VRAM for the non comfy-weight so we dont have to deal
#with pin and unpin syncrhonization which can be expensive for small weights
#with a high layer rate (e.g. autoregressive LLMs).
#prioritize the non-comfy weights (note the order reverse).
loading = self._load_list(prio_comfy_cast_weights=True)
loading.sort(reverse=True)
@@ -1557,6 +1559,7 @@ class ModelPatcherDynamic(ModelPatcher):
weight._v = vbar.alloc(weight_size)
weight._model_dtype = model_dtype
allocated_size += weight_size
vbar.set_watermark_limit(allocated_size)
logging.info(f"Model {self.model.__class__.__name__} prepared for dynamic VRAM loading. {allocated_size // (1024 ** 2)}MB Staged. {num_patches} patches attached.")

View File

@@ -13,8 +13,11 @@ from contextlib import nullcontext
import torch
from comfy.cli_args import args
import comfy.memory_management
import comfy.model_management
import comfy_aimdo.model_vbar
from latent_preview import set_preview_method
import nodes
from comfy_execution.caching import (
@@ -527,8 +530,10 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
output_data, output_ui, has_subgraph, has_pending_tasks = await get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, v3_data=v3_data)
finally:
if allocator is not None:
if args.verbose == "DEBUG":
comfy_aimdo.model_vbar.vbars_analyze()
comfy.model_management.reset_cast_buffers()
torch.cuda.synchronize()
comfy_aimdo.model_vbar.vbars_reset_watermark_limits()
if has_pending_tasks:
pending_async_nodes[unique_id] = output_data