Sampling tests and added fixes for cleanups

This commit is contained in:
Jaret Burkett
2023-11-16 08:33:23 -07:00
parent e47006ed70
commit ad50921c41
4 changed files with 495 additions and 5 deletions

View File

@@ -282,12 +282,34 @@ class BaseSDTrainProcess(BaseTrainProcess):
items = glob.glob(os.path.join(self.save_root, pattern))
# Separate files and directories
safetensors_files = [f for f in items if f.endswith('.safetensors')]
pt_files = [f for f in items if f.endswith('.pt')]
directories = [d for d in items if os.path.isdir(d) and not d.endswith('.safetensors')]
# Combine the list and sort by creation time
combined_items = safetensors_files + directories
# Sort the lists by creation time if they are not empty
if safetensors_files:
safetensors_files.sort(key=os.path.getctime)
if pt_files:
pt_files.sort(key=os.path.getctime)
if directories:
directories.sort(key=os.path.getctime)
# Combine and sort the lists
combined_items = safetensors_files + directories + pt_files
combined_items.sort(key=os.path.getctime)
# Use slicing with a check to avoid 'NoneType' error
safetensors_to_remove = safetensors_files[
:-self.save_config.max_step_saves_to_keep] if safetensors_files else []
pt_files_to_remove = pt_files[:-self.save_config.max_step_saves_to_keep] if pt_files else []
directories_to_remove = directories[:-self.save_config.max_step_saves_to_keep] if directories else []
combined_items_to_remove = combined_items[
:-self.save_config.max_step_saves_to_keep] if combined_items else []
items_to_remove = safetensors_to_remove + pt_files_to_remove + directories_to_remove
# remove all but the latest max_step_saves_to_keep
items_to_remove = combined_items[:-self.save_config.max_step_saves_to_keep]
# items_to_remove = combined_items[:-self.save_config.max_step_saves_to_keep]
for item in items_to_remove:
self.print(f"Removing old save: {item}")
if os.path.isdir(item):
@@ -655,14 +677,15 @@ class BaseSDTrainProcess(BaseTrainProcess):
do_double = False
with self.timer('prepare_noise'):
num_train_timesteps = self.sd.noise_scheduler.config['num_train_timesteps']
if self.train_config.noise_scheduler == 'lcm':
self.sd.noise_scheduler.set_timesteps(
1000, device=self.device_torch, original_inference_steps=1000
num_train_timesteps, device=self.device_torch, original_inference_steps=num_train_timesteps
)
else:
self.sd.noise_scheduler.set_timesteps(
1000, device=self.device_torch
num_train_timesteps, device=self.device_torch
)
# if self.train_config.timestep_sampling == 'style' or self.train_config.timestep_sampling == 'content':