This commit is contained in:
lllyasviel
2024-01-18 04:40:29 -08:00
parent 75f5f0da01
commit 3ba5754cc7
15 changed files with 1265 additions and 57 deletions

View File

@@ -23,13 +23,13 @@ class CLIPEmbeddingNoiseAugmentation(ImageConcatWithNoiseAugmentation):
x = (x * self.data_std.to(x.device)) + self.data_mean.to(x.device)
return x
def forward(self, x, noise_level=None):
def forward(self, x, noise_level=None, seed=None):
if noise_level is None:
noise_level = torch.randint(0, self.max_noise_level, (x.shape[0],), device=x.device).long()
else:
assert isinstance(noise_level, torch.Tensor)
x = self.scale(x)
z = self.q_sample(x, noise_level)
z = self.q_sample(x, noise_level, seed=seed)
z = self.unscale(z)
noise_level = self.time_embed(noise_level)
return z, noise_level