mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-28 02:01:29 +00:00
Bug fixes, negative prompting during training, hardened catching
This commit is contained in:
@@ -193,6 +193,9 @@ class TrainConfig:
|
||||
self.adapter_assist_name_or_path: Optional[str] = kwargs.get('adapter_assist_name_or_path', None)
|
||||
self.noise_multiplier = kwargs.get('noise_multiplier', 1.0)
|
||||
self.img_multiplier = kwargs.get('img_multiplier', 1.0)
|
||||
self.negative_prompt = kwargs.get('negative_prompt', None)
|
||||
# multiplier applied to loos on regularization images
|
||||
self.reg_weight = kwargs.get('reg_weight', 1.0)
|
||||
|
||||
# dropout that happens before encoding. It functions independently per text encoder
|
||||
self.prompt_dropout_prob = kwargs.get('prompt_dropout_prob', 0.0)
|
||||
|
||||
@@ -77,6 +77,10 @@ def parse_metadata_from_safetensors(meta: OrderedDict) -> OrderedDict:
|
||||
|
||||
|
||||
def load_metadata_from_safetensors(file_path: str) -> OrderedDict:
|
||||
with safe_open(file_path, framework="pt") as f:
|
||||
metadata = f.metadata()
|
||||
return parse_metadata_from_safetensors(metadata)
|
||||
try:
|
||||
with safe_open(file_path, framework="pt") as f:
|
||||
metadata = f.metadata()
|
||||
return parse_metadata_from_safetensors(metadata)
|
||||
except Exception as e:
|
||||
print(f"Error loading metadata from {file_path}: {e}")
|
||||
return OrderedDict()
|
||||
|
||||
@@ -254,6 +254,7 @@ class ToolkitModuleMixin:
|
||||
multiplier_batch_size = multiplier.size(0)
|
||||
if lora_output_batch_size != multiplier_batch_size:
|
||||
num_interleaves = lora_output_batch_size // multiplier_batch_size
|
||||
# todo check if this is correct, do we just concat when doing cfg?
|
||||
multiplier = multiplier.repeat_interleave(num_interleaves)
|
||||
|
||||
x = org_forwarded + broadcast_and_multiply(lora_output, multiplier)
|
||||
@@ -470,11 +471,11 @@ class ToolkitNetworkMixin:
|
||||
self.torch_multiplier = tensor_multiplier.clone().detach()
|
||||
|
||||
@property
|
||||
def multiplier(self) -> Union[float, List[float]]:
|
||||
def multiplier(self) -> Union[float, List[float], List[List[float]]]:
|
||||
return self._multiplier
|
||||
|
||||
@multiplier.setter
|
||||
def multiplier(self, value: Union[float, List[float]]):
|
||||
def multiplier(self, value: Union[float, List[float], List[List[float]]]):
|
||||
# it takes time to update all the multipliers, so we only do it if the value has changed
|
||||
if self._multiplier == value:
|
||||
return
|
||||
|
||||
Reference in New Issue
Block a user