diff --git a/jobs/process/BaseSDTrainProcess.py b/jobs/process/BaseSDTrainProcess.py index dae17596..0ca4680c 100644 --- a/jobs/process/BaseSDTrainProcess.py +++ b/jobs/process/BaseSDTrainProcess.py @@ -1618,7 +1618,7 @@ class BaseSDTrainProcess(BaseTrainProcess): # compile the model if needed if self.model_config.compile: try: - torch.compile(self.sd.unet, dynamic=True, fullgraph=True, mode='max-autotune') + self.sd.unet = torch.compile(self.sd.unet, dynamic=True, fullgraph=True, mode='max-autotune') except Exception as e: print_acc(f"Failed to compile model: {e}") print_acc("Continuing without compilation")