diff --git a/backend/nn/flux.py b/backend/nn/flux.py index ae5144f6..3423ca95 100644 --- a/backend/nn/flux.py +++ b/backend/nn/flux.py @@ -150,10 +150,7 @@ class Modulation(nn.Module): def forward(self, vec): out = self.lin(nn.functional.silu(vec))[:, None, :].chunk(self.multiplier, dim=-1) - if self.is_double: - return out[0], out[1], out[2], out[3], out[4], out[5], - else: - return out[0], out[1], out[2] + return out class DoubleStreamBlock(nn.Module): @@ -295,6 +292,7 @@ class IntegratedFluxTransformer2DModel(nn.Module): def __init__(self, in_channels: int, vec_in_dim: int, context_in_dim: int, hidden_size: int, mlp_ratio: float, num_heads: int, depth: int, depth_single_blocks: int, axes_dim: list[int], theta: int, qkv_bias: bool, guidance_embed: bool): super().__init__() + self.guidance_embed = guidance_embed self.in_channels = in_channels * 4 self.out_channels = self.in_channels @@ -341,7 +339,7 @@ class IntegratedFluxTransformer2DModel(nn.Module): raise ValueError("Input img and txt tensors must have 3 dimensions.") img = self.img_in(img) vec = self.time_in(timestep_embedding(timesteps, 256).to(img.dtype)) - if self.params.guidance_embed: + if self.guidance_embed: if guidance is None: raise ValueError("Didn't get guidance strength for guidance distilled model.") vec = vec + self.guidance_in(timestep_embedding(guidance, 256).to(img.dtype))