This commit is contained in:
layerdiffusion
2024-08-08 16:28:31 -07:00
parent 26b7fea8a1
commit ea65ad6763

View File

@@ -150,10 +150,7 @@ class Modulation(nn.Module):
def forward(self, vec):
out = self.lin(nn.functional.silu(vec))[:, None, :].chunk(self.multiplier, dim=-1)
if self.is_double:
return out[0], out[1], out[2], out[3], out[4], out[5],
else:
return out[0], out[1], out[2]
return out
class DoubleStreamBlock(nn.Module):
@@ -295,6 +292,7 @@ class IntegratedFluxTransformer2DModel(nn.Module):
def __init__(self, in_channels: int, vec_in_dim: int, context_in_dim: int, hidden_size: int, mlp_ratio: float, num_heads: int, depth: int, depth_single_blocks: int, axes_dim: list[int], theta: int, qkv_bias: bool, guidance_embed: bool):
super().__init__()
self.guidance_embed = guidance_embed
self.in_channels = in_channels * 4
self.out_channels = self.in_channels
@@ -341,7 +339,7 @@ class IntegratedFluxTransformer2DModel(nn.Module):
raise ValueError("Input img and txt tensors must have 3 dimensions.")
img = self.img_in(img)
vec = self.time_in(timestep_embedding(timesteps, 256).to(img.dtype))
if self.params.guidance_embed:
if self.guidance_embed:
if guidance is None:
raise ValueError("Didn't get guidance strength for guidance distilled model.")
vec = vec + self.guidance_in(timestep_embedding(guidance, 256).to(img.dtype))