diff --git a/README.md b/README.md index 4ddb8603..3434e344 100644 --- a/README.md +++ b/README.md @@ -350,14 +350,12 @@ The memory optimization in this example is fully automatic. You do not need to c ```python # Use --show-controlnet-example to see this extension. -import os import cv2 import gradio as gr -import numpy as np from modules import scripts from modules.shared_cmd_options import cmd_opts -from modules_forge.shared import shared_preprocessors +from modules_forge.shared import supported_preprocessors from modules.modelloader import load_file_from_url from ldm_patched.modules.controlnet import load_controlnet from modules_forge.controlnet import apply_controlnet_advanced @@ -425,7 +423,7 @@ class ControlNetExampleForge(scripts.Script): width = W * 8 batch_size = p.batch_size - preprocessor = shared_preprocessors['canny'] + preprocessor = supported_preprocessors['canny'] # detect control at certain resolution control_image = preprocessor( @@ -518,7 +516,8 @@ Your preprocessor will be read by all other extensions using `modules_forge.shar Below codes are in `extensions-builtin\forge_preprocessor_normalbae\scripts\preprocessor_normalbae.py` ```python -from modules_forge.shared import Preprocessor, PreprocessorParameter, preprocessor_dir, add_preprocessor +from modules_forge.supported_preprocessor import Preprocessor, PreprocessorParameter +from modules_forge.shared import preprocessor_dir, add_supported_preprocessor from modules_forge.forge_util import resize_image_with_pad from modules.modelloader import load_file_from_url @@ -537,13 +536,15 @@ class PreprocessorNormalBae(Preprocessor): super().__init__() self.name = 'normalbae' self.tags = ['NormalMap'] - self.slider_resolution = PreprocessorParameter(label='Resolution', minimum=128, maximum=2048, value=512, step=8, visible=True) + self.model_filename_filers = ['normal'] + self.slider_resolution = PreprocessorParameter( + label='Resolution', minimum=128, maximum=2048, value=512, step=8, visible=True) self.slider_1 = PreprocessorParameter(visible=False) self.slider_2 = PreprocessorParameter(visible=False) self.slider_3 = PreprocessorParameter(visible=False) self.show_control_mode = True self.do_not_need_model = False - self.sorting_priority = 0.0 # higher goes to top in the list + self.sorting_priority = 100 # higher goes to top in the list def load_model(self): if self.model_patcher is not None: @@ -591,7 +592,7 @@ class PreprocessorNormalBae(Preprocessor): return remove_pad(normal_image) -add_preprocessor(PreprocessorNormalBae) +add_supported_preprocessor(PreprocessorNormalBae()) ``` diff --git a/extensions-builtin/forge_legacy_preprocessors/scripts/legacy_preprocessors.py b/extensions-builtin/forge_legacy_preprocessors/scripts/legacy_preprocessors.py index f4a9bb88..1db8b9eb 100644 --- a/extensions-builtin/forge_legacy_preprocessors/scripts/legacy_preprocessors.py +++ b/extensions-builtin/forge_legacy_preprocessors/scripts/legacy_preprocessors.py @@ -15,7 +15,26 @@ import contextlib from annotator.util import HWC3 from modules_forge.ops import automatic_memory_management from legacy_preprocessors.preprocessor_compiled import legacy_preprocessors -from modules_forge.shared import Preprocessor, PreprocessorParameter, add_preprocessor +from modules_forge.supported_preprocessor import Preprocessor, PreprocessorParameter +from modules_forge.shared import add_supported_preprocessor + + +### + +# This file has lots of unreasonable historical designs and should be viewed as a frozen blackbox library + +# If you want to add preprocessor, +# please instead look at `extensions-builtin/forge_preprocessor_normalbae/scripts/preprocessor_normalbae` +# If you want to use preprocessor, +# please instead use `from modules_forge.shared import supported_preprocessors` +# and then use any preprocessor like: depth_midas = supported_preprocessors['depth_midas'] + +# Please do not hack/edit/modify/rely-on any codes in this file. + +# Never use methods in this file to add anything! +# This file will be eventually removed but the workload is super high and we need more time to do this. + +### class LegacyPreprocessor(Preprocessor): @@ -30,6 +49,21 @@ class LegacyPreprocessor(Preprocessor): self.sorting_priority = legacy_dict['priority'] self.tags = legacy_dict['tags'] + filters_aliases = { + 'instructp2p': ['ip2p'], + 'segmentation': ['seg'], + 'normalmap': ['normal'], + 't2i-adapter': ['t2i_adapter', 't2iadapter', 't2ia'], + 'ip-adapter': ['ip_adapter', 'ipadapter'], + 'openpose': ['openpose', 'densepose'], + } + + self.model_filename_filers = [] + for tag in self.tags: + tag_lower = tag.lower() + self.model_filename_filers.append(tag_lower) + self.model_filename_filers += filters_aliases.get(tag_lower, []) + if legacy_dict['resolution'] is None: self.resolution = PreprocessorParameter(visible=False) else: @@ -76,4 +110,4 @@ class LegacyPreprocessor(Preprocessor): for k, v in legacy_preprocessors.items(): p = LegacyPreprocessor(v) p.name = k - add_preprocessor(p) + add_supported_preprocessor(p) diff --git a/extensions-builtin/forge_preprocessor_normalbae/scripts/preprocessor_normalbae.py b/extensions-builtin/forge_preprocessor_normalbae/scripts/preprocessor_normalbae.py index 496b5546..2c99e396 100644 --- a/extensions-builtin/forge_preprocessor_normalbae/scripts/preprocessor_normalbae.py +++ b/extensions-builtin/forge_preprocessor_normalbae/scripts/preprocessor_normalbae.py @@ -1,4 +1,5 @@ -from modules_forge.shared import Preprocessor, PreprocessorParameter, preprocessor_dir, add_preprocessor +from modules_forge.supported_preprocessor import Preprocessor, PreprocessorParameter +from modules_forge.shared import preprocessor_dir, add_supported_preprocessor from modules_forge.forge_util import resize_image_with_pad from modules.modelloader import load_file_from_url @@ -17,7 +18,9 @@ class PreprocessorNormalBae(Preprocessor): super().__init__() self.name = 'normalbae' self.tags = ['NormalMap'] - self.slider_resolution = PreprocessorParameter(label='Resolution', minimum=128, maximum=2048, value=512, step=8, visible=True) + self.model_filename_filers = ['normal'] + self.slider_resolution = PreprocessorParameter( + label='Resolution', minimum=128, maximum=2048, value=512, step=8, visible=True) self.slider_1 = PreprocessorParameter(visible=False) self.slider_2 = PreprocessorParameter(visible=False) self.slider_3 = PreprocessorParameter(visible=False) @@ -71,4 +74,4 @@ class PreprocessorNormalBae(Preprocessor): return remove_pad(normal_image) -add_preprocessor(PreprocessorNormalBae()) +add_supported_preprocessor(PreprocessorNormalBae()) diff --git a/extensions-builtin/sd_forge_controlnet/.gitignore b/extensions-builtin/sd_forge_controlnet/.gitignore new file mode 100644 index 00000000..60d06e51 --- /dev/null +++ b/extensions-builtin/sd_forge_controlnet/.gitignore @@ -0,0 +1,185 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +#pdm.lock +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. +# https://pdm.fming.dev/#use-with-ide +.pdm.toml + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +#.idea +*.pt +*.pth +*.ckpt +*.bin +*.safetensors + +# Editor setting metadata +.idea/ +.vscode/ +detected_maps/ +annotator/downloads/ + +# test results and expectations +web_tests/results/ +web_tests/expectations/ +tests/web_api/full_coverage/results/ +tests/web_api/full_coverage/expectations/ + +*_diff.png + +# Presets +presets/ + +# Ignore existing dir of hand refiner if exists. +annotator/hand_refiner_portable \ No newline at end of file diff --git a/extensions-builtin/sd_forge_controlnet/LICENSE b/extensions-builtin/sd_forge_controlnet/LICENSE new file mode 100644 index 00000000..f288702d --- /dev/null +++ b/extensions-builtin/sd_forge_controlnet/LICENSE @@ -0,0 +1,674 @@ + GNU GENERAL PUBLIC LICENSE + Version 3, 29 June 2007 + + Copyright (C) 2007 Free Software Foundation, Inc. + Everyone is permitted to copy and distribute verbatim copies + of this license document, but changing it is not allowed. + + Preamble + + The GNU General Public License is a free, copyleft license for +software and other kinds of works. + + The licenses for most software and other practical works are designed +to take away your freedom to share and change the works. By contrast, +the GNU General Public License is intended to guarantee your freedom to +share and change all versions of a program--to make sure it remains free +software for all its users. We, the Free Software Foundation, use the +GNU General Public License for most of our software; it applies also to +any other work released this way by its authors. You can apply it to +your programs, too. + + When we speak of free software, we are referring to freedom, not +price. Our General Public Licenses are designed to make sure that you +have the freedom to distribute copies of free software (and charge for +them if you wish), that you receive source code or can get it if you +want it, that you can change the software or use pieces of it in new +free programs, and that you know you can do these things. + + To protect your rights, we need to prevent others from denying you +these rights or asking you to surrender the rights. Therefore, you have +certain responsibilities if you distribute copies of the software, or if +you modify it: responsibilities to respect the freedom of others. + + For example, if you distribute copies of such a program, whether +gratis or for a fee, you must pass on to the recipients the same +freedoms that you received. You must make sure that they, too, receive +or can get the source code. And you must show them these terms so they +know their rights. + + Developers that use the GNU GPL protect your rights with two steps: +(1) assert copyright on the software, and (2) offer you this License +giving you legal permission to copy, distribute and/or modify it. + + For the developers' and authors' protection, the GPL clearly explains +that there is no warranty for this free software. For both users' and +authors' sake, the GPL requires that modified versions be marked as +changed, so that their problems will not be attributed erroneously to +authors of previous versions. + + Some devices are designed to deny users access to install or run +modified versions of the software inside them, although the manufacturer +can do so. This is fundamentally incompatible with the aim of +protecting users' freedom to change the software. The systematic +pattern of such abuse occurs in the area of products for individuals to +use, which is precisely where it is most unacceptable. Therefore, we +have designed this version of the GPL to prohibit the practice for those +products. If such problems arise substantially in other domains, we +stand ready to extend this provision to those domains in future versions +of the GPL, as needed to protect the freedom of users. + + Finally, every program is threatened constantly by software patents. +States should not allow patents to restrict development and use of +software on general-purpose computers, but in those that do, we wish to +avoid the special danger that patents applied to a free program could +make it effectively proprietary. To prevent this, the GPL assures that +patents cannot be used to render the program non-free. + + The precise terms and conditions for copying, distribution and +modification follow. + + TERMS AND CONDITIONS + + 0. Definitions. + + "This License" refers to version 3 of the GNU General Public License. + + "Copyright" also means copyright-like laws that apply to other kinds of +works, such as semiconductor masks. + + "The Program" refers to any copyrightable work licensed under this +License. Each licensee is addressed as "you". "Licensees" and +"recipients" may be individuals or organizations. + + To "modify" a work means to copy from or adapt all or part of the work +in a fashion requiring copyright permission, other than the making of an +exact copy. The resulting work is called a "modified version" of the +earlier work or a work "based on" the earlier work. + + A "covered work" means either the unmodified Program or a work based +on the Program. + + To "propagate" a work means to do anything with it that, without +permission, would make you directly or secondarily liable for +infringement under applicable copyright law, except executing it on a +computer or modifying a private copy. Propagation includes copying, +distribution (with or without modification), making available to the +public, and in some countries other activities as well. + + To "convey" a work means any kind of propagation that enables other +parties to make or receive copies. Mere interaction with a user through +a computer network, with no transfer of a copy, is not conveying. + + An interactive user interface displays "Appropriate Legal Notices" +to the extent that it includes a convenient and prominently visible +feature that (1) displays an appropriate copyright notice, and (2) +tells the user that there is no warranty for the work (except to the +extent that warranties are provided), that licensees may convey the +work under this License, and how to view a copy of this License. If +the interface presents a list of user commands or options, such as a +menu, a prominent item in the list meets this criterion. + + 1. Source Code. + + The "source code" for a work means the preferred form of the work +for making modifications to it. "Object code" means any non-source +form of a work. + + A "Standard Interface" means an interface that either is an official +standard defined by a recognized standards body, or, in the case of +interfaces specified for a particular programming language, one that +is widely used among developers working in that language. + + The "System Libraries" of an executable work include anything, other +than the work as a whole, that (a) is included in the normal form of +packaging a Major Component, but which is not part of that Major +Component, and (b) serves only to enable use of the work with that +Major Component, or to implement a Standard Interface for which an +implementation is available to the public in source code form. A +"Major Component", in this context, means a major essential component +(kernel, window system, and so on) of the specific operating system +(if any) on which the executable work runs, or a compiler used to +produce the work, or an object code interpreter used to run it. + + The "Corresponding Source" for a work in object code form means all +the source code needed to generate, install, and (for an executable +work) run the object code and to modify the work, including scripts to +control those activities. However, it does not include the work's +System Libraries, or general-purpose tools or generally available free +programs which are used unmodified in performing those activities but +which are not part of the work. For example, Corresponding Source +includes interface definition files associated with source files for +the work, and the source code for shared libraries and dynamically +linked subprograms that the work is specifically designed to require, +such as by intimate data communication or control flow between those +subprograms and other parts of the work. + + The Corresponding Source need not include anything that users +can regenerate automatically from other parts of the Corresponding +Source. + + The Corresponding Source for a work in source code form is that +same work. + + 2. Basic Permissions. + + All rights granted under this License are granted for the term of +copyright on the Program, and are irrevocable provided the stated +conditions are met. This License explicitly affirms your unlimited +permission to run the unmodified Program. The output from running a +covered work is covered by this License only if the output, given its +content, constitutes a covered work. This License acknowledges your +rights of fair use or other equivalent, as provided by copyright law. + + You may make, run and propagate covered works that you do not +convey, without conditions so long as your license otherwise remains +in force. You may convey covered works to others for the sole purpose +of having them make modifications exclusively for you, or provide you +with facilities for running those works, provided that you comply with +the terms of this License in conveying all material for which you do +not control copyright. Those thus making or running the covered works +for you must do so exclusively on your behalf, under your direction +and control, on terms that prohibit them from making any copies of +your copyrighted material outside their relationship with you. + + Conveying under any other circumstances is permitted solely under +the conditions stated below. Sublicensing is not allowed; section 10 +makes it unnecessary. + + 3. Protecting Users' Legal Rights From Anti-Circumvention Law. + + No covered work shall be deemed part of an effective technological +measure under any applicable law fulfilling obligations under article +11 of the WIPO copyright treaty adopted on 20 December 1996, or +similar laws prohibiting or restricting circumvention of such +measures. + + When you convey a covered work, you waive any legal power to forbid +circumvention of technological measures to the extent such circumvention +is effected by exercising rights under this License with respect to +the covered work, and you disclaim any intention to limit operation or +modification of the work as a means of enforcing, against the work's +users, your or third parties' legal rights to forbid circumvention of +technological measures. + + 4. Conveying Verbatim Copies. + + You may convey verbatim copies of the Program's source code as you +receive it, in any medium, provided that you conspicuously and +appropriately publish on each copy an appropriate copyright notice; +keep intact all notices stating that this License and any +non-permissive terms added in accord with section 7 apply to the code; +keep intact all notices of the absence of any warranty; and give all +recipients a copy of this License along with the Program. + + You may charge any price or no price for each copy that you convey, +and you may offer support or warranty protection for a fee. + + 5. Conveying Modified Source Versions. + + You may convey a work based on the Program, or the modifications to +produce it from the Program, in the form of source code under the +terms of section 4, provided that you also meet all of these conditions: + + a) The work must carry prominent notices stating that you modified + it, and giving a relevant date. + + b) The work must carry prominent notices stating that it is + released under this License and any conditions added under section + 7. This requirement modifies the requirement in section 4 to + "keep intact all notices". + + c) You must license the entire work, as a whole, under this + License to anyone who comes into possession of a copy. This + License will therefore apply, along with any applicable section 7 + additional terms, to the whole of the work, and all its parts, + regardless of how they are packaged. This License gives no + permission to license the work in any other way, but it does not + invalidate such permission if you have separately received it. + + d) If the work has interactive user interfaces, each must display + Appropriate Legal Notices; however, if the Program has interactive + interfaces that do not display Appropriate Legal Notices, your + work need not make them do so. + + A compilation of a covered work with other separate and independent +works, which are not by their nature extensions of the covered work, +and which are not combined with it such as to form a larger program, +in or on a volume of a storage or distribution medium, is called an +"aggregate" if the compilation and its resulting copyright are not +used to limit the access or legal rights of the compilation's users +beyond what the individual works permit. Inclusion of a covered work +in an aggregate does not cause this License to apply to the other +parts of the aggregate. + + 6. Conveying Non-Source Forms. + + You may convey a covered work in object code form under the terms +of sections 4 and 5, provided that you also convey the +machine-readable Corresponding Source under the terms of this License, +in one of these ways: + + a) Convey the object code in, or embodied in, a physical product + (including a physical distribution medium), accompanied by the + Corresponding Source fixed on a durable physical medium + customarily used for software interchange. + + b) Convey the object code in, or embodied in, a physical product + (including a physical distribution medium), accompanied by a + written offer, valid for at least three years and valid for as + long as you offer spare parts or customer support for that product + model, to give anyone who possesses the object code either (1) a + copy of the Corresponding Source for all the software in the + product that is covered by this License, on a durable physical + medium customarily used for software interchange, for a price no + more than your reasonable cost of physically performing this + conveying of source, or (2) access to copy the + Corresponding Source from a network server at no charge. + + c) Convey individual copies of the object code with a copy of the + written offer to provide the Corresponding Source. This + alternative is allowed only occasionally and noncommercially, and + only if you received the object code with such an offer, in accord + with subsection 6b. + + d) Convey the object code by offering access from a designated + place (gratis or for a charge), and offer equivalent access to the + Corresponding Source in the same way through the same place at no + further charge. You need not require recipients to copy the + Corresponding Source along with the object code. If the place to + copy the object code is a network server, the Corresponding Source + may be on a different server (operated by you or a third party) + that supports equivalent copying facilities, provided you maintain + clear directions next to the object code saying where to find the + Corresponding Source. Regardless of what server hosts the + Corresponding Source, you remain obligated to ensure that it is + available for as long as needed to satisfy these requirements. + + e) Convey the object code using peer-to-peer transmission, provided + you inform other peers where the object code and Corresponding + Source of the work are being offered to the general public at no + charge under subsection 6d. + + A separable portion of the object code, whose source code is excluded +from the Corresponding Source as a System Library, need not be +included in conveying the object code work. + + A "User Product" is either (1) a "consumer product", which means any +tangible personal property which is normally used for personal, family, +or household purposes, or (2) anything designed or sold for incorporation +into a dwelling. In determining whether a product is a consumer product, +doubtful cases shall be resolved in favor of coverage. For a particular +product received by a particular user, "normally used" refers to a +typical or common use of that class of product, regardless of the status +of the particular user or of the way in which the particular user +actually uses, or expects or is expected to use, the product. A product +is a consumer product regardless of whether the product has substantial +commercial, industrial or non-consumer uses, unless such uses represent +the only significant mode of use of the product. + + "Installation Information" for a User Product means any methods, +procedures, authorization keys, or other information required to install +and execute modified versions of a covered work in that User Product from +a modified version of its Corresponding Source. The information must +suffice to ensure that the continued functioning of the modified object +code is in no case prevented or interfered with solely because +modification has been made. + + If you convey an object code work under this section in, or with, or +specifically for use in, a User Product, and the conveying occurs as +part of a transaction in which the right of possession and use of the +User Product is transferred to the recipient in perpetuity or for a +fixed term (regardless of how the transaction is characterized), the +Corresponding Source conveyed under this section must be accompanied +by the Installation Information. But this requirement does not apply +if neither you nor any third party retains the ability to install +modified object code on the User Product (for example, the work has +been installed in ROM). + + The requirement to provide Installation Information does not include a +requirement to continue to provide support service, warranty, or updates +for a work that has been modified or installed by the recipient, or for +the User Product in which it has been modified or installed. Access to a +network may be denied when the modification itself materially and +adversely affects the operation of the network or violates the rules and +protocols for communication across the network. + + Corresponding Source conveyed, and Installation Information provided, +in accord with this section must be in a format that is publicly +documented (and with an implementation available to the public in +source code form), and must require no special password or key for +unpacking, reading or copying. + + 7. Additional Terms. + + "Additional permissions" are terms that supplement the terms of this +License by making exceptions from one or more of its conditions. +Additional permissions that are applicable to the entire Program shall +be treated as though they were included in this License, to the extent +that they are valid under applicable law. If additional permissions +apply only to part of the Program, that part may be used separately +under those permissions, but the entire Program remains governed by +this License without regard to the additional permissions. + + When you convey a copy of a covered work, you may at your option +remove any additional permissions from that copy, or from any part of +it. (Additional permissions may be written to require their own +removal in certain cases when you modify the work.) You may place +additional permissions on material, added by you to a covered work, +for which you have or can give appropriate copyright permission. + + Notwithstanding any other provision of this License, for material you +add to a covered work, you may (if authorized by the copyright holders of +that material) supplement the terms of this License with terms: + + a) Disclaiming warranty or limiting liability differently from the + terms of sections 15 and 16 of this License; or + + b) Requiring preservation of specified reasonable legal notices or + author attributions in that material or in the Appropriate Legal + Notices displayed by works containing it; or + + c) Prohibiting misrepresentation of the origin of that material, or + requiring that modified versions of such material be marked in + reasonable ways as different from the original version; or + + d) Limiting the use for publicity purposes of names of licensors or + authors of the material; or + + e) Declining to grant rights under trademark law for use of some + trade names, trademarks, or service marks; or + + f) Requiring indemnification of licensors and authors of that + material by anyone who conveys the material (or modified versions of + it) with contractual assumptions of liability to the recipient, for + any liability that these contractual assumptions directly impose on + those licensors and authors. + + All other non-permissive additional terms are considered "further +restrictions" within the meaning of section 10. If the Program as you +received it, or any part of it, contains a notice stating that it is +governed by this License along with a term that is a further +restriction, you may remove that term. If a license document contains +a further restriction but permits relicensing or conveying under this +License, you may add to a covered work material governed by the terms +of that license document, provided that the further restriction does +not survive such relicensing or conveying. + + If you add terms to a covered work in accord with this section, you +must place, in the relevant source files, a statement of the +additional terms that apply to those files, or a notice indicating +where to find the applicable terms. + + Additional terms, permissive or non-permissive, may be stated in the +form of a separately written license, or stated as exceptions; +the above requirements apply either way. + + 8. Termination. + + You may not propagate or modify a covered work except as expressly +provided under this License. Any attempt otherwise to propagate or +modify it is void, and will automatically terminate your rights under +this License (including any patent licenses granted under the third +paragraph of section 11). + + However, if you cease all violation of this License, then your +license from a particular copyright holder is reinstated (a) +provisionally, unless and until the copyright holder explicitly and +finally terminates your license, and (b) permanently, if the copyright +holder fails to notify you of the violation by some reasonable means +prior to 60 days after the cessation. + + Moreover, your license from a particular copyright holder is +reinstated permanently if the copyright holder notifies you of the +violation by some reasonable means, this is the first time you have +received notice of violation of this License (for any work) from that +copyright holder, and you cure the violation prior to 30 days after +your receipt of the notice. + + Termination of your rights under this section does not terminate the +licenses of parties who have received copies or rights from you under +this License. If your rights have been terminated and not permanently +reinstated, you do not qualify to receive new licenses for the same +material under section 10. + + 9. Acceptance Not Required for Having Copies. + + You are not required to accept this License in order to receive or +run a copy of the Program. Ancillary propagation of a covered work +occurring solely as a consequence of using peer-to-peer transmission +to receive a copy likewise does not require acceptance. However, +nothing other than this License grants you permission to propagate or +modify any covered work. These actions infringe copyright if you do +not accept this License. Therefore, by modifying or propagating a +covered work, you indicate your acceptance of this License to do so. + + 10. Automatic Licensing of Downstream Recipients. + + Each time you convey a covered work, the recipient automatically +receives a license from the original licensors, to run, modify and +propagate that work, subject to this License. You are not responsible +for enforcing compliance by third parties with this License. + + An "entity transaction" is a transaction transferring control of an +organization, or substantially all assets of one, or subdividing an +organization, or merging organizations. If propagation of a covered +work results from an entity transaction, each party to that +transaction who receives a copy of the work also receives whatever +licenses to the work the party's predecessor in interest had or could +give under the previous paragraph, plus a right to possession of the +Corresponding Source of the work from the predecessor in interest, if +the predecessor has it or can get it with reasonable efforts. + + You may not impose any further restrictions on the exercise of the +rights granted or affirmed under this License. For example, you may +not impose a license fee, royalty, or other charge for exercise of +rights granted under this License, and you may not initiate litigation +(including a cross-claim or counterclaim in a lawsuit) alleging that +any patent claim is infringed by making, using, selling, offering for +sale, or importing the Program or any portion of it. + + 11. Patents. + + A "contributor" is a copyright holder who authorizes use under this +License of the Program or a work on which the Program is based. The +work thus licensed is called the contributor's "contributor version". + + A contributor's "essential patent claims" are all patent claims +owned or controlled by the contributor, whether already acquired or +hereafter acquired, that would be infringed by some manner, permitted +by this License, of making, using, or selling its contributor version, +but do not include claims that would be infringed only as a +consequence of further modification of the contributor version. For +purposes of this definition, "control" includes the right to grant +patent sublicenses in a manner consistent with the requirements of +this License. + + Each contributor grants you a non-exclusive, worldwide, royalty-free +patent license under the contributor's essential patent claims, to +make, use, sell, offer for sale, import and otherwise run, modify and +propagate the contents of its contributor version. + + In the following three paragraphs, a "patent license" is any express +agreement or commitment, however denominated, not to enforce a patent +(such as an express permission to practice a patent or covenant not to +sue for patent infringement). To "grant" such a patent license to a +party means to make such an agreement or commitment not to enforce a +patent against the party. + + If you convey a covered work, knowingly relying on a patent license, +and the Corresponding Source of the work is not available for anyone +to copy, free of charge and under the terms of this License, through a +publicly available network server or other readily accessible means, +then you must either (1) cause the Corresponding Source to be so +available, or (2) arrange to deprive yourself of the benefit of the +patent license for this particular work, or (3) arrange, in a manner +consistent with the requirements of this License, to extend the patent +license to downstream recipients. "Knowingly relying" means you have +actual knowledge that, but for the patent license, your conveying the +covered work in a country, or your recipient's use of the covered work +in a country, would infringe one or more identifiable patents in that +country that you have reason to believe are valid. + + If, pursuant to or in connection with a single transaction or +arrangement, you convey, or propagate by procuring conveyance of, a +covered work, and grant a patent license to some of the parties +receiving the covered work authorizing them to use, propagate, modify +or convey a specific copy of the covered work, then the patent license +you grant is automatically extended to all recipients of the covered +work and works based on it. + + A patent license is "discriminatory" if it does not include within +the scope of its coverage, prohibits the exercise of, or is +conditioned on the non-exercise of one or more of the rights that are +specifically granted under this License. You may not convey a covered +work if you are a party to an arrangement with a third party that is +in the business of distributing software, under which you make payment +to the third party based on the extent of your activity of conveying +the work, and under which the third party grants, to any of the +parties who would receive the covered work from you, a discriminatory +patent license (a) in connection with copies of the covered work +conveyed by you (or copies made from those copies), or (b) primarily +for and in connection with specific products or compilations that +contain the covered work, unless you entered into that arrangement, +or that patent license was granted, prior to 28 March 2007. + + Nothing in this License shall be construed as excluding or limiting +any implied license or other defenses to infringement that may +otherwise be available to you under applicable patent law. + + 12. No Surrender of Others' Freedom. + + If conditions are imposed on you (whether by court order, agreement or +otherwise) that contradict the conditions of this License, they do not +excuse you from the conditions of this License. If you cannot convey a +covered work so as to satisfy simultaneously your obligations under this +License and any other pertinent obligations, then as a consequence you may +not convey it at all. For example, if you agree to terms that obligate you +to collect a royalty for further conveying from those to whom you convey +the Program, the only way you could satisfy both those terms and this +License would be to refrain entirely from conveying the Program. + + 13. Use with the GNU Affero General Public License. + + Notwithstanding any other provision of this License, you have +permission to link or combine any covered work with a work licensed +under version 3 of the GNU Affero General Public License into a single +combined work, and to convey the resulting work. The terms of this +License will continue to apply to the part which is the covered work, +but the special requirements of the GNU Affero General Public License, +section 13, concerning interaction through a network will apply to the +combination as such. + + 14. Revised Versions of this License. + + The Free Software Foundation may publish revised and/or new versions of +the GNU General Public License from time to time. Such new versions will +be similar in spirit to the present version, but may differ in detail to +address new problems or concerns. + + Each version is given a distinguishing version number. If the +Program specifies that a certain numbered version of the GNU General +Public License "or any later version" applies to it, you have the +option of following the terms and conditions either of that numbered +version or of any later version published by the Free Software +Foundation. If the Program does not specify a version number of the +GNU General Public License, you may choose any version ever published +by the Free Software Foundation. + + If the Program specifies that a proxy can decide which future +versions of the GNU General Public License can be used, that proxy's +public statement of acceptance of a version permanently authorizes you +to choose that version for the Program. + + Later license versions may give you additional or different +permissions. However, no additional obligations are imposed on any +author or copyright holder as a result of your choosing to follow a +later version. + + 15. Disclaimer of Warranty. + + THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY +APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT +HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY +OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO, +THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM +IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF +ALL NECESSARY SERVICING, REPAIR OR CORRECTION. + + 16. Limitation of Liability. + + IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING +WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS +THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY +GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE +USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF +DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD +PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS), +EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF +SUCH DAMAGES. + + 17. Interpretation of Sections 15 and 16. + + If the disclaimer of warranty and limitation of liability provided +above cannot be given local legal effect according to their terms, +reviewing courts shall apply local law that most closely approximates +an absolute waiver of all civil liability in connection with the +Program, unless a warranty or assumption of liability accompanies a +copy of the Program in return for a fee. + + END OF TERMS AND CONDITIONS + + How to Apply These Terms to Your New Programs + + If you develop a new program, and you want it to be of the greatest +possible use to the public, the best way to achieve this is to make it +free software which everyone can redistribute and change under these terms. + + To do so, attach the following notices to the program. It is safest +to attach them to the start of each source file to most effectively +state the exclusion of warranty; and each file should have at least +the "copyright" line and a pointer to where the full notice is found. + + + Copyright (C) + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with this program. If not, see . + +Also add information on how to contact you by electronic and paper mail. + + If the program does terminal interaction, make it output a short +notice like this when it starts in an interactive mode: + + Copyright (C) + This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'. + This is free software, and you are welcome to redistribute it + under certain conditions; type `show c' for details. + +The hypothetical commands `show w' and `show c' should show the appropriate +parts of the General Public License. Of course, your program's commands +might be different; for a GUI interface, you would use an "about box". + + You should also get your employer (if you work as a programmer) or school, +if any, to sign a "copyright disclaimer" for the program, if necessary. +For more information on this, and how to apply and follow the GNU GPL, see +. + + The GNU General Public License does not permit incorporating your program +into proprietary programs. If your program is a subroutine library, you +may consider it more useful to permit linking proprietary applications with +the library. If this is what you want to do, use the GNU Lesser General +Public License instead of this License. But first, please read +. diff --git a/extensions-builtin/sd_forge_controlnet/README.md b/extensions-builtin/sd_forge_controlnet/README.md new file mode 100644 index 00000000..38460eca --- /dev/null +++ b/extensions-builtin/sd_forge_controlnet/README.md @@ -0,0 +1,243 @@ +# ControlNet for Stable Diffusion WebUI + +The WebUI extension for ControlNet and other injection-based SD controls. + +![image](https://github.com/Mikubill/sd-webui-controlnet/assets/20929282/51172d20-606b-4b9f-aba5-db2f2417cb0b) + +This extension is for AUTOMATIC1111's [Stable Diffusion web UI](https://github.com/AUTOMATIC1111/stable-diffusion-webui), allows the Web UI to add [ControlNet](https://github.com/lllyasviel/ControlNet) to the original Stable Diffusion model to generate images. The addition is on-the-fly, the merging is not required. + +# Installation + +1. Open "Extensions" tab. +2. Open "Install from URL" tab in the tab. +3. Enter `https://github.com/Mikubill/sd-webui-controlnet.git` to "URL for extension's git repository". +4. Press "Install" button. +5. Wait for 5 seconds, and you will see the message "Installed into stable-diffusion-webui\extensions\sd-webui-controlnet. Use Installed tab to restart". +6. Go to "Installed" tab, click "Check for updates", and then click "Apply and restart UI". (The next time you can also use these buttons to update ControlNet.) +7. Completely restart A1111 webui including your terminal. (If you do not know what is a "terminal", you can reboot your computer to achieve the same effect.) +8. Download models (see below). +9. After you put models in the correct folder, you may need to refresh to see the models. The refresh button is right to your "Model" dropdown. + +# Download Models + +Right now all the 14 models of ControlNet 1.1 are in the beta test. + +Download the models from ControlNet 1.1: https://huggingface.co/lllyasviel/ControlNet-v1-1/tree/main + +You need to download model files ending with ".pth" . + +Put models in your "stable-diffusion-webui\extensions\sd-webui-controlnet\models". You only need to download "pth" files. + +Do not right-click the filenames in HuggingFace website to download. Some users right-clicked those HuggingFace HTML websites and saved those HTML pages as PTH/YAML files. They are not downloading correct files. Instead, please click the small download arrow “↓” icon in HuggingFace to download. + +# Download Models for SDXL + +See instructions [here](https://github.com/Mikubill/sd-webui-controlnet/discussions/2039). + +# Features in ControlNet 1.1 + +### Perfect Support for All ControlNet 1.0/1.1 and T2I Adapter Models. + +Now we have perfect support all available models and preprocessors, including perfect support for T2I style adapter and ControlNet 1.1 Shuffle. (Make sure that your YAML file names and model file names are same, see also YAML files in "stable-diffusion-webui\extensions\sd-webui-controlnet\models".) + +### Perfect Support for A1111 High-Res. Fix + +Now if you turn on High-Res Fix in A1111, each controlnet will output two different control images: a small one and a large one. The small one is for your basic generating, and the big one is for your High-Res Fix generating. The two control images are computed by a smart algorithm called "super high-quality control image resampling". This is turned on by default, and you do not need to change any setting. + +### Perfect Support for All A1111 Img2Img or Inpaint Settings and All Mask Types + +Now ControlNet is extensively tested with A1111's different types of masks, including "Inpaint masked"/"Inpaint not masked", and "Whole picture"/"Only masked", and "Only masked padding"&"Mask blur". The resizing perfectly matches A1111's "Just resize"/"Crop and resize"/"Resize and fill". This means you can use ControlNet in nearly everywhere in your A1111 UI without difficulty! + +### The New "Pixel-Perfect" Mode + +Now if you turn on pixel-perfect mode, you do not need to set preprocessor (annotator) resolutions manually. The ControlNet will automatically compute the best annotator resolution for you so that each pixel perfectly matches Stable Diffusion. + +### User-Friendly GUI and Preprocessor Preview + +We reorganized some previously confusing UI like "canvas width/height for new canvas" and it is in the 📝 button now. Now the preview GUI is controlled by the "allow preview" option and the trigger button 💥. The preview image size is better than before, and you do not need to scroll up and down - your a1111 GUI will not be messed up anymore! + +### Support for Almost All Upscaling Scripts + +Now ControlNet 1.1 can support almost all Upscaling/Tile methods. ControlNet 1.1 support the script "Ultimate SD upscale" and almost all other tile-based extensions. Please do not confuse ["Ultimate SD upscale"](https://github.com/Coyote-A/ultimate-upscale-for-automatic1111) with "SD upscale" - they are different scripts. Note that the most recommended upscaling method is ["Tiled VAE/Diffusion"](https://github.com/pkuliyi2015/multidiffusion-upscaler-for-automatic1111) but we test as many methods/extensions as possible. Note that "SD upscale" is supported since 1.1.117, and if you use it, you need to leave all ControlNet images as blank (We do not recommend "SD upscale" since it is somewhat buggy and cannot be maintained - use the "Ultimate SD upscale" instead). + +### More Control Modes (previously called Guess Mode) + +We have fixed many bugs in previous 1.0’s Guess Mode and now it is called Control Mode + +![image](https://user-images.githubusercontent.com/19834515/236641759-6c44ddf6-c7ad-4bda-92be-e90a52911d75.png) + +Now you can control which aspect is more important (your prompt or your ControlNet): + +* "Balanced": ControlNet on both sides of CFG scale, same as turning off "Guess Mode" in ControlNet 1.0 + +* "My prompt is more important": ControlNet on both sides of CFG scale, with progressively reduced SD U-Net injections (layer_weight*=0.825**I, where 0<=I <13, and the 13 means ControlNet injected SD 13 times). In this way, you can make sure that your prompts are perfectly displayed in your generated images. + +* "ControlNet is more important": ControlNet only on the Conditional Side of CFG scale (the cond in A1111's batch-cond-uncond). This means the ControlNet will be X times stronger if your cfg-scale is X. For example, if your cfg-scale is 7, then ControlNet is 7 times stronger. Note that here the X times stronger is different from "Control Weights" since your weights are not modified. This "stronger" effect usually has less artifact and give ControlNet more room to guess what is missing from your prompts (and in the previous 1.0, it is called "Guess Mode"). + + + + + + + + + + + + + + +
Input (depth+canny+hed)"Balanced""My prompt is more important""ControlNet is more important"
+ +### Reference-Only Control + +Now we have a `reference-only` preprocessor that does not require any control models. It can guide the diffusion directly using images as references. + +(Prompt "a dog running on grassland, best quality, ...") + +![image](samples/ref.png) + +This method is similar to inpaint-based reference but it does not make your image disordered. + +Many professional A1111 users know a trick to diffuse image with references by inpaint. For example, if you have a 512x512 image of a dog, and want to generate another 512x512 image with the same dog, some users will connect the 512x512 dog image and a 512x512 blank image into a 1024x512 image, send to inpaint, and mask out the blank 512x512 part to diffuse a dog with similar appearance. However, that method is usually not very satisfying since images are connected and many distortions will appear. + +This `reference-only` ControlNet can directly link the attention layers of your SD to any independent images, so that your SD will read arbitrary images for reference. You need at least ControlNet 1.1.153 to use it. + +To use, just select `reference-only` as preprocessor and put an image. Your SD will just use the image as reference. + +*Note that this method is as "non-opinioned" as possible. It only contains very basic connection codes, without any personal preferences, to connect the attention layers with your reference images. However, even if we tried best to not include any opinioned codes, we still need to write some subjective implementations to deal with weighting, cfg-scale, etc - tech report is on the way.* + +More examples [here](https://github.com/Mikubill/sd-webui-controlnet/discussions/1236). + +# Technical Documents + +See also the documents of ControlNet 1.1: + +https://github.com/lllyasviel/ControlNet-v1-1-nightly#model-specification + +# Default Setting + +This is my setting. If you run into any problem, you can use this setting as a sanity check + +![image](https://user-images.githubusercontent.com/19834515/235620638-17937171-8ac1-45bc-a3cb-3aebf605b4ef.png) + +# Use Previous Models + +### Use ControlNet 1.0 Models + +https://huggingface.co/lllyasviel/ControlNet/tree/main/models + +You can still use all previous models in the previous ControlNet 1.0. Now, the previous "depth" is now called "depth_midas", the previous "normal" is called "normal_midas", the previous "hed" is called "softedge_hed". And starting from 1.1, all line maps, edge maps, lineart maps, boundary maps will have black background and white lines. + +### Use T2I-Adapter Models + +(From TencentARC/T2I-Adapter) + +To use T2I-Adapter models: + +1. Download files from https://huggingface.co/TencentARC/T2I-Adapter/tree/main/models +2. Put them in "stable-diffusion-webui\extensions\sd-webui-controlnet\models". +3. Make sure that the file names of pth files and yaml files are consistent. + +*Note that "CoAdapter" is not implemented yet.* + +# Gallery + +The below results are from ControlNet 1.0. + +| Source | Input | Output | +|:-------------------------:|:-------------------------:|:-------------------------:| +| (no preprocessor) | | | +| (no preprocessor) | | | +| | | | +| | | | +| | | | +| | | | + +The below examples are from T2I-Adapter. + +From `t2iadapter_color_sd14v1.pth` : + +| Source | Input | Output | +|:-------------------------:|:-------------------------:|:-------------------------:| +| | | | + +From `t2iadapter_style_sd14v1.pth` : + +| Source | Input | Output | +|:-------------------------:|:-------------------------:|:-------------------------:| +| | (clip, non-image) | | + +# Minimum Requirements + +* (Windows) (NVIDIA: Ampere) 4gb - with `--xformers` enabled, and `Low VRAM` mode ticked in the UI, goes up to 768x832 + +# Multi-ControlNet + +This option allows multiple ControlNet inputs for a single generation. To enable this option, change `Multi ControlNet: Max models amount (requires restart)` in the settings. Note that you will need to restart the WebUI for changes to take effect. + + + + + + + + + + + + +
Source ASource BOutput
+ +# Control Weight/Start/End + +Weight is the weight of the controlnet "influence". It's analogous to prompt attention/emphasis. E.g. (myprompt: 1.2). Technically, it's the factor by which to multiply the ControlNet outputs before merging them with original SD Unet. + +Guidance Start/End is the percentage of total steps the controlnet applies (guidance strength = guidance end). It's analogous to prompt editing/shifting. E.g. \[myprompt::0.8\] (It applies from the beginning until 80% of total steps) + +# Batch Mode + +Put any unit into batch mode to activate batch mode for all units. Specify a batch directory for each unit, or use the new textbox in the img2img batch tab as a fallback. Although the textbox is located in the img2img batch tab, you can use it to generate images in the txt2img tab as well. + +Note that this feature is only available in the gradio user interface. Call the APIs as many times as you want for custom batch scheduling. + +# API and Script Access + +This extension can accept txt2img or img2img tasks via API or external extension call. Note that you may need to enable `Allow other scripts to control this extension` in settings for external calls. + +To use the API: start WebUI with argument `--api` and go to `http://webui-address/docs` for documents or checkout [examples](https://github.com/Mikubill/sd-webui-controlnet/blob/main/example/txt2img_example/api_txt2img.py). + +To use external call: Checkout [Wiki](https://github.com/Mikubill/sd-webui-controlnet/wiki/API) + +# Command Line Arguments + +This extension adds these command line arguments to the webui: + +``` + --controlnet-dir ADD a controlnet models directory + --controlnet-annotator-models-path SET the directory for annotator models + --no-half-controlnet load controlnet models in full precision + --controlnet-preprocessor-cache-size Cache size for controlnet preprocessor results + --controlnet-loglevel Log level for the controlnet extension + --controlnet-tracemalloc Enable malloc memory tracing +``` + +# MacOS Support + +Tested with pytorch nightly: https://github.com/Mikubill/sd-webui-controlnet/pull/143#issuecomment-1435058285 + +To use this extension with mps and normal pytorch, currently you may need to start WebUI with `--no-half`. + +# Archive of Deprecated Versions + +The previous version (sd-webui-controlnet 1.0) is archived in + +https://github.com/lllyasviel/webui-controlnet-v1-archived + +Using this version is not a temporary stop of updates. You will stop all updates forever. + +Please consider this version if you work with professional studios that requires 100% reproducing of all previous results pixel by pixel. + +# Thanks + +This implementation is inspired by kohya-ss/sd-webui-additional-networks diff --git a/extensions-builtin/sd_forge_controlnet/install.py b/extensions-builtin/sd_forge_controlnet/install.py new file mode 100644 index 00000000..5370d221 --- /dev/null +++ b/extensions-builtin/sd_forge_controlnet/install.py @@ -0,0 +1,66 @@ +import launch +import pkg_resources +import sys +import os +import shutil +import platform +from pathlib import Path +from typing import Tuple, Optional + + +repo_root = Path(__file__).parent +main_req_file = repo_root / "requirements.txt" + + +def comparable_version(version: str) -> Tuple: + return tuple(version.split(".")) + + +def get_installed_version(package: str) -> Optional[str]: + try: + return pkg_resources.get_distribution(package).version + except Exception: + return None + + +def extract_base_package(package_string: str) -> str: + base_package = package_string.split("@git")[0] + return base_package + + +def install_requirements(req_file): + with open(req_file) as file: + for package in file: + try: + package = package.strip() + if "==" in package: + package_name, package_version = package.split("==") + installed_version = get_installed_version(package_name) + if installed_version != package_version: + launch.run_pip( + f"install -U {package}", + f"sd-forge-controlnet requirement: changing {package_name} version from {installed_version} to {package_version}", + ) + elif ">=" in package: + package_name, package_version = package.split(">=") + installed_version = get_installed_version(package_name) + if not installed_version or comparable_version( + installed_version + ) < comparable_version(package_version): + launch.run_pip( + f"install -U {package}", + f"sd-forge-controlnet requirement: changing {package_name} version from {installed_version} to {package_version}", + ) + elif not launch.is_installed(extract_base_package(package)): + launch.run_pip( + f"install {package}", + f"sd-forge-controlnet requirement: {package}", + ) + except Exception as e: + print(e) + print( + f"Warning: Failed to install {package}, some preprocessors may not work." + ) + + +install_requirements(main_req_file) diff --git a/extensions-builtin/sd_forge_controlnet/javascript/active_units.js b/extensions-builtin/sd_forge_controlnet/javascript/active_units.js new file mode 100644 index 00000000..a2662055 --- /dev/null +++ b/extensions-builtin/sd_forge_controlnet/javascript/active_units.js @@ -0,0 +1,311 @@ +/** + * Give a badge on ControlNet Accordion indicating total number of active + * units. + * Make active unit's tab name green. + * Append control type to tab name. + * Disable resize mode selection when A1111 img2img input is used. + */ +(function () { + const cnetAllAccordions = new Set(); + onUiUpdate(() => { + const ImgChangeType = { + NO_CHANGE: 0, + REMOVE: 1, + ADD: 2, + SRC_CHANGE: 3, + }; + + function imgChangeObserved(mutationsList) { + // Iterate over all mutations that just occured + for (let mutation of mutationsList) { + // Check if the mutation is an addition or removal of a node + if (mutation.type === 'childList') { + // Check if nodes were added + if (mutation.addedNodes.length > 0) { + for (const node of mutation.addedNodes) { + if (node.tagName === 'IMG') { + return ImgChangeType.ADD; + } + } + } + + // Check if nodes were removed + if (mutation.removedNodes.length > 0) { + for (const node of mutation.removedNodes) { + if (node.tagName === 'IMG') { + return ImgChangeType.REMOVE; + } + } + } + } + // Check if the mutation is a change of an attribute + else if (mutation.type === 'attributes') { + if (mutation.target.tagName === 'IMG' && mutation.attributeName === 'src') { + return ImgChangeType.SRC_CHANGE; + } + } + } + return ImgChangeType.NO_CHANGE; + } + + function childIndex(element) { + // Get all child nodes of the parent + let children = Array.from(element.parentNode.childNodes); + + // Filter out non-element nodes (like text nodes and comments) + children = children.filter(child => child.nodeType === Node.ELEMENT_NODE); + + return children.indexOf(element); + } + + function imageInputDisabledAlert() { + alert('Inpaint control type must use a1111 input in img2img mode.'); + } + + class ControlNetUnitTab { + constructor(tab, accordion) { + this.tab = tab; + this.accordion = accordion; + this.isImg2Img = tab.querySelector('.cnet-unit-enabled').id.includes('img2img'); + + this.enabledCheckbox = tab.querySelector('.cnet-unit-enabled input'); + this.inputImage = tab.querySelector('.cnet-input-image-group .cnet-image input[type="file"]'); + this.inputImageContainer = tab.querySelector('.cnet-input-image-group .cnet-image'); + this.controlTypeRadios = tab.querySelectorAll('.controlnet_control_type_filter_group input[type="radio"]'); + this.resizeModeRadios = tab.querySelectorAll('.controlnet_resize_mode_radio input[type="radio"]'); + this.runPreprocessorButton = tab.querySelector('.cnet-run-preprocessor'); + + const tabs = tab.parentNode; + this.tabNav = tabs.querySelector('.tab-nav'); + this.tabIndex = childIndex(tab) - 1; // -1 because tab-nav is also at the same level. + + this.attachEnabledButtonListener(); + this.attachControlTypeRadioListener(); + this.attachTabNavChangeObserver(); + this.attachImageUploadListener(); + this.attachImageStateChangeObserver(); + this.attachA1111SendInfoObserver(); + this.attachPresetDropdownObserver(); + } + + getTabNavButton() { + return this.tabNav.querySelector(`:nth-child(${this.tabIndex + 1})`); + } + + getActiveControlType() { + for (let radio of this.controlTypeRadios) { + if (radio.checked) { + return radio.value; + } + } + return undefined; + } + + updateActiveState() { + const tabNavButton = this.getTabNavButton(); + if (!tabNavButton) return; + + if (this.enabledCheckbox.checked) { + tabNavButton.classList.add('cnet-unit-active'); + } else { + tabNavButton.classList.remove('cnet-unit-active'); + } + } + + updateActiveUnitCount() { + function getActiveUnitCount(checkboxes) { + let activeUnitCount = 0; + for (const checkbox of checkboxes) { + if (checkbox.checked) + activeUnitCount++; + } + return activeUnitCount; + } + + const checkboxes = this.accordion.querySelectorAll('.cnet-unit-enabled input'); + const span = this.accordion.querySelector('.label-wrap span'); + + // Remove existing badge. + if (span.childNodes.length !== 1) { + span.removeChild(span.lastChild); + } + // Add new badge if necessary. + const activeUnitCount = getActiveUnitCount(checkboxes); + if (activeUnitCount > 0) { + const div = document.createElement('div'); + div.classList.add('cnet-badge'); + div.classList.add('primary'); + div.innerHTML = `${activeUnitCount} unit${activeUnitCount > 1 ? 's' : ''}`; + span.appendChild(div); + } + } + + /** + * Add the active control type to tab displayed text. + */ + updateActiveControlType() { + const tabNavButton = this.getTabNavButton(); + if (!tabNavButton) return; + + // Remove the control if exists + const controlTypeSuffix = tabNavButton.querySelector('.control-type-suffix'); + if (controlTypeSuffix) controlTypeSuffix.remove(); + + // Add new suffix. + const controlType = this.getActiveControlType(); + if (controlType === 'All') return; + + const span = document.createElement('span'); + span.innerHTML = `[${controlType}]`; + span.classList.add('control-type-suffix'); + tabNavButton.appendChild(span); + } + + /** + * When 'Inpaint' control type is selected in img2img: + * - Make image input disabled + * - Clear existing image input + */ + updateImageInputState() { + if (!this.isImg2Img) return; + + const tabNavButton = this.getTabNavButton(); + if (!tabNavButton) return; + + const controlType = this.getActiveControlType(); + if (controlType.toLowerCase() === 'inpaint') { + this.inputImage.disabled = true; + this.inputImage.parentNode.addEventListener('click', imageInputDisabledAlert); + const removeButton = this.tab.querySelector( + '.cnet-input-image-group .cnet-image button[aria-label="Remove Image"]'); + if (removeButton) removeButton.click(); + } else { + this.inputImage.disabled = false; + this.inputImage.parentNode.removeEventListener('click', imageInputDisabledAlert); + } + } + + attachEnabledButtonListener() { + this.enabledCheckbox.addEventListener('change', () => { + this.updateActiveState(); + this.updateActiveUnitCount(); + }); + } + + attachControlTypeRadioListener() { + for (const radio of this.controlTypeRadios) { + radio.addEventListener('change', () => { + this.updateActiveControlType(); + }); + } + } + + /** + * Each time the active tab change, all tab nav buttons are cleared and + * regenerated by gradio. So we need to reapply the active states on + * them. + */ + attachTabNavChangeObserver() { + new MutationObserver((mutationsList) => { + for (const mutation of mutationsList) { + if (mutation.type === 'childList') { + this.updateActiveState(); + this.updateActiveControlType(); + } + } + }).observe(this.tabNav, { childList: true }); + } + + attachImageUploadListener() { + // Automatically check `enable` checkbox when image is uploaded. + this.inputImage.addEventListener('change', (event) => { + if (!event.target.files) return; + if (!this.enabledCheckbox.checked) + this.enabledCheckbox.click(); + }); + + // Automatically check `enable` checkbox when JSON pose file is uploaded. + this.tab.querySelector('.cnet-upload-pose input').addEventListener('change', (event) => { + if (!event.target.files) return; + if (!this.enabledCheckbox.checked) + this.enabledCheckbox.click(); + }); + } + + attachImageStateChangeObserver() { + new MutationObserver((mutationsList) => { + const changeObserved = imgChangeObserved(mutationsList); + + if (changeObserved === ImgChangeType.ADD) { + // enabling the run preprocessor button + this.runPreprocessorButton.removeAttribute("disabled"); + this.runPreprocessorButton.title = 'Run preprocessor'; + } + + if (changeObserved === ImgChangeType.REMOVE) { + // disabling the run preprocessor button + this.runPreprocessorButton.setAttribute("disabled", true); + this.runPreprocessorButton.title = "No ControlNet input image available"; + } + }).observe(this.inputImageContainer, { + childList: true, + subtree: true, + }); + } + + /** + * Observe send PNG info buttons in A1111, as they can also directly + * set states of ControlNetUnit. + */ + attachA1111SendInfoObserver() { + const pasteButtons = gradioApp().querySelectorAll('#paste'); + const pngButtons = gradioApp().querySelectorAll( + this.isImg2Img ? + '#img2img_tab, #inpaint_tab' : + '#txt2img_tab' + ); + + for (const button of [...pasteButtons, ...pngButtons]) { + button.addEventListener('click', () => { + // The paste/send img generation info feature goes + // though gradio, which is pretty slow. Ideally we should + // observe the event when gradio has done the job, but + // that is not an easy task. + // Here we just do a 2 second delay until the refresh. + setTimeout(() => { + this.updateActiveState(); + this.updateActiveUnitCount(); + }, 2000); + }); + } + } + + attachPresetDropdownObserver() { + const presetDropDown = this.tab.querySelector('.cnet-preset-dropdown'); + + new MutationObserver((mutationsList) => { + for (const mutation of mutationsList) { + if (mutation.removedNodes.length > 0) { + setTimeout(() => { + this.updateActiveState(); + this.updateActiveUnitCount(); + this.updateActiveControlType(); + }, 1000); + return; + } + } + }).observe(presetDropDown, { + childList: true, + subtree: true, + }); + } + } + + gradioApp().querySelectorAll('#controlnet').forEach(accordion => { + if (cnetAllAccordions.has(accordion)) return; + accordion.querySelectorAll('.cnet-unit-tab') + .forEach(tab => new ControlNetUnitTab(tab, accordion)); + cnetAllAccordions.add(accordion); + }); + }); +})(); \ No newline at end of file diff --git a/extensions-builtin/sd_forge_controlnet/javascript/canvas.js b/extensions-builtin/sd_forge_controlnet/javascript/canvas.js new file mode 100644 index 00000000..a122c9fa --- /dev/null +++ b/extensions-builtin/sd_forge_controlnet/javascript/canvas.js @@ -0,0 +1,17 @@ +(function () { + var hasApplied = false; + onUiUpdate(function () { + if (!hasApplied) { + if (typeof window.applyZoomAndPanIntegration === "function") { + hasApplied = true; + window.applyZoomAndPanIntegration("#txt2img_controlnet", Array.from({ length: 20 }, (_, i) => `#txt2img_controlnet_ControlNet-${i}_input_image`)); + window.applyZoomAndPanIntegration("#img2img_controlnet", Array.from({ length: 20 }, (_, i) => `#img2img_controlnet_ControlNet-${i}_input_image`)); + window.applyZoomAndPanIntegration("#txt2img_controlnet", ["#txt2img_controlnet_ControlNet_input_image"]); + window.applyZoomAndPanIntegration("#img2img_controlnet", ["#img2img_controlnet_ControlNet_input_image"]); + //console.log("window.applyZoomAndPanIntegration applied."); + } else { + //console.log("window.applyZoomAndPanIntegration is not available."); + } + } + }); +})(); diff --git a/extensions-builtin/sd_forge_controlnet/javascript/modal.js b/extensions-builtin/sd_forge_controlnet/javascript/modal.js new file mode 100644 index 00000000..dc6190de --- /dev/null +++ b/extensions-builtin/sd_forge_controlnet/javascript/modal.js @@ -0,0 +1,33 @@ +(function () { + const cnetModalRegisteredElements = new Set(); + onUiUpdate(() => { + // Get all the buttons that open a modal + const btns = gradioApp().querySelectorAll(".cnet-modal-open"); + + // Get all the elements that close a modal + const spans = document.querySelectorAll(".cnet-modal-close"); + + // For each button, add a click event listener that opens the corresponding modal + btns.forEach((btn) => { + if (cnetModalRegisteredElements.has(btn)) return; + cnetModalRegisteredElements.add(btn); + + const modalId = btn.id.replace('cnet-modal-open-', ''); + const modal = document.getElementById("cnet-modal-" + modalId); + btn.addEventListener('click', () => { + modal.style.display = "block"; + }); + }); + + // For each element, add a click event listener that closes the corresponding modal + spans.forEach((span) => { + if (cnetModalRegisteredElements.has(span)) return; + cnetModalRegisteredElements.add(span); + + const modal = span.parentNode; + span.addEventListener('click', () => { + modal.style.display = "none"; + }); + }); + }); +})(); diff --git a/extensions-builtin/sd_forge_controlnet/javascript/openpose_editor.js b/extensions-builtin/sd_forge_controlnet/javascript/openpose_editor.js new file mode 100644 index 00000000..1c7b570a --- /dev/null +++ b/extensions-builtin/sd_forge_controlnet/javascript/openpose_editor.js @@ -0,0 +1,152 @@ +(function () { + async function checkEditorAvailable() { + const LOCAL_EDITOR_PATH = '/openpose_editor_index'; + const REMOTE_EDITOR_PATH = 'https://huchenlei.github.io/sd-webui-openpose-editor/'; + + async function testEditorPath(path) { + const res = await fetch(path); + return res.status === 200 ? path : null; + } + + // Use local editor if the user has the extension installed. Fallback + // onto remote editor if the local editor is not ready yet. + // See https://github.com/huchenlei/sd-webui-openpose-editor/issues/53 + // for more details. + return await testEditorPath(LOCAL_EDITOR_PATH) || await testEditorPath(REMOTE_EDITOR_PATH); + } + + const cnetOpenposeEditorRegisteredElements = new Set(); + let editorURL = null; + function loadOpenposeEditor() { + // Simulate an `input` DOM event for Gradio Textbox component. Needed after you edit its contents in javascript, otherwise your edits + // will only visible on web page and not sent to python. + function updateInput(target) { + let e = new Event("input", { bubbles: true }) + Object.defineProperty(e, "target", { value: target }) + target.dispatchEvent(e); + } + + function navigateIframe(iframe, editorURL) { + function getPathname(rawURL) { + try { + return new URL(rawURL).pathname; + } catch (e) { + return rawURL; + } + } + + return new Promise((resolve) => { + const darkThemeParam = document.body.classList.contains('dark') ? + new URLSearchParams({ theme: 'dark' }).toString() : + ''; + + window.addEventListener('message', (event) => { + const message = event.data; + if (message['ready']) resolve(); + }, { once: true }); + + if ((editorURL.startsWith("http") ? iframe.src : getPathname(iframe.src)) !== editorURL) { + iframe.src = `${editorURL}?${darkThemeParam}`; + // By default assume 5 second is enough for the openpose editor + // to load. + setTimeout(resolve, 5000); + } else { + // If no navigation is required, immediately return. + resolve(); + } + }); + } + const tabs = gradioApp().querySelectorAll('.cnet-unit-tab'); + tabs.forEach(tab => { + if (cnetOpenposeEditorRegisteredElements.has(tab)) return; + cnetOpenposeEditorRegisteredElements.add(tab); + + const generatedImageGroup = tab.querySelector('.cnet-generated-image-group'); + const editButton = generatedImageGroup.querySelector('.cnet-edit-pose'); + + editButton.addEventListener('click', async () => { + const inputImageGroup = tab.querySelector('.cnet-input-image-group'); + const inputImage = inputImageGroup.querySelector('.cnet-image img'); + const downloadLink = generatedImageGroup.querySelector('.cnet-download-pose a'); + const modalId = editButton.id.replace('cnet-modal-open-', ''); + const modalIframe = generatedImageGroup.querySelector('.cnet-modal iframe'); + + if (!editorURL) { + editorURL = await checkEditorAvailable(); + if (!editorURL) { + alert("No openpose editor available.") + } + } + + await navigateIframe(modalIframe, editorURL); + modalIframe.contentWindow.postMessage({ + modalId, + imageURL: inputImage ? inputImage.src : undefined, + poseURL: downloadLink.href, + }, '*'); + // Focus the iframe so that the focus is no longer on the `Edit` button. + // Pressing space when the focus is on `Edit` button will trigger + // the click again to resend the frame message. + modalIframe.contentWindow.focus(); + }); + /* + * Writes the pose data URL to an link element on input image group. + * Click a hidden button to trigger a backend rendering of the pose JSON. + * + * The backend should: + * - Set the rendered pose image as preprocessor generated image. + */ + function updatePreviewPose(poseURL) { + const downloadLink = generatedImageGroup.querySelector('.cnet-download-pose a'); + const renderButton = generatedImageGroup.querySelector('.cnet-render-pose'); + const poseTextbox = generatedImageGroup.querySelector('.cnet-pose-json textarea'); + const allowPreviewCheckbox = tab.querySelector('.cnet-allow-preview input'); + + if (!allowPreviewCheckbox.checked) + allowPreviewCheckbox.click(); + + // Only set href when download link exists and needs an update. `downloadLink` + // can be null when user closes preview and click `Upload JSON` button again. + // https://github.com/Mikubill/sd-webui-controlnet/issues/2308 + if (downloadLink !== null) + downloadLink.href = poseURL; + + poseTextbox.value = poseURL; + updateInput(poseTextbox); + renderButton.click(); + } + + // Updates preview image when edit is done. + window.addEventListener('message', (event) => { + const message = event.data; + const modalId = editButton.id.replace('cnet-modal-open-', ''); + if (message.modalId !== modalId) return; + updatePreviewPose(message.poseURL); + + const closeModalButton = generatedImageGroup.querySelector('.cnet-modal .cnet-modal-close'); + closeModalButton.click(); + }); + + const inputImageGroup = tab.querySelector('.cnet-input-image-group'); + const uploadButton = inputImageGroup.querySelector('.cnet-upload-pose input'); + // Updates preview image when JSON file is uploaded. + uploadButton.addEventListener('change', (event) => { + const file = event.target.files[0]; + if (!file) + return; + + const reader = new FileReader(); + reader.onload = function (e) { + const contents = e.target.result; + const poseURL = `data:application/json;base64,${btoa(contents)}`; + updatePreviewPose(poseURL); + }; + reader.readAsText(file); + // Reset the file input value so that uploading the same file still triggers callback. + event.target.value = ''; + }); + }); + } + + onUiUpdate(loadOpenposeEditor); +})(); \ No newline at end of file diff --git a/extensions-builtin/sd_forge_controlnet/javascript/photopea.js b/extensions-builtin/sd_forge_controlnet/javascript/photopea.js new file mode 100644 index 00000000..d2b1ebc9 --- /dev/null +++ b/extensions-builtin/sd_forge_controlnet/javascript/photopea.js @@ -0,0 +1,435 @@ +(function () { + /* + MIT LICENSE + Copyright 2011 Jon Leighton + Permission is hereby granted, free of charge, to any person obtaining a copy of this software and + associated documentation files (the "Software"), to deal in the Software without restriction, + including without limitation the rights to use, copy, modify, merge, publish, distribute, + sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is + furnished to do so, subject to the following conditions: + The above copyright notice and this permission notice shall be included in all copies or substantial + portions of the Software. + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR + PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + */ + // From: https://gist.github.com/jonleighton/958841 + function base64ArrayBuffer(arrayBuffer) { + var base64 = '' + var encodings = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/' + + var bytes = new Uint8Array(arrayBuffer) + var byteLength = bytes.byteLength + var byteRemainder = byteLength % 3 + var mainLength = byteLength - byteRemainder + + var a, b, c, d + var chunk + + // Main loop deals with bytes in chunks of 3 + for (var i = 0; i < mainLength; i = i + 3) { + // Combine the three bytes into a single integer + chunk = (bytes[i] << 16) | (bytes[i + 1] << 8) | bytes[i + 2] + + // Use bitmasks to extract 6-bit segments from the triplet + a = (chunk & 16515072) >> 18 // 16515072 = (2^6 - 1) << 18 + b = (chunk & 258048) >> 12 // 258048 = (2^6 - 1) << 12 + c = (chunk & 4032) >> 6 // 4032 = (2^6 - 1) << 6 + d = chunk & 63 // 63 = 2^6 - 1 + + // Convert the raw binary segments to the appropriate ASCII encoding + base64 += encodings[a] + encodings[b] + encodings[c] + encodings[d] + } + + // Deal with the remaining bytes and padding + if (byteRemainder == 1) { + chunk = bytes[mainLength] + + a = (chunk & 252) >> 2 // 252 = (2^6 - 1) << 2 + + // Set the 4 least significant bits to zero + b = (chunk & 3) << 4 // 3 = 2^2 - 1 + + base64 += encodings[a] + encodings[b] + '==' + } else if (byteRemainder == 2) { + chunk = (bytes[mainLength] << 8) | bytes[mainLength + 1] + + a = (chunk & 64512) >> 10 // 64512 = (2^6 - 1) << 10 + b = (chunk & 1008) >> 4 // 1008 = (2^6 - 1) << 4 + + // Set the 2 least significant bits to zero + c = (chunk & 15) << 2 // 15 = 2^4 - 1 + + base64 += encodings[a] + encodings[b] + encodings[c] + '=' + } + + return base64 + } + + // Turn a base64 string into a blob. + // From https://gist.github.com/gauravmehla/7a7dfd87dd7d1b13697b6e894426615f + function b64toBlob(b64Data, contentType, sliceSize) { + var contentType = contentType || ''; + var sliceSize = sliceSize || 512; + var byteCharacters = atob(b64Data); + var byteArrays = []; + for (var offset = 0; offset < byteCharacters.length; offset += sliceSize) { + var slice = byteCharacters.slice(offset, offset + sliceSize); + var byteNumbers = new Array(slice.length); + for (var i = 0; i < slice.length; i++) { + byteNumbers[i] = slice.charCodeAt(i); + } + var byteArray = new Uint8Array(byteNumbers); + byteArrays.push(byteArray); + } + return new Blob(byteArrays, { type: contentType }); + } + + function createBlackImageBase64(width, height) { + // Create a canvas element + var canvas = document.createElement('canvas'); + canvas.width = width; + canvas.height = height; + + // Get the context of the canvas + var ctx = canvas.getContext('2d'); + + // Fill the canvas with black color + ctx.fillStyle = 'black'; + ctx.fillRect(0, 0, width, height); + + // Get the base64 encoded string + var base64Image = canvas.toDataURL('image/png'); + + return base64Image; + } + + // Functions to be called within photopea context. + // Start of photopea functions + function pasteImage(base64image) { + app.open(base64image, null, /* asSmart */ true); + app.echoToOE("success"); + } + + function setLayerNames(names) { + const layers = app.activeDocument.layers; + if (layers.length !== names.length) { + console.error("layer length does not match names length"); + echoToOE("error"); + return; + } + + for (let i = 0; i < names.length; i++) { + const layer = layers[i]; + layer.name = names[i]; + } + app.echoToOE("success"); + } + + function removeLayersWithNames(names) { + const layers = app.activeDocument.layers; + for (let i = 0; i < layers.length; i++) { + const layer = layers[i]; + if (names.includes(layer.name)) { + layer.remove(); + } + } + app.echoToOE("success"); + } + + function getAllLayerNames() { + const layers = app.activeDocument.layers; + const names = []; + for (let i = 0; i < layers.length; i++) { + const layer = layers[i]; + names.push(layer.name); + } + app.echoToOE(JSON.stringify(names)); + } + + // Hides all layers except the current one, outputs the whole image, then restores the previous + // layers state. + function exportSelectedLayerOnly(format, layerName) { + // Gets all layers recursively, including the ones inside folders. + function getAllArtLayers(document) { + let allArtLayers = []; + + for (let i = 0; i < document.layers.length; i++) { + const currentLayer = document.layers[i]; + allArtLayers.push(currentLayer); + if (currentLayer.typename === "LayerSet") { + allArtLayers = allArtLayers.concat(getAllArtLayers(currentLayer)); + } + } + return allArtLayers; + } + + function makeLayerVisible(layer) { + let currentLayer = layer; + while (currentLayer != app.activeDocument) { + currentLayer.visible = true; + if (currentLayer.parent.typename != 'Document') { + currentLayer = currentLayer.parent; + } else { + break; + } + } + } + + + const allLayers = getAllArtLayers(app.activeDocument); + // Make all layers except the currently selected one invisible, and store + // their initial state. + const layerStates = []; + for (let i = 0; i < allLayers.length; i++) { + const layer = allLayers[i]; + layerStates.push(layer.visible); + } + // Hide all layers to begin with + for (let i = 0; i < allLayers.length; i++) { + const layer = allLayers[i]; + layer.visible = false; + } + for (let i = 0; i < allLayers.length; i++) { + const layer = allLayers[i]; + const selected = layer.name === layerName; + if (selected) { + makeLayerVisible(layer); + } + } + app.activeDocument.saveToOE(format); + + for (let i = 0; i < allLayers.length; i++) { + const layer = allLayers[i]; + layer.visible = layerStates[i]; + } + } + + function hasActiveDocument() { + app.echoToOE(app.documents.length > 0 ? "true" : "false"); + } + // End of photopea functions + + const MESSAGE_END_ACK = "done"; + const MESSAGE_ERROR = "error"; + const PHOTOPEA_URL = "https://www.photopea.com/"; + class PhotopeaContext { + constructor(photopeaIframe) { + this.photopeaIframe = photopeaIframe; + this.timeout = 1000; + } + + navigateIframe() { + const iframe = this.photopeaIframe; + const editorURL = PHOTOPEA_URL; + + return new Promise(async (resolve) => { + if (iframe.src !== editorURL) { + iframe.src = editorURL; + // Stop waiting after 10s. + setTimeout(resolve, 10000); + + // Testing whether photopea is able to accept message. + while (true) { + try { + await this.invoke(hasActiveDocument); + break; + } catch (e) { + console.log("Keep waiting for photopea to accept message."); + } + } + this.timeout = 5000; // Restore to a longer timeout in normal messaging. + } + resolve(); + }); + } + + // From https://github.com/huchenlei/stable-diffusion-ps-pea/blob/main/src/Photopea.ts + postMessageToPhotopea(message) { + return new Promise((resolve, reject) => { + const responseDataPieces = []; + let hasError = false; + const photopeaMessageHandle = (event) => { + if (event.source !== this.photopeaIframe.contentWindow) { + return; + } + // Filter out the ping messages + if (typeof event.data === 'string' && event.data.includes('MSFAPI#')) { + return; + } + // Ignore "done" when no data has been received. The "done" can come from + // MSFAPI ping. + if (event.data === MESSAGE_END_ACK && responseDataPieces.length === 0) { + return; + } + if (event.data === MESSAGE_END_ACK) { + window.removeEventListener("message", photopeaMessageHandle); + if (hasError) { + reject('Photopea Error.'); + } else { + resolve(responseDataPieces.length === 1 ? responseDataPieces[0] : responseDataPieces); + } + } else if (event.data === MESSAGE_ERROR) { + responseDataPieces.push(event.data); + hasError = true; + } else { + responseDataPieces.push(event.data); + } + }; + + window.addEventListener("message", photopeaMessageHandle); + setTimeout(() => reject("Photopea message timeout"), this.timeout); + this.photopeaIframe.contentWindow.postMessage(message, "*"); + }); + } + + // From https://github.com/huchenlei/stable-diffusion-ps-pea/blob/main/src/Photopea.ts + async invoke(func, ...args) { + await this.navigateIframe(); + const message = `${func.toString()} ${func.name}(${args.map(arg => JSON.stringify(arg)).join(',')});`; + try { + return await this.postMessageToPhotopea(message); + } catch (e) { + throw `Failed to invoke ${func.name}. ${e}.`; + } + } + + /** + * Fetch detected maps from each ControlNet units. + * Create a new photopea document. + * Add those detected maps to the created document. + */ + async fetchFromControlNet(tabs) { + if (tabs.length === 0) return; + const isImg2Img = tabs[0].querySelector('.cnet-unit-enabled').id.includes('img2img'); + const generationType = isImg2Img ? 'img2img' : 'txt2img'; + const width = gradioApp().querySelector(`#${generationType}_width input[type=number]`).value; + const height = gradioApp().querySelector(`#${generationType}_height input[type=number]`).value; + + const layerNames = ["background"]; + await this.invoke(pasteImage, createBlackImageBase64(width, height)); + await new Promise(r => setTimeout(r, 200)); + for (const [i, tab] of tabs.entries()) { + const generatedImage = tab.querySelector('.cnet-generated-image-group .cnet-image img'); + if (!generatedImage) continue; + await this.invoke(pasteImage, generatedImage.src); + // Wait 200ms for pasting to fully complete so that we do not ended up with 2 separate + // documents. + await new Promise(r => setTimeout(r, 200)); + layerNames.push(`unit-${i}`); + } + await this.invoke(removeLayersWithNames, layerNames); + await this.invoke(setLayerNames, layerNames.reverse()); + } + + /** + * Send the images in the active photopea document back to each ControlNet units. + */ + async sendToControlNet(tabs) { + // Gradio's image widgets are inputs. To set the image in one, we set the image on the input and + // force it to refresh. + function setImageOnInput(imageInput, file) { + // Createa a data transfer element to set as the data in the input. + const dt = new DataTransfer(); + dt.items.add(file); + const list = dt.files; + + // Actually set the image in the image widget. + imageInput.files = list; + + // Foce the image widget to update with the new image, after setting its source files. + const event = new Event('change', { + 'bubbles': true, + "composed": true + }); + imageInput.dispatchEvent(event); + } + + function sendToControlNetUnit(b64Image, index) { + const tab = tabs[index]; + // Upload image to output image element. + const outputImage = tab.querySelector('.cnet-photopea-output'); + const outputImageUpload = outputImage.querySelector('input[type="file"]'); + setImageOnInput(outputImageUpload, new File([b64toBlob(b64Image, "image/png")], "photopea_output.png")); + + // Make sure `UsePreviewAsInput` checkbox is checked. + const checkbox = tab.querySelector('.cnet-preview-as-input input[type="checkbox"]'); + if (!checkbox.checked) { + checkbox.click(); + } + } + + const layerNames = + JSON.parse(await this.invoke(getAllLayerNames)) + .filter(name => /unit-\d+/.test(name)); + + for (const layerName of layerNames) { + const arrayBuffer = await this.invoke(exportSelectedLayerOnly, 'PNG', layerName); + const b64Image = base64ArrayBuffer(arrayBuffer); + const layerIndex = Number.parseInt(layerName.split('-')[1]); + sendToControlNetUnit(b64Image, layerIndex); + } + } + } + + let photopeaWarningShown = false; + + function firstTimeUserPrompt() { + if (opts.controlnet_photopea_warning){ + const photopeaPopupMsg = "you are about to connect to https://photopea.com\n" + + "- Click OK: proceed.\n" + + "- Click Cancel: abort.\n" + + "Photopea integration can be disabled in Settings > ControlNet > Disable photopea edit.\n" + + "This popup can be disabled in Settings > ControlNet > Photopea popup warning."; + if (photopeaWarningShown || confirm(photopeaPopupMsg)) photopeaWarningShown = true; + else return false; + } + return true; + } + + const cnetRegisteredAccordions = new Set(); + function loadPhotopea() { + function registerCallbacks(accordion) { + const photopeaMainTrigger = accordion.querySelector('.cnet-photopea-main-trigger'); + // Photopea edit feature disabled. + if (!photopeaMainTrigger) { + console.log("ControlNet photopea edit disabled."); + return; + } + + const closeModalButton = accordion.querySelector('.cnet-photopea-edit .cnet-modal-close'); + const tabs = accordion.querySelectorAll('.cnet-unit-tab'); + const photopeaIframe = accordion.querySelector('.photopea-iframe'); + const photopeaContext = new PhotopeaContext(photopeaIframe, tabs); + + tabs.forEach(tab => { + const photopeaChildTrigger = tab.querySelector('.cnet-photopea-child-trigger'); + photopeaChildTrigger.addEventListener('click', async () => { + if (!firstTimeUserPrompt()) return; + + photopeaMainTrigger.click(); + if (await photopeaContext.invoke(hasActiveDocument) === "false") { + await photopeaContext.fetchFromControlNet(tabs); + } + }); + }); + accordion.querySelector('.photopea-fetch').addEventListener('click', () => photopeaContext.fetchFromControlNet(tabs)); + accordion.querySelector('.photopea-send').addEventListener('click', () => { + photopeaContext.sendToControlNet(tabs) + closeModalButton.click(); + }); + } + + const accordions = gradioApp().querySelectorAll('#controlnet'); + accordions.forEach(accordion => { + if (cnetRegisteredAccordions.has(accordion)) return; + registerCallbacks(accordion); + cnetRegisteredAccordions.add(accordion); + }); + } + + onUiUpdate(loadPhotopea); +})(); \ No newline at end of file diff --git a/extensions-builtin/sd_forge_controlnet/lib_controlnet/controlnet_ui/controlnet_ui_group.py b/extensions-builtin/sd_forge_controlnet/lib_controlnet/controlnet_ui/controlnet_ui_group.py new file mode 100644 index 00000000..f36ca34d --- /dev/null +++ b/extensions-builtin/sd_forge_controlnet/lib_controlnet/controlnet_ui/controlnet_ui_group.py @@ -0,0 +1,1348 @@ +import json +import gradio as gr +import functools +from copy import copy +from typing import List, Optional, Union, Callable, Dict, Tuple, Literal +from dataclasses import dataclass +import numpy as np + +from lib_controlnet.utils import svg_preprocess, read_image +from lib_controlnet import ( + global_state, + external_code, +) +from lib_controlnet.logging import logger +from lib_controlnet.controlnet_ui.openpose_editor import OpenposeEditor +from lib_controlnet.controlnet_ui.preset import ControlNetPresetUI +from lib_controlnet.controlnet_ui.tool_button import ToolButton +from lib_controlnet.controlnet_ui.photopea import Photopea +from lib_controlnet.enums import InputMode +from modules import shared +from modules.ui_components import FormRow +from modules_forge.forge_util import HWC3 + + +@dataclass +class A1111Context: + """Contains all components from A1111.""" + + img2img_batch_input_dir: Optional[gr.components.IOComponent] = None + img2img_batch_output_dir: Optional[gr.components.IOComponent] = None + txt2img_submit_button: Optional[gr.components.IOComponent] = None + img2img_submit_button: Optional[gr.components.IOComponent] = None + + # Slider controls from A1111 WebUI. + txt2img_w_slider: Optional[gr.components.IOComponent] = None + txt2img_h_slider: Optional[gr.components.IOComponent] = None + img2img_w_slider: Optional[gr.components.IOComponent] = None + img2img_h_slider: Optional[gr.components.IOComponent] = None + + img2img_img2img_tab: Optional[gr.components.IOComponent] = None + img2img_img2img_sketch_tab: Optional[gr.components.IOComponent] = None + img2img_batch_tab: Optional[gr.components.IOComponent] = None + img2img_inpaint_tab: Optional[gr.components.IOComponent] = None + img2img_inpaint_sketch_tab: Optional[gr.components.IOComponent] = None + img2img_inpaint_upload_tab: Optional[gr.components.IOComponent] = None + + img2img_inpaint_area: Optional[gr.components.IOComponent] = None + # txt2img_enable_hr is only available for A1111 > 1.7.0. + txt2img_enable_hr: Optional[gr.components.IOComponent] = None + setting_sd_model_checkpoint: Optional[gr.components.IOComponent] = None + + @property + def img2img_inpaint_tabs(self) -> Tuple[gr.components.IOComponent]: + return ( + self.img2img_inpaint_tab, + self.img2img_inpaint_sketch_tab, + self.img2img_inpaint_upload_tab, + ) + + @property + def img2img_non_inpaint_tabs(self) -> List[gr.components.IOComponent]: + return ( + self.img2img_img2img_tab, + self.img2img_img2img_sketch_tab, + self.img2img_batch_tab, + ) + + @property + def ui_initialized(self) -> bool: + optional_components = { + # Optional components are only available after A1111 v1.7.0. + "img2img_img2img_tab": "img2img_img2img_tab", + "img2img_img2img_sketch_tab": "img2img_img2img_sketch_tab", + "img2img_batch_tab": "img2img_batch_tab", + "img2img_inpaint_tab": "img2img_inpaint_tab", + "img2img_inpaint_sketch_tab": "img2img_inpaint_sketch_tab", + "img2img_inpaint_upload_tab": "img2img_inpaint_upload_tab", + # SDNext does not have this field. Temporarily disable the callback on + # the checkpoint change until we find a way to register an event when + # all A1111 UI components are ready. + "setting_sd_model_checkpoint": "setting_sd_model_checkpoint", + } + return all( + c + for name, c in vars(self).items() + if name not in optional_components.values() + ) + + def set_component(self, component: gr.components.IOComponent): + id_mapping = { + "img2img_batch_input_dir": "img2img_batch_input_dir", + "img2img_batch_output_dir": "img2img_batch_output_dir", + "txt2img_generate": "txt2img_submit_button", + "img2img_generate": "img2img_submit_button", + "txt2img_width": "txt2img_w_slider", + "txt2img_height": "txt2img_h_slider", + "img2img_width": "img2img_w_slider", + "img2img_height": "img2img_h_slider", + "img2img_img2img_tab": "img2img_img2img_tab", + "img2img_img2img_sketch_tab": "img2img_img2img_sketch_tab", + "img2img_batch_tab": "img2img_batch_tab", + "img2img_inpaint_tab": "img2img_inpaint_tab", + "img2img_inpaint_sketch_tab": "img2img_inpaint_sketch_tab", + "img2img_inpaint_upload_tab": "img2img_inpaint_upload_tab", + "img2img_inpaint_full_res": "img2img_inpaint_area", + "txt2img_hr-checkbox": "txt2img_enable_hr", + # setting_sd_model_checkpoint is expected to be initialized last. + # "setting_sd_model_checkpoint": "setting_sd_model_checkpoint", + } + elem_id = getattr(component, "elem_id", None) + # Do not set component if it has already been set. + # https://github.com/Mikubill/sd-webui-controlnet/issues/2587 + if elem_id in id_mapping and getattr(self, id_mapping[elem_id]) is None: + setattr(self, id_mapping[elem_id], component) + logger.debug(f"Setting {elem_id}.") + logger.debug( + f"A1111 initialized {sum(c is not None for c in vars(self).values())}/{len(vars(self).keys())}." + ) + + +class UiControlNetUnit(external_code.ControlNetUnit): + """The data class that stores all states of a ControlNetUnit.""" + + def __init__( + self, + input_mode: InputMode = InputMode.SIMPLE, + batch_images: Optional[Union[str, List[external_code.InputImage]]] = None, + output_dir: str = "", + loopback: bool = False, + merge_gallery_files: List[ + Dict[Union[Literal["name"], Literal["data"]], str] + ] = [], + use_preview_as_input: bool = False, + generated_image: Optional[np.ndarray] = None, + mask_image: Optional[np.ndarray] = None, + enabled: bool = True, + module: Optional[str] = None, + model: Optional[str] = None, + weight: float = 1.0, + image: Optional[Dict[str, np.ndarray]] = None, + *args, + **kwargs, + ): + if use_preview_as_input and generated_image is not None: + input_image = generated_image + module = "none" + else: + input_image = image + + # Prefer uploaded mask_image over hand-drawn mask. + if input_image is not None and mask_image is not None: + assert isinstance(input_image, dict) + input_image["mask"] = mask_image + + if merge_gallery_files and input_mode == InputMode.MERGE: + input_image = [ + {"image": read_image(file["name"])} for file in merge_gallery_files + ] + + super().__init__(enabled, module, model, weight, input_image, *args, **kwargs) + self.is_ui = True + self.input_mode = input_mode + self.batch_images = batch_images + self.output_dir = output_dir + self.loopback = loopback + + def unfold_merged(self) -> List[external_code.ControlNetUnit]: + """Unfolds a merged unit to multiple units. Keeps the unit merged for + preprocessors that can accept multiple input images. + """ + if self.input_mode != InputMode.MERGE: + return [copy(self)] + + if self.accepts_multiple_inputs(): + self.input_mode = InputMode.SIMPLE + return [copy(self)] + + assert isinstance(self.image, list) + result = [] + for image in self.image: + unit = copy(self) + unit.image = image["image"] + unit.input_mode = InputMode.SIMPLE + unit.weight = self.weight / len(self.image) + result.append(unit) + return result + + +class ControlNetUiGroup(object): + refresh_symbol = "\U0001f504" # 🔄 + switch_values_symbol = "\U000021C5" # ⇅ + camera_symbol = "\U0001F4F7" # 📷 + reverse_symbol = "\U000021C4" # ⇄ + tossup_symbol = "\u2934" + trigger_symbol = "\U0001F4A5" # 💥 + open_symbol = "\U0001F4DD" # 📝 + + tooltips = { + "🔄": "Refresh", + "\u2934": "Send dimensions to stable diffusion", + "💥": "Run preprocessor", + "📝": "Open new canvas", + "📷": "Enable webcam", + "⇄": "Mirror webcam", + } + + global_batch_input_dir = gr.Textbox( + label="Controlnet input directory", + placeholder="Leave empty to use input directory", + **shared.hide_dirs, + elem_id="controlnet_batch_input_dir", + ) + a1111_context = A1111Context() + # All ControlNetUiGroup instances created. + all_ui_groups: List["ControlNetUiGroup"] = [] + + def __init__( + self, + is_img2img: bool, + default_unit: external_code.ControlNetUnit, + photopea: Optional[Photopea], + ): + # Whether callbacks have been registered. + self.callbacks_registered: bool = False + # Whether the render method on this object has been called. + self.ui_initialized: bool = False + + self.is_img2img = is_img2img + self.default_unit = default_unit + self.photopea = photopea + self.webcam_enabled = False + self.webcam_mirrored = False + + # Note: All gradio elements declared in `render` will be defined as member variable. + # Update counter to trigger a force update of UiControlNetUnit. + # This is useful when a field with no event subscriber available changes. + # e.g. gr.Gallery, gr.State, etc. + self.update_unit_counter = None + self.upload_tab = None + self.image = None + self.generated_image_group = None + self.generated_image = None + self.mask_image_group = None + self.mask_image = None + self.batch_tab = None + self.batch_image_dir = None + self.merge_tab = None + self.merge_gallery = None + self.merge_upload_button = None + self.merge_clear_button = None + self.create_canvas = None + self.canvas_width = None + self.canvas_height = None + self.canvas_create_button = None + self.canvas_cancel_button = None + self.open_new_canvas_button = None + self.webcam_enable = None + self.webcam_mirror = None + self.send_dimen_button = None + self.enabled = None + self.low_vram = None + self.pixel_perfect = None + self.preprocessor_preview = None + self.mask_upload = None + self.type_filter = None + self.module = None + self.trigger_preprocessor = None + self.model = None + self.refresh_models = None + self.weight = None + self.guidance_start = None + self.guidance_end = None + self.advanced = None + self.processor_res = None + self.threshold_a = None + self.threshold_b = None + self.control_mode = None + self.resize_mode = None + self.loopback = None + self.use_preview_as_input = None + self.openpose_editor = None + self.preset_panel = None + self.upload_independent_img_in_img2img = None + self.image_upload_panel = None + self.save_detected_map = None + self.input_mode = gr.State(InputMode.SIMPLE) + self.inpaint_crop_input_image = None + self.hr_option = None + self.batch_image_dir_state = None + self.output_dir_state = None + + # Internal states for UI state pasting. + self.prevent_next_n_module_update = 0 + self.prevent_next_n_slider_value_update = 0 + + # API-only fields + self.advanced_weighting = gr.State(None) + + ControlNetUiGroup.all_ui_groups.append(self) + + def render(self, tabname: str, elem_id_tabname: str) -> None: + """The pure HTML structure of a single ControlNetUnit. Calling this + function will populate `self` with all gradio element declared + in local scope. + + Args: + tabname: + elem_id_tabname: + + Returns: + None + """ + self.update_unit_counter = gr.Number(value=0, visible=False) + self.openpose_editor = OpenposeEditor() + + with gr.Group(visible=not self.is_img2img) as self.image_upload_panel: + self.save_detected_map = gr.Checkbox(value=True, visible=False) + with gr.Tabs(): + with gr.Tab(label="Single Image") as self.upload_tab: + with gr.Row(elem_classes=["cnet-image-row"], equal_height=True): + with gr.Group(elem_classes=["cnet-input-image-group"]): + self.image = gr.Image( + source="upload", + brush_radius=20, + mirror_webcam=False, + type="numpy", + tool="sketch", + elem_id=f"{elem_id_tabname}_{tabname}_input_image", + elem_classes=["cnet-image"], + brush_color=shared.opts.img2img_inpaint_mask_brush_color + if hasattr( + shared.opts, "img2img_inpaint_mask_brush_color" + ) + else None, + ) + self.image.preprocess = functools.partial( + svg_preprocess, preprocess=self.image.preprocess + ) + self.openpose_editor.render_upload() + + with gr.Group( + visible=False, elem_classes=["cnet-generated-image-group"] + ) as self.generated_image_group: + self.generated_image = gr.Image( + value=None, + label="Preprocessor Preview", + elem_id=f"{elem_id_tabname}_{tabname}_generated_image", + elem_classes=["cnet-image"], + interactive=True, + height=242, + ) # Gradio's magic number. Only 242 works. + + with gr.Group( + elem_classes=["cnet-generated-image-control-group"] + ): + if self.photopea: + self.photopea.render_child_trigger() + self.openpose_editor.render_edit() + preview_check_elem_id = f"{elem_id_tabname}_{tabname}_controlnet_preprocessor_preview_checkbox" + preview_close_button_js = f"document.querySelector('#{preview_check_elem_id} input[type=\\'checkbox\\']').click();" + gr.HTML( + value=f"""Close""", + visible=True, + elem_classes=["cnet-close-preview"], + ) + + with gr.Group( + visible=False, elem_classes=["cnet-mask-image-group"] + ) as self.mask_image_group: + self.mask_image = gr.Image( + value=None, + label="Upload Mask", + elem_id=f"{elem_id_tabname}_{tabname}_mask_image", + elem_classes=["cnet-mask-image"], + interactive=True, + ) + + with gr.Tab(label="Batch") as self.batch_tab: + self.batch_image_dir = gr.Textbox( + label="Input Directory", + placeholder="Leave empty to use img2img batch controlnet input directory", + elem_id=f"{elem_id_tabname}_{tabname}_batch_image_dir", + ) + + with gr.Tab(label="Multi-Inputs") as self.merge_tab: + self.merge_gallery = gr.Gallery( + columns=[4], rows=[2], object_fit="contain", height="auto" + ) + with gr.Row(): + self.merge_upload_button = gr.UploadButton( + "Upload Images", + file_types=["image"], + file_count="multiple", + ) + self.merge_clear_button = gr.Button("Clear Images") + + if self.photopea: + self.photopea.attach_photopea_output(self.generated_image) + + with gr.Accordion( + label="Open New Canvas", visible=False + ) as self.create_canvas: + self.canvas_width = gr.Slider( + label="New Canvas Width", + minimum=256, + maximum=1024, + value=512, + step=64, + elem_id=f"{elem_id_tabname}_{tabname}_controlnet_canvas_width", + ) + self.canvas_height = gr.Slider( + label="New Canvas Height", + minimum=256, + maximum=1024, + value=512, + step=64, + elem_id=f"{elem_id_tabname}_{tabname}_controlnet_canvas_height", + ) + with gr.Row(): + self.canvas_create_button = gr.Button( + value="Create New Canvas", + elem_id=f"{elem_id_tabname}_{tabname}_controlnet_canvas_create_button", + ) + self.canvas_cancel_button = gr.Button( + value="Cancel", + elem_id=f"{elem_id_tabname}_{tabname}_controlnet_canvas_cancel_button", + ) + + with gr.Row(elem_classes="controlnet_image_controls"): + gr.HTML( + value="

