diff --git a/extensions_built_in/diffusion_models/wan22/wan22_pipeline.py b/extensions_built_in/diffusion_models/wan22/wan22_pipeline.py index 442b5c7a..6c7b93b7 100644 --- a/extensions_built_in/diffusion_models/wan22/wan22_pipeline.py +++ b/extensions_built_in/diffusion_models/wan22/wan22_pipeline.py @@ -256,6 +256,14 @@ class Wan22Pipeline(WanPipeline): # Offload all models self.maybe_free_model_hooks() + + # move transformer back to device + if self._aggressive_offload: + print("Moving transformer back to device") + self.transformer.to(self._execution_device) + if self.transformer_2 is not None: + self.transformer_2.to(self._execution_device) + flush() if not return_dict: return (video,)