Removed all submodules. Submodule free now, yay.

This commit is contained in:
Jaret Burkett
2025-04-18 10:39:15 -06:00
parent bd2de5b74e
commit bfe29e2151
18 changed files with 1246 additions and 62 deletions

3
.gitignore vendored
View File

@@ -179,4 +179,5 @@ cython_debug/
.vscode/settings.json .vscode/settings.json
.DS_Store .DS_Store
._.DS_Store ._.DS_Store
aitk_db.db aitk_db.db
/notes.md

4
.gitmodules vendored
View File

@@ -1,4 +0,0 @@
[submodule "repositories/sd-scripts"]
path = repositories/sd-scripts
url = https://github.com/kohya-ss/sd-scripts.git
commit = b78c0e2a69e52ce6c79abc6c8c82d1a9cabcf05c

View File

@@ -37,7 +37,6 @@ Linux:
```bash ```bash
git clone https://github.com/ostris/ai-toolkit.git git clone https://github.com/ostris/ai-toolkit.git
cd ai-toolkit cd ai-toolkit
git submodule update --init --recursive
python3 -m venv venv python3 -m venv venv
source venv/bin/activate source venv/bin/activate
# install torch first # install torch first
@@ -49,7 +48,6 @@ Windows:
```bash ```bash
git clone https://github.com/ostris/ai-toolkit.git git clone https://github.com/ostris/ai-toolkit.git
cd ai-toolkit cd ai-toolkit
git submodule update --init --recursive
python -m venv venv python -m venv venv
.\venv\Scripts\activate .\venv\Scripts\activate
pip install --no-cache-dir torch==2.6.0 torchvision==0.21.0 --index-url https://download.pytorch.org/whl/cu126 pip install --no-cache-dir torch==2.6.0 torchvision==0.21.0 --index-url https://download.pytorch.org/whl/cu126

View File

