Bug fixes, negative prompting during training, hardened catching

This commit is contained in:
Jaret Burkett
2023-11-24 07:25:11 -07:00
parent fbec68681d
commit d7e55b6ad4
6 changed files with 93 additions and 9 deletions

View File

@@ -254,6 +254,7 @@ class ToolkitModuleMixin:
multiplier_batch_size = multiplier.size(0)
if lora_output_batch_size != multiplier_batch_size:
num_interleaves = lora_output_batch_size // multiplier_batch_size
# todo check if this is correct, do we just concat when doing cfg?
multiplier = multiplier.repeat_interleave(num_interleaves)
x = org_forwarded + broadcast_and_multiply(lora_output, multiplier)
@@ -470,11 +471,11 @@ class ToolkitNetworkMixin:
self.torch_multiplier = tensor_multiplier.clone().detach()
@property
def multiplier(self) -> Union[float, List[float]]:
def multiplier(self) -> Union[float, List[float], List[List[float]]]:
return self._multiplier
@multiplier.setter
def multiplier(self, value: Union[float, List[float]]):
def multiplier(self, value: Union[float, List[float], List[List[float]]]):
# it takes time to update all the multipliers, so we only do it if the value has changed
if self._multiplier == value:
return