diff --git a/extensions_built_in/diffusion_models/chroma/pipeline.py b/extensions_built_in/diffusion_models/chroma/pipeline.py index 52b9b817..215be798 100644 --- a/extensions_built_in/diffusion_models/chroma/pipeline.py +++ b/extensions_built_in/diffusion_models/chroma/pipeline.py @@ -61,6 +61,8 @@ class ChromaPipeline(FluxPipeline): batch_size = prompt_embeds.shape[0] device = self._execution_device + if isinstance(device, str): + device = torch.device(device) text_ids = torch.zeros(batch_size, prompt_embeds.shape[1], 3).to(device=device, dtype=torch.bfloat16) if guidance_scale > 1.00001: