mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-30 19:21:39 +00:00
Varous bug fixes. Finalized targeted guidance algo
This commit is contained in:
@@ -1061,6 +1061,9 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
# load last saved weights
|
||||
if latest_save_path is not None:
|
||||
self.embedding.load_embedding_from_file(latest_save_path, self.device_torch)
|
||||
|
||||
# self.step_num = self.embedding.step
|
||||
# self.start_step = self.step_num
|
||||
params.append({
|
||||
'params': self.embedding.get_trainable_params(),
|
||||
'lr': self.train_config.embedding_lr
|
||||
@@ -1068,27 +1071,6 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
|
||||
flush()
|
||||
|
||||
if self.embed_config is not None:
|
||||
self.embedding = Embedding(
|
||||
sd=self.sd,
|
||||
embed_config=self.embed_config
|
||||
)
|
||||
latest_save_path = self.get_latest_save_path(self.embed_config.trigger)
|
||||
# load last saved weights
|
||||
if latest_save_path is not None:
|
||||
self.embedding.load_embedding_from_file(latest_save_path, self.device_torch)
|
||||
|
||||
# resume state from embedding
|
||||
self.step_num = self.embedding.step
|
||||
self.start_step = self.step_num
|
||||
|
||||
params = self.get_params()
|
||||
if not params:
|
||||
# set trainable params
|
||||
params = self.embedding.get_trainable_params()
|
||||
|
||||
flush()
|
||||
|
||||
if self.adapter_config is not None:
|
||||
self.setup_adapter()
|
||||
# set trainable params
|
||||
|
||||
@@ -327,6 +327,7 @@ class TrainSliderProcess(BaseSDTrainProcess):
|
||||
|
||||
with torch.no_grad():
|
||||
adapter_images = None
|
||||
self.sd.unet.eval()
|
||||
|
||||
# for a complete slider, the batch size is 4 to begin with now
|
||||
true_batch_size = prompt_pair.target_class.text_embeds.shape[0] * self.train_config.batch_size
|
||||
@@ -385,21 +386,22 @@ class TrainSliderProcess(BaseSDTrainProcess):
|
||||
latents = noise * self.sd.noise_scheduler.init_noise_sigma
|
||||
latents = latents.to(self.device_torch, dtype=dtype)
|
||||
|
||||
with self.network:
|
||||
assert self.network.is_active
|
||||
# pass the multiplier list to the network
|
||||
self.network.multiplier = prompt_pair.multiplier_list
|
||||
denoised_latents = self.sd.diffuse_some_steps(
|
||||
latents, # pass simple noise latents
|
||||
train_tools.concat_prompt_embeddings(
|
||||
prompt_pair.positive_target, # unconditional
|
||||
prompt_pair.target_class, # target
|
||||
self.train_config.batch_size,
|
||||
),
|
||||
start_timesteps=0,
|
||||
total_timesteps=timesteps_to,
|
||||
guidance_scale=3,
|
||||
)
|
||||
assert not self.network.is_active
|
||||
self.sd.unet.eval()
|
||||
# pass the multiplier list to the network
|
||||
self.network.multiplier = prompt_pair.multiplier_list
|
||||
denoised_latents = self.sd.diffuse_some_steps(
|
||||
latents, # pass simple noise latents
|
||||
train_tools.concat_prompt_embeddings(
|
||||
prompt_pair.positive_target, # unconditional
|
||||
prompt_pair.target_class, # target
|
||||
self.train_config.batch_size,
|
||||
),
|
||||
start_timesteps=0,
|
||||
total_timesteps=timesteps_to,
|
||||
guidance_scale=3,
|
||||
)
|
||||
|
||||
|
||||
noise_scheduler.set_timesteps(1000)
|
||||
|
||||
@@ -473,6 +475,7 @@ class TrainSliderProcess(BaseSDTrainProcess):
|
||||
denoised_latents = denoised_latents.detach()
|
||||
|
||||
self.sd.set_device_state(self.train_slider_device_state)
|
||||
self.sd.unet.train()
|
||||
# start accumulating gradients
|
||||
self.optimizer.zero_grad(set_to_none=True)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user