mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-25 00:39:22 +00:00
Initial training script for photomaker training. Needs a little more work.
This commit is contained in:
@@ -19,6 +19,7 @@ from tqdm import tqdm
|
||||
from torchvision.transforms import Resize, transforms
|
||||
|
||||
from toolkit.clip_vision_adapter import ClipVisionAdapter
|
||||
from toolkit.custom_adapter import CustomAdapter
|
||||
from toolkit.ip_adapter import IPAdapter
|
||||
from library.model_util import convert_unet_state_dict_to_sd, convert_text_encoder_state_dict_to_sd_v2, \
|
||||
convert_vae_state_dict, load_vae
|
||||
@@ -483,6 +484,13 @@ class StableDiffusion:
|
||||
transforms.ToTensor(),
|
||||
])
|
||||
validation_image = transform(validation_image)
|
||||
if isinstance(self.adapter, CustomAdapter):
|
||||
# todo allow loading multiple
|
||||
transform = transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
])
|
||||
validation_image = transform(validation_image)
|
||||
self.adapter.num_images = 1
|
||||
if isinstance(self.adapter, ReferenceAdapter):
|
||||
# need -1 to 1
|
||||
validation_image = transforms.ToTensor()(validation_image)
|
||||
@@ -501,6 +509,19 @@ class StableDiffusion:
|
||||
conditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors(validation_image)
|
||||
self.adapter(conditional_clip_embeds)
|
||||
|
||||
if self.adapter is not None and isinstance(self.adapter, CustomAdapter):
|
||||
# handle condition the prompts
|
||||
gen_config.prompt = self.adapter.condition_prompt(
|
||||
gen_config.prompt,
|
||||
is_unconditional=False,
|
||||
)
|
||||
gen_config.prompt_2 = gen_config.prompt
|
||||
gen_config.negative_prompt = self.adapter.condition_prompt(
|
||||
gen_config.negative_prompt,
|
||||
is_unconditional=True,
|
||||
)
|
||||
gen_config.negative_prompt_2 = gen_config.negative_prompt
|
||||
|
||||
# encode the prompt ourselves so we can do fun stuff with embeddings
|
||||
conditional_embeds = self.encode_prompt(gen_config.prompt, gen_config.prompt_2, force_all=True)
|
||||
|
||||
@@ -524,6 +545,21 @@ class StableDiffusion:
|
||||
conditional_embeds = self.adapter(conditional_embeds, conditional_clip_embeds)
|
||||
unconditional_embeds = self.adapter(unconditional_embeds, unconditional_clip_embeds)
|
||||
|
||||
if self.adapter is not None and isinstance(self.adapter, CustomAdapter):
|
||||
conditional_embeds = self.adapter.condition_encoded_embeds(
|
||||
tensors_0_1=validation_image,
|
||||
prompt_embeds=conditional_embeds,
|
||||
is_training=False,
|
||||
has_been_preprocessed=False,
|
||||
)
|
||||
unconditional_embeds = self.adapter.condition_encoded_embeds(
|
||||
tensors_0_1=validation_image,
|
||||
prompt_embeds=unconditional_embeds,
|
||||
is_training=False,
|
||||
has_been_preprocessed=False,
|
||||
is_unconditional=True,
|
||||
)
|
||||
|
||||
if self.refiner_unet is not None and gen_config.refiner_start_at < 1.0:
|
||||
# if we have a refiner loaded, set the denoising end at the refiner start
|
||||
extra['denoising_end'] = gen_config.refiner_start_at
|
||||
@@ -1468,6 +1504,9 @@ class StableDiffusion:
|
||||
elif isinstance(self.adapter, ClipVisionAdapter):
|
||||
requires_grad = self.adapter.embedder.training
|
||||
adapter_device = self.adapter.device
|
||||
elif isinstance(self.adapter, CustomAdapter):
|
||||
requires_grad = self.adapter.training
|
||||
adapter_device = self.adapter.device
|
||||
elif isinstance(self.adapter, ReferenceAdapter):
|
||||
# todo update this!!
|
||||
requires_grad = True
|
||||
|
||||
Reference in New Issue
Block a user