mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-05-01 03:31:35 +00:00
WIP. just need to put it here
This commit is contained in:
@@ -157,33 +157,19 @@ class TrainSDRescaleProcess(BaseSDTrainProcess):
|
||||
torch.randint(0, len(self.prompt_txt_list), (1,)).item()
|
||||
]
|
||||
prompt = self.prompt_cache[prompt_txt].to(device=self.device_torch, dtype=dtype)
|
||||
prompt.text_embeds.to(device=self.device_torch, dtype=dtype)
|
||||
prompt.pooled_embeds.to(device=self.device_torch, dtype=dtype)
|
||||
neutral = self.prompt_cache[""].to(device=self.device_torch, dtype=dtype)
|
||||
neutral.text_embeds.to(device=self.device_torch, dtype=dtype)
|
||||
neutral.pooled_embeds.to(device=self.device_torch, dtype=dtype)
|
||||
if prompt is None:
|
||||
raise ValueError(f"Prompt {prompt_txt} is not in cache")
|
||||
|
||||
prompt_batch = train_tools.concat_prompt_embeddings(
|
||||
prompt,
|
||||
neutral,
|
||||
self.train_config.batch_size,
|
||||
)
|
||||
|
||||
noise_scheduler = self.sd.noise_scheduler
|
||||
optimizer = self.optimizer
|
||||
lr_scheduler = self.lr_scheduler
|
||||
loss_function = torch.nn.MSELoss()
|
||||
|
||||
def get_noise_pred(p, n, gs, cts, dn):
|
||||
return self.predict_noise(
|
||||
latents=dn,
|
||||
text_embeddings=train_tools.concat_prompt_embeddings(
|
||||
p, # unconditional
|
||||
n, # positive
|
||||
self.train_config.batch_size,
|
||||
),
|
||||
timestep=cts,
|
||||
guidance_scale=gs,
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
self.sd.noise_scheduler.set_timesteps(
|
||||
self.train_config.max_denoising_steps, device=self.device_torch
|
||||
@@ -195,52 +181,60 @@ class TrainSDRescaleProcess(BaseSDTrainProcess):
|
||||
timesteps_to = torch.randint(
|
||||
1, self.train_config.max_denoising_steps, (1,)
|
||||
).item()
|
||||
absolute_total_timesteps = 1000
|
||||
|
||||
# get noise
|
||||
noise = self.get_latent_noise(
|
||||
latents = self.get_latent_noise(
|
||||
pixel_height=self.rescale_config.from_resolution,
|
||||
pixel_width=self.rescale_config.from_resolution,
|
||||
).to(self.device_torch, dtype=dtype)
|
||||
|
||||
# get latents
|
||||
latents = noise * self.sd.noise_scheduler.init_noise_sigma
|
||||
latents = latents.to(self.device_torch, dtype=dtype)
|
||||
#
|
||||
# # predict without network
|
||||
# assert self.network.is_active is False
|
||||
# denoised_latents = self.diffuse_some_steps(
|
||||
# latents, # pass simple noise latents
|
||||
# prompt_batch,
|
||||
# start_timesteps=0,
|
||||
# total_timesteps=timesteps_to,
|
||||
# guidance_scale=3,
|
||||
# )
|
||||
# noise_scheduler.set_timesteps(1000)
|
||||
#
|
||||
# current_timestep = noise_scheduler.timesteps[
|
||||
# int(timesteps_to * 1000 / self.train_config.max_denoising_steps)
|
||||
# ]
|
||||
denoised_fraction = timesteps_to / absolute_total_timesteps
|
||||
|
||||
current_timestep = 0
|
||||
denoised_latents = latents
|
||||
# get noise prediction at full scale
|
||||
from_prediction = get_noise_pred(
|
||||
prompt, neutral, 1, current_timestep, denoised_latents
|
||||
denoised_latents = self.sd.pipeline(
|
||||
num_inference_steps=1000,
|
||||
denoising_end=denoised_fraction,
|
||||
latents=latents,
|
||||
prompt_embeds=prompt.text_embeds,
|
||||
negative_prompt_embeds=neutral.text_embeds,
|
||||
pooled_prompt_embeds=prompt.pooled_embeds,
|
||||
negative_pooled_prompt_embeds=neutral.pooled_embeds,
|
||||
output_type="latent",
|
||||
num_images_per_prompt=self.train_config.batch_size,
|
||||
guidance_scale=3,
|
||||
).images.to(self.device_torch, dtype=dtype)
|
||||
|
||||
current_timestep = timesteps_to
|
||||
|
||||
from_prediction = self.sd.pipeline.predict_noise(
|
||||
latents=denoised_latents,
|
||||
prompt_embeds=prompt.text_embeds,
|
||||
negative_prompt_embeds=neutral.text_embeds,
|
||||
pooled_prompt_embeds=prompt.pooled_embeds,
|
||||
negative_pooled_prompt_embeds=neutral.pooled_embeds,
|
||||
timestep=current_timestep,
|
||||
guidance_scale=2
|
||||
)
|
||||
|
||||
reduced_from_prediction = self.reduce_size_fn(from_prediction).to("cpu", dtype=torch.float32)
|
||||
|
||||
# get noise prediction at reduced scale
|
||||
to_denoised_latents = self.reduce_size_fn(denoised_latents)
|
||||
to_denoised_latents = self.reduce_size_fn(denoised_latents).to(self.device_torch, dtype=dtype)
|
||||
|
||||
# start gradient
|
||||
optimizer.zero_grad()
|
||||
self.network.multiplier = 1.0
|
||||
with self.network:
|
||||
assert self.network.is_active is True
|
||||
to_prediction = get_noise_pred(
|
||||
prompt, neutral, 1, current_timestep, to_denoised_latents
|
||||
).to("cpu", dtype=torch.float32)
|
||||
to_prediction = self.sd.pipeline.predict_noise(
|
||||
latents=to_denoised_latents,
|
||||
prompt_embeds=prompt.text_embeds,
|
||||
negative_prompt_embeds=neutral.text_embeds,
|
||||
pooled_prompt_embeds=prompt.pooled_embeds,
|
||||
negative_pooled_prompt_embeds=neutral.pooled_embeds,
|
||||
timestep=current_timestep,
|
||||
guidance_scale=2
|
||||
)
|
||||
|
||||
reduced_from_prediction.requires_grad = False
|
||||
from_prediction.requires_grad = False
|
||||
|
||||
Reference in New Issue
Block a user