revise structure

This commit is contained in:
layerdiffusion
2024-08-07 20:44:34 -07:00
parent 015587ca59
commit a91a81d8e6
6 changed files with 120 additions and 127 deletions

View File

@@ -1,7 +1,8 @@
import time
import torch
import contextlib
from backend import stream
from backend import stream, memory_management
stash = {}
@@ -304,3 +305,44 @@ def shift_manual_cast(model, enabled):
if hasattr(m, 'parameters_manual_cast'):
m.parameters_manual_cast = enabled
return
@contextlib.contextmanager
def automatic_memory_management():
memory_management.free_memory(
memory_required=3 * 1024 * 1024 * 1024,
device=memory_management.get_torch_device()
)
module_list = []
original_init = torch.nn.Module.__init__
original_to = torch.nn.Module.to
def patched_init(self, *args, **kwargs):
module_list.append(self)
return original_init(self, *args, **kwargs)
def patched_to(self, *args, **kwargs):
module_list.append(self)
return original_to(self, *args, **kwargs)
try:
torch.nn.Module.__init__ = patched_init
torch.nn.Module.to = patched_to
yield
finally:
torch.nn.Module.__init__ = original_init
torch.nn.Module.to = original_to
start = time.perf_counter()
module_list = set(module_list)
for module in module_list:
module.cpu()
memory_management.soft_empty_cache()
end = time.perf_counter()
print(f'Automatic Memory Management: {len(module_list)} Modules in {(end - start):.2f} seconds.')
return

View File

@@ -8,6 +8,81 @@ from backend.patcher.base import ModelPatcher
from backend.operations import using_forge_operations, ForgeOperations, main_stream_worker, weights_manual_cast
def apply_controlnet_advanced(
unet,
controlnet,
image_bchw,
strength,
start_percent,
end_percent,
positive_advanced_weighting=None,
negative_advanced_weighting=None,
advanced_frame_weighting=None,
advanced_sigma_weighting=None,
advanced_mask_weighting=None
):
"""
# positive_advanced_weighting or negative_advanced_weighting
Unet has input, middle, output blocks, and we can give different weights to each layers in all blocks.
Below is an example for stronger control in middle block.
This is helpful for some high-res fix passes.
positive_advanced_weighting = {
'input': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2],
'middle': [1.0],
'output': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2]
}
negative_advanced_weighting = {
'input': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2],
'middle': [1.0],
'output': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2]
}
# advanced_frame_weighting
The advanced_frame_weighting is a weight applied to each image in a batch.
The length of this list must be same with batch size
For example, if batch size is 5, you can use advanced_frame_weighting = [0, 0.25, 0.5, 0.75, 1.0]
If you view the 5 images as 5 frames in a video, this will lead to progressively stronger control over time.
# advanced_sigma_weighting
The advanced_sigma_weighting allows you to dynamically compute control
weights given diffusion timestep (sigma).
For example below code can softly make beginning steps stronger than ending steps.
sigma_max = unet.model.model_sampling.sigma_max
sigma_min = unet.model.model_sampling.sigma_min
advanced_sigma_weighting = lambda s: (s - sigma_min) / (sigma_max - sigma_min)
# advanced_mask_weighting
A mask can be applied to control signals.
This should be a tensor with shape B 1 H W where the H and W can be arbitrary.
This mask will be resized automatically to match the shape of all injection layers.
"""
cnet = controlnet.copy().set_cond_hint(image_bchw, strength, (start_percent, end_percent))
cnet.positive_advanced_weighting = positive_advanced_weighting
cnet.negative_advanced_weighting = negative_advanced_weighting
cnet.advanced_frame_weighting = advanced_frame_weighting
cnet.advanced_sigma_weighting = advanced_sigma_weighting
if advanced_mask_weighting is not None:
assert isinstance(advanced_mask_weighting, torch.Tensor)
B, C, H, W = advanced_mask_weighting.shape
assert B > 0 and C == 1 and H > 0 and W > 0
cnet.advanced_mask_weighting = advanced_mask_weighting
m = unet.clone()
m.add_patched_controlnet(cnet)
return m
def compute_controlnet_weighting(control, cnet):
positive_advanced_weighting = getattr(cnet, 'positive_advanced_weighting', None)
negative_advanced_weighting = getattr(cnet, 'negative_advanced_weighting', None)

View File

