diff --git a/extensions_built_in/diffusion_models/flux2/flux2_model.py b/extensions_built_in/diffusion_models/flux2/flux2_model.py index 726d0e40..83eb7ca1 100644 --- a/extensions_built_in/diffusion_models/flux2/flux2_model.py +++ b/extensions_built_in/diffusion_models/flux2/flux2_model.py @@ -412,8 +412,8 @@ class Flux2Model(BaseModel): assert img_cond_seq_ids is not None, ( "You need to provide either both or neither of the sequence conditioning" ) - img_input = torch.cat((img_input, img_cond_seq), dim=1) - img_input_ids = torch.cat((img_input_ids, img_cond_seq_ids), dim=1) + img_input = torch.cat((img_input, img_cond_seq.to(img_input.device, img_input.dtype)), dim=1) + img_input_ids = torch.cat((img_input_ids, img_cond_seq_ids.to(img_input_ids.device)), dim=1) guidance_vec = torch.full( (img_input.shape[0],),