Added caching to image sizes so we dont do it every time.

This commit is contained in:
Jaret Burkett
2024-07-15 19:07:41 -06:00
parent e4558dff4b
commit 58dffd43a8
7 changed files with 90 additions and 34 deletions

View File

@@ -714,6 +714,17 @@ class BaseSDTrainProcess(BaseTrainProcess):
sigma = sigma.unsqueeze(-1)
return sigma
def get_noise(self, latents, batch_size, dtype=torch.float32):
# get noise
noise = self.sd.get_latent_noise(
height=latents.shape[2],
width=latents.shape[3],
batch_size=batch_size,
noise_offset=self.train_config.noise_offset,
).to(self.device_torch, dtype=dtype)
return noise
def process_general_training_batch(self, batch: 'DataLoaderBatchDTO'):
with torch.no_grad():
with self.timer('prepare_prompt'):
@@ -927,12 +938,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
timesteps = torch.stack(timesteps, dim=0)
# get noise
noise = self.sd.get_latent_noise(
height=latents.shape[2],
width=latents.shape[3],
batch_size=batch_size,
noise_offset=self.train_config.noise_offset,
).to(self.device_torch, dtype=dtype)
noise = self.get_noise(latents, batch_size, dtype=dtype)
# add dynamic noise offset. Dynamic noise is offsetting the noise to the same channelwise mean as the latents
# this will negate any noise offsets