From 0efed794b43562b0d65f5c90aa8f36462490b7f2 Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Sat, 17 Jan 2026 20:26:35 +0000 Subject: [PATCH] Fix issue where flux2 would ignore single control image on training --- .../diffusion_models/flux2/flux2_model.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/extensions_built_in/diffusion_models/flux2/flux2_model.py b/extensions_built_in/diffusion_models/flux2/flux2_model.py index 6467ead..3ed3f48 100644 --- a/extensions_built_in/diffusion_models/flux2/flux2_model.py +++ b/extensions_built_in/diffusion_models/flux2/flux2_model.py @@ -312,7 +312,13 @@ class Flux2Model(BaseModel): img_cond_seq_ids: torch.Tensor | None = None # handle control images - if batch.control_tensor_list is not None: + batch_control_tensor_list = batch.control_tensor_list + if batch_control_tensor_list is None and batch.control_tensor is not None: + batch_control_tensor_list = [] + for b in range(latent_model_input.shape[0]): + batch_control_tensor_list.append(batch.control_tensor[b : b + 1]) + + if batch_control_tensor_list is not None: batch_size, num_channels_latents, height, width = ( latent_model_input.shape ) @@ -328,11 +334,11 @@ class Flux2Model(BaseModel): ) control_image_max_res = control_image_res - if len(batch.control_tensor_list) != batch_size: + if len(batch_control_tensor_list) != batch_size: raise ValueError( "Control tensor list length does not match batch size" ) - for control_tensor_list in batch.control_tensor_list: + for control_tensor_list in batch_control_tensor_list: # control tensor list is a list of tensors for this batch item controls = [] # pack control