WIP. just need to put it here

This commit is contained in:
Jaret Burkett
2023-07-27 01:46:30 -06:00
parent 2305e55c82
commit 6ab8b8b0f1
4 changed files with 279 additions and 63 deletions

View File

@@ -157,33 +157,19 @@ class TrainSDRescaleProcess(BaseSDTrainProcess):
torch.randint(0, len(self.prompt_txt_list), (1,)).item()
]
prompt = self.prompt_cache[prompt_txt].to(device=self.device_torch, dtype=dtype)
prompt.text_embeds.to(device=self.device_torch, dtype=dtype)
prompt.pooled_embeds.to(device=self.device_torch, dtype=dtype)
neutral = self.prompt_cache[""].to(device=self.device_torch, dtype=dtype)
neutral.text_embeds.to(device=self.device_torch, dtype=dtype)
neutral.pooled_embeds.to(device=self.device_torch, dtype=dtype)
if prompt is None:
raise ValueError(f"Prompt {prompt_txt} is not in cache")
prompt_batch = train_tools.concat_prompt_embeddings(
prompt,
neutral,
self.train_config.batch_size,
)
noise_scheduler = self.sd.noise_scheduler
optimizer = self.optimizer
lr_scheduler = self.lr_scheduler
loss_function = torch.nn.MSELoss()
def get_noise_pred(p, n, gs, cts, dn):
return self.predict_noise(
latents=dn,
text_embeddings=train_tools.concat_prompt_embeddings(
p, # unconditional
n, # positive
self.train_config.batch_size,
),
timestep=cts,
guidance_scale=gs,
)
with torch.no_grad():
self.sd.noise_scheduler.set_timesteps(
self.train_config.max_denoising_steps, device=self.device_torch
@@ -195,52 +181,60 @@ class TrainSDRescaleProcess(BaseSDTrainProcess):
timesteps_to = torch.randint(
1, self.train_config.max_denoising_steps, (1,)
).item()
absolute_total_timesteps = 1000
# get noise
noise = self.get_latent_noise(
latents = self.get_latent_noise(
pixel_height=self.rescale_config.from_resolution,
pixel_width=self.rescale_config.from_resolution,
).to(self.device_torch, dtype=dtype)
# get latents
latents = noise * self.sd.noise_scheduler.init_noise_sigma
latents = latents.to(self.device_torch, dtype=dtype)
#
# # predict without network
# assert self.network.is_active is False
# denoised_latents = self.diffuse_some_steps(
# latents, # pass simple noise latents
# prompt_batch,
# start_timesteps=0,
# total_timesteps=timesteps_to,
# guidance_scale=3,
# )
# noise_scheduler.set_timesteps(1000)
#
# current_timestep = noise_scheduler.timesteps[
# int(timesteps_to * 1000 / self.train_config.max_denoising_steps)
# ]
denoised_fraction = timesteps_to / absolute_total_timesteps
current_timestep = 0
denoised_latents = latents
# get noise prediction at full scale
from_prediction = get_noise_pred(
prompt, neutral, 1, current_timestep, denoised_latents
denoised_latents = self.sd.pipeline(
num_inference_steps=1000,
denoising_end=denoised_fraction,
latents=latents,
prompt_embeds=prompt.text_embeds,
negative_prompt_embeds=neutral.text_embeds,
pooled_prompt_embeds=prompt.pooled_embeds,
negative_pooled_prompt_embeds=neutral.pooled_embeds,
output_type="latent",
num_images_per_prompt=self.train_config.batch_size,
guidance_scale=3,
).images.to(self.device_torch, dtype=dtype)
current_timestep = timesteps_to
from_prediction = self.sd.pipeline.predict_noise(
latents=denoised_latents,
prompt_embeds=prompt.text_embeds,
negative_prompt_embeds=neutral.text_embeds,
pooled_prompt_embeds=prompt.pooled_embeds,
negative_pooled_prompt_embeds=neutral.pooled_embeds,
timestep=current_timestep,
guidance_scale=2
)
reduced_from_prediction = self.reduce_size_fn(from_prediction).to("cpu", dtype=torch.float32)
# get noise prediction at reduced scale
to_denoised_latents = self.reduce_size_fn(denoised_latents)
to_denoised_latents = self.reduce_size_fn(denoised_latents).to(self.device_torch, dtype=dtype)
# start gradient
optimizer.zero_grad()
self.network.multiplier = 1.0
with self.network:
assert self.network.is_active is True
to_prediction = get_noise_pred(
prompt, neutral, 1, current_timestep, to_denoised_latents
).to("cpu", dtype=torch.float32)
to_prediction = self.sd.pipeline.predict_noise(
latents=to_denoised_latents,
prompt_embeds=prompt.text_embeds,
negative_prompt_embeds=neutral.text_embeds,
pooled_prompt_embeds=prompt.pooled_embeds,
negative_pooled_prompt_embeds=neutral.pooled_embeds,
timestep=current_timestep,
guidance_scale=2
)
reduced_from_prediction.requires_grad = False
from_prediction.requires_grad = False