@@ -13,7 +13,7 @@
import contextlib
from annotator.util import HWC3
from modules_forge.ops import automatic_memory_management
from backend.operations import automatic_memory_management
from legacy_preprocessors.preprocessor_compiled import legacy_preprocessors
from modules_forge.supported_preprocessor import Preprocessor, PreprocessorParameter
from modules_forge.shared import add_supported_preprocessor

View File

@@ -1,77 +0,0 @@
import torch
def apply_controlnet_advanced(
unet,
controlnet,
image_bchw,
strength,
start_percent,
end_percent,
positive_advanced_weighting=None,
negative_advanced_weighting=None,
advanced_frame_weighting=None,
advanced_sigma_weighting=None,
advanced_mask_weighting=None
):
"""
# positive_advanced_weighting or negative_advanced_weighting
Unet has input, middle, output blocks, and we can give different weights to each layers in all blocks.
Below is an example for stronger control in middle block.
This is helpful for some high-res fix passes.
positive_advanced_weighting = {
'input': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2],
'middle': [1.0],
'output': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2]
}
negative_advanced_weighting = {
'input': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2],
'middle': [1.0],
'output': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2]
}
# advanced_frame_weighting
The advanced_frame_weighting is a weight applied to each image in a batch.
The length of this list must be same with batch size
For example, if batch size is 5, you can use advanced_frame_weighting = [0, 0.25, 0.5, 0.75, 1.0]
If you view the 5 images as 5 frames in a video, this will lead to progressively stronger control over time.
# advanced_sigma_weighting
The advanced_sigma_weighting allows you to dynamically compute control
weights given diffusion timestep (sigma).
For example below code can softly make beginning steps stronger than ending steps.
sigma_max = unet.model.model_sampling.sigma_max
sigma_min = unet.model.model_sampling.sigma_min
advanced_sigma_weighting = lambda s: (s - sigma_min) / (sigma_max - sigma_min)
# advanced_mask_weighting
A mask can be applied to control signals.
This should be a tensor with shape B 1 H W where the H and W can be arbitrary.
This mask will be resized automatically to match the shape of all injection layers.
"""
cnet = controlnet.copy().set_cond_hint(image_bchw, strength, (start_percent, end_percent))
cnet.positive_advanced_weighting = positive_advanced_weighting
cnet.negative_advanced_weighting = negative_advanced_weighting
cnet.advanced_frame_weighting = advanced_frame_weighting
cnet.advanced_sigma_weighting = advanced_sigma_weighting
if advanced_mask_weighting is not None:
assert isinstance(advanced_mask_weighting, torch.Tensor)
B, C, H, W = advanced_mask_weighting.shape
assert B > 0 and C == 1 and H > 0 and W > 0
cnet.advanced_mask_weighting = advanced_mask_weighting
m = unet.clone()
m.add_patched_controlnet(cnet)
return m

View File

@@ -1,46 +0,0 @@
import time
import torch
import contextlib
from backend import memory_management
@contextlib.contextmanager
def automatic_memory_management():
memory_management.free_memory(
memory_required=3 * 1024 * 1024 * 1024,
device=memory_management.get_torch_device()
)
module_list = []
original_init = torch.nn.Module.__init__
original_to = torch.nn.Module.to
def patched_init(self, *args, **kwargs):
module_list.append(self)
return original_init(self, *args, **kwargs)
def patched_to(self, *args, **kwargs):
module_list.append(self)
return original_to(self, *args, **kwargs)
try:
torch.nn.Module.__init__ = patched_init
torch.nn.Module.to = patched_to
yield
finally:
torch.nn.Module.__init__ = original_init
torch.nn.Module.to = original_to
start = time.perf_counter()
module_list = set(module_list)
for module in module_list:
module.cpu()
memory_management.soft_empty_cache()
end = time.perf_counter()
print(f'Automatic Memory Management: {len(module_list)} Modules in {(end - start):.2f} seconds.')
return

View File

@@ -6,8 +6,7 @@ from huggingface_guess.utils import unet_to_diffusers
from backend import memory_management
from backend.operations import using_forge_operations
from backend.nn.cnets import cldm
from backend.patcher.controlnet import ControlLora, ControlNet, load_t2i_adapter
from modules_forge.controlnet import apply_controlnet_advanced
from backend.patcher.controlnet import ControlLora, ControlNet, load_t2i_adapter, apply_controlnet_advanced
from modules_forge.shared import add_supported_control_model