mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-03-02 03:29:49 +00:00
Add Chroma (#2925)
This commit is contained in:
committed by
GitHub
parent
ae278f7940
commit
963e7643f0
@@ -22,9 +22,10 @@ from backend.diffusion_engine.sd20 import StableDiffusion2
|
||||
from backend.diffusion_engine.sdxl import StableDiffusionXL, StableDiffusionXLRefiner
|
||||
from backend.diffusion_engine.sd35 import StableDiffusion3
|
||||
from backend.diffusion_engine.flux import Flux
|
||||
from backend.diffusion_engine.chroma import Chroma
|
||||
|
||||
|
||||
possible_models = [StableDiffusion, StableDiffusion2, StableDiffusionXLRefiner, StableDiffusionXL, StableDiffusion3, Flux]
|
||||
possible_models = [StableDiffusion, StableDiffusion2, StableDiffusionXLRefiner, StableDiffusionXL, StableDiffusion3, Chroma, Flux]
|
||||
|
||||
|
||||
logging.getLogger("diffusers").setLevel(logging.ERROR)
|
||||
@@ -108,7 +109,7 @@ def load_huggingface_component(guess, component_name, lib_name, cls_name, repo_p
|
||||
load_state_dict(model, state_dict, log_name=cls_name, ignore_errors=['transformer.encoder.embed_tokens.weight', 'logit_scale'])
|
||||
|
||||
return model
|
||||
if cls_name in ['UNet2DConditionModel', 'FluxTransformer2DModel', 'SD3Transformer2DModel']:
|
||||
if cls_name in ['UNet2DConditionModel', 'FluxTransformer2DModel', 'SD3Transformer2DModel', 'ChromaTransformer2DModel']:
|
||||
assert isinstance(state_dict, dict) and len(state_dict) > 16, 'You do not have model state dict!'
|
||||
|
||||
model_loader = None
|
||||
@@ -117,6 +118,9 @@ def load_huggingface_component(guess, component_name, lib_name, cls_name, repo_p
|
||||
elif cls_name == 'FluxTransformer2DModel':
|
||||
from backend.nn.flux import IntegratedFluxTransformer2DModel
|
||||
model_loader = lambda c: IntegratedFluxTransformer2DModel(**c)
|
||||
elif cls_name == 'ChromaTransformer2DModel':
|
||||
from backend.nn.chroma import IntegratedChromaTransformer2DModel
|
||||
model_loader = lambda c: IntegratedChromaTransformer2DModel(**c)
|
||||
elif cls_name == 'SD3Transformer2DModel':
|
||||
from backend.nn.mmditx import MMDiTX
|
||||
model_loader = lambda c: MMDiTX(**c)
|
||||
@@ -478,7 +482,18 @@ def split_state_dict(sd, additional_state_dicts: list = None):
|
||||
|
||||
return state_dict, guess
|
||||
|
||||
# To be removed once PR merged on huggingface_guess
|
||||
chroma_is_in_huggingface_guess = hasattr(huggingface_guess.model_list, "Chroma")
|
||||
|
||||
if not chroma_is_in_huggingface_guess:
|
||||
class GuessChroma:
|
||||
huggingface_repo = 'Chroma'
|
||||
unet_extra_config = {
|
||||
'guidance_out_dim': 3072,
|
||||
'guidance_hidden_dim': 5120,
|
||||
'guidance_n_layers': 5
|
||||
}
|
||||
unet_remove_config = ['guidance_embed']
|
||||
@torch.inference_mode()
|
||||
def forge_loader(sd, additional_state_dicts=None):
|
||||
try:
|
||||
@@ -486,6 +501,17 @@ def forge_loader(sd, additional_state_dicts=None):
|
||||
except:
|
||||
raise ValueError('Failed to recognize model type!')
|
||||
|
||||
if not chroma_is_in_huggingface_guess \
|
||||
and estimated_config.huggingface_repo == "black-forest-labs/FLUX.1-schnell" \
|
||||
and "transformer" in state_dicts \
|
||||
and "distilled_guidance_layer.layers.0.in_layer.bias" in state_dicts["transformer"]:
|
||||
estimated_config.huggingface_repo = GuessChroma.huggingface_repo
|
||||
for x in GuessChroma.unet_extra_config:
|
||||
estimated_config.unet_config[x] = GuessChroma.unet_extra_config[x]
|
||||
for x in GuessChroma.unet_remove_config:
|
||||
del estimated_config.unet_config[x]
|
||||
state_dicts['text_encoder'] = state_dicts['text_encoder_2']
|
||||
del state_dicts['text_encoder_2']
|
||||
repo_name = estimated_config.huggingface_repo
|
||||
|
||||
local_path = os.path.join(dir_path, 'huggingface', repo_name)
|
||||
@@ -540,6 +566,8 @@ def forge_loader(sd, additional_state_dicts=None):
|
||||
else:
|
||||
huggingface_components['scheduler'].config.prediction_type = prediction_types.get(estimated_config.model_type.name, huggingface_components['scheduler'].config.prediction_type)
|
||||
|
||||
if not chroma_is_in_huggingface_guess and estimated_config.huggingface_repo == "Chroma":
|
||||
return Chroma(estimated_config=estimated_config, huggingface_components=huggingface_components)
|
||||
for M in possible_models:
|
||||
if any(isinstance(estimated_config, x) for x in M.matched_guesses):
|
||||
return M(estimated_config=estimated_config, huggingface_components=huggingface_components)
|
||||
|
||||
Reference in New Issue
Block a user