Varous bug fixes. Finalized targeted guidance algo

This commit is contained in:
Jaret Burkett
2023-11-10 12:18:08 -07:00
parent fa6d91ba76
commit 7782caa468
5 changed files with 92 additions and 138 deletions

View File

@@ -1061,6 +1061,9 @@ class BaseSDTrainProcess(BaseTrainProcess):
# load last saved weights
if latest_save_path is not None:
self.embedding.load_embedding_from_file(latest_save_path, self.device_torch)
# self.step_num = self.embedding.step
# self.start_step = self.step_num
params.append({
'params': self.embedding.get_trainable_params(),
'lr': self.train_config.embedding_lr
@@ -1068,27 +1071,6 @@ class BaseSDTrainProcess(BaseTrainProcess):
flush()
if self.embed_config is not None:
self.embedding = Embedding(
sd=self.sd,
embed_config=self.embed_config
)
latest_save_path = self.get_latest_save_path(self.embed_config.trigger)
# load last saved weights
if latest_save_path is not None:
self.embedding.load_embedding_from_file(latest_save_path, self.device_torch)
# resume state from embedding
self.step_num = self.embedding.step
self.start_step = self.step_num
params = self.get_params()
if not params:
# set trainable params
params = self.embedding.get_trainable_params()
flush()
if self.adapter_config is not None:
self.setup_adapter()
# set trainable params

View File

@@ -327,6 +327,7 @@ class TrainSliderProcess(BaseSDTrainProcess):
with torch.no_grad():
adapter_images = None
self.sd.unet.eval()
# for a complete slider, the batch size is 4 to begin with now
true_batch_size = prompt_pair.target_class.text_embeds.shape[0] * self.train_config.batch_size
@@ -385,21 +386,22 @@ class TrainSliderProcess(BaseSDTrainProcess):
latents = noise * self.sd.noise_scheduler.init_noise_sigma
latents = latents.to(self.device_torch, dtype=dtype)
with self.network:
assert self.network.is_active
# pass the multiplier list to the network
self.network.multiplier = prompt_pair.multiplier_list
denoised_latents = self.sd.diffuse_some_steps(
latents, # pass simple noise latents
train_tools.concat_prompt_embeddings(
prompt_pair.positive_target, # unconditional
prompt_pair.target_class, # target
self.train_config.batch_size,
),
start_timesteps=0,
total_timesteps=timesteps_to,
guidance_scale=3,
)
assert not self.network.is_active
self.sd.unet.eval()
# pass the multiplier list to the network
self.network.multiplier = prompt_pair.multiplier_list
denoised_latents = self.sd.diffuse_some_steps(
latents, # pass simple noise latents
train_tools.concat_prompt_embeddings(
prompt_pair.positive_target, # unconditional
prompt_pair.target_class, # target
self.train_config.batch_size,
),
start_timesteps=0,
total_timesteps=timesteps_to,
guidance_scale=3,
)
noise_scheduler.set_timesteps(1000)
@@ -473,6 +475,7 @@ class TrainSliderProcess(BaseSDTrainProcess):
denoised_latents = denoised_latents.detach()
self.sd.set_device_state(self.train_slider_device_state)
self.sd.unet.train()
# start accumulating gradients
self.optimizer.zero_grad(set_to_none=True)