mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-30 19:21:39 +00:00
Bugfixes for slider reference
This commit is contained in:
@@ -159,7 +159,13 @@ class ImageReferenceSliderTrainerProcess(BaseSDTrainProcess):
|
|||||||
|
|
||||||
# if training text encoder enable grads, else do context of no grad
|
# if training text encoder enable grads, else do context of no grad
|
||||||
with torch.set_grad_enabled(self.train_config.train_text_encoder):
|
with torch.set_grad_enabled(self.train_config.train_text_encoder):
|
||||||
conditional_embeds = self.sd.encode_prompt(prompts).to(self.device_torch, dtype=dtype)
|
# fix issue with them being tuples sometimes
|
||||||
|
prompt_list = []
|
||||||
|
for prompt in prompts:
|
||||||
|
if isinstance(prompt, tuple):
|
||||||
|
prompt = prompt[0]
|
||||||
|
prompt_list.append(prompt)
|
||||||
|
conditional_embeds = self.sd.encode_prompt(prompt_list).to(self.device_torch, dtype=dtype)
|
||||||
conditional_embeds = concat_prompt_embeds([conditional_embeds, conditional_embeds])
|
conditional_embeds = concat_prompt_embeds([conditional_embeds, conditional_embeds])
|
||||||
|
|
||||||
# if self.model_config.is_xl:
|
# if self.model_config.is_xl:
|
||||||
|
|||||||
@@ -55,7 +55,7 @@ class ToolkitModuleMixin:
|
|||||||
self.is_checkpointing = False
|
self.is_checkpointing = False
|
||||||
self.is_normalizing = False
|
self.is_normalizing = False
|
||||||
self.normalize_scaler = 1.0
|
self.normalize_scaler = 1.0
|
||||||
self._multiplier: Union[float, list, torch.Tensor] = 1.0
|
self._multiplier: Union[float, list, torch.Tensor] = None
|
||||||
|
|
||||||
# this allows us to set different multipliers on a per item in a batch basis
|
# this allows us to set different multipliers on a per item in a batch basis
|
||||||
# allowing us to run positive and negative weights in the same batch
|
# allowing us to run positive and negative weights in the same batch
|
||||||
@@ -123,6 +123,8 @@ class ToolkitModuleMixin:
|
|||||||
return lx * scale
|
return lx * scale
|
||||||
|
|
||||||
def forward(self: Module, x):
|
def forward(self: Module, x):
|
||||||
|
if self._multiplier is None:
|
||||||
|
self.set_multiplier(0.0)
|
||||||
|
|
||||||
org_forwarded = self.org_forward(x)
|
org_forwarded = self.org_forward(x)
|
||||||
lora_output = self._call_forward(x)
|
lora_output = self._call_forward(x)
|
||||||
|
|||||||
Reference in New Issue
Block a user