From c0084054801289b5c921a79c17a4b6e25bde7c88 Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Tue, 9 Jul 2024 15:34:48 -0600 Subject: [PATCH] Added after model load hook --- jobs/process/BaseSDTrainProcess.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/jobs/process/BaseSDTrainProcess.py b/jobs/process/BaseSDTrainProcess.py index b863dec9..78280288 100644 --- a/jobs/process/BaseSDTrainProcess.py +++ b/jobs/process/BaseSDTrainProcess.py @@ -550,6 +550,10 @@ class BaseSDTrainProcess(BaseTrainProcess): # override in subclass pass + def hook_after_model_load(self): + # override in subclass + pass + def hook_add_extra_train_params(self, params): # override in subclass return params @@ -1246,6 +1250,7 @@ class BaseSDTrainProcess(BaseTrainProcess): self.snr_gos.scale.data = torch.tensor(json_data['scale'], device=self.device_torch) self.snr_gos.gamma.data = torch.tensor(json_data['gamma'], device=self.device_torch) + self.hook_after_model_load() flush() if not self.is_fine_tuning: if self.network_config is not None: