From 489b1942315c2064383475c2ae00ce77a49bcb55 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=A7=91=E6=9E=97=20KELIN?= <167663109+HuangYuChuh@users.noreply.github.com> Date: Thu, 26 Mar 2026 01:45:38 +0800 Subject: [PATCH] Fix CPU/CUDA device mismatch in Klein edit control image encoding (#742) When training Klein models with a `control_path` (edit/kontext-style paired datasets), `encode_image_refs()` returns tensors that reside on the VAE's device (CPU, since the VAE weights are loaded via `load_file(..., device="cpu")` and are never explicitly moved to the training device). Concatenating those CPU tensors with the training latents (`packed_latents`) that live on CUDA raises: RuntimeError: Expected all tensors to be on the same device Fix: move `img_cond_seq` and `img_cond_seq_ids` to the same device (and dtype) as `img_input` / `img_input_ids` before concatenation. Co-authored-by: HuangYuChuh Co-authored-by: Claude Sonnet 4.6 --- extensions_built_in/diffusion_models/flux2/flux2_model.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/extensions_built_in/diffusion_models/flux2/flux2_model.py b/extensions_built_in/diffusion_models/flux2/flux2_model.py index 726d0e40..83eb7ca1 100644 --- a/extensions_built_in/diffusion_models/flux2/flux2_model.py +++ b/extensions_built_in/diffusion_models/flux2/flux2_model.py @@ -412,8 +412,8 @@ class Flux2Model(BaseModel): assert img_cond_seq_ids is not None, ( "You need to provide either both or neither of the sequence conditioning" ) - img_input = torch.cat((img_input, img_cond_seq), dim=1) - img_input_ids = torch.cat((img_input_ids, img_cond_seq_ids), dim=1) + img_input = torch.cat((img_input, img_cond_seq.to(img_input.device, img_input.dtype)), dim=1) + img_input_ids = torch.cat((img_input_ids, img_cond_seq_ids.to(img_input_ids.device)), dim=1) guidance_vec = torch.full( (img_input.shape[0],),