Fix issue where flux2 would ignore single control image on training

This commit is contained in:
Jaret Burkett
2026-01-17 20:26:35 +00:00
parent e132dbae76
commit 0efed794b4

View File

@@ -312,7 +312,13 @@ class Flux2Model(BaseModel):
img_cond_seq_ids: torch.Tensor | None = None img_cond_seq_ids: torch.Tensor | None = None
# handle control images # handle control images
if batch.control_tensor_list is not None: batch_control_tensor_list = batch.control_tensor_list
if batch_control_tensor_list is None and batch.control_tensor is not None:
batch_control_tensor_list = []
for b in range(latent_model_input.shape[0]):
batch_control_tensor_list.append(batch.control_tensor[b : b + 1])
if batch_control_tensor_list is not None:
batch_size, num_channels_latents, height, width = ( batch_size, num_channels_latents, height, width = (
latent_model_input.shape latent_model_input.shape
) )
@@ -328,11 +334,11 @@ class Flux2Model(BaseModel):
) )
control_image_max_res = control_image_res control_image_max_res = control_image_res
if len(batch.control_tensor_list) != batch_size: if len(batch_control_tensor_list) != batch_size:
raise ValueError( raise ValueError(
"Control tensor list length does not match batch size" "Control tensor list length does not match batch size"
) )
for control_tensor_list in batch.control_tensor_list: for control_tensor_list in batch_control_tensor_list:
# control tensor list is a list of tensors for this batch item # control tensor list is a list of tensors for this batch item
controls = [] controls = []
# pack control # pack control