mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-30 03:01:28 +00:00
Built base interfaces for a DTO to handle batch infomation transports for the dataloader
This commit is contained in:
@@ -86,6 +86,14 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
|||||||
if embedding_raw is not None:
|
if embedding_raw is not None:
|
||||||
self.embed_config = EmbeddingConfig(**embedding_raw)
|
self.embed_config = EmbeddingConfig(**embedding_raw)
|
||||||
|
|
||||||
|
|
||||||
|
# check to see if we have a latest save
|
||||||
|
latest_save_path = self.get_latest_save_path()
|
||||||
|
|
||||||
|
if latest_save_path is not None:
|
||||||
|
print(f"#### IMPORTANT RESUMING FROM {latest_save_path} ####")
|
||||||
|
self.model_config.name_or_path = latest_save_path
|
||||||
|
|
||||||
self.sd = StableDiffusion(
|
self.sd = StableDiffusion(
|
||||||
device=self.device,
|
device=self.device,
|
||||||
model_config=self.model_config,
|
model_config=self.model_config,
|
||||||
@@ -113,7 +121,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
|||||||
# zero-pad 9 digits
|
# zero-pad 9 digits
|
||||||
step_num = f"_{str(step).zfill(9)}"
|
step_num = f"_{str(step).zfill(9)}"
|
||||||
|
|
||||||
filename = f"[time]_{step_num}_[count].png"
|
filename = f"[time]_{step_num}_[count].{self.sample_config.ext}"
|
||||||
|
|
||||||
output_path = os.path.join(sample_folder, filename)
|
output_path = os.path.join(sample_folder, filename)
|
||||||
|
|
||||||
@@ -142,6 +150,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
|||||||
num_inference_steps=sample_config.sample_steps,
|
num_inference_steps=sample_config.sample_steps,
|
||||||
network_multiplier=sample_config.network_multiplier,
|
network_multiplier=sample_config.network_multiplier,
|
||||||
output_path=output_path,
|
output_path=output_path,
|
||||||
|
output_ext=sample_config.ext,
|
||||||
))
|
))
|
||||||
|
|
||||||
# send to be generated
|
# send to be generated
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ import time
|
|||||||
from typing import List, Optional, Literal
|
from typing import List, Optional, Literal
|
||||||
import random
|
import random
|
||||||
|
|
||||||
|
ImgExt = Literal['jpg', 'png', 'webp']
|
||||||
|
|
||||||
class SaveConfig:
|
class SaveConfig:
|
||||||
def __init__(self, **kwargs):
|
def __init__(self, **kwargs):
|
||||||
@@ -31,6 +32,7 @@ class SampleConfig:
|
|||||||
self.sample_steps = kwargs.get('sample_steps', 20)
|
self.sample_steps = kwargs.get('sample_steps', 20)
|
||||||
self.network_multiplier = kwargs.get('network_multiplier', 1)
|
self.network_multiplier = kwargs.get('network_multiplier', 1)
|
||||||
self.guidance_rescale = kwargs.get('guidance_rescale', 0.0)
|
self.guidance_rescale = kwargs.get('guidance_rescale', 0.0)
|
||||||
|
self.ext: ImgExt = kwargs.get('format', 'jpg')
|
||||||
|
|
||||||
|
|
||||||
class NetworkConfig:
|
class NetworkConfig:
|
||||||
@@ -158,7 +160,10 @@ class SliderConfig:
|
|||||||
|
|
||||||
|
|
||||||
class DatasetConfig:
|
class DatasetConfig:
|
||||||
caption_type: Literal["txt", "caption"] = 'txt'
|
"""
|
||||||
|
Dataset config for sd-datasets
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self, **kwargs):
|
def __init__(self, **kwargs):
|
||||||
self.type = kwargs.get('type', 'image') # sd, slider, reference
|
self.type = kwargs.get('type', 'image') # sd, slider, reference
|
||||||
@@ -172,6 +177,10 @@ class DatasetConfig:
|
|||||||
self.buckets: bool = kwargs.get('buckets', False)
|
self.buckets: bool = kwargs.get('buckets', False)
|
||||||
self.bucket_tolerance: int = kwargs.get('bucket_tolerance', 64)
|
self.bucket_tolerance: int = kwargs.get('bucket_tolerance', 64)
|
||||||
self.is_reg: bool = kwargs.get('is_reg', False)
|
self.is_reg: bool = kwargs.get('is_reg', False)
|
||||||
|
self.network_weight: float = float(kwargs.get('network_weight', 1.0))
|
||||||
|
self.token_dropout_rate: float = float(kwargs.get('token_dropout_rate', 0.0))
|
||||||
|
self.shuffle_tokens: bool = kwargs.get('shuffle_tokens', False)
|
||||||
|
self.caption_dropout_rate: float = float(kwargs.get('caption_dropout_rate', 0.0))
|
||||||
|
|
||||||
|
|
||||||
class GenerateImageConfig:
|
class GenerateImageConfig:
|
||||||
@@ -191,7 +200,7 @@ class GenerateImageConfig:
|
|||||||
# the tag [time] will be replaced with milliseconds since epoch
|
# the tag [time] will be replaced with milliseconds since epoch
|
||||||
output_path: str = None, # full image path
|
output_path: str = None, # full image path
|
||||||
output_folder: str = None, # folder to save image in if output_path is not specified
|
output_folder: str = None, # folder to save image in if output_path is not specified
|
||||||
output_ext: str = 'png', # extension to save image as if output_path is not specified
|
output_ext: str = ImgExt, # extension to save image as if output_path is not specified
|
||||||
output_tail: str = '', # tail to add to output filename
|
output_tail: str = '', # tail to add to output filename
|
||||||
add_prompt_file: bool = False, # add a prompt file with generated image
|
add_prompt_file: bool = False, # add a prompt file with generated image
|
||||||
):
|
):
|
||||||
|
|||||||
@@ -15,6 +15,8 @@ import albumentations as A
|
|||||||
from toolkit import image_utils
|
from toolkit import image_utils
|
||||||
from toolkit.config_modules import DatasetConfig
|
from toolkit.config_modules import DatasetConfig
|
||||||
from toolkit.dataloader_mixins import CaptionMixin, BucketsMixin
|
from toolkit.dataloader_mixins import CaptionMixin, BucketsMixin
|
||||||
|
from toolkit.data_transfer_object.data_loader import FileItemDTO
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class ImageDataset(Dataset, CaptionMixin):
|
class ImageDataset(Dataset, CaptionMixin):
|
||||||
@@ -296,20 +298,6 @@ def print_once(msg):
|
|||||||
printed_messages.append(msg)
|
printed_messages.append(msg)
|
||||||
|
|
||||||
|
|
||||||
class FileItem:
|
|
||||||
def __init__(self, **kwargs):
|
|
||||||
self.path = kwargs.get('path', None)
|
|
||||||
self.width = kwargs.get('width', None)
|
|
||||||
self.height = kwargs.get('height', None)
|
|
||||||
# we scale first, then crop
|
|
||||||
self.scale_to_width = kwargs.get('scale_to_width', self.width)
|
|
||||||
self.scale_to_height = kwargs.get('scale_to_height', self.height)
|
|
||||||
# crop values are from scaled size
|
|
||||||
self.crop_x = kwargs.get('crop_x', 0)
|
|
||||||
self.crop_y = kwargs.get('crop_y', 0)
|
|
||||||
self.crop_width = kwargs.get('crop_width', self.scale_to_width)
|
|
||||||
self.crop_height = kwargs.get('crop_height', self.scale_to_height)
|
|
||||||
|
|
||||||
|
|
||||||
class AiToolkitDataset(Dataset, CaptionMixin, BucketsMixin):
|
class AiToolkitDataset(Dataset, CaptionMixin, BucketsMixin):
|
||||||
|
|
||||||
@@ -325,7 +313,7 @@ class AiToolkitDataset(Dataset, CaptionMixin, BucketsMixin):
|
|||||||
# we always random crop if random scale is enabled
|
# we always random crop if random scale is enabled
|
||||||
self.random_crop = self.random_scale if self.random_scale else dataset_config.random_crop
|
self.random_crop = self.random_scale if self.random_scale else dataset_config.random_crop
|
||||||
self.resolution = dataset_config.resolution
|
self.resolution = dataset_config.resolution
|
||||||
self.file_list: List['FileItem'] = []
|
self.file_list: List['FileItemDTO'] = []
|
||||||
|
|
||||||
# get the file list
|
# get the file list
|
||||||
file_list = [
|
file_list = [
|
||||||
@@ -344,14 +332,16 @@ class AiToolkitDataset(Dataset, CaptionMixin, BucketsMixin):
|
|||||||
f'This process is faster for png, jpeg')
|
f'This process is faster for png, jpeg')
|
||||||
img = Image.open(file)
|
img = Image.open(file)
|
||||||
h, w = img.size
|
h, w = img.size
|
||||||
|
# TODO allow smaller images
|
||||||
if int(min(h, w) * self.scale) >= self.resolution:
|
if int(min(h, w) * self.scale) >= self.resolution:
|
||||||
self.file_list.append(
|
self.file_list.append(
|
||||||
FileItem(
|
FileItemDTO(
|
||||||
path=file,
|
path=file,
|
||||||
width=w,
|
width=w,
|
||||||
height=h,
|
height=h,
|
||||||
scale_to_width=int(w * self.scale),
|
scale_to_width=int(w * self.scale),
|
||||||
scale_to_height=int(h * self.scale),
|
scale_to_height=int(h * self.scale),
|
||||||
|
dataset_config=dataset_config
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
|||||||
36
toolkit/data_transfer_object/data_loader.py
Normal file
36
toolkit/data_transfer_object/data_loader.py
Normal file
@@ -0,0 +1,36 @@
|
|||||||
|
from typing import TYPE_CHECKING
|
||||||
|
import torch
|
||||||
|
import random
|
||||||
|
|
||||||
|
from toolkit.dataloader_mixins import CaptionProcessingDTOMixin
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from toolkit.config_modules import DatasetConfig
|
||||||
|
|
||||||
|
|
||||||
|
class FileItemDTO(CaptionProcessingDTOMixin):
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
self.path = kwargs.get('path', None)
|
||||||
|
self.caption_path: str = kwargs.get('caption_path', None)
|
||||||
|
self.raw_caption: str = kwargs.get('raw_caption', None)
|
||||||
|
self.width: int = kwargs.get('width', None)
|
||||||
|
self.height: int = kwargs.get('height', None)
|
||||||
|
# we scale first, then crop
|
||||||
|
self.scale_to_width: int = kwargs.get('scale_to_width', self.width)
|
||||||
|
self.scale_to_height: int = kwargs.get('scale_to_height', self.height)
|
||||||
|
# crop values are from scaled size
|
||||||
|
self.crop_x: int = kwargs.get('crop_x', 0)
|
||||||
|
self.crop_y: int = kwargs.get('crop_y', 0)
|
||||||
|
self.crop_width: int = kwargs.get('crop_width', self.scale_to_width)
|
||||||
|
self.crop_height: int = kwargs.get('crop_height', self.scale_to_height)
|
||||||
|
|
||||||
|
# process config
|
||||||
|
self.dataset_config: 'DatasetConfig' = kwargs.get('dataset_config', None)
|
||||||
|
|
||||||
|
self.network_network_weight: float = self.dataset_config.network_weight
|
||||||
|
|
||||||
|
|
||||||
|
class DataLoaderBatchDTO:
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
self.file_item: 'FileItemDTO' = kwargs.get('file_item', None)
|
||||||
|
self.dataset_config: 'DatasetConfig' = kwargs.get('dataset_config', None)
|
||||||
@@ -1,9 +1,19 @@
|
|||||||
import os
|
import os
|
||||||
|
import random
|
||||||
from typing import TYPE_CHECKING, List, Dict
|
from typing import TYPE_CHECKING, List, Dict
|
||||||
|
|
||||||
|
from toolkit.prompt_utils import inject_trigger_into_prompt
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from toolkit.data_loader import AiToolkitDataset
|
||||||
|
from toolkit.data_transfer_object.data_loader import FileItemDTO
|
||||||
|
|
||||||
|
|
||||||
|
# def get_associated_caption_from_img_path(img_path):
|
||||||
|
|
||||||
|
|
||||||
class CaptionMixin:
|
class CaptionMixin:
|
||||||
def get_caption_item(self, index):
|
def get_caption_item(self: 'AiToolkitDataset', index):
|
||||||
if not hasattr(self, 'caption_type'):
|
if not hasattr(self, 'caption_type'):
|
||||||
raise Exception('caption_type not found on class instance')
|
raise Exception('caption_type not found on class instance')
|
||||||
if not hasattr(self, 'file_list'):
|
if not hasattr(self, 'file_list'):
|
||||||
@@ -48,7 +58,7 @@ class CaptionMixin:
|
|||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from toolkit.config_modules import DatasetConfig
|
from toolkit.config_modules import DatasetConfig
|
||||||
from toolkit.data_loader import FileItem
|
from toolkit.data_transfer_object.data_loader import FileItemDTO
|
||||||
|
|
||||||
|
|
||||||
class Bucket:
|
class Bucket:
|
||||||
@@ -63,14 +73,14 @@ class BucketsMixin:
|
|||||||
self.buckets: Dict[str, Bucket] = {}
|
self.buckets: Dict[str, Bucket] = {}
|
||||||
self.batch_indices: List[List[int]] = []
|
self.batch_indices: List[List[int]] = []
|
||||||
|
|
||||||
def build_batch_indices(self):
|
def build_batch_indices(self: 'AiToolkitDataset'):
|
||||||
for key, bucket in self.buckets.items():
|
for key, bucket in self.buckets.items():
|
||||||
for start_idx in range(0, len(bucket.file_list_idx), self.batch_size):
|
for start_idx in range(0, len(bucket.file_list_idx), self.batch_size):
|
||||||
end_idx = min(start_idx + self.batch_size, len(bucket.file_list_idx))
|
end_idx = min(start_idx + self.batch_size, len(bucket.file_list_idx))
|
||||||
batch = bucket.file_list_idx[start_idx:end_idx]
|
batch = bucket.file_list_idx[start_idx:end_idx]
|
||||||
self.batch_indices.append(batch)
|
self.batch_indices.append(batch)
|
||||||
|
|
||||||
def setup_buckets(self):
|
def setup_buckets(self: 'AiToolkitDataset'):
|
||||||
if not hasattr(self, 'file_list'):
|
if not hasattr(self, 'file_list'):
|
||||||
raise Exception(f'file_list not found on class instance {self.__class__.__name__}')
|
raise Exception(f'file_list not found on class instance {self.__class__.__name__}')
|
||||||
if not hasattr(self, 'dataset_config'):
|
if not hasattr(self, 'dataset_config'):
|
||||||
@@ -79,7 +89,7 @@ class BucketsMixin:
|
|||||||
config: 'DatasetConfig' = self.dataset_config
|
config: 'DatasetConfig' = self.dataset_config
|
||||||
resolution = config.resolution
|
resolution = config.resolution
|
||||||
bucket_tolerance = config.bucket_tolerance
|
bucket_tolerance = config.bucket_tolerance
|
||||||
file_list: List['FileItem'] = self.file_list
|
file_list: List['FileItemDTO'] = self.file_list
|
||||||
|
|
||||||
# make sure out resolution is divisible by bucket_tolerance
|
# make sure out resolution is divisible by bucket_tolerance
|
||||||
if resolution % bucket_tolerance != 0:
|
if resolution % bucket_tolerance != 0:
|
||||||
@@ -146,3 +156,48 @@ class BucketsMixin:
|
|||||||
print(f'{len(self.buckets)} buckets made')
|
print(f'{len(self.buckets)} buckets made')
|
||||||
|
|
||||||
# file buckets made
|
# file buckets made
|
||||||
|
|
||||||
|
|
||||||
|
class CaptionProcessingDTOMixin:
|
||||||
|
def get_caption(
|
||||||
|
self: 'FileItemDTO',
|
||||||
|
trigger=None,
|
||||||
|
to_replace_list=None,
|
||||||
|
add_if_not_present=True
|
||||||
|
):
|
||||||
|
raw_caption = self.raw_caption
|
||||||
|
if raw_caption is None:
|
||||||
|
raw_caption = ''
|
||||||
|
# handle dropout
|
||||||
|
if self.dataset_config.caption_dropout_rate > 0:
|
||||||
|
# get a random float form 0 to 1
|
||||||
|
rand = random.random()
|
||||||
|
if rand < self.dataset_config.caption_dropout_rate:
|
||||||
|
# drop the caption
|
||||||
|
return ''
|
||||||
|
|
||||||
|
# get tokens
|
||||||
|
token_list = raw_caption.split(',')
|
||||||
|
# trim whitespace
|
||||||
|
token_list = [x.strip() for x in token_list]
|
||||||
|
# remove empty strings
|
||||||
|
token_list = [x for x in token_list if x]
|
||||||
|
|
||||||
|
if self.dataset_config.shuffle_tokens:
|
||||||
|
random.shuffle(token_list)
|
||||||
|
|
||||||
|
# handle token dropout
|
||||||
|
if self.dataset_config.token_dropout_rate > 0:
|
||||||
|
new_token_list = []
|
||||||
|
for token in token_list:
|
||||||
|
# get a random float form 0 to 1
|
||||||
|
rand = random.random()
|
||||||
|
if rand > self.dataset_config.token_dropout_rate:
|
||||||
|
# keep the token
|
||||||
|
new_token_list.append(token)
|
||||||
|
token_list = new_token_list
|
||||||
|
|
||||||
|
# join back together
|
||||||
|
caption = ', '.join(token_list)
|
||||||
|
caption = inject_trigger_into_prompt(caption, trigger, to_replace_list, add_if_not_present)
|
||||||
|
return caption
|
||||||
|
|||||||
@@ -1,12 +1,11 @@
|
|||||||
import os
|
import os
|
||||||
from typing import Optional, TYPE_CHECKING, List
|
from typing import Optional, TYPE_CHECKING, List, Union, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from safetensors.torch import load_file, save_file
|
from safetensors.torch import load_file, save_file
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
import random
|
import random
|
||||||
|
|
||||||
from toolkit.stable_diffusion_model import PromptEmbeds
|
|
||||||
from toolkit.train_tools import get_torch_dtype
|
from toolkit.train_tools import get_torch_dtype
|
||||||
import itertools
|
import itertools
|
||||||
|
|
||||||
@@ -19,6 +18,27 @@ class ACTION_TYPES_SLIDER:
|
|||||||
ENHANCE_NEGATIVE = 1
|
ENHANCE_NEGATIVE = 1
|
||||||
|
|
||||||
|
|
||||||
|
class PromptEmbeds:
|
||||||
|
text_embeds: torch.Tensor
|
||||||
|
pooled_embeds: Union[torch.Tensor, None]
|
||||||
|
|
||||||
|
def __init__(self, args: Union[Tuple[torch.Tensor], List[torch.Tensor], torch.Tensor]) -> None:
|
||||||
|
if isinstance(args, list) or isinstance(args, tuple):
|
||||||
|
# xl
|
||||||
|
self.text_embeds = args[0]
|
||||||
|
self.pooled_embeds = args[1]
|
||||||
|
else:
|
||||||
|
# sdv1.x, sdv2.x
|
||||||
|
self.text_embeds = args
|
||||||
|
self.pooled_embeds = None
|
||||||
|
|
||||||
|
def to(self, *args, **kwargs):
|
||||||
|
self.text_embeds = self.text_embeds.to(*args, **kwargs)
|
||||||
|
if self.pooled_embeds is not None:
|
||||||
|
self.pooled_embeds = self.pooled_embeds.to(*args, **kwargs)
|
||||||
|
return self
|
||||||
|
|
||||||
|
|
||||||
class EncodedPromptPair:
|
class EncodedPromptPair:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -465,3 +485,37 @@ def build_latent_image_batch_for_prompt_pair(
|
|||||||
latent_list.append(neg_latent)
|
latent_list.append(neg_latent)
|
||||||
|
|
||||||
return torch.cat(latent_list, dim=0)
|
return torch.cat(latent_list, dim=0)
|
||||||
|
|
||||||
|
|
||||||
|
def inject_trigger_into_prompt(prompt, trigger=None, to_replace_list=None, add_if_not_present=True):
|
||||||
|
if trigger is None:
|
||||||
|
return prompt
|
||||||
|
output_prompt = prompt
|
||||||
|
default_replacements = ["[name]", "[trigger]"]
|
||||||
|
|
||||||
|
replace_with = trigger
|
||||||
|
if to_replace_list is None:
|
||||||
|
to_replace_list = default_replacements
|
||||||
|
else:
|
||||||
|
to_replace_list += default_replacements
|
||||||
|
|
||||||
|
# remove duplicates
|
||||||
|
to_replace_list = list(set(to_replace_list))
|
||||||
|
|
||||||
|
# replace them all
|
||||||
|
for to_replace in to_replace_list:
|
||||||
|
# replace it
|
||||||
|
output_prompt = output_prompt.replace(to_replace, replace_with)
|
||||||
|
|
||||||
|
# see how many times replace_with is in the prompt
|
||||||
|
num_instances = output_prompt.count(replace_with)
|
||||||
|
|
||||||
|
if num_instances == 0 and add_if_not_present:
|
||||||
|
# add it to the beginning of the prompt
|
||||||
|
output_prompt = replace_with + " " + output_prompt
|
||||||
|
|
||||||
|
if num_instances > 1:
|
||||||
|
print(
|
||||||
|
f"Warning: {trigger} token appears {num_instances} times in prompt {output_prompt}. This may cause issues.")
|
||||||
|
|
||||||
|
return output_prompt
|
||||||
|
|||||||
@@ -16,6 +16,7 @@ from toolkit import train_tools
|
|||||||
from toolkit.config_modules import ModelConfig, GenerateImageConfig
|
from toolkit.config_modules import ModelConfig, GenerateImageConfig
|
||||||
from toolkit.metadata import get_meta_for_safetensors
|
from toolkit.metadata import get_meta_for_safetensors
|
||||||
from toolkit.paths import REPOS_ROOT
|
from toolkit.paths import REPOS_ROOT
|
||||||
|
from toolkit.prompt_utils import inject_trigger_into_prompt, PromptEmbeds
|
||||||
from toolkit.saving import save_ldm_model_from_diffusers
|
from toolkit.saving import save_ldm_model_from_diffusers
|
||||||
from toolkit.train_tools import get_torch_dtype, apply_noise_offset
|
from toolkit.train_tools import get_torch_dtype, apply_noise_offset
|
||||||
import torch
|
import torch
|
||||||
@@ -63,25 +64,7 @@ UNET_IN_CHANNELS = 4 # Stable Diffusion の in_channels は 4 で固定。XLも
|
|||||||
VAE_SCALE_FACTOR = 8 # 2 ** (len(vae.config.block_out_channels) - 1) = 8
|
VAE_SCALE_FACTOR = 8 # 2 ** (len(vae.config.block_out_channels) - 1) = 8
|
||||||
|
|
||||||
|
|
||||||
class PromptEmbeds:
|
|
||||||
text_embeds: torch.Tensor
|
|
||||||
pooled_embeds: Union[torch.Tensor, None]
|
|
||||||
|
|
||||||
def __init__(self, args: Union[Tuple[torch.Tensor], List[torch.Tensor], torch.Tensor]) -> None:
|
|
||||||
if isinstance(args, list) or isinstance(args, tuple):
|
|
||||||
# xl
|
|
||||||
self.text_embeds = args[0]
|
|
||||||
self.pooled_embeds = args[1]
|
|
||||||
else:
|
|
||||||
# sdv1.x, sdv2.x
|
|
||||||
self.text_embeds = args
|
|
||||||
self.pooled_embeds = None
|
|
||||||
|
|
||||||
def to(self, *args, **kwargs):
|
|
||||||
self.text_embeds = self.text_embeds.to(*args, **kwargs)
|
|
||||||
if self.pooled_embeds is not None:
|
|
||||||
self.pooled_embeds = self.pooled_embeds.to(*args, **kwargs)
|
|
||||||
return self
|
|
||||||
|
|
||||||
|
|
||||||
# if is type checking
|
# if is type checking
|
||||||
@@ -708,38 +691,12 @@ class StableDiffusion:
|
|||||||
raise ValueError(f"Unknown weight name: {name}")
|
raise ValueError(f"Unknown weight name: {name}")
|
||||||
|
|
||||||
def inject_trigger_into_prompt(self, prompt, trigger=None, to_replace_list=None, add_if_not_present=True):
|
def inject_trigger_into_prompt(self, prompt, trigger=None, to_replace_list=None, add_if_not_present=True):
|
||||||
if trigger is None:
|
return inject_trigger_into_prompt(
|
||||||
return prompt
|
prompt,
|
||||||
output_prompt = prompt
|
trigger=trigger,
|
||||||
default_replacements = ["[name]", "[trigger]"]
|
to_replace_list=to_replace_list,
|
||||||
num_times_trigger_exists = prompt.count(trigger)
|
add_if_not_present=add_if_not_present,
|
||||||
|
)
|
||||||
replace_with = trigger
|
|
||||||
if to_replace_list is None:
|
|
||||||
to_replace_list = default_replacements
|
|
||||||
else:
|
|
||||||
to_replace_list += default_replacements
|
|
||||||
|
|
||||||
# remove duplicates
|
|
||||||
to_replace_list = list(set(to_replace_list))
|
|
||||||
|
|
||||||
# replace them all
|
|
||||||
for to_replace in to_replace_list:
|
|
||||||
# replace it
|
|
||||||
output_prompt = output_prompt.replace(to_replace, replace_with)
|
|
||||||
|
|
||||||
# see how many times replace_with is in the prompt
|
|
||||||
num_instances = output_prompt.count(replace_with)
|
|
||||||
|
|
||||||
if num_instances == 0 and add_if_not_present:
|
|
||||||
# add it to the beginning of the prompt
|
|
||||||
output_prompt = replace_with + " " + output_prompt
|
|
||||||
|
|
||||||
if num_instances > 1:
|
|
||||||
print(
|
|
||||||
f"Warning: {trigger} token appears {num_instances} times in prompt {output_prompt}. This may cause issues.")
|
|
||||||
|
|
||||||
return output_prompt
|
|
||||||
|
|
||||||
def state_dict(self, vae=True, text_encoder=True, unet=True):
|
def state_dict(self, vae=True, text_encoder=True, unet=True):
|
||||||
state_dict = OrderedDict()
|
state_dict = OrderedDict()
|
||||||
|
|||||||
Reference in New Issue
Block a user