Bug fixes with ip adapter training. Made a clip pre processor that can be trained with ip adapter to help augment the clip input to squeeze in more detail from a larget input. moved clip processing to the dataloader for speed.

This commit is contained in:
Jaret Burkett
2024-01-04 12:59:38 -07:00
parent 65c08b09c3
commit 645b27f97a
8 changed files with 253 additions and 64 deletions

View File

@@ -678,6 +678,8 @@ class BaseSDTrainProcess(BaseTrainProcess):
if batch.tensor is not None:
imgs = batch.tensor
imgs = imgs.to(self.device_torch, dtype=dtype)
if self.train_config.img_multiplier is not None:
imgs = imgs * self.train_config.img_multiplier
if batch.latents is not None:
latents = batch.latents.to(self.device_torch, dtype=dtype)
batch.latents = latents
@@ -1113,6 +1115,9 @@ class BaseSDTrainProcess(BaseTrainProcess):
self.snr_gos.scale.data = torch.tensor(json_data['scale'], device=self.device_torch)
self.snr_gos.gamma.data = torch.tensor(json_data['gamma'], device=self.device_torch)
# load the adapters before the dataset as they may use the clip encoders
if self.adapter_config is not None:
self.setup_adapter()
flush()
### HOOk ###
@@ -1249,7 +1254,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
flush()
if self.adapter_config is not None:
self.setup_adapter()
# self.setup_adapter()
# set trainable params
params.append({
'params': self.adapter.parameters(),