Switched to new bucket system that matched sdxl trained buckets. Fixed requirements. Updated embeddings to work with sdxl. Added method to train lora with an embedding at the trigger. Still testing but works amazingly well from what I can see

This commit is contained in:
Jaret Burkett
2023-09-07 13:06:18 -06:00
parent 436bf0c6a3
commit 3feb663a51
10 changed files with 208 additions and 140 deletions

View File

@@ -26,11 +26,9 @@ class SDTrainer(BaseSDTrainProcess):
self.sd.vae.to(self.device_torch)
# textual inversion
if self.embedding is not None:
# keep original embeddings as reference
self.orig_embeds_params = self.sd.text_encoder.get_input_embeddings().weight.data.clone()
# if self.embedding is not None:
# set text encoder to train. Not sure if this is necessary but diffusers example did it
self.sd.text_encoder.train()
# self.sd.text_encoder.train()
def hook_train_loop(self, batch):
dtype = get_torch_dtype(self.train_config.dtype)
@@ -103,13 +101,7 @@ class SDTrainer(BaseSDTrainProcess):
if self.embedding is not None:
# Let's make sure we don't update any embedding weights besides the newly added token
index_no_updates = torch.ones((len(self.sd.tokenizer),), dtype=torch.bool)
index_no_updates[
min(self.embedding.placeholder_token_ids): max(self.embedding.placeholder_token_ids) + 1] = False
with torch.no_grad():
self.sd.text_encoder.get_input_embeddings().weight[
index_no_updates
] = self.orig_embeds_params[index_no_updates]
self.embedding.restore_embeddings()
loss_dict = OrderedDict(
{'loss': loss.item()}