mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-05-02 12:11:34 +00:00
upload a cn
This commit is contained in:
@@ -0,0 +1,337 @@
|
||||
import unittest.mock
|
||||
import importlib
|
||||
from typing import Any
|
||||
|
||||
utils = importlib.import_module('extensions.sd-webui-controlnet.tests.utils', 'utils')
|
||||
|
||||
|
||||
from modules import processing, scripts, shared
|
||||
from scripts import controlnet, external_code, batch_hijack
|
||||
|
||||
|
||||
batch_hijack.instance.undo_hijack()
|
||||
original_process_images_inner = processing.process_images_inner
|
||||
|
||||
|
||||
class TestBatchHijack(unittest.TestCase):
|
||||
@unittest.mock.patch('modules.script_callbacks.on_script_unloaded')
|
||||
def setUp(self, on_script_unloaded_mock):
|
||||
self.on_script_unloaded_mock = on_script_unloaded_mock
|
||||
|
||||
self.batch_hijack_object = batch_hijack.BatchHijack()
|
||||
self.batch_hijack_object.do_hijack()
|
||||
|
||||
def tearDown(self):
|
||||
self.batch_hijack_object.undo_hijack()
|
||||
|
||||
def test_do_hijack__registers_on_script_unloaded(self):
|
||||
self.on_script_unloaded_mock.assert_called_once_with(self.batch_hijack_object.undo_hijack)
|
||||
|
||||
def test_do_hijack__call_once__hijacks_once(self):
|
||||
self.assertEqual(getattr(processing, '__controlnet_original_process_images_inner'), original_process_images_inner)
|
||||
self.assertEqual(processing.process_images_inner, self.batch_hijack_object.processing_process_images_hijack)
|
||||
|
||||
@unittest.mock.patch('modules.processing.__controlnet_original_process_images_inner')
|
||||
def test_do_hijack__multiple_times__hijacks_once(self, process_images_inner_mock):
|
||||
self.batch_hijack_object.do_hijack()
|
||||
self.batch_hijack_object.do_hijack()
|
||||
self.batch_hijack_object.do_hijack()
|
||||
self.assertEqual(process_images_inner_mock, getattr(processing, '__controlnet_original_process_images_inner'))
|
||||
|
||||
|
||||
class TestGetControlNetBatchesWorks(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.p = unittest.mock.MagicMock()
|
||||
assert scripts.scripts_txt2img is not None
|
||||
self.p.scripts = scripts.scripts_txt2img
|
||||
self.cn_script = controlnet.Script()
|
||||
self.p.scripts.alwayson_scripts = [self.cn_script]
|
||||
self.p.script_args = []
|
||||
|
||||
def tearDown(self):
|
||||
batch_hijack.instance.dispatch_callbacks(batch_hijack.instance.postprocess_batch_callbacks, self.p)
|
||||
|
||||
def assert_get_cn_batches_works(self, batch_images_list):
|
||||
self.cn_script.args_from = 0
|
||||
self.cn_script.args_to = self.cn_script.args_from + len(self.p.script_args)
|
||||
|
||||
is_cn_batch, batches, output_dir, _ = batch_hijack.get_cn_batches(self.p)
|
||||
batch_hijack.instance.dispatch_callbacks(batch_hijack.instance.process_batch_callbacks, self.p, batches, output_dir)
|
||||
|
||||
batch_units = [unit for unit in self.p.script_args if getattr(unit, 'input_mode', batch_hijack.InputMode.SIMPLE) == batch_hijack.InputMode.BATCH]
|
||||
if batch_units:
|
||||
self.assertEqual(min(len(list(unit.batch_images)) for unit in batch_units), len(batches))
|
||||
else:
|
||||
self.assertEqual(1, len(batches))
|
||||
|
||||
for i, unit in enumerate(self.cn_script.enabled_units):
|
||||
self.assertListEqual(batch_images_list[i], list(unit.batch_images))
|
||||
|
||||
def test_get_cn_batches__empty(self):
|
||||
is_batch, batches, _, _ = batch_hijack.get_cn_batches(self.p)
|
||||
self.assertEqual(1, len(batches))
|
||||
self.assertEqual(is_batch, False)
|
||||
|
||||
def test_get_cn_batches__1_simple(self):
|
||||
self.p.script_args.append(external_code.ControlNetUnit(image=get_dummy_image()))
|
||||
self.assert_get_cn_batches_works([
|
||||
[self.p.script_args[0].image],
|
||||
])
|
||||
|
||||
def test_get_cn_batches__2_simples(self):
|
||||
self.p.script_args.extend([
|
||||
external_code.ControlNetUnit(image=get_dummy_image(0)),
|
||||
external_code.ControlNetUnit(image=get_dummy_image(1)),
|
||||
])
|
||||
self.assert_get_cn_batches_works([
|
||||
[get_dummy_image(0)],
|
||||
[get_dummy_image(1)],
|
||||
])
|
||||
|
||||
def test_get_cn_batches__1_batch(self):
|
||||
self.p.script_args.extend([
|
||||
controlnet.UiControlNetUnit(
|
||||
input_mode=batch_hijack.InputMode.BATCH,
|
||||
batch_images=[
|
||||
get_dummy_image(0),
|
||||
get_dummy_image(1),
|
||||
],
|
||||
),
|
||||
])
|
||||
self.assert_get_cn_batches_works([
|
||||
[
|
||||
get_dummy_image(0),
|
||||
get_dummy_image(1),
|
||||
],
|
||||
])
|
||||
|
||||
def test_get_cn_batches__2_batches(self):
|
||||
self.p.script_args.extend([
|
||||
controlnet.UiControlNetUnit(
|
||||
input_mode=batch_hijack.InputMode.BATCH,
|
||||
batch_images=[
|
||||
get_dummy_image(0),
|
||||
get_dummy_image(1),
|
||||
],
|
||||
),
|
||||
controlnet.UiControlNetUnit(
|
||||
input_mode=batch_hijack.InputMode.BATCH,
|
||||
batch_images=[
|
||||
get_dummy_image(2),
|
||||
get_dummy_image(3),
|
||||
],
|
||||
),
|
||||
])
|
||||
self.assert_get_cn_batches_works([
|
||||
[
|
||||
get_dummy_image(0),
|
||||
get_dummy_image(1),
|
||||
],
|
||||
[
|
||||
get_dummy_image(2),
|
||||
get_dummy_image(3),
|
||||
],
|
||||
])
|
||||
|
||||
def test_get_cn_batches__2_mixed(self):
|
||||
self.p.script_args.extend([
|
||||
external_code.ControlNetUnit(image=get_dummy_image(0)),
|
||||
controlnet.UiControlNetUnit(
|
||||
input_mode=batch_hijack.InputMode.BATCH,
|
||||
batch_images=[
|
||||
get_dummy_image(1),
|
||||
get_dummy_image(2),
|
||||
],
|
||||
),
|
||||
])
|
||||
self.assert_get_cn_batches_works([
|
||||
[
|
||||
get_dummy_image(0),
|
||||
get_dummy_image(0),
|
||||
],
|
||||
[
|
||||
get_dummy_image(1),
|
||||
get_dummy_image(2),
|
||||
],
|
||||
])
|
||||
|
||||
def test_get_cn_batches__3_mixed(self):
|
||||
self.p.script_args.extend([
|
||||
external_code.ControlNetUnit(image=get_dummy_image(0)),
|
||||
controlnet.UiControlNetUnit(
|
||||
input_mode=batch_hijack.InputMode.BATCH,
|
||||
batch_images=[
|
||||
get_dummy_image(1),
|
||||
get_dummy_image(2),
|
||||
get_dummy_image(3),
|
||||
],
|
||||
),
|
||||
controlnet.UiControlNetUnit(
|
||||
input_mode=batch_hijack.InputMode.BATCH,
|
||||
batch_images=[
|
||||
get_dummy_image(4),
|
||||
get_dummy_image(5),
|
||||
],
|
||||
),
|
||||
])
|
||||
self.assert_get_cn_batches_works([
|
||||
[
|
||||
get_dummy_image(0),
|
||||
get_dummy_image(0),
|
||||
],
|
||||
[
|
||||
get_dummy_image(1),
|
||||
get_dummy_image(2),
|
||||
],
|
||||
[
|
||||
get_dummy_image(4),
|
||||
get_dummy_image(5),
|
||||
],
|
||||
])
|
||||
|
||||
class TestProcessImagesPatchWorks(unittest.TestCase):
|
||||
@unittest.mock.patch('modules.script_callbacks.on_script_unloaded')
|
||||
def setUp(self, on_script_unloaded_mock):
|
||||
self.on_script_unloaded_mock = on_script_unloaded_mock
|
||||
self.p = unittest.mock.MagicMock()
|
||||
assert scripts.scripts_txt2img is not None
|
||||
self.p.scripts = scripts.scripts_txt2img
|
||||
self.cn_script = controlnet.Script()
|
||||
self.p.scripts.alwayson_scripts = [self.cn_script]
|
||||
self.p.script_args = []
|
||||
self.p.all_seeds = [0]
|
||||
self.p.all_subseeds = [0]
|
||||
self.old_model, shared.sd_model = shared.sd_model, unittest.mock.MagicMock()
|
||||
|
||||
self.batch_hijack_object = batch_hijack.BatchHijack()
|
||||
self.callbacks_mock = unittest.mock.MagicMock()
|
||||
self.batch_hijack_object.process_batch_callbacks.append(self.callbacks_mock.process)
|
||||
self.batch_hijack_object.process_batch_each_callbacks.append(self.callbacks_mock.process_each)
|
||||
self.batch_hijack_object.postprocess_batch_each_callbacks.insert(0, self.callbacks_mock.postprocess_each)
|
||||
self.batch_hijack_object.postprocess_batch_callbacks.insert(0, self.callbacks_mock.postprocess)
|
||||
self.batch_hijack_object.do_hijack()
|
||||
shared.state.begin()
|
||||
|
||||
def tearDown(self):
|
||||
shared.state.end()
|
||||
self.batch_hijack_object.undo_hijack()
|
||||
shared.sd_model = self.old_model
|
||||
|
||||
@unittest.mock.patch('modules.processing.__controlnet_original_process_images_inner')
|
||||
def assert_process_images_hijack_called(self, process_images_mock, batch_count):
|
||||
process_images_mock.return_value = processing.Processed(self.p, [get_dummy_image('output')])
|
||||
with unittest.mock.patch.dict(shared.opts.data, {
|
||||
'controlnet_show_batch_images_in_ui': True,
|
||||
}):
|
||||
res = processing.process_images_inner(self.p)
|
||||
|
||||
self.assertEqual(res, process_images_mock.return_value)
|
||||
|
||||
if batch_count > 0:
|
||||
self.callbacks_mock.process.assert_called()
|
||||
self.callbacks_mock.postprocess.assert_called()
|
||||
else:
|
||||
self.callbacks_mock.process.assert_not_called()
|
||||
self.callbacks_mock.postprocess.assert_not_called()
|
||||
|
||||
self.assertEqual(self.callbacks_mock.process_each.call_count, batch_count)
|
||||
self.assertEqual(self.callbacks_mock.postprocess_each.call_count, batch_count)
|
||||
|
||||
def test_process_images_no_units_forwards(self):
|
||||
self.assert_process_images_hijack_called(batch_count=0)
|
||||
|
||||
def test_process_images__only_simple_units__forwards(self):
|
||||
self.p.script_args = [
|
||||
external_code.ControlNetUnit(image=get_dummy_image()),
|
||||
external_code.ControlNetUnit(image=get_dummy_image()),
|
||||
]
|
||||
self.assert_process_images_hijack_called(batch_count=0)
|
||||
|
||||
def test_process_images__1_batch_1_unit__runs_1_batch(self):
|
||||
self.p.script_args = [
|
||||
controlnet.UiControlNetUnit(
|
||||
input_mode=batch_hijack.InputMode.BATCH,
|
||||
batch_images=[
|
||||
get_dummy_image(),
|
||||
],
|
||||
),
|
||||
]
|
||||
self.assert_process_images_hijack_called(batch_count=1)
|
||||
|
||||
def test_process_images__2_batches_1_unit__runs_2_batches(self):
|
||||
self.p.script_args = [
|
||||
controlnet.UiControlNetUnit(
|
||||
input_mode=batch_hijack.InputMode.BATCH,
|
||||
batch_images=[
|
||||
get_dummy_image(0),
|
||||
get_dummy_image(1),
|
||||
],
|
||||
),
|
||||
]
|
||||
self.assert_process_images_hijack_called(batch_count=2)
|
||||
|
||||
def test_process_images__8_batches_1_unit__runs_8_batches(self):
|
||||
batch_count = 8
|
||||
self.p.script_args = [
|
||||
controlnet.UiControlNetUnit(
|
||||
input_mode=batch_hijack.InputMode.BATCH,
|
||||
batch_images=[get_dummy_image(i) for i in range(batch_count)]
|
||||
),
|
||||
]
|
||||
self.assert_process_images_hijack_called(batch_count=batch_count)
|
||||
|
||||
def test_process_images__1_batch_2_units__runs_1_batch(self):
|
||||
self.p.script_args = [
|
||||
controlnet.UiControlNetUnit(
|
||||
input_mode=batch_hijack.InputMode.BATCH,
|
||||
batch_images=[get_dummy_image(0)]
|
||||
),
|
||||
controlnet.UiControlNetUnit(
|
||||
input_mode=batch_hijack.InputMode.BATCH,
|
||||
batch_images=[get_dummy_image(1)]
|
||||
),
|
||||
]
|
||||
self.assert_process_images_hijack_called(batch_count=1)
|
||||
|
||||
def test_process_images__2_batches_2_units__runs_2_batches(self):
|
||||
self.p.script_args = [
|
||||
controlnet.UiControlNetUnit(
|
||||
input_mode=batch_hijack.InputMode.BATCH,
|
||||
batch_images=[
|
||||
get_dummy_image(0),
|
||||
get_dummy_image(1),
|
||||
],
|
||||
),
|
||||
controlnet.UiControlNetUnit(
|
||||
input_mode=batch_hijack.InputMode.BATCH,
|
||||
batch_images=[
|
||||
get_dummy_image(2),
|
||||
get_dummy_image(3),
|
||||
],
|
||||
),
|
||||
]
|
||||
self.assert_process_images_hijack_called(batch_count=2)
|
||||
|
||||
def test_process_images__3_batches_2_mixed_units__runs_3_batches(self):
|
||||
self.p.script_args = [
|
||||
controlnet.UiControlNetUnit(
|
||||
input_mode=batch_hijack.InputMode.BATCH,
|
||||
batch_images=[
|
||||
get_dummy_image(0),
|
||||
get_dummy_image(1),
|
||||
get_dummy_image(2),
|
||||
],
|
||||
),
|
||||
controlnet.UiControlNetUnit(
|
||||
input_mode=batch_hijack.InputMode.SIMPLE,
|
||||
image=get_dummy_image(3),
|
||||
),
|
||||
]
|
||||
self.assert_process_images_hijack_called(batch_count=3)
|
||||
|
||||
|
||||
def get_dummy_image(name: Any = 0):
|
||||
return f'base64#{name}...'
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
@@ -0,0 +1,184 @@
|
||||
from typing import Any, Dict, List
|
||||
import unittest
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
|
||||
import importlib
|
||||
|
||||
utils = importlib.import_module("extensions.sd-webui-controlnet.tests.utils", "utils")
|
||||
|
||||
|
||||
from scripts import external_code, processor
|
||||
from scripts.controlnet import prepare_mask, Script, set_numpy_seed
|
||||
from modules import processing
|
||||
|
||||
|
||||
class TestPrepareMask(unittest.TestCase):
|
||||
def test_prepare_mask(self):
|
||||
p = processing.StableDiffusionProcessing()
|
||||
p.inpainting_mask_invert = True
|
||||
p.mask_blur = 5
|
||||
|
||||
mask = Image.new("RGB", (10, 10), color="white")
|
||||
|
||||
processed_mask = prepare_mask(mask, p)
|
||||
|
||||
# Check that mask is correctly converted to grayscale
|
||||
self.assertTrue(processed_mask.mode, "L")
|
||||
|
||||
# Check that mask colors are correctly inverted
|
||||
self.assertEqual(
|
||||
processed_mask.getpixel((0, 0)), 0
|
||||
) # inverted white should be black
|
||||
|
||||
p.inpainting_mask_invert = False
|
||||
processed_mask = prepare_mask(mask, p)
|
||||
|
||||
# Check that mask colors are not inverted when 'inpainting_mask_invert' is False
|
||||
self.assertEqual(
|
||||
processed_mask.getpixel((0, 0)), 255
|
||||
) # white should remain white
|
||||
|
||||
p.mask_blur = 0
|
||||
mask = Image.new("RGB", (10, 10), color="black")
|
||||
processed_mask = prepare_mask(mask, p)
|
||||
|
||||
# Check that mask is not blurred when 'mask_blur' is 0
|
||||
self.assertEqual(
|
||||
processed_mask.getpixel((0, 0)), 0
|
||||
) # black should remain black
|
||||
|
||||
|
||||
class TestSetNumpySeed(unittest.TestCase):
|
||||
def test_seed_subseed_minus_one(self):
|
||||
p = processing.StableDiffusionProcessing()
|
||||
p.seed = -1
|
||||
p.subseed = -1
|
||||
p.all_seeds = [123, 456]
|
||||
expected_seed = (123 + 123) & 0xFFFFFFFF
|
||||
self.assertEqual(set_numpy_seed(p), expected_seed)
|
||||
|
||||
def test_valid_seed_subseed(self):
|
||||
p = processing.StableDiffusionProcessing()
|
||||
p.seed = 50
|
||||
p.subseed = 100
|
||||
p.all_seeds = [123, 456]
|
||||
expected_seed = (50 + 100) & 0xFFFFFFFF
|
||||
self.assertEqual(set_numpy_seed(p), expected_seed)
|
||||
|
||||
def test_invalid_seed_subseed(self):
|
||||
p = processing.StableDiffusionProcessing()
|
||||
p.seed = "invalid"
|
||||
p.subseed = 2.5
|
||||
p.all_seeds = [123, 456]
|
||||
self.assertEqual(set_numpy_seed(p), None)
|
||||
|
||||
def test_empty_all_seeds(self):
|
||||
p = processing.StableDiffusionProcessing()
|
||||
p.seed = -1
|
||||
p.subseed = 2
|
||||
p.all_seeds = []
|
||||
self.assertEqual(set_numpy_seed(p), None)
|
||||
|
||||
def test_random_state_change(self):
|
||||
p = processing.StableDiffusionProcessing()
|
||||
p.seed = 50
|
||||
p.subseed = 100
|
||||
p.all_seeds = [123, 456]
|
||||
expected_seed = (50 + 100) & 0xFFFFFFFF
|
||||
|
||||
np.random.seed(0) # set a known seed
|
||||
before_random = np.random.randint(0, 1000) # get a random integer
|
||||
|
||||
seed = set_numpy_seed(p)
|
||||
self.assertEqual(seed, expected_seed)
|
||||
|
||||
after_random = np.random.randint(0, 1000) # get another random integer
|
||||
|
||||
self.assertNotEqual(before_random, after_random)
|
||||
|
||||
|
||||
class MockImg2ImgProcessing(processing.StableDiffusionProcessing):
|
||||
"""Mock the Img2Img processing as the WebUI version have dependency on
|
||||
`sd_model`."""
|
||||
|
||||
def __init__(self, init_images, resize_mode, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.init_images = init_images
|
||||
self.resize_mode = resize_mode
|
||||
|
||||
|
||||
class TestScript(unittest.TestCase):
|
||||
sample_base64_image = (
|
||||
"data:image/png;base64,"
|
||||
"iVBORw0KGgoAAAANSUhEUgAAARMAAAC3CAIAAAC+MS2jAAAAqUlEQVR4nO3BAQ"
|
||||
"0AAADCoPdPbQ8HFAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA"
|
||||
"AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA"
|
||||
"AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA"
|
||||
"AAAAAAAAAAAAAAAAAAAAAAAA/wZOlAAB5tU+nAAAAABJRU5ErkJggg=="
|
||||
)
|
||||
|
||||
sample_np_image = np.array(
|
||||
[[100, 200, 50], [150, 75, 225], [30, 120, 180]], dtype=np.uint8
|
||||
)
|
||||
|
||||
def test_bound_check_params(self):
|
||||
def param_required(module: str, param: str) -> bool:
|
||||
configs = processor.preprocessor_sliders_config[module]
|
||||
config_index = ("processor_res", "threshold_a", "threshold_b").index(param)
|
||||
return config_index < len(configs) and configs[config_index] is not None
|
||||
|
||||
for module in processor.preprocessor_sliders_config.keys():
|
||||
for param in ("processor_res", "threshold_a", "threshold_b"):
|
||||
with self.subTest(param=param, module=module):
|
||||
unit = external_code.ControlNetUnit(
|
||||
module=module,
|
||||
**{param: -100},
|
||||
)
|
||||
Script.bound_check_params(unit)
|
||||
if param_required(module, param):
|
||||
self.assertGreaterEqual(getattr(unit, param), 0)
|
||||
else:
|
||||
self.assertEqual(getattr(unit, param), -100)
|
||||
|
||||
def test_choose_input_image(self):
|
||||
with self.subTest(name="no image"):
|
||||
with self.assertRaises(ValueError):
|
||||
Script.choose_input_image(
|
||||
p=processing.StableDiffusionProcessing(),
|
||||
unit=external_code.ControlNetUnit(),
|
||||
idx=0,
|
||||
)
|
||||
|
||||
with self.subTest(name="control net input"):
|
||||
_, resize_mode = Script.choose_input_image(
|
||||
p=MockImg2ImgProcessing(
|
||||
init_images=[TestScript.sample_np_image],
|
||||
resize_mode=external_code.ResizeMode.OUTER_FIT,
|
||||
),
|
||||
unit=external_code.ControlNetUnit(
|
||||
image=TestScript.sample_base64_image,
|
||||
module="none",
|
||||
resize_mode=external_code.ResizeMode.INNER_FIT,
|
||||
),
|
||||
idx=0,
|
||||
)
|
||||
self.assertEqual(resize_mode, external_code.ResizeMode.INNER_FIT)
|
||||
|
||||
with self.subTest(name="A1111 input"):
|
||||
_, resize_mode = Script.choose_input_image(
|
||||
p=MockImg2ImgProcessing(
|
||||
init_images=[TestScript.sample_np_image],
|
||||
resize_mode=external_code.ResizeMode.OUTER_FIT,
|
||||
),
|
||||
unit=external_code.ControlNetUnit(
|
||||
module="none",
|
||||
resize_mode=external_code.ResizeMode.INNER_FIT,
|
||||
),
|
||||
idx=0,
|
||||
)
|
||||
self.assertEqual(resize_mode, external_code.ResizeMode.OUTER_FIT)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -0,0 +1,67 @@
|
||||
import importlib
|
||||
utils = importlib.import_module("extensions.sd-webui-controlnet.tests.utils", "utils")
|
||||
|
||||
from scripts.global_state import select_control_type, ui_preprocessor_keys
|
||||
from scripts.enums import StableDiffusionVersion
|
||||
|
||||
|
||||
dummy_value = "dummy"
|
||||
cn_models = {
|
||||
"None": dummy_value,
|
||||
"canny_sd15": dummy_value,
|
||||
"canny_sdxl": dummy_value,
|
||||
}
|
||||
|
||||
|
||||
# Tests for the select_control_type function
|
||||
class TestSelectControlType:
|
||||
def test_all_control_type(self):
|
||||
result = select_control_type("All", cn_models=cn_models)
|
||||
assert result == (
|
||||
[ui_preprocessor_keys, list(cn_models.keys()), "none", "None"]
|
||||
), "Expected all preprocessors and models"
|
||||
|
||||
def test_sd_version(self):
|
||||
(_, filtered_model_list, _, default_model) = select_control_type(
|
||||
"Canny", sd_version=StableDiffusionVersion.UNKNOWN, cn_models=cn_models
|
||||
)
|
||||
assert filtered_model_list == [
|
||||
"None",
|
||||
"canny_sd15",
|
||||
"canny_sdxl",
|
||||
], "UNKNOWN sd version should match all models"
|
||||
assert default_model == "canny_sd15"
|
||||
|
||||
(_, filtered_model_list, _, default_model) = select_control_type(
|
||||
"Canny", sd_version=StableDiffusionVersion.SD1x, cn_models=cn_models
|
||||
)
|
||||
assert filtered_model_list == [
|
||||
"None",
|
||||
"canny_sd15",
|
||||
], "sd1x version should only sd1x"
|
||||
assert default_model == "canny_sd15"
|
||||
|
||||
(_, filtered_model_list, _, default_model) = select_control_type(
|
||||
"Canny", sd_version=StableDiffusionVersion.SDXL, cn_models=cn_models
|
||||
)
|
||||
assert filtered_model_list == [
|
||||
"None",
|
||||
"canny_sdxl",
|
||||
], "sdxl version should only sdxl"
|
||||
assert default_model == "canny_sdxl"
|
||||
|
||||
def test_invert_preprocessor(self):
|
||||
for control_type in ("Canny", "Lineart", "Scribble/Sketch", "MLSD"):
|
||||
filtered_preprocessor_list, _, _, _ = select_control_type(
|
||||
control_type, cn_models=cn_models
|
||||
)
|
||||
assert any(
|
||||
"invert" in module.lower() for module in filtered_preprocessor_list
|
||||
)
|
||||
|
||||
def test_no_module_available(self):
|
||||
(_, filtered_model_list, _, default_model) = select_control_type(
|
||||
"Depth", cn_models=cn_models
|
||||
)
|
||||
assert filtered_model_list == ["None"]
|
||||
assert default_model == "None"
|
||||
@@ -0,0 +1,34 @@
|
||||
import unittest
|
||||
import importlib
|
||||
|
||||
utils = importlib.import_module("extensions.sd-webui-controlnet.tests.utils", "utils")
|
||||
|
||||
from scripts.infotext import parse_unit
|
||||
from scripts.external_code import ControlNetUnit
|
||||
|
||||
|
||||
class TestInfotext(unittest.TestCase):
|
||||
def test_parsing(self):
|
||||
infotext = (
|
||||
"Module: inpaint_only+lama, Model: control_v11p_sd15_inpaint [ebff9138], Weight: 1, "
|
||||
"Resize Mode: Resize and Fill, Low Vram: False, Guidance Start: 0, Guidance End: 1, "
|
||||
"Pixel Perfect: True, Control Mode: Balanced, Hr Option: Both, Save Detected Map: True"
|
||||
)
|
||||
self.assertEqual(
|
||||
vars(
|
||||
ControlNetUnit(
|
||||
module="inpaint_only+lama",
|
||||
model="control_v11p_sd15_inpaint [ebff9138]",
|
||||
weight=1,
|
||||
resize_mode="Resize and Fill",
|
||||
low_vram=False,
|
||||
guidance_start=0,
|
||||
guidance_end=1,
|
||||
pixel_perfect=True,
|
||||
control_mode="Balanced",
|
||||
hr_option="Both",
|
||||
save_detected_map=True,
|
||||
)
|
||||
),
|
||||
vars(parse_unit(infotext)),
|
||||
)
|
||||
@@ -0,0 +1,75 @@
|
||||
import importlib
|
||||
utils = importlib.import_module('extensions.sd-webui-controlnet.tests.utils', 'utils')
|
||||
|
||||
|
||||
from scripts.utils import ndarray_lru_cache, get_unique_axis0
|
||||
|
||||
import unittest
|
||||
import numpy as np
|
||||
|
||||
class TestNumpyLruCache(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.arr1 = np.array([1, 2, 3, 4, 5])
|
||||
self.arr2 = np.array([1, 2, 3, 4, 5])
|
||||
|
||||
@ndarray_lru_cache(max_size=128)
|
||||
def add_one(self, arr):
|
||||
return arr + 1
|
||||
|
||||
def test_same_array(self):
|
||||
# Test that the decorator works with numpy arrays.
|
||||
result1 = self.add_one(self.arr1)
|
||||
result2 = self.add_one(self.arr1)
|
||||
|
||||
# If caching is working correctly, these should be the same object.
|
||||
self.assertIs(result1, result2)
|
||||
|
||||
def test_different_array_same_data(self):
|
||||
# Test that the decorator works with different numpy arrays with the same data.
|
||||
result1 = self.add_one(self.arr1)
|
||||
result2 = self.add_one(self.arr2)
|
||||
|
||||
# If caching is working correctly, these should be the same object.
|
||||
self.assertIs(result1, result2)
|
||||
|
||||
def test_cache_size(self):
|
||||
# Test that the cache size limit is respected.
|
||||
arrs = [np.array([i]) for i in range(150)]
|
||||
|
||||
# Add all arrays to the cache.
|
||||
|
||||
result1 = self.add_one(arrs[0])
|
||||
for arr in arrs[1:]:
|
||||
self.add_one(arr)
|
||||
|
||||
# Check that the first array is no longer in the cache.
|
||||
result2 = self.add_one(arrs[0])
|
||||
|
||||
# If the cache size limit is working correctly, these should not be the same object.
|
||||
self.assertIsNot(result1, result2)
|
||||
|
||||
def test_large_array(self):
|
||||
# Create two large arrays with the same elements in the beginning and end, but one different element in the middle.
|
||||
arr1 = np.ones(10000)
|
||||
arr2 = np.ones(10000)
|
||||
arr2[len(arr2)//2] = 0
|
||||
|
||||
result1 = self.add_one(arr1)
|
||||
result2 = self.add_one(arr2)
|
||||
|
||||
# If hashing is working correctly, these should not be the same object because the input arrays are not equal.
|
||||
self.assertIsNot(result1, result2)
|
||||
|
||||
class TestUniqueFunctions(unittest.TestCase):
|
||||
def test_get_unique_axis0(self):
|
||||
data = np.random.randint(0, 100, size=(100000, 3))
|
||||
data = np.concatenate((data, data))
|
||||
numpy_unique_res = np.unique(data, axis=0)
|
||||
get_unique_axis0_res = get_unique_axis0(data)
|
||||
self.assertEqual(np.array_equal(
|
||||
np.sort(numpy_unique_res, axis=0), np.sort(get_unique_axis0_res, axis=0),
|
||||
), True)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user