diff --git a/modules_forge/patch_basic.py b/modules_forge/patch_basic.py index d4bff6ed..504c9ddb 100644 --- a/modules_forge/patch_basic.py +++ b/modules_forge/patch_basic.py @@ -130,7 +130,6 @@ def patched_calc_cond_uncond_batch(model, cond, uncond, x_in, timestep, model_op timestep_ = torch.cat([timestep] * batch_chunks) if control is not None: - control.current_cond_or_uncond = cond_or_uncond c['control'] = control.get_control(input_x, timestep_, c, len(cond_or_uncond)) transformer_options = {} @@ -153,6 +152,9 @@ def patched_calc_cond_uncond_batch(model, cond, uncond, x_in, timestep, model_op c['transformer_options'] = transformer_options + if control is not None: + control.transformer_options = transformer_options + if 'model_function_wrapper' in model_options: output = model_options['model_function_wrapper'](model.apply_model, {"input": input_x, "timestep": timestep_, "c": c, "cond_or_uncond": cond_or_uncond}).chunk(batch_chunks) else: