mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 08:29:45 +00:00
Removed all submodules. Submodule free now, yay.
This commit is contained in:
3
.gitignore
vendored
3
.gitignore
vendored
@@ -179,4 +179,5 @@ cython_debug/
|
||||
.vscode/settings.json
|
||||
.DS_Store
|
||||
._.DS_Store
|
||||
aitk_db.db
|
||||
aitk_db.db
|
||||
/notes.md
|
||||
4
.gitmodules
vendored
4
.gitmodules
vendored
@@ -1,4 +0,0 @@
|
||||
[submodule "repositories/sd-scripts"]
|
||||
path = repositories/sd-scripts
|
||||
url = https://github.com/kohya-ss/sd-scripts.git
|
||||
commit = b78c0e2a69e52ce6c79abc6c8c82d1a9cabcf05c
|
||||
|
||||
@@ -37,7 +37,6 @@ Linux:
|
||||
```bash
|
||||
git clone https://github.com/ostris/ai-toolkit.git
|
||||
cd ai-toolkit
|
||||
git submodule update --init --recursive
|
||||
python3 -m venv venv
|
||||
source venv/bin/activate
|
||||
# install torch first
|
||||
@@ -49,7 +48,6 @@ Windows:
|
||||
```bash
|
||||
git clone https://github.com/ostris/ai-toolkit.git
|
||||
cd ai-toolkit
|
||||
git submodule update --init --recursive
|
||||
python -m venv venv
|
||||
.\venv\Scripts\activate
|
||||
pip install --no-cache-dir torch==2.6.0 torchvision==0.21.0 --index-url https://download.pytorch.org/whl/cu126
|
||||
|
||||
@@ -1,12 +1,5 @@
|
||||
from jobs import BaseJob
|
||||
from collections import OrderedDict
|
||||
from typing import List
|
||||
from jobs.process import GenerateProcess
|
||||
from toolkit.paths import REPOS_ROOT
|
||||
|
||||
import sys
|
||||
|
||||
sys.path.append(REPOS_ROOT)
|
||||
|
||||
process_dict = {
|
||||
'to_folder': 'GenerateProcess',
|
||||
|
||||
@@ -7,12 +7,7 @@ from collections import OrderedDict
|
||||
from typing import List
|
||||
from jobs.process import BaseExtractProcess, TrainFineTuneProcess
|
||||
from datetime import datetime
|
||||
import yaml
|
||||
from toolkit.paths import REPOS_ROOT
|
||||
|
||||
import sys
|
||||
|
||||
sys.path.append(REPOS_ROOT)
|
||||
|
||||
process_dict = {
|
||||
'vae': 'TrainVAEProcess',
|
||||
|
||||
Submodule repositories/sd-scripts deleted from b78c0e2a69
@@ -11,13 +11,9 @@ import random
|
||||
from transformers import CLIPImageProcessor
|
||||
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
from toolkit.paths import SD_SCRIPTS_ROOT
|
||||
import torchvision.transforms.functional
|
||||
from toolkit.image_utils import save_tensors, show_img, show_tensors
|
||||
|
||||
sys.path.append(SD_SCRIPTS_ROOT)
|
||||
|
||||
from library.model_util import load_vae
|
||||
from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO
|
||||
from toolkit.data_loader import AiToolkitDataset, get_dataloader_from_datasets, \
|
||||
trigger_dataloader_setup_epoch
|
||||
|
||||
@@ -19,16 +19,12 @@ from toolkit.models.te_adapter import TEAdapter
|
||||
from toolkit.models.te_aug_adapter import TEAugAdapter
|
||||
from toolkit.models.vd_adapter import VisionDirectAdapter
|
||||
from toolkit.models.redux import ReduxImageEncoder
|
||||
from toolkit.paths import REPOS_ROOT
|
||||
from toolkit.photomaker import PhotoMakerIDEncoder, FuseModule, PhotoMakerCLIPEncoder
|
||||
from toolkit.saving import load_ip_adapter_model, load_custom_adapter_model
|
||||
from toolkit.train_tools import get_torch_dtype
|
||||
from toolkit.models.pixtral_vision import PixtralVisionEncoderCompatible, PixtralVisionImagePreprocessorCompatible
|
||||
import random
|
||||
|
||||
from toolkit.util.mask import generate_random_mask
|
||||
|
||||
sys.path.append(REPOS_ROOT)
|
||||
from typing import TYPE_CHECKING, Union, Iterator, Mapping, Any, Tuple, List, Optional, Dict
|
||||
from collections import OrderedDict
|
||||
from toolkit.config_modules import AdapterConfig, AdapterTypes, TrainConfig
|
||||
|
||||
1221
toolkit/kohya_lora.py
Normal file
1221
toolkit/kohya_lora.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -14,16 +14,11 @@ from toolkit.models.lokr import LokrModule
|
||||
from .config_modules import NetworkConfig
|
||||
from .lorm import count_parameters
|
||||
from .network_mixins import ToolkitNetworkMixin, ToolkitModuleMixin, ExtractableModuleMixin
|
||||
from .paths import SD_SCRIPTS_ROOT
|
||||
|
||||
sys.path.append(SD_SCRIPTS_ROOT)
|
||||
|
||||
from networks.lora import LoRANetwork, get_block_index
|
||||
from toolkit.kohya_lora import LoRANetwork
|
||||
from toolkit.models.DoRA import DoRAModule
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from toolkit.stable_diffusion_model import StableDiffusion
|
||||
|
||||
@@ -389,15 +384,6 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
|
||||
if lora_name in modules_dim:
|
||||
dim = modules_dim[lora_name]
|
||||
alpha = modules_alpha[lora_name]
|
||||
elif is_unet and block_dims is not None:
|
||||
# U-Netでblock_dims指定あり
|
||||
block_idx = get_block_index(lora_name)
|
||||
if is_linear or is_conv2d_1x1:
|
||||
dim = block_dims[block_idx]
|
||||
alpha = block_alphas[block_idx]
|
||||
elif conv_block_dims is not None:
|
||||
dim = conv_block_dims[block_idx]
|
||||
alpha = conv_block_alphas[block_idx]
|
||||
else:
|
||||
# 通常、すべて対象とする
|
||||
if is_linear or is_conv2d_1x1:
|
||||
|
||||
@@ -7,9 +7,6 @@ import weakref
|
||||
from typing import Union, TYPE_CHECKING
|
||||
|
||||
from diffusers import Transformer2DModel
|
||||
from transformers import T5EncoderModel, CLIPTextModel, CLIPTokenizer, T5Tokenizer, CLIPVisionModelWithProjection
|
||||
from toolkit.paths import REPOS_ROOT
|
||||
sys.path.append(REPOS_ROOT)
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
||||
@@ -8,7 +8,6 @@ from toolkit.config_modules import GenerateImageConfig, ModelConfig
|
||||
from toolkit.dequantize import patch_dequantization_on_save
|
||||
from toolkit.models.base_model import BaseModel
|
||||
from toolkit.prompt_utils import PromptEmbeds
|
||||
from toolkit.paths import REPOS_ROOT
|
||||
from transformers import AutoTokenizer, UMT5EncoderModel
|
||||
from diffusers import AutoencoderKLWan, WanPipeline, WanTransformer3DModel
|
||||
import os
|
||||
@@ -34,7 +33,6 @@ from diffusers import FlowMatchEulerDiscreteScheduler, UniPCMultistepScheduler
|
||||
from typing import TYPE_CHECKING, List
|
||||
from toolkit.accelerator import unwrap_model
|
||||
from toolkit.samplers.custom_flowmatch_sampler import CustomFlowMatchEulerDiscreteScheduler
|
||||
from torchvision.transforms import Resize, ToPILImage
|
||||
from tqdm import tqdm
|
||||
import torch.nn.functional as F
|
||||
from diffusers.pipelines.wan.pipeline_output import WanPipelineOutput
|
||||
|
||||
@@ -6,7 +6,6 @@ from toolkit.accelerator import unwrap_model
|
||||
from toolkit.basic import flush
|
||||
from toolkit.config_modules import GenerateImageConfig, ModelConfig
|
||||
from toolkit.prompt_utils import PromptEmbeds
|
||||
from toolkit.paths import REPOS_ROOT
|
||||
from transformers import AutoTokenizer, UMT5EncoderModel
|
||||
from diffusers import AutoencoderKLWan, WanImageToVideoPipeline, WanTransformer3DModel
|
||||
import os
|
||||
|
||||
@@ -2,8 +2,6 @@ import os
|
||||
|
||||
TOOLKIT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
CONFIG_ROOT = os.path.join(TOOLKIT_ROOT, 'config')
|
||||
SD_SCRIPTS_ROOT = os.path.join(TOOLKIT_ROOT, "repositories", "sd-scripts")
|
||||
REPOS_ROOT = os.path.join(TOOLKIT_ROOT, "repositories")
|
||||
KEYMAPS_ROOT = os.path.join(TOOLKIT_ROOT, "toolkit", "keymaps")
|
||||
ORIG_CONFIGS_ROOT = os.path.join(TOOLKIT_ROOT, "toolkit", "orig_configs")
|
||||
DIFFUSERS_CONFIGS_ROOT = os.path.join(TOOLKIT_ROOT, "toolkit", "diffusers_configs")
|
||||
|
||||
@@ -8,11 +8,8 @@ from torch.nn import Parameter
|
||||
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
|
||||
|
||||
from toolkit.basic import adain
|
||||
from toolkit.paths import REPOS_ROOT
|
||||
from toolkit.saving import load_ip_adapter_model
|
||||
from toolkit.train_tools import get_torch_dtype
|
||||
|
||||
sys.path.append(REPOS_ROOT)
|
||||
from typing import TYPE_CHECKING, Union, Iterator, Mapping, Any, Tuple, List, Optional, Dict
|
||||
from collections import OrderedDict
|
||||
from toolkit.config_modules import AdapterConfig
|
||||
|
||||
@@ -26,13 +26,12 @@ from toolkit.clip_vision_adapter import ClipVisionAdapter
|
||||
from toolkit.custom_adapter import CustomAdapter
|
||||
from toolkit.dequantize import patch_dequantization_on_save
|
||||
from toolkit.ip_adapter import IPAdapter
|
||||
from library.model_util import convert_unet_state_dict_to_sd, convert_text_encoder_state_dict_to_sd_v2, \
|
||||
convert_vae_state_dict, load_vae
|
||||
from toolkit.util.vae import load_vae
|
||||
from toolkit import train_tools
|
||||
from toolkit.config_modules import ModelConfig, GenerateImageConfig, ModelArch
|
||||
from toolkit.metadata import get_meta_for_safetensors
|
||||
from toolkit.models.decorator import Decorator
|
||||
from toolkit.paths import REPOS_ROOT, KEYMAPS_ROOT
|
||||
from toolkit.paths import KEYMAPS_ROOT
|
||||
from toolkit.prompt_utils import inject_trigger_into_prompt, PromptEmbeds, concat_prompt_embeds
|
||||
from toolkit.reference_adapter import ReferenceAdapter
|
||||
from toolkit.sampler import get_sampler
|
||||
|
||||
@@ -6,11 +6,6 @@ import time
|
||||
from typing import TYPE_CHECKING, Union, List
|
||||
import sys
|
||||
|
||||
from torch.cuda.amp import GradScaler
|
||||
|
||||
from toolkit.paths import SD_SCRIPTS_ROOT
|
||||
|
||||
sys.path.append(SD_SCRIPTS_ROOT)
|
||||
|
||||
from diffusers import (
|
||||
DDPMScheduler,
|
||||
|
||||
20
toolkit/util/vae.py
Normal file
20
toolkit/util/vae.py
Normal file
@@ -0,0 +1,20 @@
|
||||
from diffusers import AutoencoderKL
|
||||
|
||||
|
||||
def load_vae(vae_path, dtype):
|
||||
try:
|
||||
vae = AutoencoderKL.from_pretrained(
|
||||
vae_path,
|
||||
torch_dtype=dtype,
|
||||
)
|
||||
except Exception as e:
|
||||
try:
|
||||
vae = AutoencoderKL.from_pretrained(
|
||||
vae_path.vae_path,
|
||||
subfolder="vae",
|
||||
torch_dtype=dtype,
|
||||
)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to load VAE from {vae_path}: {e}")
|
||||
vae.to(dtype)
|
||||
return vae
|
||||
Reference in New Issue
Block a user