mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Allow fine tuning pruned versions of chroma. Allow flash attention 2 for chroma if it is installed.
This commit is contained in:
@@ -141,19 +141,38 @@ class ChromaModel(BaseModel):
|
||||
extras_path = 'ostris/Flex.1-alpha'
|
||||
|
||||
self.print_and_status_update("Loading transformer")
|
||||
|
||||
chroma_state_dict = load_file(model_path, 'cpu')
|
||||
|
||||
# determine number of double and single blocks
|
||||
double_blocks = 0
|
||||
single_blocks = 0
|
||||
for key in chroma_state_dict.keys():
|
||||
if "double_blocks" in key:
|
||||
block_num = int(key.split(".")[1]) + 1
|
||||
if block_num > double_blocks:
|
||||
double_blocks = block_num
|
||||
elif "single_blocks" in key:
|
||||
block_num = int(key.split(".")[1]) + 1
|
||||
if block_num > single_blocks:
|
||||
single_blocks = block_num
|
||||
print(f"Double Blocks: {double_blocks}")
|
||||
print(f"Single Blocks: {single_blocks}")
|
||||
|
||||
chroma_params.depth = double_blocks
|
||||
chroma_params.depth_single_blocks = single_blocks
|
||||
transformer = Chroma(chroma_params)
|
||||
|
||||
# add dtype, not sure why it doesnt have it
|
||||
transformer.dtype = dtype
|
||||
|
||||
chroma_state_dict = load_file(model_path, 'cpu')
|
||||
# load the state dict into the model
|
||||
transformer.load_state_dict(chroma_state_dict)
|
||||
|
||||
transformer.to(self.quantize_device, dtype=dtype)
|
||||
|
||||
transformer.config = FakeConfig()
|
||||
transformer.config.num_layers = double_blocks
|
||||
transformer.config.num_single_layers = single_blocks
|
||||
|
||||
if self.model_config.quantize:
|
||||
# patch the state dict method
|
||||
|
||||
@@ -2,14 +2,32 @@ import torch
|
||||
from einops import rearrange
|
||||
from torch import Tensor
|
||||
|
||||
# Flash-Attention 2 (optional)
|
||||
try:
|
||||
from flash_attn.flash_attn_interface import flash_attn_func # type: ignore
|
||||
_HAS_FLASH = True
|
||||
except (ImportError, ModuleNotFoundError):
|
||||
_HAS_FLASH = False
|
||||
|
||||
|
||||
def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask: Tensor) -> Tensor:
|
||||
q, k = apply_rope(q, k, pe)
|
||||
|
||||
# mask should have shape [B, H, L, D]
|
||||
x = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask)
|
||||
x = rearrange(x, "B H L D -> B L (H D)")
|
||||
if _HAS_FLASH and mask is None and q.is_cuda:
|
||||
x = flash_attn_func(
|
||||
rearrange(q, "B H L D -> B L H D").contiguous(),
|
||||
rearrange(k, "B H L D -> B L H D").contiguous(),
|
||||
rearrange(v, "B H L D -> B L H D").contiguous(),
|
||||
dropout_p=0.0,
|
||||
softmax_scale=None,
|
||||
causal=False,
|
||||
)
|
||||
x = rearrange(x, "B L H D -> B H L D")
|
||||
else:
|
||||
x = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask)
|
||||
|
||||
x = rearrange(x, "B H L D -> B L (H D)")
|
||||
return x
|
||||
|
||||
|
||||
|
||||
@@ -253,13 +253,13 @@ class StableDiffusion:
|
||||
|
||||
def get_bucket_divisibility(self):
|
||||
if self.vae is None:
|
||||
return 8
|
||||
return 16
|
||||
divisibility = 2 ** (len(self.vae.config['block_out_channels']) - 1)
|
||||
|
||||
# flux packs this again,
|
||||
if self.is_flux or self.is_v3:
|
||||
divisibility = divisibility * 2
|
||||
return divisibility
|
||||
return divisibility * 2 # todo remove this
|
||||
|
||||
|
||||
def load_model(self):
|
||||
|
||||
Reference in New Issue
Block a user