diff --git a/extensions_built_in/diffusion_models/flux_kontext/flux_kontext.py b/extensions_built_in/diffusion_models/flux_kontext/flux_kontext.py index 49a79b9b..e9eeee57 100644 --- a/extensions_built_in/diffusion_models/flux_kontext/flux_kontext.py +++ b/extensions_built_in/diffusion_models/flux_kontext/flux_kontext.py @@ -291,7 +291,8 @@ class FluxKontextModel(BaseModel): else: guidance = torch.tensor( [guidance_embedding_scale], device=self.device_torch) - guidance = guidance.expand(latent_model_input.shape[0]) + # Expand guidance to match original batch_size + guidance = guidance.expand(bs) else: guidance = None