Update controlnet.py

This commit is contained in:
lllyasviel
2024-01-29 18:23:21 -08:00
parent abb96b822c
commit b611a58edf

View File

@@ -459,15 +459,34 @@ class ControlNetForForgeOfficial(scripts.Script):
params.preprocessor = preprocessor
model_filename = global_state.get_controlnet_filename(unit.model)
controlnet_model = cached_controlnet_loader(model_filename)
params.model = cached_controlnet_loader(model_filename)
controlnet_model.strength = float(unit.weight)
controlnet_model.start_percent = float(unit.guidance_start)
controlnet_model.end_percent = float(unit.guidance_end)
controlnet_model.positive_advanced_weighting = None
controlnet_model.negative_advanced_weighting = None
controlnet_model.advanced_frame_weighting = None
controlnet_model.advanced_sigma_weighting = None
logger.info(f"Current ControlNet: {model_filename}")
return
def process_unit_before_every_sampling(self,
p: StableDiffusionProcessing,
unit: external_code.ControlNetUnit,
params: ControlNetCachedParameters,
*args, **kwargs):
is_hr_pass = getattr(p, 'is_hr_pass', False)
if is_hr_pass:
cond = params.control_cond_for_hr_fix
else:
cond = params.control_cond
kwargs.update(dict(unit=unit, params=params))
params.model.strength = float(unit.weight)
params.model.start_percent = float(unit.guidance_start)
params.model.end_percent = float(unit.guidance_end)
params.model.positive_advanced_weighting = None
params.model.negative_advanced_weighting = None
params.model.advanced_frame_weighting = None
params.model.advanced_sigma_weighting = None
soft_weighting = {
'input': [0.09941396206337118, 0.12050177219802567, 0.14606275417942507, 0.17704576264172736,
@@ -488,30 +507,13 @@ class ControlNetForForgeOfficial(scripts.Script):
}
if unit.control_mode == external_code.ControlMode.CONTROL.value:
controlnet_model.positive_advanced_weighting = soft_weighting.copy()
controlnet_model.negative_advanced_weighting = zero_weighting.copy()
params.model.positive_advanced_weighting = soft_weighting.copy()
params.model.negative_advanced_weighting = zero_weighting.copy()
if unit.control_mode == external_code.ControlMode.PROMPT.value:
controlnet_model.positive_advanced_weighting = soft_weighting.copy()
controlnet_model.negative_advanced_weighting = soft_weighting.copy()
params.model = controlnet_model
logger.info(f"Current ControlNet: {model_filename}")
return
def process_unit_before_every_sampling(self,
p: StableDiffusionProcessing,
unit: external_code.ControlNetUnit,
params: ControlNetCachedParameters,
*args, **kwargs):
if getattr(p, 'is_hr_pass', False):
cond = params.control_cond_for_hr_fix
else:
cond = params.control_cond
kwargs.update(dict(unit=unit, params=params))
# high-ref fix pass always use softer injections
if is_hr_pass or unit.control_mode == external_code.ControlMode.PROMPT.value:
params.model.positive_advanced_weighting = soft_weighting.copy()
params.model.negative_advanced_weighting = soft_weighting.copy()
params.preprocessor.process_before_every_sampling(process=p, cond=cond, **kwargs)
params.model.process_before_every_sampling(process=p, cond=cond, **kwargs)