mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-26 01:09:19 +00:00
Added caching to image sizes so we dont do it every time.
This commit is contained in:
@@ -714,6 +714,17 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
sigma = sigma.unsqueeze(-1)
|
||||
return sigma
|
||||
|
||||
def get_noise(self, latents, batch_size, dtype=torch.float32):
|
||||
# get noise
|
||||
noise = self.sd.get_latent_noise(
|
||||
height=latents.shape[2],
|
||||
width=latents.shape[3],
|
||||
batch_size=batch_size,
|
||||
noise_offset=self.train_config.noise_offset,
|
||||
).to(self.device_torch, dtype=dtype)
|
||||
|
||||
return noise
|
||||
|
||||
def process_general_training_batch(self, batch: 'DataLoaderBatchDTO'):
|
||||
with torch.no_grad():
|
||||
with self.timer('prepare_prompt'):
|
||||
@@ -927,12 +938,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
timesteps = torch.stack(timesteps, dim=0)
|
||||
|
||||
# get noise
|
||||
noise = self.sd.get_latent_noise(
|
||||
height=latents.shape[2],
|
||||
width=latents.shape[3],
|
||||
batch_size=batch_size,
|
||||
noise_offset=self.train_config.noise_offset,
|
||||
).to(self.device_torch, dtype=dtype)
|
||||
noise = self.get_noise(latents, batch_size, dtype=dtype)
|
||||
|
||||
# add dynamic noise offset. Dynamic noise is offsetting the noise to the same channelwise mean as the latents
|
||||
# this will negate any noise offsets
|
||||
|
||||
Reference in New Issue
Block a user