mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-05-11 08:20:35 +00:00
Move compiiling the model after accelerate manipulation
This commit is contained in:
@@ -1615,14 +1615,6 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
# run base sd process run
|
||||
self.sd.load_model()
|
||||
|
||||
# compile the model if needed
|
||||
if self.model_config.compile:
|
||||
try:
|
||||
self.sd.unet = torch.compile(self.sd.unet, dynamic=True, fullgraph=True, mode='max-autotune')
|
||||
except Exception as e:
|
||||
print_acc(f"Failed to compile model: {e}")
|
||||
print_acc("Continuing without compilation")
|
||||
|
||||
self.sd.add_after_sample_image_hook(self.sample_step_hook)
|
||||
|
||||
dtype = get_torch_dtype(self.train_config.dtype)
|
||||
@@ -1959,6 +1951,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
if self.adapter_config is not None and self.adapter is None:
|
||||
self.setup_adapter()
|
||||
flush()
|
||||
|
||||
### HOOK ###
|
||||
params = self.hook_add_extra_train_params(params)
|
||||
self.params = params
|
||||
@@ -2053,6 +2046,15 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
### HOOK ###
|
||||
self.hook_before_train_loop()
|
||||
|
||||
# compile the model if needed (must be after LoRA/adapter injection AND accelerator.prepare)
|
||||
if self.model_config.compile:
|
||||
try:
|
||||
print_acc(f"Compiling model with torch.compile")
|
||||
self.sd.unet = torch.compile(self.sd.unet, dynamic=True, mode='reduce-overhead')
|
||||
except Exception as e:
|
||||
print_acc(f"Failed to compile model: {e}")
|
||||
print_acc("Continuing without compilation")
|
||||
|
||||
if self.has_first_sample_requested and self.step_num <= 1 and not self.train_config.disable_sampling:
|
||||
print_acc("Generating first sample from first sample config")
|
||||
self.sample(0, is_first=True)
|
||||
|
||||
Reference in New Issue
Block a user