mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-27 17:51:41 +00:00
Added refiner fine tuning. Works, but needs some polish.
This commit is contained in:
@@ -152,6 +152,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
train_lora=self.network_config is not None,
|
||||
train_adapter=is_training_adapter,
|
||||
train_embedding=self.embed_config is not None,
|
||||
train_refiner=self.train_config.train_refiner,
|
||||
)
|
||||
|
||||
# fine_tuning here is for training actual SD network, not LoRA, embeddings, etc. it is (Dreambooth, etc)
|
||||
@@ -382,16 +383,29 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
file_path = file_path.replace('.safetensors', '')
|
||||
# convert it back to normal object
|
||||
save_meta = parse_metadata_from_safetensors(save_meta)
|
||||
self.sd.save(
|
||||
file_path,
|
||||
save_meta,
|
||||
get_torch_dtype(self.save_config.dtype)
|
||||
)
|
||||
|
||||
if self.sd.refiner_unet and self.train_config.train_refiner:
|
||||
# save refiner
|
||||
refiner_name = self.job.name + '_refiner'
|
||||
filename = f'{refiner_name}{step_num}.safetensors'
|
||||
file_path = os.path.join(self.save_root, filename)
|
||||
self.sd.save_refiner(
|
||||
file_path,
|
||||
save_meta,
|
||||
get_torch_dtype(self.save_config.dtype)
|
||||
)
|
||||
if self.train_config.train_unet or self.train_config.train_text_encoder:
|
||||
self.sd.save(
|
||||
file_path,
|
||||
save_meta,
|
||||
get_torch_dtype(self.save_config.dtype)
|
||||
)
|
||||
|
||||
# save learnable params as json if we have thim
|
||||
if self.snr_gos:
|
||||
json_data = {
|
||||
'offset': self.snr_gos.offset.item(),
|
||||
'offset_1': self.snr_gos.offset_1.item(),
|
||||
'offset_2': self.snr_gos.offset_2.item(),
|
||||
'scale': self.snr_gos.scale.item(),
|
||||
'gamma': self.snr_gos.gamma.item(),
|
||||
}
|
||||
@@ -447,7 +461,16 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
# Filter out non-existent paths and sort by creation time
|
||||
if paths:
|
||||
paths = [p for p in paths if os.path.exists(p)]
|
||||
latest_path = max(paths, key=os.path.getctime)
|
||||
# remove false positives
|
||||
if '_LoRA' not in name:
|
||||
paths = [p for p in paths if '_LoRA' not in p]
|
||||
if '_refiner' not in name:
|
||||
paths = [p for p in paths if '_refiner' not in p]
|
||||
if '_t2i' not in name:
|
||||
paths = [p for p in paths if '_t2i' not in p]
|
||||
|
||||
if len(paths) > 0:
|
||||
latest_path = max(paths, key=os.path.getctime)
|
||||
|
||||
return latest_path
|
||||
|
||||
@@ -540,6 +563,9 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
# double batch and add short captions to the end
|
||||
prompts = prompts + batch.get_caption_short_list()
|
||||
is_reg_list = is_reg_list + is_reg_list
|
||||
if self.model_config.refiner_name_or_path is not None and self.train_config.train_unet:
|
||||
prompts = prompts + prompts
|
||||
is_reg_list = is_reg_list + is_reg_list
|
||||
|
||||
conditioned_prompts = []
|
||||
|
||||
@@ -587,6 +613,16 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
unaugmented_latents = self.sd.encode_images(batch.unaugmented_tensor)
|
||||
|
||||
batch_size = len(batch.file_items)
|
||||
min_noise_steps = self.train_config.min_denoising_steps
|
||||
max_noise_steps = self.train_config.max_denoising_steps
|
||||
if self.model_config.refiner_name_or_path is not None:
|
||||
# if we are not training the unet, then we are only doing refiner and do not need to double up
|
||||
if self.train_config.train_unet:
|
||||
max_noise_steps = round(self.train_config.max_denoising_steps * self.model_config.refiner_start_at)
|
||||
do_double = True
|
||||
else:
|
||||
min_noise_steps = round(self.train_config.max_denoising_steps * self.model_config.refiner_start_at)
|
||||
do_double = False
|
||||
|
||||
with self.timer('prepare_noise'):
|
||||
|
||||
@@ -615,18 +651,18 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
timesteps,
|
||||
0,
|
||||
self.sd.noise_scheduler.config['num_train_timesteps'] - 1,
|
||||
self.train_config.min_denoising_steps,
|
||||
self.train_config.max_denoising_steps
|
||||
min_noise_steps,
|
||||
max_noise_steps
|
||||
)
|
||||
timesteps = timesteps.long().clamp(
|
||||
self.train_config.min_denoising_steps + 1,
|
||||
self.train_config.max_denoising_steps - 1
|
||||
min_noise_steps + 1,
|
||||
max_noise_steps - 1
|
||||
)
|
||||
|
||||
elif self.train_config.content_or_style == 'balanced':
|
||||
timesteps = torch.randint(
|
||||
self.train_config.min_denoising_steps,
|
||||
self.train_config.max_denoising_steps,
|
||||
min_noise_steps,
|
||||
max_noise_steps,
|
||||
(batch_size,),
|
||||
device=self.device_torch
|
||||
)
|
||||
@@ -678,9 +714,27 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
return torch.cat([tensor, tensor], dim=0)
|
||||
|
||||
if do_double:
|
||||
noisy_latents = double_up_tensor(noisy_latents)
|
||||
if self.model_config.refiner_name_or_path:
|
||||
# apply refiner double up
|
||||
refiner_timesteps = torch.randint(
|
||||
max_noise_steps,
|
||||
self.train_config.max_denoising_steps,
|
||||
(batch_size,),
|
||||
device=self.device_torch
|
||||
)
|
||||
refiner_timesteps = refiner_timesteps.long()
|
||||
# add our new timesteps on to end
|
||||
timesteps = torch.cat([timesteps, refiner_timesteps], dim=0)
|
||||
|
||||
refiner_noisy_latents = self.sd.noise_scheduler.add_noise(latents, noise, refiner_timesteps)
|
||||
noisy_latents = torch.cat([noisy_latents, refiner_noisy_latents], dim=0)
|
||||
|
||||
else:
|
||||
# just double it
|
||||
noisy_latents = double_up_tensor(noisy_latents)
|
||||
timesteps = double_up_tensor(timesteps)
|
||||
|
||||
noise = double_up_tensor(noise)
|
||||
timesteps = double_up_tensor(timesteps)
|
||||
# prompts are already updated above
|
||||
imgs = double_up_tensor(imgs)
|
||||
batch.mask_tensor = double_up_tensor(batch.mask_tensor)
|
||||
@@ -772,6 +826,12 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
# get the noise scheduler
|
||||
sampler = get_sampler(self.train_config.noise_scheduler)
|
||||
|
||||
if self.train_config.train_refiner and self.model_config.refiner_name_or_path is not None and self.network_config is None:
|
||||
previous_refiner_save = self.get_latest_save_path(self.job.name + '_refiner')
|
||||
if previous_refiner_save is not None:
|
||||
model_config_to_load.refiner_name_or_path = previous_refiner_save
|
||||
self.load_training_state_from_metadata(previous_refiner_save)
|
||||
|
||||
self.sd = StableDiffusion(
|
||||
device=self.device,
|
||||
model_config=model_config_to_load,
|
||||
@@ -818,6 +878,15 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
if hasattr(text_encoder, "gradient_checkpointing_enable"):
|
||||
text_encoder.gradient_checkpointing_enable()
|
||||
|
||||
if self.sd.refiner_unet is not None:
|
||||
self.sd.refiner_unet.to(self.device_torch, dtype=dtype)
|
||||
self.sd.refiner_unet.requires_grad_(False)
|
||||
self.sd.refiner_unet.eval()
|
||||
if self.train_config.xformers:
|
||||
self.sd.refiner_unet.enable_xformers_memory_efficient_attention()
|
||||
if self.train_config.gradient_checkpointing:
|
||||
self.sd.refiner_unet.enable_gradient_checkpointing()
|
||||
|
||||
if isinstance(text_encoder, list):
|
||||
for te in text_encoder:
|
||||
te.requires_grad_(False)
|
||||
@@ -840,7 +909,12 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
if os.path.exists(path_to_load):
|
||||
with open(path_to_load, 'r') as f:
|
||||
json_data = json.load(f)
|
||||
self.snr_gos.offset.data = torch.tensor(json_data['offset'], device=self.device_torch)
|
||||
if 'offset' in json_data:
|
||||
# legacy
|
||||
self.snr_gos.offset_2.data = torch.tensor(json_data['offset'], device=self.device_torch)
|
||||
else:
|
||||
self.snr_gos.offset_1.data = torch.tensor(json_data['offset_1'], device=self.device_torch)
|
||||
self.snr_gos.offset_2.data = torch.tensor(json_data['offset_2'], device=self.device_torch)
|
||||
self.snr_gos.scale.data = torch.tensor(json_data['scale'], device=self.device_torch)
|
||||
self.snr_gos.gamma.data = torch.tensor(json_data['gamma'], device=self.device_torch)
|
||||
|
||||
@@ -1018,7 +1092,9 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
text_encoder=self.train_config.train_text_encoder,
|
||||
text_encoder_lr=self.train_config.lr,
|
||||
unet_lr=self.train_config.lr,
|
||||
default_lr=self.train_config.lr
|
||||
default_lr=self.train_config.lr,
|
||||
refiner=self.train_config.train_refiner and self.sd.refiner_unet is not None,
|
||||
refiner_lr=self.train_config.refiner_lr,
|
||||
)
|
||||
# we may be using it for prompt injections
|
||||
if self.adapter_config is not None:
|
||||
@@ -1158,6 +1234,10 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
else:
|
||||
batch = None
|
||||
|
||||
# if we are doing a reg step, always accumulate
|
||||
if is_reg_step:
|
||||
self.is_grad_accumulation_step = True
|
||||
|
||||
# setup accumulation
|
||||
if self.train_config.gradient_accumulation_steps == -1:
|
||||
# epoch is handling the accumulation, dont touch it
|
||||
|
||||
Reference in New Issue
Block a user