Set the preprocessor to [invert] If your image has white background and black lines.

", + elem_classes="controlnet_invert_warning", + ) + self.open_new_canvas_button = ToolButton( + value=ControlNetUiGroup.open_symbol, + elem_id=f"{elem_id_tabname}_{tabname}_controlnet_open_new_canvas_button", + tooltip=ControlNetUiGroup.tooltips[ControlNetUiGroup.open_symbol], + ) + self.webcam_enable = ToolButton( + value=ControlNetUiGroup.camera_symbol, + elem_id=f"{elem_id_tabname}_{tabname}_controlnet_webcam_enable", + tooltip=ControlNetUiGroup.tooltips[ControlNetUiGroup.camera_symbol], + ) + self.webcam_mirror = ToolButton( + value=ControlNetUiGroup.reverse_symbol, + elem_id=f"{elem_id_tabname}_{tabname}_controlnet_webcam_mirror", + tooltip=ControlNetUiGroup.tooltips[ + ControlNetUiGroup.reverse_symbol + ], + ) + self.send_dimen_button = ToolButton( + value=ControlNetUiGroup.tossup_symbol, + elem_id=f"{elem_id_tabname}_{tabname}_controlnet_send_dimen_button", + tooltip=ControlNetUiGroup.tooltips[ControlNetUiGroup.tossup_symbol], + ) + + with FormRow(elem_classes=["controlnet_main_options"]): + self.enabled = gr.Checkbox( + label="Enable", + value=self.default_unit.enabled, + elem_id=f"{elem_id_tabname}_{tabname}_controlnet_enable_checkbox", + elem_classes=["cnet-unit-enabled"], + ) + self.low_vram = gr.Checkbox( + label="Low VRAM", + value=self.default_unit.low_vram, + elem_id=f"{elem_id_tabname}_{tabname}_controlnet_low_vram_checkbox", + visible=False, # Not needed now + ) + self.pixel_perfect = gr.Checkbox( + label="Pixel Perfect", + value=self.default_unit.pixel_perfect, + elem_id=f"{elem_id_tabname}_{tabname}_controlnet_pixel_perfect_checkbox", + ) + self.preprocessor_preview = gr.Checkbox( + label="Allow Preview", + value=False, + elem_classes=["cnet-allow-preview"], + elem_id=preview_check_elem_id, + visible=not self.is_img2img, + ) + self.mask_upload = gr.Checkbox( + label="Mask Upload", + value=False, + elem_classes=["cnet-mask-upload"], + elem_id=f"{elem_id_tabname}_{tabname}_controlnet_mask_upload_checkbox", + visible=not self.is_img2img, + ) + self.use_preview_as_input = gr.Checkbox( + label="Preview as Input", + value=False, + elem_classes=["cnet-preview-as-input"], + visible=False, + ) + + with gr.Row(elem_classes="controlnet_img2img_options"): + if self.is_img2img: + self.upload_independent_img_in_img2img = gr.Checkbox( + label="Upload independent control image", + value=False, + elem_id=f"{elem_id_tabname}_{tabname}_controlnet_same_img2img_checkbox", + elem_classes=["cnet-unit-same_img2img"], + ) + else: + self.upload_independent_img_in_img2img = None + + # Note: The checkbox needs to exist for both img2img and txt2img as infotext + # needs the checkbox value. + self.inpaint_crop_input_image = gr.Checkbox( + label="Crop input image based on A1111 mask", + value=False, + elem_classes=["cnet-crop-input-image"], + visible=False, + ) + + with gr.Row(elem_classes=["controlnet_control_type", "controlnet_row"]): + self.type_filter = gr.Radio( + global_state.get_all_preprocessor_tags(), + label=f"Control Type", + value="All", + elem_id=f"{elem_id_tabname}_{tabname}_controlnet_type_filter_radio", + elem_classes="controlnet_control_type_filter_group", + ) + + with gr.Row(elem_classes=["controlnet_preprocessor_model", "controlnet_row"]): + self.module = gr.Dropdown( + global_state.get_all_preprocessor_names(), + label=f"Preprocessor", + value=self.default_unit.module, + elem_id=f"{elem_id_tabname}_{tabname}_controlnet_preprocessor_dropdown", + ) + self.trigger_preprocessor = ToolButton( + value=ControlNetUiGroup.trigger_symbol, + visible=not self.is_img2img, + elem_id=f"{elem_id_tabname}_{tabname}_controlnet_trigger_preprocessor", + elem_classes=["cnet-run-preprocessor"], + tooltip=ControlNetUiGroup.tooltips[ControlNetUiGroup.trigger_symbol], + ) + self.model = gr.Dropdown( + global_state.get_all_controlnet_names(), + label=f"Model", + value=self.default_unit.model, + elem_id=f"{elem_id_tabname}_{tabname}_controlnet_model_dropdown", + ) + self.refresh_models = ToolButton( + value=ControlNetUiGroup.refresh_symbol, + elem_id=f"{elem_id_tabname}_{tabname}_controlnet_refresh_models", + tooltip=ControlNetUiGroup.tooltips[ControlNetUiGroup.refresh_symbol], + ) + + with gr.Row(elem_classes=["controlnet_weight_steps", "controlnet_row"]): + self.weight = gr.Slider( + label=f"Control Weight", + value=self.default_unit.weight, + minimum=0.0, + maximum=2.0, + step=0.05, + elem_id=f"{elem_id_tabname}_{tabname}_controlnet_control_weight_slider", + elem_classes="controlnet_control_weight_slider", + ) + self.guidance_start = gr.Slider( + label="Starting Control Step", + value=self.default_unit.guidance_start, + minimum=0.0, + maximum=1.0, + interactive=True, + elem_id=f"{elem_id_tabname}_{tabname}_controlnet_start_control_step_slider", + elem_classes="controlnet_start_control_step_slider", + ) + self.guidance_end = gr.Slider( + label="Ending Control Step", + value=self.default_unit.guidance_end, + minimum=0.0, + maximum=1.0, + interactive=True, + elem_id=f"{elem_id_tabname}_{tabname}_controlnet_ending_control_step_slider", + elem_classes="controlnet_ending_control_step_slider", + ) + + # advanced options + with gr.Column(visible=False) as self.advanced: + self.processor_res = gr.Slider( + label="Preprocessor resolution", + value=self.default_unit.processor_res, + minimum=64, + maximum=2048, + visible=False, + interactive=True, + elem_id=f"{elem_id_tabname}_{tabname}_controlnet_preprocessor_resolution_slider", + ) + self.threshold_a = gr.Slider( + label="Threshold A", + value=self.default_unit.threshold_a, + minimum=64, + maximum=1024, + visible=False, + interactive=True, + elem_id=f"{elem_id_tabname}_{tabname}_controlnet_threshold_A_slider", + ) + self.threshold_b = gr.Slider( + label="Threshold B", + value=self.default_unit.threshold_b, + minimum=64, + maximum=1024, + visible=False, + interactive=True, + elem_id=f"{elem_id_tabname}_{tabname}_controlnet_threshold_B_slider", + ) + + self.control_mode = gr.Radio( + choices=[e.value for e in external_code.ControlMode], + value=self.default_unit.control_mode.value, + label="Control Mode", + elem_id=f"{elem_id_tabname}_{tabname}_controlnet_control_mode_radio", + elem_classes="controlnet_control_mode_radio", + ) + + self.resize_mode = gr.Radio( + choices=[e.value for e in external_code.ResizeMode], + value=self.default_unit.resize_mode.value, + label="Resize Mode", + elem_id=f"{elem_id_tabname}_{tabname}_controlnet_resize_mode_radio", + elem_classes="controlnet_resize_mode_radio", + visible=not self.is_img2img, + ) + + self.hr_option = gr.Radio( + choices=[e.value for e in external_code.HiResFixOption], + value=self.default_unit.hr_option.value, + label="Hires-Fix Option", + elem_id=f"{elem_id_tabname}_{tabname}_controlnet_hr_option_radio", + elem_classes="controlnet_hr_option_radio", + visible=False, + ) + + self.loopback = gr.Checkbox( + label="[Batch Loopback] Automatically send generated images to this ControlNet unit in batch generation", + value=self.default_unit.loopback, + elem_id=f"{elem_id_tabname}_{tabname}_controlnet_automatically_send_generated_images_checkbox", + elem_classes="controlnet_loopback_checkbox", + visible=False, + ) + + self.preset_panel = ControlNetPresetUI( + id_prefix=f"{elem_id_tabname}_{tabname}_" + ) + + self.batch_image_dir_state = gr.State("") + self.output_dir_state = gr.State("") + unit_args = ( + self.input_mode, + self.batch_image_dir_state, + self.output_dir_state, + self.loopback, + # Non-persistent fields. + # Following inputs will not be persistent on `ControlNetUnit`. + # They are only used during object construction. + self.merge_gallery, + self.use_preview_as_input, + self.generated_image, + self.mask_image, + # End of Non-persistent fields. + self.enabled, + self.module, + self.model, + self.weight, + self.image, + self.resize_mode, + self.low_vram, + self.processor_res, + self.threshold_a, + self.threshold_b, + self.guidance_start, + self.guidance_end, + self.pixel_perfect, + self.control_mode, + self.inpaint_crop_input_image, + self.hr_option, + ) + + unit = gr.State(self.default_unit) + for comp in unit_args + (self.update_unit_counter,): + event_subscribers = [] + if hasattr(comp, "edit"): + event_subscribers.append(comp.edit) + elif hasattr(comp, "click"): + event_subscribers.append(comp.click) + elif isinstance(comp, gr.Slider) and hasattr(comp, "release"): + event_subscribers.append(comp.release) + elif hasattr(comp, "change"): + event_subscribers.append(comp.change) + + if hasattr(comp, "clear"): + event_subscribers.append(comp.clear) + + for event_subscriber in event_subscribers: + event_subscriber( + fn=UiControlNetUnit, inputs=list(unit_args), outputs=unit + ) + + ( + ControlNetUiGroup.a1111_context.img2img_submit_button + if self.is_img2img + else ControlNetUiGroup.a1111_context.txt2img_submit_button + ).click( + fn=UiControlNetUnit, + inputs=list(unit_args), + outputs=unit, + queue=False, + ) + self.register_core_callbacks() + self.ui_initialized = True + return unit + + def register_send_dimensions(self): + """Register event handler for send dimension button.""" + + def send_dimensions(image): + def closesteight(num): + rem = num % 8 + if rem <= 4: + return round(num - rem) + else: + return round(num + (8 - rem)) + + if image: + interm = np.asarray(image.get("image")) + return closesteight(interm.shape[1]), closesteight(interm.shape[0]) + else: + return gr.Slider.update(), gr.Slider.update() + + outputs = ( + [ + ControlNetUiGroup.a1111_context.img2img_w_slider, + ControlNetUiGroup.a1111_context.img2img_h_slider, + ] + if self.is_img2img + else [ + ControlNetUiGroup.a1111_context.txt2img_w_slider, + ControlNetUiGroup.a1111_context.txt2img_h_slider, + ] + ) + self.send_dimen_button.click( + fn=send_dimensions, + inputs=[self.image], + outputs=outputs, + show_progress=False, + ) + + def register_webcam_toggle(self): + def webcam_toggle(): + self.webcam_enabled = not self.webcam_enabled + return { + "value": None, + "source": "webcam" if self.webcam_enabled else "upload", + "__type__": "update", + } + + self.webcam_enable.click( + webcam_toggle, inputs=None, outputs=self.image, show_progress=False + ) + + def register_webcam_mirror_toggle(self): + def webcam_mirror_toggle(): + self.webcam_mirrored = not self.webcam_mirrored + return {"mirror_webcam": self.webcam_mirrored, "__type__": "update"} + + self.webcam_mirror.click( + webcam_mirror_toggle, inputs=None, outputs=self.image, show_progress=False + ) + + def register_refresh_all_models(self): + def refresh_all_models(): + global_state.update_controlnet_filenames() + return gr.Dropdown.update( + choices=global_state.get_all_controlnet_names(), + ) + + self.refresh_models.click( + refresh_all_models, + outputs=[self.model], + show_progress=False, + ) + + def register_build_sliders(self): + def build_sliders(module: str, pp: bool): + + logger.debug( + f"Prevent update slider value: {self.prevent_next_n_slider_value_update}" + ) + logger.debug(f"Build slider for module: {module} - {pp}") + + preprocessor = global_state.get_preprocessor(module) + + slider_resolution_kwargs = preprocessor.slider_resolution.gradio_update_kwargs.copy() + + if pp: + slider_resolution_kwargs['visible'] = False + + grs = [ + gr.update(**slider_resolution_kwargs), + gr.update(**preprocessor.slider_1.gradio_update_kwargs.copy()), + gr.update(**preprocessor.slider_2.gradio_update_kwargs.copy()), + gr.update(visible=True), + gr.update(visible=not preprocessor.do_not_need_model), + gr.update(visible=not preprocessor.do_not_need_model), + gr.update(visible=preprocessor.show_control_mode), + ] + + return grs + + inputs = [ + self.module, + self.pixel_perfect, + ] + outputs = [ + self.processor_res, + self.threshold_a, + self.threshold_b, + self.advanced, + self.model, + self.refresh_models, + self.control_mode, + ] + self.module.change( + build_sliders, inputs=inputs, outputs=outputs, show_progress=False + ) + self.pixel_perfect.change( + build_sliders, inputs=inputs, outputs=outputs, show_progress=False + ) + + def filter_selected(k: str): + logger.debug(f"Prevent update {self.prevent_next_n_module_update}") + logger.debug(f"Switch to control type {k}") + + filtered_preprocessor_list = global_state.get_filtered_preprocessor_names(k) + filtered_controlnet_names = global_state.get_filtered_controlnet_names(k) + default_preprocessor = filtered_preprocessor_list[0] + default_controlnet_name = filtered_controlnet_names[0] + + if k != 'All': + if len(filtered_preprocessor_list) > 1: + default_preprocessor = filtered_preprocessor_list[1] + if len(filtered_controlnet_names) > 1: + default_controlnet_name = filtered_controlnet_names[1] + + if self.prevent_next_n_module_update > 0: + self.prevent_next_n_module_update -= 1 + return [ + gr.Dropdown.update(choices=filtered_preprocessor_list), + gr.Dropdown.update(choices=filtered_controlnet_names), + ] + else: + return [ + gr.Dropdown.update( + value=default_preprocessor, choices=filtered_preprocessor_list + ), + gr.Dropdown.update( + value=default_controlnet_name, choices=filtered_controlnet_names + ), + ] + + self.type_filter.change( + fn=filter_selected, + inputs=[self.type_filter], + outputs=[self.module, self.model], + show_progress=False, + ) + + def register_run_annotator(self): + def run_annotator(image, module, pres, pthr_a, pthr_b, t2i_w, t2i_h, pp, rm): + if image is None: + return ( + gr.update(value=None, visible=True), + gr.update(), + *self.openpose_editor.update(""), + ) + + img = HWC3(image["image"]) + has_mask = not ( + (image["mask"][:, :, 0] <= 5).all() + or (image["mask"][:, :, 0] >= 250).all() + ) + if "inpaint" in module: + color = HWC3(image["image"]) + alpha = image["mask"][:, :, 0:1] + img = np.concatenate([color, alpha], axis=2) + elif has_mask and not shared.opts.data.get( + "controlnet_ignore_noninpaint_mask", False + ): + img = HWC3(image["mask"][:, :, 0]) + + preprocessor = global_state.get_preprocessor(module) + + if pp: + pres = external_code.pixel_perfect_resolution( + img, + target_H=t2i_h, + target_W=t2i_w, + resize_mode=external_code.resize_mode_from_value(rm), + ) + + class JsonAcceptor: + def __init__(self) -> None: + self.value = "" + + def accept(self, json_dict: dict) -> None: + self.value = json.dumps(json_dict) + + json_acceptor = JsonAcceptor() + + logger.info(f"Preview Resolution = {pres}") + + def is_openpose(module: str): + return "openpose" in module + + # Only openpose preprocessor returns a JSON output, pass json_acceptor + # only when a JSON output is expected. This will make preprocessor cache + # work for all other preprocessors other than openpose ones. JSON acceptor + # instance are different every call, which means cache will never take + # effect. + # TODO: Maybe we should let `preprocessor` return a Dict to alleviate this issue? + # This requires changing all callsites though. + result = preprocessor( + input_image=img, + resolution=pres, + slider_1=pthr_a, + slider_2=pthr_b, + low_vram=( + ("clip" in module or module == "ip-adapter_face_id_plus") + and shared.opts.data.get("controlnet_clip_detector_on_cpu", False) + ), + json_pose_callback=json_acceptor.accept + if is_openpose(module) + else None, + ) + + if not isinstance(result, np.ndarray) and result.nidm == 3 and result.shape[2] < 5: + result = img + + result = external_code.visualize_inpaint_mask(result) + return ( + # Update to `generated_image` + gr.update(value=result, visible=True, interactive=False), + # preprocessor_preview + gr.update(value=True), + # openpose editor + *self.openpose_editor.update(json_acceptor.value), + ) + + self.trigger_preprocessor.click( + fn=run_annotator, + inputs=[ + self.image, + self.module, + self.processor_res, + self.threshold_a, + self.threshold_b, + ControlNetUiGroup.a1111_context.img2img_w_slider + if self.is_img2img + else ControlNetUiGroup.a1111_context.txt2img_w_slider, + ControlNetUiGroup.a1111_context.img2img_h_slider + if self.is_img2img + else ControlNetUiGroup.a1111_context.txt2img_h_slider, + self.pixel_perfect, + self.resize_mode, + ], + outputs=[ + self.generated_image, + self.preprocessor_preview, + *self.openpose_editor.outputs(), + ], + ) + + def register_shift_preview(self): + def shift_preview(is_on): + return ( + # generated_image + gr.update() if is_on else gr.update(value=None), + # generated_image_group + gr.update(visible=is_on), + # use_preview_as_input, + gr.update(visible=False), # Now this is automatically managed + # download_pose_link + gr.update() if is_on else gr.update(value=None), + # modal edit button + gr.update() if is_on else gr.update(visible=False), + ) + + self.preprocessor_preview.change( + fn=shift_preview, + inputs=[self.preprocessor_preview], + outputs=[ + self.generated_image, + self.generated_image_group, + self.use_preview_as_input, + self.openpose_editor.download_link, + self.openpose_editor.modal, + ], + show_progress=False, + ) + + def register_create_canvas(self): + self.open_new_canvas_button.click( + lambda: gr.Accordion.update(visible=True), + inputs=None, + outputs=self.create_canvas, + show_progress=False, + ) + self.canvas_cancel_button.click( + lambda: gr.Accordion.update(visible=False), + inputs=None, + outputs=self.create_canvas, + show_progress=False, + ) + + def fn_canvas(h, w): + return np.zeros(shape=(h, w, 3), dtype=np.uint8) + 255, gr.Accordion.update( + visible=False + ) + + self.canvas_create_button.click( + fn=fn_canvas, + inputs=[self.canvas_height, self.canvas_width], + outputs=[self.image, self.create_canvas], + show_progress=False, + ) + + def register_img2img_same_input(self): + def fn_same_checked(x): + return [ + gr.update(value=None), + gr.update(value=None), + gr.update(value=False, visible=x), + ] + [gr.update(visible=x)] * 4 + + self.upload_independent_img_in_img2img.change( + fn_same_checked, + inputs=self.upload_independent_img_in_img2img, + outputs=[ + self.image, + self.batch_image_dir, + self.preprocessor_preview, + self.image_upload_panel, + self.trigger_preprocessor, + self.loopback, + self.resize_mode, + ], + show_progress=False, + ) + + def register_shift_crop_input_image(self): + # A1111 < 1.7.0 compatibility. + if any(c is None for c in ControlNetUiGroup.a1111_context.img2img_inpaint_tabs): + self.inpaint_crop_input_image.visible = True + self.inpaint_crop_input_image.value = True + return + + is_inpaint_tab = gr.State(False) + + def shift_crop_input_image(is_inpaint: bool, inpaint_area: int): + # Note: inpaint_area (0: Whole picture, 1: Only masked) + # By default set value to True, as most preprocessors need cropped result. + return gr.update(value=True, visible=is_inpaint and inpaint_area == 1) + + gradio_kwargs = dict( + fn=shift_crop_input_image, + inputs=[ + is_inpaint_tab, + ControlNetUiGroup.a1111_context.img2img_inpaint_area, + ], + outputs=[self.inpaint_crop_input_image], + show_progress=False, + ) + + for elem in ControlNetUiGroup.a1111_context.img2img_inpaint_tabs: + elem.select(fn=lambda: True, inputs=[], outputs=[is_inpaint_tab]).then( + **gradio_kwargs + ) + + for elem in ControlNetUiGroup.a1111_context.img2img_non_inpaint_tabs: + elem.select(fn=lambda: False, inputs=[], outputs=[is_inpaint_tab]).then( + **gradio_kwargs + ) + + ControlNetUiGroup.a1111_context.img2img_inpaint_area.change(**gradio_kwargs) + + def register_shift_hr_options(self): + # A1111 version < 1.6.0. + if not ControlNetUiGroup.a1111_context.txt2img_enable_hr: + return + + ControlNetUiGroup.a1111_context.txt2img_enable_hr.change( + fn=lambda checked: gr.update(visible=checked), + inputs=[ControlNetUiGroup.a1111_context.txt2img_enable_hr], + outputs=[self.hr_option], + show_progress=False, + ) + + def register_shift_upload_mask(self): + """Controls whether the upload mask input should be visible.""" + self.mask_upload.change( + fn=lambda checked: ( + # Clear mask_image if unchecked. + (gr.update(visible=False), gr.update(value=None)) + if not checked + else (gr.update(visible=True), gr.update()) + ), + inputs=[self.mask_upload], + outputs=[self.mask_image_group, self.mask_image], + show_progress=False, + ) + + if self.upload_independent_img_in_img2img is not None: + self.upload_independent_img_in_img2img.change( + fn=lambda checked: ( + # Uncheck `upload_mask` when not using independent input. + gr.update(visible=False, value=False) + if not checked + else gr.update(visible=True) + ), + inputs=[self.upload_independent_img_in_img2img], + outputs=[self.mask_upload], + show_progress=False, + ) + + def register_sync_batch_dir(self): + def determine_batch_dir(batch_dir, fallback_dir, fallback_fallback_dir): + if batch_dir: + return batch_dir + elif fallback_dir: + return fallback_dir + else: + return fallback_fallback_dir + + batch_dirs = [ + self.batch_image_dir, + ControlNetUiGroup.global_batch_input_dir, + ControlNetUiGroup.a1111_context.img2img_batch_input_dir, + ] + for batch_dir_comp in batch_dirs: + subscriber = getattr(batch_dir_comp, "blur", None) + if subscriber is None: + continue + subscriber( + fn=determine_batch_dir, + inputs=batch_dirs, + outputs=[self.batch_image_dir_state], + queue=False, + ) + + ControlNetUiGroup.a1111_context.img2img_batch_output_dir.blur( + fn=lambda a: a, + inputs=[ControlNetUiGroup.a1111_context.img2img_batch_output_dir], + outputs=[self.output_dir_state], + queue=False, + ) + + def register_clear_preview(self): + def clear_preview(x): + if x: + logger.info("Preview as input is cancelled.") + return gr.update(value=False), gr.update(value=None) + + for comp in ( + self.pixel_perfect, + self.module, + self.image, + self.processor_res, + self.threshold_a, + self.threshold_b, + self.upload_independent_img_in_img2img, + ): + event_subscribers = [] + if hasattr(comp, "edit"): + event_subscribers.append(comp.edit) + elif hasattr(comp, "click"): + event_subscribers.append(comp.click) + elif isinstance(comp, gr.Slider) and hasattr(comp, "release"): + event_subscribers.append(comp.release) + elif hasattr(comp, "change"): + event_subscribers.append(comp.change) + if hasattr(comp, "clear"): + event_subscribers.append(comp.clear) + for event_subscriber in event_subscribers: + event_subscriber( + fn=clear_preview, + inputs=self.use_preview_as_input, + outputs=[self.use_preview_as_input, self.generated_image], + show_progress=False + ) + + def register_multi_images_upload(self): + """Register callbacks on merge tab multiple images upload.""" + self.merge_clear_button.click( + fn=lambda: [], + inputs=[], + outputs=[self.merge_gallery], + ).then( + fn=lambda x: gr.update(value=x + 1), + inputs=[self.update_unit_counter], + outputs=[self.update_unit_counter], + ) + + def upload_file(files, current_files): + return {file_d["name"] for file_d in current_files} | { + file.name for file in files + } + + self.merge_upload_button.upload( + upload_file, + inputs=[self.merge_upload_button, self.merge_gallery], + outputs=[self.merge_gallery], + queue=False, + ).then( + fn=lambda x: gr.update(value=x + 1), + inputs=[self.update_unit_counter], + outputs=[self.update_unit_counter], + ) + + def register_core_callbacks(self): + """Register core callbacks that only involves gradio components defined + within this ui group.""" + self.register_webcam_toggle() + self.register_webcam_mirror_toggle() + self.register_refresh_all_models() + self.register_build_sliders() + self.register_shift_preview() + self.register_shift_upload_mask() + self.register_create_canvas() + self.register_clear_preview() + self.register_multi_images_upload() + self.openpose_editor.register_callbacks( + self.generated_image, + self.use_preview_as_input, + self.model, + ) + assert self.type_filter is not None + self.preset_panel.register_callbacks( + self, + self.type_filter, + *[ + getattr(self, key) + for key in vars(external_code.ControlNetUnit()).keys() + ], + ) + if self.is_img2img: + self.register_img2img_same_input() + + def register_callbacks(self): + """Register callbacks that involves A1111 context gradio components.""" + # Prevent infinite recursion. + if self.callbacks_registered: + return + + self.callbacks_registered = True + self.register_send_dimensions() + self.register_run_annotator() + self.register_sync_batch_dir() + if self.is_img2img: + self.register_shift_crop_input_image() + else: + self.register_shift_hr_options() + + @staticmethod + def register_input_mode_sync(ui_groups: List["ControlNetUiGroup"]): + """ + - ui_group.input_mode should be updated when user switch tabs. + - Loopback checkbox should only be visible if at least one ControlNet unit + is set to batch mode. + + Argument: + ui_groups: All ControlNetUiGroup instances defined in current Script context. + + Returns: + None + """ + if not ui_groups: + return + + for ui_group in ui_groups: + batch_fn = lambda: InputMode.BATCH + simple_fn = lambda: InputMode.SIMPLE + merge_fn = lambda: InputMode.MERGE + for input_tab, fn in ( + (ui_group.upload_tab, simple_fn), + (ui_group.batch_tab, batch_fn), + (ui_group.merge_tab, merge_fn), + ): + # Sync input_mode. + input_tab.select( + fn=fn, + inputs=[], + outputs=[ui_group.input_mode], + show_progress=False, + ).then( + # Update visibility of loopback checkbox. + fn=lambda *mode_values: ( + ( + gr.update( + visible=any(m == InputMode.BATCH for m in mode_values) + ), + ) + * len(ui_groups) + ), + inputs=[g.input_mode for g in ui_groups], + outputs=[g.loopback for g in ui_groups], + show_progress=False, + ) + + @staticmethod + def reset(): + ControlNetUiGroup.a1111_context = A1111Context() + ControlNetUiGroup.all_ui_groups = [] + + @staticmethod + def try_register_all_callbacks(): + unit_count = shared.opts.data.get("control_net_unit_count", 3) + all_unit_count = unit_count * 2 # txt2img + img2img. + if ( + # All A1111 components ControlNet units care about are all registered. + ControlNetUiGroup.a1111_context.ui_initialized + and all_unit_count == len(ControlNetUiGroup.all_ui_groups) + and all( + g.ui_initialized and (not g.callbacks_registered) + for g in ControlNetUiGroup.all_ui_groups + ) + ): + for ui_group in ControlNetUiGroup.all_ui_groups: + ui_group.register_callbacks() + + ControlNetUiGroup.register_input_mode_sync( + [g for g in ControlNetUiGroup.all_ui_groups if g.is_img2img] + ) + ControlNetUiGroup.register_input_mode_sync( + [g for g in ControlNetUiGroup.all_ui_groups if not g.is_img2img] + ) + logger.info("ControlNet UI callback registered.") + + @staticmethod + def on_after_component(component, **_kwargs): + """Register the A1111 component.""" + if getattr(component, "elem_id", None) == "img2img_batch_inpaint_mask_dir": + ControlNetUiGroup.global_batch_input_dir.render() + return + + ControlNetUiGroup.a1111_context.set_component(component) + ControlNetUiGroup.try_register_all_callbacks() diff --git a/extensions-builtin/sd_forge_controlnet/lib_controlnet/controlnet_ui/modal.py b/extensions-builtin/sd_forge_controlnet/lib_controlnet/controlnet_ui/modal.py new file mode 100644 index 00000000..17ea4d67 --- /dev/null +++ b/extensions-builtin/sd_forge_controlnet/lib_controlnet/controlnet_ui/modal.py @@ -0,0 +1,38 @@ +import gradio as gr +from typing import List + + +class ModalInterface(gr.Interface): + modal_id_counter = 0 + + def __init__( + self, + html_content: str, + open_button_text: str, + open_button_classes: List[str] = [], + open_button_extra_attrs: str = '' + ): + self.html_content = html_content + self.open_button_text = open_button_text + self.open_button_classes = open_button_classes + self.open_button_extra_attrs = open_button_extra_attrs + self.modal_id = ModalInterface.modal_id_counter + ModalInterface.modal_id_counter += 1 + + def __call__(self): + return self.create_modal() + + def create_modal(self, visible=True): + html_code = f""" +
+ × +
+ {self.html_content} +
+
+
{self.open_button_text}
+ """ + return gr.HTML(value=html_code, visible=visible) diff --git a/extensions-builtin/sd_forge_controlnet/lib_controlnet/controlnet_ui/openpose_editor.py b/extensions-builtin/sd_forge_controlnet/lib_controlnet/controlnet_ui/openpose_editor.py new file mode 100644 index 00000000..4146018a --- /dev/null +++ b/extensions-builtin/sd_forge_controlnet/lib_controlnet/controlnet_ui/openpose_editor.py @@ -0,0 +1,154 @@ +import base64 +import gradio as gr +import json +from typing import List, Dict, Any, Tuple + +from annotator.openpose import decode_json_as_poses, draw_poses +from annotator.openpose.animalpose import draw_animalposes +from lib_controlnet.controlnet_ui.modal import ModalInterface +from modules import shared +from lib_controlnet.logging import logger + + +def parse_data_url(data_url: str): + # Split the URL at the comma + media_type, data = data_url.split(",", 1) + + # Check if the data is base64-encoded + assert ";base64" in media_type + + # Decode the base64 data + return base64.b64decode(data) + + +def encode_data_url(json_string: str) -> str: + base64_encoded_json = base64.b64encode(json_string.encode("utf-8")).decode("utf-8") + return f"data:application/json;base64,{base64_encoded_json}" + + +class OpenposeEditor(object): + # Filename used when user click the download link. + download_file = "pose.json" + # URL the openpose editor is mounted on. + editor_url = "/openpose_editor_index" + + def __init__(self) -> None: + self.render_button = None + self.pose_input = None + self.download_link = None + self.upload_link = None + self.modal = None + + def render_edit(self): + """Renders the buttons in preview image control button group.""" + # The hidden button to trigger a re-render of generated image. + self.render_button = gr.Button(visible=False, elem_classes=["cnet-render-pose"]) + # The hidden element that stores the pose json for backend retrieval. + # The front-end javascript will write the edited JSON data to the element. + self.pose_input = gr.Textbox(visible=False, elem_classes=["cnet-pose-json"]) + + self.modal = ModalInterface( + # Use about:blank here as placeholder so that the iframe does not + # immediately navigate. Most of controlnet units do not need + # openpose editor active. Only navigate when the user first click + # 'Edit'. The navigation logic is in `openpose_editor.js`. + f'', + open_button_text="Edit", + open_button_classes=["cnet-edit-pose"], + open_button_extra_attrs=f'title="Send pose to {OpenposeEditor.editor_url} for edit."', + ).create_modal(visible=False) + self.download_link = gr.HTML( + value=f"""JSON""", + visible=False, + elem_classes=["cnet-download-pose"], + ) + + def render_upload(self): + """Renders the button in input image control button group.""" + self.upload_link = gr.HTML( + value=""" + + + """, + visible=False, + elem_classes=["cnet-upload-pose"], + ) + + def register_callbacks( + self, + generated_image: gr.Image, + use_preview_as_input: gr.Checkbox, + model: gr.Dropdown, + ): + def render_pose(pose_url: str) -> Tuple[Dict, Dict]: + json_string = parse_data_url(pose_url).decode("utf-8") + poses, animals, height, width = decode_json_as_poses( + json.loads(json_string) + ) + logger.info("Preview as input is enabled.") + return ( + # Generated image. + gr.update( + value=( + draw_poses( + poses, + height, + width, + draw_body=True, + draw_hand=True, + draw_face=True, + ) + if poses + else draw_animalposes(animals, height, width) + ), + visible=True, + ), + # Use preview as input. + gr.update(value=True), + # Self content. + *self.update(json_string), + ) + + self.render_button.click( + fn=render_pose, + inputs=[self.pose_input], + outputs=[generated_image, use_preview_as_input, *self.outputs()], + ) + + def update_upload_link(model: str) -> Dict: + return gr.update(visible="openpose" in model.lower()) + + model.change(fn=update_upload_link, inputs=[model], outputs=[self.upload_link]) + + def outputs(self) -> List[Any]: + return [ + self.download_link, + self.modal, + ] + + def update(self, json_string: str) -> List[Dict]: + """ + Called when there is a new JSON pose value generated by running + preprocessor. + + Args: + json_string: The new JSON string generated by preprocessor. + + Returns: + An gr.update event. + """ + hint = "Download the pose as .json file" + html = f""" + JSON""" + + visible = json_string != "" + return [ + # Download link update. + gr.update(value=html, visible=visible), + # Modal update. + gr.update( + visible=visible + and not shared.opts.data.get("controlnet_disable_openpose_edit", False) + ), + ] diff --git a/extensions-builtin/sd_forge_controlnet/lib_controlnet/controlnet_ui/photopea.py b/extensions-builtin/sd_forge_controlnet/lib_controlnet/controlnet_ui/photopea.py new file mode 100644 index 00000000..5bea02e8 --- /dev/null +++ b/extensions-builtin/sd_forge_controlnet/lib_controlnet/controlnet_ui/photopea.py @@ -0,0 +1,182 @@ +import gradio as gr + +from lib_controlnet.controlnet_ui.modal import ModalInterface + +PHOTOPEA_LOGO = """ + + + + + + + +""" + + +class Photopea(object): + def __init__(self) -> None: + self.modal = None + self.triggers = [] + self.render_editor() + + def render_editor(self): + """Render the editor modal.""" + with gr.Group(elem_classes=["cnet-photopea-edit"]): + self.modal = ModalInterface( + # Use about:blank here as placeholder so that the iframe does not + # immediately navigate. Only navigate when the user first click + # 'Edit'. The navigation logic is in `photopea.js`. + f""" +
+ + +
+ + """, + open_button_text="Edit", + open_button_classes=["cnet-photopea-main-trigger"], + open_button_extra_attrs="hidden", + ).create_modal(visible=True) + + def render_child_trigger(self): + self.triggers.append( + gr.HTML( + f"""
+ Edit {PHOTOPEA_LOGO} +
""" + ) + ) + + def attach_photopea_output(self, generated_image: gr.Image): + """Called in ControlNetUiGroup to attach preprocessor preview image Gradio element + as the photopea output. If the front-end directly change the img HTML element's src + to reflect the edited image result from photopea, the backend won't be notified. + + In this method we let the front-end upload the result image an invisible gr.Image + instance and mirrors the value to preprocessor preview gr.Image. This is because + the generated image gr.Image instance is inferred to be an output image by Gradio + and has no ability to accept image upload directly. + + Arguments: + generated_image: preprocessor result Gradio Image output element. + + Returns: + None + """ + output = gr.Image( + visible=False, + source="upload", + type="numpy", + elem_classes=[f"cnet-photopea-output"], + ) + + output.upload( + fn=lambda img: img, + inputs=[output], + outputs=[generated_image], + ) diff --git a/extensions-builtin/sd_forge_controlnet/lib_controlnet/controlnet_ui/preset.py b/extensions-builtin/sd_forge_controlnet/lib_controlnet/controlnet_ui/preset.py new file mode 100644 index 00000000..831bc93e --- /dev/null +++ b/extensions-builtin/sd_forge_controlnet/lib_controlnet/controlnet_ui/preset.py @@ -0,0 +1,318 @@ +import os +import gradio as gr + +from typing import Dict, List + +from modules import scripts +from lib_controlnet.infotext import parse_unit, serialize_unit +from lib_controlnet.controlnet_ui.tool_button import ToolButton +from lib_controlnet.logging import logger +from lib_controlnet import external_code + +save_symbol = "\U0001f4be" # 💾 +delete_symbol = "\U0001f5d1\ufe0f" # 🗑️ +refresh_symbol = "\U0001f504" # 🔄 +reset_symbol = "\U000021A9" # ↩ + +NEW_PRESET = "New Preset" + + +def load_presets(preset_dir: str) -> Dict[str, str]: + if not os.path.exists(preset_dir): + os.makedirs(preset_dir) + return {} + + presets = {} + for filename in os.listdir(preset_dir): + if filename.endswith(".txt"): + with open(os.path.join(preset_dir, filename), "r") as f: + name = filename.replace(".txt", "") + if name == NEW_PRESET: + continue + presets[name] = f.read() + return presets + + +def infer_control_type(module: str, model: str) -> str: + def matches_control_type(input_string: str, control_type: str) -> bool: + return any(t.lower() in input_string for t in control_type.split("/")) + + control_types = preprocessor_filters.keys() + control_type_candidates = [ + control_type + for control_type in control_types + if ( + matches_control_type(module, control_type) + or matches_control_type(model, control_type) + ) + ] + if len(control_type_candidates) != 1: + raise ValueError( + f"Unable to infer control type from module {module} and model {model}" + ) + return control_type_candidates[0] + + +class ControlNetPresetUI(object): + preset_directory = os.path.join(scripts.basedir(), "presets") + presets = load_presets(preset_directory) + + def __init__(self, id_prefix: str): + with gr.Row(): + self.dropdown = gr.Dropdown( + label="Presets", + show_label=True, + elem_classes=["cnet-preset-dropdown"], + choices=ControlNetPresetUI.dropdown_choices(), + value=NEW_PRESET, + ) + self.reset_button = ToolButton( + value=reset_symbol, + elem_classes=["cnet-preset-reset"], + tooltip="Reset preset", + visible=False, + ) + self.save_button = ToolButton( + value=save_symbol, + elem_classes=["cnet-preset-save"], + tooltip="Save preset", + ) + self.delete_button = ToolButton( + value=delete_symbol, + elem_classes=["cnet-preset-delete"], + tooltip="Delete preset", + ) + self.refresh_button = ToolButton( + value=refresh_symbol, + elem_classes=["cnet-preset-refresh"], + tooltip="Refresh preset", + ) + + with gr.Box( + elem_classes=["popup-dialog", "cnet-preset-enter-name"], + elem_id=f"{id_prefix}_cnet_preset_enter_name", + ) as self.name_dialog: + with gr.Row(): + self.preset_name = gr.Textbox( + label="Preset name", + show_label=True, + lines=1, + elem_classes=["cnet-preset-name"], + ) + self.confirm_preset_name = ToolButton( + value=save_symbol, + elem_classes=["cnet-preset-confirm-name"], + tooltip="Save preset", + ) + + def register_callbacks( + self, + uigroup, + control_type: gr.Radio, + *ui_states, + ): + def apply_preset(name: str, control_type: str, *ui_states): + if name == NEW_PRESET: + return ( + gr.update(visible=False), + *( + (gr.skip(),) + * (len(vars(external_code.ControlNetUnit()).keys()) + 1) + ), + ) + + assert name in ControlNetPresetUI.presets + + infotext = ControlNetPresetUI.presets[name] + preset_unit = parse_unit(infotext) + current_unit = external_code.ControlNetUnit(*ui_states) + preset_unit.image = None + current_unit.image = None + + # Do not compare module param that are not used in preset. + for module_param in ("processor_res", "threshold_a", "threshold_b"): + if getattr(preset_unit, module_param) == -1: + setattr(current_unit, module_param, -1) + + # No update necessary. + if vars(current_unit) == vars(preset_unit): + return ( + gr.update(visible=False), + *( + (gr.skip(),) + * (len(vars(external_code.ControlNetUnit()).keys()) + 1) + ), + ) + + unit = preset_unit + + try: + new_control_type = infer_control_type(unit.module, unit.model) + except ValueError as e: + logger.error(e) + new_control_type = control_type + + if new_control_type != control_type: + uigroup.prevent_next_n_module_update += 1 + + if preset_unit.module != current_unit.module: + uigroup.prevent_next_n_slider_value_update += 1 + + if preset_unit.pixel_perfect != current_unit.pixel_perfect: + uigroup.prevent_next_n_slider_value_update += 1 + + return ( + gr.update(visible=True), + gr.update(value=new_control_type), + *[ + gr.update(value=value) if value is not None else gr.update() + for value in vars(unit).values() + ], + ) + + for element, action in ( + (self.dropdown, "change"), + (self.reset_button, "click"), + ): + getattr(element, action)( + fn=apply_preset, + inputs=[self.dropdown, control_type, *ui_states], + outputs=[self.delete_button, control_type, *ui_states], + show_progress="hidden", + ).then( + fn=lambda: gr.update(visible=False), + inputs=None, + outputs=[self.reset_button], + ) + + def save_preset(name: str, *ui_states): + if name == NEW_PRESET: + return gr.update(visible=True), gr.update(), gr.update() + + ControlNetPresetUI.save_preset( + name, external_code.ControlNetUnit(*ui_states) + ) + return ( + gr.update(), # name dialog + gr.update(choices=ControlNetPresetUI.dropdown_choices(), value=name), + gr.update(visible=False), # Reset button + ) + + self.save_button.click( + fn=save_preset, + inputs=[self.dropdown, *ui_states], + outputs=[self.name_dialog, self.dropdown, self.reset_button], + show_progress="hidden", + ).then( + fn=None, + _js=f""" + (name) => {{ + if (name === "{NEW_PRESET}") + popup(gradioApp().getElementById('{self.name_dialog.elem_id}')); + }}""", + inputs=[self.dropdown], + ) + + def delete_preset(name: str): + ControlNetPresetUI.delete_preset(name) + return gr.Dropdown.update( + choices=ControlNetPresetUI.dropdown_choices(), + value=NEW_PRESET, + ), gr.update(visible=False) + + self.delete_button.click( + fn=delete_preset, + inputs=[self.dropdown], + outputs=[self.dropdown, self.reset_button], + show_progress="hidden", + ) + + self.name_dialog.visible = False + + def save_new_preset(new_name: str, *ui_states): + if new_name == NEW_PRESET: + logger.warn(f"Cannot save preset with reserved name '{NEW_PRESET}'") + return gr.update(visible=False), gr.update() + + ControlNetPresetUI.save_preset( + new_name, external_code.ControlNetUnit(*ui_states) + ) + return gr.update(visible=False), gr.update( + choices=ControlNetPresetUI.dropdown_choices(), value=new_name + ) + + self.confirm_preset_name.click( + fn=save_new_preset, + inputs=[self.preset_name, *ui_states], + outputs=[self.name_dialog, self.dropdown], + show_progress="hidden", + ).then(fn=None, _js="closePopup") + + self.refresh_button.click( + fn=ControlNetPresetUI.refresh_preset, + inputs=None, + outputs=[self.dropdown], + show_progress="hidden", + ) + + def update_reset_button(preset_name: str, *ui_states): + if preset_name == NEW_PRESET: + return gr.update(visible=False) + + infotext = ControlNetPresetUI.presets[preset_name] + preset_unit = parse_unit(infotext) + current_unit = external_code.ControlNetUnit(*ui_states) + preset_unit.image = None + current_unit.image = None + + # Do not compare module param that are not used in preset. + for module_param in ("processor_res", "threshold_a", "threshold_b"): + if getattr(preset_unit, module_param) == -1: + setattr(current_unit, module_param, -1) + + return gr.update(visible=vars(current_unit) != vars(preset_unit)) + + for ui_state in ui_states: + if isinstance(ui_state, gr.Image): + continue + + for action in ("edit", "click", "change", "clear", "release"): + if action == "release" and not isinstance(ui_state, gr.Slider): + continue + + if hasattr(ui_state, action): + getattr(ui_state, action)( + fn=update_reset_button, + inputs=[self.dropdown, *ui_states], + outputs=[self.reset_button], + ) + + @staticmethod + def dropdown_choices() -> List[str]: + return list(ControlNetPresetUI.presets.keys()) + [NEW_PRESET] + + @staticmethod + def save_preset(name: str, unit: external_code.ControlNetUnit): + infotext = serialize_unit(unit) + with open( + os.path.join(ControlNetPresetUI.preset_directory, f"{name}.txt"), "w" + ) as f: + f.write(infotext) + + ControlNetPresetUI.presets[name] = infotext + + @staticmethod + def delete_preset(name: str): + if name not in ControlNetPresetUI.presets: + return + + del ControlNetPresetUI.presets[name] + + file = os.path.join(ControlNetPresetUI.preset_directory, f"{name}.txt") + if os.path.exists(file): + os.unlink(file) + + @staticmethod + def refresh_preset(): + ControlNetPresetUI.presets = load_presets(ControlNetPresetUI.preset_directory) + return gr.update(choices=ControlNetPresetUI.dropdown_choices()) diff --git a/extensions-builtin/sd_forge_controlnet/lib_controlnet/controlnet_ui/tool_button.py b/extensions-builtin/sd_forge_controlnet/lib_controlnet/controlnet_ui/tool_button.py new file mode 100644 index 00000000..8a38df8f --- /dev/null +++ b/extensions-builtin/sd_forge_controlnet/lib_controlnet/controlnet_ui/tool_button.py @@ -0,0 +1,12 @@ +import gradio as gr + +class ToolButton(gr.Button, gr.components.FormComponent): + """Small button with single emoji as text, fits inside gradio forms""" + + def __init__(self, **kwargs): + super().__init__(variant="tool", + elem_classes=kwargs.pop('elem_classes', []) + ["cnet-toolbutton"], + **kwargs) + + def get_block_name(self): + return "button" diff --git a/extensions-builtin/sd_forge_controlnet/lib_controlnet/enums.py b/extensions-builtin/sd_forge_controlnet/lib_controlnet/enums.py new file mode 100644 index 00000000..05dc8a63 --- /dev/null +++ b/extensions-builtin/sd_forge_controlnet/lib_controlnet/enums.py @@ -0,0 +1,73 @@ +from enum import Enum +from typing import Any + + +class StableDiffusionVersion(Enum): + """The version family of stable diffusion model.""" + + UNKNOWN = 0 + SD1x = 1 + SD2x = 2 + SDXL = 3 + + @staticmethod + def detect_from_model_name(model_name: str) -> "StableDiffusionVersion": + """Based on the model name provided, guess what stable diffusion version it is. + This might not be accurate without actually inspect the file content. + """ + if any(f"sd{v}" in model_name.lower() for v in ("14", "15", "16")): + return StableDiffusionVersion.SD1x + + if "sd21" in model_name or "2.1" in model_name: + return StableDiffusionVersion.SD2x + + if "xl" in model_name.lower(): + return StableDiffusionVersion.SDXL + + return StableDiffusionVersion.UNKNOWN + + def encoder_block_num(self) -> int: + if self in (StableDiffusionVersion.SD1x, StableDiffusionVersion.SD2x, StableDiffusionVersion.UNKNOWN): + return 12 + else: + return 9 # SDXL + + def controlnet_layer_num(self) -> int: + return self.encoder_block_num() + 1 + + def is_compatible_with(self, other: "StableDiffusionVersion") -> bool: + """ Incompatible only when one of version is SDXL and other is not. """ + return ( + any(v == StableDiffusionVersion.UNKNOWN for v in [self, other]) or + sum(v == StableDiffusionVersion.SDXL for v in [self, other]) != 1 + ) + + +class HiResFixOption(Enum): + BOTH = "Both" + LOW_RES_ONLY = "Low res only" + HIGH_RES_ONLY = "High res only" + + @staticmethod + def from_value(value: Any) -> "HiResFixOption": + if isinstance(value, str) and value.startswith("HiResFixOption."): + _, field = value.split(".") + return getattr(HiResFixOption, field) + if isinstance(value, str): + return HiResFixOption(value) + elif isinstance(value, int): + return [x for x in HiResFixOption][value] + else: + assert isinstance(value, HiResFixOption) + return value + + +class InputMode(Enum): + # Single image to a single ControlNet unit. + SIMPLE = "simple" + # Input is a directory. N generations. Each generation takes 1 input image + # from the directory. + BATCH = "batch" + # Input is a directory. 1 generation. Each generation takes N input image + # from the directory. + MERGE = "merge" diff --git a/extensions-builtin/sd_forge_controlnet/lib_controlnet/external_code.py b/extensions-builtin/sd_forge_controlnet/lib_controlnet/external_code.py new file mode 100644 index 00000000..92f3904b --- /dev/null +++ b/extensions-builtin/sd_forge_controlnet/lib_controlnet/external_code.py @@ -0,0 +1,460 @@ +from dataclasses import dataclass +from enum import Enum +from copy import copy +from typing import List, Any, Optional, Union, Tuple, Dict +import numpy as np +from modules import scripts, processing, shared +from lib_controlnet import global_state +from lib_controlnet.logging import logger +from lib_controlnet.enums import HiResFixOption + +from modules.api import api + + +def get_api_version() -> int: + return 2 + + +class ControlMode(Enum): + """ + The improved guess mode. + """ + + BALANCED = "Balanced" + PROMPT = "My prompt is more important" + CONTROL = "ControlNet is more important" + + +class BatchOption(Enum): + DEFAULT = "All ControlNet units for all images in a batch" + SEPARATE = "Each ControlNet unit for each image in a batch" + + +class ResizeMode(Enum): + """ + Resize modes for ControlNet input images. + """ + + RESIZE = "Just Resize" + INNER_FIT = "Crop and Resize" + OUTER_FIT = "Resize and Fill" + + def int_value(self): + if self == ResizeMode.RESIZE: + return 0 + elif self == ResizeMode.INNER_FIT: + return 1 + elif self == ResizeMode.OUTER_FIT: + return 2 + assert False, "NOTREACHED" + + +resize_mode_aliases = { + 'Inner Fit (Scale to Fit)': 'Crop and Resize', + 'Outer Fit (Shrink to Fit)': 'Resize and Fill', + 'Scale to Fit (Inner Fit)': 'Crop and Resize', + 'Envelope (Outer Fit)': 'Resize and Fill', +} + + +def resize_mode_from_value(value: Union[str, int, ResizeMode]) -> ResizeMode: + if isinstance(value, str): + return ResizeMode(resize_mode_aliases.get(value, value)) + elif isinstance(value, int): + assert value >= 0 + if value == 3: # 'Just Resize (Latent upscale)' + return ResizeMode.RESIZE + + if value >= len(ResizeMode): + logger.warning(f'Unrecognized ResizeMode int value {value}. Fall back to RESIZE.') + return ResizeMode.RESIZE + + return [e for e in ResizeMode][value] + else: + return value + + +def control_mode_from_value(value: Union[str, int, ControlMode]) -> ControlMode: + if isinstance(value, str): + return ControlMode(value) + elif isinstance(value, int): + return [e for e in ControlMode][value] + else: + return value + + +def visualize_inpaint_mask(img): + if img.ndim == 3 and img.shape[2] == 4: + result = img.copy() + mask = result[:, :, 3] + mask = 255 - mask // 2 + result[:, :, 3] = mask + return np.ascontiguousarray(result.copy()) + return img + + +def pixel_perfect_resolution( + image: np.ndarray, + target_H: int, + target_W: int, + resize_mode: ResizeMode, +) -> int: + """ + Calculate the estimated resolution for resizing an image while preserving aspect ratio. + + The function first calculates scaling factors for height and width of the image based on the target + height and width. Then, based on the chosen resize mode, it either takes the smaller or the larger + scaling factor to estimate the new resolution. + + If the resize mode is OUTER_FIT, the function uses the smaller scaling factor, ensuring the whole image + fits within the target dimensions, potentially leaving some empty space. + + If the resize mode is not OUTER_FIT, the function uses the larger scaling factor, ensuring the target + dimensions are fully filled, potentially cropping the image. + + After calculating the estimated resolution, the function prints some debugging information. + + Args: + image (np.ndarray): A 3D numpy array representing an image. The dimensions represent [height, width, channels]. + target_H (int): The target height for the image. + target_W (int): The target width for the image. + resize_mode (ResizeMode): The mode for resizing. + + Returns: + int: The estimated resolution after resizing. + """ + raw_H, raw_W, _ = image.shape + + k0 = float(target_H) / float(raw_H) + k1 = float(target_W) / float(raw_W) + + if resize_mode == ResizeMode.OUTER_FIT: + estimation = min(k0, k1) * float(min(raw_H, raw_W)) + else: + estimation = max(k0, k1) * float(min(raw_H, raw_W)) + + logger.debug(f"Pixel Perfect Computation:") + logger.debug(f"resize_mode = {resize_mode}") + logger.debug(f"raw_H = {raw_H}") + logger.debug(f"raw_W = {raw_W}") + logger.debug(f"target_H = {target_H}") + logger.debug(f"target_W = {target_W}") + logger.debug(f"estimation = {estimation}") + + return int(np.round(estimation)) + + +InputImage = Union[np.ndarray, str] +InputImage = Union[Dict[str, InputImage], Tuple[InputImage, InputImage], InputImage] + + +@dataclass +class ControlNetUnit: + """ + Represents an entire ControlNet processing unit. + """ + enabled: bool = True + module: str = "none" + model: str = "None" + weight: float = 1.0 + image: Optional[Union[InputImage, List[InputImage]]] = None + resize_mode: Union[ResizeMode, int, str] = ResizeMode.INNER_FIT + low_vram: bool = False + processor_res: int = -1 + threshold_a: float = -1 + threshold_b: float = -1 + guidance_start: float = 0.0 + guidance_end: float = 1.0 + pixel_perfect: bool = False + control_mode: Union[ControlMode, int, str] = ControlMode.BALANCED + # Whether to crop input image based on A1111 img2img mask. This flag is only used when `inpaint area` + # in A1111 is set to `Only masked`. In API, this correspond to `inpaint_full_res = True`. + inpaint_crop_input_image: bool = True + # If hires fix is enabled in A1111, how should this ControlNet unit be applied. + # The value is ignored if the generation is not using hires fix. + hr_option: Union[HiResFixOption, int, str] = HiResFixOption.BOTH + + # Whether save the detected map of this unit. Setting this option to False prevents saving the + # detected map or sending detected map along with generated images via API. + # Currently the option is only accessible in API calls. + save_detected_map: bool = True + + # Weight for each layer of ControlNet params. + # For ControlNet: + # - SD1.5: 13 weights (4 encoder block * 3 + 1 middle block) + # - SDXL: 10 weights (3 encoder block * 3 + 1 middle block) + # For T2IAdapter + # - SD1.5: 5 weights (4 encoder block + 1 middle block) + # - SDXL: 4 weights (3 encoder block + 1 middle block) + # Note1: Setting advanced weighting will disable `soft_injection`, i.e. + # It is recommended to set ControlMode = BALANCED when using `advanced_weighting`. + # Note2: The field `weight` is still used in some places, e.g. reference_only, + # even advanced_weighting is set. + advanced_weighting: Optional[List[float]] = None + + def __eq__(self, other): + if not isinstance(other, ControlNetUnit): + return False + + return vars(self) == vars(other) + + def accepts_multiple_inputs(self) -> bool: + """This unit can accept multiple input images.""" + return self.module in ( + "ip-adapter_clip_sdxl", + "ip-adapter_clip_sdxl_plus_vith", + "ip-adapter_clip_sd15", + "ip-adapter_face_id", + "ip-adapter_face_id_plus", + "instant_id_face_embedding", + ) + + +def to_base64_nparray(encoding: str): + """ + Convert a base64 image into the image type the extension uses + """ + + return np.array(api.decode_base64_to_image(encoding)).astype('uint8') + + +def get_all_units_in_processing(p: processing.StableDiffusionProcessing) -> List[ControlNetUnit]: + """ + Fetch ControlNet processing units from a StableDiffusionProcessing. + """ + + return get_all_units(p.scripts, p.script_args) + + +def get_all_units(script_runner: scripts.ScriptRunner, script_args: List[Any]) -> List[ControlNetUnit]: + """ + Fetch ControlNet processing units from an existing script runner. + Use this function to fetch units from the list of all scripts arguments. + """ + + cn_script = find_cn_script(script_runner) + if cn_script: + return get_all_units_from(script_args[cn_script.args_from:cn_script.args_to]) + + return [] + + +def get_all_units_from(script_args: List[Any]) -> List[ControlNetUnit]: + """ + Fetch ControlNet processing units from ControlNet script arguments. + Use `external_code.get_all_units` to fetch units from the list of all scripts arguments. + """ + + def is_stale_unit(script_arg: Any) -> bool: + """ Returns whether the script_arg is potentially an stale version of + ControlNetUnit created before module reload.""" + return ( + 'ControlNetUnit' in type(script_arg).__name__ and + not isinstance(script_arg, ControlNetUnit) + ) + + def is_controlnet_unit(script_arg: Any) -> bool: + """ Returns whether the script_arg is ControlNetUnit or anything that + can be treated like ControlNetUnit. """ + return ( + isinstance(script_arg, (ControlNetUnit, dict)) or + ( + hasattr(script_arg, '__dict__') and + set(vars(ControlNetUnit()).keys()).issubset( + set(vars(script_arg).keys())) + ) + ) + + all_units = [ + to_processing_unit(script_arg) + for script_arg in script_args + if is_controlnet_unit(script_arg) + ] + if not all_units: + logger.warning( + "No ControlNetUnit detected in args. It is very likely that you are having an extension conflict." + f"Here are args received by ControlNet: {script_args}.") + if any(is_stale_unit(script_arg) for script_arg in script_args): + logger.debug( + "Stale version of ControlNetUnit detected. The ControlNetUnit received" + "by ControlNet is created before the newest load of ControlNet extension." + "They will still be used by ControlNet as long as they provide same fields" + "defined in the newest version of ControlNetUnit." + ) + + return all_units + + +def get_single_unit_from(script_args: List[Any], index: int = 0) -> Optional[ControlNetUnit]: + """ + Fetch a single ControlNet processing unit from ControlNet script arguments. + The list must not contain script positional arguments. It must only contain processing units. + """ + + i = 0 + while i < len(script_args) and index >= 0: + if index == 0 and script_args[i] is not None: + return to_processing_unit(script_args[i]) + i += 1 + + index -= 1 + + return None + + +def get_max_models_num(): + """ + Fetch the maximum number of allowed ControlNet models. + """ + + max_models_num = shared.opts.data.get("control_net_unit_count", 3) + return max_models_num + + +def to_processing_unit(unit: Union[Dict[str, Any], ControlNetUnit]) -> ControlNetUnit: + """ + Convert different types to processing unit. + If `unit` is a dict, alternative keys are supported. See `ext_compat_keys` in implementation for details. + """ + + ext_compat_keys = { + 'guessmode': 'guess_mode', + 'guidance': 'guidance_end', + 'lowvram': 'low_vram', + 'input_image': 'image' + } + + if isinstance(unit, dict): + unit = {ext_compat_keys.get(k, k): v for k, v in unit.items()} + + mask = None + if 'mask' in unit: + mask = unit['mask'] + del unit['mask'] + + if 'image' in unit and not isinstance(unit['image'], dict): + unit['image'] = {'image': unit['image'], 'mask': mask} if mask is not None else unit['image'] if unit[ + 'image'] else None + + if 'guess_mode' in unit: + logger.warning('Guess Mode is removed since 1.1.136. Please use Control Mode instead.') + + unit = ControlNetUnit(**{k: v for k, v in unit.items() if k in vars(ControlNetUnit).keys()}) + + # temporary, check #602 + # assert isinstance(unit, ControlNetUnit), f'bad argument to controlnet extension: {unit}\nexpected Union[dict[str, Any], ControlNetUnit]' + return unit + + +def update_cn_script_in_processing( + p: processing.StableDiffusionProcessing, + cn_units: List[ControlNetUnit], + **_kwargs, # for backwards compatibility +): + """ + Update the arguments of the ControlNet script in `p.script_args` in place, reading from `cn_units`. + `cn_units` and its elements are not modified. You can call this function repeatedly, as many times as you want. + + Does not update `p.script_args` if any of the folling is true: + - ControlNet is not present in `p.scripts` + - `p.script_args` is not filled with script arguments for scripts that are processed before ControlNet + """ + p.script_args = update_cn_script(p.scripts, p.script_args_value, cn_units) + + +def update_cn_script( + script_runner: scripts.ScriptRunner, + script_args: Union[Tuple[Any], List[Any]], + cn_units: List[ControlNetUnit], +) -> Union[Tuple[Any], List[Any]]: + """ + Returns: The updated `script_args` with given `cn_units` used as ControlNet + script args. + + Does not update `script_args` if any of the folling is true: + - ControlNet is not present in `script_runner` + - `script_args` is not filled with script arguments for scripts that are + processed before ControlNet + """ + script_args_type = type(script_args) + assert script_args_type in (tuple, list), script_args_type + updated_script_args = list(copy(script_args)) + + cn_script = find_cn_script(script_runner) + + if cn_script is None or len(script_args) < cn_script.args_from: + return script_args + + # fill in remaining parameters to satisfy max models, just in case script needs it. + max_models = shared.opts.data.get("control_net_unit_count", 3) + cn_units = cn_units + [ControlNetUnit(enabled=False)] * max(max_models - len(cn_units), 0) + + cn_script_args_diff = 0 + for script in script_runner.alwayson_scripts: + if script is cn_script: + cn_script_args_diff = len(cn_units) - (cn_script.args_to - cn_script.args_from) + updated_script_args[script.args_from:script.args_to] = cn_units + script.args_to = script.args_from + len(cn_units) + else: + script.args_from += cn_script_args_diff + script.args_to += cn_script_args_diff + + return script_args_type(updated_script_args) + + +def update_cn_script_in_place( + script_runner: scripts.ScriptRunner, + script_args: List[Any], + cn_units: List[ControlNetUnit], + **_kwargs, # for backwards compatibility +): + """ + @Deprecated(Raises assertion error if script_args passed in is Tuple) + + Update the arguments of the ControlNet script in `script_args` in place, reading from `cn_units`. + `cn_units` and its elements are not modified. You can call this function repeatedly, as many times as you want. + + Does not update `script_args` if any of the folling is true: + - ControlNet is not present in `script_runner` + - `script_args` is not filled with script arguments for scripts that are processed before ControlNet + """ + assert isinstance(script_args, list), type(script_args) + + cn_script = find_cn_script(script_runner) + if cn_script is None or len(script_args) < cn_script.args_from: + return + + # fill in remaining parameters to satisfy max models, just in case script needs it. + max_models = shared.opts.data.get("control_net_unit_count", 3) + cn_units = cn_units + [ControlNetUnit(enabled=False)] * max(max_models - len(cn_units), 0) + + cn_script_args_diff = 0 + for script in script_runner.alwayson_scripts: + if script is cn_script: + cn_script_args_diff = len(cn_units) - (cn_script.args_to - cn_script.args_from) + script_args[script.args_from:script.args_to] = cn_units + script.args_to = script.args_from + len(cn_units) + else: + script.args_from += cn_script_args_diff + script.args_to += cn_script_args_diff + + +def find_cn_script(script_runner: scripts.ScriptRunner) -> Optional[scripts.Script]: + """ + Find the ControlNet script in `script_runner`. Returns `None` if `script_runner` does not contain a ControlNet script. + """ + + if script_runner is None: + return None + + for script in script_runner.alwayson_scripts: + if is_cn_script(script): + return script + + +def is_cn_script(script: scripts.Script) -> bool: + """ + Determine whether `script` is a ControlNet script. + """ + + return script.title().lower() == 'controlnet' diff --git a/extensions-builtin/sd_forge_controlnet/lib_controlnet/global_state.py b/extensions-builtin/sd_forge_controlnet/lib_controlnet/global_state.py new file mode 100644 index 00000000..58ceff21 --- /dev/null +++ b/extensions-builtin/sd_forge_controlnet/lib_controlnet/global_state.py @@ -0,0 +1,138 @@ +import os.path +import stat +from collections import OrderedDict + +from modules import shared, sd_models +from lib_controlnet.enums import StableDiffusionVersion +from modules_forge.shared import controlnet_dir, supported_preprocessors + + +CN_MODEL_EXTS = [".pt", ".pth", ".ckpt", ".safetensors", ".bin"] + + +def traverse_all_files(curr_path, model_list): + f_list = [ + (os.path.join(curr_path, entry.name), entry.stat()) + for entry in os.scandir(curr_path) + if os.path.isdir(curr_path) + ] + for f_info in f_list: + fname, fstat = f_info + if os.path.splitext(fname)[1] in CN_MODEL_EXTS: + model_list.append(f_info) + elif stat.S_ISDIR(fstat.st_mode): + model_list = traverse_all_files(fname, model_list) + return model_list + + +def get_all_models(sort_by, filter_by, path): + res = OrderedDict() + fileinfos = traverse_all_files(path, []) + filter_by = filter_by.strip(" ") + if len(filter_by) != 0: + fileinfos = [x for x in fileinfos if filter_by.lower() + in os.path.basename(x[0]).lower()] + if sort_by == "name": + fileinfos = sorted(fileinfos, key=lambda x: os.path.basename(x[0])) + elif sort_by == "date": + fileinfos = sorted(fileinfos, key=lambda x: -x[1].st_mtime) + elif sort_by == "path name": + fileinfos = sorted(fileinfos) + + for finfo in fileinfos: + filename = finfo[0] + name = os.path.splitext(os.path.basename(filename))[0] + # Prevent a hypothetical "None.pt" from being listed. + if name != "None": + res[name + f" [{sd_models.model_hash(filename)}]"] = filename + + return res + + +controlnet_filename_dict = {'None': 'model.safetensors'} +controlnet_names = ['None'] + + +def get_preprocessor(name): + return supported_preprocessors.get(name, None) + + +def get_sorted_preprocessors(): + preprocessors = [p for k, p in supported_preprocessors.items() if k != 'None'] + preprocessors = sorted(preprocessors, key=lambda x: str(x.sorting_priority).zfill(8) + x.name)[::-1] + results = OrderedDict() + results['None'] = supported_preprocessors['None'] + for p in preprocessors: + results[p.name] = p + return results + + +def get_all_controlnet_names(): + return controlnet_names + + +def get_controlnet_filename(controlnet_name): + return controlnet_filename_dict[controlnet_name] + + +def get_all_preprocessor_names(): + return list(get_sorted_preprocessors().keys()) + + +def get_all_preprocessor_tags(): + tags = [] + for k, p in supported_preprocessors.items(): + tags += p.tags + tags = list(set(tags)) + tags = sorted(tags) + return ['All'] + tags + + +def get_filtered_preprocessors(tag): + if tag == 'All': + return supported_preprocessors + return {k: v for k, v in get_sorted_preprocessors().items() if tag in v.tags or k == 'None'} + + +def get_filtered_preprocessor_names(tag): + return list(get_filtered_preprocessors(tag).keys()) + + +def get_filtered_controlnet_names(tag): + filtered_preprocessors = get_filtered_preprocessors(tag) + model_filename_filers = [] + for p in filtered_preprocessors.values(): + model_filename_filers += p.model_filename_filers + return [x for x in controlnet_names if any(f.lower() in x.lower() for f in model_filename_filers) or x == 'None'] + + +def update_controlnet_filenames(): + global controlnet_filename_dict, controlnet_names + + controlnet_filename_dict = {'None': 'model.safetensors'} + controlnet_names = ['None'] + + ext_dirs = (shared.opts.data.get("control_net_models_path", None), getattr(shared.cmd_opts, 'controlnet_dir', None)) + extra_lora_paths = (extra_lora_path for extra_lora_path in ext_dirs + if extra_lora_path is not None and os.path.exists(extra_lora_path)) + paths = [controlnet_dir, *extra_lora_paths] + + for path in paths: + sort_by = shared.opts.data.get("control_net_models_sort_models_by", "name") + filter_by = shared.opts.data.get("control_net_models_name_filter", "") + found = get_all_models(sort_by, filter_by, path) + controlnet_filename_dict.update(found) + + controlnet_names = list(controlnet_filename_dict.keys()) + return + + +def get_sd_version() -> StableDiffusionVersion: + if shared.sd_model.is_sdxl: + return StableDiffusionVersion.SDXL + elif shared.sd_model.is_sd2: + return StableDiffusionVersion.SD2x + elif shared.sd_model.is_sd1: + return StableDiffusionVersion.SD1x + else: + return StableDiffusionVersion.UNKNOWN diff --git a/extensions-builtin/sd_forge_controlnet/lib_controlnet/infotext.py b/extensions-builtin/sd_forge_controlnet/lib_controlnet/infotext.py new file mode 100644 index 00000000..8a3a063b --- /dev/null +++ b/extensions-builtin/sd_forge_controlnet/lib_controlnet/infotext.py @@ -0,0 +1,135 @@ +from typing import List, Tuple, Union + +import gradio as gr + +from modules.processing import StableDiffusionProcessing + +from lib_controlnet import external_code +from lib_controlnet.logging import logger + + +def field_to_displaytext(fieldname: str) -> str: + return " ".join([word.capitalize() for word in fieldname.split("_")]) + + +def displaytext_to_field(text: str) -> str: + return "_".join([word.lower() for word in text.split(" ")]) + + +def parse_value(value: str) -> Union[str, float, int, bool]: + if value in ("True", "False"): + return value == "True" + try: + return int(value) + except ValueError: + try: + return float(value) + except ValueError: + return value # Plain string. + + +def serialize_unit(unit: external_code.ControlNetUnit) -> str: + excluded_fields = ( + "image", + "enabled", + # Note: "advanced_weighting" is excluded as it is an API-only field. + "advanced_weighting", + # Note: "inpaint_crop_image" is img2img inpaint only flag, which does not + # provide much information when restoring the unit. + "inpaint_crop_input_image", + ) + + log_value = { + field_to_displaytext(field): getattr(unit, field) + for field in vars(external_code.ControlNetUnit()).keys() + if field not in excluded_fields and getattr(unit, field) != -1 + # Note: exclude hidden slider values. + } + if not all("," not in str(v) and ":" not in str(v) for v in log_value.values()): + logger.error(f"Unexpected tokens encountered:\n{log_value}") + return "" + + return ", ".join(f"{field}: {value}" for field, value in log_value.items()) + + +def parse_unit(text: str) -> external_code.ControlNetUnit: + return external_code.ControlNetUnit( + enabled=True, + **{ + displaytext_to_field(key): parse_value(value) + for item in text.split(",") + for (key, value) in (item.strip().split(": "),) + }, + ) + + +class Infotext(object): + def __init__(self) -> None: + self.infotext_fields: List[Tuple[gr.components.IOComponent, str]] = [] + self.paste_field_names: List[str] = [] + + @staticmethod + def unit_prefix(unit_index: int) -> str: + return f"ControlNet {unit_index}" + + def register_unit(self, unit_index: int, uigroup) -> None: + """Register the unit's UI group. By regsitering the unit, A1111 will be + able to paste values from infotext to IOComponents. + + Args: + unit_index: The index of the ControlNet unit + uigroup: The ControlNetUiGroup instance that contains all gradio + iocomponents. + """ + unit_prefix = Infotext.unit_prefix(unit_index) + for field in vars(external_code.ControlNetUnit()).keys(): + # Exclude image for infotext. + if field == "image": + continue + + # Every field in ControlNetUnit should have a cooresponding + # IOComponent in ControlNetUiGroup. + io_component = getattr(uigroup, field) + component_locator = f"{unit_prefix} {field}" + self.infotext_fields.append((io_component, component_locator)) + self.paste_field_names.append(component_locator) + + @staticmethod + def write_infotext( + units: List[external_code.ControlNetUnit], p: StableDiffusionProcessing + ): + """Write infotext to `p`.""" + p.extra_generation_params.update( + { + Infotext.unit_prefix(i): serialize_unit(unit) + for i, unit in enumerate(units) + if unit.enabled + } + ) + + @staticmethod + def on_infotext_pasted(infotext: str, results: dict) -> None: + """Parse ControlNet infotext string and write result to `results` dict.""" + updates = {} + for k, v in results.items(): + if not k.startswith("ControlNet"): + continue + + assert isinstance(v, str), f"Expect string but got {v}." + try: + for field, value in vars(parse_unit(v)).items(): + if field == "image": + continue + if value is None: + logger.debug(f"InfoText: Skipping {field} because value is None.") + continue + + component_locator = f"{k} {field}" + updates[component_locator] = value + logger.debug(f"InfoText: Setting {component_locator} = {value}") + except Exception as e: + logger.warn( + f"Failed to parse infotext, legacy format infotext is no longer supported:\n{v}\n{e}" + ) + + results.update(updates) diff --git a/extensions-builtin/sd_forge_controlnet/lib_controlnet/logging.py b/extensions-builtin/sd_forge_controlnet/lib_controlnet/logging.py new file mode 100644 index 00000000..f30d5eec --- /dev/null +++ b/extensions-builtin/sd_forge_controlnet/lib_controlnet/logging.py @@ -0,0 +1,41 @@ +import logging +import copy +import sys + +from modules import shared + + +class ColoredFormatter(logging.Formatter): + COLORS = { + "DEBUG": "\033[0;36m", # CYAN + "INFO": "\033[0;32m", # GREEN + "WARNING": "\033[0;33m", # YELLOW + "ERROR": "\033[0;31m", # RED + "CRITICAL": "\033[0;37;41m", # WHITE ON RED + "RESET": "\033[0m", # RESET COLOR + } + + def format(self, record): + colored_record = copy.copy(record) + levelname = colored_record.levelname + seq = self.COLORS.get(levelname, self.COLORS["RESET"]) + colored_record.levelname = f"{seq}{levelname}{self.COLORS['RESET']}" + return super().format(colored_record) + + +# Create a new logger +logger = logging.getLogger("ControlNet") +logger.propagate = False + +# Add handler if we don't have one. +if not logger.handlers: + handler = logging.StreamHandler(sys.stdout) + handler.setFormatter( + ColoredFormatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") + ) + logger.addHandler(handler) + +# Configure logger +loglevel_string = getattr(shared.cmd_opts, "controlnet_loglevel", "INFO") +loglevel = getattr(logging, loglevel_string.upper(), None) +logger.setLevel(loglevel) diff --git a/extensions-builtin/sd_forge_controlnet/lib_controlnet/lvminthin.py b/extensions-builtin/sd_forge_controlnet/lib_controlnet/lvminthin.py new file mode 100644 index 00000000..641227aa --- /dev/null +++ b/extensions-builtin/sd_forge_controlnet/lib_controlnet/lvminthin.py @@ -0,0 +1,88 @@ +# High Quality Edge Thinning using Pure Python +# Written by Lvmin Zhang +# 2023 April +# Stanford University +# If you use this, please Cite "High Quality Edge Thinning using Pure Python", Lvmin Zhang, In Mikubill/sd-webui-controlnet. + + +import cv2 +import numpy as np + + +lvmin_kernels_raw = [ + np.array([ + [-1, -1, -1], + [0, 1, 0], + [1, 1, 1] + ], dtype=np.int32), + np.array([ + [0, -1, -1], + [1, 1, -1], + [0, 1, 0] + ], dtype=np.int32) +] + +lvmin_kernels = [] +lvmin_kernels += [np.rot90(x, k=0, axes=(0, 1)) for x in lvmin_kernels_raw] +lvmin_kernels += [np.rot90(x, k=1, axes=(0, 1)) for x in lvmin_kernels_raw] +lvmin_kernels += [np.rot90(x, k=2, axes=(0, 1)) for x in lvmin_kernels_raw] +lvmin_kernels += [np.rot90(x, k=3, axes=(0, 1)) for x in lvmin_kernels_raw] + +lvmin_prunings_raw = [ + np.array([ + [-1, -1, -1], + [-1, 1, -1], + [0, 0, -1] + ], dtype=np.int32), + np.array([ + [-1, -1, -1], + [-1, 1, -1], + [-1, 0, 0] + ], dtype=np.int32) +] + +lvmin_prunings = [] +lvmin_prunings += [np.rot90(x, k=0, axes=(0, 1)) for x in lvmin_prunings_raw] +lvmin_prunings += [np.rot90(x, k=1, axes=(0, 1)) for x in lvmin_prunings_raw] +lvmin_prunings += [np.rot90(x, k=2, axes=(0, 1)) for x in lvmin_prunings_raw] +lvmin_prunings += [np.rot90(x, k=3, axes=(0, 1)) for x in lvmin_prunings_raw] + + +def remove_pattern(x, kernel): + objects = cv2.morphologyEx(x, cv2.MORPH_HITMISS, kernel) + objects = np.where(objects > 127) + x[objects] = 0 + return x, objects[0].shape[0] > 0 + + +def thin_one_time(x, kernels): + y = x + is_done = True + for k in kernels: + y, has_update = remove_pattern(y, k) + if has_update: + is_done = False + return y, is_done + + +def lvmin_thin(x, prunings=True): + y = x + for i in range(32): + y, is_done = thin_one_time(y, lvmin_kernels) + if is_done: + break + if prunings: + y, _ = thin_one_time(y, lvmin_prunings) + return y + + +def nake_nms(x): + f1 = np.array([[0, 0, 0], [1, 1, 1], [0, 0, 0]], dtype=np.uint8) + f2 = np.array([[0, 1, 0], [0, 1, 0], [0, 1, 0]], dtype=np.uint8) + f3 = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=np.uint8) + f4 = np.array([[0, 0, 1], [0, 1, 0], [1, 0, 0]], dtype=np.uint8) + y = np.zeros_like(x) + for f in [f1, f2, f3, f4]: + np.putmask(y, cv2.dilate(x, kernel=f) == x, x) + return y + diff --git a/extensions-builtin/sd_forge_controlnet/lib_controlnet/utils.py b/extensions-builtin/sd_forge_controlnet/lib_controlnet/utils.py new file mode 100644 index 00000000..6712f71d --- /dev/null +++ b/extensions-builtin/sd_forge_controlnet/lib_controlnet/utils.py @@ -0,0 +1,180 @@ +import torch +import os +import functools +import time +import base64 +import numpy as np +import safetensors.torch +import cv2 +import logging + +from typing import Any, Callable, Dict, List +from modules.safe import unsafe_torch_load +from lib_controlnet.logging import logger + + +def load_state_dict(ckpt_path, location="cpu"): + _, extension = os.path.splitext(ckpt_path) + if extension.lower() == ".safetensors": + state_dict = safetensors.torch.load_file(ckpt_path, device=location) + else: + state_dict = unsafe_torch_load(ckpt_path, map_location=torch.device(location)) + state_dict = get_state_dict(state_dict) + logger.info(f"Loaded state_dict from [{ckpt_path}]") + return state_dict + + +def get_state_dict(d): + return d.get("state_dict", d) + + +def ndarray_lru_cache(max_size: int = 128, typed: bool = False): + """ + Decorator to enable caching for functions with numpy array arguments. + Numpy arrays are mutable, and thus not directly usable as hash keys. + + The idea here is to wrap the incoming arguments with type `np.ndarray` + as `HashableNpArray` so that `lru_cache` can correctly handles `np.ndarray` + arguments. + + `HashableNpArray` functions exactly the same way as `np.ndarray` except + having `__hash__` and `__eq__` overriden. + """ + + def decorator(func: Callable): + """The actual decorator that accept function as input.""" + + class HashableNpArray(np.ndarray): + def __new__(cls, input_array): + # Input array is an instance of ndarray. + # The view makes the input array and returned array share the same data. + obj = np.asarray(input_array).view(cls) + return obj + + def __eq__(self, other) -> bool: + return np.array_equal(self, other) + + def __hash__(self): + # Hash the bytes representing the data of the array. + return hash(self.tobytes()) + + @functools.lru_cache(maxsize=max_size, typed=typed) + def cached_func(*args, **kwargs): + """This function only accepts `HashableNpArray` as input params.""" + return func(*args, **kwargs) + + # Preserves original function.__name__ and __doc__. + @functools.wraps(func) + def decorated_func(*args, **kwargs): + """The decorated function that delegates the original function.""" + + def convert_item(item: Any): + if isinstance(item, np.ndarray): + return HashableNpArray(item) + if isinstance(item, tuple): + return tuple(convert_item(i) for i in item) + return item + + args = [convert_item(arg) for arg in args] + kwargs = {k: convert_item(arg) for k, arg in kwargs.items()} + return cached_func(*args, **kwargs) + + return decorated_func + + return decorator + + +def timer_decorator(func): + """Time the decorated function and output the result to debug logger.""" + if logger.level != logging.DEBUG: + return func + + @functools.wraps(func) + def wrapper(*args, **kwargs): + start_time = time.time() + result = func(*args, **kwargs) + end_time = time.time() + duration = end_time - start_time + # Only report function that are significant enough. + if duration > 1e-3: + logger.debug(f"{func.__name__} ran in: {duration:.3f} sec") + return result + + return wrapper + + +class TimeMeta(type): + """ Metaclass to record execution time on all methods of the + child class. """ + def __new__(cls, name, bases, attrs): + for attr_name, attr_value in attrs.items(): + if callable(attr_value): + attrs[attr_name] = timer_decorator(attr_value) + return super().__new__(cls, name, bases, attrs) + + +# svgsupports +svgsupport = False +try: + import io + from svglib.svglib import svg2rlg + from reportlab.graphics import renderPM + + svgsupport = True +except ImportError: + pass + + +def svg_preprocess(inputs: Dict, preprocess: Callable): + if not inputs: + return None + + if inputs["image"].startswith("data:image/svg+xml;base64,") and svgsupport: + svg_data = base64.b64decode( + inputs["image"].replace("data:image/svg+xml;base64,", "") + ) + drawing = svg2rlg(io.BytesIO(svg_data)) + png_data = renderPM.drawToString(drawing, fmt="PNG") + encoded_string = base64.b64encode(png_data) + base64_str = str(encoded_string, "utf-8") + base64_str = "data:image/png;base64," + base64_str + inputs["image"] = base64_str + return preprocess(inputs) + + +def get_unique_axis0(data): + arr = np.asanyarray(data) + idxs = np.lexsort(arr.T) + arr = arr[idxs] + unique_idxs = np.empty(len(arr), dtype=np.bool_) + unique_idxs[:1] = True + unique_idxs[1:] = np.any(arr[:-1, :] != arr[1:, :], axis=-1) + return arr[unique_idxs] + + +def read_image(img_path: str) -> str: + """Read image from specified path and return a base64 string.""" + img = cv2.imread(img_path) + _, bytes = cv2.imencode(".png", img) + encoded_image = base64.b64encode(bytes).decode("utf-8") + return encoded_image + + +def read_image_dir(img_dir: str, suffixes=('.png', '.jpg', '.jpeg', '.webp')) -> List[str]: + """Try read all images in given img_dir.""" + images = [] + for filename in os.listdir(img_dir): + if filename.endswith(suffixes): + img_path = os.path.join(img_dir, filename) + try: + images.append(read_image(img_path)) + except IOError: + logger.error(f"Error opening {img_path}") + return images + + +def align_dim_latent(x: int) -> int: + """ Align the pixel dimension (w/h) to latent dimension. + Stable diffusion 1:8 ratio for latent/pixel, i.e., + 1 latent unit == 8 pixel unit.""" + return (x // 8) * 8 \ No newline at end of file diff --git a/extensions-builtin/sd_forge_controlnet/preload.py b/extensions-builtin/sd_forge_controlnet/preload.py new file mode 100644 index 00000000..9bc15b70 --- /dev/null +++ b/extensions-builtin/sd_forge_controlnet/preload.py @@ -0,0 +1,39 @@ +def preload(parser): + parser.add_argument( + "--controlnet-dir", + type=str, + help="Path to directory with ControlNet models", + default=None, + ) + parser.add_argument( + "--controlnet-annotator-models-path", + type=str, + help="Path to directory with annotator model directories", + default=None, + ) + parser.add_argument( + "--no-half-controlnet", + action="store_true", + help="do not switch the ControlNet models to 16-bit floats (only needed without --no-half)", + default=None, + ) + # Setting default max_size=16 as each cache entry contains image as both key + # and value (Very costly). + parser.add_argument( + "--controlnet-preprocessor-cache-size", + type=int, + help="Cache size for controlnet preprocessor results", + default=16, + ) + parser.add_argument( + "--controlnet-loglevel", + default="INFO", + choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"], + help="Set the log level (DEBUG, INFO, WARNING, ERROR, CRITICAL)", + ) + parser.add_argument( + "--controlnet-tracemalloc", + action="store_true", + help="Enable memory tracing.", + default=None, + ) diff --git a/extensions-builtin/sd_forge_controlnet/requirements.txt b/extensions-builtin/sd_forge_controlnet/requirements.txt new file mode 100644 index 00000000..d12e85b0 --- /dev/null +++ b/extensions-builtin/sd_forge_controlnet/requirements.txt @@ -0,0 +1,5 @@ +fvcore +mediapipe +onnxruntime +opencv-python>=4.8.0 +svglib diff --git a/extensions-builtin/sd_forge_controlnet/scripts/controlnet.py b/extensions-builtin/sd_forge_controlnet/scripts/controlnet.py new file mode 100644 index 00000000..0c11624b --- /dev/null +++ b/extensions-builtin/sd_forge_controlnet/scripts/controlnet.py @@ -0,0 +1,1181 @@ +import gc +import tracemalloc +import os +import logging +from collections import OrderedDict +from copy import copy +from typing import Dict, Optional, Tuple, List, Union +import modules.scripts as scripts +from modules import shared, devices, script_callbacks, processing, masking, images +from modules.api.api import decode_base64_to_image +import gradio as gr +import time + +from einops import rearrange +from lib_controlnet import global_state, external_code, utils +from lib_controlnet.utils import get_unique_axis0, align_dim_latent +from lib_controlnet.enums import StableDiffusionVersion, HiResFixOption +from lib_controlnet.controlnet_ui.controlnet_ui_group import ControlNetUiGroup, UiControlNetUnit +from lib_controlnet.controlnet_ui.photopea import Photopea +from lib_controlnet.logging import logger +from modules.processing import StableDiffusionProcessingImg2Img, StableDiffusionProcessingTxt2Img, StableDiffusionProcessing +from modules.images import save_image +from lib_controlnet.infotext import Infotext +from modules_forge.forge_util import HWC3 + +import cv2 +import numpy as np +import torch + +from PIL import Image, ImageFilter, ImageOps +from lib_controlnet.lvminthin import lvmin_thin, nake_nms + + +# Gradio 3.32 bug fix +import tempfile +gradio_tempfile_path = os.path.join(tempfile.gettempdir(), 'gradio') +os.makedirs(gradio_tempfile_path, exist_ok=True) + + +global_state.update_controlnet_filenames() + + +def image_dict_from_any(image) -> Optional[Dict[str, np.ndarray]]: + if image is None: + return None + + if isinstance(image, (tuple, list)): + image = {'image': image[0], 'mask': image[1]} + elif not isinstance(image, dict): + image = {'image': image, 'mask': None} + else: # type(image) is dict + # copy to enable modifying the dict and prevent response serialization error + image = dict(image) + + if isinstance(image['image'], str): + if os.path.exists(image['image']): + image['image'] = np.array(Image.open(image['image'])).astype('uint8') + elif image['image']: + image['image'] = external_code.to_base64_nparray(image['image']) + else: + image['image'] = None + + # If there is no image, return image with None image and None mask + if image['image'] is None: + image['mask'] = None + return image + + if 'mask' not in image or image['mask'] is None: + image['mask'] = np.zeros_like(image['image'], dtype=np.uint8) + elif isinstance(image['mask'], str): + if os.path.exists(image['mask']): + image['mask'] = np.array(Image.open(image['mask'])).astype('uint8') + elif image['mask']: + image['mask'] = external_code.to_base64_nparray(image['mask']) + else: + image['mask'] = np.zeros_like(image['image'], dtype=np.uint8) + + return image + + +def prepare_mask( + mask: Image.Image, p: processing.StableDiffusionProcessing +) -> Image.Image: + """ + Prepare an image mask for the inpainting process. + + This function takes as input a PIL Image object and an instance of the + StableDiffusionProcessing class, and performs the following steps to prepare the mask: + + 1. Convert the mask to grayscale (mode "L"). + 2. If the 'inpainting_mask_invert' attribute of the processing instance is True, + invert the mask colors. + 3. If the 'mask_blur' attribute of the processing instance is greater than 0, + apply a Gaussian blur to the mask with a radius equal to 'mask_blur'. + + Args: + mask (Image.Image): The input mask as a PIL Image object. + p (processing.StableDiffusionProcessing): An instance of the StableDiffusionProcessing class + containing the processing parameters. + + Returns: + mask (Image.Image): The prepared mask as a PIL Image object. + """ + mask = mask.convert("L") + if getattr(p, "inpainting_mask_invert", False): + mask = ImageOps.invert(mask) + + if hasattr(p, 'mask_blur_x'): + if getattr(p, "mask_blur_x", 0) > 0: + np_mask = np.array(mask) + kernel_size = 2 * int(2.5 * p.mask_blur_x + 0.5) + 1 + np_mask = cv2.GaussianBlur(np_mask, (kernel_size, 1), p.mask_blur_x) + mask = Image.fromarray(np_mask) + if getattr(p, "mask_blur_y", 0) > 0: + np_mask = np.array(mask) + kernel_size = 2 * int(2.5 * p.mask_blur_y + 0.5) + 1 + np_mask = cv2.GaussianBlur(np_mask, (1, kernel_size), p.mask_blur_y) + mask = Image.fromarray(np_mask) + else: + if getattr(p, "mask_blur", 0) > 0: + mask = mask.filter(ImageFilter.GaussianBlur(p.mask_blur)) + + return mask + + +def set_numpy_seed(p: processing.StableDiffusionProcessing) -> Optional[int]: + """ + Set the random seed for NumPy based on the provided parameters. + + Args: + p (processing.StableDiffusionProcessing): The instance of the StableDiffusionProcessing class. + + Returns: + Optional[int]: The computed random seed if successful, or None if an exception occurs. + + This function sets the random seed for NumPy using the seed and subseed values from the given instance of + StableDiffusionProcessing. If either seed or subseed is -1, it uses the first value from `all_seeds`. + Otherwise, it takes the maximum of the provided seed value and 0. + + The final random seed is computed by adding the seed and subseed values, applying a bitwise AND operation + with 0xFFFFFFFF to ensure it fits within a 32-bit integer. + """ + try: + tmp_seed = int(p.all_seeds[0] if p.seed == -1 else max(int(p.seed), 0)) + tmp_subseed = int(p.all_seeds[0] if p.subseed == -1 else max(int(p.subseed), 0)) + seed = (tmp_seed + tmp_subseed) & 0xFFFFFFFF + np.random.seed(seed) + return seed + except Exception as e: + logger.warning(e) + logger.warning('Warning: Failed to use consistent random seed.') + return None + + +def get_pytorch_control(x: np.ndarray) -> torch.Tensor: + # A very safe method to make sure that Apple/Mac works + y = x + + # below is very boring but do not change these. If you change these Apple or Mac may fail. + y = torch.from_numpy(y) + y = y.float() / 255.0 + y = rearrange(y, 'h w c -> 1 c h w') + y = y.clone() + y = y.to(devices.get_device_for("controlnet")) + y = y.clone() + return y + + +class Script(scripts.Script, metaclass=( + utils.TimeMeta if logger.level == logging.DEBUG else type)): + + model_cache = OrderedDict() + + def __init__(self) -> None: + super().__init__() + self.latest_network = None + self.input_image = None + self.latest_model_hash = "" + self.enabled_units = [] + self.detected_map = [] + self.post_processors = [] + self.noise_modifier = None + self.ui_batch_option_state = [external_code.BatchOption.DEFAULT.value, False] + + def title(self): + return "ControlNet" + + def show(self, is_img2img): + return scripts.AlwaysVisible + + @staticmethod + def get_default_ui_unit(is_ui=True): + cls = UiControlNetUnit if is_ui else external_code.ControlNetUnit + return cls( + enabled=False, + module="none", + model="None" + ) + + def uigroup(self, tabname: str, is_img2img: bool, elem_id_tabname: str, photopea: Optional[Photopea]) -> Tuple[ControlNetUiGroup, gr.State]: + group = ControlNetUiGroup( + is_img2img, + Script.get_default_ui_unit(), + photopea, + ) + return group, group.render(tabname, elem_id_tabname) + + def ui_batch_options(self, is_img2img: bool, elem_id_tabname: str): + batch_option = gr.Radio( + choices=[e.value for e in external_code.BatchOption], + value=external_code.BatchOption.DEFAULT.value, + label="Batch Option", + elem_id=f"{elem_id_tabname}_controlnet_batch_option_radio", + elem_classes="controlnet_batch_option_radio", + ) + use_batch_style_align = gr.Checkbox( + label='[StyleAlign] Align image style in the batch.' + ) + + unit_args = [batch_option, use_batch_style_align] + + def update_ui_batch_options(*args): + self.ui_batch_option_state = args + return + + for comp in unit_args: + event_subscribers = [] + if hasattr(comp, "edit"): + event_subscribers.append(comp.edit) + elif hasattr(comp, "click"): + event_subscribers.append(comp.click) + elif isinstance(comp, gr.Slider) and hasattr(comp, "release"): + event_subscribers.append(comp.release) + elif hasattr(comp, "change"): + event_subscribers.append(comp.change) + + if hasattr(comp, "clear"): + event_subscribers.append(comp.clear) + + for event_subscriber in event_subscribers: + event_subscriber( + fn=update_ui_batch_options, inputs=unit_args + ) + + return + + def ui(self, is_img2img): + """this function should create gradio UI elements. See https://gradio.app/docs/#components + The return value should be an array of all components that are used in processing. + Values of those returned components will be passed to run() and process() functions. + """ + infotext = Infotext() + ui_groups = [] + controls = [] + max_models = shared.opts.data.get("control_net_unit_count", 3) + elem_id_tabname = ("img2img" if is_img2img else "txt2img") + "_controlnet" + with gr.Group(elem_id=elem_id_tabname): + with gr.Accordion(f"ControlNet Integrated", open=False, elem_id="controlnet"): + photopea = Photopea() if not shared.opts.data.get("controlnet_disable_photopea_edit", False) else None + if max_models > 1: + with gr.Tabs(elem_id=f"{elem_id_tabname}_tabs"): + for i in range(max_models): + with gr.Tab(f"ControlNet Unit {i}", + elem_classes=['cnet-unit-tab']): + group, state = self.uigroup(f"ControlNet-{i}", is_img2img, elem_id_tabname, photopea) + ui_groups.append(group) + controls.append(state) + else: + with gr.Column(): + group, state = self.uigroup(f"ControlNet", is_img2img, elem_id_tabname, photopea) + ui_groups.append(group) + controls.append(state) + with gr.Accordion(f"Batch Options", open=False, elem_id="controlnet_batch_options"): + self.ui_batch_options(is_img2img, elem_id_tabname) + + for i, ui_group in enumerate(ui_groups): + infotext.register_unit(i, ui_group) + if shared.opts.data.get("control_net_sync_field_args", True): + self.infotext_fields = infotext.infotext_fields + self.paste_field_names = infotext.paste_field_names + + return tuple(controls) + + @staticmethod + def clear_control_model_cache(): + Script.model_cache.clear() + gc.collect() + devices.torch_gc() + + @staticmethod + def get_remote_call(p, attribute, default=None, idx=0, strict=False, force=False): + if not force and not shared.opts.data.get("control_net_allow_script_control", False): + return default + + def get_element(obj, strict=False): + if not isinstance(obj, list): + return obj if not strict or idx == 0 else None + elif idx < len(obj): + return obj[idx] + else: + return None + + attribute_value = get_element(getattr(p, attribute, None), strict) + return attribute_value if attribute_value is not None else default + + @staticmethod + def parse_remote_call(p, unit: external_code.ControlNetUnit, idx): + selector = Script.get_remote_call + + unit.enabled = selector(p, "control_net_enabled", unit.enabled, idx, strict=True) + unit.module = selector(p, "control_net_module", unit.module, idx) + unit.model = selector(p, "control_net_model", unit.model, idx) + unit.weight = selector(p, "control_net_weight", unit.weight, idx) + unit.image = selector(p, "control_net_image", unit.image, idx) + unit.resize_mode = selector(p, "control_net_resize_mode", unit.resize_mode, idx) + unit.low_vram = selector(p, "control_net_lowvram", unit.low_vram, idx) + unit.processor_res = selector(p, "control_net_pres", unit.processor_res, idx) + unit.threshold_a = selector(p, "control_net_pthr_a", unit.threshold_a, idx) + unit.threshold_b = selector(p, "control_net_pthr_b", unit.threshold_b, idx) + unit.guidance_start = selector(p, "control_net_guidance_start", unit.guidance_start, idx) + unit.guidance_end = selector(p, "control_net_guidance_end", unit.guidance_end, idx) + # Backward compatibility. See https://github.com/Mikubill/sd-webui-controlnet/issues/1740 + # for more details. + unit.guidance_end = selector(p, "control_net_guidance_strength", unit.guidance_end, idx) + unit.control_mode = selector(p, "control_net_control_mode", unit.control_mode, idx) + unit.pixel_perfect = selector(p, "control_net_pixel_perfect", unit.pixel_perfect, idx) + + return unit + + @staticmethod + def detectmap_proc(detected_map, module, resize_mode, h, w): + + if 'inpaint' in module: + detected_map = detected_map.astype(np.float32) + else: + detected_map = HWC3(detected_map) + + def safe_numpy(x): + # A very safe method to make sure that Apple/Mac works + y = x + + # below is very boring but do not change these. If you change these Apple or Mac may fail. + y = y.copy() + y = np.ascontiguousarray(y) + y = y.copy() + return y + + def high_quality_resize(x, size): + # Written by lvmin + # Super high-quality control map up-scaling, considering binary, seg, and one-pixel edges + + inpaint_mask = None + if x.ndim == 3 and x.shape[2] == 4: + inpaint_mask = x[:, :, 3] + x = x[:, :, 0:3] + + if x.shape[0] != size[1] or x.shape[1] != size[0]: + new_size_is_smaller = (size[0] * size[1]) < (x.shape[0] * x.shape[1]) + new_size_is_bigger = (size[0] * size[1]) > (x.shape[0] * x.shape[1]) + unique_color_count = len(get_unique_axis0(x.reshape(-1, x.shape[2]))) + is_one_pixel_edge = False + is_binary = False + if unique_color_count == 2: + is_binary = np.min(x) < 16 and np.max(x) > 240 + if is_binary: + xc = x + xc = cv2.erode(xc, np.ones(shape=(3, 3), dtype=np.uint8), iterations=1) + xc = cv2.dilate(xc, np.ones(shape=(3, 3), dtype=np.uint8), iterations=1) + one_pixel_edge_count = np.where(xc < x)[0].shape[0] + all_edge_count = np.where(x > 127)[0].shape[0] + is_one_pixel_edge = one_pixel_edge_count * 2 > all_edge_count + + if 2 < unique_color_count < 200: + interpolation = cv2.INTER_NEAREST + elif new_size_is_smaller: + interpolation = cv2.INTER_AREA + else: + interpolation = cv2.INTER_CUBIC # Must be CUBIC because we now use nms. NEVER CHANGE THIS + + y = cv2.resize(x, size, interpolation=interpolation) + if inpaint_mask is not None: + inpaint_mask = cv2.resize(inpaint_mask, size, interpolation=interpolation) + + if is_binary: + y = np.mean(y.astype(np.float32), axis=2).clip(0, 255).astype(np.uint8) + if is_one_pixel_edge: + y = nake_nms(y) + _, y = cv2.threshold(y, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU) + y = lvmin_thin(y, prunings=new_size_is_bigger) + else: + _, y = cv2.threshold(y, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU) + y = np.stack([y] * 3, axis=2) + else: + y = x + + if inpaint_mask is not None: + inpaint_mask = (inpaint_mask > 127).astype(np.float32) * 255.0 + inpaint_mask = inpaint_mask[:, :, None].clip(0, 255).astype(np.uint8) + y = np.concatenate([y, inpaint_mask], axis=2) + + return y + + if resize_mode == external_code.ResizeMode.RESIZE: + detected_map = high_quality_resize(detected_map, (w, h)) + detected_map = safe_numpy(detected_map) + return get_pytorch_control(detected_map), detected_map + + old_h, old_w, _ = detected_map.shape + old_w = float(old_w) + old_h = float(old_h) + k0 = float(h) / old_h + k1 = float(w) / old_w + + safeint = lambda x: int(np.round(x)) + + if resize_mode == external_code.ResizeMode.OUTER_FIT: + k = min(k0, k1) + borders = np.concatenate([detected_map[0, :, :], detected_map[-1, :, :], detected_map[:, 0, :], detected_map[:, -1, :]], axis=0) + high_quality_border_color = np.median(borders, axis=0).astype(detected_map.dtype) + if len(high_quality_border_color) == 4: + # Inpaint hijack + high_quality_border_color[3] = 255 + high_quality_background = np.tile(high_quality_border_color[None, None], [h, w, 1]) + detected_map = high_quality_resize(detected_map, (safeint(old_w * k), safeint(old_h * k))) + new_h, new_w, _ = detected_map.shape + pad_h = max(0, (h - new_h) // 2) + pad_w = max(0, (w - new_w) // 2) + high_quality_background[pad_h:pad_h + new_h, pad_w:pad_w + new_w] = detected_map + detected_map = high_quality_background + detected_map = safe_numpy(detected_map) + return get_pytorch_control(detected_map), detected_map + else: + k = max(k0, k1) + detected_map = high_quality_resize(detected_map, (safeint(old_w * k), safeint(old_h * k))) + new_h, new_w, _ = detected_map.shape + pad_h = max(0, (new_h - h) // 2) + pad_w = max(0, (new_w - w) // 2) + detected_map = detected_map[pad_h:pad_h+h, pad_w:pad_w+w] + detected_map = safe_numpy(detected_map) + return get_pytorch_control(detected_map), detected_map + + @staticmethod + def get_enabled_units(p): + units = external_code.get_all_units_in_processing(p) + if len(units) == 0: + # fill a null group + remote_unit = Script.parse_remote_call(p, Script.get_default_ui_unit(), 0) + if remote_unit.enabled: + units.append(remote_unit) + + enabled_units = [] + for idx, unit in enumerate(units): + local_unit = Script.parse_remote_call(p, unit, idx) + if not local_unit.enabled: + continue + if hasattr(local_unit, "unfold_merged"): + enabled_units.extend(local_unit.unfold_merged()) + else: + enabled_units.append(copy(local_unit)) + + Infotext.write_infotext(enabled_units, p) + return enabled_units + + @staticmethod + def choose_input_image( + p: processing.StableDiffusionProcessing, + unit: external_code.ControlNetUnit, + idx: int + ) -> Tuple[np.ndarray, external_code.ResizeMode]: + """ Choose input image from following sources with descending priority: + - p.image_control: [Deprecated] Lagacy way to pass image to controlnet. + - p.control_net_input_image: [Deprecated] Lagacy way to pass image to controlnet. + - unit.image: ControlNet tab input image. + - p.init_images: A1111 img2img tab input image. + + Returns: + - The input image in ndarray form. + - The resize mode. + """ + def parse_unit_image(unit: external_code.ControlNetUnit) -> Union[List[Dict[str, np.ndarray]], Dict[str, np.ndarray]]: + unit_has_multiple_images = ( + isinstance(unit.image, list) and + len(unit.image) > 0 and + "image" in unit.image[0] + ) + if unit_has_multiple_images: + return [ + d + for img in unit.image + for d in (image_dict_from_any(img),) + if d is not None + ] + return image_dict_from_any(unit.image) + + def decode_image(img) -> np.ndarray: + """Need to check the image for API compatibility.""" + if isinstance(img, str): + return np.asarray(decode_base64_to_image(image['image'])) + else: + assert isinstance(img, np.ndarray) + return img + + # 4 input image sources. + p_image_control = getattr(p, "image_control", None) + p_input_image = Script.get_remote_call(p, "control_net_input_image", None, idx) + image = parse_unit_image(unit) + a1111_image = getattr(p, "init_images", [None])[0] + + resize_mode = external_code.resize_mode_from_value(unit.resize_mode) + + if p_image_control is not None: + logger.warning("Warn: Using legacy field 'p.image_control'.") + input_image = HWC3(np.asarray(p_image_control)) + elif p_input_image is not None: + logger.warning("Warn: Using legacy field 'p.controlnet_input_image'") + if isinstance(p_input_image, dict) and "mask" in p_input_image and "image" in p_input_image: + color = HWC3(np.asarray(p_input_image['image'])) + alpha = np.asarray(p_input_image['mask'])[..., None] + input_image = np.concatenate([color, alpha], axis=2) + else: + input_image = HWC3(np.asarray(p_input_image)) + elif image: + if isinstance(image, list): + # Add mask logic if later there is a processor that accepts mask + # on multiple inputs. + input_image = [HWC3(decode_image(img['image'])) for img in image] + else: + input_image = HWC3(decode_image(image['image'])) + if 'mask' in image and image['mask'] is not None: + while len(image['mask'].shape) < 3: + image['mask'] = image['mask'][..., np.newaxis] + if 'inpaint' in unit.module: + logger.info("using inpaint as input") + color = HWC3(image['image']) + alpha = image['mask'][:, :, 0:1] + input_image = np.concatenate([color, alpha], axis=2) + elif ( + not shared.opts.data.get("controlnet_ignore_noninpaint_mask", False) and + # There is wield gradio issue that would produce mask that is + # not pure color when no scribble is made on canvas. + # See https://github.com/Mikubill/sd-webui-controlnet/issues/1638. + not ( + (image['mask'][:, :, 0] <= 5).all() or + (image['mask'][:, :, 0] >= 250).all() + ) + ): + logger.info("using mask as input") + input_image = HWC3(image['mask'][:, :, 0]) + unit.module = 'none' # Always use black bg and white line + elif a1111_image is not None: + input_image = HWC3(np.asarray(a1111_image)) + a1111_i2i_resize_mode = getattr(p, "resize_mode", None) + assert a1111_i2i_resize_mode is not None + resize_mode = external_code.resize_mode_from_value(a1111_i2i_resize_mode) + + a1111_mask_image : Optional[Image.Image] = getattr(p, "image_mask", None) + if 'inpaint' in unit.module: + if a1111_mask_image is not None: + a1111_mask = np.array(prepare_mask(a1111_mask_image, p)) + assert a1111_mask.ndim == 2 + assert a1111_mask.shape[0] == input_image.shape[0] + assert a1111_mask.shape[1] == input_image.shape[1] + input_image = np.concatenate([input_image[:, :, 0:3], a1111_mask[:, :, None]], axis=2) + else: + input_image = np.concatenate([ + input_image[:, :, 0:3], + np.zeros_like(input_image, dtype=np.uint8)[:, :, 0:1], + ], axis=2) + else: + raise ValueError("controlnet is enabled but no input image is given") + + assert isinstance(input_image, (np.ndarray, list)) + return input_image, resize_mode + + @staticmethod + def try_crop_image_with_a1111_mask( + p: StableDiffusionProcessing, + unit: external_code.ControlNetUnit, + input_image: np.ndarray, + resize_mode: external_code.ResizeMode, + ) -> np.ndarray: + """ + Crop ControlNet input image based on A1111 inpaint mask given. + This logic is crutial in upscale scripts, as they use A1111 mask + inpaint_full_res + to crop tiles. + """ + # Note: The method determining whether the active script is an upscale script is purely + # based on `extra_generation_params` these scripts attach on `p`, and subject to change + # in the future. + # TODO: Change this to a more robust condition once A1111 offers a way to verify script name. + is_upscale_script = any("upscale" in k.lower() for k in getattr(p, "extra_generation_params", {}).keys()) + logger.debug(f"is_upscale_script={is_upscale_script}") + # Note: `inpaint_full_res` is "inpaint area" on UI. The flag is `True` when "Only masked" + # option is selected. + a1111_mask_image : Optional[Image.Image] = getattr(p, "image_mask", None) + is_only_masked_inpaint = ( + issubclass(type(p), StableDiffusionProcessingImg2Img) and + p.inpaint_full_res and + a1111_mask_image is not None + ) + if ( + 'reference' not in unit.module + and is_only_masked_inpaint + and (is_upscale_script or unit.inpaint_crop_input_image) + ): + logger.debug("Crop input image based on A1111 mask.") + input_image = [input_image[:, :, i] for i in range(input_image.shape[2])] + input_image = [Image.fromarray(x) for x in input_image] + + mask = prepare_mask(a1111_mask_image, p) + + crop_region = masking.get_crop_region(np.array(mask), p.inpaint_full_res_padding) + crop_region = masking.expand_crop_region(crop_region, p.width, p.height, mask.width, mask.height) + + input_image = [ + images.resize_image(resize_mode.int_value(), i, mask.width, mask.height) + for i in input_image + ] + + input_image = [x.crop(crop_region) for x in input_image] + input_image = [ + images.resize_image(external_code.ResizeMode.OUTER_FIT.int_value(), x, p.width, p.height) + for x in input_image + ] + + input_image = [np.asarray(x)[:, :, 0] for x in input_image] + input_image = np.stack(input_image, axis=2) + return input_image + + @staticmethod + def bound_check_params(unit: external_code.ControlNetUnit) -> None: + """ + Checks and corrects negative parameters in ControlNetUnit 'unit'. + Parameters 'processor_res', 'threshold_a', 'threshold_b' are reset to + their default values if negative. + + Args: + unit (external_code.ControlNetUnit): The ControlNetUnit instance to check. + """ + cfg = preprocessor_sliders_config.get( + global_state.get_module_basename(unit.module), []) + defaults = { + param: cfg_default['value'] + for param, cfg_default in zip( + ("processor_res", 'threshold_a', 'threshold_b'), cfg) + if cfg_default is not None + } + for param, default_value in defaults.items(): + value = getattr(unit, param) + if value < 0: + setattr(unit, param, default_value) + logger.warning(f'[{unit.module}.{param}] Invalid value({value}), using default value {default_value}.') + + @staticmethod + def check_sd_version_compatible(unit: external_code.ControlNetUnit) -> None: + """ + Checks whether the given ControlNet unit has model compatible with the currently + active sd model. An exception is thrown if ControlNet unit is detected to be + incompatible. + """ + sd_version = global_state.get_sd_version() + assert sd_version != StableDiffusionVersion.UNKNOWN + + if "revision" in unit.module.lower() and sd_version != StableDiffusionVersion.SDXL: + raise Exception(f"Preprocessor 'revision' only supports SDXL. Current SD base model is {sd_version}.") + + # No need to check if the ControlModelType does not require model to be present. + if unit.model is None or unit.model.lower() == "none": + return + + cnet_sd_version = StableDiffusionVersion.detect_from_model_name(unit.model) + + if cnet_sd_version == StableDiffusionVersion.UNKNOWN: + logger.warn(f"Unable to determine version for ControlNet model '{unit.model}'.") + return + + if not sd_version.is_compatible_with(cnet_sd_version): + raise Exception(f"ControlNet model {unit.model}({cnet_sd_version}) is not compatible with sd model({sd_version})") + + @staticmethod + def get_target_dimensions(p: StableDiffusionProcessing) -> Tuple[int, int, int, int]: + """Returns (h, w, hr_h, hr_w).""" + h = align_dim_latent(p.height) + w = align_dim_latent(p.width) + + high_res_fix = ( + isinstance(p, StableDiffusionProcessingTxt2Img) + and getattr(p, 'enable_hr', False) + ) + if high_res_fix: + if p.hr_resize_x == 0 and p.hr_resize_y == 0: + hr_y = int(p.height * p.hr_scale) + hr_x = int(p.width * p.hr_scale) + else: + hr_y, hr_x = p.hr_resize_y, p.hr_resize_x + hr_y = align_dim_latent(hr_y) + hr_x = align_dim_latent(hr_x) + else: + hr_y = h + hr_x = w + + return h, w, hr_y, hr_x + + def controlnet_main_entry(self, p): + sd_ldm = p.sd_model + unet = sd_ldm.model.diffusion_model + self.noise_modifier = None + + setattr(p, 'controlnet_control_loras', []) + + if self.latest_network is not None: + # always restore (~0.05s) + self.latest_network.restore() + + # always clear (~0.05s) + clear_all_secondary_control_models(unet) + + if not batch_hijack.instance.is_batch: + self.enabled_units = Script.get_enabled_units(p) + + batch_option_uint_separate = self.ui_batch_option_state[0] == external_code.BatchOption.SEPARATE.value + batch_option_style_align = self.ui_batch_option_state[1] + + if len(self.enabled_units) == 0 and not batch_option_style_align: + self.latest_network = None + return + + logger.info(f"unit_separate = {batch_option_uint_separate}, style_align = {batch_option_style_align}") + + detected_maps = [] + forward_params = [] + post_processors = [] + + # cache stuff + if self.latest_model_hash != p.sd_model.sd_model_hash: + Script.clear_control_model_cache() + + for idx, unit in enumerate(self.enabled_units): + unit.module = global_state.get_module_basename(unit.module) + + # unload unused preproc + module_list = [unit.module for unit in self.enabled_units] + for key in self.unloadable: + if key not in module_list: + self.unloadable.get(key, lambda:None)() + + self.latest_model_hash = p.sd_model.sd_model_hash + high_res_fix = isinstance(p, StableDiffusionProcessingTxt2Img) and getattr(p, 'enable_hr', False) + h, w, hr_y, hr_x = Script.get_target_dimensions(p) + + for idx, unit in enumerate(self.enabled_units): + Script.bound_check_params(unit) + Script.check_sd_version_compatible(unit) + if ( + "ip-adapter" in unit.module and + not global_state.ip_adapter_pairing_model[unit.module](unit.model) + ): + logger.error(f"Invalid pair of IP-Adapter preprocessor({unit.module}) and model({unit.model}).\n" + "Please follow following pairing logic:\n" + + global_state.ip_adapter_pairing_logic_text) + continue + + if ( + 'inpaint_only' == unit.module and + issubclass(type(p), StableDiffusionProcessingImg2Img) and + p.image_mask is not None + ): + logger.warning('A1111 inpaint and ControlNet inpaint duplicated. Falls back to inpaint_global_harmonious.') + unit.module = 'inpaint' + + if unit.module in model_free_preprocessors: + model_net = None + if 'reference' in unit.module: + control_model_type = ControlModelType.AttentionInjection + elif 'revision' in unit.module: + control_model_type = ControlModelType.ReVision + else: + raise Exception("Unable to determine control_model_type.") + else: + model_net, control_model_type = Script.load_control_model(p, unet, unit.model) + model_net.reset() + + if control_model_type == ControlModelType.ControlLoRA: + control_lora = model_net.control_model + bind_control_lora(unet, control_lora) + p.controlnet_control_loras.append(control_lora) + + input_image, resize_mode = Script.choose_input_image(p, unit, idx) + if isinstance(input_image, list): + assert unit.accepts_multiple_inputs() + input_images = input_image + else: # Following operations are only for single input image. + input_image = Script.try_crop_image_with_a1111_mask(p, unit, input_image, resize_mode) + input_image = np.ascontiguousarray(input_image.copy()).copy() # safe numpy + if unit.module == 'inpaint_only+lama' and resize_mode == external_code.ResizeMode.OUTER_FIT: + # inpaint_only+lama is special and required outpaint fix + _, input_image = Script.detectmap_proc(input_image, unit.module, resize_mode, hr_y, hr_x) + if unit.pixel_perfect: + unit.processor_res = external_code.pixel_perfect_resolution( + input_image, + target_H=h, + target_W=w, + resize_mode=resize_mode, + ) + input_images = [input_image] + # Preprocessor result may depend on numpy random operations, use the + # random seed in `StableDiffusionProcessing` to make the + # preprocessor result reproducable. + # Currently following preprocessors use numpy random: + # - shuffle + seed = set_numpy_seed(p) + logger.debug(f"Use numpy seed {seed}.") + logger.info(f"Using preprocessor: {unit.module}") + logger.info(f'preprocessor resolution = {unit.processor_res}') + + def store_detected_map(detected_map, module: str) -> None: + if unit.save_detected_map: + detected_maps.append((detected_map, module)) + + def preprocess_input_image(input_image: np.ndarray): + """ Preprocess single input image. """ + detected_map, is_image = self.preprocessor[unit.module]( + input_image, + res=unit.processor_res, + thr_a=unit.threshold_a, + thr_b=unit.threshold_b, + low_vram=( + ("clip" in unit.module or unit.module == "ip-adapter_face_id_plus") and + shared.opts.data.get("controlnet_clip_detector_on_cpu", False) + ), + ) + if high_res_fix: + if is_image: + hr_control, hr_detected_map = Script.detectmap_proc(detected_map, unit.module, resize_mode, hr_y, hr_x) + store_detected_map(hr_detected_map, unit.module) + else: + hr_control = detected_map + else: + hr_control = None + + if is_image: + control, detected_map = Script.detectmap_proc(detected_map, unit.module, resize_mode, h, w) + store_detected_map(detected_map, unit.module) + else: + control = detected_map + store_detected_map(input_image, unit.module) + + if control_model_type == ControlModelType.T2I_StyleAdapter: + control = control['last_hidden_state'] + + if control_model_type == ControlModelType.ReVision: + control = control['image_embeds'] + return control, hr_control + + controls, hr_controls = list(zip(*[preprocess_input_image(img) for img in input_images])) + if len(controls) == len(hr_controls) == 1: + control = controls[0] + hr_control = hr_controls[0] + else: + control = controls + hr_control = hr_controls + + preprocessor_dict = dict( + name=unit.module, + preprocessor_resolution=unit.processor_res, + threshold_a=unit.threshold_a, + threshold_b=unit.threshold_b + ) + + global_average_pooling = ( + control_model_type.is_controlnet() and + model_net.control_model.global_average_pooling + ) + control_mode = external_code.control_mode_from_value(unit.control_mode) + forward_param = ControlParams( + control_model=model_net, + preprocessor=preprocessor_dict, + hint_cond=control, + weight=unit.weight, + guidance_stopped=False, + start_guidance_percent=unit.guidance_start, + stop_guidance_percent=unit.guidance_end, + advanced_weighting=unit.advanced_weighting, + control_model_type=control_model_type, + global_average_pooling=global_average_pooling, + hr_hint_cond=hr_control, + hr_option=HiResFixOption.from_value(unit.hr_option) if high_res_fix else HiResFixOption.BOTH, + soft_injection=control_mode != external_code.ControlMode.BALANCED, + cfg_injection=control_mode == external_code.ControlMode.CONTROL, + ) + forward_params.append(forward_param) + + if 'inpaint_only' in unit.module: + final_inpaint_feed = hr_control if hr_control is not None else control + final_inpaint_feed = final_inpaint_feed.detach().cpu().numpy() + final_inpaint_feed = np.ascontiguousarray(final_inpaint_feed).copy() + final_inpaint_mask = final_inpaint_feed[0, 3, :, :].astype(np.float32) + final_inpaint_raw = final_inpaint_feed[0, :3].astype(np.float32) + sigma = shared.opts.data.get("control_net_inpaint_blur_sigma", 7) + final_inpaint_mask = cv2.dilate(final_inpaint_mask, np.ones((sigma, sigma), dtype=np.uint8)) + final_inpaint_mask = cv2.blur(final_inpaint_mask, (sigma, sigma))[None] + _, Hmask, Wmask = final_inpaint_mask.shape + final_inpaint_raw = torch.from_numpy(np.ascontiguousarray(final_inpaint_raw).copy()) + final_inpaint_mask = torch.from_numpy(np.ascontiguousarray(final_inpaint_mask).copy()) + + def inpaint_only_post_processing(x): + _, H, W = x.shape + if Hmask != H or Wmask != W: + logger.error('Error: ControlNet find post-processing resolution mismatch. This could be related to other extensions hacked processing.') + return x + r = final_inpaint_raw.to(x.dtype).to(x.device) + m = final_inpaint_mask.to(x.dtype).to(x.device) + y = m * x.clip(0, 1) + (1 - m) * r + y = y.clip(0, 1) + return y + + post_processors.append(inpaint_only_post_processing) + + if 'recolor' in unit.module: + final_feed = hr_control if hr_control is not None else control + final_feed = final_feed.detach().cpu().numpy() + final_feed = np.ascontiguousarray(final_feed).copy() + final_feed = final_feed[0, 0, :, :].astype(np.float32) + final_feed = (final_feed * 255).clip(0, 255).astype(np.uint8) + Hfeed, Wfeed = final_feed.shape + + if 'luminance' in unit.module: + + def recolor_luminance_post_processing(x): + C, H, W = x.shape + if Hfeed != H or Wfeed != W or C != 3: + logger.error('Error: ControlNet find post-processing resolution mismatch. This could be related to other extensions hacked processing.') + return x + h = x.detach().cpu().numpy().transpose((1, 2, 0)) + h = (h * 255).clip(0, 255).astype(np.uint8) + h = cv2.cvtColor(h, cv2.COLOR_RGB2LAB) + h[:, :, 0] = final_feed + h = cv2.cvtColor(h, cv2.COLOR_LAB2RGB) + h = (h.astype(np.float32) / 255.0).transpose((2, 0, 1)) + y = torch.from_numpy(h).clip(0, 1).to(x) + return y + + post_processors.append(recolor_luminance_post_processing) + + if 'intensity' in unit.module: + + def recolor_intensity_post_processing(x): + C, H, W = x.shape + if Hfeed != H or Wfeed != W or C != 3: + logger.error('Error: ControlNet find post-processing resolution mismatch. This could be related to other extensions hacked processing.') + return x + h = x.detach().cpu().numpy().transpose((1, 2, 0)) + h = (h * 255).clip(0, 255).astype(np.uint8) + h = cv2.cvtColor(h, cv2.COLOR_RGB2HSV) + h[:, :, 2] = final_feed + h = cv2.cvtColor(h, cv2.COLOR_HSV2RGB) + h = (h.astype(np.float32) / 255.0).transpose((2, 0, 1)) + y = torch.from_numpy(h).clip(0, 1).to(x) + return y + + post_processors.append(recolor_intensity_post_processing) + + if '+lama' in unit.module: + forward_param.used_hint_cond_latent = hook.UnetHook.call_vae_using_process(p, control) + self.noise_modifier = forward_param.used_hint_cond_latent + + del model_net + + is_low_vram = any(unit.low_vram for unit in self.enabled_units) + + for i, param in enumerate(forward_params): + if param.control_model_type == ControlModelType.IPAdapter: + param.control_model.hook( + model=unet, + preprocessor_outputs=param.hint_cond, + weight=param.weight, + dtype=torch.float32, + start=param.start_guidance_percent, + end=param.stop_guidance_percent + ) + if param.control_model_type == ControlModelType.Controlllite: + param.control_model.hook( + model=unet, + cond=param.hint_cond, + weight=param.weight, + start=param.start_guidance_percent, + end=param.stop_guidance_percent + ) + if param.control_model_type == ControlModelType.InstantID: + # For instant_id we always expect ip-adapter model followed + # by ControlNet model. + assert i > 0, "InstantID control model should follow ipadapter model." + ip_adapter_param = forward_params[i - 1] + assert ip_adapter_param.control_model_type == ControlModelType.IPAdapter, \ + "InstantID control model should follow ipadapter model." + control_model = ip_adapter_param.control_model + assert hasattr(control_model, "image_emb") + param.control_context_override = control_model.image_emb + + self.latest_network = UnetHook(lowvram=is_low_vram) + self.latest_network.hook(model=unet, sd_ldm=sd_ldm, control_params=forward_params, process=p, + batch_option_uint_separate=batch_option_uint_separate, + batch_option_style_align=batch_option_style_align) + + self.detected_map = detected_maps + self.post_processors = post_processors + + def controlnet_hack(self, p): + t = time.time() + if getattr(shared.cmd_opts, 'controlnet_tracemalloc', False): + tracemalloc.start() + setattr(self, "malloc_begin", tracemalloc.take_snapshot()) + + self.controlnet_main_entry(p) + if getattr(shared.cmd_opts, 'controlnet_tracemalloc', False): + logger.info("After hook malloc:") + for stat in tracemalloc.take_snapshot().compare_to(self.malloc_begin, "lineno")[:10]: + logger.info(stat) + + if len(self.enabled_units) > 0: + logger.info(f'ControlNet Patched - Time = {time.time() - t}') + + @staticmethod + def process_has_sdxl_refiner(p): + return getattr(p, 'refiner_checkpoint', None) is not None + + def process(self, p, *args, **kwargs): + if not Script.process_has_sdxl_refiner(p): + self.controlnet_hack(p) + return + + def before_process_batch(self, p, *args, **kwargs): + if Script.process_has_sdxl_refiner(p): + self.controlnet_hack(p) + return + + def postprocess_batch(self, p, *args, **kwargs): + images = kwargs.get('images', []) + for post_processor in self.post_processors: + for i in range(len(images)): + images[i] = post_processor(images[i]) + return + + def postprocess(self, p, processed, *args): + sd_ldm = p.sd_model + unet = sd_ldm.model.diffusion_model + + clear_all_secondary_control_models(unet) + + self.noise_modifier = None + + for control_lora in getattr(p, 'controlnet_control_loras', []): + unbind_control_lora(control_lora) + p.controlnet_control_loras = [] + + self.post_processors = [] + setattr(p, 'controlnet_vae_cache', None) + + processor_params_flag = (', '.join(getattr(processed, 'extra_generation_params', []))).lower() + self.post_processors = [] + + if not batch_hijack.instance.is_batch: + self.enabled_units.clear() + + if shared.opts.data.get("control_net_detectmap_autosaving", False) and self.latest_network is not None: + for detect_map, module in self.detected_map: + detectmap_dir = os.path.join(shared.opts.data.get("control_net_detectedmap_dir", ""), module) + if not os.path.isabs(detectmap_dir): + detectmap_dir = os.path.join(p.outpath_samples, detectmap_dir) + if module != "none": + os.makedirs(detectmap_dir, exist_ok=True) + img = Image.fromarray(np.ascontiguousarray(detect_map.clip(0, 255).astype(np.uint8)).copy()) + save_image(img, detectmap_dir, module) + + if self.latest_network is None: + return + + if not batch_hijack.instance.is_batch: + if not shared.opts.data.get("control_net_no_detectmap", False): + if 'sd upscale' not in processor_params_flag: + if self.detected_map is not None: + for detect_map, module in self.detected_map: + if detect_map is None: + continue + detect_map = np.ascontiguousarray(detect_map.copy()).copy() + detect_map = external_code.visualize_inpaint_mask(detect_map) + processed.images.extend([ + Image.fromarray( + detect_map.clip(0, 255).astype(np.uint8) + ) + ]) + + self.input_image = None + self.latest_network.restore() + self.latest_network = None + self.detected_map.clear() + + gc.collect() + devices.torch_gc() + if getattr(shared.cmd_opts, 'controlnet_tracemalloc', False): + logger.info("After generation:") + for stat in tracemalloc.take_snapshot().compare_to(self.malloc_begin, "lineno")[:10]: + logger.info(stat) + tracemalloc.stop() + + def batch_tab_process(self, p, batches, *args, **kwargs): + self.enabled_units = Script.get_enabled_units(p) + for unit_i, unit in enumerate(self.enabled_units): + unit.batch_images = iter([batch[unit_i] for batch in batches]) + + def batch_tab_process_each(self, p, *args, **kwargs): + for unit_i, unit in enumerate(self.enabled_units): + if getattr(unit, 'loopback', False): + continue + + unit.image = next(unit.batch_images) + + def batch_tab_postprocess_each(self, p, processed, *args, **kwargs): + for unit_i, unit in enumerate(self.enabled_units): + if getattr(unit, 'loopback', False): + output_images = getattr(processed, 'images', [])[processed.index_of_first_image:] + if output_images: + unit.image = np.array(output_images[0]) + else: + logger.warning(f'Warning: No loopback image found for controlnet unit {unit_i}. ' + f'Using control map from last batch iteration instead') + + def batch_tab_postprocess(self, p, *args, **kwargs): + self.enabled_units.clear() + self.input_image = None + if self.latest_network is None: return + + self.latest_network.restore() + self.latest_network = None + self.detected_map.clear() + + +def on_ui_settings(): + section = ('control_net', "ControlNet") + shared.opts.add_option("control_net_detectedmap_dir", shared.OptionInfo( + "detected_maps", "Directory for detected maps auto saving", section=section)) + shared.opts.add_option("control_net_models_path", shared.OptionInfo( + "", "Extra path to scan for ControlNet models (e.g. training output directory)", section=section)) + shared.opts.add_option("control_net_modules_path", shared.OptionInfo( + "", "Path to directory containing annotator model directories (requires restart, overrides corresponding command line flag)", section=section)) + shared.opts.add_option("control_net_unit_count", shared.OptionInfo( + 3, "Multi-ControlNet: ControlNet unit number (requires restart)", gr.Slider, {"minimum": 1, "maximum": 10, "step": 1}, section=section)) + shared.opts.add_option("control_net_model_cache_size", shared.OptionInfo( + 2, "Model cache size (requires restart)", gr.Slider, {"minimum": 1, "maximum": 10, "step": 1}, section=section)) + shared.opts.add_option("control_net_inpaint_blur_sigma", shared.OptionInfo( + 7, "ControlNet inpainting Gaussian blur sigma", gr.Slider, {"minimum": 0, "maximum": 64, "step": 1}, section=section)) + shared.opts.add_option("control_net_no_detectmap", shared.OptionInfo( + False, "Do not append detectmap to output", gr.Checkbox, {"interactive": True}, section=section)) + shared.opts.add_option("control_net_detectmap_autosaving", shared.OptionInfo( + False, "Allow detectmap auto saving", gr.Checkbox, {"interactive": True}, section=section)) + shared.opts.add_option("control_net_allow_script_control", shared.OptionInfo( + False, "Allow other script to control this extension", gr.Checkbox, {"interactive": True}, section=section)) + shared.opts.add_option("control_net_sync_field_args", shared.OptionInfo( + True, "Paste ControlNet parameters in infotext", gr.Checkbox, {"interactive": True}, section=section)) + shared.opts.add_option("controlnet_show_batch_images_in_ui", shared.OptionInfo( + False, "Show batch images in gradio gallery output", gr.Checkbox, {"interactive": True}, section=section)) + shared.opts.add_option("controlnet_increment_seed_during_batch", shared.OptionInfo( + False, "Increment seed after each controlnet batch iteration", gr.Checkbox, {"interactive": True}, section=section)) + shared.opts.add_option("controlnet_disable_openpose_edit", shared.OptionInfo( + False, "Disable openpose edit", gr.Checkbox, {"interactive": True}, section=section)) + shared.opts.add_option("controlnet_disable_photopea_edit", shared.OptionInfo( + False, "Disable photopea edit", gr.Checkbox, {"interactive": True}, section=section)) + shared.opts.add_option("controlnet_photopea_warning", shared.OptionInfo( + True, "Photopea popup warning", gr.Checkbox, {"interactive": True}, section=section)) + shared.opts.add_option("controlnet_ignore_noninpaint_mask", shared.OptionInfo( + False, "Ignore mask on ControlNet input image if control type is not inpaint", + gr.Checkbox, {"interactive": True}, section=section)) + shared.opts.add_option("controlnet_clip_detector_on_cpu", shared.OptionInfo( + False, "Load CLIP preprocessor model on CPU", + gr.Checkbox, {"interactive": True}, section=section)) + + +script_callbacks.on_ui_settings(on_ui_settings) +script_callbacks.on_infotext_pasted(Infotext.on_infotext_pasted) +script_callbacks.on_after_component(ControlNetUiGroup.on_after_component) +script_callbacks.on_before_reload(ControlNetUiGroup.reset) \ No newline at end of file diff --git a/extensions-builtin/sd_forge_controlnet/scripts/xyz_grid_support.py b/extensions-builtin/sd_forge_controlnet/scripts/xyz_grid_support.py new file mode 100644 index 00000000..950bfe9b --- /dev/null +++ b/extensions-builtin/sd_forge_controlnet/scripts/xyz_grid_support.py @@ -0,0 +1,449 @@ +import re +import numpy as np + +from modules import scripts, shared + +try: + from lib_controlnet.global_state import update_controlnet_filenames, cn_models_names, get_preprocessor_names + from lib_controlnet.external_code import ResizeMode, ControlMode + +except (ImportError, NameError): + import_error = True +else: + import_error = False + +DEBUG_MODE = False + + +def debug_info(func): + def debug_info_(*args, **kwargs): + if DEBUG_MODE: + print(f"Debug info: {func.__name__}, {args}") + return func(*args, **kwargs) + return debug_info_ + + +def find_dict(dict_list, keyword, search_key="name", stop=False): + result = next((d for d in dict_list if d[search_key] == keyword), None) + if result or not stop: + return result + else: + raise ValueError(f"Dictionary with value '{keyword}' in key '{search_key}' not found.") + + +def flatten(lst): + result = [] + for element in lst: + if isinstance(element, list): + result.extend(flatten(element)) + else: + result.append(element) + return result + + +def is_all_included(target_list, check_list, allow_blank=False, stop=False): + for element in flatten(target_list): + if allow_blank and str(element) in ["None", ""]: + continue + elif element not in check_list: + if not stop: + return False + else: + raise ValueError(f"'{element}' is not included in check list.") + return True + + +class ListParser(): + """This class restores a broken list caused by the following process + in the xyz_grid module. + -> valslist = [x.strip() for x in chain.from_iterable( + csv.reader(StringIO(vals)))] + It also performs type conversion, + adjusts the number of elements in the list, and other operations. + + This class directly modifies the received list. + """ + numeric_pattern = { + int: { + "range": r"\s*([+-]?\s*\d+)\s*-\s*([+-]?\s*\d+)(?:\s*\(([+-]\d+)\s*\))?\s*", + "count": r"\s*([+-]?\s*\d+)\s*-\s*([+-]?\s*\d+)(?:\s*\[(\d+)\s*\])?\s*" + }, + float: { + "range": r"\s*([+-]?\s*\d+(?:\.\d*)?)\s*-\s*([+-]?\s*\d+(?:\.\d*)?)(?:\s*\(([+-]\d+(?:\.\d*)?)\s*\))?\s*", + "count": r"\s*([+-]?\s*\d+(?:\.\d*)?)\s*-\s*([+-]?\s*\d+(?:\.\d*)?)(?:\s*\[(\d+(?:\.\d*)?)\s*\])?\s*" + } + } + + ################################################ + # + # Initialization method from here. + # + ################################################ + + def __init__(self, my_list, converter=None, allow_blank=True, exclude_list=None, run=True): + self.my_list = my_list + self.converter = converter + self.allow_blank = allow_blank + self.exclude_list = exclude_list + self.re_bracket_start = None + self.re_bracket_start_precheck = None + self.re_bracket_end = None + self.re_bracket_end_precheck = None + self.re_range = None + self.re_count = None + self.compile_regex() + if run: + self.auto_normalize() + + def compile_regex(self): + exclude_pattern = "|".join(self.exclude_list) if self.exclude_list else None + if exclude_pattern is None: + self.re_bracket_start = re.compile(r"^\[") + self.re_bracket_end = re.compile(r"\]$") + else: + self.re_bracket_start = re.compile(fr"^\[(?!(?:{exclude_pattern})\])") + self.re_bracket_end = re.compile(fr"(? valslist = [opt.type(x) for x in valslist] + # Perform type conversion using the function + # set to the confirm attribute instead. + # + def identity(x): + return x + + def enable_script_control(): + shared.opts.data["control_net_allow_script_control"] = True + + def apply_field(field): + @debug_info + def apply_field_(p, x, xs): + enable_script_control() + setattr(p, field, x) + + return apply_field_ + + ################################################ + # The confirm function defined in this module + # enables list notation and performs type conversion. + # + # Example: + # any = [any, any, any, ...] + # [any] = [any, None, None, ...] + # [None, None, any] = [None, None, any] + # [,,any] = [None, None, any] + # any, [,any,] = [any, any, any, ...], [None, any, None] + # + # Enabled Only: + # any = [any] = [any, None, None, ...] + # (any and [any] are considered equivalent) + # + def confirm(func_or_str): + @debug_info + def confirm_(p, xs): + if callable(func_or_str): # func_or_str is converter + ListParser(xs, func_or_str, allow_blank=True) + return + + elif isinstance(func_or_str, str): # func_or_str is keyword + valid_data = find_dict(validation_data, func_or_str, stop=True) + converter = valid_data["type"] + exclude_list = valid_data["exclude"]() if valid_data["exclude"] else None + check_list = valid_data["check"]() + + ListParser(xs, converter, allow_blank=True, exclude_list=exclude_list) + is_all_included(xs, check_list, allow_blank=True, stop=True) + return + + else: + raise TypeError(f"Argument must be callable or str, not {type(func_or_str).__name__}.") + + return confirm_ + + def bool_(string): + string = str(string) + if string in ["None", ""]: + return None + elif string.lower() in ["true", "1"]: + return True + elif string.lower() in ["false", "0"]: + return False + else: + raise ValueError(f"Could not convert string to boolean: {string}") + + def choices_bool(): + return ["False", "True"] + + def choices_model(): + update_controlnet_filenames() + return list(cn_models_names.values()) + + def choices_control_mode(): + return [e.value for e in ControlMode] + + def choices_resize_mode(): + return [e.value for e in ResizeMode] + + def choices_preprocessor(): + return list(get_preprocessor_names()) + + def make_excluded_list(): + pattern = re.compile(r"\[(\w+)\]") + return [match.group(1) for s in choices_model() + for match in pattern.finditer(s)] + + validation_data = [ + {"name": "model", "type": str, "check": choices_model, "exclude": make_excluded_list}, + {"name": "control_mode", "type": str, "check": choices_control_mode, "exclude": None}, + {"name": "resize_mode", "type": str, "check": choices_resize_mode, "exclude": None}, + {"name": "preprocessor", "type": str, "check": choices_preprocessor, "exclude": None}, + ] + + extra_axis_options = [ + xyz_grid.AxisOption("[ControlNet] Enabled", identity, apply_field("control_net_enabled"), confirm=confirm(bool_), choices=choices_bool), + xyz_grid.AxisOption("[ControlNet] Model", identity, apply_field("control_net_model"), confirm=confirm("model"), choices=choices_model, cost=0.9), + xyz_grid.AxisOption("[ControlNet] Weight", identity, apply_field("control_net_weight"), confirm=confirm(float)), + xyz_grid.AxisOption("[ControlNet] Guidance Start", identity, apply_field("control_net_guidance_start"), confirm=confirm(float)), + xyz_grid.AxisOption("[ControlNet] Guidance End", identity, apply_field("control_net_guidance_end"), confirm=confirm(float)), + xyz_grid.AxisOption("[ControlNet] Control Mode", identity, apply_field("control_net_control_mode"), confirm=confirm("control_mode"), choices=choices_control_mode), + xyz_grid.AxisOption("[ControlNet] Resize Mode", identity, apply_field("control_net_resize_mode"), confirm=confirm("resize_mode"), choices=choices_resize_mode), + xyz_grid.AxisOption("[ControlNet] Preprocessor", identity, apply_field("control_net_module"), confirm=confirm("preprocessor"), choices=choices_preprocessor), + xyz_grid.AxisOption("[ControlNet] Pre Resolution", identity, apply_field("control_net_pres"), confirm=confirm(int)), + xyz_grid.AxisOption("[ControlNet] Pre Threshold A", identity, apply_field("control_net_pthr_a"), confirm=confirm(float)), + xyz_grid.AxisOption("[ControlNet] Pre Threshold B", identity, apply_field("control_net_pthr_b"), confirm=confirm(float)), + ] + + xyz_grid.axis_options.extend(extra_axis_options) + + +def run(): + xyz_grid = find_module("xyz_grid.py, xy_grid.py") + if xyz_grid: + add_axis_options(xyz_grid) + + +if not import_error: + run() diff --git a/extensions-builtin/sd_forge_controlnet/style.css b/extensions-builtin/sd_forge_controlnet/style.css new file mode 100644 index 00000000..2e15e8d4 --- /dev/null +++ b/extensions-builtin/sd_forge_controlnet/style.css @@ -0,0 +1,182 @@ +.cnet-modal { + display: none; + /* Hidden by default */ + position: fixed; + /* Stay in place */ + z-index: 2147483647; + /* Sit on top */ + left: 0; + top: 0; + width: 100%; + /* Full width */ + height: 100%; + /* Full height */ + overflow: auto; + /* Enable scroll if needed */ + background-color: rgba(0, 0, 0, 0.4); + /* Black with opacity */ + max-width: none !important; + /* Fix sizing with SD.Next (vladmandic/automatic#2594) */ +} + +.cnet-modal-content { + position: relative; + background-color: var(--background-fill-primary); + margin: 5vh auto; + /* 15% from the top and centered */ + padding: 20px; + border: 1px solid #888; + width: 95%; + height: 90vh; + /* Could be more or less, depending on screen size */ + box-shadow: 0 4px 8px 0 rgba(0, 0, 0, 0.2), 0 6px 20px 0 rgba(0, 0, 0, 0.19); + animation-name: animatetop; + animation-duration: 0.4s; + max-width: none !important; + /* Fix sizing with SD.Next (vladmandic/automatic#2594) */ +} + +.cnet-modal-content iframe { + position: absolute; + top: 0; + left: 0; + width: 100%; + height: 100%; + border: none; +} + +.cnet-modal-content.alert { + padding: var(--size-5); +} + +.cnet-modal-content.alert ul { + list-style-type: none; +} + +.cnet-modal-close { + color: white !important; + right: 0.25em; + top: 0; + cursor: pointer; + position: absolute; + font-size: 56px; + font-weight: bold; +} + +@keyframes animatetop { + from { + top: -300px; + opacity: 0 + } + + to { + top: 0; + opacity: 1 + } +} + +.cnet-generated-image-control-group, +.cnet-upload-pose { + display: flex; + flex-direction: column; + align-items: flex-end; + + position: absolute; + right: var(--size-2); + bottom: var(--size-2); +} + +/* Gradio button style */ +.cnet-download-pose a, +.cnet-close-preview, +.cnet-edit-pose, +.cnet-upload-pose, +.cnet-photopea-child-trigger { + font-size: x-small !important; + font-weight: bold !important; + padding: 2px !important; + box-shadow: var(--shadow-drop); + border: 1px solid var(--button-secondary-border-color); + border-radius: var(--radius-sm); + background: var(--background-fill-primary); + height: var(--size-5); + color: var(--block-label-text-color) !important; + display: flex; + justify-content: center; + cursor: pointer; +} + +.cnet-download-pose:hover a, +.cnet-close-preview:hover a, +.cnet-edit-pose:hover, +.cnet-upload-pose:hover, +.cnet-photopea-child-trigger:hover { + color: var(--block-label-text-color) !important; +} + +.cnet-unit-active { + color: green !important; + font-weight: bold !important; +} + +.dark .cnet-unit-active { + color: greenyellow !important; +} + +.cnet-badge { + display: inline-block; + padding: 0.25em 0.75em; + font-size: 0.75em; + font-weight: bold; + color: white; + border-radius: 0.5em; + text-align: center; + vertical-align: middle; + margin-left: var(--size-2); +} + +.cnet-badge.primary { + background-color: green; +} + +.cnet-a1111-badge { + position: absolute; + bottom: 0px; + right: 0px; +} + +.cnet-disabled-radio { + opacity: 50%; +} + +.controlnet_row { + margin-top: 10px !important; +} + +/* JSON pose upload button styling */ +.cnet-upload-pose input[type=file] { + position: absolute; + left: 0; + top: 0; + opacity: 0; + width: 100%; + height: 100%; +} + +/* Photopea integration styles */ +.photopea-button-group { + position: absolute; + top: -30px; /* 20px modal padding + 10px margin */ +} + +.photopea-button { + font-size: 3rem; + font-weight: bold; + padding: 2px !important; + margin: 2px !important; + box-shadow: var(--shadow-drop); + border: 1px solid var(--button-secondary-border-color); + border-radius: var(--radius-sm); + background: var(--background-fill-primary); + color: var(--block-label-text-color); +} diff --git a/extensions-builtin/sd_forge_controlnet_example/scripts/sd_forge_controlnet_example.py b/extensions-builtin/sd_forge_controlnet_example/scripts/sd_forge_controlnet_example.py index 9c8833ec..2f9a168c 100644 --- a/extensions-builtin/sd_forge_controlnet_example/scripts/sd_forge_controlnet_example.py +++ b/extensions-builtin/sd_forge_controlnet_example/scripts/sd_forge_controlnet_example.py @@ -5,7 +5,7 @@ import gradio as gr from modules import scripts from modules.shared_cmd_options import cmd_opts -from modules_forge.shared import shared_preprocessors +from modules_forge.shared import supported_preprocessors from modules.modelloader import load_file_from_url from ldm_patched.modules.controlnet import load_controlnet from modules_forge.controlnet import apply_controlnet_advanced @@ -73,7 +73,7 @@ class ControlNetExampleForge(scripts.Script): width = W * 8 batch_size = p.batch_size - preprocessor = shared_preprocessors['canny'] + preprocessor = supported_preprocessors['canny'] # detect control at certain resolution control_image = preprocessor( diff --git a/modules_forge/forge_util.py b/modules_forge/forge_util.py index 75451cdb..c9c6c579 100644 --- a/modules_forge/forge_util.py +++ b/modules_forge/forge_util.py @@ -7,6 +7,25 @@ import string import cv2 +def HWC3(x): + assert x.dtype == np.uint8 + if x.ndim == 2: + x = x[:, :, None] + assert x.ndim == 3 + H, W, C = x.shape + assert C == 1 or C == 3 or C == 4 + if C == 3: + return x + if C == 1: + return np.concatenate([x, x, x], axis=2) + if C == 4: + color = x[:, :, 0:3].astype(np.float32) + alpha = x[:, :, 3:4].astype(np.float32) / 255.0 + y = color * alpha + 255.0 * (1.0 - alpha) + y = y.clip(0, 255).astype(np.uint8) + return y + + def compute_cond_mark(cond_or_uncond, sigmas): cond_or_uncond_size = int(sigmas.shape[0]) diff --git a/modules_forge/shared.py b/modules_forge/shared.py index ed22c640..793a469a 100644 --- a/modules_forge/shared.py +++ b/modules_forge/shared.py @@ -1,11 +1,7 @@ -import cv2 import os -import torch +import ldm_patched.modules.utils from modules.paths import models_path -from ldm_patched.modules import model_management -from ldm_patched.modules.model_patcher import ModelPatcher -from modules_forge.forge_util import resize_image_with_pad controlnet_dir = os.path.join(models_path, 'ControlNet') @@ -14,97 +10,29 @@ os.makedirs(controlnet_dir, exist_ok=True) preprocessor_dir = os.path.join(models_path, 'ControlNetPreprocessor') os.makedirs(preprocessor_dir, exist_ok=True) -shared_preprocessors = {} +supported_preprocessors = {} +supported_control_models = [] -def add_preprocessor(preprocessor): - global shared_preprocessors +def add_supported_preprocessor(preprocessor): + global supported_preprocessors p = preprocessor - shared_preprocessors[p.name] = p + supported_preprocessors[p.name] = p return -class PreprocessorParameter: - def __init__(self, minimum=0.0, maximum=1.0, step=0.01, label='Parameter 1', value=0.5, visible=False, **kwargs): - self.gradio_update_kwargs = dict( - minimum=minimum, maximum=maximum, step=step, label=label, value=value, visible=visible, **kwargs - ) +def add_supported_control_model(control_model): + global supported_control_models + supported_control_models.append(control_model) + return -class Preprocessor: - def __init__(self): - self.name = 'PreprocessorBase' - self.tags = [] - self.slider_resolution = PreprocessorParameter(label='Resolution', minimum=128, maximum=2048, value=512, step=8, visible=True) - self.slider_1 = PreprocessorParameter() - self.slider_2 = PreprocessorParameter() - self.slider_3 = PreprocessorParameter() - self.model_patcher: ModelPatcher = None - self.show_control_mode = True - self.do_not_need_model = False - self.sorting_priority = 0.0 # higher goes to top in the list - - def setup_model_patcher(self, model, load_device=None, offload_device=None, dtype=torch.float32, **kwargs): - if load_device is None: - load_device = model_management.get_torch_device() - - if offload_device is None: - offload_device = torch.device('cpu') - - if not model_management.should_use_fp16(load_device): - dtype = torch.float32 - - model.eval() - model = model.to(device=offload_device, dtype=dtype) - - self.model_patcher = ModelPatcher(model=model, load_device=load_device, offload_device=offload_device, **kwargs) - self.model_patcher.dtype = dtype - return self.model_patcher - - def move_all_model_patchers_to_gpu(self): - model_management.load_models_gpu([self.model_patcher]) - return - - def send_tensor_to_model_device(self, x): - return x.to(device=self.model_patcher.current_device, dtype=self.model_patcher.dtype) - - def lazy_memory_management(self, model): - # This is a lazy method to just free some memory - # so that we can still use old codes to manage memory in a bad way - # Ideally this should all be removed and all memory should be managed by model patcher. - # But the workload is too big, so we just use a quick method to manage in dirty way. - required_memory = model_management.module_size(model) + model_management.minimum_inference_memory() - model_management.free_memory(required_memory, device=model_management.get_torch_device()) - return - - def process_before_every_sampling(self, process, cnet): - return - - def __call__(self, input_image, resolution, slider_1=None, slider_2=None, slider_3=None, **kwargs): - return input_image - - -class PreprocessorNone(Preprocessor): - def __init__(self): - super().__init__() - self.name = 'None' - self.sorting_priority = 10 - - -class PreprocessorCanny(Preprocessor): - def __init__(self): - super().__init__() - self.name = 'canny' - self.tags = ['Canny'] - self.slider_1 = PreprocessorParameter(minimum=0, maximum=256, step=1, value=100, label='Low Threshold', visible=True) - self.slider_2 = PreprocessorParameter(minimum=0, maximum=256, step=1, value=200, label='High Threshold', visible=True) - self.sorting_priority = 100 - - def __call__(self, input_image, resolution, slider_1=None, slider_2=None, slider_3=None, **kwargs): - input_image, remove_pad = resize_image_with_pad(input_image, resolution) - canny_image = cv2.cvtColor(cv2.Canny(input_image, int(slider_1), int(slider_2)), cv2.COLOR_GRAY2RGB) - return remove_pad(canny_image) - - -add_preprocessor(PreprocessorNone()) -add_preprocessor(PreprocessorCanny()) +def try_load_supported_control_model(ckpt_path): + global supported_control_models + state_dict = ldm_patched.modules.utils.load_torch_file(ckpt_path, safe_load=True) + for supported_type in supported_control_models: + state_dict_copy = {k: v for k, v in state_dict.items()} + model = supported_type.try_build_from_state_dict(state_dict_copy, ckpt_path) + if model is not None: + return model + return None diff --git a/modules_forge/supported_controlnet.py b/modules_forge/supported_controlnet.py new file mode 100644 index 00000000..d557e847 --- /dev/null +++ b/modules_forge/supported_controlnet.py @@ -0,0 +1,158 @@ +import os +import torch +import ldm_patched.modules.utils +import ldm_patched.controlnet + +from ldm_patched.modules.controlnet import ControlLora, ControlNet, load_t2i_adapter +from modules_forge.controlnet import apply_controlnet_advanced +from modules_forge.shared import add_supported_control_model + + +class ControlModelPatcher: + @staticmethod + def try_build_from_state_dict(state_dict, ckpt_path): + return None + + def __init__(self, model_patcher): + self.model_patcher = model_patcher + + def patch_to_process(self, p, control_image): + return + + +class ControlNetPatcher(ControlModelPatcher): + @staticmethod + def try_build_from_state_dict(controlnet_data, ckpt_path): + if "lora_controlnet" in controlnet_data: + return ControlNetPatcher(ControlLora(controlnet_data)) + + controlnet_config = None + if "controlnet_cond_embedding.conv_in.weight" in controlnet_data: # diffusers format + unet_dtype = ldm_patched.modules.model_management.unet_dtype() + controlnet_config = ldm_patched.modules.model_detection.unet_config_from_diffusers_unet(controlnet_data, + unet_dtype) + diffusers_keys = ldm_patched.modules.utils.unet_to_diffusers(controlnet_config) + diffusers_keys["controlnet_mid_block.weight"] = "middle_block_out.0.weight" + diffusers_keys["controlnet_mid_block.bias"] = "middle_block_out.0.bias" + + count = 0 + loop = True + while loop: + suffix = [".weight", ".bias"] + for s in suffix: + k_in = "controlnet_down_blocks.{}{}".format(count, s) + k_out = "zero_convs.{}.0{}".format(count, s) + if k_in not in controlnet_data: + loop = False + break + diffusers_keys[k_in] = k_out + count += 1 + + count = 0 + loop = True + while loop: + suffix = [".weight", ".bias"] + for s in suffix: + if count == 0: + k_in = "controlnet_cond_embedding.conv_in{}".format(s) + else: + k_in = "controlnet_cond_embedding.blocks.{}{}".format(count - 1, s) + k_out = "input_hint_block.{}{}".format(count * 2, s) + if k_in not in controlnet_data: + k_in = "controlnet_cond_embedding.conv_out{}".format(s) + loop = False + diffusers_keys[k_in] = k_out + count += 1 + + new_sd = {} + for k in diffusers_keys: + if k in controlnet_data: + new_sd[diffusers_keys[k]] = controlnet_data.pop(k) + + leftover_keys = controlnet_data.keys() + if len(leftover_keys) > 0: + print("leftover keys:", leftover_keys) + controlnet_data = new_sd + + pth_key = 'control_model.zero_convs.0.0.weight' + pth = False + key = 'zero_convs.0.0.weight' + if pth_key in controlnet_data: + pth = True + key = pth_key + prefix = "control_model." + elif key in controlnet_data: + prefix = "" + else: + net = load_t2i_adapter(controlnet_data) + if net is None: + return None + return ControlNetPatcher(net) + + if controlnet_config is None: + unet_dtype = ldm_patched.modules.model_management.unet_dtype() + controlnet_config = ldm_patched.modules.model_detection.model_config_from_unet(controlnet_data, prefix, + unet_dtype, True).unet_config + load_device = ldm_patched.modules.model_management.get_torch_device() + manual_cast_dtype = ldm_patched.modules.model_management.unet_manual_cast(unet_dtype, load_device) + if manual_cast_dtype is not None: + controlnet_config["operations"] = ldm_patched.modules.ops.manual_cast + controlnet_config.pop("out_channels") + controlnet_config["hint_channels"] = controlnet_data["{}input_hint_block.0.weight".format(prefix)].shape[1] + control_model = ldm_patched.controlnet.cldm.ControlNet(**controlnet_config) + + if pth: + if 'difference' in controlnet_data: + print("WARNING: Your controlnet model is diff version rather than official float16 model. " + "Please use an official float16/float32 model for robust performance.") + + class WeightsLoader(torch.nn.Module): + pass + + w = WeightsLoader() + w.control_model = control_model + missing, unexpected = w.load_state_dict(controlnet_data, strict=False) + else: + missing, unexpected = control_model.load_state_dict(controlnet_data, strict=False) + print(missing, unexpected) + + global_average_pooling = False + filename = os.path.splitext(ckpt_path)[0] + if filename.endswith("_shuffle") or filename.endswith("_shuffle_fp16"): + # TODO: smarter way of enabling global_average_pooling + global_average_pooling = True + + control = ControlNet(control_model, global_average_pooling=global_average_pooling, load_device=load_device, + manual_cast_dtype=manual_cast_dtype) + return ControlNetPatcher(control) + + def __init__(self, model_patcher): + super().__init__(model_patcher) + self.strength = 1.0 + self.start_percent = 0.0 + self.end_percent = 1.0 + self.positive_advanced_weighting = None + self.negative_advanced_weighting = None + self.advanced_frame_weighting = None + self.advanced_sigma_weighting = None + + def patch_to_process(self, p, control_image): + unet = p.sd_model.forge_objects.unet + + unet = apply_controlnet_advanced( + unet=unet, + controlnet=self.model_patcher, + image_bchw=control_image, + strength=self.strength, + start_percent=self.start_percent, + end_percent=self.end_percent, + positive_advanced_weighting=self.positive_advanced_weighting, + negative_advanced_weighting=self.negative_advanced_weighting, + advanced_frame_weighting=self.advanced_frame_weighting, + advanced_sigma_weighting=self.advanced_sigma_weighting) + + p.sd_model.forge_objects.unet = unet + return + + +add_supported_control_model(ControlNetPatcher) diff --git a/modules_forge/supported_preprocessor.py b/modules_forge/supported_preprocessor.py new file mode 100644 index 00000000..a97ba30c --- /dev/null +++ b/modules_forge/supported_preprocessor.py @@ -0,0 +1,95 @@ +import cv2 +import torch + +from modules_forge.shared import add_supported_preprocessor +from ldm_patched.modules import model_management +from ldm_patched.modules.model_patcher import ModelPatcher +from modules_forge.forge_util import resize_image_with_pad + + +class PreprocessorParameter: + def __init__(self, minimum=0.0, maximum=1.0, step=0.01, label='Parameter 1', value=0.5, visible=False, **kwargs): + self.gradio_update_kwargs = dict( + minimum=minimum, maximum=maximum, step=step, label=label, value=value, visible=visible, **kwargs + ) + + +class Preprocessor: + def __init__(self): + self.name = 'PreprocessorBase' + self.tags = [] + self.model_filename_filers = [] + self.slider_resolution = PreprocessorParameter(label='Resolution', minimum=128, maximum=2048, value=512, step=8, visible=True) + self.slider_1 = PreprocessorParameter() + self.slider_2 = PreprocessorParameter() + self.slider_3 = PreprocessorParameter() + self.model_patcher: ModelPatcher = None + self.show_control_mode = True + self.do_not_need_model = False + self.sorting_priority = 0.0 # higher goes to top in the list + + def setup_model_patcher(self, model, load_device=None, offload_device=None, dtype=torch.float32, **kwargs): + if load_device is None: + load_device = model_management.get_torch_device() + + if offload_device is None: + offload_device = torch.device('cpu') + + if not model_management.should_use_fp16(load_device): + dtype = torch.float32 + + model.eval() + model = model.to(device=offload_device, dtype=dtype) + + self.model_patcher = ModelPatcher(model=model, load_device=load_device, offload_device=offload_device, **kwargs) + self.model_patcher.dtype = dtype + return self.model_patcher + + def move_all_model_patchers_to_gpu(self): + model_management.load_models_gpu([self.model_patcher]) + return + + def send_tensor_to_model_device(self, x): + return x.to(device=self.model_patcher.current_device, dtype=self.model_patcher.dtype) + + def lazy_memory_management(self, model): + # This is a lazy method to just free some memory + # so that we can still use old codes to manage memory in a bad way + # Ideally this should all be removed and all memory should be managed by model patcher. + # But the workload is too big, so we just use a quick method to manage in dirty way. + required_memory = model_management.module_size(model) + model_management.minimum_inference_memory() + model_management.free_memory(required_memory, device=model_management.get_torch_device()) + return + + def process_before_every_sampling(self, process, cnet): + return + + def __call__(self, input_image, resolution, slider_1=None, slider_2=None, slider_3=None, **kwargs): + return input_image + + +class PreprocessorNone(Preprocessor): + def __init__(self): + super().__init__() + self.name = 'None' + self.sorting_priority = 10 + + +class PreprocessorCanny(Preprocessor): + def __init__(self): + super().__init__() + self.name = 'canny' + self.tags = ['Canny'] + self.model_filename_filers = ['canny'] + self.slider_1 = PreprocessorParameter(minimum=0, maximum=256, step=1, value=100, label='Low Threshold', visible=True) + self.slider_2 = PreprocessorParameter(minimum=0, maximum=256, step=1, value=200, label='High Threshold', visible=True) + self.sorting_priority = 100 + + def __call__(self, input_image, resolution, slider_1=None, slider_2=None, slider_3=None, **kwargs): + input_image, remove_pad = resize_image_with_pad(input_image, resolution) + canny_image = cv2.cvtColor(cv2.Canny(input_image, int(slider_1), int(slider_2)), cv2.COLOR_GRAY2RGB) + return remove_pad(canny_image) + + +add_supported_preprocessor(PreprocessorNone()) +add_supported_preprocessor(PreprocessorCanny()) \ No newline at end of file