Save embeddings as their trigger to match auto and comfy style loading. Also, FINALLY found why gradients were wonkey and fixed it. The root problem is dropping out of network state before backward pass.

This commit is contained in:
Jaret Burkett
2023-09-09 12:02:07 -06:00
parent 408c50ead1
commit be804c9cf5
5 changed files with 64 additions and 71 deletions

View File

@@ -69,29 +69,33 @@ class SDTrainer(BaseSDTrainProcess):
guidance_scale=1.0,
)
flush()
# 9.18 gb
noise = noise.to(self.device_torch, dtype=dtype).detach()
# 9.18 gb
noise = noise.to(self.device_torch, dtype=dtype).detach()
if self.sd.prediction_type == 'v_prediction':
# v-parameterization training
target = self.sd.noise_scheduler.get_velocity(noisy_latents, noise, timesteps)
else:
target = noise
if self.sd.prediction_type == 'v_prediction':
# v-parameterization training
target = self.sd.noise_scheduler.get_velocity(noisy_latents, noise, timesteps)
else:
target = noise
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
loss = loss.mean([1, 2, 3])
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
loss = loss.mean([1, 2, 3])
if self.train_config.min_snr_gamma is not None and self.train_config.min_snr_gamma > 0.000001:
# add min_snr_gamma
loss = apply_snr_weight(loss, timesteps, self.sd.noise_scheduler, self.train_config.min_snr_gamma)
if self.train_config.min_snr_gamma is not None and self.train_config.min_snr_gamma > 0.000001:
# add min_snr_gamma
loss = apply_snr_weight(loss, timesteps, self.sd.noise_scheduler, self.train_config.min_snr_gamma)
loss = loss.mean()
loss = loss.mean()
# back propagate loss to free ram
loss.backward()
torch.nn.utils.clip_grad_norm_(self.params, self.train_config.max_grad_norm)
flush()
# IMPORTANT if gradient checkpointing do not leave with network when doing backward
# it will destroy the gradients. This is because the network is a context manager
# and will change the multipliers back to 0.0 when exiting. They will be
# 0.0 for the backward pass and the gradients will be 0.0
# I spent weeks on fighting this. DON'T DO IT
loss.backward()
torch.nn.utils.clip_grad_norm_(self.params, self.train_config.max_grad_norm)
flush()
# apply gradients
self.optimizer.step()

View File

