mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-30 11:11:37 +00:00
Switched to new bucket system that matched sdxl trained buckets. Fixed requirements. Updated embeddings to work with sdxl. Added method to train lora with an embedding at the trigger. Still testing but works amazingly well from what I can see
This commit is contained in:
@@ -26,11 +26,9 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
self.sd.vae.to(self.device_torch)
|
||||
|
||||
# textual inversion
|
||||
if self.embedding is not None:
|
||||
# keep original embeddings as reference
|
||||
self.orig_embeds_params = self.sd.text_encoder.get_input_embeddings().weight.data.clone()
|
||||
# if self.embedding is not None:
|
||||
# set text encoder to train. Not sure if this is necessary but diffusers example did it
|
||||
self.sd.text_encoder.train()
|
||||
# self.sd.text_encoder.train()
|
||||
|
||||
def hook_train_loop(self, batch):
|
||||
dtype = get_torch_dtype(self.train_config.dtype)
|
||||
@@ -103,13 +101,7 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
|
||||
if self.embedding is not None:
|
||||
# Let's make sure we don't update any embedding weights besides the newly added token
|
||||
index_no_updates = torch.ones((len(self.sd.tokenizer),), dtype=torch.bool)
|
||||
index_no_updates[
|
||||
min(self.embedding.placeholder_token_ids): max(self.embedding.placeholder_token_ids) + 1] = False
|
||||
with torch.no_grad():
|
||||
self.sd.text_encoder.get_input_embeddings().weight[
|
||||
index_no_updates
|
||||
] = self.orig_embeds_params[index_no_updates]
|
||||
self.embedding.restore_embeddings()
|
||||
|
||||
loss_dict = OrderedDict(
|
||||
{'loss': loss.item()}
|
||||
|
||||
Reference in New Issue
Block a user