t2i training working from what I can tell at least

This commit is contained in:
Jaret Burkett
2023-09-17 15:56:43 -06:00
parent 181f237a7b
commit 61badf85a7
5 changed files with 214 additions and 174 deletions

View File

@@ -54,16 +54,20 @@ class SDTrainer(BaseSDTrainProcess):
if os.path.exists(os.path.join(adapter_folder_path, file_name_no_ext + ext)):
adapter_images.append(os.path.join(adapter_folder_path, file_name_no_ext + ext))
break
width, height = batch.file_items[0].crop_width, batch.file_items[0].crop_height
adapter_tensors = []
# load images with torch transforms
for adapter_image in adapter_images:
for idx, adapter_image in enumerate(adapter_images):
img = Image.open(adapter_image)
# resize to match batch shape
img = img.resize((width, height))
img = adapter_transforms(img)
adapter_tensors.append(img)
# stack them
adapter_tensors = torch.stack(adapter_tensors)
adapter_tensors = torch.stack(adapter_tensors).to(
self.device_torch, dtype=get_torch_dtype(self.train_config.dtype)
)
return adapter_tensors
def hook_train_loop(self, batch):
@@ -79,8 +83,8 @@ class SDTrainer(BaseSDTrainProcess):
adapter_images = self.get_adapter_images(batch)
# not 100% sure what this does. But they do it here
# https://github.com/huggingface/diffusers/blob/38a664a3d61e27ab18cd698231422b3c38d6eebf/examples/t2i_adapter/train_t2i_adapter_sdxl.py#L1170
sigmas = self.get_sigmas(timesteps, len(noisy_latents.shape), noisy_latents.dtype)
noisy_latents = noisy_latents / ((sigmas ** 2 + 1) ** 0.5)
# sigmas = self.get_sigmas(timesteps, len(noisy_latents.shape), noisy_latents.dtype)
# noisy_latents = noisy_latents / ((sigmas ** 2 + 1) ** 0.5)
# flush()
self.optimizer.zero_grad()
@@ -126,39 +130,38 @@ class SDTrainer(BaseSDTrainProcess):
**pred_kwargs
)
if self.adapter:
# todo, diffusers does this on t2i training, is it better approach?
# Denoise the latents
denoised_latents = noise_pred * (-sigmas) + noisy_latents
weighing = sigmas ** -2.0
# Get the target for loss depending on the prediction type
if self.sd.noise_scheduler.config.prediction_type == "epsilon":
target = batch.latents # we are computing loss against denoise latents
elif self.sd.noise_scheduler.config.prediction_type == "v_prediction":
target = self.sd.noise_scheduler.get_velocity(batch.latents, noise, timesteps)
else:
raise ValueError(f"Unknown prediction type {self.sd.noise_scheduler.config.prediction_type}")
# MSE loss
loss = torch.mean(
(weighing.float() * (denoised_latents.float() - target.float()) ** 2).reshape(target.shape[0], -1),
dim=1,
)
# if self.adapter:
# # todo, diffusers does this on t2i training, is it better approach?
# # Denoise the latents
# denoised_latents = noise_pred * (-sigmas) + noisy_latents
# weighing = sigmas ** -2.0
#
# # Get the target for loss depending on the prediction type
# if self.sd.noise_scheduler.config.prediction_type == "epsilon":
# target = batch.latents # we are computing loss against denoise latents
# elif self.sd.noise_scheduler.config.prediction_type == "v_prediction":
# target = self.sd.noise_scheduler.get_velocity(batch.latents, noise, timesteps)
# else:
# raise ValueError(f"Unknown prediction type {self.sd.noise_scheduler.config.prediction_type}")
#
# # MSE loss
# loss = torch.mean(
# (weighing.float() * (denoised_latents.float() - target.float()) ** 2).reshape(target.shape[0], -1),
# dim=1,
# )
# else:
noise = noise.to(self.device_torch, dtype=dtype).detach()
if self.sd.prediction_type == 'v_prediction':
# v-parameterization training
target = self.sd.noise_scheduler.get_velocity(noisy_latents, noise, timesteps)
else:
noise = noise.to(self.device_torch, dtype=dtype).detach()
if self.sd.prediction_type == 'v_prediction':
# v-parameterization training
target = self.sd.noise_scheduler.get_velocity(noisy_latents, noise, timesteps)
else:
target = noise
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
loss = loss.mean([1, 2, 3])
target = noise
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
loss = loss.mean([1, 2, 3])
# TODO: I think the sigma method does not need this. Check
if self.train_config.min_snr_gamma is not None and self.train_config.min_snr_gamma > 0.000001:
# add min_snr_gamma
loss = apply_snr_weight(loss, timesteps, self.sd.noise_scheduler, self.train_config.min_snr_gamma)
if self.train_config.min_snr_gamma is not None and self.train_config.min_snr_gamma > 0.000001:
# add min_snr_gamma
loss = apply_snr_weight(loss, timesteps, self.sd.noise_scheduler, self.train_config.min_snr_gamma)
loss = loss.mean()

