Add Chroma (#2925)

This commit is contained in:
Mathieu Croquelois
2025-06-23 23:35:06 +01:00
committed by GitHub
parent ae278f7940
commit 963e7643f0
11 changed files with 131222 additions and 2 deletions

View File

@@ -22,9 +22,10 @@ from backend.diffusion_engine.sd20 import StableDiffusion2
from backend.diffusion_engine.sdxl import StableDiffusionXL, StableDiffusionXLRefiner
from backend.diffusion_engine.sd35 import StableDiffusion3
from backend.diffusion_engine.flux import Flux
from backend.diffusion_engine.chroma import Chroma
possible_models = [StableDiffusion, StableDiffusion2, StableDiffusionXLRefiner, StableDiffusionXL, StableDiffusion3, Flux]
possible_models = [StableDiffusion, StableDiffusion2, StableDiffusionXLRefiner, StableDiffusionXL, StableDiffusion3, Chroma, Flux]
logging.getLogger("diffusers").setLevel(logging.ERROR)
@@ -108,7 +109,7 @@ def load_huggingface_component(guess, component_name, lib_name, cls_name, repo_p
load_state_dict(model, state_dict, log_name=cls_name, ignore_errors=['transformer.encoder.embed_tokens.weight', 'logit_scale'])
return model
if cls_name in ['UNet2DConditionModel', 'FluxTransformer2DModel', 'SD3Transformer2DModel']:
if cls_name in ['UNet2DConditionModel', 'FluxTransformer2DModel', 'SD3Transformer2DModel', 'ChromaTransformer2DModel']:
assert isinstance(state_dict, dict) and len(state_dict) > 16, 'You do not have model state dict!'
model_loader = None
@@ -117,6 +118,9 @@ def load_huggingface_component(guess, component_name, lib_name, cls_name, repo_p
elif cls_name == 'FluxTransformer2DModel':
from backend.nn.flux import IntegratedFluxTransformer2DModel
model_loader = lambda c: IntegratedFluxTransformer2DModel(**c)
elif cls_name == 'ChromaTransformer2DModel':
from backend.nn.chroma import IntegratedChromaTransformer2DModel
model_loader = lambda c: IntegratedChromaTransformer2DModel(**c)
elif cls_name == 'SD3Transformer2DModel':
from backend.nn.mmditx import MMDiTX
model_loader = lambda c: MMDiTX(**c)
@@ -478,7 +482,18 @@ def split_state_dict(sd, additional_state_dicts: list = None):
return state_dict, guess
# To be removed once PR merged on huggingface_guess
chroma_is_in_huggingface_guess = hasattr(huggingface_guess.model_list, "Chroma")
if not chroma_is_in_huggingface_guess:
class GuessChroma:
huggingface_repo = 'Chroma'
unet_extra_config = {
'guidance_out_dim': 3072,
'guidance_hidden_dim': 5120,
'guidance_n_layers': 5
}
unet_remove_config = ['guidance_embed']
@torch.inference_mode()
def forge_loader(sd, additional_state_dicts=None):
try:
@@ -486,6 +501,17 @@ def forge_loader(sd, additional_state_dicts=None):
except:
raise ValueError('Failed to recognize model type!')
if not chroma_is_in_huggingface_guess \
and estimated_config.huggingface_repo == "black-forest-labs/FLUX.1-schnell" \
and "transformer" in state_dicts \
and "distilled_guidance_layer.layers.0.in_layer.bias" in state_dicts["transformer"]:
estimated_config.huggingface_repo = GuessChroma.huggingface_repo
for x in GuessChroma.unet_extra_config:
estimated_config.unet_config[x] = GuessChroma.unet_extra_config[x]
for x in GuessChroma.unet_remove_config:
del estimated_config.unet_config[x]
state_dicts['text_encoder'] = state_dicts['text_encoder_2']
del state_dicts['text_encoder_2']
repo_name = estimated_config.huggingface_repo
local_path = os.path.join(dir_path, 'huggingface', repo_name)
@@ -540,6 +566,8 @@ def forge_loader(sd, additional_state_dicts=None):
else:
huggingface_components['scheduler'].config.prediction_type = prediction_types.get(estimated_config.model_type.name, huggingface_components['scheduler'].config.prediction_type)
if not chroma_is_in_huggingface_guess and estimated_config.huggingface_repo == "Chroma":
return Chroma(estimated_config=estimated_config, huggingface_components=huggingface_components)
for M in possible_models:
if any(isinstance(estimated_config, x) for x in M.matched_guesses):
return M(estimated_config=estimated_config, huggingface_components=huggingface_components)