diff --git a/backend/operations.py b/backend/operations.py index 16288e3e..250285be 100644 --- a/backend/operations.py +++ b/backend/operations.py @@ -1,7 +1,8 @@ +import time import torch import contextlib -from backend import stream +from backend import stream, memory_management stash = {} @@ -304,3 +305,44 @@ def shift_manual_cast(model, enabled): if hasattr(m, 'parameters_manual_cast'): m.parameters_manual_cast = enabled return + + +@contextlib.contextmanager +def automatic_memory_management(): + memory_management.free_memory( + memory_required=3 * 1024 * 1024 * 1024, + device=memory_management.get_torch_device() + ) + + module_list = [] + + original_init = torch.nn.Module.__init__ + original_to = torch.nn.Module.to + + def patched_init(self, *args, **kwargs): + module_list.append(self) + return original_init(self, *args, **kwargs) + + def patched_to(self, *args, **kwargs): + module_list.append(self) + return original_to(self, *args, **kwargs) + + try: + torch.nn.Module.__init__ = patched_init + torch.nn.Module.to = patched_to + yield + finally: + torch.nn.Module.__init__ = original_init + torch.nn.Module.to = original_to + + start = time.perf_counter() + module_list = set(module_list) + + for module in module_list: + module.cpu() + + memory_management.soft_empty_cache() + end = time.perf_counter() + + print(f'Automatic Memory Management: {len(module_list)} Modules in {(end - start):.2f} seconds.') + return diff --git a/backend/patcher/controlnet.py b/backend/patcher/controlnet.py index 7ea5721b..ebbbf4fe 100644 --- a/backend/patcher/controlnet.py +++ b/backend/patcher/controlnet.py @@ -8,6 +8,81 @@ from backend.patcher.base import ModelPatcher from backend.operations import using_forge_operations, ForgeOperations, main_stream_worker, weights_manual_cast +def apply_controlnet_advanced( + unet, + controlnet, + image_bchw, + strength, + start_percent, + end_percent, + positive_advanced_weighting=None, + negative_advanced_weighting=None, + advanced_frame_weighting=None, + advanced_sigma_weighting=None, + advanced_mask_weighting=None +): + """ + + # positive_advanced_weighting or negative_advanced_weighting + + Unet has input, middle, output blocks, and we can give different weights to each layers in all blocks. + Below is an example for stronger control in middle block. + This is helpful for some high-res fix passes. + + positive_advanced_weighting = { + 'input': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2], + 'middle': [1.0], + 'output': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2] + } + negative_advanced_weighting = { + 'input': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2], + 'middle': [1.0], + 'output': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2] + } + + # advanced_frame_weighting + + The advanced_frame_weighting is a weight applied to each image in a batch. + The length of this list must be same with batch size + For example, if batch size is 5, you can use advanced_frame_weighting = [0, 0.25, 0.5, 0.75, 1.0] + If you view the 5 images as 5 frames in a video, this will lead to progressively stronger control over time. + + # advanced_sigma_weighting + + The advanced_sigma_weighting allows you to dynamically compute control + weights given diffusion timestep (sigma). + For example below code can softly make beginning steps stronger than ending steps. + + sigma_max = unet.model.model_sampling.sigma_max + sigma_min = unet.model.model_sampling.sigma_min + advanced_sigma_weighting = lambda s: (s - sigma_min) / (sigma_max - sigma_min) + + # advanced_mask_weighting + + A mask can be applied to control signals. + This should be a tensor with shape B 1 H W where the H and W can be arbitrary. + This mask will be resized automatically to match the shape of all injection layers. + + """ + + cnet = controlnet.copy().set_cond_hint(image_bchw, strength, (start_percent, end_percent)) + cnet.positive_advanced_weighting = positive_advanced_weighting + cnet.negative_advanced_weighting = negative_advanced_weighting + cnet.advanced_frame_weighting = advanced_frame_weighting + cnet.advanced_sigma_weighting = advanced_sigma_weighting + + if advanced_mask_weighting is not None: + assert isinstance(advanced_mask_weighting, torch.Tensor) + B, C, H, W = advanced_mask_weighting.shape + assert B > 0 and C == 1 and H > 0 and W > 0 + + cnet.advanced_mask_weighting = advanced_mask_weighting + + m = unet.clone() + m.add_patched_controlnet(cnet) + return m + + def compute_controlnet_weighting(control, cnet): positive_advanced_weighting = getattr(cnet, 'positive_advanced_weighting', None) negative_advanced_weighting = getattr(cnet, 'negative_advanced_weighting', None) diff --git a/extensions-builtin/forge_legacy_preprocessors/scripts/legacy_preprocessors.py b/extensions-builtin/forge_legacy_preprocessors/scripts/legacy_preprocessors.py index b90fb61f..74164cf9 100644 --- a/extensions-builtin/forge_legacy_preprocessors/scripts/legacy_preprocessors.py +++ b/extensions-builtin/forge_legacy_preprocessors/scripts/legacy_preprocessors.py @@ -13,7 +13,7 @@ import contextlib from annotator.util import HWC3 -from modules_forge.ops import automatic_memory_management +from backend.operations import automatic_memory_management from legacy_preprocessors.preprocessor_compiled import legacy_preprocessors from modules_forge.supported_preprocessor import Preprocessor, PreprocessorParameter from modules_forge.shared import add_supported_preprocessor diff --git a/modules_forge/controlnet.py b/modules_forge/controlnet.py deleted file mode 100644 index e8585c53..00000000 --- a/modules_forge/controlnet.py +++ /dev/null @@ -1,77 +0,0 @@ -import torch - - -def apply_controlnet_advanced( - unet, - controlnet, - image_bchw, - strength, - start_percent, - end_percent, - positive_advanced_weighting=None, - negative_advanced_weighting=None, - advanced_frame_weighting=None, - advanced_sigma_weighting=None, - advanced_mask_weighting=None -): - """ - - # positive_advanced_weighting or negative_advanced_weighting - - Unet has input, middle, output blocks, and we can give different weights to each layers in all blocks. - Below is an example for stronger control in middle block. - This is helpful for some high-res fix passes. - - positive_advanced_weighting = { - 'input': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2], - 'middle': [1.0], - 'output': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2] - } - negative_advanced_weighting = { - 'input': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2], - 'middle': [1.0], - 'output': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2] - } - - # advanced_frame_weighting - - The advanced_frame_weighting is a weight applied to each image in a batch. - The length of this list must be same with batch size - For example, if batch size is 5, you can use advanced_frame_weighting = [0, 0.25, 0.5, 0.75, 1.0] - If you view the 5 images as 5 frames in a video, this will lead to progressively stronger control over time. - - # advanced_sigma_weighting - - The advanced_sigma_weighting allows you to dynamically compute control - weights given diffusion timestep (sigma). - For example below code can softly make beginning steps stronger than ending steps. - - sigma_max = unet.model.model_sampling.sigma_max - sigma_min = unet.model.model_sampling.sigma_min - advanced_sigma_weighting = lambda s: (s - sigma_min) / (sigma_max - sigma_min) - - # advanced_mask_weighting - - A mask can be applied to control signals. - This should be a tensor with shape B 1 H W where the H and W can be arbitrary. - This mask will be resized automatically to match the shape of all injection layers. - - """ - - cnet = controlnet.copy().set_cond_hint(image_bchw, strength, (start_percent, end_percent)) - cnet.positive_advanced_weighting = positive_advanced_weighting - cnet.negative_advanced_weighting = negative_advanced_weighting - cnet.advanced_frame_weighting = advanced_frame_weighting - cnet.advanced_sigma_weighting = advanced_sigma_weighting - - if advanced_mask_weighting is not None: - assert isinstance(advanced_mask_weighting, torch.Tensor) - B, C, H, W = advanced_mask_weighting.shape - assert B > 0 and C == 1 and H > 0 and W > 0 - - cnet.advanced_mask_weighting = advanced_mask_weighting - - m = unet.clone() - m.add_patched_controlnet(cnet) - return m - diff --git a/modules_forge/ops.py b/modules_forge/ops.py deleted file mode 100644 index aac2f259..00000000 --- a/modules_forge/ops.py +++ /dev/null @@ -1,46 +0,0 @@ -import time -import torch -import contextlib - -from backend import memory_management - - -@contextlib.contextmanager -def automatic_memory_management(): - memory_management.free_memory( - memory_required=3 * 1024 * 1024 * 1024, - device=memory_management.get_torch_device() - ) - - module_list = [] - - original_init = torch.nn.Module.__init__ - original_to = torch.nn.Module.to - - def patched_init(self, *args, **kwargs): - module_list.append(self) - return original_init(self, *args, **kwargs) - - def patched_to(self, *args, **kwargs): - module_list.append(self) - return original_to(self, *args, **kwargs) - - try: - torch.nn.Module.__init__ = patched_init - torch.nn.Module.to = patched_to - yield - finally: - torch.nn.Module.__init__ = original_init - torch.nn.Module.to = original_to - - start = time.perf_counter() - module_list = set(module_list) - - for module in module_list: - module.cpu() - - memory_management.soft_empty_cache() - end = time.perf_counter() - - print(f'Automatic Memory Management: {len(module_list)} Modules in {(end - start):.2f} seconds.') - return diff --git a/modules_forge/supported_controlnet.py b/modules_forge/supported_controlnet.py index 4ccea55e..1d199934 100644 --- a/modules_forge/supported_controlnet.py +++ b/modules_forge/supported_controlnet.py @@ -6,8 +6,7 @@ from huggingface_guess.utils import unet_to_diffusers from backend import memory_management from backend.operations import using_forge_operations from backend.nn.cnets import cldm -from backend.patcher.controlnet import ControlLora, ControlNet, load_t2i_adapter -from modules_forge.controlnet import apply_controlnet_advanced +from backend.patcher.controlnet import ControlLora, ControlNet, load_t2i_adapter, apply_controlnet_advanced from modules_forge.shared import add_supported_control_model