From e190fbaeb88aac1924131e01d4c405c9de1c12f4 Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Fri, 12 Jan 2024 06:41:15 -0700 Subject: [PATCH] Prepwork for ilora --- jobs/process/BaseSDTrainProcess.py | 18 +++++++++--------- toolkit/config_modules.py | 2 +- toolkit/ip_adapter.py | 3 +++ 3 files changed, 13 insertions(+), 10 deletions(-) diff --git a/jobs/process/BaseSDTrainProcess.py b/jobs/process/BaseSDTrainProcess.py index f39c1871..1b816161 100644 --- a/jobs/process/BaseSDTrainProcess.py +++ b/jobs/process/BaseSDTrainProcess.py @@ -1119,15 +1119,6 @@ class BaseSDTrainProcess(BaseTrainProcess): if self.adapter_config is not None: self.setup_adapter() flush() - - ### HOOk ### - self.before_dataset_load() - # load datasets if passed in the root process - if self.datasets is not None: - self.data_loader = get_dataloader_from_datasets(self.datasets, self.train_config.batch_size, self.sd) - if self.datasets_reg is not None: - self.data_loader_reg = get_dataloader_from_datasets(self.datasets_reg, self.train_config.batch_size, - self.sd) if not self.is_fine_tuning: if self.network_config is not None: # TODO should we completely switch to LycorisSpecialNetwork? @@ -1334,6 +1325,15 @@ class BaseSDTrainProcess(BaseTrainProcess): ) self.lr_scheduler = lr_scheduler + ### HOOk ### + self.before_dataset_load() + # load datasets if passed in the root process + if self.datasets is not None: + self.data_loader = get_dataloader_from_datasets(self.datasets, self.train_config.batch_size, self.sd) + if self.datasets_reg is not None: + self.data_loader_reg = get_dataloader_from_datasets(self.datasets_reg, self.train_config.batch_size, + self.sd) + flush() ### HOOK ### self.hook_before_train_loop() diff --git a/toolkit/config_modules.py b/toolkit/config_modules.py index 0b5732cb..e2856544 100644 --- a/toolkit/config_modules.py +++ b/toolkit/config_modules.py @@ -127,7 +127,7 @@ class NetworkConfig: self.conv = 4 -AdapterTypes = Literal['t2i', 'ip', 'ip+'] +AdapterTypes = Literal['t2i', 'ip', 'ip+', 'clip', 'ilora'] class AdapterConfig: diff --git a/toolkit/ip_adapter.py b/toolkit/ip_adapter.py index 5d58a627..b7bb7034 100644 --- a/toolkit/ip_adapter.py +++ b/toolkit/ip_adapter.py @@ -288,6 +288,9 @@ class IPAdapter(torch.nn.Module): output_dim=sd.unet.config['cross_attention_dim'], ff_mult=4 ) + elif adapter_config.type == 'ilora': + # we apply the clip encodings to the LoRA + image_proj_model = None else: raise ValueError(f"unknown adapter type: {adapter_config.type}")