mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-29 02:31:17 +00:00
Add support for FLUX.2 klein base models
This commit is contained in:
@@ -17,6 +17,35 @@ class Flux2Params:
|
||||
axes_dim: list[int] = field(default_factory=lambda: [32, 32, 32, 32])
|
||||
theta: int = 2000
|
||||
mlp_ratio: float = 3.0
|
||||
use_guidance_embed: bool = True
|
||||
|
||||
|
||||
@dataclass
|
||||
class Klein9BParams:
|
||||
in_channels: int = 128
|
||||
context_in_dim: int = 12288
|
||||
hidden_size: int = 4096
|
||||
num_heads: int = 32
|
||||
depth: int = 8
|
||||
depth_single_blocks: int = 24
|
||||
axes_dim: list[int] = field(default_factory=lambda: [32, 32, 32, 32])
|
||||
theta: int = 2000
|
||||
mlp_ratio: float = 3.0
|
||||
use_guidance_embed: bool = False
|
||||
|
||||
|
||||
@dataclass
|
||||
class Klein4BParams:
|
||||
in_channels: int = 128
|
||||
context_in_dim: int = 7680
|
||||
hidden_size: int = 3072
|
||||
num_heads: int = 24
|
||||
depth: int = 5
|
||||
depth_single_blocks: int = 20
|
||||
axes_dim: list[int] = field(default_factory=lambda: [32, 32, 32, 32])
|
||||
theta: int = 2000
|
||||
mlp_ratio: float = 3.0
|
||||
use_guidance_embed: bool = False
|
||||
|
||||
|
||||
class FakeConfig:
|
||||
@@ -50,11 +79,14 @@ class Flux2(nn.Module):
|
||||
self.time_in = MLPEmbedder(
|
||||
in_dim=256, hidden_dim=self.hidden_size, disable_bias=True
|
||||
)
|
||||
self.guidance_in = MLPEmbedder(
|
||||
in_dim=256, hidden_dim=self.hidden_size, disable_bias=True
|
||||
)
|
||||
self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size, bias=False)
|
||||
|
||||
self.use_guidance_embed = params.use_guidance_embed
|
||||
if self.use_guidance_embed:
|
||||
self.guidance_in = MLPEmbedder(
|
||||
in_dim=256, hidden_dim=self.hidden_size, disable_bias=True
|
||||
)
|
||||
|
||||
self.double_blocks = nn.ModuleList(
|
||||
[
|
||||
DoubleStreamBlock(
|
||||
@@ -116,14 +148,15 @@ class Flux2(nn.Module):
|
||||
timesteps: Tensor,
|
||||
ctx: Tensor,
|
||||
ctx_ids: Tensor,
|
||||
guidance: Tensor,
|
||||
guidance: Tensor | None,
|
||||
):
|
||||
num_txt_tokens = ctx.shape[1]
|
||||
|
||||
timestep_emb = timestep_embedding(timesteps, 256)
|
||||
vec = self.time_in(timestep_emb)
|
||||
guidance_emb = timestep_embedding(guidance, 256)
|
||||
vec = vec + self.guidance_in(guidance_emb)
|
||||
if self.use_guidance_embed:
|
||||
guidance_emb = timestep_embedding(guidance, 256)
|
||||
vec = vec + self.guidance_in(guidance_emb)
|
||||
|
||||
double_block_mod_img = self.double_stream_modulation_img(vec)
|
||||
double_block_mod_txt = self.double_stream_modulation_txt(vec)
|
||||
|
||||
Reference in New Issue
Block a user