mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-28 18:21:16 +00:00
Added ability to do cfg during training. Various bug fixes
This commit is contained in:
@@ -156,6 +156,7 @@ class AdapterConfig:
|
||||
self.image_encoder_arch: str = kwargs.get('image_encoder_arch', 'clip') # clip vit vit_hybrid, safe
|
||||
self.safe_reducer_channels: int = kwargs.get('safe_reducer_channels', 512)
|
||||
self.safe_channels: int = kwargs.get('safe_channels', 2048)
|
||||
self.safe_tokens: int = kwargs.get('safe_tokens', 8)
|
||||
|
||||
# clip vision
|
||||
self.trigger = kwargs.get('trigger', 'tri993r')
|
||||
@@ -270,6 +271,7 @@ class TrainConfig:
|
||||
raise ValueError(f"train_turbo is only supported with euler and wuler_a noise schedulers")
|
||||
|
||||
self.dynamic_noise_offset = kwargs.get('dynamic_noise_offset', False)
|
||||
self.do_cfg = kwargs.get('do_cfg', False)
|
||||
|
||||
|
||||
class ModelConfig:
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import torch
|
||||
from typing import Literal
|
||||
from typing import Literal, Optional
|
||||
|
||||
from toolkit.basic import value_map
|
||||
from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO
|
||||
@@ -193,6 +193,7 @@ def get_direct_guidance_loss(
|
||||
batch: 'DataLoaderBatchDTO',
|
||||
noise: torch.Tensor,
|
||||
sd: 'StableDiffusion',
|
||||
unconditional_embeds: Optional[PromptEmbeds] = None,
|
||||
**kwargs
|
||||
):
|
||||
with torch.no_grad():
|
||||
@@ -222,9 +223,14 @@ def get_direct_guidance_loss(
|
||||
|
||||
# sd.network.multiplier = network_weight_list
|
||||
# do our prediction with LoRA active on the scaled guidance latents
|
||||
if unconditional_embeds is not None:
|
||||
unconditional_embeds = unconditional_embeds.to(device, dtype=dtype).detach()
|
||||
unconditional_embeds = concat_prompt_embeds([unconditional_embeds, unconditional_embeds])
|
||||
|
||||
prediction = sd.predict_noise(
|
||||
latents=torch.cat([unconditional_noisy_latents, conditional_noisy_latents]).to(device, dtype=dtype).detach(),
|
||||
conditional_embeddings=concat_prompt_embeds([conditional_embeds,conditional_embeds]).to(device, dtype=dtype).detach(),
|
||||
unconditional_embeddings=unconditional_embeds,
|
||||
timestep=torch.cat([timesteps, timesteps]),
|
||||
guidance_scale=1.0,
|
||||
**pred_kwargs # adapter residuals in here
|
||||
@@ -482,12 +488,14 @@ def get_guidance_loss(
|
||||
batch: 'DataLoaderBatchDTO',
|
||||
noise: torch.Tensor,
|
||||
sd: 'StableDiffusion',
|
||||
unconditional_embeds: Optional[PromptEmbeds] = None,
|
||||
**kwargs
|
||||
):
|
||||
# TODO add others and process individual batch items separately
|
||||
guidance_type: GuidanceType = batch.file_items[0].dataset_config.guidance_type
|
||||
|
||||
if guidance_type == "targeted":
|
||||
assert unconditional_embeds is None, "Unconditional embeds are not supported for targeted guidance"
|
||||
return get_targeted_guidance_loss(
|
||||
noisy_latents,
|
||||
conditional_embeds,
|
||||
@@ -501,6 +509,7 @@ def get_guidance_loss(
|
||||
**kwargs
|
||||
)
|
||||
elif guidance_type == "polarity":
|
||||
assert unconditional_embeds is None, "Unconditional embeds are not supported for polarity guidance"
|
||||
return get_guided_loss_polarity(
|
||||
noisy_latents,
|
||||
conditional_embeds,
|
||||
@@ -515,6 +524,7 @@ def get_guidance_loss(
|
||||
)
|
||||
|
||||
elif guidance_type == "targeted_polarity":
|
||||
assert unconditional_embeds is None, "Unconditional embeds are not supported for targeted polarity guidance"
|
||||
return get_targeted_polarity_loss(
|
||||
noisy_latents,
|
||||
conditional_embeds,
|
||||
@@ -538,6 +548,7 @@ def get_guidance_loss(
|
||||
batch,
|
||||
noise,
|
||||
sd,
|
||||
unconditional_embeds=unconditional_embeds,
|
||||
**kwargs
|
||||
)
|
||||
else:
|
||||
|
||||
@@ -184,7 +184,7 @@ class IPAdapter(torch.nn.Module):
|
||||
self.clip_image_processor = SAFEImageProcessor()
|
||||
self.image_encoder = SAFEVisionModel(
|
||||
in_channels=3,
|
||||
num_tokens=8,
|
||||
num_tokens=self.config.safe_tokens,
|
||||
num_vectors=sd.unet.config['cross_attention_dim'],
|
||||
reducer_channels=self.config.safe_reducer_channels,
|
||||
channels=self.config.safe_channels,
|
||||
@@ -234,8 +234,8 @@ class IPAdapter(torch.nn.Module):
|
||||
dim = sd.unet.config['cross_attention_dim'] if not sd.is_xl else 1280
|
||||
embedding_dim = self.image_encoder.config.hidden_size if not self.config.image_encoder_arch == "convnext" else self.image_encoder.config.hidden_sizes[-1]
|
||||
|
||||
if self.config.image_encoder_arch == 'safe':
|
||||
embedding_dim = self.config.safe_channels
|
||||
# if self.config.image_encoder_arch == 'safe':
|
||||
# embedding_dim = self.config.safe_tokens
|
||||
# size mismatch for latents: copying a param with shape torch.Size([1, 16, 1280]) from checkpoint, the shape in current model is torch.Size([1, 16, 2048]).
|
||||
# size mismatch for latents: copying a param with shape torch.Size([1, 32, 2048]) from checkpoint, the shape in current model is torch.Size([1, 16, 1280])
|
||||
# ip-adapter-plus
|
||||
|
||||
Reference in New Issue
Block a user