mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-28 18:21:16 +00:00
Bug fixes with ip adapter training. Made a clip pre processor that can be trained with ip adapter to help augment the clip input to squeeze in more detail from a larget input. moved clip processing to the dataloader for speed.
This commit is contained in:
@@ -678,6 +678,8 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
if batch.tensor is not None:
|
||||
imgs = batch.tensor
|
||||
imgs = imgs.to(self.device_torch, dtype=dtype)
|
||||
if self.train_config.img_multiplier is not None:
|
||||
imgs = imgs * self.train_config.img_multiplier
|
||||
if batch.latents is not None:
|
||||
latents = batch.latents.to(self.device_torch, dtype=dtype)
|
||||
batch.latents = latents
|
||||
@@ -1113,6 +1115,9 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
self.snr_gos.scale.data = torch.tensor(json_data['scale'], device=self.device_torch)
|
||||
self.snr_gos.gamma.data = torch.tensor(json_data['gamma'], device=self.device_torch)
|
||||
|
||||
# load the adapters before the dataset as they may use the clip encoders
|
||||
if self.adapter_config is not None:
|
||||
self.setup_adapter()
|
||||
flush()
|
||||
|
||||
### HOOk ###
|
||||
@@ -1249,7 +1254,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
flush()
|
||||
|
||||
if self.adapter_config is not None:
|
||||
self.setup_adapter()
|
||||
# self.setup_adapter()
|
||||
# set trainable params
|
||||
params.append({
|
||||
'params': self.adapter.parameters(),
|
||||
|
||||
Reference in New Issue
Block a user