mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-05-01 03:31:35 +00:00
Fix issue where flux2 would ignore single control image on training
This commit is contained in:
@@ -312,7 +312,13 @@ class Flux2Model(BaseModel):
|
|||||||
img_cond_seq_ids: torch.Tensor | None = None
|
img_cond_seq_ids: torch.Tensor | None = None
|
||||||
|
|
||||||
# handle control images
|
# handle control images
|
||||||
if batch.control_tensor_list is not None:
|
batch_control_tensor_list = batch.control_tensor_list
|
||||||
|
if batch_control_tensor_list is None and batch.control_tensor is not None:
|
||||||
|
batch_control_tensor_list = []
|
||||||
|
for b in range(latent_model_input.shape[0]):
|
||||||
|
batch_control_tensor_list.append(batch.control_tensor[b : b + 1])
|
||||||
|
|
||||||
|
if batch_control_tensor_list is not None:
|
||||||
batch_size, num_channels_latents, height, width = (
|
batch_size, num_channels_latents, height, width = (
|
||||||
latent_model_input.shape
|
latent_model_input.shape
|
||||||
)
|
)
|
||||||
@@ -328,11 +334,11 @@ class Flux2Model(BaseModel):
|
|||||||
)
|
)
|
||||||
control_image_max_res = control_image_res
|
control_image_max_res = control_image_res
|
||||||
|
|
||||||
if len(batch.control_tensor_list) != batch_size:
|
if len(batch_control_tensor_list) != batch_size:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Control tensor list length does not match batch size"
|
"Control tensor list length does not match batch size"
|
||||||
)
|
)
|
||||||
for control_tensor_list in batch.control_tensor_list:
|
for control_tensor_list in batch_control_tensor_list:
|
||||||
# control tensor list is a list of tensors for this batch item
|
# control tensor list is a list of tensors for this batch item
|
||||||
controls = []
|
controls = []
|
||||||
# pack control
|
# pack control
|
||||||
|
|||||||
Reference in New Issue
Block a user