mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-03-08 03:59:49 +00:00
Make a CFG version of flux pipeline
This commit is contained in:
@@ -36,7 +36,7 @@ from toolkit.sd_device_states_presets import empty_preset
|
||||
from toolkit.train_tools import get_torch_dtype, apply_noise_offset
|
||||
import torch
|
||||
from toolkit.pipelines import CustomStableDiffusionXLPipeline, CustomStableDiffusionPipeline, \
|
||||
StableDiffusionKDiffusionXLPipeline, StableDiffusionXLRefinerPipeline
|
||||
StableDiffusionKDiffusionXLPipeline, StableDiffusionXLRefinerPipeline, FluxWithCFGPipeline
|
||||
from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline, T2IAdapter, DDPMScheduler, \
|
||||
StableDiffusionXLAdapterPipeline, StableDiffusionAdapterPipeline, DiffusionPipeline, PixArtTransformer2DModel, \
|
||||
StableDiffusionXLImg2ImgPipeline, LCMScheduler, Transformer2DModel, AutoencoderTiny, ControlNetModel, \
|
||||
@@ -797,16 +797,29 @@ class StableDiffusion:
|
||||
).to(self.device_torch)
|
||||
pipeline.watermark = None
|
||||
elif self.is_flux:
|
||||
pipeline = FluxPipeline(
|
||||
vae=self.vae,
|
||||
transformer=self.unet,
|
||||
text_encoder=self.text_encoder[0],
|
||||
text_encoder_2=self.text_encoder[1],
|
||||
tokenizer=self.tokenizer[0],
|
||||
tokenizer_2=self.tokenizer[1],
|
||||
scheduler=noise_scheduler,
|
||||
**extra_args
|
||||
)
|
||||
if self.model_config.use_flux_cfg:
|
||||
pipeline = FluxWithCFGPipeline(
|
||||
vae=self.vae,
|
||||
transformer=self.unet,
|
||||
text_encoder=self.text_encoder[0],
|
||||
text_encoder_2=self.text_encoder[1],
|
||||
tokenizer=self.tokenizer[0],
|
||||
tokenizer_2=self.tokenizer[1],
|
||||
scheduler=noise_scheduler,
|
||||
**extra_args
|
||||
)
|
||||
|
||||
else:
|
||||
pipeline = FluxPipeline(
|
||||
vae=self.vae,
|
||||
transformer=self.unet,
|
||||
text_encoder=self.text_encoder[0],
|
||||
text_encoder_2=self.text_encoder[1],
|
||||
tokenizer=self.tokenizer[0],
|
||||
tokenizer_2=self.tokenizer[1],
|
||||
scheduler=noise_scheduler,
|
||||
**extra_args
|
||||
)
|
||||
pipeline.watermark = None
|
||||
elif self.is_v3:
|
||||
pipeline = Pipe(
|
||||
@@ -1068,18 +1081,32 @@ class StableDiffusion:
|
||||
**extra
|
||||
).images[0]
|
||||
elif self.is_flux:
|
||||
img = pipeline(
|
||||
prompt_embeds=conditional_embeds.text_embeds,
|
||||
pooled_prompt_embeds=conditional_embeds.pooled_embeds,
|
||||
# negative_prompt_embeds=unconditional_embeds.text_embeds,
|
||||
# negative_pooled_prompt_embeds=unconditional_embeds.pooled_embeds,
|
||||
height=gen_config.height,
|
||||
width=gen_config.width,
|
||||
num_inference_steps=gen_config.num_inference_steps,
|
||||
guidance_scale=gen_config.guidance_scale,
|
||||
latents=gen_config.latents,
|
||||
**extra
|
||||
).images[0]
|
||||
if self.model_config.use_flux_cfg:
|
||||
img = pipeline(
|
||||
prompt_embeds=conditional_embeds.text_embeds,
|
||||
pooled_prompt_embeds=conditional_embeds.pooled_embeds,
|
||||
negative_prompt_embeds=unconditional_embeds.text_embeds,
|
||||
negative_pooled_prompt_embeds=unconditional_embeds.pooled_embeds,
|
||||
height=gen_config.height,
|
||||
width=gen_config.width,
|
||||
num_inference_steps=gen_config.num_inference_steps,
|
||||
guidance_scale=gen_config.guidance_scale,
|
||||
latents=gen_config.latents,
|
||||
**extra
|
||||
).images[0]
|
||||
else:
|
||||
img = pipeline(
|
||||
prompt_embeds=conditional_embeds.text_embeds,
|
||||
pooled_prompt_embeds=conditional_embeds.pooled_embeds,
|
||||
# negative_prompt_embeds=unconditional_embeds.text_embeds,
|
||||
# negative_pooled_prompt_embeds=unconditional_embeds.pooled_embeds,
|
||||
height=gen_config.height,
|
||||
width=gen_config.width,
|
||||
num_inference_steps=gen_config.num_inference_steps,
|
||||
guidance_scale=gen_config.guidance_scale,
|
||||
latents=gen_config.latents,
|
||||
**extra
|
||||
).images[0]
|
||||
elif self.is_pixart:
|
||||
# needs attention masks for some reason
|
||||
img = pipeline(
|
||||
|
||||
Reference in New Issue
Block a user