mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Fix issue where flux2 would ignore single control image on training
This commit is contained in:
@@ -312,7 +312,13 @@ class Flux2Model(BaseModel):
|
||||
img_cond_seq_ids: torch.Tensor | None = None
|
||||
|
||||
# handle control images
|
||||
if batch.control_tensor_list is not None:
|
||||
batch_control_tensor_list = batch.control_tensor_list
|
||||
if batch_control_tensor_list is None and batch.control_tensor is not None:
|
||||
batch_control_tensor_list = []
|
||||
for b in range(latent_model_input.shape[0]):
|
||||
batch_control_tensor_list.append(batch.control_tensor[b : b + 1])
|
||||
|
||||
if batch_control_tensor_list is not None:
|
||||
batch_size, num_channels_latents, height, width = (
|
||||
latent_model_input.shape
|
||||
)
|
||||
@@ -328,11 +334,11 @@ class Flux2Model(BaseModel):
|
||||
)
|
||||
control_image_max_res = control_image_res
|
||||
|
||||
if len(batch.control_tensor_list) != batch_size:
|
||||
if len(batch_control_tensor_list) != batch_size:
|
||||
raise ValueError(
|
||||
"Control tensor list length does not match batch size"
|
||||
)
|
||||
for control_tensor_list in batch.control_tensor_list:
|
||||
for control_tensor_list in batch_control_tensor_list:
|
||||
# control tensor list is a list of tensors for this batch item
|
||||
controls = []
|
||||
# pack control
|
||||
|
||||
Reference in New Issue
Block a user