View File

@@ -160,6 +160,11 @@ class BaseSDTrainProcess(BaseTrainProcess):
train_embedding=self.embed_config is not None,
)
# fine_tuning here is for training actual SD network, not LoRA, embeddings, etc. it is (Dreambooth, etc)
self.is_fine_tuning = True
if self.network_config is not None or self.adapter_config is not None or self.embed_config is not None:
self.is_fine_tuning = False
def sample(self, step=None, is_first=False):
sample_folder = os.path.join(self.save_root, 'samples')
gen_img_config_list = []
@@ -194,6 +199,10 @@ class BaseSDTrainProcess(BaseTrainProcess):
prompt, self.trigger_word, add_if_not_present=False
)
extra_args = {}
if self.adapter_config is not None:
extra_args['adapter_image_path'] = self.adapter_config.test_img_path
gen_img_config_list.append(GenerateImageConfig(
prompt=prompt, # it will autoparse the prompt
width=sample_config.width,
@@ -206,6 +215,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
network_multiplier=sample_config.network_multiplier,
output_path=output_path,
output_ext=sample_config.ext,
**extra_args
))
# send to be generated
@@ -287,8 +297,15 @@ class BaseSDTrainProcess(BaseTrainProcess):
file_path = os.path.join(self.save_root, filename)
# prepare meta
save_meta = get_meta_for_safetensors(self.meta, self.job.name)
if self.network is not None or self.embedding is not None:
if self.network is not None or self.embedding is not None or self.adapter is not None:
if self.network is not None:
lora_name = self.job.name
if self.adapter_config is not None or self.embedding is not None:
# add _lora to name
lora_name += '_LoRA'
filename = f'{lora_name}{step_num}.safetensors'
file_path = os.path.join(self.save_root, filename)
prev_multiplier = self.network.multiplier
self.network.multiplier = 1.0
if self.network_config.normalize:
@@ -318,15 +335,23 @@ class BaseSDTrainProcess(BaseTrainProcess):
# replace extension
emb_file_path = os.path.splitext(emb_file_path)[0] + ".pt"
self.embedding.save(emb_file_path)
elif self.adapter is not None:
# save adapter
state_dict = self.adapter.state_dict()
save_t2i_from_diffusers(
state_dict,
output_file=file_path,
meta=save_meta,
dtype=get_torch_dtype(self.save_config.dtype)
)
if self.adapter is not None:
adapter_name = self.job.name
if self.network_config is not None or self.embedding is not None:
# add _lora to name
adapter_name += '_t2i'
filename = f'{adapter_name}{step_num}.safetensors'
file_path = os.path.join(self.save_root, filename)
# save adapter
state_dict = self.adapter.state_dict()
save_t2i_from_diffusers(
state_dict,
output_file=file_path,
meta=save_meta,
dtype=get_torch_dtype(self.save_config.dtype)
)
else:
self.sd.save(
file_path,
@@ -362,14 +387,14 @@ class BaseSDTrainProcess(BaseTrainProcess):
# return loss
return 0.0
def get_latest_save_path(self, name=None):
def get_latest_save_path(self, name=None, post=''):
if name == None:
name = self.job.name
# get latest saved step
if os.path.exists(self.save_root):
latest_file = None
# pattern is {job_name}_{zero_filles_step}.safetensors or {job_name}.safetensors
pattern = f"{name}*.safetensors"
pattern = f"{name}*{post}.safetensors"
files = glob.glob(os.path.join(self.save_root, pattern))
if len(files) > 0:
latest_file = max(files, key=os.path.getctime)
@@ -399,17 +424,19 @@ class BaseSDTrainProcess(BaseTrainProcess):
print("load_weights not implemented for non-network models")
return None
def get_sigmas(self, timesteps, n_dim=4, dtype=torch.float32):
sigmas = self.sd.noise_scheduler.sigmas.to(device=self.device_torch, dtype=dtype)
schedule_timesteps = self.sd.noise_scheduler.timesteps.to(self.device_torch, )
timesteps = timesteps.to(self.device_torch, )
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
sigma = sigmas[step_indices].flatten()
while len(sigma.shape) < n_dim:
sigma = sigma.unsqueeze(-1)
return sigma
# def get_sigmas(self, timesteps, n_dim=4, dtype=torch.float32):
# self.sd.noise_scheduler.set_timesteps(1000, device=self.device_torch)
# sigmas = self.sd.noise_scheduler.sigmas.to(device=self.device_torch, dtype=dtype)
# schedule_timesteps = self.sd.noise_scheduler.timesteps.to(self.device_torch, )
# timesteps = timesteps.to(self.device_torch, )
#
# # step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
# step_indices = [t for t in timesteps]
#
# sigma = sigmas[step_indices].flatten()
# while len(sigma.shape) < n_dim:
# sigma = sigma.unsqueeze(-1)
# return sigma
def process_general_training_batch(self, batch: 'DataLoaderBatchDTO'):
with torch.no_grad():
@@ -583,54 +610,52 @@ class BaseSDTrainProcess(BaseTrainProcess):
if self.datasets_reg is not None:
self.data_loader_reg = get_dataloader_from_datasets(self.datasets_reg, self.train_config.batch_size,
self.sd)
params = []
if self.network_config is not None or self.adapter_config is not None or self.embed_config is not None:
if self.network_config is not None:
# TODO should we completely switch to LycorisSpecialNetwork?
if self.network_config is not None:
# TODO should we completely switch to LycorisSpecialNetwork?
is_lycoris = False
# default to LoCON if there are any conv layers or if it is named
NetworkClass = LoRASpecialNetwork
if self.network_config.type.lower() == 'locon' or self.network_config.type.lower() == 'lycoris':
NetworkClass = LycorisSpecialNetwork
is_lycoris = True
is_lycoris = False
# default to LoCON if there are any conv layers or if it is named
NetworkClass = LoRASpecialNetwork
if self.network_config.type.lower() == 'locon' or self.network_config.type.lower() == 'lycoris':
NetworkClass = LycorisSpecialNetwork
is_lycoris = True
# if is_lycoris:
# preset = PRESET['full']
# NetworkClass.apply_preset(preset)
# if is_lycoris:
# preset = PRESET['full']
# NetworkClass.apply_preset(preset)
self.network = NetworkClass(
text_encoder=text_encoder,
unet=unet,
lora_dim=self.network_config.linear,
multiplier=1.0,
alpha=self.network_config.linear_alpha,
train_unet=self.train_config.train_unet,
train_text_encoder=self.train_config.train_text_encoder,
conv_lora_dim=self.network_config.conv,
conv_alpha=self.network_config.conv_alpha,
is_sdxl=self.model_config.is_xl,
is_v2=self.model_config.is_v2,
dropout=self.network_config.dropout
)
self.network = NetworkClass(
text_encoder=text_encoder,
unet=unet,
lora_dim=self.network_config.linear,
multiplier=1.0,
alpha=self.network_config.linear_alpha,
train_unet=self.train_config.train_unet,
train_text_encoder=self.train_config.train_text_encoder,
conv_lora_dim=self.network_config.conv,
conv_alpha=self.network_config.conv_alpha,
is_sdxl=self.model_config.is_xl,
is_v2=self.model_config.is_v2,
dropout=self.network_config.dropout
)
self.network.force_to(self.device_torch, dtype=dtype)
# give network to sd so it can use it
self.sd.network = self.network
self.network._update_torch_multiplier()
self.network.force_to(self.device_torch, dtype=dtype)
# give network to sd so it can use it
self.sd.network = self.network
self.network._update_torch_multiplier()
self.network.apply_to(
text_encoder,
unet,
self.train_config.train_text_encoder,
self.train_config.train_unet
)
self.network.apply_to(
text_encoder,
unet,
self.train_config.train_text_encoder,
self.train_config.train_unet
)
self.network.prepare_grad_etc(text_encoder, unet)
flush()
self.network.prepare_grad_etc(text_encoder, unet)
flush()
params = self.get_params()
if not params:
# LyCORIS doesnt have default_lr
config = {
'text_encoder_lr': self.train_config.lr,
@@ -639,23 +664,30 @@ class BaseSDTrainProcess(BaseTrainProcess):
sig = inspect.signature(self.network.prepare_optimizer_params)
if 'default_lr' in sig.parameters:
config['default_lr'] = self.train_config.lr
params = self.network.prepare_optimizer_params(
params_net = self.network.prepare_optimizer_params(
**config
)
if self.train_config.gradient_checkpointing:
self.network.enable_gradient_checkpointing()
params += params_net
# set the network to normalize if we are
self.network.is_normalizing = self.network_config.normalize
if self.train_config.gradient_checkpointing:
self.network.enable_gradient_checkpointing()
latest_save_path = self.get_latest_save_path()
extra_weights = None
if latest_save_path is not None:
self.print(f"#### IMPORTANT RESUMING FROM {latest_save_path} ####")
self.print(f"Loading from {latest_save_path}")
extra_weights = self.load_weights(latest_save_path)
self.network.multiplier = 1.0
# set the network to normalize if we are
self.network.is_normalizing = self.network_config.normalize
lora_name = self.name
# need to adapt name so they are not mixed up
if self.adapter_config is not None or self.embedding is not None:
lora_name = f"{lora_name}_LoRA"
latest_save_path = self.get_latest_save_path(lora_name)
extra_weights = None
if latest_save_path is not None:
self.print(f"#### IMPORTANT RESUMING FROM {latest_save_path} ####")
self.print(f"Loading from {latest_save_path}")
extra_weights = self.load_weights(latest_save_path)
self.network.multiplier = 1.0
if self.embed_config is not None:
# we are doing embedding training as well
@@ -672,68 +704,71 @@ class BaseSDTrainProcess(BaseTrainProcess):
'lr': self.train_config.embedding_lr
})
flush()
elif self.embed_config is not None:
self.embedding = Embedding(
sd=self.sd,
embed_config=self.embed_config
)
latest_save_path = self.get_latest_save_path(self.embed_config.trigger)
# load last saved weights
if latest_save_path is not None:
self.embedding.load_embedding_from_file(latest_save_path, self.device_torch)
flush()
# resume state from embedding
self.step_num = self.embedding.step
self.start_step = self.step_num
params = self.get_params()
if not params:
# set trainable params
params = self.embedding.get_trainable_params()
flush()
elif self.adapter_config is not None:
self.adapter = T2IAdapter(
in_channels=self.adapter_config.in_channels,
channels=self.adapter_config.channels,
num_res_blocks=self.adapter_config.num_res_blocks,
downscale_factor=self.adapter_config.downscale_factor,
adapter_type=self.adapter_config.adapter_type,
)
# t2i adapter
latest_save_path = self.get_latest_save_path(self.embed_config.trigger)
if latest_save_path is not None:
# load adapter from path
print(f"Loading adapter from {latest_save_path}")
loaded_state_dict = load_t2i_model(
latest_save_path,
self.device_torch,
dtype=dtype
if self.embed_config is not None:
self.embedding = Embedding(
sd=self.sd,
embed_config=self.embed_config
)
self.adapter.load_state_dict(loaded_state_dict)
self.load_training_state_from_metadata(latest_save_path)
params = self.get_params()
if not params:
latest_save_path = self.get_latest_save_path(self.embed_config.trigger)
# load last saved weights
if latest_save_path is not None:
self.embedding.load_embedding_from_file(latest_save_path, self.device_torch)
# resume state from embedding
self.step_num = self.embedding.step
self.start_step = self.step_num
params = self.get_params()
if not params:
# set trainable params
params = self.embedding.get_trainable_params()
flush()
if self.adapter_config is not None:
self.adapter = T2IAdapter(
in_channels=self.adapter_config.in_channels,
channels=self.adapter_config.channels,
num_res_blocks=self.adapter_config.num_res_blocks,
downscale_factor=self.adapter_config.downscale_factor,
adapter_type=self.adapter_config.adapter_type,
)
self.adapter.to(self.device_torch, dtype=dtype)
# t2i adapter
adapter_name = self.name
if self.network_config is not None:
adapter_name = f"{adapter_name}_t2i"
latest_save_path = self.get_latest_save_path(adapter_name)
if latest_save_path is not None:
# load adapter from path
print(f"Loading adapter from {latest_save_path}")
loaded_state_dict = load_t2i_model(
latest_save_path,
self.device,
dtype=dtype
)
self.adapter.load_state_dict(loaded_state_dict)
self.load_training_state_from_metadata(latest_save_path)
# set trainable params
params = self.adapter.parameters()
self.sd.adapter = self.adapter
flush()
else:
params.append({
'params': self.adapter.parameters(),
'lr': self.train_config.adapter_lr
})
self.sd.adapter = self.adapter
flush()
else: # no network, embedding or adapter
# set the device state preset before getting params
self.sd.set_device_state(self.train_device_state_preset)
params = self.get_params()
if params is None:
# will only return savable weights and ones with grad
params = self.sd.prepare_optimizer_params(
unet=self.train_config.train_unet,
text_encoder=self.train_config.train_text_encoder,
text_encoder_lr=self.train_config.lr,
unet_lr=self.train_config.lr,
default_lr=self.train_config.lr
)
# will only return savable weights and ones with grad
params = self.sd.prepare_optimizer_params(
unet=self.train_config.train_unet,
text_encoder=self.train_config.train_text_encoder,
text_encoder_lr=self.train_config.lr,
unet_lr=self.train_config.lr,
default_lr=self.train_config.lr
)
flush()
### HOOK ###
params = self.hook_add_extra_train_params(params)
@@ -746,7 +781,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
self.params.append(param)
optimizer_type = self.train_config.optimizer.lower()
optimizer = get_optimizer(params, optimizer_type, learning_rate=self.train_config.lr,
optimizer = get_optimizer(self.params, optimizer_type, learning_rate=self.train_config.lr,
optimizer_params=self.train_config.optimizer_params)
self.optimizer = optimizer

View File

@@ -92,6 +92,7 @@ class TrainConfig:
self.unet_lr = kwargs.get('unet_lr', self.lr)
self.text_encoder_lr = kwargs.get('text_encoder_lr', self.lr)
self.embedding_lr = kwargs.get('embedding_lr', self.lr)
self.adapter_lr = kwargs.get('adapter_lr', self.lr)
self.optimizer = kwargs.get('optimizer', 'adamw')
self.optimizer_params = kwargs.get('optimizer_params', {})
self.lr_scheduler = kwargs.get('lr_scheduler', 'constant')

View File

@@ -184,7 +184,7 @@ def save_t2i_from_diffusers(
def load_t2i_model(
path_to_file,
device: Union[str, torch.device] = 'cpu',
device: Union[str] = 'cpu',
dtype: torch.dtype = torch.float32
):
raw_state_dict = load_file(path_to_file, device)

View File

@@ -250,7 +250,7 @@ class StableDiffusion:
# add hacks to unet to help training
# pipe.unet = prepare_unet_for_training(pipe.unet)
self.unet = pipe.unet
self.unet: 'UNet2DConditionModel' = pipe.unet
self.vae: 'AutoencoderKL' = pipe.vae.to(self.device_torch, dtype=dtype)
self.vae.eval()
self.vae.requires_grad_(False)
@@ -360,8 +360,9 @@ class StableDiffusion:
extra = {}
if gen_config.adapter_image_path is not None:
validation_image = Image.open(gen_config.adapter_image_path).convert("RGB")
validation_image = validation_image.resize((gen_config.width, gen_config.height))
validation_image = validation_image.resize((gen_config.width * 2, gen_config.height * 2))
extra['image'] = validation_image
extra['adapter_conditioning_scale'] = 1.0
if self.network is not None:
self.network.multiplier = gen_config.network_multiplier
@@ -933,7 +934,7 @@ class StableDiffusion:
self.device_state['adapter'] = {
'training': self.adapter.training,
'device': self.adapter.device,
'requires_grad': self.adapter.requires_grad,
'requires_grad': self.adapter.adapter.conv_in.weight.requires_grad,
}
def restore_device_state(self):