diff --git a/jobs/process/BaseSDTrainProcess.py b/jobs/process/BaseSDTrainProcess.py index b863dec9..78280288 100644 --- a/jobs/process/BaseSDTrainProcess.py +++ b/jobs/process/BaseSDTrainProcess.py @@ -550,6 +550,10 @@ class BaseSDTrainProcess(BaseTrainProcess): # override in subclass pass + def hook_after_model_load(self): + # override in subclass + pass + def hook_add_extra_train_params(self, params): # override in subclass return params @@ -1246,6 +1250,7 @@ class BaseSDTrainProcess(BaseTrainProcess): self.snr_gos.scale.data = torch.tensor(json_data['scale'], device=self.device_torch) self.snr_gos.gamma.data = torch.tensor(json_data['gamma'], device=self.device_torch) + self.hook_after_model_load() flush() if not self.is_fine_tuning: if self.network_config is not None: