mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Added initial support for f-lite model
This commit is contained in:
@@ -249,7 +249,8 @@ class DiffusionFeatureExtractor3(nn.Module):
|
||||
# lpips_weight=1.0,
|
||||
lpips_weight=10.0,
|
||||
clip_weight=0.1,
|
||||
pixel_weight=0.1
|
||||
pixel_weight=0.1,
|
||||
model=None
|
||||
):
|
||||
dtype = torch.bfloat16
|
||||
device = self.vae.device
|
||||
@@ -274,7 +275,10 @@ class DiffusionFeatureExtractor3(nn.Module):
|
||||
|
||||
# stepped_latents = torch.cat(stepped_chunks, dim=0)
|
||||
|
||||
stepped_latents = noise - noise_pred
|
||||
if model is not None and hasattr(model, 'get_stepped_pred'):
|
||||
stepped_latents = model.get_stepped_pred(noise_pred, noise)
|
||||
else:
|
||||
stepped_latents = noise - noise_pred
|
||||
|
||||
latents = stepped_latents.to(self.vae.device, dtype=self.vae.dtype)
|
||||
|
||||
|
||||
@@ -2283,6 +2283,7 @@ class StableDiffusion:
|
||||
bleed_latents: torch.FloatTensor = None,
|
||||
is_input_scaled=False,
|
||||
return_first_prediction=False,
|
||||
bypass_guidance_embedding=False,
|
||||
**kwargs,
|
||||
):
|
||||
timesteps_to_run = self.noise_scheduler.timesteps[start_timesteps:total_timesteps]
|
||||
@@ -2299,6 +2300,7 @@ class StableDiffusion:
|
||||
add_time_ids=add_time_ids,
|
||||
is_input_scaled=is_input_scaled,
|
||||
return_conditional_pred=True,
|
||||
bypass_guidance_embedding=bypass_guidance_embedding,
|
||||
**kwargs,
|
||||
)
|
||||
# some schedulers need to run separately, so do that. (euler for example)
|
||||
|
||||
@@ -145,7 +145,7 @@ if TYPE_CHECKING:
|
||||
def concat_prompt_embeddings(
|
||||
unconditional: 'PromptEmbeds',
|
||||
conditional: 'PromptEmbeds',
|
||||
n_imgs: int,
|
||||
n_imgs: int=0,
|
||||
):
|
||||
from toolkit.stable_diffusion_model import PromptEmbeds
|
||||
text_embeds = torch.cat(
|
||||
|
||||
Reference in New Issue
Block a user