@@ -255,32 +255,37 @@ class BaseSDTrainProcess(BaseTrainProcess):
file_path = os.path.join(self.save_root, filename)
# prepare meta
save_meta = get_meta_for_safetensors(self.meta, self.job.name)
if self.network is not None:
prev_multiplier = self.network.multiplier
self.network.multiplier = 1.0
if self.network_config.normalize:
# apply the normalization
self.network.apply_stored_normalizer()
if self.network is not None or self.embedding is not None:
if self.network is not None:
prev_multiplier = self.network.multiplier
self.network.multiplier = 1.0
if self.network_config.normalize:
# apply the normalization
self.network.apply_stored_normalizer()
# if we are doing embedding training as well, add that
embedding_dict = self.embedding.state_dict() if self.embedding else None
self.network.save_weights(
file_path,
dtype=get_torch_dtype(self.save_config.dtype),
metadata=save_meta,
extra_state_dict=embedding_dict
)
self.network.multiplier = prev_multiplier
# if we have an embedding as well, pair it with the network
elif self.embedding is not None:
# for combo, above will get it
# set current step
self.embedding.step = self.step_num
# change filename to pt if that is set
if self.embed_config.save_format == "pt":
# replace extension
file_path = os.path.splitext(file_path)[0] + ".pt"
self.embedding.save(file_path)
# if we are doing embedding training as well, add that
embedding_dict = self.embedding.state_dict() if self.embedding else None
self.network.save_weights(
file_path,
dtype=get_torch_dtype(self.save_config.dtype),
metadata=save_meta,
extra_state_dict=embedding_dict
)
self.network.multiplier = prev_multiplier
# if we have an embedding as well, pair it with the network
# even if added to lora, still save the trigger version
if self.embedding is not None:
emb_filename = f'{self.embed_config.trigger}{step_num}.safetensors'
emb_file_path = os.path.join(self.save_root, emb_filename)
# for combo, above will get it
# set current step
self.embedding.step = self.step_num
# change filename to pt if that is set
if self.embed_config.save_format == "pt":
# replace extension
file_path = os.path.splitext(emb_file_path)[0] + ".pt"
self.embedding.save(file_path)
else:
self.sd.save(
file_path,
@@ -316,17 +321,19 @@ class BaseSDTrainProcess(BaseTrainProcess):
# return loss
return 0.0
def get_latest_save_path(self):
def get_latest_save_path(self, name=None):
if name == None:
name = self.job.name
# get latest saved step
if os.path.exists(self.save_root):
latest_file = None
# pattern is {job_name}_{zero_filles_step}.safetensors or {job_name}.safetensors
pattern = f"{self.job.name}*.safetensors"
pattern = f"{name}*.safetensors"
files = glob.glob(os.path.join(self.save_root, pattern))
if len(files) > 0:
latest_file = max(files, key=os.path.getctime)
# try pt
pattern = f"{self.job.name}*.pt"
pattern = f"{name}*.pt"
files = glob.glob(os.path.join(self.save_root, pattern))
if len(files) > 0:
latest_file = max(files, key=os.path.getctime)
@@ -548,9 +555,12 @@ class BaseSDTrainProcess(BaseTrainProcess):
# we are doing embedding training as well
self.embedding = Embedding(
sd=self.sd,
embed_config=self.embed_config,
state_dict=extra_weights
embed_config=self.embed_config
)
latest_save_path = self.get_latest_save_path(self.embed_config.trigger)
# load last saved weights
if latest_save_path is not None:
self.embedding.load_embedding_from_file(latest_save_path, self.device_torch)
params.append({
'params': self.embedding.get_trainable_params(),
'lr': self.train_config.embedding_lr
@@ -562,7 +572,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
sd=self.sd,
embed_config=self.embed_config
)
latest_save_path = self.get_latest_save_path()
latest_save_path = self.get_latest_save_path(self.embed_config.trigger)
# load last saved weights
if latest_save_path is not None:
self.embedding.load_embedding_from_file(latest_save_path, self.device_torch)

View File

@@ -50,10 +50,7 @@ class LoRAModule(ToolkitModuleMixin, torch.nn.Module):
**kwargs
):
"""if alpha == 0 or None, alpha is rank (no scaling)."""
super().__init__(
org_module=org_module,
parent=parent
)
super().__init__()
self.lora_name = lora_name
self.scalar = torch.tensor(1.0)

View File

@@ -36,12 +36,7 @@ class LoConSpecialModule(ToolkitModuleMixin, LoConModule):
# call super of super
torch.nn.Module.__init__(self)
# call super of
super().__init__(
org_module=org_module,
call_super_init=False,
parent=parent,
**kwargs
)
super().__init__(call_super_init=False)
self.lora_name = lora_name
self.lora_dim = lora_dim
self.cp = False

View File

@@ -52,18 +52,9 @@ class ToolkitModuleMixin:
):
if call_super_init:
super().__init__(*args, **kwargs)
self.tk_orig_module: torch.nn.Module = kwargs.get('org_module', None)
self.tk_orig_parent = kwargs.get('parent', None)
self.is_checkpointing = False
self.is_normalizing = False
self.normalize_scaler = 1.0
# see if is conv or linear
self.is_conv = False
self.is_linear = False
if self.tk_orig_module.__class__.__name__ in LINEAR_MODULES:
self.is_linear = True
elif self.tk_orig_module.__class__.__name__ in CONV_MODULES:
self.is_conv = True
self._multiplier: Union[float, list, torch.Tensor] = 1.0
# this allows us to set different multipliers on a per item in a batch basis
@@ -140,10 +131,6 @@ class ToolkitModuleMixin:
lora_output_batch_size = lora_output.size(0)
multiplier_batch_size = multiplier.size(0)
if lora_output_batch_size != multiplier_batch_size:
print(
f"Warning: lora_output_batch_size {lora_output_batch_size} != multiplier_batch_size {multiplier_batch_size}")
# doing cfg
# should be 1 for if total batch size was 1
num_interleaves = (lora_output_batch_size // 2) // multiplier_batch_size
multiplier = multiplier.repeat_interleave(num_interleaves)
# multiplier = 1.0