mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-30 19:21:39 +00:00
Save embeddings as their trigger to match auto and comfy style loading. Also, FINALLY found why gradients were wonkey and fixed it. The root problem is dropping out of network state before backward pass.
This commit is contained in:
@@ -255,32 +255,37 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
file_path = os.path.join(self.save_root, filename)
|
||||
# prepare meta
|
||||
save_meta = get_meta_for_safetensors(self.meta, self.job.name)
|
||||
if self.network is not None:
|
||||
prev_multiplier = self.network.multiplier
|
||||
self.network.multiplier = 1.0
|
||||
if self.network_config.normalize:
|
||||
# apply the normalization
|
||||
self.network.apply_stored_normalizer()
|
||||
if self.network is not None or self.embedding is not None:
|
||||
if self.network is not None:
|
||||
prev_multiplier = self.network.multiplier
|
||||
self.network.multiplier = 1.0
|
||||
if self.network_config.normalize:
|
||||
# apply the normalization
|
||||
self.network.apply_stored_normalizer()
|
||||
|
||||
# if we are doing embedding training as well, add that
|
||||
embedding_dict = self.embedding.state_dict() if self.embedding else None
|
||||
self.network.save_weights(
|
||||
file_path,
|
||||
dtype=get_torch_dtype(self.save_config.dtype),
|
||||
metadata=save_meta,
|
||||
extra_state_dict=embedding_dict
|
||||
)
|
||||
self.network.multiplier = prev_multiplier
|
||||
# if we have an embedding as well, pair it with the network
|
||||
elif self.embedding is not None:
|
||||
# for combo, above will get it
|
||||
# set current step
|
||||
self.embedding.step = self.step_num
|
||||
# change filename to pt if that is set
|
||||
if self.embed_config.save_format == "pt":
|
||||
# replace extension
|
||||
file_path = os.path.splitext(file_path)[0] + ".pt"
|
||||
self.embedding.save(file_path)
|
||||
# if we are doing embedding training as well, add that
|
||||
embedding_dict = self.embedding.state_dict() if self.embedding else None
|
||||
self.network.save_weights(
|
||||
file_path,
|
||||
dtype=get_torch_dtype(self.save_config.dtype),
|
||||
metadata=save_meta,
|
||||
extra_state_dict=embedding_dict
|
||||
)
|
||||
self.network.multiplier = prev_multiplier
|
||||
# if we have an embedding as well, pair it with the network
|
||||
|
||||
# even if added to lora, still save the trigger version
|
||||
if self.embedding is not None:
|
||||
emb_filename = f'{self.embed_config.trigger}{step_num}.safetensors'
|
||||
emb_file_path = os.path.join(self.save_root, emb_filename)
|
||||
# for combo, above will get it
|
||||
# set current step
|
||||
self.embedding.step = self.step_num
|
||||
# change filename to pt if that is set
|
||||
if self.embed_config.save_format == "pt":
|
||||
# replace extension
|
||||
file_path = os.path.splitext(emb_file_path)[0] + ".pt"
|
||||
self.embedding.save(file_path)
|
||||
else:
|
||||
self.sd.save(
|
||||
file_path,
|
||||
@@ -316,17 +321,19 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
# return loss
|
||||
return 0.0
|
||||
|
||||
def get_latest_save_path(self):
|
||||
def get_latest_save_path(self, name=None):
|
||||
if name == None:
|
||||
name = self.job.name
|
||||
# get latest saved step
|
||||
if os.path.exists(self.save_root):
|
||||
latest_file = None
|
||||
# pattern is {job_name}_{zero_filles_step}.safetensors or {job_name}.safetensors
|
||||
pattern = f"{self.job.name}*.safetensors"
|
||||
pattern = f"{name}*.safetensors"
|
||||
files = glob.glob(os.path.join(self.save_root, pattern))
|
||||
if len(files) > 0:
|
||||
latest_file = max(files, key=os.path.getctime)
|
||||
# try pt
|
||||
pattern = f"{self.job.name}*.pt"
|
||||
pattern = f"{name}*.pt"
|
||||
files = glob.glob(os.path.join(self.save_root, pattern))
|
||||
if len(files) > 0:
|
||||
latest_file = max(files, key=os.path.getctime)
|
||||
@@ -548,9 +555,12 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
# we are doing embedding training as well
|
||||
self.embedding = Embedding(
|
||||
sd=self.sd,
|
||||
embed_config=self.embed_config,
|
||||
state_dict=extra_weights
|
||||
embed_config=self.embed_config
|
||||
)
|
||||
latest_save_path = self.get_latest_save_path(self.embed_config.trigger)
|
||||
# load last saved weights
|
||||
if latest_save_path is not None:
|
||||
self.embedding.load_embedding_from_file(latest_save_path, self.device_torch)
|
||||
params.append({
|
||||
'params': self.embedding.get_trainable_params(),
|
||||
'lr': self.train_config.embedding_lr
|
||||
@@ -562,7 +572,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
sd=self.sd,
|
||||
embed_config=self.embed_config
|
||||
)
|
||||
latest_save_path = self.get_latest_save_path()
|
||||
latest_save_path = self.get_latest_save_path(self.embed_config.trigger)
|
||||
# load last saved weights
|
||||
if latest_save_path is not None:
|
||||
self.embedding.load_embedding_from_file(latest_save_path, self.device_torch)
|
||||
|
||||
Reference in New Issue
Block a user