mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-03-14 01:19:49 +00:00
upload Schnell support files
This commit is contained in:
@@ -13,7 +13,7 @@ from backend import memory_management
|
||||
|
||||
|
||||
class Flux(ForgeDiffusionEngine):
|
||||
matched_guesses = [model_list.Flux]
|
||||
matched_guesses = [model_list.Flux, model_list.FluxSchnell]
|
||||
|
||||
def __init__(self, estimated_config, huggingface_components):
|
||||
super().__init__(estimated_config, huggingface_components)
|
||||
@@ -32,10 +32,16 @@ class Flux(ForgeDiffusionEngine):
|
||||
|
||||
vae = VAE(model=huggingface_components['vae'])
|
||||
|
||||
if 'schnell' in estimated_config.huggingface_repo.lower():
|
||||
k_predictor = PredictionFlux(sigma_data=1.0, prediction_type='const', shift=1.0, timesteps=10000)
|
||||
else:
|
||||
k_predictor = PredictionFlux(sigma_data=1.0, prediction_type='const', shift=1.15, timesteps=10000)
|
||||
self.use_distilled_cfg_scale = True
|
||||
|
||||
unet = UnetPatcher.from_model(
|
||||
model=huggingface_components['transformer'],
|
||||
diffusers_scheduler=None,
|
||||
k_predictor=PredictionFlux(sigma_data=1.0, prediction_type='const', shift=1.15, timesteps=10000),
|
||||
k_predictor=k_predictor,
|
||||
config=estimated_config
|
||||
)
|
||||
|
||||
@@ -63,8 +69,6 @@ class Flux(ForgeDiffusionEngine):
|
||||
self.forge_objects_original = self.forge_objects.shallow_copy()
|
||||
self.forge_objects_after_applying_lora = self.forge_objects.shallow_copy()
|
||||
|
||||
self.use_distilled_cfg_scale = True
|
||||
|
||||
def set_clip_skip(self, clip_skip):
|
||||
self.text_processing_engine_l.clip_skip = clip_skip
|
||||
|
||||
@@ -77,11 +81,12 @@ class Flux(ForgeDiffusionEngine):
|
||||
distilled_cfg_scale = getattr(prompt, 'distilled_cfg_scale', 3.5) or 3.5
|
||||
print(f'distilled_cfg_scale = {distilled_cfg_scale}')
|
||||
|
||||
cond = dict(
|
||||
crossattn=cond_t5,
|
||||
vector=pooled_l,
|
||||
guidance=torch.FloatTensor([distilled_cfg_scale] * len(prompt))
|
||||
)
|
||||
cond = dict(crossattn=cond_t5, vector=pooled_l)
|
||||
|
||||
if self.use_distilled_cfg_scale:
|
||||
cond['guidance'] = torch.FloatTensor([distilled_cfg_scale] * len(prompt))
|
||||
else:
|
||||
print('Distilled CFG Scale will be ignored for Schnell')
|
||||
|
||||
return cond
|
||||
|
||||
|
||||
Reference in New Issue
Block a user