upload Schnell support files

This commit is contained in:
layerdiffusion
2024-08-11 17:31:12 -07:00
parent 19b41b9438
commit 12a02b88ac
17 changed files with 230402 additions and 9 deletions

View File

@@ -13,7 +13,7 @@ from backend import memory_management
class Flux(ForgeDiffusionEngine):
matched_guesses = [model_list.Flux]
matched_guesses = [model_list.Flux, model_list.FluxSchnell]
def __init__(self, estimated_config, huggingface_components):
super().__init__(estimated_config, huggingface_components)
@@ -32,10 +32,16 @@ class Flux(ForgeDiffusionEngine):
vae = VAE(model=huggingface_components['vae'])
if 'schnell' in estimated_config.huggingface_repo.lower():
k_predictor = PredictionFlux(sigma_data=1.0, prediction_type='const', shift=1.0, timesteps=10000)
else:
k_predictor = PredictionFlux(sigma_data=1.0, prediction_type='const', shift=1.15, timesteps=10000)
self.use_distilled_cfg_scale = True
unet = UnetPatcher.from_model(
model=huggingface_components['transformer'],
diffusers_scheduler=None,
k_predictor=PredictionFlux(sigma_data=1.0, prediction_type='const', shift=1.15, timesteps=10000),
k_predictor=k_predictor,
config=estimated_config
)
@@ -63,8 +69,6 @@ class Flux(ForgeDiffusionEngine):
self.forge_objects_original = self.forge_objects.shallow_copy()
self.forge_objects_after_applying_lora = self.forge_objects.shallow_copy()
self.use_distilled_cfg_scale = True
def set_clip_skip(self, clip_skip):
self.text_processing_engine_l.clip_skip = clip_skip
@@ -77,11 +81,12 @@ class Flux(ForgeDiffusionEngine):
distilled_cfg_scale = getattr(prompt, 'distilled_cfg_scale', 3.5) or 3.5
print(f'distilled_cfg_scale = {distilled_cfg_scale}')
cond = dict(
crossattn=cond_t5,
vector=pooled_l,
guidance=torch.FloatTensor([distilled_cfg_scale] * len(prompt))
)
cond = dict(crossattn=cond_t5, vector=pooled_l)
if self.use_distilled_cfg_scale:
cond['guidance'] = torch.FloatTensor([distilled_cfg_scale] * len(prompt))
else:
print('Distilled CFG Scale will be ignored for Schnell')
return cond