mirror of
https://github.com/lllyasviel/huggingface_guess.git
synced 2026-04-30 04:11:16 +00:00
ini
This commit is contained in:
164
.gitignore
vendored
Normal file
164
.gitignore
vendored
Normal file
@@ -0,0 +1,164 @@
|
||||
demo.py
|
||||
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
|
||||
# C extensions
|
||||
*.so
|
||||
|
||||
# Distribution / packaging
|
||||
.Python
|
||||
build/
|
||||
develop-eggs/
|
||||
dist/
|
||||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
lib/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
var/
|
||||
wheels/
|
||||
share/python-wheels/
|
||||
*.egg-info/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
MANIFEST
|
||||
|
||||
# PyInstaller
|
||||
# Usually these files are written by a python script from a template
|
||||
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||
*.manifest
|
||||
*.spec
|
||||
|
||||
# Installer logs
|
||||
pip-log.txt
|
||||
pip-delete-this-directory.txt
|
||||
|
||||
# Unit test / coverage reports
|
||||
htmlcov/
|
||||
.tox/
|
||||
.nox/
|
||||
.coverage
|
||||
.coverage.*
|
||||
.cache
|
||||
nosetests.xml
|
||||
coverage.xml
|
||||
*.cover
|
||||
*.py,cover
|
||||
.hypothesis/
|
||||
.pytest_cache/
|
||||
cover/
|
||||
|
||||
# Translations
|
||||
*.mo
|
||||
*.pot
|
||||
|
||||
# Django stuff:
|
||||
*.log
|
||||
local_settings.py
|
||||
db.sqlite3
|
||||
db.sqlite3-journal
|
||||
|
||||
# Flask stuff:
|
||||
instance/
|
||||
.webassets-cache
|
||||
|
||||
# Scrapy stuff:
|
||||
.scrapy
|
||||
|
||||
# Sphinx documentation
|
||||
docs/_build/
|
||||
|
||||
# PyBuilder
|
||||
.pybuilder/
|
||||
target/
|
||||
|
||||
# Jupyter Notebook
|
||||
.ipynb_checkpoints
|
||||
|
||||
# IPython
|
||||
profile_default/
|
||||
ipython_config.py
|
||||
|
||||
# pyenv
|
||||
# For a library or package, you might want to ignore these files since the code is
|
||||
# intended to run in multiple environments; otherwise, check them in:
|
||||
# .python-version
|
||||
|
||||
# pipenv
|
||||
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
||||
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
||||
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
||||
# install all needed dependencies.
|
||||
#Pipfile.lock
|
||||
|
||||
# poetry
|
||||
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
||||
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
||||
# commonly ignored for libraries.
|
||||
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
||||
#poetry.lock
|
||||
|
||||
# pdm
|
||||
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
||||
#pdm.lock
|
||||
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
||||
# in version control.
|
||||
# https://pdm.fming.dev/latest/usage/project/#working-with-version-control
|
||||
.pdm.toml
|
||||
.pdm-python
|
||||
.pdm-build/
|
||||
|
||||
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
||||
__pypackages__/
|
||||
|
||||
# Celery stuff
|
||||
celerybeat-schedule
|
||||
celerybeat.pid
|
||||
|
||||
# SageMath parsed files
|
||||
*.sage.py
|
||||
|
||||
# Environments
|
||||
.env
|
||||
.venv
|
||||
env/
|
||||
venv/
|
||||
ENV/
|
||||
env.bak/
|
||||
venv.bak/
|
||||
|
||||
# Spyder project settings
|
||||
.spyderproject
|
||||
.spyproject
|
||||
|
||||
# Rope project settings
|
||||
.ropeproject
|
||||
|
||||
# mkdocs documentation
|
||||
/site
|
||||
|
||||
# mypy
|
||||
.mypy_cache/
|
||||
.dmypy.json
|
||||
dmypy.json
|
||||
|
||||
# Pyre type checker
|
||||
.pyre/
|
||||
|
||||
# pytype static type analyzer
|
||||
.pytype/
|
||||
|
||||
# Cython debug symbols
|
||||
cython_debug/
|
||||
|
||||
# PyCharm
|
||||
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
||||
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
||||
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
||||
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
||||
.idea/
|
||||
674
LICENSE
Normal file
674
LICENSE
Normal file
@@ -0,0 +1,674 @@
|
||||
GNU GENERAL PUBLIC LICENSE
|
||||
Version 3, 29 June 2007
|
||||
|
||||
Copyright (C) 2007 Free Software Foundation, Inc. <https://fsf.org/>
|
||||
Everyone is permitted to copy and distribute verbatim copies
|
||||
of this license document, but changing it is not allowed.
|
||||
|
||||
Preamble
|
||||
|
||||
The GNU General Public License is a free, copyleft license for
|
||||
software and other kinds of works.
|
||||
|
||||
The licenses for most software and other practical works are designed
|
||||
to take away your freedom to share and change the works. By contrast,
|
||||
the GNU General Public License is intended to guarantee your freedom to
|
||||
share and change all versions of a program--to make sure it remains free
|
||||
software for all its users. We, the Free Software Foundation, use the
|
||||
GNU General Public License for most of our software; it applies also to
|
||||
any other work released this way by its authors. You can apply it to
|
||||
your programs, too.
|
||||
|
||||
When we speak of free software, we are referring to freedom, not
|
||||
price. Our General Public Licenses are designed to make sure that you
|
||||
have the freedom to distribute copies of free software (and charge for
|
||||
them if you wish), that you receive source code or can get it if you
|
||||
want it, that you can change the software or use pieces of it in new
|
||||
free programs, and that you know you can do these things.
|
||||
|
||||
To protect your rights, we need to prevent others from denying you
|
||||
these rights or asking you to surrender the rights. Therefore, you have
|
||||
certain responsibilities if you distribute copies of the software, or if
|
||||
you modify it: responsibilities to respect the freedom of others.
|
||||
|
||||
For example, if you distribute copies of such a program, whether
|
||||
gratis or for a fee, you must pass on to the recipients the same
|
||||
freedoms that you received. You must make sure that they, too, receive
|
||||
or can get the source code. And you must show them these terms so they
|
||||
know their rights.
|
||||
|
||||
Developers that use the GNU GPL protect your rights with two steps:
|
||||
(1) assert copyright on the software, and (2) offer you this License
|
||||
giving you legal permission to copy, distribute and/or modify it.
|
||||
|
||||
For the developers' and authors' protection, the GPL clearly explains
|
||||
that there is no warranty for this free software. For both users' and
|
||||
authors' sake, the GPL requires that modified versions be marked as
|
||||
changed, so that their problems will not be attributed erroneously to
|
||||
authors of previous versions.
|
||||
|
||||
Some devices are designed to deny users access to install or run
|
||||
modified versions of the software inside them, although the manufacturer
|
||||
can do so. This is fundamentally incompatible with the aim of
|
||||
protecting users' freedom to change the software. The systematic
|
||||
pattern of such abuse occurs in the area of products for individuals to
|
||||
use, which is precisely where it is most unacceptable. Therefore, we
|
||||
have designed this version of the GPL to prohibit the practice for those
|
||||
products. If such problems arise substantially in other domains, we
|
||||
stand ready to extend this provision to those domains in future versions
|
||||
of the GPL, as needed to protect the freedom of users.
|
||||
|
||||
Finally, every program is threatened constantly by software patents.
|
||||
States should not allow patents to restrict development and use of
|
||||
software on general-purpose computers, but in those that do, we wish to
|
||||
avoid the special danger that patents applied to a free program could
|
||||
make it effectively proprietary. To prevent this, the GPL assures that
|
||||
patents cannot be used to render the program non-free.
|
||||
|
||||
The precise terms and conditions for copying, distribution and
|
||||
modification follow.
|
||||
|
||||
TERMS AND CONDITIONS
|
||||
|
||||
0. Definitions.
|
||||
|
||||
"This License" refers to version 3 of the GNU General Public License.
|
||||
|
||||
"Copyright" also means copyright-like laws that apply to other kinds of
|
||||
works, such as semiconductor masks.
|
||||
|
||||
"The Program" refers to any copyrightable work licensed under this
|
||||
License. Each licensee is addressed as "you". "Licensees" and
|
||||
"recipients" may be individuals or organizations.
|
||||
|
||||
To "modify" a work means to copy from or adapt all or part of the work
|
||||
in a fashion requiring copyright permission, other than the making of an
|
||||
exact copy. The resulting work is called a "modified version" of the
|
||||
earlier work or a work "based on" the earlier work.
|
||||
|
||||
A "covered work" means either the unmodified Program or a work based
|
||||
on the Program.
|
||||
|
||||
To "propagate" a work means to do anything with it that, without
|
||||
permission, would make you directly or secondarily liable for
|
||||
infringement under applicable copyright law, except executing it on a
|
||||
computer or modifying a private copy. Propagation includes copying,
|
||||
distribution (with or without modification), making available to the
|
||||
public, and in some countries other activities as well.
|
||||
|
||||
To "convey" a work means any kind of propagation that enables other
|
||||
parties to make or receive copies. Mere interaction with a user through
|
||||
a computer network, with no transfer of a copy, is not conveying.
|
||||
|
||||
An interactive user interface displays "Appropriate Legal Notices"
|
||||
to the extent that it includes a convenient and prominently visible
|
||||
feature that (1) displays an appropriate copyright notice, and (2)
|
||||
tells the user that there is no warranty for the work (except to the
|
||||
extent that warranties are provided), that licensees may convey the
|
||||
work under this License, and how to view a copy of this License. If
|
||||
the interface presents a list of user commands or options, such as a
|
||||
menu, a prominent item in the list meets this criterion.
|
||||
|
||||
1. Source Code.
|
||||
|
||||
The "source code" for a work means the preferred form of the work
|
||||
for making modifications to it. "Object code" means any non-source
|
||||
form of a work.
|
||||
|
||||
A "Standard Interface" means an interface that either is an official
|
||||
standard defined by a recognized standards body, or, in the case of
|
||||
interfaces specified for a particular programming language, one that
|
||||
is widely used among developers working in that language.
|
||||
|
||||
The "System Libraries" of an executable work include anything, other
|
||||
than the work as a whole, that (a) is included in the normal form of
|
||||
packaging a Major Component, but which is not part of that Major
|
||||
Component, and (b) serves only to enable use of the work with that
|
||||
Major Component, or to implement a Standard Interface for which an
|
||||
implementation is available to the public in source code form. A
|
||||
"Major Component", in this context, means a major essential component
|
||||
(kernel, window system, and so on) of the specific operating system
|
||||
(if any) on which the executable work runs, or a compiler used to
|
||||
produce the work, or an object code interpreter used to run it.
|
||||
|
||||
The "Corresponding Source" for a work in object code form means all
|
||||
the source code needed to generate, install, and (for an executable
|
||||
work) run the object code and to modify the work, including scripts to
|
||||
control those activities. However, it does not include the work's
|
||||
System Libraries, or general-purpose tools or generally available free
|
||||
programs which are used unmodified in performing those activities but
|
||||
which are not part of the work. For example, Corresponding Source
|
||||
includes interface definition files associated with source files for
|
||||
the work, and the source code for shared libraries and dynamically
|
||||
linked subprograms that the work is specifically designed to require,
|
||||
such as by intimate data communication or control flow between those
|
||||
subprograms and other parts of the work.
|
||||
|
||||
The Corresponding Source need not include anything that users
|
||||
can regenerate automatically from other parts of the Corresponding
|
||||
Source.
|
||||
|
||||
The Corresponding Source for a work in source code form is that
|
||||
same work.
|
||||
|
||||
2. Basic Permissions.
|
||||
|
||||
All rights granted under this License are granted for the term of
|
||||
copyright on the Program, and are irrevocable provided the stated
|
||||
conditions are met. This License explicitly affirms your unlimited
|
||||
permission to run the unmodified Program. The output from running a
|
||||
covered work is covered by this License only if the output, given its
|
||||
content, constitutes a covered work. This License acknowledges your
|
||||
rights of fair use or other equivalent, as provided by copyright law.
|
||||
|
||||
You may make, run and propagate covered works that you do not
|
||||
convey, without conditions so long as your license otherwise remains
|
||||
in force. You may convey covered works to others for the sole purpose
|
||||
of having them make modifications exclusively for you, or provide you
|
||||
with facilities for running those works, provided that you comply with
|
||||
the terms of this License in conveying all material for which you do
|
||||
not control copyright. Those thus making or running the covered works
|
||||
for you must do so exclusively on your behalf, under your direction
|
||||
and control, on terms that prohibit them from making any copies of
|
||||
your copyrighted material outside their relationship with you.
|
||||
|
||||
Conveying under any other circumstances is permitted solely under
|
||||
the conditions stated below. Sublicensing is not allowed; section 10
|
||||
makes it unnecessary.
|
||||
|
||||
3. Protecting Users' Legal Rights From Anti-Circumvention Law.
|
||||
|
||||
No covered work shall be deemed part of an effective technological
|
||||
measure under any applicable law fulfilling obligations under article
|
||||
11 of the WIPO copyright treaty adopted on 20 December 1996, or
|
||||
similar laws prohibiting or restricting circumvention of such
|
||||
measures.
|
||||
|
||||
When you convey a covered work, you waive any legal power to forbid
|
||||
circumvention of technological measures to the extent such circumvention
|
||||
is effected by exercising rights under this License with respect to
|
||||
the covered work, and you disclaim any intention to limit operation or
|
||||
modification of the work as a means of enforcing, against the work's
|
||||
users, your or third parties' legal rights to forbid circumvention of
|
||||
technological measures.
|
||||
|
||||
4. Conveying Verbatim Copies.
|
||||
|
||||
You may convey verbatim copies of the Program's source code as you
|
||||
receive it, in any medium, provided that you conspicuously and
|
||||
appropriately publish on each copy an appropriate copyright notice;
|
||||
keep intact all notices stating that this License and any
|
||||
non-permissive terms added in accord with section 7 apply to the code;
|
||||
keep intact all notices of the absence of any warranty; and give all
|
||||
recipients a copy of this License along with the Program.
|
||||
|
||||
You may charge any price or no price for each copy that you convey,
|
||||
and you may offer support or warranty protection for a fee.
|
||||
|
||||
5. Conveying Modified Source Versions.
|
||||
|
||||
You may convey a work based on the Program, or the modifications to
|
||||
produce it from the Program, in the form of source code under the
|
||||
terms of section 4, provided that you also meet all of these conditions:
|
||||
|
||||
a) The work must carry prominent notices stating that you modified
|
||||
it, and giving a relevant date.
|
||||
|
||||
b) The work must carry prominent notices stating that it is
|
||||
released under this License and any conditions added under section
|
||||
7. This requirement modifies the requirement in section 4 to
|
||||
"keep intact all notices".
|
||||
|
||||
c) You must license the entire work, as a whole, under this
|
||||
License to anyone who comes into possession of a copy. This
|
||||
License will therefore apply, along with any applicable section 7
|
||||
additional terms, to the whole of the work, and all its parts,
|
||||
regardless of how they are packaged. This License gives no
|
||||
permission to license the work in any other way, but it does not
|
||||
invalidate such permission if you have separately received it.
|
||||
|
||||
d) If the work has interactive user interfaces, each must display
|
||||
Appropriate Legal Notices; however, if the Program has interactive
|
||||
interfaces that do not display Appropriate Legal Notices, your
|
||||
work need not make them do so.
|
||||
|
||||
A compilation of a covered work with other separate and independent
|
||||
works, which are not by their nature extensions of the covered work,
|
||||
and which are not combined with it such as to form a larger program,
|
||||
in or on a volume of a storage or distribution medium, is called an
|
||||
"aggregate" if the compilation and its resulting copyright are not
|
||||
used to limit the access or legal rights of the compilation's users
|
||||
beyond what the individual works permit. Inclusion of a covered work
|
||||
in an aggregate does not cause this License to apply to the other
|
||||
parts of the aggregate.
|
||||
|
||||
6. Conveying Non-Source Forms.
|
||||
|
||||
You may convey a covered work in object code form under the terms
|
||||
of sections 4 and 5, provided that you also convey the
|
||||
machine-readable Corresponding Source under the terms of this License,
|
||||
in one of these ways:
|
||||
|
||||
a) Convey the object code in, or embodied in, a physical product
|
||||
(including a physical distribution medium), accompanied by the
|
||||
Corresponding Source fixed on a durable physical medium
|
||||
customarily used for software interchange.
|
||||
|
||||
b) Convey the object code in, or embodied in, a physical product
|
||||
(including a physical distribution medium), accompanied by a
|
||||
written offer, valid for at least three years and valid for as
|
||||
long as you offer spare parts or customer support for that product
|
||||
model, to give anyone who possesses the object code either (1) a
|
||||
copy of the Corresponding Source for all the software in the
|
||||
product that is covered by this License, on a durable physical
|
||||
medium customarily used for software interchange, for a price no
|
||||
more than your reasonable cost of physically performing this
|
||||
conveying of source, or (2) access to copy the
|
||||
Corresponding Source from a network server at no charge.
|
||||
|
||||
c) Convey individual copies of the object code with a copy of the
|
||||
written offer to provide the Corresponding Source. This
|
||||
alternative is allowed only occasionally and noncommercially, and
|
||||
only if you received the object code with such an offer, in accord
|
||||
with subsection 6b.
|
||||
|
||||
d) Convey the object code by offering access from a designated
|
||||
place (gratis or for a charge), and offer equivalent access to the
|
||||
Corresponding Source in the same way through the same place at no
|
||||
further charge. You need not require recipients to copy the
|
||||
Corresponding Source along with the object code. If the place to
|
||||
copy the object code is a network server, the Corresponding Source
|
||||
may be on a different server (operated by you or a third party)
|
||||
that supports equivalent copying facilities, provided you maintain
|
||||
clear directions next to the object code saying where to find the
|
||||
Corresponding Source. Regardless of what server hosts the
|
||||
Corresponding Source, you remain obligated to ensure that it is
|
||||
available for as long as needed to satisfy these requirements.
|
||||
|
||||
e) Convey the object code using peer-to-peer transmission, provided
|
||||
you inform other peers where the object code and Corresponding
|
||||
Source of the work are being offered to the general public at no
|
||||
charge under subsection 6d.
|
||||
|
||||
A separable portion of the object code, whose source code is excluded
|
||||
from the Corresponding Source as a System Library, need not be
|
||||
included in conveying the object code work.
|
||||
|
||||
A "User Product" is either (1) a "consumer product", which means any
|
||||
tangible personal property which is normally used for personal, family,
|
||||
or household purposes, or (2) anything designed or sold for incorporation
|
||||
into a dwelling. In determining whether a product is a consumer product,
|
||||
doubtful cases shall be resolved in favor of coverage. For a particular
|
||||
product received by a particular user, "normally used" refers to a
|
||||
typical or common use of that class of product, regardless of the status
|
||||
of the particular user or of the way in which the particular user
|
||||
actually uses, or expects or is expected to use, the product. A product
|
||||
is a consumer product regardless of whether the product has substantial
|
||||
commercial, industrial or non-consumer uses, unless such uses represent
|
||||
the only significant mode of use of the product.
|
||||
|
||||
"Installation Information" for a User Product means any methods,
|
||||
procedures, authorization keys, or other information required to install
|
||||
and execute modified versions of a covered work in that User Product from
|
||||
a modified version of its Corresponding Source. The information must
|
||||
suffice to ensure that the continued functioning of the modified object
|
||||
code is in no case prevented or interfered with solely because
|
||||
modification has been made.
|
||||
|
||||
If you convey an object code work under this section in, or with, or
|
||||
specifically for use in, a User Product, and the conveying occurs as
|
||||
part of a transaction in which the right of possession and use of the
|
||||
User Product is transferred to the recipient in perpetuity or for a
|
||||
fixed term (regardless of how the transaction is characterized), the
|
||||
Corresponding Source conveyed under this section must be accompanied
|
||||
by the Installation Information. But this requirement does not apply
|
||||
if neither you nor any third party retains the ability to install
|
||||
modified object code on the User Product (for example, the work has
|
||||
been installed in ROM).
|
||||
|
||||
The requirement to provide Installation Information does not include a
|
||||
requirement to continue to provide support service, warranty, or updates
|
||||
for a work that has been modified or installed by the recipient, or for
|
||||
the User Product in which it has been modified or installed. Access to a
|
||||
network may be denied when the modification itself materially and
|
||||
adversely affects the operation of the network or violates the rules and
|
||||
protocols for communication across the network.
|
||||
|
||||
Corresponding Source conveyed, and Installation Information provided,
|
||||
in accord with this section must be in a format that is publicly
|
||||
documented (and with an implementation available to the public in
|
||||
source code form), and must require no special password or key for
|
||||
unpacking, reading or copying.
|
||||
|
||||
7. Additional Terms.
|
||||
|
||||
"Additional permissions" are terms that supplement the terms of this
|
||||
License by making exceptions from one or more of its conditions.
|
||||
Additional permissions that are applicable to the entire Program shall
|
||||
be treated as though they were included in this License, to the extent
|
||||
that they are valid under applicable law. If additional permissions
|
||||
apply only to part of the Program, that part may be used separately
|
||||
under those permissions, but the entire Program remains governed by
|
||||
this License without regard to the additional permissions.
|
||||
|
||||
When you convey a copy of a covered work, you may at your option
|
||||
remove any additional permissions from that copy, or from any part of
|
||||
it. (Additional permissions may be written to require their own
|
||||
removal in certain cases when you modify the work.) You may place
|
||||
additional permissions on material, added by you to a covered work,
|
||||
for which you have or can give appropriate copyright permission.
|
||||
|
||||
Notwithstanding any other provision of this License, for material you
|
||||
add to a covered work, you may (if authorized by the copyright holders of
|
||||
that material) supplement the terms of this License with terms:
|
||||
|
||||
a) Disclaiming warranty or limiting liability differently from the
|
||||
terms of sections 15 and 16 of this License; or
|
||||
|
||||
b) Requiring preservation of specified reasonable legal notices or
|
||||
author attributions in that material or in the Appropriate Legal
|
||||
Notices displayed by works containing it; or
|
||||
|
||||
c) Prohibiting misrepresentation of the origin of that material, or
|
||||
requiring that modified versions of such material be marked in
|
||||
reasonable ways as different from the original version; or
|
||||
|
||||
d) Limiting the use for publicity purposes of names of licensors or
|
||||
authors of the material; or
|
||||
|
||||
e) Declining to grant rights under trademark law for use of some
|
||||
trade names, trademarks, or service marks; or
|
||||
|
||||
f) Requiring indemnification of licensors and authors of that
|
||||
material by anyone who conveys the material (or modified versions of
|
||||
it) with contractual assumptions of liability to the recipient, for
|
||||
any liability that these contractual assumptions directly impose on
|
||||
those licensors and authors.
|
||||
|
||||
All other non-permissive additional terms are considered "further
|
||||
restrictions" within the meaning of section 10. If the Program as you
|
||||
received it, or any part of it, contains a notice stating that it is
|
||||
governed by this License along with a term that is a further
|
||||
restriction, you may remove that term. If a license document contains
|
||||
a further restriction but permits relicensing or conveying under this
|
||||
License, you may add to a covered work material governed by the terms
|
||||
of that license document, provided that the further restriction does
|
||||
not survive such relicensing or conveying.
|
||||
|
||||
If you add terms to a covered work in accord with this section, you
|
||||
must place, in the relevant source files, a statement of the
|
||||
additional terms that apply to those files, or a notice indicating
|
||||
where to find the applicable terms.
|
||||
|
||||
Additional terms, permissive or non-permissive, may be stated in the
|
||||
form of a separately written license, or stated as exceptions;
|
||||
the above requirements apply either way.
|
||||
|
||||
8. Termination.
|
||||
|
||||
You may not propagate or modify a covered work except as expressly
|
||||
provided under this License. Any attempt otherwise to propagate or
|
||||
modify it is void, and will automatically terminate your rights under
|
||||
this License (including any patent licenses granted under the third
|
||||
paragraph of section 11).
|
||||
|
||||
However, if you cease all violation of this License, then your
|
||||
license from a particular copyright holder is reinstated (a)
|
||||
provisionally, unless and until the copyright holder explicitly and
|
||||
finally terminates your license, and (b) permanently, if the copyright
|
||||
holder fails to notify you of the violation by some reasonable means
|
||||
prior to 60 days after the cessation.
|
||||
|
||||
Moreover, your license from a particular copyright holder is
|
||||
reinstated permanently if the copyright holder notifies you of the
|
||||
violation by some reasonable means, this is the first time you have
|
||||
received notice of violation of this License (for any work) from that
|
||||
copyright holder, and you cure the violation prior to 30 days after
|
||||
your receipt of the notice.
|
||||
|
||||
Termination of your rights under this section does not terminate the
|
||||
licenses of parties who have received copies or rights from you under
|
||||
this License. If your rights have been terminated and not permanently
|
||||
reinstated, you do not qualify to receive new licenses for the same
|
||||
material under section 10.
|
||||
|
||||
9. Acceptance Not Required for Having Copies.
|
||||
|
||||
You are not required to accept this License in order to receive or
|
||||
run a copy of the Program. Ancillary propagation of a covered work
|
||||
occurring solely as a consequence of using peer-to-peer transmission
|
||||
to receive a copy likewise does not require acceptance. However,
|
||||
nothing other than this License grants you permission to propagate or
|
||||
modify any covered work. These actions infringe copyright if you do
|
||||
not accept this License. Therefore, by modifying or propagating a
|
||||
covered work, you indicate your acceptance of this License to do so.
|
||||
|
||||
10. Automatic Licensing of Downstream Recipients.
|
||||
|
||||
Each time you convey a covered work, the recipient automatically
|
||||
receives a license from the original licensors, to run, modify and
|
||||
propagate that work, subject to this License. You are not responsible
|
||||
for enforcing compliance by third parties with this License.
|
||||
|
||||
An "entity transaction" is a transaction transferring control of an
|
||||
organization, or substantially all assets of one, or subdividing an
|
||||
organization, or merging organizations. If propagation of a covered
|
||||
work results from an entity transaction, each party to that
|
||||
transaction who receives a copy of the work also receives whatever
|
||||
licenses to the work the party's predecessor in interest had or could
|
||||
give under the previous paragraph, plus a right to possession of the
|
||||
Corresponding Source of the work from the predecessor in interest, if
|
||||
the predecessor has it or can get it with reasonable efforts.
|
||||
|
||||
You may not impose any further restrictions on the exercise of the
|
||||
rights granted or affirmed under this License. For example, you may
|
||||
not impose a license fee, royalty, or other charge for exercise of
|
||||
rights granted under this License, and you may not initiate litigation
|
||||
(including a cross-claim or counterclaim in a lawsuit) alleging that
|
||||
any patent claim is infringed by making, using, selling, offering for
|
||||
sale, or importing the Program or any portion of it.
|
||||
|
||||
11. Patents.
|
||||
|
||||
A "contributor" is a copyright holder who authorizes use under this
|
||||
License of the Program or a work on which the Program is based. The
|
||||
work thus licensed is called the contributor's "contributor version".
|
||||
|
||||
A contributor's "essential patent claims" are all patent claims
|
||||
owned or controlled by the contributor, whether already acquired or
|
||||
hereafter acquired, that would be infringed by some manner, permitted
|
||||
by this License, of making, using, or selling its contributor version,
|
||||
but do not include claims that would be infringed only as a
|
||||
consequence of further modification of the contributor version. For
|
||||
purposes of this definition, "control" includes the right to grant
|
||||
patent sublicenses in a manner consistent with the requirements of
|
||||
this License.
|
||||
|
||||
Each contributor grants you a non-exclusive, worldwide, royalty-free
|
||||
patent license under the contributor's essential patent claims, to
|
||||
make, use, sell, offer for sale, import and otherwise run, modify and
|
||||
propagate the contents of its contributor version.
|
||||
|
||||
In the following three paragraphs, a "patent license" is any express
|
||||
agreement or commitment, however denominated, not to enforce a patent
|
||||
(such as an express permission to practice a patent or covenant not to
|
||||
sue for patent infringement). To "grant" such a patent license to a
|
||||
party means to make such an agreement or commitment not to enforce a
|
||||
patent against the party.
|
||||
|
||||
If you convey a covered work, knowingly relying on a patent license,
|
||||
and the Corresponding Source of the work is not available for anyone
|
||||
to copy, free of charge and under the terms of this License, through a
|
||||
publicly available network server or other readily accessible means,
|
||||
then you must either (1) cause the Corresponding Source to be so
|
||||
available, or (2) arrange to deprive yourself of the benefit of the
|
||||
patent license for this particular work, or (3) arrange, in a manner
|
||||
consistent with the requirements of this License, to extend the patent
|
||||
license to downstream recipients. "Knowingly relying" means you have
|
||||
actual knowledge that, but for the patent license, your conveying the
|
||||
covered work in a country, or your recipient's use of the covered work
|
||||
in a country, would infringe one or more identifiable patents in that
|
||||
country that you have reason to believe are valid.
|
||||
|
||||
If, pursuant to or in connection with a single transaction or
|
||||
arrangement, you convey, or propagate by procuring conveyance of, a
|
||||
covered work, and grant a patent license to some of the parties
|
||||
receiving the covered work authorizing them to use, propagate, modify
|
||||
or convey a specific copy of the covered work, then the patent license
|
||||
you grant is automatically extended to all recipients of the covered
|
||||
work and works based on it.
|
||||
|
||||
A patent license is "discriminatory" if it does not include within
|
||||
the scope of its coverage, prohibits the exercise of, or is
|
||||
conditioned on the non-exercise of one or more of the rights that are
|
||||
specifically granted under this License. You may not convey a covered
|
||||
work if you are a party to an arrangement with a third party that is
|
||||
in the business of distributing software, under which you make payment
|
||||
to the third party based on the extent of your activity of conveying
|
||||
the work, and under which the third party grants, to any of the
|
||||
parties who would receive the covered work from you, a discriminatory
|
||||
patent license (a) in connection with copies of the covered work
|
||||
conveyed by you (or copies made from those copies), or (b) primarily
|
||||
for and in connection with specific products or compilations that
|
||||
contain the covered work, unless you entered into that arrangement,
|
||||
or that patent license was granted, prior to 28 March 2007.
|
||||
|
||||
Nothing in this License shall be construed as excluding or limiting
|
||||
any implied license or other defenses to infringement that may
|
||||
otherwise be available to you under applicable patent law.
|
||||
|
||||
12. No Surrender of Others' Freedom.
|
||||
|
||||
If conditions are imposed on you (whether by court order, agreement or
|
||||
otherwise) that contradict the conditions of this License, they do not
|
||||
excuse you from the conditions of this License. If you cannot convey a
|
||||
covered work so as to satisfy simultaneously your obligations under this
|
||||
License and any other pertinent obligations, then as a consequence you may
|
||||
not convey it at all. For example, if you agree to terms that obligate you
|
||||
to collect a royalty for further conveying from those to whom you convey
|
||||
the Program, the only way you could satisfy both those terms and this
|
||||
License would be to refrain entirely from conveying the Program.
|
||||
|
||||
13. Use with the GNU Affero General Public License.
|
||||
|
||||
Notwithstanding any other provision of this License, you have
|
||||
permission to link or combine any covered work with a work licensed
|
||||
under version 3 of the GNU Affero General Public License into a single
|
||||
combined work, and to convey the resulting work. The terms of this
|
||||
License will continue to apply to the part which is the covered work,
|
||||
but the special requirements of the GNU Affero General Public License,
|
||||
section 13, concerning interaction through a network will apply to the
|
||||
combination as such.
|
||||
|
||||
14. Revised Versions of this License.
|
||||
|
||||
The Free Software Foundation may publish revised and/or new versions of
|
||||
the GNU General Public License from time to time. Such new versions will
|
||||
be similar in spirit to the present version, but may differ in detail to
|
||||
address new problems or concerns.
|
||||
|
||||
Each version is given a distinguishing version number. If the
|
||||
Program specifies that a certain numbered version of the GNU General
|
||||
Public License "or any later version" applies to it, you have the
|
||||
option of following the terms and conditions either of that numbered
|
||||
version or of any later version published by the Free Software
|
||||
Foundation. If the Program does not specify a version number of the
|
||||
GNU General Public License, you may choose any version ever published
|
||||
by the Free Software Foundation.
|
||||
|
||||
If the Program specifies that a proxy can decide which future
|
||||
versions of the GNU General Public License can be used, that proxy's
|
||||
public statement of acceptance of a version permanently authorizes you
|
||||
to choose that version for the Program.
|
||||
|
||||
Later license versions may give you additional or different
|
||||
permissions. However, no additional obligations are imposed on any
|
||||
author or copyright holder as a result of your choosing to follow a
|
||||
later version.
|
||||
|
||||
15. Disclaimer of Warranty.
|
||||
|
||||
THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY
|
||||
APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT
|
||||
HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY
|
||||
OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO,
|
||||
THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
|
||||
PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM
|
||||
IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF
|
||||
ALL NECESSARY SERVICING, REPAIR OR CORRECTION.
|
||||
|
||||
16. Limitation of Liability.
|
||||
|
||||
IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING
|
||||
WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS
|
||||
THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY
|
||||
GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE
|
||||
USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF
|
||||
DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD
|
||||
PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS),
|
||||
EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF
|
||||
SUCH DAMAGES.
|
||||
|
||||
17. Interpretation of Sections 15 and 16.
|
||||
|
||||
If the disclaimer of warranty and limitation of liability provided
|
||||
above cannot be given local legal effect according to their terms,
|
||||
reviewing courts shall apply local law that most closely approximates
|
||||
an absolute waiver of all civil liability in connection with the
|
||||
Program, unless a warranty or assumption of liability accompanies a
|
||||
copy of the Program in return for a fee.
|
||||
|
||||
END OF TERMS AND CONDITIONS
|
||||
|
||||
How to Apply These Terms to Your New Programs
|
||||
|
||||
If you develop a new program, and you want it to be of the greatest
|
||||
possible use to the public, the best way to achieve this is to make it
|
||||
free software which everyone can redistribute and change under these terms.
|
||||
|
||||
To do so, attach the following notices to the program. It is safest
|
||||
to attach them to the start of each source file to most effectively
|
||||
state the exclusion of warranty; and each file should have at least
|
||||
the "copyright" line and a pointer to where the full notice is found.
|
||||
|
||||
<one line to give the program's name and a brief idea of what it does.>
|
||||
Copyright (C) <year> <name of author>
|
||||
|
||||
This program is free software: you can redistribute it and/or modify
|
||||
it under the terms of the GNU General Public License as published by
|
||||
the Free Software Foundation, either version 3 of the License, or
|
||||
(at your option) any later version.
|
||||
|
||||
This program is distributed in the hope that it will be useful,
|
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
GNU General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU General Public License
|
||||
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
Also add information on how to contact you by electronic and paper mail.
|
||||
|
||||
If the program does terminal interaction, make it output a short
|
||||
notice like this when it starts in an interactive mode:
|
||||
|
||||
<program> Copyright (C) <year> <name of author>
|
||||
This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'.
|
||||
This is free software, and you are welcome to redistribute it
|
||||
under certain conditions; type `show c' for details.
|
||||
|
||||
The hypothetical commands `show w' and `show c' should show the appropriate
|
||||
parts of the General Public License. Of course, your program's commands
|
||||
might be different; for a GUI interface, you would use an "about box".
|
||||
|
||||
You should also get your employer (if you work as a programmer) or school,
|
||||
if any, to sign a "copyright disclaimer" for the program, if necessary.
|
||||
For more information on this, and how to apply and follow the GNU GPL, see
|
||||
<https://www.gnu.org/licenses/>.
|
||||
|
||||
The GNU General Public License does not permit incorporating your program
|
||||
into proprietary programs. If your program is a subroutine library, you
|
||||
may consider it more useful to permit linking proprietary applications with
|
||||
the library. If this is what you want to do, use the GNU Lesser General
|
||||
Public License instead of this License. But first, please read
|
||||
<https://www.gnu.org/licenses/why-not-lgpl.html>.
|
||||
21
README.md
Normal file
21
README.md
Normal file
@@ -0,0 +1,21 @@
|
||||
# HuggingFace Guess
|
||||
|
||||
A simple tool to guess an HuggingFace repo URL from a state dict.
|
||||
|
||||
This repo does almost the same thing as `from diffusers.loaders.single_file_utils import fetch_diffusers_config` but a bit stronger and more robust.
|
||||
|
||||
The main model detection logics are extracted from Diffusers and stolen from ComfyUI.
|
||||
|
||||
```python
|
||||
import safetensors.torch as sf
|
||||
import huggingface_guess
|
||||
|
||||
|
||||
state_dict = sf.load_file('./realisticVisionV51_v51VAE.safetensors')
|
||||
repo_name = huggingface_guess.guess_repo_name(state_dict)
|
||||
print(repo_name)
|
||||
```
|
||||
|
||||
The above codes will print `runwayml/stable-diffusion-v1-5`.
|
||||
|
||||
Then you can download (or prefetch configs) from HuggingFace to instantiate models and load weights.
|
||||
8
huggingface_guess/__init__.py
Normal file
8
huggingface_guess/__init__.py
Normal file
@@ -0,0 +1,8 @@
|
||||
from huggingface_guess.detection import model_config_from_unet, unet_prefix_from_state_dict, model_config_from_diffusers_unet
|
||||
|
||||
|
||||
def guess_repo_name(state_dict):
|
||||
unet_key_prefix = unet_prefix_from_state_dict(state_dict)
|
||||
model_config = model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=True)
|
||||
repo_id = model_config.huggingface_repo
|
||||
return repo_id
|
||||
543
huggingface_guess/detection.py
Normal file
543
huggingface_guess/detection.py
Normal file
@@ -0,0 +1,543 @@
|
||||
import math
|
||||
import logging
|
||||
import torch
|
||||
|
||||
from huggingface_guess import utils, model_list
|
||||
|
||||
|
||||
def count_blocks(state_dict_keys, prefix_string):
|
||||
count = 0
|
||||
while True:
|
||||
c = False
|
||||
for k in state_dict_keys:
|
||||
if k.startswith(prefix_string.format(count)):
|
||||
c = True
|
||||
break
|
||||
if c == False:
|
||||
break
|
||||
count += 1
|
||||
return count
|
||||
|
||||
|
||||
def calculate_transformer_depth(prefix, state_dict_keys, state_dict):
|
||||
context_dim = None
|
||||
use_linear_in_transformer = False
|
||||
|
||||
transformer_prefix = prefix + "1.transformer_blocks."
|
||||
transformer_keys = sorted(list(filter(lambda a: a.startswith(transformer_prefix), state_dict_keys)))
|
||||
if len(transformer_keys) > 0:
|
||||
last_transformer_depth = count_blocks(state_dict_keys, transformer_prefix + '{}')
|
||||
context_dim = state_dict['{}0.attn2.to_k.weight'.format(transformer_prefix)].shape[1]
|
||||
use_linear_in_transformer = len(state_dict['{}1.proj_in.weight'.format(prefix)].shape) == 2
|
||||
time_stack = '{}1.time_stack.0.attn1.to_q.weight'.format(prefix) in state_dict or '{}1.time_mix_blocks.0.attn1.to_q.weight'.format(prefix) in state_dict
|
||||
time_stack_cross = '{}1.time_stack.0.attn2.to_q.weight'.format(prefix) in state_dict or '{}1.time_mix_blocks.0.attn2.to_q.weight'.format(prefix) in state_dict
|
||||
return last_transformer_depth, context_dim, use_linear_in_transformer, time_stack, time_stack_cross
|
||||
return None
|
||||
|
||||
|
||||
def detect_unet_config(state_dict, key_prefix):
|
||||
state_dict_keys = list(state_dict.keys())
|
||||
|
||||
if '{}joint_blocks.0.context_block.attn.qkv.weight'.format(key_prefix) in state_dict_keys: # mmdit model
|
||||
unet_config = {}
|
||||
unet_config["in_channels"] = state_dict['{}x_embedder.proj.weight'.format(key_prefix)].shape[1]
|
||||
patch_size = state_dict['{}x_embedder.proj.weight'.format(key_prefix)].shape[2]
|
||||
unet_config["patch_size"] = patch_size
|
||||
final_layer = '{}final_layer.linear.weight'.format(key_prefix)
|
||||
if final_layer in state_dict:
|
||||
unet_config["out_channels"] = state_dict[final_layer].shape[0] // (patch_size * patch_size)
|
||||
|
||||
unet_config["depth"] = state_dict['{}x_embedder.proj.weight'.format(key_prefix)].shape[0] // 64
|
||||
unet_config["input_size"] = None
|
||||
y_key = '{}y_embedder.mlp.0.weight'.format(key_prefix)
|
||||
if y_key in state_dict_keys:
|
||||
unet_config["adm_in_channels"] = state_dict[y_key].shape[1]
|
||||
|
||||
context_key = '{}context_embedder.weight'.format(key_prefix)
|
||||
if context_key in state_dict_keys:
|
||||
in_features = state_dict[context_key].shape[1]
|
||||
out_features = state_dict[context_key].shape[0]
|
||||
unet_config["context_embedder_config"] = {"target": "torch.nn.Linear", "params": {"in_features": in_features, "out_features": out_features}}
|
||||
num_patches_key = '{}pos_embed'.format(key_prefix)
|
||||
if num_patches_key in state_dict_keys:
|
||||
num_patches = state_dict[num_patches_key].shape[1]
|
||||
unet_config["num_patches"] = num_patches
|
||||
unet_config["pos_embed_max_size"] = round(math.sqrt(num_patches))
|
||||
|
||||
rms_qk = '{}joint_blocks.0.context_block.attn.ln_q.weight'.format(key_prefix)
|
||||
if rms_qk in state_dict_keys:
|
||||
unet_config["qk_norm"] = "rms"
|
||||
|
||||
unet_config["pos_embed_scaling_factor"] = None # unused for inference
|
||||
context_processor = '{}context_processor.layers.0.attn.qkv.weight'.format(key_prefix)
|
||||
if context_processor in state_dict_keys:
|
||||
unet_config["context_processor_layers"] = count_blocks(state_dict_keys, '{}context_processor.layers.'.format(key_prefix) + '{}.')
|
||||
return unet_config
|
||||
|
||||
if '{}clf.1.weight'.format(key_prefix) in state_dict_keys: # stable cascade
|
||||
unet_config = {}
|
||||
text_mapper_name = '{}clip_txt_mapper.weight'.format(key_prefix)
|
||||
if text_mapper_name in state_dict_keys:
|
||||
unet_config['stable_cascade_stage'] = 'c'
|
||||
w = state_dict[text_mapper_name]
|
||||
if w.shape[0] == 1536: # stage c lite
|
||||
unet_config['c_cond'] = 1536
|
||||
unet_config['c_hidden'] = [1536, 1536]
|
||||
unet_config['nhead'] = [24, 24]
|
||||
unet_config['blocks'] = [[4, 12], [12, 4]]
|
||||
elif w.shape[0] == 2048: # stage c full
|
||||
unet_config['c_cond'] = 2048
|
||||
elif '{}clip_mapper.weight'.format(key_prefix) in state_dict_keys:
|
||||
unet_config['stable_cascade_stage'] = 'b'
|
||||
w = state_dict['{}down_blocks.1.0.channelwise.0.weight'.format(key_prefix)]
|
||||
if w.shape[-1] == 640:
|
||||
unet_config['c_hidden'] = [320, 640, 1280, 1280]
|
||||
unet_config['nhead'] = [-1, -1, 20, 20]
|
||||
unet_config['blocks'] = [[2, 6, 28, 6], [6, 28, 6, 2]]
|
||||
unet_config['block_repeat'] = [[1, 1, 1, 1], [3, 3, 2, 2]]
|
||||
elif w.shape[-1] == 576: # stage b lite
|
||||
unet_config['c_hidden'] = [320, 576, 1152, 1152]
|
||||
unet_config['nhead'] = [-1, 9, 18, 18]
|
||||
unet_config['blocks'] = [[2, 4, 14, 4], [4, 14, 4, 2]]
|
||||
unet_config['block_repeat'] = [[1, 1, 1, 1], [2, 2, 2, 2]]
|
||||
return unet_config
|
||||
|
||||
if '{}transformer.rotary_pos_emb.inv_freq'.format(key_prefix) in state_dict_keys: # stable audio dit
|
||||
unet_config = {}
|
||||
unet_config["audio_model"] = "dit1.0"
|
||||
return unet_config
|
||||
|
||||
if '{}double_layers.0.attn.w1q.weight'.format(key_prefix) in state_dict_keys: # aura flow dit
|
||||
unet_config = {}
|
||||
unet_config["max_seq"] = state_dict['{}positional_encoding'.format(key_prefix)].shape[1]
|
||||
unet_config["cond_seq_dim"] = state_dict['{}cond_seq_linear.weight'.format(key_prefix)].shape[1]
|
||||
double_layers = count_blocks(state_dict_keys, '{}double_layers.'.format(key_prefix) + '{}.')
|
||||
single_layers = count_blocks(state_dict_keys, '{}single_layers.'.format(key_prefix) + '{}.')
|
||||
unet_config["n_double_layers"] = double_layers
|
||||
unet_config["n_layers"] = double_layers + single_layers
|
||||
return unet_config
|
||||
|
||||
if '{}mlp_t5.0.weight'.format(key_prefix) in state_dict_keys: # Hunyuan DiT
|
||||
unet_config = {}
|
||||
unet_config["image_model"] = "hydit"
|
||||
unet_config["depth"] = count_blocks(state_dict_keys, '{}blocks.'.format(key_prefix) + '{}.')
|
||||
unet_config["hidden_size"] = state_dict['{}x_embedder.proj.weight'.format(key_prefix)].shape[0]
|
||||
if unet_config["hidden_size"] == 1408 and unet_config["depth"] == 40: # DiT-g/2
|
||||
unet_config["mlp_ratio"] = 4.3637
|
||||
if state_dict['{}extra_embedder.0.weight'.format(key_prefix)].shape[1] == 3968:
|
||||
unet_config["size_cond"] = True
|
||||
unet_config["use_style_cond"] = True
|
||||
unet_config["image_model"] = "hydit1"
|
||||
return unet_config
|
||||
|
||||
if '{}double_blocks.0.img_attn.norm.key_norm.scale'.format(key_prefix) in state_dict_keys: # Flux
|
||||
dit_config = {}
|
||||
dit_config["image_model"] = "flux"
|
||||
dit_config["in_channels"] = 64
|
||||
dit_config["vec_in_dim"] = 768
|
||||
dit_config["context_in_dim"] = 4096
|
||||
dit_config["hidden_size"] = 3072
|
||||
dit_config["mlp_ratio"] = 4.0
|
||||
dit_config["num_heads"] = 24
|
||||
dit_config["depth"] = 19
|
||||
dit_config["depth_single_blocks"] = 38
|
||||
dit_config["axes_dim"] = [16, 56, 56]
|
||||
dit_config["theta"] = 10000
|
||||
dit_config["qkv_bias"] = True
|
||||
dit_config["guidance_embed"] = "{}guidance_in.in_layer.weight".format(key_prefix) in state_dict_keys
|
||||
return dit_config
|
||||
|
||||
if '{}input_blocks.0.0.weight'.format(key_prefix) not in state_dict_keys:
|
||||
return None
|
||||
|
||||
unet_config = {
|
||||
"use_checkpoint": False,
|
||||
"image_size": 32,
|
||||
"use_spatial_transformer": True,
|
||||
"legacy": False
|
||||
}
|
||||
|
||||
y_input = '{}label_emb.0.0.weight'.format(key_prefix)
|
||||
if y_input in state_dict_keys:
|
||||
unet_config["num_classes"] = "sequential"
|
||||
unet_config["adm_in_channels"] = state_dict[y_input].shape[1]
|
||||
else:
|
||||
unet_config["adm_in_channels"] = None
|
||||
|
||||
model_channels = state_dict['{}input_blocks.0.0.weight'.format(key_prefix)].shape[0]
|
||||
in_channels = state_dict['{}input_blocks.0.0.weight'.format(key_prefix)].shape[1]
|
||||
|
||||
out_key = '{}out.2.weight'.format(key_prefix)
|
||||
if out_key in state_dict:
|
||||
out_channels = state_dict[out_key].shape[0]
|
||||
else:
|
||||
out_channels = 4
|
||||
|
||||
num_res_blocks = []
|
||||
channel_mult = []
|
||||
attention_resolutions = []
|
||||
transformer_depth = []
|
||||
transformer_depth_output = []
|
||||
context_dim = None
|
||||
use_linear_in_transformer = False
|
||||
|
||||
video_model = False
|
||||
video_model_cross = False
|
||||
|
||||
current_res = 1
|
||||
count = 0
|
||||
|
||||
last_res_blocks = 0
|
||||
last_channel_mult = 0
|
||||
|
||||
input_block_count = count_blocks(state_dict_keys, '{}input_blocks'.format(key_prefix) + '.{}.')
|
||||
for count in range(input_block_count):
|
||||
prefix = '{}input_blocks.{}.'.format(key_prefix, count)
|
||||
prefix_output = '{}output_blocks.{}.'.format(key_prefix, input_block_count - count - 1)
|
||||
|
||||
block_keys = sorted(list(filter(lambda a: a.startswith(prefix), state_dict_keys)))
|
||||
if len(block_keys) == 0:
|
||||
break
|
||||
|
||||
block_keys_output = sorted(list(filter(lambda a: a.startswith(prefix_output), state_dict_keys)))
|
||||
|
||||
if "{}0.op.weight".format(prefix) in block_keys: # new layer
|
||||
num_res_blocks.append(last_res_blocks)
|
||||
channel_mult.append(last_channel_mult)
|
||||
|
||||
current_res *= 2
|
||||
last_res_blocks = 0
|
||||
last_channel_mult = 0
|
||||
out = calculate_transformer_depth(prefix_output, state_dict_keys, state_dict)
|
||||
if out is not None:
|
||||
transformer_depth_output.append(out[0])
|
||||
else:
|
||||
transformer_depth_output.append(0)
|
||||
else:
|
||||
res_block_prefix = "{}0.in_layers.0.weight".format(prefix)
|
||||
if res_block_prefix in block_keys:
|
||||
last_res_blocks += 1
|
||||
last_channel_mult = state_dict["{}0.out_layers.3.weight".format(prefix)].shape[0] // model_channels
|
||||
|
||||
out = calculate_transformer_depth(prefix, state_dict_keys, state_dict)
|
||||
if out is not None:
|
||||
transformer_depth.append(out[0])
|
||||
if context_dim is None:
|
||||
context_dim = out[1]
|
||||
use_linear_in_transformer = out[2]
|
||||
video_model = out[3]
|
||||
video_model_cross = out[4]
|
||||
else:
|
||||
transformer_depth.append(0)
|
||||
|
||||
res_block_prefix = "{}0.in_layers.0.weight".format(prefix_output)
|
||||
if res_block_prefix in block_keys_output:
|
||||
out = calculate_transformer_depth(prefix_output, state_dict_keys, state_dict)
|
||||
if out is not None:
|
||||
transformer_depth_output.append(out[0])
|
||||
else:
|
||||
transformer_depth_output.append(0)
|
||||
|
||||
num_res_blocks.append(last_res_blocks)
|
||||
channel_mult.append(last_channel_mult)
|
||||
if "{}middle_block.1.proj_in.weight".format(key_prefix) in state_dict_keys:
|
||||
transformer_depth_middle = count_blocks(state_dict_keys, '{}middle_block.1.transformer_blocks.'.format(key_prefix) + '{}')
|
||||
elif "{}middle_block.0.in_layers.0.weight".format(key_prefix) in state_dict_keys:
|
||||
transformer_depth_middle = -1
|
||||
else:
|
||||
transformer_depth_middle = -2
|
||||
|
||||
unet_config["in_channels"] = in_channels
|
||||
unet_config["out_channels"] = out_channels
|
||||
unet_config["model_channels"] = model_channels
|
||||
unet_config["num_res_blocks"] = num_res_blocks
|
||||
unet_config["transformer_depth"] = transformer_depth
|
||||
unet_config["transformer_depth_output"] = transformer_depth_output
|
||||
unet_config["channel_mult"] = channel_mult
|
||||
unet_config["transformer_depth_middle"] = transformer_depth_middle
|
||||
unet_config['use_linear_in_transformer'] = use_linear_in_transformer
|
||||
unet_config["context_dim"] = context_dim
|
||||
|
||||
if video_model:
|
||||
unet_config["extra_ff_mix_layer"] = True
|
||||
unet_config["use_spatial_context"] = True
|
||||
unet_config["merge_strategy"] = "learned_with_images"
|
||||
unet_config["merge_factor"] = 0.0
|
||||
unet_config["video_kernel_size"] = [3, 1, 1]
|
||||
unet_config["use_temporal_resblock"] = True
|
||||
unet_config["use_temporal_attention"] = True
|
||||
unet_config["disable_temporal_crossattention"] = not video_model_cross
|
||||
else:
|
||||
unet_config["use_temporal_resblock"] = False
|
||||
unet_config["use_temporal_attention"] = False
|
||||
|
||||
return unet_config
|
||||
|
||||
|
||||
def model_config_from_unet_config(unet_config, state_dict=None):
|
||||
for model_config in model_list.models:
|
||||
if model_config.matches(unet_config, state_dict):
|
||||
return model_config(unet_config)
|
||||
|
||||
logging.error("no match {}".format(unet_config))
|
||||
return None
|
||||
|
||||
|
||||
def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=False):
|
||||
unet_config = detect_unet_config(state_dict, unet_key_prefix)
|
||||
if unet_config is None:
|
||||
return None
|
||||
model_config = model_config_from_unet_config(unet_config, state_dict)
|
||||
if model_config is None and use_base_if_no_match:
|
||||
return model_list.BASE(unet_config)
|
||||
else:
|
||||
return model_config
|
||||
|
||||
|
||||
def unet_prefix_from_state_dict(state_dict):
|
||||
candidates = [
|
||||
"model.diffusion_model.", # ldm/sgm models
|
||||
"model.model.", # audio models
|
||||
]
|
||||
counts = {k: 0 for k in candidates}
|
||||
for k in state_dict:
|
||||
for c in candidates:
|
||||
if k.startswith(c):
|
||||
counts[c] += 1
|
||||
break
|
||||
|
||||
top = max(counts, key=counts.get)
|
||||
if counts[top] > 5:
|
||||
return top
|
||||
else:
|
||||
return "model." # aura flow and others
|
||||
|
||||
|
||||
def convert_config(unet_config):
|
||||
new_config = unet_config.copy()
|
||||
num_res_blocks = new_config.get("num_res_blocks", None)
|
||||
channel_mult = new_config.get("channel_mult", None)
|
||||
|
||||
if isinstance(num_res_blocks, int):
|
||||
num_res_blocks = len(channel_mult) * [num_res_blocks]
|
||||
|
||||
if "attention_resolutions" in new_config:
|
||||
attention_resolutions = new_config.pop("attention_resolutions")
|
||||
transformer_depth = new_config.get("transformer_depth", None)
|
||||
transformer_depth_middle = new_config.get("transformer_depth_middle", None)
|
||||
|
||||
if isinstance(transformer_depth, int):
|
||||
transformer_depth = len(channel_mult) * [transformer_depth]
|
||||
if transformer_depth_middle is None:
|
||||
transformer_depth_middle = transformer_depth[-1]
|
||||
t_in = []
|
||||
t_out = []
|
||||
s = 1
|
||||
for i in range(len(num_res_blocks)):
|
||||
res = num_res_blocks[i]
|
||||
d = 0
|
||||
if s in attention_resolutions:
|
||||
d = transformer_depth[i]
|
||||
|
||||
t_in += [d] * res
|
||||
t_out += [d] * (res + 1)
|
||||
s *= 2
|
||||
transformer_depth = t_in
|
||||
transformer_depth_output = t_out
|
||||
new_config["transformer_depth"] = t_in
|
||||
new_config["transformer_depth_output"] = t_out
|
||||
new_config["transformer_depth_middle"] = transformer_depth_middle
|
||||
|
||||
new_config["num_res_blocks"] = num_res_blocks
|
||||
return new_config
|
||||
|
||||
|
||||
def unet_config_from_diffusers_unet(state_dict, dtype=None):
|
||||
match = {}
|
||||
transformer_depth = []
|
||||
|
||||
attn_res = 1
|
||||
down_blocks = count_blocks(state_dict, "down_blocks.{}")
|
||||
for i in range(down_blocks):
|
||||
attn_blocks = count_blocks(state_dict, "down_blocks.{}.attentions.".format(i) + '{}')
|
||||
res_blocks = count_blocks(state_dict, "down_blocks.{}.resnets.".format(i) + '{}')
|
||||
for ab in range(attn_blocks):
|
||||
transformer_count = count_blocks(state_dict, "down_blocks.{}.attentions.{}.transformer_blocks.".format(i, ab) + '{}')
|
||||
transformer_depth.append(transformer_count)
|
||||
if transformer_count > 0:
|
||||
match["context_dim"] = state_dict["down_blocks.{}.attentions.{}.transformer_blocks.0.attn2.to_k.weight".format(i, ab)].shape[1]
|
||||
|
||||
attn_res *= 2
|
||||
if attn_blocks == 0:
|
||||
for i in range(res_blocks):
|
||||
transformer_depth.append(0)
|
||||
|
||||
match["transformer_depth"] = transformer_depth
|
||||
|
||||
match["model_channels"] = state_dict["conv_in.weight"].shape[0]
|
||||
match["in_channels"] = state_dict["conv_in.weight"].shape[1]
|
||||
match["adm_in_channels"] = None
|
||||
if "class_embedding.linear_1.weight" in state_dict:
|
||||
match["adm_in_channels"] = state_dict["class_embedding.linear_1.weight"].shape[1]
|
||||
elif "add_embedding.linear_1.weight" in state_dict:
|
||||
match["adm_in_channels"] = state_dict["add_embedding.linear_1.weight"].shape[1]
|
||||
|
||||
SDXL = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
|
||||
'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320,
|
||||
'num_res_blocks': [2, 2, 2], 'transformer_depth': [0, 0, 2, 2, 10, 10], 'channel_mult': [1, 2, 4], 'transformer_depth_middle': 10,
|
||||
'use_linear_in_transformer': True, 'context_dim': 2048, 'num_head_channels': 64, 'transformer_depth_output': [0, 0, 0, 2, 2, 2, 10, 10, 10],
|
||||
'use_temporal_attention': False, 'use_temporal_resblock': False}
|
||||
|
||||
SDXL_refiner = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
|
||||
'num_classes': 'sequential', 'adm_in_channels': 2560, 'dtype': dtype, 'in_channels': 4, 'model_channels': 384,
|
||||
'num_res_blocks': [2, 2, 2, 2], 'transformer_depth': [0, 0, 4, 4, 4, 4, 0, 0], 'channel_mult': [1, 2, 4, 4], 'transformer_depth_middle': 4,
|
||||
'use_linear_in_transformer': True, 'context_dim': 1280, 'num_head_channels': 64, 'transformer_depth_output': [0, 0, 0, 4, 4, 4, 4, 4, 4, 0, 0, 0],
|
||||
'use_temporal_attention': False, 'use_temporal_resblock': False}
|
||||
|
||||
SD21 = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
|
||||
'adm_in_channels': None, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320, 'num_res_blocks': [2, 2, 2, 2],
|
||||
'transformer_depth': [1, 1, 1, 1, 1, 1, 0, 0], 'channel_mult': [1, 2, 4, 4], 'transformer_depth_middle': 1, 'use_linear_in_transformer': True,
|
||||
'context_dim': 1024, 'num_head_channels': 64, 'transformer_depth_output': [1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0],
|
||||
'use_temporal_attention': False, 'use_temporal_resblock': False}
|
||||
|
||||
SD21_uncliph = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
|
||||
'num_classes': 'sequential', 'adm_in_channels': 2048, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320,
|
||||
'num_res_blocks': [2, 2, 2, 2], 'transformer_depth': [1, 1, 1, 1, 1, 1, 0, 0], 'channel_mult': [1, 2, 4, 4], 'transformer_depth_middle': 1,
|
||||
'use_linear_in_transformer': True, 'context_dim': 1024, 'num_head_channels': 64, 'transformer_depth_output': [1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0],
|
||||
'use_temporal_attention': False, 'use_temporal_resblock': False}
|
||||
|
||||
SD21_unclipl = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
|
||||
'num_classes': 'sequential', 'adm_in_channels': 1536, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320,
|
||||
'num_res_blocks': [2, 2, 2, 2], 'transformer_depth': [1, 1, 1, 1, 1, 1, 0, 0], 'channel_mult': [1, 2, 4, 4], 'transformer_depth_middle': 1,
|
||||
'use_linear_in_transformer': True, 'context_dim': 1024, 'num_head_channels': 64, 'transformer_depth_output': [1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0],
|
||||
'use_temporal_attention': False, 'use_temporal_resblock': False}
|
||||
|
||||
SD15 = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, 'adm_in_channels': None,
|
||||
'dtype': dtype, 'in_channels': 4, 'model_channels': 320, 'num_res_blocks': [2, 2, 2, 2], 'transformer_depth': [1, 1, 1, 1, 1, 1, 0, 0],
|
||||
'channel_mult': [1, 2, 4, 4], 'transformer_depth_middle': 1, 'use_linear_in_transformer': False, 'context_dim': 768, 'num_heads': 8,
|
||||
'transformer_depth_output': [1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0],
|
||||
'use_temporal_attention': False, 'use_temporal_resblock': False}
|
||||
|
||||
SDXL_mid_cnet = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
|
||||
'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320,
|
||||
'num_res_blocks': [2, 2, 2], 'transformer_depth': [0, 0, 0, 0, 1, 1], 'channel_mult': [1, 2, 4], 'transformer_depth_middle': 1,
|
||||
'use_linear_in_transformer': True, 'context_dim': 2048, 'num_head_channels': 64, 'transformer_depth_output': [0, 0, 0, 0, 0, 0, 1, 1, 1],
|
||||
'use_temporal_attention': False, 'use_temporal_resblock': False}
|
||||
|
||||
SDXL_small_cnet = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
|
||||
'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320,
|
||||
'num_res_blocks': [2, 2, 2], 'transformer_depth': [0, 0, 0, 0, 0, 0], 'channel_mult': [1, 2, 4], 'transformer_depth_middle': 0,
|
||||
'use_linear_in_transformer': True, 'num_head_channels': 64, 'context_dim': 1, 'transformer_depth_output': [0, 0, 0, 0, 0, 0, 0, 0, 0],
|
||||
'use_temporal_attention': False, 'use_temporal_resblock': False}
|
||||
|
||||
SDXL_diffusers_inpaint = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
|
||||
'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 9, 'model_channels': 320,
|
||||
'num_res_blocks': [2, 2, 2], 'transformer_depth': [0, 0, 2, 2, 10, 10], 'channel_mult': [1, 2, 4], 'transformer_depth_middle': 10,
|
||||
'use_linear_in_transformer': True, 'context_dim': 2048, 'num_head_channels': 64, 'transformer_depth_output': [0, 0, 0, 2, 2, 2, 10, 10, 10],
|
||||
'use_temporal_attention': False, 'use_temporal_resblock': False}
|
||||
|
||||
SDXL_diffusers_ip2p = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
|
||||
'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 8, 'model_channels': 320,
|
||||
'num_res_blocks': [2, 2, 2], 'transformer_depth': [0, 0, 2, 2, 10, 10], 'channel_mult': [1, 2, 4], 'transformer_depth_middle': 10,
|
||||
'use_linear_in_transformer': True, 'context_dim': 2048, 'num_head_channels': 64, 'transformer_depth_output': [0, 0, 0, 2, 2, 2, 10, 10, 10],
|
||||
'use_temporal_attention': False, 'use_temporal_resblock': False}
|
||||
|
||||
SSD_1B = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
|
||||
'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320,
|
||||
'num_res_blocks': [2, 2, 2], 'transformer_depth': [0, 0, 2, 2, 4, 4], 'transformer_depth_output': [0, 0, 0, 1, 1, 2, 10, 4, 4],
|
||||
'channel_mult': [1, 2, 4], 'transformer_depth_middle': -1, 'use_linear_in_transformer': True, 'context_dim': 2048, 'num_head_channels': 64,
|
||||
'use_temporal_attention': False, 'use_temporal_resblock': False}
|
||||
|
||||
Segmind_Vega = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
|
||||
'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320,
|
||||
'num_res_blocks': [2, 2, 2], 'transformer_depth': [0, 0, 1, 1, 2, 2], 'transformer_depth_output': [0, 0, 0, 1, 1, 1, 2, 2, 2],
|
||||
'channel_mult': [1, 2, 4], 'transformer_depth_middle': -1, 'use_linear_in_transformer': True, 'context_dim': 2048, 'num_head_channels': 64,
|
||||
'use_temporal_attention': False, 'use_temporal_resblock': False}
|
||||
|
||||
KOALA_700M = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
|
||||
'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320,
|
||||
'num_res_blocks': [1, 1, 1], 'transformer_depth': [0, 2, 5], 'transformer_depth_output': [0, 0, 2, 2, 5, 5],
|
||||
'channel_mult': [1, 2, 4], 'transformer_depth_middle': -2, 'use_linear_in_transformer': True, 'context_dim': 2048, 'num_head_channels': 64,
|
||||
'use_temporal_attention': False, 'use_temporal_resblock': False}
|
||||
|
||||
KOALA_1B = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
|
||||
'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320,
|
||||
'num_res_blocks': [1, 1, 1], 'transformer_depth': [0, 2, 6], 'transformer_depth_output': [0, 0, 2, 2, 6, 6],
|
||||
'channel_mult': [1, 2, 4], 'transformer_depth_middle': 6, 'use_linear_in_transformer': True, 'context_dim': 2048, 'num_head_channels': 64,
|
||||
'use_temporal_attention': False, 'use_temporal_resblock': False}
|
||||
|
||||
SD09_XS = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
|
||||
'adm_in_channels': None, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320, 'num_res_blocks': [1, 1, 1],
|
||||
'transformer_depth': [1, 1, 1], 'channel_mult': [1, 2, 4], 'transformer_depth_middle': -2, 'use_linear_in_transformer': True,
|
||||
'context_dim': 1024, 'num_head_channels': 64, 'transformer_depth_output': [1, 1, 1, 1, 1, 1],
|
||||
'use_temporal_attention': False, 'use_temporal_resblock': False, 'disable_self_attentions': [True, False, False]}
|
||||
|
||||
SD_XS = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
|
||||
'adm_in_channels': None, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320, 'num_res_blocks': [1, 1, 1],
|
||||
'transformer_depth': [0, 1, 1], 'channel_mult': [1, 2, 4], 'transformer_depth_middle': -2, 'use_linear_in_transformer': False,
|
||||
'context_dim': 768, 'num_head_channels': 64, 'transformer_depth_output': [0, 0, 1, 1, 1, 1],
|
||||
'use_temporal_attention': False, 'use_temporal_resblock': False}
|
||||
|
||||
supported_models = [SDXL, SDXL_refiner, SD21, SD15, SD21_uncliph, SD21_unclipl, SDXL_mid_cnet, SDXL_small_cnet, SDXL_diffusers_inpaint, SSD_1B, Segmind_Vega, KOALA_700M, KOALA_1B, SD09_XS, SD_XS, SDXL_diffusers_ip2p]
|
||||
|
||||
for unet_config in supported_models:
|
||||
matches = True
|
||||
for k in match:
|
||||
if match[k] != unet_config[k]:
|
||||
matches = False
|
||||
break
|
||||
if matches:
|
||||
return convert_config(unet_config)
|
||||
return None
|
||||
|
||||
|
||||
def model_config_from_diffusers_unet(state_dict):
|
||||
unet_config = unet_config_from_diffusers_unet(state_dict)
|
||||
if unet_config is not None:
|
||||
return model_config_from_unet_config(unet_config)
|
||||
return None
|
||||
|
||||
|
||||
def convert_diffusers_mmdit(state_dict, output_prefix=""):
|
||||
out_sd = {}
|
||||
|
||||
if 'transformer_blocks.0.attn.add_q_proj.weight' in state_dict: # SD3
|
||||
num_blocks = count_blocks(state_dict, 'transformer_blocks.{}.')
|
||||
depth = state_dict["pos_embed.proj.weight"].shape[0] // 64
|
||||
sd_map = utils.mmdit_to_diffusers({"depth": depth, "num_blocks": num_blocks}, output_prefix=output_prefix)
|
||||
elif 'joint_transformer_blocks.0.attn.add_k_proj.weight' in state_dict: # AuraFlow
|
||||
num_joint = count_blocks(state_dict, 'joint_transformer_blocks.{}.')
|
||||
num_single = count_blocks(state_dict, 'single_transformer_blocks.{}.')
|
||||
sd_map = utils.auraflow_to_diffusers({"n_double_layers": num_joint, "n_layers": num_joint + num_single}, output_prefix=output_prefix)
|
||||
else:
|
||||
return None
|
||||
|
||||
for k in sd_map:
|
||||
weight = state_dict.get(k, None)
|
||||
if weight is not None:
|
||||
t = sd_map[k]
|
||||
|
||||
if not isinstance(t, str):
|
||||
if len(t) > 2:
|
||||
fun = t[2]
|
||||
else:
|
||||
fun = lambda a: a
|
||||
offset = t[1]
|
||||
if offset is not None:
|
||||
old_weight = out_sd.get(t[0], None)
|
||||
if old_weight is None:
|
||||
old_weight = torch.empty_like(weight)
|
||||
old_weight = old_weight.repeat([3] + [1] * (len(old_weight.shape) - 1))
|
||||
|
||||
w = old_weight.narrow(offset[0], offset[1], offset[2])
|
||||
else:
|
||||
old_weight = weight
|
||||
w = weight
|
||||
w[:] = fun(weight)
|
||||
t = t[0]
|
||||
out_sd[t] = old_weight
|
||||
else:
|
||||
out_sd[t] = weight
|
||||
state_dict.pop(k)
|
||||
|
||||
return out_sd
|
||||
281
huggingface_guess/diffusers_convert.py
Normal file
281
huggingface_guess/diffusers_convert.py
Normal file
@@ -0,0 +1,281 @@
|
||||
import re
|
||||
import torch
|
||||
import logging
|
||||
|
||||
# conversion code from https://github.com/huggingface/diffusers/blob/main/scripts/convert_diffusers_to_original_stable_diffusion.py
|
||||
|
||||
# =================#
|
||||
# UNet Conversion #
|
||||
# =================#
|
||||
|
||||
unet_conversion_map = [
|
||||
# (stable-diffusion, HF Diffusers)
|
||||
("time_embed.0.weight", "time_embedding.linear_1.weight"),
|
||||
("time_embed.0.bias", "time_embedding.linear_1.bias"),
|
||||
("time_embed.2.weight", "time_embedding.linear_2.weight"),
|
||||
("time_embed.2.bias", "time_embedding.linear_2.bias"),
|
||||
("input_blocks.0.0.weight", "conv_in.weight"),
|
||||
("input_blocks.0.0.bias", "conv_in.bias"),
|
||||
("out.0.weight", "conv_norm_out.weight"),
|
||||
("out.0.bias", "conv_norm_out.bias"),
|
||||
("out.2.weight", "conv_out.weight"),
|
||||
("out.2.bias", "conv_out.bias"),
|
||||
]
|
||||
|
||||
unet_conversion_map_resnet = [
|
||||
# (stable-diffusion, HF Diffusers)
|
||||
("in_layers.0", "norm1"),
|
||||
("in_layers.2", "conv1"),
|
||||
("out_layers.0", "norm2"),
|
||||
("out_layers.3", "conv2"),
|
||||
("emb_layers.1", "time_emb_proj"),
|
||||
("skip_connection", "conv_shortcut"),
|
||||
]
|
||||
|
||||
unet_conversion_map_layer = []
|
||||
# hardcoded number of downblocks and resnets/attentions...
|
||||
# would need smarter logic for other networks.
|
||||
for i in range(4):
|
||||
# loop over downblocks/upblocks
|
||||
|
||||
for j in range(2):
|
||||
# loop over resnets/attentions for downblocks
|
||||
hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
|
||||
sd_down_res_prefix = f"input_blocks.{3 * i + j + 1}.0."
|
||||
unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
|
||||
|
||||
if i < 3:
|
||||
# no attention layers in down_blocks.3
|
||||
hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
|
||||
sd_down_atn_prefix = f"input_blocks.{3 * i + j + 1}.1."
|
||||
unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
|
||||
|
||||
for j in range(3):
|
||||
# loop over resnets/attentions for upblocks
|
||||
hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}."
|
||||
sd_up_res_prefix = f"output_blocks.{3 * i + j}.0."
|
||||
unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix))
|
||||
|
||||
if i > 0:
|
||||
# no attention layers in up_blocks.0
|
||||
hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}."
|
||||
sd_up_atn_prefix = f"output_blocks.{3 * i + j}.1."
|
||||
unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix))
|
||||
|
||||
if i < 3:
|
||||
# no downsample in down_blocks.3
|
||||
hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
|
||||
sd_downsample_prefix = f"input_blocks.{3 * (i + 1)}.0.op."
|
||||
unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
|
||||
|
||||
# no upsample in up_blocks.3
|
||||
hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
|
||||
sd_upsample_prefix = f"output_blocks.{3 * i + 2}.{1 if i == 0 else 2}."
|
||||
unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix))
|
||||
|
||||
hf_mid_atn_prefix = "mid_block.attentions.0."
|
||||
sd_mid_atn_prefix = "middle_block.1."
|
||||
unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
|
||||
|
||||
for j in range(2):
|
||||
hf_mid_res_prefix = f"mid_block.resnets.{j}."
|
||||
sd_mid_res_prefix = f"middle_block.{2 * j}."
|
||||
unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
|
||||
|
||||
|
||||
def convert_unet_state_dict(unet_state_dict):
|
||||
# buyer beware: this is a *brittle* function,
|
||||
# and correct output requires that all of these pieces interact in
|
||||
# the exact order in which I have arranged them.
|
||||
mapping = {k: k for k in unet_state_dict.keys()}
|
||||
for sd_name, hf_name in unet_conversion_map:
|
||||
mapping[hf_name] = sd_name
|
||||
for k, v in mapping.items():
|
||||
if "resnets" in k:
|
||||
for sd_part, hf_part in unet_conversion_map_resnet:
|
||||
v = v.replace(hf_part, sd_part)
|
||||
mapping[k] = v
|
||||
for k, v in mapping.items():
|
||||
for sd_part, hf_part in unet_conversion_map_layer:
|
||||
v = v.replace(hf_part, sd_part)
|
||||
mapping[k] = v
|
||||
new_state_dict = {v: unet_state_dict[k] for k, v in mapping.items()}
|
||||
return new_state_dict
|
||||
|
||||
|
||||
# ================#
|
||||
# VAE Conversion #
|
||||
# ================#
|
||||
|
||||
vae_conversion_map = [
|
||||
# (stable-diffusion, HF Diffusers)
|
||||
("nin_shortcut", "conv_shortcut"),
|
||||
("norm_out", "conv_norm_out"),
|
||||
("mid.attn_1.", "mid_block.attentions.0."),
|
||||
]
|
||||
|
||||
for i in range(4):
|
||||
# down_blocks have two resnets
|
||||
for j in range(2):
|
||||
hf_down_prefix = f"encoder.down_blocks.{i}.resnets.{j}."
|
||||
sd_down_prefix = f"encoder.down.{i}.block.{j}."
|
||||
vae_conversion_map.append((sd_down_prefix, hf_down_prefix))
|
||||
|
||||
if i < 3:
|
||||
hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0."
|
||||
sd_downsample_prefix = f"down.{i}.downsample."
|
||||
vae_conversion_map.append((sd_downsample_prefix, hf_downsample_prefix))
|
||||
|
||||
hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
|
||||
sd_upsample_prefix = f"up.{3 - i}.upsample."
|
||||
vae_conversion_map.append((sd_upsample_prefix, hf_upsample_prefix))
|
||||
|
||||
# up_blocks have three resnets
|
||||
# also, up blocks in hf are numbered in reverse from sd
|
||||
for j in range(3):
|
||||
hf_up_prefix = f"decoder.up_blocks.{i}.resnets.{j}."
|
||||
sd_up_prefix = f"decoder.up.{3 - i}.block.{j}."
|
||||
vae_conversion_map.append((sd_up_prefix, hf_up_prefix))
|
||||
|
||||
# this part accounts for mid blocks in both the encoder and the decoder
|
||||
for i in range(2):
|
||||
hf_mid_res_prefix = f"mid_block.resnets.{i}."
|
||||
sd_mid_res_prefix = f"mid.block_{i + 1}."
|
||||
vae_conversion_map.append((sd_mid_res_prefix, hf_mid_res_prefix))
|
||||
|
||||
vae_conversion_map_attn = [
|
||||
# (stable-diffusion, HF Diffusers)
|
||||
("norm.", "group_norm."),
|
||||
("q.", "query."),
|
||||
("k.", "key."),
|
||||
("v.", "value."),
|
||||
("q.", "to_q."),
|
||||
("k.", "to_k."),
|
||||
("v.", "to_v."),
|
||||
("proj_out.", "to_out.0."),
|
||||
("proj_out.", "proj_attn."),
|
||||
]
|
||||
|
||||
|
||||
def reshape_weight_for_sd(w):
|
||||
# convert HF linear weights to SD conv2d weights
|
||||
return w.reshape(*w.shape, 1, 1)
|
||||
|
||||
|
||||
def convert_vae_state_dict(vae_state_dict):
|
||||
mapping = {k: k for k in vae_state_dict.keys()}
|
||||
for k, v in mapping.items():
|
||||
for sd_part, hf_part in vae_conversion_map:
|
||||
v = v.replace(hf_part, sd_part)
|
||||
mapping[k] = v
|
||||
for k, v in mapping.items():
|
||||
if "attentions" in k:
|
||||
for sd_part, hf_part in vae_conversion_map_attn:
|
||||
v = v.replace(hf_part, sd_part)
|
||||
mapping[k] = v
|
||||
new_state_dict = {v: vae_state_dict[k] for k, v in mapping.items()}
|
||||
weights_to_convert = ["q", "k", "v", "proj_out"]
|
||||
for k, v in new_state_dict.items():
|
||||
for weight_name in weights_to_convert:
|
||||
if f"mid.attn_1.{weight_name}.weight" in k:
|
||||
logging.debug(f"Reshaping {k} for SD format")
|
||||
new_state_dict[k] = reshape_weight_for_sd(v)
|
||||
return new_state_dict
|
||||
|
||||
|
||||
# =========================#
|
||||
# Text Encoder Conversion #
|
||||
# =========================#
|
||||
|
||||
|
||||
textenc_conversion_lst = [
|
||||
# (stable-diffusion, HF Diffusers)
|
||||
("resblocks.", "text_model.encoder.layers."),
|
||||
("ln_1", "layer_norm1"),
|
||||
("ln_2", "layer_norm2"),
|
||||
(".c_fc.", ".fc1."),
|
||||
(".c_proj.", ".fc2."),
|
||||
(".attn", ".self_attn"),
|
||||
("ln_final.", "transformer.text_model.final_layer_norm."),
|
||||
("token_embedding.weight", "transformer.text_model.embeddings.token_embedding.weight"),
|
||||
("positional_embedding", "transformer.text_model.embeddings.position_embedding.weight"),
|
||||
]
|
||||
protected = {re.escape(x[1]): x[0] for x in textenc_conversion_lst}
|
||||
textenc_pattern = re.compile("|".join(protected.keys()))
|
||||
|
||||
# Ordering is from https://github.com/pytorch/pytorch/blob/master/test/cpp/api/modules.cpp
|
||||
code2idx = {"q": 0, "k": 1, "v": 2}
|
||||
|
||||
# This function exists because at the time of writing torch.cat can't do fp8 with cuda
|
||||
def cat_tensors(tensors):
|
||||
x = 0
|
||||
for t in tensors:
|
||||
x += t.shape[0]
|
||||
|
||||
shape = [x] + list(tensors[0].shape)[1:]
|
||||
out = torch.empty(shape, device=tensors[0].device, dtype=tensors[0].dtype)
|
||||
|
||||
x = 0
|
||||
for t in tensors:
|
||||
out[x:x + t.shape[0]] = t
|
||||
x += t.shape[0]
|
||||
|
||||
return out
|
||||
|
||||
def convert_text_enc_state_dict_v20(text_enc_dict, prefix=""):
|
||||
new_state_dict = {}
|
||||
capture_qkv_weight = {}
|
||||
capture_qkv_bias = {}
|
||||
for k, v in text_enc_dict.items():
|
||||
if not k.startswith(prefix):
|
||||
continue
|
||||
if (
|
||||
k.endswith(".self_attn.q_proj.weight")
|
||||
or k.endswith(".self_attn.k_proj.weight")
|
||||
or k.endswith(".self_attn.v_proj.weight")
|
||||
):
|
||||
k_pre = k[: -len(".q_proj.weight")]
|
||||
k_code = k[-len("q_proj.weight")]
|
||||
if k_pre not in capture_qkv_weight:
|
||||
capture_qkv_weight[k_pre] = [None, None, None]
|
||||
capture_qkv_weight[k_pre][code2idx[k_code]] = v
|
||||
continue
|
||||
|
||||
if (
|
||||
k.endswith(".self_attn.q_proj.bias")
|
||||
or k.endswith(".self_attn.k_proj.bias")
|
||||
or k.endswith(".self_attn.v_proj.bias")
|
||||
):
|
||||
k_pre = k[: -len(".q_proj.bias")]
|
||||
k_code = k[-len("q_proj.bias")]
|
||||
if k_pre not in capture_qkv_bias:
|
||||
capture_qkv_bias[k_pre] = [None, None, None]
|
||||
capture_qkv_bias[k_pre][code2idx[k_code]] = v
|
||||
continue
|
||||
|
||||
text_proj = "transformer.text_projection.weight"
|
||||
if k.endswith(text_proj):
|
||||
new_state_dict[k.replace(text_proj, "text_projection")] = v.transpose(0, 1).contiguous()
|
||||
else:
|
||||
relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k)
|
||||
new_state_dict[relabelled_key] = v
|
||||
|
||||
for k_pre, tensors in capture_qkv_weight.items():
|
||||
if None in tensors:
|
||||
raise Exception("CORRUPTED MODEL: one of the q-k-v values for the text encoder was missing")
|
||||
relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k_pre)
|
||||
new_state_dict[relabelled_key + ".in_proj_weight"] = cat_tensors(tensors)
|
||||
|
||||
for k_pre, tensors in capture_qkv_bias.items():
|
||||
if None in tensors:
|
||||
raise Exception("CORRUPTED MODEL: one of the q-k-v values for the text encoder was missing")
|
||||
relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k_pre)
|
||||
new_state_dict[relabelled_key + ".in_proj_bias"] = cat_tensors(tensors)
|
||||
|
||||
return new_state_dict
|
||||
|
||||
|
||||
def convert_text_enc_state_dict(text_enc_dict):
|
||||
return text_enc_dict
|
||||
|
||||
|
||||
181
huggingface_guess/latent.py
Normal file
181
huggingface_guess/latent.py
Normal file
@@ -0,0 +1,181 @@
|
||||
import torch
|
||||
|
||||
|
||||
class LatentFormat:
|
||||
scale_factor = 1.0
|
||||
latent_channels = 4
|
||||
latent_rgb_factors = None
|
||||
taesd_decoder_name = None
|
||||
|
||||
def process_in(self, latent):
|
||||
return latent * self.scale_factor
|
||||
|
||||
def process_out(self, latent):
|
||||
return latent / self.scale_factor
|
||||
|
||||
|
||||
class SD15(LatentFormat):
|
||||
def __init__(self, scale_factor=0.18215):
|
||||
self.scale_factor = scale_factor
|
||||
self.latent_rgb_factors = [
|
||||
# R G B
|
||||
[0.3512, 0.2297, 0.3227],
|
||||
[0.3250, 0.4974, 0.2350],
|
||||
[-0.2829, 0.1762, 0.2721],
|
||||
[-0.2120, -0.2616, -0.7177]
|
||||
]
|
||||
self.taesd_decoder_name = "taesd_decoder"
|
||||
|
||||
|
||||
class SDXL(LatentFormat):
|
||||
scale_factor = 0.13025
|
||||
|
||||
def __init__(self):
|
||||
self.latent_rgb_factors = [
|
||||
# R G B
|
||||
[0.3920, 0.4054, 0.4549],
|
||||
[-0.2634, -0.0196, 0.0653],
|
||||
[0.0568, 0.1687, -0.0755],
|
||||
[-0.3112, -0.2359, -0.2076]
|
||||
]
|
||||
self.taesd_decoder_name = "taesdxl_decoder"
|
||||
|
||||
|
||||
class SDXL_Playground_2_5(LatentFormat):
|
||||
def __init__(self):
|
||||
self.scale_factor = 0.5
|
||||
self.latents_mean = torch.tensor([-1.6574, 1.886, -1.383, 2.5155]).view(1, 4, 1, 1)
|
||||
self.latents_std = torch.tensor([8.4927, 5.9022, 6.5498, 5.2299]).view(1, 4, 1, 1)
|
||||
|
||||
self.latent_rgb_factors = [
|
||||
# R G B
|
||||
[0.3920, 0.4054, 0.4549],
|
||||
[-0.2634, -0.0196, 0.0653],
|
||||
[0.0568, 0.1687, -0.0755],
|
||||
[-0.3112, -0.2359, -0.2076]
|
||||
]
|
||||
self.taesd_decoder_name = "taesdxl_decoder"
|
||||
|
||||
def process_in(self, latent):
|
||||
latents_mean = self.latents_mean.to(latent.device, latent.dtype)
|
||||
latents_std = self.latents_std.to(latent.device, latent.dtype)
|
||||
return (latent - latents_mean) * self.scale_factor / latents_std
|
||||
|
||||
def process_out(self, latent):
|
||||
latents_mean = self.latents_mean.to(latent.device, latent.dtype)
|
||||
latents_std = self.latents_std.to(latent.device, latent.dtype)
|
||||
return latent * latents_std / self.scale_factor + latents_mean
|
||||
|
||||
|
||||
class SD_X4(LatentFormat):
|
||||
def __init__(self):
|
||||
self.scale_factor = 0.08333
|
||||
self.latent_rgb_factors = [
|
||||
[-0.2340, -0.3863, -0.3257],
|
||||
[0.0994, 0.0885, -0.0908],
|
||||
[-0.2833, -0.2349, -0.3741],
|
||||
[0.2523, -0.0055, -0.1651]
|
||||
]
|
||||
|
||||
|
||||
class SC_Prior(LatentFormat):
|
||||
latent_channels = 16
|
||||
|
||||
def __init__(self):
|
||||
self.scale_factor = 1.0
|
||||
self.latent_rgb_factors = [
|
||||
[-0.0326, -0.0204, -0.0127],
|
||||
[-0.1592, -0.0427, 0.0216],
|
||||
[0.0873, 0.0638, -0.0020],
|
||||
[-0.0602, 0.0442, 0.1304],
|
||||
[0.0800, -0.0313, -0.1796],
|
||||
[-0.0810, -0.0638, -0.1581],
|
||||
[0.1791, 0.1180, 0.0967],
|
||||
[0.0740, 0.1416, 0.0432],
|
||||
[-0.1745, -0.1888, -0.1373],
|
||||
[0.2412, 0.1577, 0.0928],
|
||||
[0.1908, 0.0998, 0.0682],
|
||||
[0.0209, 0.0365, -0.0092],
|
||||
[0.0448, -0.0650, -0.1728],
|
||||
[-0.1658, -0.1045, -0.1308],
|
||||
[0.0542, 0.1545, 0.1325],
|
||||
[-0.0352, -0.1672, -0.2541]
|
||||
]
|
||||
|
||||
|
||||
class SC_B(LatentFormat):
|
||||
def __init__(self):
|
||||
self.scale_factor = 1.0 / 0.43
|
||||
self.latent_rgb_factors = [
|
||||
[0.1121, 0.2006, 0.1023],
|
||||
[-0.2093, -0.0222, -0.0195],
|
||||
[-0.3087, -0.1535, 0.0366],
|
||||
[0.0290, -0.1574, -0.4078]
|
||||
]
|
||||
|
||||
|
||||
class SD3(LatentFormat):
|
||||
latent_channels = 16
|
||||
|
||||
def __init__(self):
|
||||
self.scale_factor = 1.5305
|
||||
self.shift_factor = 0.0609
|
||||
self.latent_rgb_factors = [
|
||||
[-0.0645, 0.0177, 0.1052],
|
||||
[0.0028, 0.0312, 0.0650],
|
||||
[0.1848, 0.0762, 0.0360],
|
||||
[0.0944, 0.0360, 0.0889],
|
||||
[0.0897, 0.0506, -0.0364],
|
||||
[-0.0020, 0.1203, 0.0284],
|
||||
[0.0855, 0.0118, 0.0283],
|
||||
[-0.0539, 0.0658, 0.1047],
|
||||
[-0.0057, 0.0116, 0.0700],
|
||||
[-0.0412, 0.0281, -0.0039],
|
||||
[0.1106, 0.1171, 0.1220],
|
||||
[-0.0248, 0.0682, -0.0481],
|
||||
[0.0815, 0.0846, 0.1207],
|
||||
[-0.0120, -0.0055, -0.0867],
|
||||
[-0.0749, -0.0634, -0.0456],
|
||||
[-0.1418, -0.1457, -0.1259]
|
||||
]
|
||||
self.taesd_decoder_name = "taesd3_decoder"
|
||||
|
||||
def process_in(self, latent):
|
||||
return (latent - self.shift_factor) * self.scale_factor
|
||||
|
||||
def process_out(self, latent):
|
||||
return (latent / self.scale_factor) + self.shift_factor
|
||||
|
||||
|
||||
class StableAudio1(LatentFormat):
|
||||
latent_channels = 64
|
||||
|
||||
|
||||
class Flux(SD3):
|
||||
def __init__(self):
|
||||
self.scale_factor = 0.3611
|
||||
self.shift_factor = 0.1159
|
||||
self.latent_rgb_factors = [
|
||||
[-0.0404, 0.0159, 0.0609],
|
||||
[0.0043, 0.0298, 0.0850],
|
||||
[0.0328, -0.0749, -0.0503],
|
||||
[-0.0245, 0.0085, 0.0549],
|
||||
[0.0966, 0.0894, 0.0530],
|
||||
[0.0035, 0.0399, 0.0123],
|
||||
[0.0583, 0.1184, 0.1262],
|
||||
[-0.0191, -0.0206, -0.0306],
|
||||
[-0.0324, 0.0055, 0.1001],
|
||||
[0.0955, 0.0659, -0.0545],
|
||||
[-0.0504, 0.0231, -0.0013],
|
||||
[0.0500, -0.0008, -0.0088],
|
||||
[0.0982, 0.0941, 0.0976],
|
||||
[-0.1233, -0.0280, -0.0897],
|
||||
[-0.0005, -0.0530, -0.0020],
|
||||
[-0.1273, -0.0932, -0.0680]
|
||||
]
|
||||
|
||||
def process_in(self, latent):
|
||||
return (latent - self.shift_factor) * self.scale_factor
|
||||
|
||||
def process_out(self, latent):
|
||||
return (latent / self.scale_factor) + self.shift_factor
|
||||
720
huggingface_guess/model_list.py
Normal file
720
huggingface_guess/model_list.py
Normal file
@@ -0,0 +1,720 @@
|
||||
import torch
|
||||
|
||||
from enum import Enum
|
||||
from huggingface_guess import latent, utils, diffusers_convert
|
||||
|
||||
|
||||
class ModelType(Enum):
|
||||
EPS = 1
|
||||
V_PREDICTION = 2
|
||||
V_PREDICTION_EDM = 3
|
||||
STABLE_CASCADE = 4
|
||||
EDM = 5
|
||||
FLOW = 6
|
||||
V_PREDICTION_CONTINUOUS = 7
|
||||
FLUX = 8
|
||||
|
||||
|
||||
class BASE:
|
||||
huggingface_repo = None
|
||||
unet_config = {}
|
||||
unet_extra_config = {
|
||||
"num_heads": -1,
|
||||
"num_head_channels": 64,
|
||||
}
|
||||
|
||||
required_keys = {}
|
||||
|
||||
clip_prefix = []
|
||||
clip_vision_prefix = None
|
||||
noise_aug_config = None
|
||||
sampling_settings = {}
|
||||
latent_format = latent.LatentFormat
|
||||
vae_key_prefix = ["first_stage_model."]
|
||||
text_encoder_key_prefix = ["cond_stage_model."]
|
||||
supported_inference_dtypes = [torch.float16, torch.bfloat16, torch.float32]
|
||||
|
||||
memory_usage_factor = 2.0
|
||||
|
||||
manual_cast_dtype = None
|
||||
|
||||
@classmethod
|
||||
def matches(s, unet_config, state_dict=None):
|
||||
for k in s.unet_config:
|
||||
if k not in unet_config or s.unet_config[k] != unet_config[k]:
|
||||
return False
|
||||
if state_dict is not None:
|
||||
for k in s.required_keys:
|
||||
if k not in state_dict:
|
||||
return False
|
||||
return True
|
||||
|
||||
def model_type(self, state_dict, prefix=""):
|
||||
return ModelType.EPS
|
||||
|
||||
def inpaint_model(self):
|
||||
return self.unet_config["in_channels"] > 4
|
||||
|
||||
def __init__(self, unet_config):
|
||||
self.unet_config = unet_config.copy()
|
||||
self.sampling_settings = self.sampling_settings.copy()
|
||||
self.latent_format = self.latent_format()
|
||||
for x in self.unet_extra_config:
|
||||
self.unet_config[x] = self.unet_extra_config[x]
|
||||
|
||||
def process_clip_state_dict(self, state_dict):
|
||||
state_dict = utils.state_dict_prefix_replace(state_dict, {k: "" for k in self.text_encoder_key_prefix}, filter_keys=True)
|
||||
return state_dict
|
||||
|
||||
def process_unet_state_dict(self, state_dict):
|
||||
return state_dict
|
||||
|
||||
def process_vae_state_dict(self, state_dict):
|
||||
return state_dict
|
||||
|
||||
def process_clip_state_dict_for_saving(self, state_dict):
|
||||
replace_prefix = {"": self.text_encoder_key_prefix[0]}
|
||||
return utils.state_dict_prefix_replace(state_dict, replace_prefix)
|
||||
|
||||
def process_clip_vision_state_dict_for_saving(self, state_dict):
|
||||
replace_prefix = {}
|
||||
if self.clip_vision_prefix is not None:
|
||||
replace_prefix[""] = self.clip_vision_prefix
|
||||
return utils.state_dict_prefix_replace(state_dict, replace_prefix)
|
||||
|
||||
def process_unet_state_dict_for_saving(self, state_dict):
|
||||
replace_prefix = {"": "model.diffusion_model."}
|
||||
return utils.state_dict_prefix_replace(state_dict, replace_prefix)
|
||||
|
||||
def process_vae_state_dict_for_saving(self, state_dict):
|
||||
replace_prefix = {"": self.vae_key_prefix[0]}
|
||||
return utils.state_dict_prefix_replace(state_dict, replace_prefix)
|
||||
|
||||
|
||||
class SD15(BASE):
|
||||
huggingface_repo = "runwayml/stable-diffusion-v1-5"
|
||||
|
||||
unet_config = {
|
||||
"context_dim": 768,
|
||||
"model_channels": 320,
|
||||
"use_linear_in_transformer": False,
|
||||
"adm_in_channels": None,
|
||||
"use_temporal_attention": False,
|
||||
}
|
||||
|
||||
unet_extra_config = {
|
||||
"num_heads": 8,
|
||||
"num_head_channels": -1,
|
||||
}
|
||||
|
||||
latent_format = latent.SD15
|
||||
memory_usage_factor = 1.0
|
||||
|
||||
def process_clip_state_dict(self, state_dict):
|
||||
k = list(state_dict.keys())
|
||||
for x in k:
|
||||
if x.startswith("cond_stage_model.transformer.") and not x.startswith("cond_stage_model.transformer.text_model."):
|
||||
y = x.replace("cond_stage_model.transformer.", "cond_stage_model.transformer.text_model.")
|
||||
state_dict[y] = state_dict.pop(x)
|
||||
|
||||
if 'cond_stage_model.transformer.text_model.embeddings.position_ids' in state_dict:
|
||||
ids = state_dict['cond_stage_model.transformer.text_model.embeddings.position_ids']
|
||||
if ids.dtype == torch.float32:
|
||||
state_dict['cond_stage_model.transformer.text_model.embeddings.position_ids'] = ids.round()
|
||||
|
||||
replace_prefix = {}
|
||||
replace_prefix["cond_stage_model."] = "clip_l."
|
||||
state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix, filter_keys=True)
|
||||
return state_dict
|
||||
|
||||
def process_clip_state_dict_for_saving(self, state_dict):
|
||||
pop_keys = ["clip_l.transformer.text_projection.weight", "clip_l.logit_scale"]
|
||||
for p in pop_keys:
|
||||
if p in state_dict:
|
||||
state_dict.pop(p)
|
||||
|
||||
replace_prefix = {"clip_l.": "cond_stage_model."}
|
||||
return utils.state_dict_prefix_replace(state_dict, replace_prefix)
|
||||
|
||||
def clip_target(self, state_dict={}):
|
||||
return ['clip_l']
|
||||
|
||||
|
||||
class SD20(BASE):
|
||||
huggingface_repo = "stabilityai/stable-diffusion-2-1"
|
||||
|
||||
unet_config = {
|
||||
"context_dim": 1024,
|
||||
"model_channels": 320,
|
||||
"use_linear_in_transformer": True,
|
||||
"adm_in_channels": None,
|
||||
"use_temporal_attention": False,
|
||||
}
|
||||
|
||||
unet_extra_config = {
|
||||
"num_heads": -1,
|
||||
"num_head_channels": 64,
|
||||
"attn_precision": torch.float32,
|
||||
}
|
||||
|
||||
latent_format = latent.SD15
|
||||
memory_usage_factor = 1.0
|
||||
|
||||
def model_type(self, state_dict, prefix=""):
|
||||
if self.unet_config["in_channels"] == 4: # SD2.0 inpainting models are not v prediction
|
||||
k = "{}output_blocks.11.1.transformer_blocks.0.norm1.bias".format(prefix)
|
||||
out = state_dict.get(k, None)
|
||||
if out is not None and torch.std(out, unbiased=False) > 0.09: # not sure how well this will actually work. I guess we will find out.
|
||||
return ModelType.V_PREDICTION
|
||||
return ModelType.EPS
|
||||
|
||||
def process_clip_state_dict(self, state_dict):
|
||||
replace_prefix = {}
|
||||
replace_prefix["conditioner.embedders.0.model."] = "clip_h." # SD2 in sgm format
|
||||
replace_prefix["cond_stage_model.model."] = "clip_h."
|
||||
state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix, filter_keys=True)
|
||||
state_dict = utils.clip_text_transformers_convert(state_dict, "clip_h.", "clip_h.transformer.")
|
||||
return state_dict
|
||||
|
||||
def process_clip_state_dict_for_saving(self, state_dict):
|
||||
replace_prefix = {}
|
||||
replace_prefix["clip_h"] = "cond_stage_model.model"
|
||||
state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix)
|
||||
state_dict = diffusers_convert.convert_text_enc_state_dict_v20(state_dict)
|
||||
return state_dict
|
||||
|
||||
def clip_target(self, state_dict={}):
|
||||
return ['clip_h']
|
||||
|
||||
|
||||
class SD21UnclipL(SD20):
|
||||
unet_config = {
|
||||
"context_dim": 1024,
|
||||
"model_channels": 320,
|
||||
"use_linear_in_transformer": True,
|
||||
"adm_in_channels": 1536,
|
||||
"use_temporal_attention": False,
|
||||
}
|
||||
|
||||
clip_vision_prefix = "embedder.model.visual."
|
||||
noise_aug_config = {"noise_schedule_config": {"timesteps": 1000, "beta_schedule": "squaredcos_cap_v2"}, "timestep_dim": 768}
|
||||
|
||||
|
||||
class SD21UnclipH(SD20):
|
||||
unet_config = {
|
||||
"context_dim": 1024,
|
||||
"model_channels": 320,
|
||||
"use_linear_in_transformer": True,
|
||||
"adm_in_channels": 2048,
|
||||
"use_temporal_attention": False,
|
||||
}
|
||||
|
||||
clip_vision_prefix = "embedder.model.visual."
|
||||
noise_aug_config = {"noise_schedule_config": {"timesteps": 1000, "beta_schedule": "squaredcos_cap_v2"}, "timestep_dim": 1024}
|
||||
|
||||
|
||||
class SDXLRefiner(BASE):
|
||||
huggingface_repo = "stabilityai/stable-diffusion-xl-refiner-1.0"
|
||||
|
||||
unet_config = {
|
||||
"model_channels": 384,
|
||||
"use_linear_in_transformer": True,
|
||||
"context_dim": 1280,
|
||||
"adm_in_channels": 2560,
|
||||
"transformer_depth": [0, 0, 4, 4, 4, 4, 0, 0],
|
||||
"use_temporal_attention": False,
|
||||
}
|
||||
|
||||
latent_format = latent.SDXL
|
||||
memory_usage_factor = 1.0
|
||||
|
||||
def process_clip_state_dict(self, state_dict):
|
||||
keys_to_replace = {}
|
||||
replace_prefix = {}
|
||||
replace_prefix["conditioner.embedders.0.model."] = "clip_g."
|
||||
state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix, filter_keys=True)
|
||||
|
||||
state_dict = utils.clip_text_transformers_convert(state_dict, "clip_g.", "clip_g.transformer.")
|
||||
state_dict = utils.state_dict_key_replace(state_dict, keys_to_replace)
|
||||
return state_dict
|
||||
|
||||
def process_clip_state_dict_for_saving(self, state_dict):
|
||||
replace_prefix = {}
|
||||
state_dict_g = diffusers_convert.convert_text_enc_state_dict_v20(state_dict, "clip_g")
|
||||
if "clip_g.transformer.text_model.embeddings.position_ids" in state_dict_g:
|
||||
state_dict_g.pop("clip_g.transformer.text_model.embeddings.position_ids")
|
||||
replace_prefix["clip_g"] = "conditioner.embedders.0.model"
|
||||
state_dict_g = utils.state_dict_prefix_replace(state_dict_g, replace_prefix)
|
||||
return state_dict_g
|
||||
|
||||
def clip_target(self, state_dict={}):
|
||||
return ["clip_g"]
|
||||
|
||||
|
||||
class SDXL(BASE):
|
||||
huggingface_repo = "stabilityai/stable-diffusion-xl-base-1.0"
|
||||
|
||||
unet_config = {
|
||||
"model_channels": 320,
|
||||
"use_linear_in_transformer": True,
|
||||
"transformer_depth": [0, 0, 2, 2, 10, 10],
|
||||
"context_dim": 2048,
|
||||
"adm_in_channels": 2816,
|
||||
"use_temporal_attention": False,
|
||||
}
|
||||
|
||||
latent_format = latent.SDXL
|
||||
|
||||
memory_usage_factor = 0.7
|
||||
|
||||
def model_type(self, state_dict, prefix=""):
|
||||
if 'edm_mean' in state_dict and 'edm_std' in state_dict: # Playground V2.5
|
||||
self.latent_format = latent.SDXL_Playground_2_5()
|
||||
self.sampling_settings["sigma_data"] = 0.5
|
||||
self.sampling_settings["sigma_max"] = 80.0
|
||||
self.sampling_settings["sigma_min"] = 0.002
|
||||
return ModelType.EDM
|
||||
elif "edm_vpred.sigma_max" in state_dict:
|
||||
self.sampling_settings["sigma_max"] = float(state_dict["edm_vpred.sigma_max"].item())
|
||||
if "edm_vpred.sigma_min" in state_dict:
|
||||
self.sampling_settings["sigma_min"] = float(state_dict["edm_vpred.sigma_min"].item())
|
||||
return ModelType.V_PREDICTION_EDM
|
||||
elif "v_pred" in state_dict:
|
||||
return ModelType.V_PREDICTION
|
||||
else:
|
||||
return ModelType.EPS
|
||||
|
||||
def process_clip_state_dict(self, state_dict):
|
||||
keys_to_replace = {}
|
||||
replace_prefix = {}
|
||||
|
||||
replace_prefix["conditioner.embedders.0.transformer.text_model"] = "clip_l.transformer.text_model"
|
||||
replace_prefix["conditioner.embedders.1.model."] = "clip_g."
|
||||
state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix, filter_keys=True)
|
||||
|
||||
state_dict = utils.state_dict_key_replace(state_dict, keys_to_replace)
|
||||
state_dict = utils.clip_text_transformers_convert(state_dict, "clip_g.", "clip_g.transformer.")
|
||||
return state_dict
|
||||
|
||||
def process_clip_state_dict_for_saving(self, state_dict):
|
||||
replace_prefix = {}
|
||||
keys_to_replace = {}
|
||||
state_dict_g = diffusers_convert.convert_text_enc_state_dict_v20(state_dict, "clip_g")
|
||||
for k in state_dict:
|
||||
if k.startswith("clip_l"):
|
||||
state_dict_g[k] = state_dict[k]
|
||||
|
||||
state_dict_g["clip_l.transformer.text_model.embeddings.position_ids"] = torch.arange(77).expand((1, -1))
|
||||
pop_keys = ["clip_l.transformer.text_projection.weight", "clip_l.logit_scale"]
|
||||
for p in pop_keys:
|
||||
if p in state_dict_g:
|
||||
state_dict_g.pop(p)
|
||||
|
||||
replace_prefix["clip_g"] = "conditioner.embedders.1.model"
|
||||
replace_prefix["clip_l"] = "conditioner.embedders.0"
|
||||
state_dict_g = utils.state_dict_prefix_replace(state_dict_g, replace_prefix)
|
||||
return state_dict_g
|
||||
|
||||
def clip_target(self, state_dict={}):
|
||||
return ['clip_l', 'clip_g']
|
||||
|
||||
|
||||
class SSD1B(SDXL):
|
||||
unet_config = {
|
||||
"model_channels": 320,
|
||||
"use_linear_in_transformer": True,
|
||||
"transformer_depth": [0, 0, 2, 2, 4, 4],
|
||||
"context_dim": 2048,
|
||||
"adm_in_channels": 2816,
|
||||
"use_temporal_attention": False,
|
||||
}
|
||||
|
||||
|
||||
class Segmind_Vega(SDXL):
|
||||
unet_config = {
|
||||
"model_channels": 320,
|
||||
"use_linear_in_transformer": True,
|
||||
"transformer_depth": [0, 0, 1, 1, 2, 2],
|
||||
"context_dim": 2048,
|
||||
"adm_in_channels": 2816,
|
||||
"use_temporal_attention": False,
|
||||
}
|
||||
|
||||
|
||||
class KOALA_700M(SDXL):
|
||||
unet_config = {
|
||||
"model_channels": 320,
|
||||
"use_linear_in_transformer": True,
|
||||
"transformer_depth": [0, 2, 5],
|
||||
"context_dim": 2048,
|
||||
"adm_in_channels": 2816,
|
||||
"use_temporal_attention": False,
|
||||
}
|
||||
|
||||
|
||||
class KOALA_1B(SDXL):
|
||||
unet_config = {
|
||||
"model_channels": 320,
|
||||
"use_linear_in_transformer": True,
|
||||
"transformer_depth": [0, 2, 6],
|
||||
"context_dim": 2048,
|
||||
"adm_in_channels": 2816,
|
||||
"use_temporal_attention": False,
|
||||
}
|
||||
|
||||
|
||||
class SVD_img2vid(BASE):
|
||||
unet_config = {
|
||||
"model_channels": 320,
|
||||
"in_channels": 8,
|
||||
"use_linear_in_transformer": True,
|
||||
"transformer_depth": [1, 1, 1, 1, 1, 1, 0, 0],
|
||||
"context_dim": 1024,
|
||||
"adm_in_channels": 768,
|
||||
"use_temporal_attention": True,
|
||||
"use_temporal_resblock": True
|
||||
}
|
||||
|
||||
unet_extra_config = {
|
||||
"num_heads": -1,
|
||||
"num_head_channels": 64,
|
||||
"attn_precision": torch.float32,
|
||||
}
|
||||
|
||||
clip_vision_prefix = "conditioner.embedders.0.open_clip.model.visual."
|
||||
|
||||
latent_format = latent.SD15
|
||||
|
||||
sampling_settings = {"sigma_max": 700.0, "sigma_min": 0.002}
|
||||
|
||||
def clip_target(self, state_dict={}):
|
||||
return None
|
||||
|
||||
|
||||
class SV3D_u(SVD_img2vid):
|
||||
unet_config = {
|
||||
"model_channels": 320,
|
||||
"in_channels": 8,
|
||||
"use_linear_in_transformer": True,
|
||||
"transformer_depth": [1, 1, 1, 1, 1, 1, 0, 0],
|
||||
"context_dim": 1024,
|
||||
"adm_in_channels": 256,
|
||||
"use_temporal_attention": True,
|
||||
"use_temporal_resblock": True
|
||||
}
|
||||
|
||||
vae_key_prefix = ["conditioner.embedders.1.encoder."]
|
||||
|
||||
|
||||
class SV3D_p(SV3D_u):
|
||||
unet_config = {
|
||||
"model_channels": 320,
|
||||
"in_channels": 8,
|
||||
"use_linear_in_transformer": True,
|
||||
"transformer_depth": [1, 1, 1, 1, 1, 1, 0, 0],
|
||||
"context_dim": 1024,
|
||||
"adm_in_channels": 1280,
|
||||
"use_temporal_attention": True,
|
||||
"use_temporal_resblock": True
|
||||
}
|
||||
|
||||
|
||||
class Stable_Zero123(BASE):
|
||||
unet_config = {
|
||||
"context_dim": 768,
|
||||
"model_channels": 320,
|
||||
"use_linear_in_transformer": False,
|
||||
"adm_in_channels": None,
|
||||
"use_temporal_attention": False,
|
||||
"in_channels": 8,
|
||||
}
|
||||
|
||||
unet_extra_config = {
|
||||
"num_heads": 8,
|
||||
"num_head_channels": -1,
|
||||
}
|
||||
|
||||
required_keys = {
|
||||
"cc_projection.weight": None,
|
||||
"cc_projection.bias": None,
|
||||
}
|
||||
|
||||
clip_vision_prefix = "cond_stage_model.model.visual."
|
||||
|
||||
latent_format = latent.SD15
|
||||
|
||||
def clip_target(self, state_dict={}):
|
||||
return None
|
||||
|
||||
|
||||
class SD_X4Upscaler(SD20):
|
||||
unet_config = {
|
||||
"context_dim": 1024,
|
||||
"model_channels": 256,
|
||||
'in_channels': 7,
|
||||
"use_linear_in_transformer": True,
|
||||
"adm_in_channels": None,
|
||||
"use_temporal_attention": False,
|
||||
}
|
||||
|
||||
unet_extra_config = {
|
||||
"disable_self_attentions": [True, True, True, False],
|
||||
"num_classes": 1000,
|
||||
"num_heads": 8,
|
||||
"num_head_channels": -1,
|
||||
}
|
||||
|
||||
latent_format = latent.SD_X4
|
||||
|
||||
sampling_settings = {
|
||||
"linear_start": 0.0001,
|
||||
"linear_end": 0.02,
|
||||
}
|
||||
|
||||
|
||||
class Stable_Cascade_C(BASE):
|
||||
unet_config = {
|
||||
"stable_cascade_stage": 'c',
|
||||
}
|
||||
|
||||
unet_extra_config = {}
|
||||
|
||||
latent_format = latent.SC_Prior
|
||||
supported_inference_dtypes = [torch.bfloat16, torch.float32]
|
||||
|
||||
sampling_settings = {
|
||||
"shift": 2.0,
|
||||
}
|
||||
|
||||
vae_key_prefix = ["vae."]
|
||||
text_encoder_key_prefix = ["text_encoder."]
|
||||
clip_vision_prefix = "clip_l_vision."
|
||||
|
||||
def process_unet_state_dict(self, state_dict):
|
||||
key_list = list(state_dict.keys())
|
||||
for y in ["weight", "bias"]:
|
||||
suffix = "in_proj_{}".format(y)
|
||||
keys = filter(lambda a: a.endswith(suffix), key_list)
|
||||
for k_from in keys:
|
||||
weights = state_dict.pop(k_from)
|
||||
prefix = k_from[:-(len(suffix) + 1)]
|
||||
shape_from = weights.shape[0] // 3
|
||||
for x in range(3):
|
||||
p = ["to_q", "to_k", "to_v"]
|
||||
k_to = "{}.{}.{}".format(prefix, p[x], y)
|
||||
state_dict[k_to] = weights[shape_from * x:shape_from * (x + 1)]
|
||||
return state_dict
|
||||
|
||||
def process_clip_state_dict(self, state_dict):
|
||||
state_dict = utils.state_dict_prefix_replace(state_dict, {k: "" for k in self.text_encoder_key_prefix}, filter_keys=True)
|
||||
if "clip_g.text_projection" in state_dict:
|
||||
state_dict["clip_g.transformer.text_projection.weight"] = state_dict.pop("clip_g.text_projection").transpose(0, 1)
|
||||
return state_dict
|
||||
|
||||
def clip_target(self, state_dict={}):
|
||||
return ['clip_g']
|
||||
|
||||
|
||||
class Stable_Cascade_B(Stable_Cascade_C):
|
||||
unet_config = {
|
||||
"stable_cascade_stage": 'b',
|
||||
}
|
||||
|
||||
unet_extra_config = {}
|
||||
|
||||
latent_format = latent.SC_B
|
||||
supported_inference_dtypes = [torch.float16, torch.bfloat16, torch.float32]
|
||||
|
||||
sampling_settings = {
|
||||
"shift": 1.0,
|
||||
}
|
||||
|
||||
clip_vision_prefix = None
|
||||
|
||||
|
||||
class SD15_instructpix2pix(SD15):
|
||||
unet_config = {
|
||||
"context_dim": 768,
|
||||
"model_channels": 320,
|
||||
"use_linear_in_transformer": False,
|
||||
"adm_in_channels": None,
|
||||
"use_temporal_attention": False,
|
||||
"in_channels": 8,
|
||||
}
|
||||
|
||||
|
||||
class SDXL_instructpix2pix(SDXL):
|
||||
unet_config = {
|
||||
"model_channels": 320,
|
||||
"use_linear_in_transformer": True,
|
||||
"transformer_depth": [0, 0, 2, 2, 10, 10],
|
||||
"context_dim": 2048,
|
||||
"adm_in_channels": 2816,
|
||||
"use_temporal_attention": False,
|
||||
"in_channels": 8,
|
||||
}
|
||||
|
||||
|
||||
class SD3(BASE):
|
||||
huggingface_repo = "stabilityai/stable-diffusion-3-medium-diffusers"
|
||||
|
||||
unet_config = {
|
||||
"in_channels": 16,
|
||||
"pos_embed_scaling_factor": None,
|
||||
}
|
||||
|
||||
sampling_settings = {
|
||||
"shift": 3.0,
|
||||
}
|
||||
|
||||
unet_extra_config = {}
|
||||
latent_format = latent.SD3
|
||||
|
||||
memory_usage_factor = 1.2
|
||||
|
||||
text_encoder_key_prefix = ["text_encoders."]
|
||||
|
||||
def clip_target(self, state_dict={}):
|
||||
clip_l = False
|
||||
clip_g = False
|
||||
t5 = False
|
||||
dtype_t5 = None
|
||||
pref = self.text_encoder_key_prefix[0]
|
||||
if "{}clip_l.transformer.text_model.final_layer_norm.weight".format(pref) in state_dict:
|
||||
clip_l = True
|
||||
if "{}clip_g.transformer.text_model.final_layer_norm.weight".format(pref) in state_dict:
|
||||
clip_g = True
|
||||
t5_key = "{}t5xxl.transformer.encoder.final_layer_norm.weight".format(pref)
|
||||
if t5_key in state_dict:
|
||||
t5 = True
|
||||
dtype_t5 = state_dict[t5_key].dtype
|
||||
|
||||
return dict(clip_l=clip_l, clip_g=clip_g, t5=t5, dtype_t5=dtype_t5)
|
||||
|
||||
|
||||
class StableAudio(BASE):
|
||||
unet_config = {
|
||||
"audio_model": "dit1.0",
|
||||
}
|
||||
|
||||
sampling_settings = {"sigma_max": 500.0, "sigma_min": 0.03}
|
||||
|
||||
unet_extra_config = {}
|
||||
latent_format = latent.StableAudio1
|
||||
|
||||
text_encoder_key_prefix = ["text_encoders."]
|
||||
vae_key_prefix = ["pretransform.model."]
|
||||
|
||||
def process_unet_state_dict(self, state_dict):
|
||||
for k in list(state_dict.keys()):
|
||||
if k.endswith(".cross_attend_norm.beta") or k.endswith(".ff_norm.beta") or k.endswith(".pre_norm.beta"): # These weights are all zero
|
||||
state_dict.pop(k)
|
||||
return state_dict
|
||||
|
||||
def process_unet_state_dict_for_saving(self, state_dict):
|
||||
replace_prefix = {"": "model.model."}
|
||||
return utils.state_dict_prefix_replace(state_dict, replace_prefix)
|
||||
|
||||
def clip_target(self, state_dict={}):
|
||||
return ['sa_t5']
|
||||
|
||||
|
||||
class AuraFlow(BASE):
|
||||
unet_config = {
|
||||
"cond_seq_dim": 2048,
|
||||
}
|
||||
|
||||
sampling_settings = {
|
||||
"multiplier": 1.0,
|
||||
"shift": 1.73,
|
||||
}
|
||||
|
||||
unet_extra_config = {}
|
||||
latent_format = latent.SDXL
|
||||
|
||||
vae_key_prefix = ["vae."]
|
||||
text_encoder_key_prefix = ["text_encoders."]
|
||||
|
||||
def clip_target(self, state_dict={}):
|
||||
return ['aura_t5']
|
||||
|
||||
|
||||
class HunyuanDiT(BASE):
|
||||
unet_config = {
|
||||
"image_model": "hydit",
|
||||
}
|
||||
|
||||
unet_extra_config = {
|
||||
"attn_precision": torch.float32,
|
||||
}
|
||||
|
||||
sampling_settings = {
|
||||
"linear_start": 0.00085,
|
||||
"linear_end": 0.018,
|
||||
}
|
||||
|
||||
latent_format = latent.SDXL
|
||||
|
||||
vae_key_prefix = ["vae."]
|
||||
text_encoder_key_prefix = ["text_encoders."]
|
||||
|
||||
def clip_target(self, state_dict={}):
|
||||
return ['hunyuan_t5']
|
||||
|
||||
|
||||
class HunyuanDiT1(HunyuanDiT):
|
||||
unet_config = {
|
||||
"image_model": "hydit1",
|
||||
}
|
||||
|
||||
unet_extra_config = {}
|
||||
|
||||
sampling_settings = {
|
||||
"linear_start": 0.00085,
|
||||
"linear_end": 0.03,
|
||||
}
|
||||
|
||||
|
||||
class Flux(BASE):
|
||||
unet_config = {
|
||||
"image_model": "flux",
|
||||
"guidance_embed": True,
|
||||
}
|
||||
|
||||
sampling_settings = {
|
||||
}
|
||||
|
||||
unet_extra_config = {}
|
||||
latent_format = latent.Flux
|
||||
|
||||
memory_usage_factor = 2.6
|
||||
|
||||
supported_inference_dtypes = [torch.bfloat16, torch.float32]
|
||||
|
||||
vae_key_prefix = ["vae."]
|
||||
text_encoder_key_prefix = ["text_encoders."]
|
||||
|
||||
def clip_target(self, state_dict={}):
|
||||
dtype_t5 = None
|
||||
pref = self.text_encoder_key_prefix[0]
|
||||
t5_key = "{}t5xxl.transformer.encoder.final_layer_norm.weight".format(pref)
|
||||
if t5_key in state_dict:
|
||||
dtype_t5 = state_dict[t5_key].dtype
|
||||
return dict(t5=True, dtype_t5=dtype_t5)
|
||||
|
||||
|
||||
class FluxSchnell(Flux):
|
||||
unet_config = {
|
||||
"image_model": "flux",
|
||||
"guidance_embed": False,
|
||||
}
|
||||
|
||||
sampling_settings = {
|
||||
"multiplier": 1.0,
|
||||
"shift": 1.0,
|
||||
}
|
||||
|
||||
|
||||
models = [Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, HunyuanDiT, HunyuanDiT1, Flux, FluxSchnell]
|
||||
|
||||
models += [SVD_img2vid]
|
||||
462
huggingface_guess/utils.py
Normal file
462
huggingface_guess/utils.py
Normal file
@@ -0,0 +1,462 @@
|
||||
import torch
|
||||
import math
|
||||
import struct
|
||||
|
||||
|
||||
def calculate_parameters(sd, prefix=""):
|
||||
params = 0
|
||||
for k in sd.keys():
|
||||
if k.startswith(prefix):
|
||||
w = sd[k]
|
||||
params += w.nelement()
|
||||
return params
|
||||
|
||||
|
||||
def weight_dtype(sd, prefix=""):
|
||||
dtypes = {}
|
||||
for k in sd.keys():
|
||||
if k.startswith(prefix):
|
||||
w = sd[k]
|
||||
dtypes[w.dtype] = dtypes.get(w.dtype, 0) + 1
|
||||
|
||||
return max(dtypes, key=dtypes.get)
|
||||
|
||||
|
||||
def state_dict_key_replace(state_dict, keys_to_replace):
|
||||
for x in keys_to_replace:
|
||||
if x in state_dict:
|
||||
state_dict[keys_to_replace[x]] = state_dict.pop(x)
|
||||
return state_dict
|
||||
|
||||
|
||||
def state_dict_prefix_replace(state_dict, replace_prefix, filter_keys=False):
|
||||
if filter_keys:
|
||||
out = {}
|
||||
else:
|
||||
out = state_dict
|
||||
for rp in replace_prefix:
|
||||
replace = list(map(lambda a: (a, "{}{}".format(replace_prefix[rp], a[len(rp):])), filter(lambda a: a.startswith(rp), state_dict.keys())))
|
||||
for x in replace:
|
||||
w = state_dict.pop(x[0])
|
||||
out[x[1]] = w
|
||||
return out
|
||||
|
||||
|
||||
def transformers_convert(sd, prefix_from, prefix_to, number):
|
||||
keys_to_replace = {
|
||||
"{}positional_embedding": "{}embeddings.position_embedding.weight",
|
||||
"{}token_embedding.weight": "{}embeddings.token_embedding.weight",
|
||||
"{}ln_final.weight": "{}final_layer_norm.weight",
|
||||
"{}ln_final.bias": "{}final_layer_norm.bias",
|
||||
}
|
||||
|
||||
for k in keys_to_replace:
|
||||
x = k.format(prefix_from)
|
||||
if x in sd:
|
||||
sd[keys_to_replace[k].format(prefix_to)] = sd.pop(x)
|
||||
|
||||
resblock_to_replace = {
|
||||
"ln_1": "layer_norm1",
|
||||
"ln_2": "layer_norm2",
|
||||
"mlp.c_fc": "mlp.fc1",
|
||||
"mlp.c_proj": "mlp.fc2",
|
||||
"attn.out_proj": "self_attn.out_proj",
|
||||
}
|
||||
|
||||
for resblock in range(number):
|
||||
for x in resblock_to_replace:
|
||||
for y in ["weight", "bias"]:
|
||||
k = "{}transformer.resblocks.{}.{}.{}".format(prefix_from, resblock, x, y)
|
||||
k_to = "{}encoder.layers.{}.{}.{}".format(prefix_to, resblock, resblock_to_replace[x], y)
|
||||
if k in sd:
|
||||
sd[k_to] = sd.pop(k)
|
||||
|
||||
for y in ["weight", "bias"]:
|
||||
k_from = "{}transformer.resblocks.{}.attn.in_proj_{}".format(prefix_from, resblock, y)
|
||||
if k_from in sd:
|
||||
weights = sd.pop(k_from)
|
||||
shape_from = weights.shape[0] // 3
|
||||
for x in range(3):
|
||||
p = ["self_attn.q_proj", "self_attn.k_proj", "self_attn.v_proj"]
|
||||
k_to = "{}encoder.layers.{}.{}.{}".format(prefix_to, resblock, p[x], y)
|
||||
sd[k_to] = weights[shape_from * x:shape_from * (x + 1)]
|
||||
|
||||
return sd
|
||||
|
||||
|
||||
def clip_text_transformers_convert(sd, prefix_from, prefix_to):
|
||||
sd = transformers_convert(sd, prefix_from, "{}text_model.".format(prefix_to), 32)
|
||||
|
||||
tp = "{}text_projection.weight".format(prefix_from)
|
||||
if tp in sd:
|
||||
sd["{}text_projection.weight".format(prefix_to)] = sd.pop(tp)
|
||||
|
||||
tp = "{}text_projection".format(prefix_from)
|
||||
if tp in sd:
|
||||
sd["{}text_projection.weight".format(prefix_to)] = sd.pop(tp).transpose(0, 1).contiguous()
|
||||
return sd
|
||||
|
||||
|
||||
UNET_MAP_ATTENTIONS = {
|
||||
"proj_in.weight",
|
||||
"proj_in.bias",
|
||||
"proj_out.weight",
|
||||
"proj_out.bias",
|
||||
"norm.weight",
|
||||
"norm.bias",
|
||||
}
|
||||
|
||||
TRANSFORMER_BLOCKS = {
|
||||
"norm1.weight",
|
||||
"norm1.bias",
|
||||
"norm2.weight",
|
||||
"norm2.bias",
|
||||
"norm3.weight",
|
||||
"norm3.bias",
|
||||
"attn1.to_q.weight",
|
||||
"attn1.to_k.weight",
|
||||
"attn1.to_v.weight",
|
||||
"attn1.to_out.0.weight",
|
||||
"attn1.to_out.0.bias",
|
||||
"attn2.to_q.weight",
|
||||
"attn2.to_k.weight",
|
||||
"attn2.to_v.weight",
|
||||
"attn2.to_out.0.weight",
|
||||
"attn2.to_out.0.bias",
|
||||
"ff.net.0.proj.weight",
|
||||
"ff.net.0.proj.bias",
|
||||
"ff.net.2.weight",
|
||||
"ff.net.2.bias",
|
||||
}
|
||||
|
||||
UNET_MAP_RESNET = {
|
||||
"in_layers.2.weight": "conv1.weight",
|
||||
"in_layers.2.bias": "conv1.bias",
|
||||
"emb_layers.1.weight": "time_emb_proj.weight",
|
||||
"emb_layers.1.bias": "time_emb_proj.bias",
|
||||
"out_layers.3.weight": "conv2.weight",
|
||||
"out_layers.3.bias": "conv2.bias",
|
||||
"skip_connection.weight": "conv_shortcut.weight",
|
||||
"skip_connection.bias": "conv_shortcut.bias",
|
||||
"in_layers.0.weight": "norm1.weight",
|
||||
"in_layers.0.bias": "norm1.bias",
|
||||
"out_layers.0.weight": "norm2.weight",
|
||||
"out_layers.0.bias": "norm2.bias",
|
||||
}
|
||||
|
||||
UNET_MAP_BASIC = {
|
||||
("label_emb.0.0.weight", "class_embedding.linear_1.weight"),
|
||||
("label_emb.0.0.bias", "class_embedding.linear_1.bias"),
|
||||
("label_emb.0.2.weight", "class_embedding.linear_2.weight"),
|
||||
("label_emb.0.2.bias", "class_embedding.linear_2.bias"),
|
||||
("label_emb.0.0.weight", "add_embedding.linear_1.weight"),
|
||||
("label_emb.0.0.bias", "add_embedding.linear_1.bias"),
|
||||
("label_emb.0.2.weight", "add_embedding.linear_2.weight"),
|
||||
("label_emb.0.2.bias", "add_embedding.linear_2.bias"),
|
||||
("input_blocks.0.0.weight", "conv_in.weight"),
|
||||
("input_blocks.0.0.bias", "conv_in.bias"),
|
||||
("out.0.weight", "conv_norm_out.weight"),
|
||||
("out.0.bias", "conv_norm_out.bias"),
|
||||
("out.2.weight", "conv_out.weight"),
|
||||
("out.2.bias", "conv_out.bias"),
|
||||
("time_embed.0.weight", "time_embedding.linear_1.weight"),
|
||||
("time_embed.0.bias", "time_embedding.linear_1.bias"),
|
||||
("time_embed.2.weight", "time_embedding.linear_2.weight"),
|
||||
("time_embed.2.bias", "time_embedding.linear_2.bias")
|
||||
}
|
||||
|
||||
|
||||
def unet_to_diffusers(unet_config):
|
||||
if "num_res_blocks" not in unet_config:
|
||||
return {}
|
||||
num_res_blocks = unet_config["num_res_blocks"]
|
||||
channel_mult = unet_config["channel_mult"]
|
||||
transformer_depth = unet_config["transformer_depth"][:]
|
||||
transformer_depth_output = unet_config["transformer_depth_output"][:]
|
||||
num_blocks = len(channel_mult)
|
||||
|
||||
transformers_mid = unet_config.get("transformer_depth_middle", None)
|
||||
|
||||
diffusers_unet_map = {}
|
||||
for x in range(num_blocks):
|
||||
n = 1 + (num_res_blocks[x] + 1) * x
|
||||
for i in range(num_res_blocks[x]):
|
||||
for b in UNET_MAP_RESNET:
|
||||
diffusers_unet_map["down_blocks.{}.resnets.{}.{}".format(x, i, UNET_MAP_RESNET[b])] = "input_blocks.{}.0.{}".format(n, b)
|
||||
num_transformers = transformer_depth.pop(0)
|
||||
if num_transformers > 0:
|
||||
for b in UNET_MAP_ATTENTIONS:
|
||||
diffusers_unet_map["down_blocks.{}.attentions.{}.{}".format(x, i, b)] = "input_blocks.{}.1.{}".format(n, b)
|
||||
for t in range(num_transformers):
|
||||
for b in TRANSFORMER_BLOCKS:
|
||||
diffusers_unet_map["down_blocks.{}.attentions.{}.transformer_blocks.{}.{}".format(x, i, t, b)] = "input_blocks.{}.1.transformer_blocks.{}.{}".format(n, t, b)
|
||||
n += 1
|
||||
for k in ["weight", "bias"]:
|
||||
diffusers_unet_map["down_blocks.{}.downsamplers.0.conv.{}".format(x, k)] = "input_blocks.{}.0.op.{}".format(n, k)
|
||||
|
||||
i = 0
|
||||
for b in UNET_MAP_ATTENTIONS:
|
||||
diffusers_unet_map["mid_block.attentions.{}.{}".format(i, b)] = "middle_block.1.{}".format(b)
|
||||
for t in range(transformers_mid):
|
||||
for b in TRANSFORMER_BLOCKS:
|
||||
diffusers_unet_map["mid_block.attentions.{}.transformer_blocks.{}.{}".format(i, t, b)] = "middle_block.1.transformer_blocks.{}.{}".format(t, b)
|
||||
|
||||
for i, n in enumerate([0, 2]):
|
||||
for b in UNET_MAP_RESNET:
|
||||
diffusers_unet_map["mid_block.resnets.{}.{}".format(i, UNET_MAP_RESNET[b])] = "middle_block.{}.{}".format(n, b)
|
||||
|
||||
num_res_blocks = list(reversed(num_res_blocks))
|
||||
for x in range(num_blocks):
|
||||
n = (num_res_blocks[x] + 1) * x
|
||||
l = num_res_blocks[x] + 1
|
||||
for i in range(l):
|
||||
c = 0
|
||||
for b in UNET_MAP_RESNET:
|
||||
diffusers_unet_map["up_blocks.{}.resnets.{}.{}".format(x, i, UNET_MAP_RESNET[b])] = "output_blocks.{}.0.{}".format(n, b)
|
||||
c += 1
|
||||
num_transformers = transformer_depth_output.pop()
|
||||
if num_transformers > 0:
|
||||
c += 1
|
||||
for b in UNET_MAP_ATTENTIONS:
|
||||
diffusers_unet_map["up_blocks.{}.attentions.{}.{}".format(x, i, b)] = "output_blocks.{}.1.{}".format(n, b)
|
||||
for t in range(num_transformers):
|
||||
for b in TRANSFORMER_BLOCKS:
|
||||
diffusers_unet_map["up_blocks.{}.attentions.{}.transformer_blocks.{}.{}".format(x, i, t, b)] = "output_blocks.{}.1.transformer_blocks.{}.{}".format(n, t, b)
|
||||
if i == l - 1:
|
||||
for k in ["weight", "bias"]:
|
||||
diffusers_unet_map["up_blocks.{}.upsamplers.0.conv.{}".format(x, k)] = "output_blocks.{}.{}.conv.{}".format(n, c, k)
|
||||
n += 1
|
||||
|
||||
for k in UNET_MAP_BASIC:
|
||||
diffusers_unet_map[k[1]] = k[0]
|
||||
|
||||
return diffusers_unet_map
|
||||
|
||||
|
||||
def swap_scale_shift(weight):
|
||||
shift, scale = weight.chunk(2, dim=0)
|
||||
new_weight = torch.cat([scale, shift], dim=0)
|
||||
return new_weight
|
||||
|
||||
|
||||
MMDIT_MAP_BASIC = {
|
||||
("context_embedder.bias", "context_embedder.bias"),
|
||||
("context_embedder.weight", "context_embedder.weight"),
|
||||
("t_embedder.mlp.0.bias", "time_text_embed.timestep_embedder.linear_1.bias"),
|
||||
("t_embedder.mlp.0.weight", "time_text_embed.timestep_embedder.linear_1.weight"),
|
||||
("t_embedder.mlp.2.bias", "time_text_embed.timestep_embedder.linear_2.bias"),
|
||||
("t_embedder.mlp.2.weight", "time_text_embed.timestep_embedder.linear_2.weight"),
|
||||
("x_embedder.proj.bias", "pos_embed.proj.bias"),
|
||||
("x_embedder.proj.weight", "pos_embed.proj.weight"),
|
||||
("y_embedder.mlp.0.bias", "time_text_embed.text_embedder.linear_1.bias"),
|
||||
("y_embedder.mlp.0.weight", "time_text_embed.text_embedder.linear_1.weight"),
|
||||
("y_embedder.mlp.2.bias", "time_text_embed.text_embedder.linear_2.bias"),
|
||||
("y_embedder.mlp.2.weight", "time_text_embed.text_embedder.linear_2.weight"),
|
||||
("pos_embed", "pos_embed.pos_embed"),
|
||||
("final_layer.adaLN_modulation.1.bias", "norm_out.linear.bias", swap_scale_shift),
|
||||
("final_layer.adaLN_modulation.1.weight", "norm_out.linear.weight", swap_scale_shift),
|
||||
("final_layer.linear.bias", "proj_out.bias"),
|
||||
("final_layer.linear.weight", "proj_out.weight"),
|
||||
}
|
||||
|
||||
MMDIT_MAP_BLOCK = {
|
||||
("context_block.adaLN_modulation.1.bias", "norm1_context.linear.bias"),
|
||||
("context_block.adaLN_modulation.1.weight", "norm1_context.linear.weight"),
|
||||
("context_block.attn.proj.bias", "attn.to_add_out.bias"),
|
||||
("context_block.attn.proj.weight", "attn.to_add_out.weight"),
|
||||
("context_block.mlp.fc1.bias", "ff_context.net.0.proj.bias"),
|
||||
("context_block.mlp.fc1.weight", "ff_context.net.0.proj.weight"),
|
||||
("context_block.mlp.fc2.bias", "ff_context.net.2.bias"),
|
||||
("context_block.mlp.fc2.weight", "ff_context.net.2.weight"),
|
||||
("x_block.adaLN_modulation.1.bias", "norm1.linear.bias"),
|
||||
("x_block.adaLN_modulation.1.weight", "norm1.linear.weight"),
|
||||
("x_block.attn.proj.bias", "attn.to_out.0.bias"),
|
||||
("x_block.attn.proj.weight", "attn.to_out.0.weight"),
|
||||
("x_block.mlp.fc1.bias", "ff.net.0.proj.bias"),
|
||||
("x_block.mlp.fc1.weight", "ff.net.0.proj.weight"),
|
||||
("x_block.mlp.fc2.bias", "ff.net.2.bias"),
|
||||
("x_block.mlp.fc2.weight", "ff.net.2.weight"),
|
||||
}
|
||||
|
||||
|
||||
def mmdit_to_diffusers(mmdit_config, output_prefix=""):
|
||||
key_map = {}
|
||||
|
||||
depth = mmdit_config.get("depth", 0)
|
||||
num_blocks = mmdit_config.get("num_blocks", depth)
|
||||
for i in range(num_blocks):
|
||||
block_from = "transformer_blocks.{}".format(i)
|
||||
block_to = "{}joint_blocks.{}".format(output_prefix, i)
|
||||
|
||||
offset = depth * 64
|
||||
|
||||
for end in ("weight", "bias"):
|
||||
k = "{}.attn.".format(block_from)
|
||||
qkv = "{}.x_block.attn.qkv.{}".format(block_to, end)
|
||||
key_map["{}to_q.{}".format(k, end)] = (qkv, (0, 0, offset))
|
||||
key_map["{}to_k.{}".format(k, end)] = (qkv, (0, offset, offset))
|
||||
key_map["{}to_v.{}".format(k, end)] = (qkv, (0, offset * 2, offset))
|
||||
|
||||
qkv = "{}.context_block.attn.qkv.{}".format(block_to, end)
|
||||
key_map["{}add_q_proj.{}".format(k, end)] = (qkv, (0, 0, offset))
|
||||
key_map["{}add_k_proj.{}".format(k, end)] = (qkv, (0, offset, offset))
|
||||
key_map["{}add_v_proj.{}".format(k, end)] = (qkv, (0, offset * 2, offset))
|
||||
|
||||
for k in MMDIT_MAP_BLOCK:
|
||||
key_map["{}.{}".format(block_from, k[1])] = "{}.{}".format(block_to, k[0])
|
||||
|
||||
map_basic = MMDIT_MAP_BASIC.copy()
|
||||
map_basic.add(("joint_blocks.{}.context_block.adaLN_modulation.1.bias".format(depth - 1), "transformer_blocks.{}.norm1_context.linear.bias".format(depth - 1), swap_scale_shift))
|
||||
map_basic.add(("joint_blocks.{}.context_block.adaLN_modulation.1.weight".format(depth - 1), "transformer_blocks.{}.norm1_context.linear.weight".format(depth - 1), swap_scale_shift))
|
||||
|
||||
for k in map_basic:
|
||||
if len(k) > 2:
|
||||
key_map[k[1]] = ("{}{}".format(output_prefix, k[0]), None, k[2])
|
||||
else:
|
||||
key_map[k[1]] = "{}{}".format(output_prefix, k[0])
|
||||
|
||||
return key_map
|
||||
|
||||
|
||||
def auraflow_to_diffusers(mmdit_config, output_prefix=""):
|
||||
n_double_layers = mmdit_config.get("n_double_layers", 0)
|
||||
n_layers = mmdit_config.get("n_layers", 0)
|
||||
|
||||
key_map = {}
|
||||
for i in range(n_layers):
|
||||
if i < n_double_layers:
|
||||
index = i
|
||||
prefix_from = "joint_transformer_blocks"
|
||||
prefix_to = "{}double_layers".format(output_prefix)
|
||||
block_map = {
|
||||
"attn.to_q.weight": "attn.w2q.weight",
|
||||
"attn.to_k.weight": "attn.w2k.weight",
|
||||
"attn.to_v.weight": "attn.w2v.weight",
|
||||
"attn.to_out.0.weight": "attn.w2o.weight",
|
||||
"attn.add_q_proj.weight": "attn.w1q.weight",
|
||||
"attn.add_k_proj.weight": "attn.w1k.weight",
|
||||
"attn.add_v_proj.weight": "attn.w1v.weight",
|
||||
"attn.to_add_out.weight": "attn.w1o.weight",
|
||||
"ff.linear_1.weight": "mlpX.c_fc1.weight",
|
||||
"ff.linear_2.weight": "mlpX.c_fc2.weight",
|
||||
"ff.out_projection.weight": "mlpX.c_proj.weight",
|
||||
"ff_context.linear_1.weight": "mlpC.c_fc1.weight",
|
||||
"ff_context.linear_2.weight": "mlpC.c_fc2.weight",
|
||||
"ff_context.out_projection.weight": "mlpC.c_proj.weight",
|
||||
"norm1.linear.weight": "modX.1.weight",
|
||||
"norm1_context.linear.weight": "modC.1.weight",
|
||||
}
|
||||
else:
|
||||
index = i - n_double_layers
|
||||
prefix_from = "single_transformer_blocks"
|
||||
prefix_to = "{}single_layers".format(output_prefix)
|
||||
|
||||
block_map = {
|
||||
"attn.to_q.weight": "attn.w1q.weight",
|
||||
"attn.to_k.weight": "attn.w1k.weight",
|
||||
"attn.to_v.weight": "attn.w1v.weight",
|
||||
"attn.to_out.0.weight": "attn.w1o.weight",
|
||||
"norm1.linear.weight": "modCX.1.weight",
|
||||
"ff.linear_1.weight": "mlp.c_fc1.weight",
|
||||
"ff.linear_2.weight": "mlp.c_fc2.weight",
|
||||
"ff.out_projection.weight": "mlp.c_proj.weight"
|
||||
}
|
||||
|
||||
for k in block_map:
|
||||
key_map["{}.{}.{}".format(prefix_from, index, k)] = "{}.{}.{}".format(prefix_to, index, block_map[k])
|
||||
|
||||
MAP_BASIC = {
|
||||
("positional_encoding", "pos_embed.pos_embed"),
|
||||
("register_tokens", "register_tokens"),
|
||||
("t_embedder.mlp.0.weight", "time_step_proj.linear_1.weight"),
|
||||
("t_embedder.mlp.0.bias", "time_step_proj.linear_1.bias"),
|
||||
("t_embedder.mlp.2.weight", "time_step_proj.linear_2.weight"),
|
||||
("t_embedder.mlp.2.bias", "time_step_proj.linear_2.bias"),
|
||||
("cond_seq_linear.weight", "context_embedder.weight"),
|
||||
("init_x_linear.weight", "pos_embed.proj.weight"),
|
||||
("init_x_linear.bias", "pos_embed.proj.bias"),
|
||||
("final_linear.weight", "proj_out.weight"),
|
||||
("modF.1.weight", "norm_out.linear.weight", swap_scale_shift),
|
||||
}
|
||||
|
||||
for k in MAP_BASIC:
|
||||
if len(k) > 2:
|
||||
key_map[k[1]] = ("{}{}".format(output_prefix, k[0]), None, k[2])
|
||||
else:
|
||||
key_map[k[1]] = "{}{}".format(output_prefix, k[0])
|
||||
|
||||
return key_map
|
||||
|
||||
|
||||
def repeat_to_batch_size(tensor, batch_size, dim=0):
|
||||
if tensor.shape[dim] > batch_size:
|
||||
return tensor.narrow(dim, 0, batch_size)
|
||||
elif tensor.shape[dim] < batch_size:
|
||||
return tensor.repeat(dim * [1] + [math.ceil(batch_size / tensor.shape[dim])] + [1] * (len(tensor.shape) - 1 - dim)).narrow(dim, 0, batch_size)
|
||||
return tensor
|
||||
|
||||
|
||||
def resize_to_batch_size(tensor, batch_size):
|
||||
in_batch_size = tensor.shape[0]
|
||||
if in_batch_size == batch_size:
|
||||
return tensor
|
||||
|
||||
if batch_size <= 1:
|
||||
return tensor[:batch_size]
|
||||
|
||||
output = torch.empty([batch_size] + list(tensor.shape)[1:], dtype=tensor.dtype, device=tensor.device)
|
||||
if batch_size < in_batch_size:
|
||||
scale = (in_batch_size - 1) / (batch_size - 1)
|
||||
for i in range(batch_size):
|
||||
output[i] = tensor[min(round(i * scale), in_batch_size - 1)]
|
||||
else:
|
||||
scale = in_batch_size / batch_size
|
||||
for i in range(batch_size):
|
||||
output[i] = tensor[min(math.floor((i + 0.5) * scale), in_batch_size - 1)]
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def convert_sd_to(state_dict, dtype):
|
||||
keys = list(state_dict.keys())
|
||||
for k in keys:
|
||||
state_dict[k] = state_dict[k].to(dtype)
|
||||
return state_dict
|
||||
|
||||
|
||||
def safetensors_header(safetensors_path, max_size=100 * 1024 * 1024):
|
||||
with open(safetensors_path, "rb") as f:
|
||||
header = f.read(8)
|
||||
length_of_header = struct.unpack('<Q', header)[0]
|
||||
if length_of_header > max_size:
|
||||
return None
|
||||
return f.read(length_of_header)
|
||||
|
||||
|
||||
def set_attr(obj, attr, value):
|
||||
attrs = attr.split(".")
|
||||
for name in attrs[:-1]:
|
||||
obj = getattr(obj, name)
|
||||
prev = getattr(obj, attrs[-1])
|
||||
setattr(obj, attrs[-1], value)
|
||||
return prev
|
||||
|
||||
|
||||
def set_attr_param(obj, attr, value):
|
||||
return set_attr(obj, attr, torch.nn.Parameter(value, requires_grad=False))
|
||||
|
||||
|
||||
def copy_to_param(obj, attr, value):
|
||||
# inplace update tensor instead of replacing it
|
||||
attrs = attr.split(".")
|
||||
for name in attrs[:-1]:
|
||||
obj = getattr(obj, name)
|
||||
prev = getattr(obj, attrs[-1])
|
||||
prev.data.copy_(value)
|
||||
|
||||
|
||||
def get_attr(obj, attr):
|
||||
attrs = attr.split(".")
|
||||
for name in attrs:
|
||||
obj = getattr(obj, name)
|
||||
return obj
|
||||
Reference in New Issue
Block a user