mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-02-28 00:03:57 +00:00
Save embeddings as their trigger to match auto and comfy style loading. Also, FINALLY found why gradients were wonkey and fixed it. The root problem is dropping out of network state before backward pass.
This commit is contained in:
@@ -69,29 +69,33 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
guidance_scale=1.0,
|
||||
)
|
||||
flush()
|
||||
# 9.18 gb
|
||||
noise = noise.to(self.device_torch, dtype=dtype).detach()
|
||||
# 9.18 gb
|
||||
noise = noise.to(self.device_torch, dtype=dtype).detach()
|
||||
|
||||
|
||||
if self.sd.prediction_type == 'v_prediction':
|
||||
# v-parameterization training
|
||||
target = self.sd.noise_scheduler.get_velocity(noisy_latents, noise, timesteps)
|
||||
else:
|
||||
target = noise
|
||||
if self.sd.prediction_type == 'v_prediction':
|
||||
# v-parameterization training
|
||||
target = self.sd.noise_scheduler.get_velocity(noisy_latents, noise, timesteps)
|
||||
else:
|
||||
target = noise
|
||||
|
||||
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
|
||||
loss = loss.mean([1, 2, 3])
|
||||
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
|
||||
loss = loss.mean([1, 2, 3])
|
||||
|
||||
if self.train_config.min_snr_gamma is not None and self.train_config.min_snr_gamma > 0.000001:
|
||||
# add min_snr_gamma
|
||||
loss = apply_snr_weight(loss, timesteps, self.sd.noise_scheduler, self.train_config.min_snr_gamma)
|
||||
if self.train_config.min_snr_gamma is not None and self.train_config.min_snr_gamma > 0.000001:
|
||||
# add min_snr_gamma
|
||||
loss = apply_snr_weight(loss, timesteps, self.sd.noise_scheduler, self.train_config.min_snr_gamma)
|
||||
|
||||
loss = loss.mean()
|
||||
loss = loss.mean()
|
||||
|
||||
# back propagate loss to free ram
|
||||
loss.backward()
|
||||
torch.nn.utils.clip_grad_norm_(self.params, self.train_config.max_grad_norm)
|
||||
flush()
|
||||
# IMPORTANT if gradient checkpointing do not leave with network when doing backward
|
||||
# it will destroy the gradients. This is because the network is a context manager
|
||||
# and will change the multipliers back to 0.0 when exiting. They will be
|
||||
# 0.0 for the backward pass and the gradients will be 0.0
|
||||
# I spent weeks on fighting this. DON'T DO IT
|
||||
loss.backward()
|
||||
torch.nn.utils.clip_grad_norm_(self.params, self.train_config.max_grad_norm)
|
||||
flush()
|
||||
|
||||
# apply gradients
|
||||
self.optimizer.step()
|
||||
|
||||
@@ -255,32 +255,37 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
file_path = os.path.join(self.save_root, filename)
|
||||
# prepare meta
|
||||
save_meta = get_meta_for_safetensors(self.meta, self.job.name)
|
||||
if self.network is not None:
|
||||
prev_multiplier = self.network.multiplier
|
||||
self.network.multiplier = 1.0
|
||||
if self.network_config.normalize:
|
||||
# apply the normalization
|
||||
self.network.apply_stored_normalizer()
|
||||
if self.network is not None or self.embedding is not None:
|
||||
if self.network is not None:
|
||||
prev_multiplier = self.network.multiplier
|
||||
self.network.multiplier = 1.0
|
||||
if self.network_config.normalize:
|
||||
# apply the normalization
|
||||
self.network.apply_stored_normalizer()
|
||||
|
||||
# if we are doing embedding training as well, add that
|
||||
embedding_dict = self.embedding.state_dict() if self.embedding else None
|
||||
self.network.save_weights(
|
||||
file_path,
|
||||
dtype=get_torch_dtype(self.save_config.dtype),
|
||||
metadata=save_meta,
|
||||
extra_state_dict=embedding_dict
|
||||
)
|
||||
self.network.multiplier = prev_multiplier
|
||||
# if we have an embedding as well, pair it with the network
|
||||
elif self.embedding is not None:
|
||||
# for combo, above will get it
|
||||
# set current step
|
||||
self.embedding.step = self.step_num
|
||||
# change filename to pt if that is set
|
||||
if self.embed_config.save_format == "pt":
|
||||
# replace extension
|
||||
file_path = os.path.splitext(file_path)[0] + ".pt"
|
||||
self.embedding.save(file_path)
|
||||
# if we are doing embedding training as well, add that
|
||||
embedding_dict = self.embedding.state_dict() if self.embedding else None
|
||||
self.network.save_weights(
|
||||
file_path,
|
||||
dtype=get_torch_dtype(self.save_config.dtype),
|
||||
metadata=save_meta,
|
||||
extra_state_dict=embedding_dict
|
||||
)
|
||||
self.network.multiplier = prev_multiplier
|
||||
# if we have an embedding as well, pair it with the network
|
||||
|
||||
# even if added to lora, still save the trigger version
|
||||
if self.embedding is not None:
|
||||
emb_filename = f'{self.embed_config.trigger}{step_num}.safetensors'
|
||||
emb_file_path = os.path.join(self.save_root, emb_filename)
|
||||
# for combo, above will get it
|
||||
# set current step
|
||||
self.embedding.step = self.step_num
|
||||
# change filename to pt if that is set
|
||||
if self.embed_config.save_format == "pt":
|
||||
# replace extension
|
||||
file_path = os.path.splitext(emb_file_path)[0] + ".pt"
|
||||
self.embedding.save(file_path)
|
||||
else:
|
||||
self.sd.save(
|
||||
file_path,
|
||||
@@ -316,17 +321,19 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
# return loss
|
||||
return 0.0
|
||||
|
||||
def get_latest_save_path(self):
|
||||
def get_latest_save_path(self, name=None):
|
||||
if name == None:
|
||||
name = self.job.name
|
||||
# get latest saved step
|
||||
if os.path.exists(self.save_root):
|
||||
latest_file = None
|
||||
# pattern is {job_name}_{zero_filles_step}.safetensors or {job_name}.safetensors
|
||||
pattern = f"{self.job.name}*.safetensors"
|
||||
pattern = f"{name}*.safetensors"
|
||||
files = glob.glob(os.path.join(self.save_root, pattern))
|
||||
if len(files) > 0:
|
||||
latest_file = max(files, key=os.path.getctime)
|
||||
# try pt
|
||||
pattern = f"{self.job.name}*.pt"
|
||||
pattern = f"{name}*.pt"
|
||||
files = glob.glob(os.path.join(self.save_root, pattern))
|
||||
if len(files) > 0:
|
||||
latest_file = max(files, key=os.path.getctime)
|
||||
@@ -548,9 +555,12 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
# we are doing embedding training as well
|
||||
self.embedding = Embedding(
|
||||
sd=self.sd,
|
||||
embed_config=self.embed_config,
|
||||
state_dict=extra_weights
|
||||
embed_config=self.embed_config
|
||||
)
|
||||
latest_save_path = self.get_latest_save_path(self.embed_config.trigger)
|
||||
# load last saved weights
|
||||
if latest_save_path is not None:
|
||||
self.embedding.load_embedding_from_file(latest_save_path, self.device_torch)
|
||||
params.append({
|
||||
'params': self.embedding.get_trainable_params(),
|
||||
'lr': self.train_config.embedding_lr
|
||||
@@ -562,7 +572,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
sd=self.sd,
|
||||
embed_config=self.embed_config
|
||||
)
|
||||
latest_save_path = self.get_latest_save_path()
|
||||
latest_save_path = self.get_latest_save_path(self.embed_config.trigger)
|
||||
# load last saved weights
|
||||
if latest_save_path is not None:
|
||||
self.embedding.load_embedding_from_file(latest_save_path, self.device_torch)
|
||||
|
||||
@@ -50,10 +50,7 @@ class LoRAModule(ToolkitModuleMixin, torch.nn.Module):
|
||||
**kwargs
|
||||
):
|
||||
"""if alpha == 0 or None, alpha is rank (no scaling)."""
|
||||
super().__init__(
|
||||
org_module=org_module,
|
||||
parent=parent
|
||||
)
|
||||
super().__init__()
|
||||
self.lora_name = lora_name
|
||||
self.scalar = torch.tensor(1.0)
|
||||
|
||||
|
||||
@@ -36,12 +36,7 @@ class LoConSpecialModule(ToolkitModuleMixin, LoConModule):
|
||||
# call super of super
|
||||
torch.nn.Module.__init__(self)
|
||||
# call super of
|
||||
super().__init__(
|
||||
org_module=org_module,
|
||||
call_super_init=False,
|
||||
parent=parent,
|
||||
**kwargs
|
||||
)
|
||||
super().__init__(call_super_init=False)
|
||||
self.lora_name = lora_name
|
||||
self.lora_dim = lora_dim
|
||||
self.cp = False
|
||||
|
||||
@@ -52,18 +52,9 @@ class ToolkitModuleMixin:
|
||||
):
|
||||
if call_super_init:
|
||||
super().__init__(*args, **kwargs)
|
||||
self.tk_orig_module: torch.nn.Module = kwargs.get('org_module', None)
|
||||
self.tk_orig_parent = kwargs.get('parent', None)
|
||||
self.is_checkpointing = False
|
||||
self.is_normalizing = False
|
||||
self.normalize_scaler = 1.0
|
||||
# see if is conv or linear
|
||||
self.is_conv = False
|
||||
self.is_linear = False
|
||||
if self.tk_orig_module.__class__.__name__ in LINEAR_MODULES:
|
||||
self.is_linear = True
|
||||
elif self.tk_orig_module.__class__.__name__ in CONV_MODULES:
|
||||
self.is_conv = True
|
||||
self._multiplier: Union[float, list, torch.Tensor] = 1.0
|
||||
|
||||
# this allows us to set different multipliers on a per item in a batch basis
|
||||
@@ -140,10 +131,6 @@ class ToolkitModuleMixin:
|
||||
lora_output_batch_size = lora_output.size(0)
|
||||
multiplier_batch_size = multiplier.size(0)
|
||||
if lora_output_batch_size != multiplier_batch_size:
|
||||
print(
|
||||
f"Warning: lora_output_batch_size {lora_output_batch_size} != multiplier_batch_size {multiplier_batch_size}")
|
||||
# doing cfg
|
||||
# should be 1 for if total batch size was 1
|
||||
num_interleaves = (lora_output_batch_size // 2) // multiplier_batch_size
|
||||
multiplier = multiplier.repeat_interleave(num_interleaves)
|
||||
# multiplier = 1.0
|
||||
|
||||
Reference in New Issue
Block a user