diff --git a/extensions_built_in/diffusion_models/flux2/flux2_model.py b/extensions_built_in/diffusion_models/flux2/flux2_model.py index 6467ead0..3ed3f482 100644 --- a/extensions_built_in/diffusion_models/flux2/flux2_model.py +++ b/extensions_built_in/diffusion_models/flux2/flux2_model.py @@ -312,7 +312,13 @@ class Flux2Model(BaseModel): img_cond_seq_ids: torch.Tensor | None = None # handle control images - if batch.control_tensor_list is not None: + batch_control_tensor_list = batch.control_tensor_list + if batch_control_tensor_list is None and batch.control_tensor is not None: + batch_control_tensor_list = [] + for b in range(latent_model_input.shape[0]): + batch_control_tensor_list.append(batch.control_tensor[b : b + 1]) + + if batch_control_tensor_list is not None: batch_size, num_channels_latents, height, width = ( latent_model_input.shape ) @@ -328,11 +334,11 @@ class Flux2Model(BaseModel): ) control_image_max_res = control_image_res - if len(batch.control_tensor_list) != batch_size: + if len(batch_control_tensor_list) != batch_size: raise ValueError( "Control tensor list length does not match batch size" ) - for control_tensor_list in batch.control_tensor_list: + for control_tensor_list in batch_control_tensor_list: # control tensor list is a list of tensors for this batch item controls = [] # pack control