mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-03-13 14:39:50 +00:00
Bug fixes, work on maing IP adapters more customizable.
This commit is contained in:
@@ -827,7 +827,7 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
conditional_embeds = self.adapter(conditional_embeds.detach(), conditional_clip_embeds)
|
||||
|
||||
prior_pred = None
|
||||
if (has_adapter_img and self.assistant_adapter and match_adapter_assist) or (self.do_prior_prediction and not is_reg):
|
||||
if ((has_adapter_img and self.assistant_adapter and match_adapter_assist) or self.do_prior_prediction) and not is_reg:
|
||||
with self.timer('prior predict'):
|
||||
prior_pred = self.get_prior_prediction(
|
||||
noisy_latents=noisy_latents,
|
||||
|
||||
@@ -151,6 +151,16 @@ class AdapterConfig:
|
||||
|
||||
self.num_tokens: int = num_tokens
|
||||
self.train_image_encoder: bool = kwargs.get('train_image_encoder', False)
|
||||
self.image_encoder_arch: str = kwargs.get('image_encoder_arch', 'clip') # clip vit vit_hybrid
|
||||
|
||||
|
||||
|
||||
class ClipTokenMakerConfig:
|
||||
def __init__(self, **kwargs):
|
||||
self.image_encoder_path: str = kwargs.get('image_encoder_path', None)
|
||||
self.num_tokens: int = kwargs.get('num_tokens', 8)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -393,7 +403,7 @@ class DatasetConfig:
|
||||
None) # focus mask (black and white. White has higher loss than black)
|
||||
self.unconditional_path: str = kwargs.get('unconditional_path', None) # path where matching unconditional images are located
|
||||
self.invert_mask: bool = kwargs.get('invert_mask', False) # invert mask
|
||||
self.mask_min_value: float = kwargs.get('mask_min_value', 0.01) # min value for . 0 - 1
|
||||
self.mask_min_value: float = kwargs.get('mask_min_value', 0.0) # min value for . 0 - 1
|
||||
self.poi: Union[str, None] = kwargs.get('poi',
|
||||
None) # if one is set and in json data, will be used as auto crop scale point of interes
|
||||
self.num_repeats: int = kwargs.get('num_repeats', 1) # number of times to repeat dataset
|
||||
@@ -402,6 +412,8 @@ class DatasetConfig:
|
||||
# cache latents to disk will store them on disk. If both are true, it will save to disk, but keep in memory
|
||||
self.cache_latents_to_disk: bool = kwargs.get('cache_latents_to_disk', False)
|
||||
|
||||
self.standardize_images: bool = kwargs.get('standardize_images', False)
|
||||
|
||||
# https://albumentations.ai/docs/api_reference/augmentations/transforms
|
||||
# augmentations are returned as a separate image and cannot currently be cached
|
||||
self.augmentations: List[dict] = kwargs.get('augmentations', None)
|
||||
|
||||
@@ -8,6 +8,7 @@ from typing import List, TYPE_CHECKING
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
from PIL.ImageOps import exif_transpose
|
||||
from torchvision import transforms
|
||||
@@ -24,6 +25,45 @@ if TYPE_CHECKING:
|
||||
from toolkit.stable_diffusion_model import StableDiffusion
|
||||
|
||||
|
||||
class RescaleTransform:
|
||||
"""Transform to rescale images to the range [-1, 1]."""
|
||||
|
||||
def __call__(self, image):
|
||||
return image * 2 - 1
|
||||
|
||||
|
||||
class NormalizeSDXLTransform:
|
||||
"""
|
||||
Transforms the range from 0 to 1 to SDXL mean and std per channel based on avgs over thousands of images
|
||||
|
||||
Mean: tensor([ 0.0002, -0.1034, -0.1879])
|
||||
Standard Deviation: tensor([0.5436, 0.5116, 0.5033])
|
||||
"""
|
||||
|
||||
def __call__(self, image):
|
||||
return transforms.Normalize(
|
||||
mean=[0.0002, -0.1034, -0.1879],
|
||||
std=[0.5436, 0.5116, 0.5033],
|
||||
)(image)
|
||||
|
||||
|
||||
class NormalizeSD15Transform:
|
||||
"""
|
||||
Transforms the range from 0 to 1 to SDXL mean and std per channel based on avgs over thousands of images
|
||||
|
||||
Mean: tensor([-0.1600, -0.2450, -0.3227])
|
||||
Standard Deviation: tensor([0.5319, 0.4997, 0.5139])
|
||||
|
||||
"""
|
||||
|
||||
def __call__(self, image):
|
||||
return transforms.Normalize(
|
||||
mean=[-0.1600, -0.2450, -0.3227],
|
||||
std=[0.5319, 0.4997, 0.5139],
|
||||
)(image)
|
||||
|
||||
|
||||
|
||||
class ImageDataset(Dataset, CaptionMixin):
|
||||
def __init__(self, config):
|
||||
self.config = config
|
||||
@@ -63,7 +103,7 @@ class ImageDataset(Dataset, CaptionMixin):
|
||||
|
||||
self.transform = transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.5], [0.5]), # normalize to [-1, 1]
|
||||
RescaleTransform(),
|
||||
])
|
||||
|
||||
def get_config(self, key, default=None, required=False):
|
||||
@@ -200,7 +240,7 @@ class PairedImageDataset(Dataset):
|
||||
|
||||
self.transform = transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.5], [0.5]), # normalize to [-1, 1]
|
||||
RescaleTransform(),
|
||||
])
|
||||
|
||||
def get_all_prompts(self):
|
||||
@@ -368,6 +408,23 @@ class AiToolkitDataset(LatentCachingMixin, BucketsMixin, CaptionMixin, Dataset):
|
||||
# repeat the list
|
||||
file_list = file_list * self.dataset_config.num_repeats
|
||||
|
||||
if self.dataset_config.standardize_images:
|
||||
if self.sd.is_xl or self.sd.is_vega or self.sd.is_ssd:
|
||||
NormalizeMethod = NormalizeSDXLTransform
|
||||
else:
|
||||
NormalizeMethod = NormalizeSD15Transform
|
||||
|
||||
self.transform = transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
RescaleTransform(),
|
||||
NormalizeMethod(),
|
||||
])
|
||||
else:
|
||||
self.transform = transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
RescaleTransform(),
|
||||
])
|
||||
|
||||
# this might take a while
|
||||
print(f" - Preprocessing image dimensions")
|
||||
bad_count = 0
|
||||
@@ -375,7 +432,8 @@ class AiToolkitDataset(LatentCachingMixin, BucketsMixin, CaptionMixin, Dataset):
|
||||
try:
|
||||
file_item = FileItemDTO(
|
||||
path=file,
|
||||
dataset_config=dataset_config
|
||||
dataset_config=dataset_config,
|
||||
dataloader_transforms=self.transform,
|
||||
)
|
||||
self.file_list.append(file_item)
|
||||
except Exception as e:
|
||||
@@ -411,10 +469,6 @@ class AiToolkitDataset(LatentCachingMixin, BucketsMixin, CaptionMixin, Dataset):
|
||||
if self.dataset_config.flip_x or self.dataset_config.flip_y:
|
||||
print(f" - Found {len(self.file_list)} images after adding flips")
|
||||
|
||||
self.transform = transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.5], [0.5]), # normalize to [-1, 1]
|
||||
])
|
||||
|
||||
self.setup_epoch()
|
||||
|
||||
|
||||
@@ -48,6 +48,7 @@ class FileItemDTO(
|
||||
h, w = img.size
|
||||
self.width: int = w
|
||||
self.height: int = h
|
||||
self.dataloader_transforms = kwargs.get('dataloader_transforms', None)
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
# self.caption_path: str = kwargs.get('caption_path', None)
|
||||
@@ -64,6 +65,7 @@ class FileItemDTO(
|
||||
self.flip_y: bool = kwargs.get('flip_x', False)
|
||||
self.augments: List[str] = self.dataset_config.augments
|
||||
|
||||
|
||||
self.network_weight: float = self.dataset_config.network_weight
|
||||
self.is_reg = self.dataset_config.is_reg
|
||||
self.tensor: Union[torch.Tensor, None] = None
|
||||
|
||||
@@ -56,6 +56,30 @@ transforms_dict = {
|
||||
caption_ext_list = ['txt', 'json', 'caption']
|
||||
|
||||
|
||||
def standardize_images(images):
|
||||
"""
|
||||
Standardize the given batch of images using the specified mean and std.
|
||||
Expects values of 0 - 1
|
||||
|
||||
Args:
|
||||
images (torch.Tensor): A batch of images in the shape of (N, C, H, W),
|
||||
where N is the number of images, C is the number of channels,
|
||||
H is the height, and W is the width.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Standardized images.
|
||||
"""
|
||||
mean = [0.48145466, 0.4578275, 0.40821073]
|
||||
std = [0.26862954, 0.26130258, 0.27577711]
|
||||
|
||||
# Define the normalization transform
|
||||
normalize = transforms.Normalize(mean=mean, std=std)
|
||||
|
||||
# Apply normalization to each image in the batch
|
||||
standardized_images = torch.stack([normalize(img) for img in images])
|
||||
|
||||
return standardized_images
|
||||
|
||||
def clean_caption(caption):
|
||||
# remove any newlines
|
||||
caption = caption.replace('\n', ', ')
|
||||
@@ -520,8 +544,13 @@ class ControlFileItemDTOMixin:
|
||||
))
|
||||
else:
|
||||
raise Exception("Control images not supported for non-bucket datasets")
|
||||
|
||||
self.control_tensor = transforms.ToTensor()(img)
|
||||
transform = transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
])
|
||||
if self.aug_replay_spatial_transforms:
|
||||
self.control_tensor = self.augment_spatial_control(img, transform=transform)
|
||||
else:
|
||||
self.control_tensor = transform(img)
|
||||
|
||||
def cleanup_control(self: 'FileItemDTO'):
|
||||
self.control_tensor = None
|
||||
@@ -624,6 +653,8 @@ class AugmentationFileItemDTOMixin:
|
||||
self.unaugmented_tensor: Union[torch.Tensor, None] = None
|
||||
# self.augmentations: Union[None, List[Augments]] = None
|
||||
self.dataset_config: 'DatasetConfig' = kwargs.get('dataset_config', None)
|
||||
self.aug_transform: Union[None, A.Compose] = None
|
||||
self.aug_replay_spatial_transforms = None
|
||||
self.build_augmentation_transform()
|
||||
|
||||
def build_augmentation_transform(self: 'FileItemDTO'):
|
||||
@@ -643,7 +674,8 @@ class AugmentationFileItemDTOMixin:
|
||||
# add the method to the list
|
||||
augmentation_list.append(method(**aug.params))
|
||||
|
||||
self.aug_transform = A.Compose(augmentation_list)
|
||||
# add additional targets so we can augment the control image
|
||||
self.aug_transform = A.ReplayCompose(augmentation_list, additional_targets={'image2': 'image'})
|
||||
|
||||
def augment_image(self: 'FileItemDTO', img: Image, transform: Union[None, transforms.Compose], ):
|
||||
|
||||
@@ -659,7 +691,17 @@ class AugmentationFileItemDTOMixin:
|
||||
open_cv_image = open_cv_image[:, :, ::-1].copy()
|
||||
|
||||
# apply augmentations
|
||||
augmented = self.aug_transform(image=open_cv_image)["image"]
|
||||
transformed = self.aug_transform(image=open_cv_image)
|
||||
augmented = transformed["image"]
|
||||
|
||||
# save just the spatial transforms for controls and masks
|
||||
augmented_params = transformed["replay"]
|
||||
spatial_transforms = ['Rotate', 'Flip', 'HorizontalFlip', 'VerticalFlip', 'Resize', 'Crop', 'RandomCrop',
|
||||
'ElasticTransform', 'GridDistortion', 'OpticalDistortion']
|
||||
# only store the spatial transforms
|
||||
augmented_params['transforms'] = [t for t in augmented_params['transforms'] if t['__class_fullname__'].split('.')[-1] in spatial_transforms]
|
||||
|
||||
self.aug_replay_spatial_transforms = augmented_params
|
||||
|
||||
# convert back to RGB tensor
|
||||
augmented = cv2.cvtColor(augmented, cv2.COLOR_BGR2RGB)
|
||||
@@ -671,6 +713,38 @@ class AugmentationFileItemDTOMixin:
|
||||
|
||||
return augmented_tensor
|
||||
|
||||
# augment control images spatially consistent with transforms done to the main image
|
||||
def augment_spatial_control(self: 'FileItemDTO', img: Image, transform: Union[None, transforms.Compose] ):
|
||||
if self.aug_replay_spatial_transforms is None:
|
||||
# no transforms
|
||||
return transform(img)
|
||||
|
||||
# save colorspace to convert back to
|
||||
colorspace = img.mode
|
||||
|
||||
# convert to rgb
|
||||
img = img.convert('RGB')
|
||||
|
||||
open_cv_image = np.array(img)
|
||||
# Convert RGB to BGR
|
||||
open_cv_image = open_cv_image[:, :, ::-1].copy()
|
||||
|
||||
# Replay transforms
|
||||
transformed = A.ReplayCompose.replay(self.aug_replay_spatial_transforms, image=open_cv_image)
|
||||
augmented = transformed["image"]
|
||||
|
||||
# convert back to RGB tensor
|
||||
augmented = cv2.cvtColor(augmented, cv2.COLOR_BGR2RGB)
|
||||
|
||||
# convert to PIL image
|
||||
augmented = Image.fromarray(augmented)
|
||||
|
||||
# convert back to original colorspace
|
||||
augmented = augmented.convert(colorspace)
|
||||
|
||||
augmented_tensor = transforms.ToTensor()(augmented) if transform is None else transform(augmented)
|
||||
return augmented_tensor
|
||||
|
||||
def cleanup_control(self: 'FileItemDTO'):
|
||||
self.unaugmented_tensor = None
|
||||
|
||||
@@ -760,7 +834,13 @@ class MaskFileItemDTOMixin:
|
||||
else:
|
||||
raise Exception("Mask images not supported for non-bucket datasets")
|
||||
|
||||
self.mask_tensor = transforms.ToTensor()(img)
|
||||
transform = transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
])
|
||||
if self.aug_replay_spatial_transforms:
|
||||
self.mask_tensor = self.augment_spatial_control(img, transform=transform)
|
||||
else:
|
||||
self.mask_tensor = transform(img)
|
||||
self.mask_tensor = value_map(self.mask_tensor, 0, 1.0, self.mask_min_value, 1.0)
|
||||
# convert to grayscale
|
||||
|
||||
@@ -776,12 +856,7 @@ class UnconditionalFileItemDTOMixin:
|
||||
self.unconditional_path: Union[str, None] = None
|
||||
self.unconditional_tensor: Union[torch.Tensor, None] = None
|
||||
self.unconditional_latent: Union[torch.Tensor, None] = None
|
||||
self.unconditional_transforms = transforms.Compose(
|
||||
[
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.5], [0.5]),
|
||||
]
|
||||
)
|
||||
self.unconditional_transforms = self.dataloader_transforms
|
||||
dataset_config: 'DatasetConfig' = kwargs.get('dataset_config', None)
|
||||
|
||||
if dataset_config.unconditional_path is not None:
|
||||
@@ -835,7 +910,10 @@ class UnconditionalFileItemDTOMixin:
|
||||
else:
|
||||
raise Exception("Unconditional images are not supported for non-bucket datasets")
|
||||
|
||||
self.unconditional_tensor = self.unconditional_transforms(img)
|
||||
if self.aug_replay_spatial_transforms:
|
||||
self.unconditional_tensor = self.augment_spatial_control(img, transform=self.unconditional_transforms)
|
||||
else:
|
||||
self.unconditional_tensor = self.unconditional_transforms(img)
|
||||
|
||||
def cleanup_unconditional(self: 'FileItemDTO'):
|
||||
self.unconditional_tensor = None
|
||||
|
||||
@@ -26,8 +26,15 @@ if TYPE_CHECKING:
|
||||
from transformers import (
|
||||
CLIPImageProcessor,
|
||||
CLIPVisionModelWithProjection,
|
||||
CLIPVisionModel
|
||||
CLIPVisionModel,
|
||||
AutoImageProcessor,
|
||||
ConvNextModel,
|
||||
ConvNextForImageClassification,
|
||||
ConvNextImageProcessor
|
||||
)
|
||||
from transformers import ViTHybridImageProcessor, ViTHybridForImageClassification
|
||||
|
||||
from transformers import ViTFeatureExtractor, ViTForImageClassification
|
||||
|
||||
import torch.nn.functional as F
|
||||
|
||||
@@ -153,13 +160,51 @@ class IPAdapter(torch.nn.Module):
|
||||
super().__init__()
|
||||
self.config = adapter_config
|
||||
self.sd_ref: weakref.ref = weakref.ref(sd)
|
||||
try:
|
||||
self.clip_image_processor = CLIPImageProcessor.from_pretrained(adapter_config.image_encoder_path)
|
||||
except EnvironmentError:
|
||||
self.clip_image_processor = CLIPImageProcessor()
|
||||
self.device = self.sd_ref().unet.device
|
||||
self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(adapter_config.image_encoder_path,
|
||||
ignore_mismatched_sizes=True)
|
||||
if self.config.image_encoder_arch == 'clip':
|
||||
try:
|
||||
self.clip_image_processor = CLIPImageProcessor.from_pretrained(adapter_config.image_encoder_path)
|
||||
except EnvironmentError:
|
||||
self.clip_image_processor = CLIPImageProcessor()
|
||||
self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(
|
||||
adapter_config.image_encoder_path,
|
||||
ignore_mismatched_sizes=True).to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype))
|
||||
elif self.config.image_encoder_arch == 'vit':
|
||||
try:
|
||||
self.clip_image_processor = ViTFeatureExtractor.from_pretrained(adapter_config.image_encoder_path)
|
||||
except EnvironmentError:
|
||||
self.clip_image_processor = ViTFeatureExtractor()
|
||||
self.image_encoder = ViTForImageClassification.from_pretrained(adapter_config.image_encoder_path).to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype))
|
||||
elif self.config.image_encoder_arch == 'convnext':
|
||||
try:
|
||||
self.clip_image_processor = ConvNextImageProcessor.from_pretrained(adapter_config.image_encoder_path)
|
||||
except EnvironmentError:
|
||||
print(f"could not load image processor from {adapter_config.image_encoder_path}")
|
||||
self.clip_image_processor = ConvNextImageProcessor(
|
||||
size=320,
|
||||
image_mean=[0.48145466, 0.4578275, 0.40821073],
|
||||
image_std=[0.26862954, 0.26130258, 0.27577711],
|
||||
)
|
||||
self.image_encoder = ConvNextForImageClassification.from_pretrained(
|
||||
adapter_config.image_encoder_path,
|
||||
use_safetensors=True,
|
||||
).to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype))
|
||||
elif self.config.image_encoder_arch == 'vit-hybrid':
|
||||
try:
|
||||
self.clip_image_processor = ViTHybridImageProcessor.from_pretrained(adapter_config.image_encoder_path)
|
||||
except EnvironmentError:
|
||||
print(f"could not load image processor from {adapter_config.image_encoder_path}")
|
||||
self.clip_image_processor = ViTHybridImageProcessor(
|
||||
size=320,
|
||||
image_mean=[0.48145466, 0.4578275, 0.40821073],
|
||||
image_std=[0.26862954, 0.26130258, 0.27577711],
|
||||
)
|
||||
self.image_encoder = ViTHybridForImageClassification.from_pretrained(
|
||||
adapter_config.image_encoder_path,
|
||||
use_safetensors=True,
|
||||
).to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype))
|
||||
else:
|
||||
raise ValueError(f"unknown image encoder arch: {adapter_config.image_encoder_arch}")
|
||||
self.current_scale = 1.0
|
||||
self.is_active = True
|
||||
if adapter_config.type == 'ip':
|
||||
@@ -181,7 +226,7 @@ class IPAdapter(torch.nn.Module):
|
||||
dim_head=64,
|
||||
heads=heads,
|
||||
num_queries=self.config.num_tokens, # usually 16
|
||||
embedding_dim=self.image_encoder.config.hidden_size,
|
||||
embedding_dim=self.image_encoder.config.hidden_size if not self.config.image_encoder_arch == "convnext" else self.image_encoder.config.hidden_sizes[-1],
|
||||
output_dim=sd.unet.config['cross_attention_dim'],
|
||||
ff_mult=4
|
||||
)
|
||||
@@ -239,6 +284,10 @@ class IPAdapter(torch.nn.Module):
|
||||
|
||||
self.set_scale(1.0)
|
||||
|
||||
if self.config.train_image_encoder:
|
||||
self.image_encoder.train()
|
||||
self.image_encoder.requires_grad_(True)
|
||||
|
||||
def to(self, *args, **kwargs):
|
||||
super().to(*args, **kwargs)
|
||||
self.image_encoder.to(*args, **kwargs)
|
||||
@@ -280,7 +329,7 @@ class IPAdapter(torch.nn.Module):
|
||||
if isinstance(pil_image, Image.Image):
|
||||
pil_image = [pil_image]
|
||||
clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
|
||||
clip_image = clip_image.to(self.device, dtype=torch.float16)
|
||||
clip_image = clip_image.to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype))
|
||||
if drop:
|
||||
clip_image = clip_image * 0
|
||||
clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2]
|
||||
@@ -307,15 +356,18 @@ class IPAdapter(torch.nn.Module):
|
||||
do_resize=True,
|
||||
do_rescale=False,
|
||||
).pixel_values
|
||||
clip_image = clip_image.to(self.device, dtype=torch.float16).detach()
|
||||
clip_image = clip_image.to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype)).detach()
|
||||
if drop:
|
||||
clip_image = clip_image * 0
|
||||
with torch.set_grad_enabled(is_training):
|
||||
if is_training:
|
||||
self.image_encoder.train()
|
||||
clip_output = self.image_encoder(clip_image.requires_grad_(True)
|
||||
, output_hidden_states=True)
|
||||
else:
|
||||
self.image_encoder.eval()
|
||||
clip_output = self.image_encoder(clip_image, output_hidden_states=True)
|
||||
clip_output = self.image_encoder(clip_image, output_hidden_states=True)
|
||||
|
||||
clip_image_embeds = clip_output.hidden_states[-2]
|
||||
return clip_image_embeds
|
||||
|
||||
@@ -332,8 +384,16 @@ class IPAdapter(torch.nn.Module):
|
||||
yield from self.image_proj_model.parameters(recurse)
|
||||
if self.config.train_image_encoder:
|
||||
yield from self.image_encoder.parameters(recurse)
|
||||
# if self.config.train_image_encoder:
|
||||
# yield from self.image_encoder.parameters(recurse)
|
||||
# self.image_encoder.train()
|
||||
# else:
|
||||
# for attn_processor in self.adapter_modules:
|
||||
# yield from attn_processor.parameters(recurse)
|
||||
# yield from self.image_proj_model.parameters(recurse)
|
||||
|
||||
def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True):
|
||||
strict = False
|
||||
self.image_proj_model.load_state_dict(state_dict["image_proj"], strict=strict)
|
||||
self.adapter_modules.load_state_dict(state_dict["ip_adapter"], strict=strict)
|
||||
if self.config.train_image_encoder and 'image_encoder' in state_dict:
|
||||
|
||||
@@ -39,10 +39,11 @@ class PromptEmbeds:
|
||||
return self
|
||||
|
||||
def detach(self):
|
||||
self.text_embeds = self.text_embeds.detach()
|
||||
if self.pooled_embeds is not None:
|
||||
self.pooled_embeds = self.pooled_embeds.detach()
|
||||
return self
|
||||
new_embeds = self.clone()
|
||||
new_embeds.text_embeds = new_embeds.text_embeds.detach()
|
||||
if new_embeds.pooled_embeds is not None:
|
||||
new_embeds.pooled_embeds = new_embeds.pooled_embeds.detach()
|
||||
return new_embeds
|
||||
|
||||
def clone(self):
|
||||
if self.pooled_embeds is not None:
|
||||
|
||||
Reference in New Issue
Block a user