mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-05-01 03:31:35 +00:00
Tons of bug fixes and improvements to special training. Fixed slider training.
This commit is contained in:
@@ -293,6 +293,10 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
# will end in safetensors or pt
|
||||
embed_files = [f for f in embed_items if f.endswith('.safetensors') or f.endswith('.pt')]
|
||||
|
||||
# check for critic files
|
||||
critic_pattern = f"CRITIC_{self.job.name}_*"
|
||||
critic_items = glob.glob(os.path.join(self.save_root, critic_pattern))
|
||||
|
||||
# Sort the lists by creation time if they are not empty
|
||||
if safetensors_files:
|
||||
safetensors_files.sort(key=os.path.getctime)
|
||||
@@ -302,6 +306,8 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
directories.sort(key=os.path.getctime)
|
||||
if embed_files:
|
||||
embed_files.sort(key=os.path.getctime)
|
||||
if critic_items:
|
||||
critic_items.sort(key=os.path.getctime)
|
||||
|
||||
# Combine and sort the lists
|
||||
combined_items = safetensors_files + directories + pt_files
|
||||
@@ -313,8 +319,9 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
pt_files_to_remove = pt_files[:-self.save_config.max_step_saves_to_keep] if pt_files else []
|
||||
directories_to_remove = directories[:-self.save_config.max_step_saves_to_keep] if directories else []
|
||||
embeddings_to_remove = embed_files[:-self.save_config.max_step_saves_to_keep] if embed_files else []
|
||||
critic_to_remove = critic_items[:-self.save_config.max_step_saves_to_keep] if critic_items else []
|
||||
|
||||
items_to_remove = safetensors_to_remove + pt_files_to_remove + directories_to_remove + embeddings_to_remove
|
||||
items_to_remove = safetensors_to_remove + pt_files_to_remove + directories_to_remove + embeddings_to_remove + critic_to_remove
|
||||
|
||||
# remove all but the latest max_step_saves_to_keep
|
||||
# items_to_remove = combined_items[:-self.save_config.max_step_saves_to_keep]
|
||||
@@ -1041,8 +1048,9 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
train_text_encoder=self.train_config.train_text_encoder,
|
||||
conv_lora_dim=self.network_config.conv,
|
||||
conv_alpha=self.network_config.conv_alpha,
|
||||
is_sdxl=self.model_config.is_xl,
|
||||
is_sdxl=self.model_config.is_xl or self.model_config.is_ssd,
|
||||
is_v2=self.model_config.is_v2,
|
||||
is_ssd=self.model_config.is_ssd,
|
||||
dropout=self.network_config.dropout,
|
||||
use_text_encoder_1=self.model_config.use_text_encoder_1,
|
||||
use_text_encoder_2=self.model_config.use_text_encoder_2,
|
||||
|
||||
@@ -371,7 +371,7 @@ class TrainSliderProcess(BaseSDTrainProcess):
|
||||
|
||||
# ger a random number of steps
|
||||
timesteps_to = torch.randint(
|
||||
1, self.train_config.max_denoising_steps, (1,)
|
||||
1, self.train_config.max_denoising_steps - 1, (1,)
|
||||
).item()
|
||||
|
||||
# get noise
|
||||
@@ -389,7 +389,8 @@ class TrainSliderProcess(BaseSDTrainProcess):
|
||||
assert not self.network.is_active
|
||||
self.sd.unet.eval()
|
||||
# pass the multiplier list to the network
|
||||
self.network.multiplier = prompt_pair.multiplier_list
|
||||
# double up since we are doing cfg
|
||||
self.network.multiplier = prompt_pair.multiplier_list + prompt_pair.multiplier_list
|
||||
denoised_latents = self.sd.diffuse_some_steps(
|
||||
latents, # pass simple noise latents
|
||||
train_tools.concat_prompt_embeddings(
|
||||
@@ -507,7 +508,7 @@ class TrainSliderProcess(BaseSDTrainProcess):
|
||||
for anchor_chunk, denoised_latent_chunk, anchor_target_noise_chunk in zip(
|
||||
anchor_chunks, denoised_latent_chunks, anchor_target_noise_chunks
|
||||
):
|
||||
self.network.multiplier = anchor_chunk.multiplier_list
|
||||
self.network.multiplier = anchor_chunk.multiplier_list + anchor_chunk.multiplier_list
|
||||
|
||||
anchor_pred_noise = get_noise_pred(
|
||||
anchor_chunk.neg_prompt, anchor_chunk.prompt, 1, current_timestep, denoised_latent_chunk
|
||||
@@ -582,7 +583,7 @@ class TrainSliderProcess(BaseSDTrainProcess):
|
||||
mask_multiplier_chunks,
|
||||
unmasked_target_chunks
|
||||
):
|
||||
self.network.multiplier = prompt_pair_chunk.multiplier_list
|
||||
self.network.multiplier = prompt_pair_chunk.multiplier_list + prompt_pair_chunk.multiplier_list
|
||||
target_latents = get_noise_pred(
|
||||
prompt_pair_chunk.positive_target,
|
||||
prompt_pair_chunk.target_class,
|
||||
@@ -611,6 +612,7 @@ class TrainSliderProcess(BaseSDTrainProcess):
|
||||
offset_neutral = neutral_latents_chunk
|
||||
# offsets are already adjusted on a per-batch basis
|
||||
offset_neutral += offset
|
||||
offset_neutral = offset_neutral.detach().requires_grad_(False)
|
||||
|
||||
# 16.15 GB RAM for 512x512 -> 4.20GB RAM for 512x512 with new grad_checkpointing
|
||||
loss = torch.nn.functional.mse_loss(target_latents.float(), offset_neutral.float(), reduction="none")
|
||||
|
||||
Reference in New Issue
Block a user