mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Added after model load hook
This commit is contained in:
@@ -550,6 +550,10 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
# override in subclass
|
||||
pass
|
||||
|
||||
def hook_after_model_load(self):
|
||||
# override in subclass
|
||||
pass
|
||||
|
||||
def hook_add_extra_train_params(self, params):
|
||||
# override in subclass
|
||||
return params
|
||||
@@ -1246,6 +1250,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
self.snr_gos.scale.data = torch.tensor(json_data['scale'], device=self.device_torch)
|
||||
self.snr_gos.gamma.data = torch.tensor(json_data['gamma'], device=self.device_torch)
|
||||
|
||||
self.hook_after_model_load()
|
||||
flush()
|
||||
if not self.is_fine_tuning:
|
||||
if self.network_config is not None:
|
||||
|
||||
Reference in New Issue
Block a user