Save embeddings as their trigger to match auto and comfy style loading. Also, FINALLY found why gradients were wonkey and fixed it. The root problem is dropping out of network state before backward pass.

This commit is contained in:
Jaret Burkett
2023-09-09 12:02:07 -06:00
parent 408c50ead1
commit be804c9cf5
5 changed files with 64 additions and 71 deletions

View File

@@ -255,32 +255,37 @@ class BaseSDTrainProcess(BaseTrainProcess):
file_path = os.path.join(self.save_root, filename)
# prepare meta
save_meta = get_meta_for_safetensors(self.meta, self.job.name)
if self.network is not None:
prev_multiplier = self.network.multiplier
self.network.multiplier = 1.0
if self.network_config.normalize:
# apply the normalization
self.network.apply_stored_normalizer()
if self.network is not None or self.embedding is not None:
if self.network is not None:
prev_multiplier = self.network.multiplier
self.network.multiplier = 1.0
if self.network_config.normalize:
# apply the normalization
self.network.apply_stored_normalizer()
# if we are doing embedding training as well, add that
embedding_dict = self.embedding.state_dict() if self.embedding else None
self.network.save_weights(
file_path,
dtype=get_torch_dtype(self.save_config.dtype),
metadata=save_meta,
extra_state_dict=embedding_dict
)
self.network.multiplier = prev_multiplier
# if we have an embedding as well, pair it with the network
elif self.embedding is not None:
# for combo, above will get it
# set current step
self.embedding.step = self.step_num
# change filename to pt if that is set
if self.embed_config.save_format == "pt":
# replace extension
file_path = os.path.splitext(file_path)[0] + ".pt"
self.embedding.save(file_path)
# if we are doing embedding training as well, add that
embedding_dict = self.embedding.state_dict() if self.embedding else None
self.network.save_weights(
file_path,
dtype=get_torch_dtype(self.save_config.dtype),
metadata=save_meta,
extra_state_dict=embedding_dict
)
self.network.multiplier = prev_multiplier
# if we have an embedding as well, pair it with the network
# even if added to lora, still save the trigger version
if self.embedding is not None:
emb_filename = f'{self.embed_config.trigger}{step_num}.safetensors'
emb_file_path = os.path.join(self.save_root, emb_filename)
# for combo, above will get it
# set current step
self.embedding.step = self.step_num
# change filename to pt if that is set
if self.embed_config.save_format == "pt":
# replace extension
file_path = os.path.splitext(emb_file_path)[0] + ".pt"
self.embedding.save(file_path)
else:
self.sd.save(
file_path,
@@ -316,17 +321,19 @@ class BaseSDTrainProcess(BaseTrainProcess):
# return loss
return 0.0
def get_latest_save_path(self):
def get_latest_save_path(self, name=None):
if name == None:
name = self.job.name
# get latest saved step
if os.path.exists(self.save_root):
latest_file = None
# pattern is {job_name}_{zero_filles_step}.safetensors or {job_name}.safetensors
pattern = f"{self.job.name}*.safetensors"
pattern = f"{name}*.safetensors"
files = glob.glob(os.path.join(self.save_root, pattern))
if len(files) > 0:
latest_file = max(files, key=os.path.getctime)
# try pt
pattern = f"{self.job.name}*.pt"
pattern = f"{name}*.pt"
files = glob.glob(os.path.join(self.save_root, pattern))
if len(files) > 0:
latest_file = max(files, key=os.path.getctime)
@@ -548,9 +555,12 @@ class BaseSDTrainProcess(BaseTrainProcess):
# we are doing embedding training as well
self.embedding = Embedding(
sd=self.sd,
embed_config=self.embed_config,
state_dict=extra_weights
embed_config=self.embed_config
)
latest_save_path = self.get_latest_save_path(self.embed_config.trigger)
# load last saved weights
if latest_save_path is not None:
self.embedding.load_embedding_from_file(latest_save_path, self.device_torch)
params.append({
'params': self.embedding.get_trainable_params(),
'lr': self.train_config.embedding_lr
@@ -562,7 +572,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
sd=self.sd,
embed_config=self.embed_config
)
latest_save_path = self.get_latest_save_path()
latest_save_path = self.get_latest_save_path(self.embed_config.trigger)
# load last saved weights
if latest_save_path is not None:
self.embedding.load_embedding_from_file(latest_save_path, self.device_torch)