mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-03-13 22:49:48 +00:00
Added initial support for f-lite model
This commit is contained in:
@@ -249,7 +249,8 @@ class DiffusionFeatureExtractor3(nn.Module):
|
||||
# lpips_weight=1.0,
|
||||
lpips_weight=10.0,
|
||||
clip_weight=0.1,
|
||||
pixel_weight=0.1
|
||||
pixel_weight=0.1,
|
||||
model=None
|
||||
):
|
||||
dtype = torch.bfloat16
|
||||
device = self.vae.device
|
||||
@@ -274,7 +275,10 @@ class DiffusionFeatureExtractor3(nn.Module):
|
||||
|
||||
# stepped_latents = torch.cat(stepped_chunks, dim=0)
|
||||
|
||||
stepped_latents = noise - noise_pred
|
||||
if model is not None and hasattr(model, 'get_stepped_pred'):
|
||||
stepped_latents = model.get_stepped_pred(noise_pred, noise)
|
||||
else:
|
||||
stepped_latents = noise - noise_pred
|
||||
|
||||
latents = stepped_latents.to(self.vae.device, dtype=self.vae.dtype)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user