diff --git a/extensions_built_in/diffusion_models/chroma/chroma_model.py b/extensions_built_in/diffusion_models/chroma/chroma_model.py index 0072e2e5..31eb6986 100644 --- a/extensions_built_in/diffusion_models/chroma/chroma_model.py +++ b/extensions_built_in/diffusion_models/chroma/chroma_model.py @@ -141,19 +141,38 @@ class ChromaModel(BaseModel): extras_path = 'ostris/Flex.1-alpha' self.print_and_status_update("Loading transformer") + + chroma_state_dict = load_file(model_path, 'cpu') + + # determine number of double and single blocks + double_blocks = 0 + single_blocks = 0 + for key in chroma_state_dict.keys(): + if "double_blocks" in key: + block_num = int(key.split(".")[1]) + 1 + if block_num > double_blocks: + double_blocks = block_num + elif "single_blocks" in key: + block_num = int(key.split(".")[1]) + 1 + if block_num > single_blocks: + single_blocks = block_num + print(f"Double Blocks: {double_blocks}") + print(f"Single Blocks: {single_blocks}") + chroma_params.depth = double_blocks + chroma_params.depth_single_blocks = single_blocks transformer = Chroma(chroma_params) # add dtype, not sure why it doesnt have it transformer.dtype = dtype - - chroma_state_dict = load_file(model_path, 'cpu') # load the state dict into the model transformer.load_state_dict(chroma_state_dict) transformer.to(self.quantize_device, dtype=dtype) transformer.config = FakeConfig() + transformer.config.num_layers = double_blocks + transformer.config.num_single_layers = single_blocks if self.model_config.quantize: # patch the state dict method diff --git a/extensions_built_in/diffusion_models/chroma/src/math.py b/extensions_built_in/diffusion_models/chroma/src/math.py index b46bca57..31205341 100644 --- a/extensions_built_in/diffusion_models/chroma/src/math.py +++ b/extensions_built_in/diffusion_models/chroma/src/math.py @@ -2,14 +2,32 @@ import torch from einops import rearrange from torch import Tensor +# Flash-Attention 2 (optional) +try: + from flash_attn.flash_attn_interface import flash_attn_func # type: ignore + _HAS_FLASH = True +except (ImportError, ModuleNotFoundError): + _HAS_FLASH = False + def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask: Tensor) -> Tensor: q, k = apply_rope(q, k, pe) # mask should have shape [B, H, L, D] - x = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask) - x = rearrange(x, "B H L D -> B L (H D)") + if _HAS_FLASH and mask is None and q.is_cuda: + x = flash_attn_func( + rearrange(q, "B H L D -> B L H D").contiguous(), + rearrange(k, "B H L D -> B L H D").contiguous(), + rearrange(v, "B H L D -> B L H D").contiguous(), + dropout_p=0.0, + softmax_scale=None, + causal=False, + ) + x = rearrange(x, "B L H D -> B H L D") + else: + x = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask) + x = rearrange(x, "B H L D -> B L (H D)") return x diff --git a/toolkit/stable_diffusion_model.py b/toolkit/stable_diffusion_model.py index b9e1b71d..068e747e 100644 --- a/toolkit/stable_diffusion_model.py +++ b/toolkit/stable_diffusion_model.py @@ -253,13 +253,13 @@ class StableDiffusion: def get_bucket_divisibility(self): if self.vae is None: - return 8 + return 16 divisibility = 2 ** (len(self.vae.config['block_out_channels']) - 1) # flux packs this again, if self.is_flux or self.is_v3: divisibility = divisibility * 2 - return divisibility + return divisibility * 2 # todo remove this def load_model(self):