mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-30 03:01:28 +00:00
Removed all submodules. Submodule free now, yay.
This commit is contained in:
@@ -19,16 +19,12 @@ from toolkit.models.te_adapter import TEAdapter
|
||||
from toolkit.models.te_aug_adapter import TEAugAdapter
|
||||
from toolkit.models.vd_adapter import VisionDirectAdapter
|
||||
from toolkit.models.redux import ReduxImageEncoder
|
||||
from toolkit.paths import REPOS_ROOT
|
||||
from toolkit.photomaker import PhotoMakerIDEncoder, FuseModule, PhotoMakerCLIPEncoder
|
||||
from toolkit.saving import load_ip_adapter_model, load_custom_adapter_model
|
||||
from toolkit.train_tools import get_torch_dtype
|
||||
from toolkit.models.pixtral_vision import PixtralVisionEncoderCompatible, PixtralVisionImagePreprocessorCompatible
|
||||
import random
|
||||
|
||||
from toolkit.util.mask import generate_random_mask
|
||||
|
||||
sys.path.append(REPOS_ROOT)
|
||||
from typing import TYPE_CHECKING, Union, Iterator, Mapping, Any, Tuple, List, Optional, Dict
|
||||
from collections import OrderedDict
|
||||
from toolkit.config_modules import AdapterConfig, AdapterTypes, TrainConfig
|
||||
|
||||
1221
toolkit/kohya_lora.py
Normal file
1221
toolkit/kohya_lora.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -14,16 +14,11 @@ from toolkit.models.lokr import LokrModule
|
||||
from .config_modules import NetworkConfig
|
||||
from .lorm import count_parameters
|
||||
from .network_mixins import ToolkitNetworkMixin, ToolkitModuleMixin, ExtractableModuleMixin
|
||||
from .paths import SD_SCRIPTS_ROOT
|
||||
|
||||
sys.path.append(SD_SCRIPTS_ROOT)
|
||||
|
||||
from networks.lora import LoRANetwork, get_block_index
|
||||
from toolkit.kohya_lora import LoRANetwork
|
||||
from toolkit.models.DoRA import DoRAModule
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from toolkit.stable_diffusion_model import StableDiffusion
|
||||
|
||||
@@ -389,15 +384,6 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
|
||||
if lora_name in modules_dim:
|
||||
dim = modules_dim[lora_name]
|
||||
alpha = modules_alpha[lora_name]
|
||||
elif is_unet and block_dims is not None:
|
||||
# U-Netでblock_dims指定あり
|
||||
block_idx = get_block_index(lora_name)
|
||||
if is_linear or is_conv2d_1x1:
|
||||
dim = block_dims[block_idx]
|
||||
alpha = block_alphas[block_idx]
|
||||
elif conv_block_dims is not None:
|
||||
dim = conv_block_dims[block_idx]
|
||||
alpha = conv_block_alphas[block_idx]
|
||||
else:
|
||||
# 通常、すべて対象とする
|
||||
if is_linear or is_conv2d_1x1:
|
||||
|
||||
@@ -7,9 +7,6 @@ import weakref
|
||||
from typing import Union, TYPE_CHECKING
|
||||
|
||||
from diffusers import Transformer2DModel
|
||||
from transformers import T5EncoderModel, CLIPTextModel, CLIPTokenizer, T5Tokenizer, CLIPVisionModelWithProjection
|
||||
from toolkit.paths import REPOS_ROOT
|
||||
sys.path.append(REPOS_ROOT)
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
||||
@@ -8,7 +8,6 @@ from toolkit.config_modules import GenerateImageConfig, ModelConfig
|
||||
from toolkit.dequantize import patch_dequantization_on_save
|
||||
from toolkit.models.base_model import BaseModel
|
||||
from toolkit.prompt_utils import PromptEmbeds
|
||||
from toolkit.paths import REPOS_ROOT
|
||||
from transformers import AutoTokenizer, UMT5EncoderModel
|
||||
from diffusers import AutoencoderKLWan, WanPipeline, WanTransformer3DModel
|
||||
import os
|
||||
@@ -34,7 +33,6 @@ from diffusers import FlowMatchEulerDiscreteScheduler, UniPCMultistepScheduler
|
||||
from typing import TYPE_CHECKING, List
|
||||
from toolkit.accelerator import unwrap_model
|
||||
from toolkit.samplers.custom_flowmatch_sampler import CustomFlowMatchEulerDiscreteScheduler
|
||||
from torchvision.transforms import Resize, ToPILImage
|
||||
from tqdm import tqdm
|
||||
import torch.nn.functional as F
|
||||
from diffusers.pipelines.wan.pipeline_output import WanPipelineOutput
|
||||
|
||||
@@ -6,7 +6,6 @@ from toolkit.accelerator import unwrap_model
|
||||
from toolkit.basic import flush
|
||||
from toolkit.config_modules import GenerateImageConfig, ModelConfig
|
||||
from toolkit.prompt_utils import PromptEmbeds
|
||||
from toolkit.paths import REPOS_ROOT
|
||||
from transformers import AutoTokenizer, UMT5EncoderModel
|
||||
from diffusers import AutoencoderKLWan, WanImageToVideoPipeline, WanTransformer3DModel
|
||||
import os
|
||||
|
||||
@@ -2,8 +2,6 @@ import os
|
||||
|
||||
TOOLKIT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
CONFIG_ROOT = os.path.join(TOOLKIT_ROOT, 'config')
|
||||
SD_SCRIPTS_ROOT = os.path.join(TOOLKIT_ROOT, "repositories", "sd-scripts")
|
||||
REPOS_ROOT = os.path.join(TOOLKIT_ROOT, "repositories")
|
||||
KEYMAPS_ROOT = os.path.join(TOOLKIT_ROOT, "toolkit", "keymaps")
|
||||
ORIG_CONFIGS_ROOT = os.path.join(TOOLKIT_ROOT, "toolkit", "orig_configs")
|
||||
DIFFUSERS_CONFIGS_ROOT = os.path.join(TOOLKIT_ROOT, "toolkit", "diffusers_configs")
|
||||
|
||||
@@ -8,11 +8,8 @@ from torch.nn import Parameter
|
||||
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
|
||||
|
||||
from toolkit.basic import adain
|
||||
from toolkit.paths import REPOS_ROOT
|
||||
from toolkit.saving import load_ip_adapter_model
|
||||
from toolkit.train_tools import get_torch_dtype
|
||||
|
||||
sys.path.append(REPOS_ROOT)
|
||||
from typing import TYPE_CHECKING, Union, Iterator, Mapping, Any, Tuple, List, Optional, Dict
|
||||
from collections import OrderedDict
|
||||
from toolkit.config_modules import AdapterConfig
|
||||
|
||||
@@ -26,13 +26,12 @@ from toolkit.clip_vision_adapter import ClipVisionAdapter
|
||||
from toolkit.custom_adapter import CustomAdapter
|
||||
from toolkit.dequantize import patch_dequantization_on_save
|
||||
from toolkit.ip_adapter import IPAdapter
|
||||
from library.model_util import convert_unet_state_dict_to_sd, convert_text_encoder_state_dict_to_sd_v2, \
|
||||
convert_vae_state_dict, load_vae
|
||||
from toolkit.util.vae import load_vae
|
||||
from toolkit import train_tools
|
||||
from toolkit.config_modules import ModelConfig, GenerateImageConfig, ModelArch
|
||||
from toolkit.metadata import get_meta_for_safetensors
|
||||
from toolkit.models.decorator import Decorator
|
||||
from toolkit.paths import REPOS_ROOT, KEYMAPS_ROOT
|
||||
from toolkit.paths import KEYMAPS_ROOT
|
||||
from toolkit.prompt_utils import inject_trigger_into_prompt, PromptEmbeds, concat_prompt_embeds
|
||||
from toolkit.reference_adapter import ReferenceAdapter
|
||||
from toolkit.sampler import get_sampler
|
||||
|
||||
@@ -6,11 +6,6 @@ import time
|
||||
from typing import TYPE_CHECKING, Union, List
|
||||
import sys
|
||||
|
||||
from torch.cuda.amp import GradScaler
|
||||
|
||||
from toolkit.paths import SD_SCRIPTS_ROOT
|
||||
|
||||
sys.path.append(SD_SCRIPTS_ROOT)
|
||||
|
||||
from diffusers import (
|
||||
DDPMScheduler,
|
||||
|
||||
20
toolkit/util/vae.py
Normal file
20
toolkit/util/vae.py
Normal file
@@ -0,0 +1,20 @@
|
||||
from diffusers import AutoencoderKL
|
||||
|
||||
|
||||
def load_vae(vae_path, dtype):
|
||||
try:
|
||||
vae = AutoencoderKL.from_pretrained(
|
||||
vae_path,
|
||||
torch_dtype=dtype,
|
||||
)
|
||||
except Exception as e:
|
||||
try:
|
||||
vae = AutoencoderKL.from_pretrained(
|
||||
vae_path.vae_path,
|
||||
subfolder="vae",
|
||||
torch_dtype=dtype,
|
||||
)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to load VAE from {vae_path}: {e}")
|
||||
vae.to(dtype)
|
||||
return vae
|
||||
Reference in New Issue
Block a user