diff --git a/backend/nn/flux.py b/backend/nn/flux.py index 942d9c33..097e6622 100644 --- a/backend/nn/flux.py +++ b/backend/nn/flux.py @@ -372,7 +372,7 @@ class IntegratedFluxTransformer2DModel(nn.Module): del vec return img - def forward(self, x, timestep, context, y, guidance, **kwargs): + def forward(self, x, timestep, context, y, guidance=None, **kwargs): bs, c, h, w = x.shape input_device = x.device input_dtype = x.dtype