From 358d684f6f2a5d41b51ed9c2c09bd88f2ceb9c64 Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Tue, 31 Mar 2026 09:52:27 -0600 Subject: [PATCH] Move compiiling the model after accelerate manipulation --- jobs/process/BaseSDTrainProcess.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/jobs/process/BaseSDTrainProcess.py b/jobs/process/BaseSDTrainProcess.py index 0ca4680c..1624fc4f 100644 --- a/jobs/process/BaseSDTrainProcess.py +++ b/jobs/process/BaseSDTrainProcess.py @@ -1615,14 +1615,6 @@ class BaseSDTrainProcess(BaseTrainProcess): # run base sd process run self.sd.load_model() - # compile the model if needed - if self.model_config.compile: - try: - self.sd.unet = torch.compile(self.sd.unet, dynamic=True, fullgraph=True, mode='max-autotune') - except Exception as e: - print_acc(f"Failed to compile model: {e}") - print_acc("Continuing without compilation") - self.sd.add_after_sample_image_hook(self.sample_step_hook) dtype = get_torch_dtype(self.train_config.dtype) @@ -1959,6 +1951,7 @@ class BaseSDTrainProcess(BaseTrainProcess): if self.adapter_config is not None and self.adapter is None: self.setup_adapter() flush() + ### HOOK ### params = self.hook_add_extra_train_params(params) self.params = params @@ -2053,6 +2046,15 @@ class BaseSDTrainProcess(BaseTrainProcess): ### HOOK ### self.hook_before_train_loop() + # compile the model if needed (must be after LoRA/adapter injection AND accelerator.prepare) + if self.model_config.compile: + try: + print_acc(f"Compiling model with torch.compile") + self.sd.unet = torch.compile(self.sd.unet, dynamic=True, mode='reduce-overhead') + except Exception as e: + print_acc(f"Failed to compile model: {e}") + print_acc("Continuing without compilation") + if self.has_first_sample_requested and self.step_num <= 1 and not self.train_config.disable_sampling: print_acc("Generating first sample from first sample config") self.sample(0, is_first=True)