mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-05-01 11:41:35 +00:00
Sampling tests and added fixes for cleanups
This commit is contained in:
@@ -282,12 +282,34 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
items = glob.glob(os.path.join(self.save_root, pattern))
|
||||
# Separate files and directories
|
||||
safetensors_files = [f for f in items if f.endswith('.safetensors')]
|
||||
pt_files = [f for f in items if f.endswith('.pt')]
|
||||
directories = [d for d in items if os.path.isdir(d) and not d.endswith('.safetensors')]
|
||||
# Combine the list and sort by creation time
|
||||
combined_items = safetensors_files + directories
|
||||
|
||||
# Sort the lists by creation time if they are not empty
|
||||
if safetensors_files:
|
||||
safetensors_files.sort(key=os.path.getctime)
|
||||
if pt_files:
|
||||
pt_files.sort(key=os.path.getctime)
|
||||
if directories:
|
||||
directories.sort(key=os.path.getctime)
|
||||
|
||||
# Combine and sort the lists
|
||||
combined_items = safetensors_files + directories + pt_files
|
||||
combined_items.sort(key=os.path.getctime)
|
||||
|
||||
# Use slicing with a check to avoid 'NoneType' error
|
||||
safetensors_to_remove = safetensors_files[
|
||||
:-self.save_config.max_step_saves_to_keep] if safetensors_files else []
|
||||
pt_files_to_remove = pt_files[:-self.save_config.max_step_saves_to_keep] if pt_files else []
|
||||
directories_to_remove = directories[:-self.save_config.max_step_saves_to_keep] if directories else []
|
||||
combined_items_to_remove = combined_items[
|
||||
:-self.save_config.max_step_saves_to_keep] if combined_items else []
|
||||
|
||||
items_to_remove = safetensors_to_remove + pt_files_to_remove + directories_to_remove
|
||||
|
||||
# remove all but the latest max_step_saves_to_keep
|
||||
items_to_remove = combined_items[:-self.save_config.max_step_saves_to_keep]
|
||||
# items_to_remove = combined_items[:-self.save_config.max_step_saves_to_keep]
|
||||
|
||||
for item in items_to_remove:
|
||||
self.print(f"Removing old save: {item}")
|
||||
if os.path.isdir(item):
|
||||
@@ -655,14 +677,15 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
do_double = False
|
||||
|
||||
with self.timer('prepare_noise'):
|
||||
num_train_timesteps = self.sd.noise_scheduler.config['num_train_timesteps']
|
||||
|
||||
if self.train_config.noise_scheduler == 'lcm':
|
||||
self.sd.noise_scheduler.set_timesteps(
|
||||
1000, device=self.device_torch, original_inference_steps=1000
|
||||
num_train_timesteps, device=self.device_torch, original_inference_steps=num_train_timesteps
|
||||
)
|
||||
else:
|
||||
self.sd.noise_scheduler.set_timesteps(
|
||||
1000, device=self.device_torch
|
||||
num_train_timesteps, device=self.device_torch
|
||||
)
|
||||
|
||||
# if self.train_config.timestep_sampling == 'style' or self.train_config.timestep_sampling == 'content':
|
||||
|
||||
Reference in New Issue
Block a user