Move compiiling the model after accelerate manipulation

This commit is contained in:
Jaret Burkett
2026-03-31 09:52:27 -06:00
parent 0045260af7
commit 358d684f6f

View File

@@ -1615,14 +1615,6 @@ class BaseSDTrainProcess(BaseTrainProcess):
# run base sd process run
self.sd.load_model()
# compile the model if needed
if self.model_config.compile:
try:
self.sd.unet = torch.compile(self.sd.unet, dynamic=True, fullgraph=True, mode='max-autotune')
except Exception as e:
print_acc(f"Failed to compile model: {e}")
print_acc("Continuing without compilation")
self.sd.add_after_sample_image_hook(self.sample_step_hook)
dtype = get_torch_dtype(self.train_config.dtype)
@@ -1959,6 +1951,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
if self.adapter_config is not None and self.adapter is None:
self.setup_adapter()
flush()
### HOOK ###
params = self.hook_add_extra_train_params(params)
self.params = params
@@ -2053,6 +2046,15 @@ class BaseSDTrainProcess(BaseTrainProcess):
### HOOK ###
self.hook_before_train_loop()
# compile the model if needed (must be after LoRA/adapter injection AND accelerator.prepare)
if self.model_config.compile:
try:
print_acc(f"Compiling model with torch.compile")
self.sd.unet = torch.compile(self.sd.unet, dynamic=True, mode='reduce-overhead')
except Exception as e:
print_acc(f"Failed to compile model: {e}")
print_acc("Continuing without compilation")
if self.has_first_sample_requested and self.step_num <= 1 and not self.train_config.disable_sampling:
print_acc("Generating first sample from first sample config")
self.sample(0, is_first=True)