Added after model load hook

This commit is contained in:
Jaret Burkett
2024-07-09 15:34:48 -06:00
parent 93e5df1d59
commit c008405480

View File

@@ -550,6 +550,10 @@ class BaseSDTrainProcess(BaseTrainProcess):
# override in subclass
pass
def hook_after_model_load(self):
# override in subclass
pass
def hook_add_extra_train_params(self, params):
# override in subclass
return params
@@ -1246,6 +1250,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
self.snr_gos.scale.data = torch.tensor(json_data['scale'], device=self.device_torch)
self.snr_gos.gamma.data = torch.tensor(json_data['gamma'], device=self.device_torch)
self.hook_after_model_load()
flush()
if not self.is_fine_tuning:
if self.network_config is not None: