Compare commits

..

5 Commits

Author SHA1 Message Date
Terry Jia
79cd9d09eb add OUTPUT_NODE = True to ImageCropV2 2026-02-27 23:20:15 -05:00
comfyanonymous
94f1a1cc9d Limit overlap in image tile and combine nodes to prevent issues. (#12688) 2026-02-27 20:16:24 -05:00
rattus
e721e24136 ops: implement lora requanting for non QuantizedTensor fp8 (#12668)
Allow non QuantizedTensor layer to set want_requant to get the post lora
calculation stochastic cast down to the original input dtype.

This is then used by the legacy fp8 Linear implementation to set the
compute_dtype to the preferred lora dtype but then want_requant it back
down to fp8.

This fixes the issue with --fast fp8_matrix_mult is combined with
--fast dynamic_vram which doing a lora on an fp8_ non QT model.
2026-02-27 19:05:51 -05:00
Reiner "Tiles" Prokein
25ec3d96a3 Class WanVAE, def encode, feat_map is using self.decoder instead of self.encoder (#12682) 2026-02-27 19:03:45 -05:00
Christian Byrne
1f1ec377ce feat: add ResolutionSelector node for aspect ratio and megapixel-based resolution calculation (#12199)
Amp-Thread-ID: https://ampcode.com/threads/T-019c179e-cd8c-768f-ae66-207c7a53c01d

Co-authored-by: Jedrzej Kosinski <kosinkadink1@gmail.com>
2026-02-27 09:13:57 -08:00
3 changed files with 15 additions and 42 deletions

View File

@@ -485,7 +485,7 @@ class WanVAE(nn.Module):
iter_ = 1 + (t - 1) // 4 iter_ = 1 + (t - 1) // 4
feat_map = None feat_map = None
if iter_ > 1: if iter_ > 1:
feat_map = [None] * count_conv3d(self.decoder) feat_map = [None] * count_conv3d(self.encoder)
## 对encode输入的x按时间拆分为1、4、4、4.... ## 对encode输入的x按时间拆分为1、4、4、4....
for i in range(iter_): for i in range(iter_):
conv_idx = [0] conv_idx = [0]

View File

@@ -167,17 +167,15 @@ def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compu
x = to_dequant(x, dtype) x = to_dequant(x, dtype)
if not resident and lowvram_fn is not None: if not resident and lowvram_fn is not None:
x = to_dequant(x, dtype if compute_dtype is None else compute_dtype) x = to_dequant(x, dtype if compute_dtype is None else compute_dtype)
#FIXME: this is not accurate, we need to be sensitive to the compute dtype
x = lowvram_fn(x) x = lowvram_fn(x)
if (isinstance(orig, QuantizedTensor) and if (want_requant and len(fns) == 0 or update_weight):
(want_requant and len(fns) == 0 or update_weight)):
seed = comfy.utils.string_to_seed(s.seed_key) seed = comfy.utils.string_to_seed(s.seed_key)
y = QuantizedTensor.from_float(x, s.layout_type, scale="recalculate", stochastic_rounding=seed) if isinstance(orig, QuantizedTensor):
if want_requant and len(fns) == 0: y = QuantizedTensor.from_float(x, s.layout_type, scale="recalculate", stochastic_rounding=seed)
#The layer actually wants our freshly saved QT else:
x = y y = comfy.float.stochastic_rounding(x, orig.dtype, seed=seed)
elif update_weight: if want_requant and len(fns) == 0:
y = comfy.float.stochastic_rounding(x, orig.dtype, seed = comfy.utils.string_to_seed(s.seed_key)) x = y
if update_weight: if update_weight:
orig.copy_(y) orig.copy_(y)
for f in fns: for f in fns:
@@ -617,7 +615,8 @@ def fp8_linear(self, input):
if input.ndim != 2: if input.ndim != 2:
return None return None
w, bias, offload_stream = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input_dtype, offloadable=True) lora_compute_dtype=comfy.model_management.lora_compute_dtype(input.device)
w, bias, offload_stream = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input_dtype, offloadable=True, compute_dtype=lora_compute_dtype, want_requant=True)
scale_weight = torch.ones((), device=input.device, dtype=torch.float32) scale_weight = torch.ones((), device=input.device, dtype=torch.float32)
scale_input = torch.ones((), device=input.device, dtype=torch.float32) scale_input = torch.ones((), device=input.device, dtype=torch.float32)

View File

@@ -65,6 +65,8 @@ class ImageCropV2(IO.ComfyNode):
outputs=[IO.Image.Output()], outputs=[IO.Image.Output()],
) )
OUTPUT_NODE = True
@classmethod @classmethod
def execute(cls, image, crop_region) -> IO.NodeOutput: def execute(cls, image, crop_region) -> IO.NodeOutput:
x = crop_region.get("x", 0) x = crop_region.get("x", 0)
@@ -706,8 +708,8 @@ class SplitImageToTileList(IO.ComfyNode):
@staticmethod @staticmethod
def get_grid_coords(width, height, tile_width, tile_height, overlap): def get_grid_coords(width, height, tile_width, tile_height, overlap):
coords = [] coords = []
stride_x = max(1, tile_width - overlap) stride_x = round(max(tile_width * 0.25, tile_width - overlap))
stride_y = max(1, tile_height - overlap) stride_y = round(max(tile_width * 0.25, tile_height - overlap))
y = 0 y = 0
while y < height: while y < height:
@@ -764,34 +766,6 @@ class ImageMergeTileList(IO.ComfyNode):
], ],
) )
@staticmethod
def get_grid_coords(width, height, tile_width, tile_height, overlap):
coords = []
stride_x = max(1, tile_width - overlap)
stride_y = max(1, tile_height - overlap)
y = 0
while y < height:
x = 0
y_end = min(y + tile_height, height)
y_start = max(0, y_end - tile_height)
while x < width:
x_end = min(x + tile_width, width)
x_start = max(0, x_end - tile_width)
coords.append((x_start, y_start, x_end, y_end))
if x_end >= width:
break
x += stride_x
if y_end >= height:
break
y += stride_y
return coords
@classmethod @classmethod
def execute(cls, image_list, final_width, final_height, overlap): def execute(cls, image_list, final_width, final_height, overlap):
w = final_width[0] w = final_width[0]
@@ -804,7 +778,7 @@ class ImageMergeTileList(IO.ComfyNode):
device = first_tile.device device = first_tile.device
dtype = first_tile.dtype dtype = first_tile.dtype
coords = cls.get_grid_coords(w, h, t_w, t_h, ovlp) coords = SplitImageToTileList.get_grid_coords(w, h, t_w, t_h, ovlp)
canvas = torch.zeros((b, h, w, c), device=device, dtype=dtype) canvas = torch.zeros((b, h, w, c), device=device, dtype=dtype)
weights = torch.zeros((b, h, w, 1), device=device, dtype=dtype) weights = torch.zeros((b, h, w, 1), device=device, dtype=dtype)