Allow fine tuning pruned versions of chroma. Allow flash attention 2 for chroma if it is installed.

This commit is contained in:
Jaret Burkett
2025-05-21 07:02:50 -06:00
parent 48e11cf843
commit 79499fa795
3 changed files with 43 additions and 6 deletions

View File

@@ -141,19 +141,38 @@ class ChromaModel(BaseModel):
extras_path = 'ostris/Flex.1-alpha'
self.print_and_status_update("Loading transformer")
chroma_state_dict = load_file(model_path, 'cpu')
# determine number of double and single blocks
double_blocks = 0
single_blocks = 0
for key in chroma_state_dict.keys():
if "double_blocks" in key:
block_num = int(key.split(".")[1]) + 1
if block_num > double_blocks:
double_blocks = block_num
elif "single_blocks" in key:
block_num = int(key.split(".")[1]) + 1
if block_num > single_blocks:
single_blocks = block_num
print(f"Double Blocks: {double_blocks}")
print(f"Single Blocks: {single_blocks}")
chroma_params.depth = double_blocks
chroma_params.depth_single_blocks = single_blocks
transformer = Chroma(chroma_params)
# add dtype, not sure why it doesnt have it
transformer.dtype = dtype
chroma_state_dict = load_file(model_path, 'cpu')
# load the state dict into the model
transformer.load_state_dict(chroma_state_dict)
transformer.to(self.quantize_device, dtype=dtype)
transformer.config = FakeConfig()
transformer.config.num_layers = double_blocks
transformer.config.num_single_layers = single_blocks
if self.model_config.quantize:
# patch the state dict method

View File

@@ -2,14 +2,32 @@ import torch
from einops import rearrange
from torch import Tensor
# Flash-Attention 2 (optional)
try:
from flash_attn.flash_attn_interface import flash_attn_func # type: ignore
_HAS_FLASH = True
except (ImportError, ModuleNotFoundError):
_HAS_FLASH = False
def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask: Tensor) -> Tensor:
q, k = apply_rope(q, k, pe)
# mask should have shape [B, H, L, D]
x = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask)
x = rearrange(x, "B H L D -> B L (H D)")
if _HAS_FLASH and mask is None and q.is_cuda:
x = flash_attn_func(
rearrange(q, "B H L D -> B L H D").contiguous(),
rearrange(k, "B H L D -> B L H D").contiguous(),
rearrange(v, "B H L D -> B L H D").contiguous(),
dropout_p=0.0,
softmax_scale=None,
causal=False,
)
x = rearrange(x, "B L H D -> B H L D")
else:
x = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask)
x = rearrange(x, "B H L D -> B L (H D)")
return x

View File

@@ -253,13 +253,13 @@ class StableDiffusion:
def get_bucket_divisibility(self):
if self.vae is None:
return 8
return 16
divisibility = 2 ** (len(self.vae.config['block_out_channels']) - 1)
# flux packs this again,
if self.is_flux or self.is_v3:
divisibility = divisibility * 2
return divisibility
return divisibility * 2 # todo remove this
def load_model(self):