Added te aug adapter

This commit is contained in:
Jaret Burkett
2024-02-21 21:30:26 -07:00
parent 49c41e6a5f
commit b68c3ef734
5 changed files with 310 additions and 8 deletions

View File

@@ -930,6 +930,9 @@ class SDTrainer(BaseSDTrainProcess):
if self.adapter and isinstance(self.adapter, ClipVisionAdapter):
grad_on_text_encoder = True
if self.adapter_config.type == 'te_augmenter':
grad_on_text_encoder = True
# have a blank network so we can wrap it in a context and set multipliers without checking every time
if self.network is not None:
network = self.network
@@ -1045,6 +1048,8 @@ class SDTrainer(BaseSDTrainProcess):
unconditional_embeds = None
if grad_on_text_encoder:
with torch.set_grad_enabled(True):
if isinstance(self.adapter, CustomAdapter):
self.adapter.is_unconditional_run = False
conditional_embeds = self.sd.encode_prompt(
conditioned_prompts, prompt_2,
dropout_prob=self.train_config.prompt_dropout_prob,
@@ -1053,6 +1058,8 @@ class SDTrainer(BaseSDTrainProcess):
dtype=dtype)
if self.train_config.do_cfg:
if isinstance(self.adapter, CustomAdapter):
self.adapter.is_unconditional_run = True
# todo only do one and repeat it
unconditional_embeds = self.sd.encode_prompt(
self.batch_negative_prompt,
@@ -1061,6 +1068,8 @@ class SDTrainer(BaseSDTrainProcess):
long_prompts=self.do_long_prompts).to(
self.device_torch,
dtype=dtype)
if isinstance(self.adapter, CustomAdapter):
self.adapter.is_unconditional_run = False
else:
with torch.set_grad_enabled(False):
# make sure it is in eval mode
@@ -1069,6 +1078,8 @@ class SDTrainer(BaseSDTrainProcess):
te.eval()
else:
self.sd.text_encoder.eval()
if isinstance(self.adapter, CustomAdapter):
self.adapter.is_unconditional_run = False
conditional_embeds = self.sd.encode_prompt(
conditioned_prompts, prompt_2,
dropout_prob=self.train_config.prompt_dropout_prob,
@@ -1076,12 +1087,16 @@ class SDTrainer(BaseSDTrainProcess):
self.device_torch,
dtype=dtype)
if self.train_config.do_cfg:
if isinstance(self.adapter, CustomAdapter):
self.adapter.is_unconditional_run = True
unconditional_embeds = self.sd.encode_prompt(
self.batch_negative_prompt,
dropout_prob=self.train_config.prompt_dropout_prob,
long_prompts=self.do_long_prompts).to(
self.device_torch,
dtype=dtype)
if isinstance(self.adapter, CustomAdapter):
self.adapter.is_unconditional_run = False
# detach the embeddings
conditional_embeds = conditional_embeds.detach()