Added refiner fine tuning. Works, but needs some polish.

This commit is contained in:
Jaret Burkett
2023-11-05 17:15:03 -07:00
parent 8a9e8f708f
commit 93ea955d7c
14 changed files with 4541 additions and 128 deletions

View File

@@ -152,6 +152,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
train_lora=self.network_config is not None,
train_adapter=is_training_adapter,
train_embedding=self.embed_config is not None,
train_refiner=self.train_config.train_refiner,
)
# fine_tuning here is for training actual SD network, not LoRA, embeddings, etc. it is (Dreambooth, etc)
@@ -382,16 +383,29 @@ class BaseSDTrainProcess(BaseTrainProcess):
file_path = file_path.replace('.safetensors', '')
# convert it back to normal object
save_meta = parse_metadata_from_safetensors(save_meta)
self.sd.save(
file_path,
save_meta,
get_torch_dtype(self.save_config.dtype)
)
if self.sd.refiner_unet and self.train_config.train_refiner:
# save refiner
refiner_name = self.job.name + '_refiner'
filename = f'{refiner_name}{step_num}.safetensors'
file_path = os.path.join(self.save_root, filename)
self.sd.save_refiner(
file_path,
save_meta,
get_torch_dtype(self.save_config.dtype)
)
if self.train_config.train_unet or self.train_config.train_text_encoder:
self.sd.save(
file_path,
save_meta,
get_torch_dtype(self.save_config.dtype)
)
# save learnable params as json if we have thim
if self.snr_gos:
json_data = {
'offset': self.snr_gos.offset.item(),
'offset_1': self.snr_gos.offset_1.item(),
'offset_2': self.snr_gos.offset_2.item(),
'scale': self.snr_gos.scale.item(),
'gamma': self.snr_gos.gamma.item(),
}
@@ -447,7 +461,16 @@ class BaseSDTrainProcess(BaseTrainProcess):
# Filter out non-existent paths and sort by creation time
if paths:
paths = [p for p in paths if os.path.exists(p)]
latest_path = max(paths, key=os.path.getctime)
# remove false positives
if '_LoRA' not in name:
paths = [p for p in paths if '_LoRA' not in p]
if '_refiner' not in name:
paths = [p for p in paths if '_refiner' not in p]
if '_t2i' not in name:
paths = [p for p in paths if '_t2i' not in p]
if len(paths) > 0:
latest_path = max(paths, key=os.path.getctime)
return latest_path
@@ -540,6 +563,9 @@ class BaseSDTrainProcess(BaseTrainProcess):
# double batch and add short captions to the end
prompts = prompts + batch.get_caption_short_list()
is_reg_list = is_reg_list + is_reg_list
if self.model_config.refiner_name_or_path is not None and self.train_config.train_unet:
prompts = prompts + prompts
is_reg_list = is_reg_list + is_reg_list
conditioned_prompts = []
@@ -587,6 +613,16 @@ class BaseSDTrainProcess(BaseTrainProcess):
unaugmented_latents = self.sd.encode_images(batch.unaugmented_tensor)
batch_size = len(batch.file_items)
min_noise_steps = self.train_config.min_denoising_steps
max_noise_steps = self.train_config.max_denoising_steps
if self.model_config.refiner_name_or_path is not None:
# if we are not training the unet, then we are only doing refiner and do not need to double up
if self.train_config.train_unet:
max_noise_steps = round(self.train_config.max_denoising_steps * self.model_config.refiner_start_at)
do_double = True
else:
min_noise_steps = round(self.train_config.max_denoising_steps * self.model_config.refiner_start_at)
do_double = False
with self.timer('prepare_noise'):
@@ -615,18 +651,18 @@ class BaseSDTrainProcess(BaseTrainProcess):
timesteps,
0,
self.sd.noise_scheduler.config['num_train_timesteps'] - 1,
self.train_config.min_denoising_steps,
self.train_config.max_denoising_steps
min_noise_steps,
max_noise_steps
)
timesteps = timesteps.long().clamp(
self.train_config.min_denoising_steps + 1,
self.train_config.max_denoising_steps - 1
min_noise_steps + 1,
max_noise_steps - 1
)
elif self.train_config.content_or_style == 'balanced':
timesteps = torch.randint(
self.train_config.min_denoising_steps,
self.train_config.max_denoising_steps,
min_noise_steps,
max_noise_steps,
(batch_size,),
device=self.device_torch
)
@@ -678,9 +714,27 @@ class BaseSDTrainProcess(BaseTrainProcess):
return torch.cat([tensor, tensor], dim=0)
if do_double:
noisy_latents = double_up_tensor(noisy_latents)
if self.model_config.refiner_name_or_path:
# apply refiner double up
refiner_timesteps = torch.randint(
max_noise_steps,
self.train_config.max_denoising_steps,
(batch_size,),
device=self.device_torch
)
refiner_timesteps = refiner_timesteps.long()
# add our new timesteps on to end
timesteps = torch.cat([timesteps, refiner_timesteps], dim=0)
refiner_noisy_latents = self.sd.noise_scheduler.add_noise(latents, noise, refiner_timesteps)
noisy_latents = torch.cat([noisy_latents, refiner_noisy_latents], dim=0)
else:
# just double it
noisy_latents = double_up_tensor(noisy_latents)
timesteps = double_up_tensor(timesteps)
noise = double_up_tensor(noise)
timesteps = double_up_tensor(timesteps)
# prompts are already updated above
imgs = double_up_tensor(imgs)
batch.mask_tensor = double_up_tensor(batch.mask_tensor)
@@ -772,6 +826,12 @@ class BaseSDTrainProcess(BaseTrainProcess):
# get the noise scheduler
sampler = get_sampler(self.train_config.noise_scheduler)
if self.train_config.train_refiner and self.model_config.refiner_name_or_path is not None and self.network_config is None:
previous_refiner_save = self.get_latest_save_path(self.job.name + '_refiner')
if previous_refiner_save is not None:
model_config_to_load.refiner_name_or_path = previous_refiner_save
self.load_training_state_from_metadata(previous_refiner_save)
self.sd = StableDiffusion(
device=self.device,
model_config=model_config_to_load,
@@ -818,6 +878,15 @@ class BaseSDTrainProcess(BaseTrainProcess):
if hasattr(text_encoder, "gradient_checkpointing_enable"):
text_encoder.gradient_checkpointing_enable()
if self.sd.refiner_unet is not None:
self.sd.refiner_unet.to(self.device_torch, dtype=dtype)
self.sd.refiner_unet.requires_grad_(False)
self.sd.refiner_unet.eval()
if self.train_config.xformers:
self.sd.refiner_unet.enable_xformers_memory_efficient_attention()
if self.train_config.gradient_checkpointing:
self.sd.refiner_unet.enable_gradient_checkpointing()
if isinstance(text_encoder, list):
for te in text_encoder:
te.requires_grad_(False)
@@ -840,7 +909,12 @@ class BaseSDTrainProcess(BaseTrainProcess):
if os.path.exists(path_to_load):
with open(path_to_load, 'r') as f:
json_data = json.load(f)
self.snr_gos.offset.data = torch.tensor(json_data['offset'], device=self.device_torch)
if 'offset' in json_data:
# legacy
self.snr_gos.offset_2.data = torch.tensor(json_data['offset'], device=self.device_torch)
else:
self.snr_gos.offset_1.data = torch.tensor(json_data['offset_1'], device=self.device_torch)
self.snr_gos.offset_2.data = torch.tensor(json_data['offset_2'], device=self.device_torch)
self.snr_gos.scale.data = torch.tensor(json_data['scale'], device=self.device_torch)
self.snr_gos.gamma.data = torch.tensor(json_data['gamma'], device=self.device_torch)
@@ -1018,7 +1092,9 @@ class BaseSDTrainProcess(BaseTrainProcess):
text_encoder=self.train_config.train_text_encoder,
text_encoder_lr=self.train_config.lr,
unet_lr=self.train_config.lr,
default_lr=self.train_config.lr
default_lr=self.train_config.lr,
refiner=self.train_config.train_refiner and self.sd.refiner_unet is not None,
refiner_lr=self.train_config.refiner_lr,
)
# we may be using it for prompt injections
if self.adapter_config is not None:
@@ -1158,6 +1234,10 @@ class BaseSDTrainProcess(BaseTrainProcess):
else:
batch = None
# if we are doing a reg step, always accumulate
if is_reg_step:
self.is_grad_accumulation_step = True
# setup accumulation
if self.train_config.gradient_accumulation_steps == -1:
# epoch is handling the accumulation, dont touch it