mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
t2i training working from what I can tell at least
This commit is contained in:
@@ -54,16 +54,20 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
if os.path.exists(os.path.join(adapter_folder_path, file_name_no_ext + ext)):
|
||||
adapter_images.append(os.path.join(adapter_folder_path, file_name_no_ext + ext))
|
||||
break
|
||||
|
||||
width, height = batch.file_items[0].crop_width, batch.file_items[0].crop_height
|
||||
adapter_tensors = []
|
||||
# load images with torch transforms
|
||||
for adapter_image in adapter_images:
|
||||
for idx, adapter_image in enumerate(adapter_images):
|
||||
img = Image.open(adapter_image)
|
||||
# resize to match batch shape
|
||||
img = img.resize((width, height))
|
||||
img = adapter_transforms(img)
|
||||
adapter_tensors.append(img)
|
||||
|
||||
# stack them
|
||||
adapter_tensors = torch.stack(adapter_tensors)
|
||||
adapter_tensors = torch.stack(adapter_tensors).to(
|
||||
self.device_torch, dtype=get_torch_dtype(self.train_config.dtype)
|
||||
)
|
||||
return adapter_tensors
|
||||
|
||||
def hook_train_loop(self, batch):
|
||||
@@ -79,8 +83,8 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
adapter_images = self.get_adapter_images(batch)
|
||||
# not 100% sure what this does. But they do it here
|
||||
# https://github.com/huggingface/diffusers/blob/38a664a3d61e27ab18cd698231422b3c38d6eebf/examples/t2i_adapter/train_t2i_adapter_sdxl.py#L1170
|
||||
sigmas = self.get_sigmas(timesteps, len(noisy_latents.shape), noisy_latents.dtype)
|
||||
noisy_latents = noisy_latents / ((sigmas ** 2 + 1) ** 0.5)
|
||||
# sigmas = self.get_sigmas(timesteps, len(noisy_latents.shape), noisy_latents.dtype)
|
||||
# noisy_latents = noisy_latents / ((sigmas ** 2 + 1) ** 0.5)
|
||||
|
||||
# flush()
|
||||
self.optimizer.zero_grad()
|
||||
@@ -126,39 +130,38 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
**pred_kwargs
|
||||
)
|
||||
|
||||
if self.adapter:
|
||||
# todo, diffusers does this on t2i training, is it better approach?
|
||||
# Denoise the latents
|
||||
denoised_latents = noise_pred * (-sigmas) + noisy_latents
|
||||
weighing = sigmas ** -2.0
|
||||
|
||||
# Get the target for loss depending on the prediction type
|
||||
if self.sd.noise_scheduler.config.prediction_type == "epsilon":
|
||||
target = batch.latents # we are computing loss against denoise latents
|
||||
elif self.sd.noise_scheduler.config.prediction_type == "v_prediction":
|
||||
target = self.sd.noise_scheduler.get_velocity(batch.latents, noise, timesteps)
|
||||
else:
|
||||
raise ValueError(f"Unknown prediction type {self.sd.noise_scheduler.config.prediction_type}")
|
||||
|
||||
# MSE loss
|
||||
loss = torch.mean(
|
||||
(weighing.float() * (denoised_latents.float() - target.float()) ** 2).reshape(target.shape[0], -1),
|
||||
dim=1,
|
||||
)
|
||||
# if self.adapter:
|
||||
# # todo, diffusers does this on t2i training, is it better approach?
|
||||
# # Denoise the latents
|
||||
# denoised_latents = noise_pred * (-sigmas) + noisy_latents
|
||||
# weighing = sigmas ** -2.0
|
||||
#
|
||||
# # Get the target for loss depending on the prediction type
|
||||
# if self.sd.noise_scheduler.config.prediction_type == "epsilon":
|
||||
# target = batch.latents # we are computing loss against denoise latents
|
||||
# elif self.sd.noise_scheduler.config.prediction_type == "v_prediction":
|
||||
# target = self.sd.noise_scheduler.get_velocity(batch.latents, noise, timesteps)
|
||||
# else:
|
||||
# raise ValueError(f"Unknown prediction type {self.sd.noise_scheduler.config.prediction_type}")
|
||||
#
|
||||
# # MSE loss
|
||||
# loss = torch.mean(
|
||||
# (weighing.float() * (denoised_latents.float() - target.float()) ** 2).reshape(target.shape[0], -1),
|
||||
# dim=1,
|
||||
# )
|
||||
# else:
|
||||
noise = noise.to(self.device_torch, dtype=dtype).detach()
|
||||
if self.sd.prediction_type == 'v_prediction':
|
||||
# v-parameterization training
|
||||
target = self.sd.noise_scheduler.get_velocity(noisy_latents, noise, timesteps)
|
||||
else:
|
||||
noise = noise.to(self.device_torch, dtype=dtype).detach()
|
||||
if self.sd.prediction_type == 'v_prediction':
|
||||
# v-parameterization training
|
||||
target = self.sd.noise_scheduler.get_velocity(noisy_latents, noise, timesteps)
|
||||
else:
|
||||
target = noise
|
||||
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
|
||||
loss = loss.mean([1, 2, 3])
|
||||
target = noise
|
||||
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
|
||||
loss = loss.mean([1, 2, 3])
|
||||
|
||||
# TODO: I think the sigma method does not need this. Check
|
||||
if self.train_config.min_snr_gamma is not None and self.train_config.min_snr_gamma > 0.000001:
|
||||
# add min_snr_gamma
|
||||
loss = apply_snr_weight(loss, timesteps, self.sd.noise_scheduler, self.train_config.min_snr_gamma)
|
||||
if self.train_config.min_snr_gamma is not None and self.train_config.min_snr_gamma > 0.000001:
|
||||
# add min_snr_gamma
|
||||
loss = apply_snr_weight(loss, timesteps, self.sd.noise_scheduler, self.train_config.min_snr_gamma)
|
||||
|
||||
loss = loss.mean()
|
||||
|
||||
|
||||
@@ -160,6 +160,11 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
train_embedding=self.embed_config is not None,
|
||||
)
|
||||
|
||||
# fine_tuning here is for training actual SD network, not LoRA, embeddings, etc. it is (Dreambooth, etc)
|
||||
self.is_fine_tuning = True
|
||||
if self.network_config is not None or self.adapter_config is not None or self.embed_config is not None:
|
||||
self.is_fine_tuning = False
|
||||
|
||||
def sample(self, step=None, is_first=False):
|
||||
sample_folder = os.path.join(self.save_root, 'samples')
|
||||
gen_img_config_list = []
|
||||
@@ -194,6 +199,10 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
prompt, self.trigger_word, add_if_not_present=False
|
||||
)
|
||||
|
||||
extra_args = {}
|
||||
if self.adapter_config is not None:
|
||||
extra_args['adapter_image_path'] = self.adapter_config.test_img_path
|
||||
|
||||
gen_img_config_list.append(GenerateImageConfig(
|
||||
prompt=prompt, # it will autoparse the prompt
|
||||
width=sample_config.width,
|
||||
@@ -206,6 +215,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
network_multiplier=sample_config.network_multiplier,
|
||||
output_path=output_path,
|
||||
output_ext=sample_config.ext,
|
||||
**extra_args
|
||||
))
|
||||
|
||||
# send to be generated
|
||||
@@ -287,8 +297,15 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
file_path = os.path.join(self.save_root, filename)
|
||||
# prepare meta
|
||||
save_meta = get_meta_for_safetensors(self.meta, self.job.name)
|
||||
if self.network is not None or self.embedding is not None:
|
||||
if self.network is not None or self.embedding is not None or self.adapter is not None:
|
||||
if self.network is not None:
|
||||
lora_name = self.job.name
|
||||
if self.adapter_config is not None or self.embedding is not None:
|
||||
# add _lora to name
|
||||
lora_name += '_LoRA'
|
||||
|
||||
filename = f'{lora_name}{step_num}.safetensors'
|
||||
file_path = os.path.join(self.save_root, filename)
|
||||
prev_multiplier = self.network.multiplier
|
||||
self.network.multiplier = 1.0
|
||||
if self.network_config.normalize:
|
||||
@@ -318,15 +335,23 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
# replace extension
|
||||
emb_file_path = os.path.splitext(emb_file_path)[0] + ".pt"
|
||||
self.embedding.save(emb_file_path)
|
||||
elif self.adapter is not None:
|
||||
# save adapter
|
||||
state_dict = self.adapter.state_dict()
|
||||
save_t2i_from_diffusers(
|
||||
state_dict,
|
||||
output_file=file_path,
|
||||
meta=save_meta,
|
||||
dtype=get_torch_dtype(self.save_config.dtype)
|
||||
)
|
||||
|
||||
if self.adapter is not None:
|
||||
adapter_name = self.job.name
|
||||
if self.network_config is not None or self.embedding is not None:
|
||||
# add _lora to name
|
||||
adapter_name += '_t2i'
|
||||
|
||||
filename = f'{adapter_name}{step_num}.safetensors'
|
||||
file_path = os.path.join(self.save_root, filename)
|
||||
# save adapter
|
||||
state_dict = self.adapter.state_dict()
|
||||
save_t2i_from_diffusers(
|
||||
state_dict,
|
||||
output_file=file_path,
|
||||
meta=save_meta,
|
||||
dtype=get_torch_dtype(self.save_config.dtype)
|
||||
)
|
||||
else:
|
||||
self.sd.save(
|
||||
file_path,
|
||||
@@ -362,14 +387,14 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
# return loss
|
||||
return 0.0
|
||||
|
||||
def get_latest_save_path(self, name=None):
|
||||
def get_latest_save_path(self, name=None, post=''):
|
||||
if name == None:
|
||||
name = self.job.name
|
||||
# get latest saved step
|
||||
if os.path.exists(self.save_root):
|
||||
latest_file = None
|
||||
# pattern is {job_name}_{zero_filles_step}.safetensors or {job_name}.safetensors
|
||||
pattern = f"{name}*.safetensors"
|
||||
pattern = f"{name}*{post}.safetensors"
|
||||
files = glob.glob(os.path.join(self.save_root, pattern))
|
||||
if len(files) > 0:
|
||||
latest_file = max(files, key=os.path.getctime)
|
||||
@@ -399,17 +424,19 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
print("load_weights not implemented for non-network models")
|
||||
return None
|
||||
|
||||
def get_sigmas(self, timesteps, n_dim=4, dtype=torch.float32):
|
||||
sigmas = self.sd.noise_scheduler.sigmas.to(device=self.device_torch, dtype=dtype)
|
||||
schedule_timesteps = self.sd.noise_scheduler.timesteps.to(self.device_torch, )
|
||||
timesteps = timesteps.to(self.device_torch, )
|
||||
|
||||
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
|
||||
|
||||
sigma = sigmas[step_indices].flatten()
|
||||
while len(sigma.shape) < n_dim:
|
||||
sigma = sigma.unsqueeze(-1)
|
||||
return sigma
|
||||
# def get_sigmas(self, timesteps, n_dim=4, dtype=torch.float32):
|
||||
# self.sd.noise_scheduler.set_timesteps(1000, device=self.device_torch)
|
||||
# sigmas = self.sd.noise_scheduler.sigmas.to(device=self.device_torch, dtype=dtype)
|
||||
# schedule_timesteps = self.sd.noise_scheduler.timesteps.to(self.device_torch, )
|
||||
# timesteps = timesteps.to(self.device_torch, )
|
||||
#
|
||||
# # step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
|
||||
# step_indices = [t for t in timesteps]
|
||||
#
|
||||
# sigma = sigmas[step_indices].flatten()
|
||||
# while len(sigma.shape) < n_dim:
|
||||
# sigma = sigma.unsqueeze(-1)
|
||||
# return sigma
|
||||
|
||||
def process_general_training_batch(self, batch: 'DataLoaderBatchDTO'):
|
||||
with torch.no_grad():
|
||||
@@ -583,54 +610,52 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
if self.datasets_reg is not None:
|
||||
self.data_loader_reg = get_dataloader_from_datasets(self.datasets_reg, self.train_config.batch_size,
|
||||
self.sd)
|
||||
params = []
|
||||
if self.network_config is not None or self.adapter_config is not None or self.embed_config is not None:
|
||||
if self.network_config is not None:
|
||||
# TODO should we completely switch to LycorisSpecialNetwork?
|
||||
|
||||
if self.network_config is not None:
|
||||
# TODO should we completely switch to LycorisSpecialNetwork?
|
||||
is_lycoris = False
|
||||
# default to LoCON if there are any conv layers or if it is named
|
||||
NetworkClass = LoRASpecialNetwork
|
||||
if self.network_config.type.lower() == 'locon' or self.network_config.type.lower() == 'lycoris':
|
||||
NetworkClass = LycorisSpecialNetwork
|
||||
is_lycoris = True
|
||||
|
||||
is_lycoris = False
|
||||
# default to LoCON if there are any conv layers or if it is named
|
||||
NetworkClass = LoRASpecialNetwork
|
||||
if self.network_config.type.lower() == 'locon' or self.network_config.type.lower() == 'lycoris':
|
||||
NetworkClass = LycorisSpecialNetwork
|
||||
is_lycoris = True
|
||||
# if is_lycoris:
|
||||
# preset = PRESET['full']
|
||||
# NetworkClass.apply_preset(preset)
|
||||
|
||||
# if is_lycoris:
|
||||
# preset = PRESET['full']
|
||||
# NetworkClass.apply_preset(preset)
|
||||
self.network = NetworkClass(
|
||||
text_encoder=text_encoder,
|
||||
unet=unet,
|
||||
lora_dim=self.network_config.linear,
|
||||
multiplier=1.0,
|
||||
alpha=self.network_config.linear_alpha,
|
||||
train_unet=self.train_config.train_unet,
|
||||
train_text_encoder=self.train_config.train_text_encoder,
|
||||
conv_lora_dim=self.network_config.conv,
|
||||
conv_alpha=self.network_config.conv_alpha,
|
||||
is_sdxl=self.model_config.is_xl,
|
||||
is_v2=self.model_config.is_v2,
|
||||
dropout=self.network_config.dropout
|
||||
)
|
||||
|
||||
self.network = NetworkClass(
|
||||
text_encoder=text_encoder,
|
||||
unet=unet,
|
||||
lora_dim=self.network_config.linear,
|
||||
multiplier=1.0,
|
||||
alpha=self.network_config.linear_alpha,
|
||||
train_unet=self.train_config.train_unet,
|
||||
train_text_encoder=self.train_config.train_text_encoder,
|
||||
conv_lora_dim=self.network_config.conv,
|
||||
conv_alpha=self.network_config.conv_alpha,
|
||||
is_sdxl=self.model_config.is_xl,
|
||||
is_v2=self.model_config.is_v2,
|
||||
dropout=self.network_config.dropout
|
||||
)
|
||||
self.network.force_to(self.device_torch, dtype=dtype)
|
||||
# give network to sd so it can use it
|
||||
self.sd.network = self.network
|
||||
self.network._update_torch_multiplier()
|
||||
|
||||
self.network.force_to(self.device_torch, dtype=dtype)
|
||||
# give network to sd so it can use it
|
||||
self.sd.network = self.network
|
||||
self.network._update_torch_multiplier()
|
||||
self.network.apply_to(
|
||||
text_encoder,
|
||||
unet,
|
||||
self.train_config.train_text_encoder,
|
||||
self.train_config.train_unet
|
||||
)
|
||||
|
||||
self.network.apply_to(
|
||||
text_encoder,
|
||||
unet,
|
||||
self.train_config.train_text_encoder,
|
||||
self.train_config.train_unet
|
||||
)
|
||||
self.network.prepare_grad_etc(text_encoder, unet)
|
||||
flush()
|
||||
|
||||
self.network.prepare_grad_etc(text_encoder, unet)
|
||||
flush()
|
||||
|
||||
params = self.get_params()
|
||||
|
||||
if not params:
|
||||
# LyCORIS doesnt have default_lr
|
||||
config = {
|
||||
'text_encoder_lr': self.train_config.lr,
|
||||
@@ -639,23 +664,30 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
sig = inspect.signature(self.network.prepare_optimizer_params)
|
||||
if 'default_lr' in sig.parameters:
|
||||
config['default_lr'] = self.train_config.lr
|
||||
params = self.network.prepare_optimizer_params(
|
||||
params_net = self.network.prepare_optimizer_params(
|
||||
**config
|
||||
)
|
||||
|
||||
if self.train_config.gradient_checkpointing:
|
||||
self.network.enable_gradient_checkpointing()
|
||||
params += params_net
|
||||
|
||||
# set the network to normalize if we are
|
||||
self.network.is_normalizing = self.network_config.normalize
|
||||
if self.train_config.gradient_checkpointing:
|
||||
self.network.enable_gradient_checkpointing()
|
||||
|
||||
latest_save_path = self.get_latest_save_path()
|
||||
extra_weights = None
|
||||
if latest_save_path is not None:
|
||||
self.print(f"#### IMPORTANT RESUMING FROM {latest_save_path} ####")
|
||||
self.print(f"Loading from {latest_save_path}")
|
||||
extra_weights = self.load_weights(latest_save_path)
|
||||
self.network.multiplier = 1.0
|
||||
# set the network to normalize if we are
|
||||
self.network.is_normalizing = self.network_config.normalize
|
||||
|
||||
lora_name = self.name
|
||||
# need to adapt name so they are not mixed up
|
||||
if self.adapter_config is not None or self.embedding is not None:
|
||||
lora_name = f"{lora_name}_LoRA"
|
||||
|
||||
latest_save_path = self.get_latest_save_path(lora_name)
|
||||
extra_weights = None
|
||||
if latest_save_path is not None:
|
||||
self.print(f"#### IMPORTANT RESUMING FROM {latest_save_path} ####")
|
||||
self.print(f"Loading from {latest_save_path}")
|
||||
extra_weights = self.load_weights(latest_save_path)
|
||||
self.network.multiplier = 1.0
|
||||
|
||||
if self.embed_config is not None:
|
||||
# we are doing embedding training as well
|
||||
@@ -672,68 +704,71 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
'lr': self.train_config.embedding_lr
|
||||
})
|
||||
|
||||
flush()
|
||||
elif self.embed_config is not None:
|
||||
self.embedding = Embedding(
|
||||
sd=self.sd,
|
||||
embed_config=self.embed_config
|
||||
)
|
||||
latest_save_path = self.get_latest_save_path(self.embed_config.trigger)
|
||||
# load last saved weights
|
||||
if latest_save_path is not None:
|
||||
self.embedding.load_embedding_from_file(latest_save_path, self.device_torch)
|
||||
flush()
|
||||
|
||||
# resume state from embedding
|
||||
self.step_num = self.embedding.step
|
||||
self.start_step = self.step_num
|
||||
|
||||
params = self.get_params()
|
||||
if not params:
|
||||
# set trainable params
|
||||
params = self.embedding.get_trainable_params()
|
||||
|
||||
flush()
|
||||
elif self.adapter_config is not None:
|
||||
self.adapter = T2IAdapter(
|
||||
in_channels=self.adapter_config.in_channels,
|
||||
channels=self.adapter_config.channels,
|
||||
num_res_blocks=self.adapter_config.num_res_blocks,
|
||||
downscale_factor=self.adapter_config.downscale_factor,
|
||||
adapter_type=self.adapter_config.adapter_type,
|
||||
)
|
||||
# t2i adapter
|
||||
latest_save_path = self.get_latest_save_path(self.embed_config.trigger)
|
||||
if latest_save_path is not None:
|
||||
# load adapter from path
|
||||
print(f"Loading adapter from {latest_save_path}")
|
||||
loaded_state_dict = load_t2i_model(
|
||||
latest_save_path,
|
||||
self.device_torch,
|
||||
dtype=dtype
|
||||
if self.embed_config is not None:
|
||||
self.embedding = Embedding(
|
||||
sd=self.sd,
|
||||
embed_config=self.embed_config
|
||||
)
|
||||
self.adapter.load_state_dict(loaded_state_dict)
|
||||
self.load_training_state_from_metadata(latest_save_path)
|
||||
params = self.get_params()
|
||||
if not params:
|
||||
latest_save_path = self.get_latest_save_path(self.embed_config.trigger)
|
||||
# load last saved weights
|
||||
if latest_save_path is not None:
|
||||
self.embedding.load_embedding_from_file(latest_save_path, self.device_torch)
|
||||
|
||||
# resume state from embedding
|
||||
self.step_num = self.embedding.step
|
||||
self.start_step = self.step_num
|
||||
|
||||
params = self.get_params()
|
||||
if not params:
|
||||
# set trainable params
|
||||
params = self.embedding.get_trainable_params()
|
||||
|
||||
flush()
|
||||
|
||||
if self.adapter_config is not None:
|
||||
self.adapter = T2IAdapter(
|
||||
in_channels=self.adapter_config.in_channels,
|
||||
channels=self.adapter_config.channels,
|
||||
num_res_blocks=self.adapter_config.num_res_blocks,
|
||||
downscale_factor=self.adapter_config.downscale_factor,
|
||||
adapter_type=self.adapter_config.adapter_type,
|
||||
)
|
||||
self.adapter.to(self.device_torch, dtype=dtype)
|
||||
# t2i adapter
|
||||
adapter_name = self.name
|
||||
if self.network_config is not None:
|
||||
adapter_name = f"{adapter_name}_t2i"
|
||||
latest_save_path = self.get_latest_save_path(adapter_name)
|
||||
if latest_save_path is not None:
|
||||
# load adapter from path
|
||||
print(f"Loading adapter from {latest_save_path}")
|
||||
loaded_state_dict = load_t2i_model(
|
||||
latest_save_path,
|
||||
self.device,
|
||||
dtype=dtype
|
||||
)
|
||||
self.adapter.load_state_dict(loaded_state_dict)
|
||||
self.load_training_state_from_metadata(latest_save_path)
|
||||
# set trainable params
|
||||
params = self.adapter.parameters()
|
||||
self.sd.adapter = self.adapter
|
||||
flush()
|
||||
else:
|
||||
params.append({
|
||||
'params': self.adapter.parameters(),
|
||||
'lr': self.train_config.adapter_lr
|
||||
})
|
||||
self.sd.adapter = self.adapter
|
||||
flush()
|
||||
else: # no network, embedding or adapter
|
||||
# set the device state preset before getting params
|
||||
self.sd.set_device_state(self.train_device_state_preset)
|
||||
|
||||
params = self.get_params()
|
||||
|
||||
if params is None:
|
||||
# will only return savable weights and ones with grad
|
||||
params = self.sd.prepare_optimizer_params(
|
||||
unet=self.train_config.train_unet,
|
||||
text_encoder=self.train_config.train_text_encoder,
|
||||
text_encoder_lr=self.train_config.lr,
|
||||
unet_lr=self.train_config.lr,
|
||||
default_lr=self.train_config.lr
|
||||
)
|
||||
# will only return savable weights and ones with grad
|
||||
params = self.sd.prepare_optimizer_params(
|
||||
unet=self.train_config.train_unet,
|
||||
text_encoder=self.train_config.train_text_encoder,
|
||||
text_encoder_lr=self.train_config.lr,
|
||||
unet_lr=self.train_config.lr,
|
||||
default_lr=self.train_config.lr
|
||||
)
|
||||
flush()
|
||||
### HOOK ###
|
||||
params = self.hook_add_extra_train_params(params)
|
||||
@@ -746,7 +781,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
self.params.append(param)
|
||||
|
||||
optimizer_type = self.train_config.optimizer.lower()
|
||||
optimizer = get_optimizer(params, optimizer_type, learning_rate=self.train_config.lr,
|
||||
optimizer = get_optimizer(self.params, optimizer_type, learning_rate=self.train_config.lr,
|
||||
optimizer_params=self.train_config.optimizer_params)
|
||||
self.optimizer = optimizer
|
||||
|
||||
|
||||
@@ -92,6 +92,7 @@ class TrainConfig:
|
||||
self.unet_lr = kwargs.get('unet_lr', self.lr)
|
||||
self.text_encoder_lr = kwargs.get('text_encoder_lr', self.lr)
|
||||
self.embedding_lr = kwargs.get('embedding_lr', self.lr)
|
||||
self.adapter_lr = kwargs.get('adapter_lr', self.lr)
|
||||
self.optimizer = kwargs.get('optimizer', 'adamw')
|
||||
self.optimizer_params = kwargs.get('optimizer_params', {})
|
||||
self.lr_scheduler = kwargs.get('lr_scheduler', 'constant')
|
||||
|
||||
@@ -184,7 +184,7 @@ def save_t2i_from_diffusers(
|
||||
|
||||
def load_t2i_model(
|
||||
path_to_file,
|
||||
device: Union[str, torch.device] = 'cpu',
|
||||
device: Union[str] = 'cpu',
|
||||
dtype: torch.dtype = torch.float32
|
||||
):
|
||||
raw_state_dict = load_file(path_to_file, device)
|
||||
|
||||
@@ -250,7 +250,7 @@ class StableDiffusion:
|
||||
# add hacks to unet to help training
|
||||
# pipe.unet = prepare_unet_for_training(pipe.unet)
|
||||
|
||||
self.unet = pipe.unet
|
||||
self.unet: 'UNet2DConditionModel' = pipe.unet
|
||||
self.vae: 'AutoencoderKL' = pipe.vae.to(self.device_torch, dtype=dtype)
|
||||
self.vae.eval()
|
||||
self.vae.requires_grad_(False)
|
||||
@@ -360,8 +360,9 @@ class StableDiffusion:
|
||||
extra = {}
|
||||
if gen_config.adapter_image_path is not None:
|
||||
validation_image = Image.open(gen_config.adapter_image_path).convert("RGB")
|
||||
validation_image = validation_image.resize((gen_config.width, gen_config.height))
|
||||
validation_image = validation_image.resize((gen_config.width * 2, gen_config.height * 2))
|
||||
extra['image'] = validation_image
|
||||
extra['adapter_conditioning_scale'] = 1.0
|
||||
|
||||
if self.network is not None:
|
||||
self.network.multiplier = gen_config.network_multiplier
|
||||
@@ -933,7 +934,7 @@ class StableDiffusion:
|
||||
self.device_state['adapter'] = {
|
||||
'training': self.adapter.training,
|
||||
'device': self.adapter.device,
|
||||
'requires_grad': self.adapter.requires_grad,
|
||||
'requires_grad': self.adapter.adapter.conv_in.weight.requires_grad,
|
||||
}
|
||||
|
||||
def restore_device_state(self):
|
||||
|
||||
Reference in New Issue
Block a user