mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 08:29:45 +00:00
Removed all submodules. Submodule free now, yay.
This commit is contained in:
3
.gitignore
vendored
3
.gitignore
vendored
@@ -179,4 +179,5 @@ cython_debug/
|
|||||||
.vscode/settings.json
|
.vscode/settings.json
|
||||||
.DS_Store
|
.DS_Store
|
||||||
._.DS_Store
|
._.DS_Store
|
||||||
aitk_db.db
|
aitk_db.db
|
||||||
|
/notes.md
|
||||||
4
.gitmodules
vendored
4
.gitmodules
vendored
@@ -1,4 +0,0 @@
|
|||||||
[submodule "repositories/sd-scripts"]
|
|
||||||
path = repositories/sd-scripts
|
|
||||||
url = https://github.com/kohya-ss/sd-scripts.git
|
|
||||||
commit = b78c0e2a69e52ce6c79abc6c8c82d1a9cabcf05c
|
|
||||||
|
|||||||
@@ -37,7 +37,6 @@ Linux:
|
|||||||
```bash
|
```bash
|
||||||
git clone https://github.com/ostris/ai-toolkit.git
|
git clone https://github.com/ostris/ai-toolkit.git
|
||||||
cd ai-toolkit
|
cd ai-toolkit
|
||||||
git submodule update --init --recursive
|
|
||||||
python3 -m venv venv
|
python3 -m venv venv
|
||||||
source venv/bin/activate
|
source venv/bin/activate
|
||||||
# install torch first
|
# install torch first
|
||||||
@@ -49,7 +48,6 @@ Windows:
|
|||||||
```bash
|
```bash
|
||||||
git clone https://github.com/ostris/ai-toolkit.git
|
git clone https://github.com/ostris/ai-toolkit.git
|
||||||
cd ai-toolkit
|
cd ai-toolkit
|
||||||
git submodule update --init --recursive
|
|
||||||
python -m venv venv
|
python -m venv venv
|
||||||
.\venv\Scripts\activate
|
.\venv\Scripts\activate
|
||||||
pip install --no-cache-dir torch==2.6.0 torchvision==0.21.0 --index-url https://download.pytorch.org/whl/cu126
|
pip install --no-cache-dir torch==2.6.0 torchvision==0.21.0 --index-url https://download.pytorch.org/whl/cu126
|
||||||
|
|||||||
@@ -1,12 +1,5 @@
|
|||||||
from jobs import BaseJob
|
from jobs import BaseJob
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from typing import List
|
|
||||||
from jobs.process import GenerateProcess
|
|
||||||
from toolkit.paths import REPOS_ROOT
|
|
||||||
|
|
||||||
import sys
|
|
||||||
|
|
||||||
sys.path.append(REPOS_ROOT)
|
|
||||||
|
|
||||||
process_dict = {
|
process_dict = {
|
||||||
'to_folder': 'GenerateProcess',
|
'to_folder': 'GenerateProcess',
|
||||||
|
|||||||
@@ -7,12 +7,7 @@ from collections import OrderedDict
|
|||||||
from typing import List
|
from typing import List
|
||||||
from jobs.process import BaseExtractProcess, TrainFineTuneProcess
|
from jobs.process import BaseExtractProcess, TrainFineTuneProcess
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
import yaml
|
|
||||||
from toolkit.paths import REPOS_ROOT
|
|
||||||
|
|
||||||
import sys
|
|
||||||
|
|
||||||
sys.path.append(REPOS_ROOT)
|
|
||||||
|
|
||||||
process_dict = {
|
process_dict = {
|
||||||
'vae': 'TrainVAEProcess',
|
'vae': 'TrainVAEProcess',
|
||||||
|
|||||||
Submodule repositories/sd-scripts deleted from b78c0e2a69
@@ -11,13 +11,9 @@ import random
|
|||||||
from transformers import CLIPImageProcessor
|
from transformers import CLIPImageProcessor
|
||||||
|
|
||||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||||
from toolkit.paths import SD_SCRIPTS_ROOT
|
|
||||||
import torchvision.transforms.functional
|
import torchvision.transforms.functional
|
||||||
from toolkit.image_utils import save_tensors, show_img, show_tensors
|
from toolkit.image_utils import save_tensors, show_img, show_tensors
|
||||||
|
|
||||||
sys.path.append(SD_SCRIPTS_ROOT)
|
|
||||||
|
|
||||||
from library.model_util import load_vae
|
|
||||||
from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO
|
from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO
|
||||||
from toolkit.data_loader import AiToolkitDataset, get_dataloader_from_datasets, \
|
from toolkit.data_loader import AiToolkitDataset, get_dataloader_from_datasets, \
|
||||||
trigger_dataloader_setup_epoch
|
trigger_dataloader_setup_epoch
|
||||||
|
|||||||
@@ -19,16 +19,12 @@ from toolkit.models.te_adapter import TEAdapter
|
|||||||
from toolkit.models.te_aug_adapter import TEAugAdapter
|
from toolkit.models.te_aug_adapter import TEAugAdapter
|
||||||
from toolkit.models.vd_adapter import VisionDirectAdapter
|
from toolkit.models.vd_adapter import VisionDirectAdapter
|
||||||
from toolkit.models.redux import ReduxImageEncoder
|
from toolkit.models.redux import ReduxImageEncoder
|
||||||
from toolkit.paths import REPOS_ROOT
|
|
||||||
from toolkit.photomaker import PhotoMakerIDEncoder, FuseModule, PhotoMakerCLIPEncoder
|
from toolkit.photomaker import PhotoMakerIDEncoder, FuseModule, PhotoMakerCLIPEncoder
|
||||||
from toolkit.saving import load_ip_adapter_model, load_custom_adapter_model
|
from toolkit.saving import load_ip_adapter_model, load_custom_adapter_model
|
||||||
from toolkit.train_tools import get_torch_dtype
|
from toolkit.train_tools import get_torch_dtype
|
||||||
from toolkit.models.pixtral_vision import PixtralVisionEncoderCompatible, PixtralVisionImagePreprocessorCompatible
|
from toolkit.models.pixtral_vision import PixtralVisionEncoderCompatible, PixtralVisionImagePreprocessorCompatible
|
||||||
import random
|
import random
|
||||||
|
|
||||||
from toolkit.util.mask import generate_random_mask
|
from toolkit.util.mask import generate_random_mask
|
||||||
|
|
||||||
sys.path.append(REPOS_ROOT)
|
|
||||||
from typing import TYPE_CHECKING, Union, Iterator, Mapping, Any, Tuple, List, Optional, Dict
|
from typing import TYPE_CHECKING, Union, Iterator, Mapping, Any, Tuple, List, Optional, Dict
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from toolkit.config_modules import AdapterConfig, AdapterTypes, TrainConfig
|
from toolkit.config_modules import AdapterConfig, AdapterTypes, TrainConfig
|
||||||
|
|||||||
1221
toolkit/kohya_lora.py
Normal file
1221
toolkit/kohya_lora.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -14,16 +14,11 @@ from toolkit.models.lokr import LokrModule
|
|||||||
from .config_modules import NetworkConfig
|
from .config_modules import NetworkConfig
|
||||||
from .lorm import count_parameters
|
from .lorm import count_parameters
|
||||||
from .network_mixins import ToolkitNetworkMixin, ToolkitModuleMixin, ExtractableModuleMixin
|
from .network_mixins import ToolkitNetworkMixin, ToolkitModuleMixin, ExtractableModuleMixin
|
||||||
from .paths import SD_SCRIPTS_ROOT
|
|
||||||
|
|
||||||
sys.path.append(SD_SCRIPTS_ROOT)
|
from toolkit.kohya_lora import LoRANetwork
|
||||||
|
|
||||||
from networks.lora import LoRANetwork, get_block_index
|
|
||||||
from toolkit.models.DoRA import DoRAModule
|
from toolkit.models.DoRA import DoRAModule
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
from torch.utils.checkpoint import checkpoint
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from toolkit.stable_diffusion_model import StableDiffusion
|
from toolkit.stable_diffusion_model import StableDiffusion
|
||||||
|
|
||||||
@@ -389,15 +384,6 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
|
|||||||
if lora_name in modules_dim:
|
if lora_name in modules_dim:
|
||||||
dim = modules_dim[lora_name]
|
dim = modules_dim[lora_name]
|
||||||
alpha = modules_alpha[lora_name]
|
alpha = modules_alpha[lora_name]
|
||||||
elif is_unet and block_dims is not None:
|
|
||||||
# U-Netでblock_dims指定あり
|
|
||||||
block_idx = get_block_index(lora_name)
|
|
||||||
if is_linear or is_conv2d_1x1:
|
|
||||||
dim = block_dims[block_idx]
|
|
||||||
alpha = block_alphas[block_idx]
|
|
||||||
elif conv_block_dims is not None:
|
|
||||||
dim = conv_block_dims[block_idx]
|
|
||||||
alpha = conv_block_alphas[block_idx]
|
|
||||||
else:
|
else:
|
||||||
# 通常、すべて対象とする
|
# 通常、すべて対象とする
|
||||||
if is_linear or is_conv2d_1x1:
|
if is_linear or is_conv2d_1x1:
|
||||||
|
|||||||
@@ -7,9 +7,6 @@ import weakref
|
|||||||
from typing import Union, TYPE_CHECKING
|
from typing import Union, TYPE_CHECKING
|
||||||
|
|
||||||
from diffusers import Transformer2DModel
|
from diffusers import Transformer2DModel
|
||||||
from transformers import T5EncoderModel, CLIPTextModel, CLIPTokenizer, T5Tokenizer, CLIPVisionModelWithProjection
|
|
||||||
from toolkit.paths import REPOS_ROOT
|
|
||||||
sys.path.append(REPOS_ROOT)
|
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
|||||||
@@ -8,7 +8,6 @@ from toolkit.config_modules import GenerateImageConfig, ModelConfig
|
|||||||
from toolkit.dequantize import patch_dequantization_on_save
|
from toolkit.dequantize import patch_dequantization_on_save
|
||||||
from toolkit.models.base_model import BaseModel
|
from toolkit.models.base_model import BaseModel
|
||||||
from toolkit.prompt_utils import PromptEmbeds
|
from toolkit.prompt_utils import PromptEmbeds
|
||||||
from toolkit.paths import REPOS_ROOT
|
|
||||||
from transformers import AutoTokenizer, UMT5EncoderModel
|
from transformers import AutoTokenizer, UMT5EncoderModel
|
||||||
from diffusers import AutoencoderKLWan, WanPipeline, WanTransformer3DModel
|
from diffusers import AutoencoderKLWan, WanPipeline, WanTransformer3DModel
|
||||||
import os
|
import os
|
||||||
@@ -34,7 +33,6 @@ from diffusers import FlowMatchEulerDiscreteScheduler, UniPCMultistepScheduler
|
|||||||
from typing import TYPE_CHECKING, List
|
from typing import TYPE_CHECKING, List
|
||||||
from toolkit.accelerator import unwrap_model
|
from toolkit.accelerator import unwrap_model
|
||||||
from toolkit.samplers.custom_flowmatch_sampler import CustomFlowMatchEulerDiscreteScheduler
|
from toolkit.samplers.custom_flowmatch_sampler import CustomFlowMatchEulerDiscreteScheduler
|
||||||
from torchvision.transforms import Resize, ToPILImage
|
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from diffusers.pipelines.wan.pipeline_output import WanPipelineOutput
|
from diffusers.pipelines.wan.pipeline_output import WanPipelineOutput
|
||||||
|
|||||||
@@ -6,7 +6,6 @@ from toolkit.accelerator import unwrap_model
|
|||||||
from toolkit.basic import flush
|
from toolkit.basic import flush
|
||||||
from toolkit.config_modules import GenerateImageConfig, ModelConfig
|
from toolkit.config_modules import GenerateImageConfig, ModelConfig
|
||||||
from toolkit.prompt_utils import PromptEmbeds
|
from toolkit.prompt_utils import PromptEmbeds
|
||||||
from toolkit.paths import REPOS_ROOT
|
|
||||||
from transformers import AutoTokenizer, UMT5EncoderModel
|
from transformers import AutoTokenizer, UMT5EncoderModel
|
||||||
from diffusers import AutoencoderKLWan, WanImageToVideoPipeline, WanTransformer3DModel
|
from diffusers import AutoencoderKLWan, WanImageToVideoPipeline, WanTransformer3DModel
|
||||||
import os
|
import os
|
||||||
|
|||||||
@@ -2,8 +2,6 @@ import os
|
|||||||
|
|
||||||
TOOLKIT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
TOOLKIT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||||
CONFIG_ROOT = os.path.join(TOOLKIT_ROOT, 'config')
|
CONFIG_ROOT = os.path.join(TOOLKIT_ROOT, 'config')
|
||||||
SD_SCRIPTS_ROOT = os.path.join(TOOLKIT_ROOT, "repositories", "sd-scripts")
|
|
||||||
REPOS_ROOT = os.path.join(TOOLKIT_ROOT, "repositories")
|
|
||||||
KEYMAPS_ROOT = os.path.join(TOOLKIT_ROOT, "toolkit", "keymaps")
|
KEYMAPS_ROOT = os.path.join(TOOLKIT_ROOT, "toolkit", "keymaps")
|
||||||
ORIG_CONFIGS_ROOT = os.path.join(TOOLKIT_ROOT, "toolkit", "orig_configs")
|
ORIG_CONFIGS_ROOT = os.path.join(TOOLKIT_ROOT, "toolkit", "orig_configs")
|
||||||
DIFFUSERS_CONFIGS_ROOT = os.path.join(TOOLKIT_ROOT, "toolkit", "diffusers_configs")
|
DIFFUSERS_CONFIGS_ROOT = os.path.join(TOOLKIT_ROOT, "toolkit", "diffusers_configs")
|
||||||
|
|||||||
@@ -8,11 +8,8 @@ from torch.nn import Parameter
|
|||||||
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
|
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
|
||||||
|
|
||||||
from toolkit.basic import adain
|
from toolkit.basic import adain
|
||||||
from toolkit.paths import REPOS_ROOT
|
|
||||||
from toolkit.saving import load_ip_adapter_model
|
from toolkit.saving import load_ip_adapter_model
|
||||||
from toolkit.train_tools import get_torch_dtype
|
from toolkit.train_tools import get_torch_dtype
|
||||||
|
|
||||||
sys.path.append(REPOS_ROOT)
|
|
||||||
from typing import TYPE_CHECKING, Union, Iterator, Mapping, Any, Tuple, List, Optional, Dict
|
from typing import TYPE_CHECKING, Union, Iterator, Mapping, Any, Tuple, List, Optional, Dict
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from toolkit.config_modules import AdapterConfig
|
from toolkit.config_modules import AdapterConfig
|
||||||
|
|||||||
@@ -26,13 +26,12 @@ from toolkit.clip_vision_adapter import ClipVisionAdapter
|
|||||||
from toolkit.custom_adapter import CustomAdapter
|
from toolkit.custom_adapter import CustomAdapter
|
||||||
from toolkit.dequantize import patch_dequantization_on_save
|
from toolkit.dequantize import patch_dequantization_on_save
|
||||||
from toolkit.ip_adapter import IPAdapter
|
from toolkit.ip_adapter import IPAdapter
|
||||||
from library.model_util import convert_unet_state_dict_to_sd, convert_text_encoder_state_dict_to_sd_v2, \
|
from toolkit.util.vae import load_vae
|
||||||
convert_vae_state_dict, load_vae
|
|
||||||
from toolkit import train_tools
|
from toolkit import train_tools
|
||||||
from toolkit.config_modules import ModelConfig, GenerateImageConfig, ModelArch
|
from toolkit.config_modules import ModelConfig, GenerateImageConfig, ModelArch
|
||||||
from toolkit.metadata import get_meta_for_safetensors
|
from toolkit.metadata import get_meta_for_safetensors
|
||||||
from toolkit.models.decorator import Decorator
|
from toolkit.models.decorator import Decorator
|
||||||
from toolkit.paths import REPOS_ROOT, KEYMAPS_ROOT
|
from toolkit.paths import KEYMAPS_ROOT
|
||||||
from toolkit.prompt_utils import inject_trigger_into_prompt, PromptEmbeds, concat_prompt_embeds
|
from toolkit.prompt_utils import inject_trigger_into_prompt, PromptEmbeds, concat_prompt_embeds
|
||||||
from toolkit.reference_adapter import ReferenceAdapter
|
from toolkit.reference_adapter import ReferenceAdapter
|
||||||
from toolkit.sampler import get_sampler
|
from toolkit.sampler import get_sampler
|
||||||
|
|||||||
@@ -6,11 +6,6 @@ import time
|
|||||||
from typing import TYPE_CHECKING, Union, List
|
from typing import TYPE_CHECKING, Union, List
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
from torch.cuda.amp import GradScaler
|
|
||||||
|
|
||||||
from toolkit.paths import SD_SCRIPTS_ROOT
|
|
||||||
|
|
||||||
sys.path.append(SD_SCRIPTS_ROOT)
|
|
||||||
|
|
||||||
from diffusers import (
|
from diffusers import (
|
||||||
DDPMScheduler,
|
DDPMScheduler,
|
||||||
|
|||||||
20
toolkit/util/vae.py
Normal file
20
toolkit/util/vae.py
Normal file
@@ -0,0 +1,20 @@
|
|||||||
|
from diffusers import AutoencoderKL
|
||||||
|
|
||||||
|
|
||||||
|
def load_vae(vae_path, dtype):
|
||||||
|
try:
|
||||||
|
vae = AutoencoderKL.from_pretrained(
|
||||||
|
vae_path,
|
||||||
|
torch_dtype=dtype,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
try:
|
||||||
|
vae = AutoencoderKL.from_pretrained(
|
||||||
|
vae_path.vae_path,
|
||||||
|
subfolder="vae",
|
||||||
|
torch_dtype=dtype,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
raise ValueError(f"Failed to load VAE from {vae_path}: {e}")
|
||||||
|
vae.to(dtype)
|
||||||
|
return vae
|
||||||
Reference in New Issue
Block a user