mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 08:29:45 +00:00
Bugfixes for slider reference
This commit is contained in:
@@ -159,7 +159,13 @@ class ImageReferenceSliderTrainerProcess(BaseSDTrainProcess):
|
||||
|
||||
# if training text encoder enable grads, else do context of no grad
|
||||
with torch.set_grad_enabled(self.train_config.train_text_encoder):
|
||||
conditional_embeds = self.sd.encode_prompt(prompts).to(self.device_torch, dtype=dtype)
|
||||
# fix issue with them being tuples sometimes
|
||||
prompt_list = []
|
||||
for prompt in prompts:
|
||||
if isinstance(prompt, tuple):
|
||||
prompt = prompt[0]
|
||||
prompt_list.append(prompt)
|
||||
conditional_embeds = self.sd.encode_prompt(prompt_list).to(self.device_torch, dtype=dtype)
|
||||
conditional_embeds = concat_prompt_embeds([conditional_embeds, conditional_embeds])
|
||||
|
||||
# if self.model_config.is_xl:
|
||||
|
||||
@@ -55,7 +55,7 @@ class ToolkitModuleMixin:
|
||||
self.is_checkpointing = False
|
||||
self.is_normalizing = False
|
||||
self.normalize_scaler = 1.0
|
||||
self._multiplier: Union[float, list, torch.Tensor] = 1.0
|
||||
self._multiplier: Union[float, list, torch.Tensor] = None
|
||||
|
||||
# this allows us to set different multipliers on a per item in a batch basis
|
||||
# allowing us to run positive and negative weights in the same batch
|
||||
@@ -123,6 +123,8 @@ class ToolkitModuleMixin:
|
||||
return lx * scale
|
||||
|
||||
def forward(self: Module, x):
|
||||
if self._multiplier is None:
|
||||
self.set_multiplier(0.0)
|
||||
|
||||
org_forwarded = self.org_forward(x)
|
||||
lora_output = self._call_forward(x)
|
||||
|
||||
Reference in New Issue
Block a user