@@ -1,12 +1,5 @@
from jobs import BaseJob from jobs import BaseJob
from collections import OrderedDict from collections import OrderedDict
from typing import List
from jobs.process import GenerateProcess
from toolkit.paths import REPOS_ROOT
import sys
sys.path.append(REPOS_ROOT)
process_dict = { process_dict = {
'to_folder': 'GenerateProcess', 'to_folder': 'GenerateProcess',

View File

@@ -7,12 +7,7 @@ from collections import OrderedDict
from typing import List from typing import List
from jobs.process import BaseExtractProcess, TrainFineTuneProcess from jobs.process import BaseExtractProcess, TrainFineTuneProcess
from datetime import datetime from datetime import datetime
import yaml
from toolkit.paths import REPOS_ROOT
import sys
sys.path.append(REPOS_ROOT)
process_dict = { process_dict = {
'vae': 'TrainVAEProcess', 'vae': 'TrainVAEProcess',

View File

@@ -11,13 +11,9 @@ import random
from transformers import CLIPImageProcessor from transformers import CLIPImageProcessor
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from toolkit.paths import SD_SCRIPTS_ROOT
import torchvision.transforms.functional import torchvision.transforms.functional
from toolkit.image_utils import save_tensors, show_img, show_tensors from toolkit.image_utils import save_tensors, show_img, show_tensors
sys.path.append(SD_SCRIPTS_ROOT)
from library.model_util import load_vae
from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO
from toolkit.data_loader import AiToolkitDataset, get_dataloader_from_datasets, \ from toolkit.data_loader import AiToolkitDataset, get_dataloader_from_datasets, \
trigger_dataloader_setup_epoch trigger_dataloader_setup_epoch

View File

@@ -19,16 +19,12 @@ from toolkit.models.te_adapter import TEAdapter
from toolkit.models.te_aug_adapter import TEAugAdapter from toolkit.models.te_aug_adapter import TEAugAdapter
from toolkit.models.vd_adapter import VisionDirectAdapter from toolkit.models.vd_adapter import VisionDirectAdapter
from toolkit.models.redux import ReduxImageEncoder from toolkit.models.redux import ReduxImageEncoder
from toolkit.paths import REPOS_ROOT
from toolkit.photomaker import PhotoMakerIDEncoder, FuseModule, PhotoMakerCLIPEncoder from toolkit.photomaker import PhotoMakerIDEncoder, FuseModule, PhotoMakerCLIPEncoder
from toolkit.saving import load_ip_adapter_model, load_custom_adapter_model from toolkit.saving import load_ip_adapter_model, load_custom_adapter_model
from toolkit.train_tools import get_torch_dtype from toolkit.train_tools import get_torch_dtype
from toolkit.models.pixtral_vision import PixtralVisionEncoderCompatible, PixtralVisionImagePreprocessorCompatible from toolkit.models.pixtral_vision import PixtralVisionEncoderCompatible, PixtralVisionImagePreprocessorCompatible
import random import random
from toolkit.util.mask import generate_random_mask from toolkit.util.mask import generate_random_mask
sys.path.append(REPOS_ROOT)
from typing import TYPE_CHECKING, Union, Iterator, Mapping, Any, Tuple, List, Optional, Dict from typing import TYPE_CHECKING, Union, Iterator, Mapping, Any, Tuple, List, Optional, Dict
from collections import OrderedDict from collections import OrderedDict
from toolkit.config_modules import AdapterConfig, AdapterTypes, TrainConfig from toolkit.config_modules import AdapterConfig, AdapterTypes, TrainConfig

1221
toolkit/kohya_lora.py Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -14,16 +14,11 @@ from toolkit.models.lokr import LokrModule
from .config_modules import NetworkConfig from .config_modules import NetworkConfig
from .lorm import count_parameters from .lorm import count_parameters
from .network_mixins import ToolkitNetworkMixin, ToolkitModuleMixin, ExtractableModuleMixin from .network_mixins import ToolkitNetworkMixin, ToolkitModuleMixin, ExtractableModuleMixin
from .paths import SD_SCRIPTS_ROOT
sys.path.append(SD_SCRIPTS_ROOT) from toolkit.kohya_lora import LoRANetwork
from networks.lora import LoRANetwork, get_block_index
from toolkit.models.DoRA import DoRAModule from toolkit.models.DoRA import DoRAModule
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from torch.utils.checkpoint import checkpoint
if TYPE_CHECKING: if TYPE_CHECKING:
from toolkit.stable_diffusion_model import StableDiffusion from toolkit.stable_diffusion_model import StableDiffusion
@@ -389,15 +384,6 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
if lora_name in modules_dim: if lora_name in modules_dim:
dim = modules_dim[lora_name] dim = modules_dim[lora_name]
alpha = modules_alpha[lora_name] alpha = modules_alpha[lora_name]
elif is_unet and block_dims is not None:
# U-Netでblock_dims指定あり
block_idx = get_block_index(lora_name)
if is_linear or is_conv2d_1x1:
dim = block_dims[block_idx]
alpha = block_alphas[block_idx]
elif conv_block_dims is not None:
dim = conv_block_dims[block_idx]
alpha = conv_block_alphas[block_idx]
else: else:
# 通常、すべて対象とする # 通常、すべて対象とする
if is_linear or is_conv2d_1x1: if is_linear or is_conv2d_1x1:

View File

@@ -7,9 +7,6 @@ import weakref
from typing import Union, TYPE_CHECKING from typing import Union, TYPE_CHECKING
from diffusers import Transformer2DModel from diffusers import Transformer2DModel
from transformers import T5EncoderModel, CLIPTextModel, CLIPTokenizer, T5Tokenizer, CLIPVisionModelWithProjection
from toolkit.paths import REPOS_ROOT
sys.path.append(REPOS_ROOT)
if TYPE_CHECKING: if TYPE_CHECKING:

View File

@@ -8,7 +8,6 @@ from toolkit.config_modules import GenerateImageConfig, ModelConfig
from toolkit.dequantize import patch_dequantization_on_save from toolkit.dequantize import patch_dequantization_on_save
from toolkit.models.base_model import BaseModel from toolkit.models.base_model import BaseModel
from toolkit.prompt_utils import PromptEmbeds from toolkit.prompt_utils import PromptEmbeds
from toolkit.paths import REPOS_ROOT
from transformers import AutoTokenizer, UMT5EncoderModel from transformers import AutoTokenizer, UMT5EncoderModel
from diffusers import AutoencoderKLWan, WanPipeline, WanTransformer3DModel from diffusers import AutoencoderKLWan, WanPipeline, WanTransformer3DModel
import os import os
@@ -34,7 +33,6 @@ from diffusers import FlowMatchEulerDiscreteScheduler, UniPCMultistepScheduler
from typing import TYPE_CHECKING, List from typing import TYPE_CHECKING, List
from toolkit.accelerator import unwrap_model from toolkit.accelerator import unwrap_model
from toolkit.samplers.custom_flowmatch_sampler import CustomFlowMatchEulerDiscreteScheduler from toolkit.samplers.custom_flowmatch_sampler import CustomFlowMatchEulerDiscreteScheduler
from torchvision.transforms import Resize, ToPILImage
from tqdm import tqdm from tqdm import tqdm
import torch.nn.functional as F import torch.nn.functional as F
from diffusers.pipelines.wan.pipeline_output import WanPipelineOutput from diffusers.pipelines.wan.pipeline_output import WanPipelineOutput

View File

@@ -6,7 +6,6 @@ from toolkit.accelerator import unwrap_model
from toolkit.basic import flush from toolkit.basic import flush
from toolkit.config_modules import GenerateImageConfig, ModelConfig from toolkit.config_modules import GenerateImageConfig, ModelConfig
from toolkit.prompt_utils import PromptEmbeds from toolkit.prompt_utils import PromptEmbeds
from toolkit.paths import REPOS_ROOT
from transformers import AutoTokenizer, UMT5EncoderModel from transformers import AutoTokenizer, UMT5EncoderModel
from diffusers import AutoencoderKLWan, WanImageToVideoPipeline, WanTransformer3DModel from diffusers import AutoencoderKLWan, WanImageToVideoPipeline, WanTransformer3DModel
import os import os

View File

@@ -2,8 +2,6 @@ import os
TOOLKIT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) TOOLKIT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
CONFIG_ROOT = os.path.join(TOOLKIT_ROOT, 'config') CONFIG_ROOT = os.path.join(TOOLKIT_ROOT, 'config')
SD_SCRIPTS_ROOT = os.path.join(TOOLKIT_ROOT, "repositories", "sd-scripts")
REPOS_ROOT = os.path.join(TOOLKIT_ROOT, "repositories")
KEYMAPS_ROOT = os.path.join(TOOLKIT_ROOT, "toolkit", "keymaps") KEYMAPS_ROOT = os.path.join(TOOLKIT_ROOT, "toolkit", "keymaps")
ORIG_CONFIGS_ROOT = os.path.join(TOOLKIT_ROOT, "toolkit", "orig_configs") ORIG_CONFIGS_ROOT = os.path.join(TOOLKIT_ROOT, "toolkit", "orig_configs")
DIFFUSERS_CONFIGS_ROOT = os.path.join(TOOLKIT_ROOT, "toolkit", "diffusers_configs") DIFFUSERS_CONFIGS_ROOT = os.path.join(TOOLKIT_ROOT, "toolkit", "diffusers_configs")

View File

@@ -8,11 +8,8 @@ from torch.nn import Parameter
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
from toolkit.basic import adain from toolkit.basic import adain
from toolkit.paths import REPOS_ROOT
from toolkit.saving import load_ip_adapter_model from toolkit.saving import load_ip_adapter_model
from toolkit.train_tools import get_torch_dtype from toolkit.train_tools import get_torch_dtype
sys.path.append(REPOS_ROOT)
from typing import TYPE_CHECKING, Union, Iterator, Mapping, Any, Tuple, List, Optional, Dict from typing import TYPE_CHECKING, Union, Iterator, Mapping, Any, Tuple, List, Optional, Dict
from collections import OrderedDict from collections import OrderedDict
from toolkit.config_modules import AdapterConfig from toolkit.config_modules import AdapterConfig

View File

@@ -26,13 +26,12 @@ from toolkit.clip_vision_adapter import ClipVisionAdapter
from toolkit.custom_adapter import CustomAdapter from toolkit.custom_adapter import CustomAdapter
from toolkit.dequantize import patch_dequantization_on_save from toolkit.dequantize import patch_dequantization_on_save
from toolkit.ip_adapter import IPAdapter from toolkit.ip_adapter import IPAdapter
from library.model_util import convert_unet_state_dict_to_sd, convert_text_encoder_state_dict_to_sd_v2, \ from toolkit.util.vae import load_vae
convert_vae_state_dict, load_vae
from toolkit import train_tools from toolkit import train_tools
from toolkit.config_modules import ModelConfig, GenerateImageConfig, ModelArch from toolkit.config_modules import ModelConfig, GenerateImageConfig, ModelArch
from toolkit.metadata import get_meta_for_safetensors from toolkit.metadata import get_meta_for_safetensors
from toolkit.models.decorator import Decorator from toolkit.models.decorator import Decorator
from toolkit.paths import REPOS_ROOT, KEYMAPS_ROOT from toolkit.paths import KEYMAPS_ROOT
from toolkit.prompt_utils import inject_trigger_into_prompt, PromptEmbeds, concat_prompt_embeds from toolkit.prompt_utils import inject_trigger_into_prompt, PromptEmbeds, concat_prompt_embeds
from toolkit.reference_adapter import ReferenceAdapter from toolkit.reference_adapter import ReferenceAdapter
from toolkit.sampler import get_sampler from toolkit.sampler import get_sampler

View File

@@ -6,11 +6,6 @@ import time
from typing import TYPE_CHECKING, Union, List from typing import TYPE_CHECKING, Union, List
import sys import sys
from torch.cuda.amp import GradScaler
from toolkit.paths import SD_SCRIPTS_ROOT
sys.path.append(SD_SCRIPTS_ROOT)
from diffusers import ( from diffusers import (
DDPMScheduler, DDPMScheduler,

20
toolkit/util/vae.py Normal file
View File

@@ -0,0 +1,20 @@
from diffusers import AutoencoderKL
def load_vae(vae_path, dtype):
try:
vae = AutoencoderKL.from_pretrained(
vae_path,
torch_dtype=dtype,
)
except Exception as e:
try:
vae = AutoencoderKL.from_pretrained(
vae_path.vae_path,
subfolder="vae",
torch_dtype=dtype,
)
except Exception as e:
raise ValueError(f"Failed to load VAE from {vae_path}: {e}")
vae.to(dtype)
return vae