mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-28 18:21:16 +00:00
Fixed cleanup of emebddings.
This commit is contained in:
@@ -175,6 +175,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
|||||||
return generate_image_config_list
|
return generate_image_config_list
|
||||||
|
|
||||||
def sample(self, step=None, is_first=False):
|
def sample(self, step=None, is_first=False):
|
||||||
|
flush()
|
||||||
sample_folder = os.path.join(self.save_root, 'samples')
|
sample_folder = os.path.join(self.save_root, 'samples')
|
||||||
gen_img_config_list = []
|
gen_img_config_list = []
|
||||||
|
|
||||||
@@ -284,6 +285,13 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
|||||||
safetensors_files = [f for f in items if f.endswith('.safetensors')]
|
safetensors_files = [f for f in items if f.endswith('.safetensors')]
|
||||||
pt_files = [f for f in items if f.endswith('.pt')]
|
pt_files = [f for f in items if f.endswith('.pt')]
|
||||||
directories = [d for d in items if os.path.isdir(d) and not d.endswith('.safetensors')]
|
directories = [d for d in items if os.path.isdir(d) and not d.endswith('.safetensors')]
|
||||||
|
embed_files = []
|
||||||
|
# do embedding files
|
||||||
|
if self.embed_config is not None:
|
||||||
|
embed_pattern = f"{self.embed_config.trigger}_*"
|
||||||
|
embed_items = glob.glob(os.path.join(self.save_root, embed_pattern))
|
||||||
|
# will end in safetensors or pt
|
||||||
|
embed_files = [f for f in embed_items if f.endswith('.safetensors') or f.endswith('.pt')]
|
||||||
|
|
||||||
# Sort the lists by creation time if they are not empty
|
# Sort the lists by creation time if they are not empty
|
||||||
if safetensors_files:
|
if safetensors_files:
|
||||||
@@ -292,6 +300,8 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
|||||||
pt_files.sort(key=os.path.getctime)
|
pt_files.sort(key=os.path.getctime)
|
||||||
if directories:
|
if directories:
|
||||||
directories.sort(key=os.path.getctime)
|
directories.sort(key=os.path.getctime)
|
||||||
|
if embed_files:
|
||||||
|
embed_files.sort(key=os.path.getctime)
|
||||||
|
|
||||||
# Combine and sort the lists
|
# Combine and sort the lists
|
||||||
combined_items = safetensors_files + directories + pt_files
|
combined_items = safetensors_files + directories + pt_files
|
||||||
@@ -302,10 +312,9 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
|||||||
:-self.save_config.max_step_saves_to_keep] if safetensors_files else []
|
:-self.save_config.max_step_saves_to_keep] if safetensors_files else []
|
||||||
pt_files_to_remove = pt_files[:-self.save_config.max_step_saves_to_keep] if pt_files else []
|
pt_files_to_remove = pt_files[:-self.save_config.max_step_saves_to_keep] if pt_files else []
|
||||||
directories_to_remove = directories[:-self.save_config.max_step_saves_to_keep] if directories else []
|
directories_to_remove = directories[:-self.save_config.max_step_saves_to_keep] if directories else []
|
||||||
combined_items_to_remove = combined_items[
|
embeddings_to_remove = embed_files[:-self.save_config.max_step_saves_to_keep] if embed_files else []
|
||||||
:-self.save_config.max_step_saves_to_keep] if combined_items else []
|
|
||||||
|
|
||||||
items_to_remove = safetensors_to_remove + pt_files_to_remove + directories_to_remove
|
items_to_remove = safetensors_to_remove + pt_files_to_remove + directories_to_remove + embeddings_to_remove
|
||||||
|
|
||||||
# remove all but the latest max_step_saves_to_keep
|
# remove all but the latest max_step_saves_to_keep
|
||||||
# items_to_remove = combined_items[:-self.save_config.max_step_saves_to_keep]
|
# items_to_remove = combined_items[:-self.save_config.max_step_saves_to_keep]
|
||||||
|
|||||||
@@ -13,6 +13,21 @@ def flush(garbage_collect=True):
|
|||||||
gc.collect()
|
gc.collect()
|
||||||
|
|
||||||
|
|
||||||
|
def get_mean_std(tensor):
|
||||||
|
if len(tensor.shape) == 3:
|
||||||
|
tensor = tensor.unsqueeze(0)
|
||||||
|
elif len(tensor.shape) != 4:
|
||||||
|
raise Exception("Expected tensor of shape (batch_size, channels, width, height)")
|
||||||
|
mean, variance = torch.mean(
|
||||||
|
tensor, dim=[2, 3], keepdim=True
|
||||||
|
), torch.var(
|
||||||
|
tensor, dim=[2, 3],
|
||||||
|
keepdim=True
|
||||||
|
)
|
||||||
|
std = torch.sqrt(variance + 1e-5)
|
||||||
|
return mean, std
|
||||||
|
|
||||||
|
|
||||||
def adain(content_features, style_features):
|
def adain(content_features, style_features):
|
||||||
# Assumes that the content and style features are of shape (batch_size, channels, width, height)
|
# Assumes that the content and style features are of shape (batch_size, channels, width, height)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user