mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-30 19:21:39 +00:00
Added initial direct vision pixtral support
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@@ -175,3 +175,4 @@ cython_debug/
|
|||||||
!/extensions/example
|
!/extensions/example
|
||||||
/temp
|
/temp
|
||||||
/wandb
|
/wandb
|
||||||
|
.vscode/settings.json
|
||||||
@@ -206,6 +206,8 @@ class AdapterConfig:
|
|||||||
self.ilora_down: bool = kwargs.get('ilora_down', True)
|
self.ilora_down: bool = kwargs.get('ilora_down', True)
|
||||||
self.ilora_mid: bool = kwargs.get('ilora_mid', True)
|
self.ilora_mid: bool = kwargs.get('ilora_mid', True)
|
||||||
self.ilora_up: bool = kwargs.get('ilora_up', True)
|
self.ilora_up: bool = kwargs.get('ilora_up', True)
|
||||||
|
|
||||||
|
self.pixtral_max_image_size: int = kwargs.get('pixtral_max_image_size', 512)
|
||||||
|
|
||||||
self.flux_only_double: bool = kwargs.get('flux_only_double', False)
|
self.flux_only_double: bool = kwargs.get('flux_only_double', False)
|
||||||
|
|
||||||
|
|||||||
@@ -17,6 +17,7 @@ from toolkit.paths import REPOS_ROOT
|
|||||||
from toolkit.photomaker import PhotoMakerIDEncoder, FuseModule, PhotoMakerCLIPEncoder
|
from toolkit.photomaker import PhotoMakerIDEncoder, FuseModule, PhotoMakerCLIPEncoder
|
||||||
from toolkit.saving import load_ip_adapter_model, load_custom_adapter_model
|
from toolkit.saving import load_ip_adapter_model, load_custom_adapter_model
|
||||||
from toolkit.train_tools import get_torch_dtype
|
from toolkit.train_tools import get_torch_dtype
|
||||||
|
from toolkit.models.pixtral_vision import PixtralVisionEncoderCompatible, PixtralVisionImagePreprocessorCompatible
|
||||||
|
|
||||||
sys.path.append(REPOS_ROOT)
|
sys.path.append(REPOS_ROOT)
|
||||||
from typing import TYPE_CHECKING, Union, Iterator, Mapping, Any, Tuple, List, Optional, Dict
|
from typing import TYPE_CHECKING, Union, Iterator, Mapping, Any, Tuple, List, Optional, Dict
|
||||||
@@ -257,6 +258,13 @@ class CustomAdapter(torch.nn.Module):
|
|||||||
self.vision_encoder = SiglipVisionModel.from_pretrained(
|
self.vision_encoder = SiglipVisionModel.from_pretrained(
|
||||||
adapter_config.image_encoder_path,
|
adapter_config.image_encoder_path,
|
||||||
ignore_mismatched_sizes=True).to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype))
|
ignore_mismatched_sizes=True).to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype))
|
||||||
|
elif self.config.image_encoder_arch == 'pixtral':
|
||||||
|
self.image_processor = PixtralVisionImagePreprocessorCompatible(
|
||||||
|
max_image_size=self.config.pixtral_max_image_size,
|
||||||
|
)
|
||||||
|
self.vision_encoder = PixtralVisionEncoderCompatible.from_pretrained(
|
||||||
|
adapter_config.image_encoder_path,
|
||||||
|
).to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype))
|
||||||
elif self.config.image_encoder_arch == 'vit':
|
elif self.config.image_encoder_arch == 'vit':
|
||||||
try:
|
try:
|
||||||
self.image_processor = ViTFeatureExtractor.from_pretrained(adapter_config.image_encoder_path)
|
self.image_processor = ViTFeatureExtractor.from_pretrained(adapter_config.image_encoder_path)
|
||||||
@@ -700,9 +708,11 @@ class CustomAdapter(torch.nn.Module):
|
|||||||
else:
|
else:
|
||||||
return prompt_embeds
|
return prompt_embeds
|
||||||
|
|
||||||
def get_empty_clip_image(self, batch_size: int) -> torch.Tensor:
|
def get_empty_clip_image(self, batch_size: int, shape=None) -> torch.Tensor:
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
tensors_0_1 = torch.rand([batch_size, 3, self.input_size, self.input_size], device=self.device)
|
if shape is None:
|
||||||
|
shape = [batch_size, 3, self.input_size, self.input_size]
|
||||||
|
tensors_0_1 = torch.rand(shape, device=self.device)
|
||||||
noise_scale = torch.rand([tensors_0_1.shape[0], 1, 1, 1], device=self.device,
|
noise_scale = torch.rand([tensors_0_1.shape[0], 1, 1, 1], device=self.device,
|
||||||
dtype=get_torch_dtype(self.sd_ref().dtype))
|
dtype=get_torch_dtype(self.sd_ref().dtype))
|
||||||
tensors_0_1 = tensors_0_1 * noise_scale
|
tensors_0_1 = tensors_0_1 * noise_scale
|
||||||
@@ -761,7 +771,7 @@ class CustomAdapter(torch.nn.Module):
|
|||||||
batch_size = clip_image.shape[0]
|
batch_size = clip_image.shape[0]
|
||||||
if self.adapter_type == 'vision_direct' or self.adapter_type == 'te_augmenter':
|
if self.adapter_type == 'vision_direct' or self.adapter_type == 'te_augmenter':
|
||||||
# add an unconditional so we can save it
|
# add an unconditional so we can save it
|
||||||
unconditional = self.get_empty_clip_image(batch_size).to(
|
unconditional = self.get_empty_clip_image(batch_size, shape=clip_image.shape).to(
|
||||||
clip_image.device, dtype=clip_image.dtype
|
clip_image.device, dtype=clip_image.dtype
|
||||||
)
|
)
|
||||||
clip_image = torch.cat([unconditional, clip_image], dim=0)
|
clip_image = torch.cat([unconditional, clip_image], dim=0)
|
||||||
|
|||||||
@@ -17,6 +17,7 @@ from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
|
|||||||
from toolkit.basic import flush, value_map
|
from toolkit.basic import flush, value_map
|
||||||
from toolkit.buckets import get_bucket_for_image_size, get_resolution
|
from toolkit.buckets import get_bucket_for_image_size, get_resolution
|
||||||
from toolkit.metadata import get_meta_for_safetensors
|
from toolkit.metadata import get_meta_for_safetensors
|
||||||
|
from toolkit.models.pixtral_vision import PixtralVisionImagePreprocessorCompatible
|
||||||
from toolkit.prompt_utils import inject_trigger_into_prompt
|
from toolkit.prompt_utils import inject_trigger_into_prompt
|
||||||
from torchvision import transforms
|
from torchvision import transforms
|
||||||
from PIL import Image, ImageFilter, ImageOps
|
from PIL import Image, ImageFilter, ImageOps
|
||||||
@@ -733,6 +734,7 @@ class ClipImageFileItemDTOMixin:
|
|||||||
return self._clip_vision_embeddings_path
|
return self._clip_vision_embeddings_path
|
||||||
|
|
||||||
def load_clip_image(self: 'FileItemDTO'):
|
def load_clip_image(self: 'FileItemDTO'):
|
||||||
|
is_dynamic_size_and_aspect = isinstance(self.clip_image_processor, PixtralVisionImagePreprocessorCompatible)
|
||||||
if self.is_vision_clip_cached:
|
if self.is_vision_clip_cached:
|
||||||
self.clip_image_embeds = load_file(self.get_clip_vision_embeddings_path())
|
self.clip_image_embeds = load_file(self.get_clip_vision_embeddings_path())
|
||||||
|
|
||||||
@@ -759,8 +761,24 @@ class ClipImageFileItemDTOMixin:
|
|||||||
if self.flip_y:
|
if self.flip_y:
|
||||||
# do a flip
|
# do a flip
|
||||||
img = img.transpose(Image.FLIP_TOP_BOTTOM)
|
img = img.transpose(Image.FLIP_TOP_BOTTOM)
|
||||||
|
|
||||||
if img.width != img.height:
|
if is_dynamic_size_and_aspect:
|
||||||
|
# just match the bucket size for now
|
||||||
|
if self.dataset_config.buckets:
|
||||||
|
# scale and crop based on file item
|
||||||
|
img = img.resize((self.scale_to_width, self.scale_to_height), Image.BICUBIC)
|
||||||
|
# img = transforms.CenterCrop((self.crop_height, self.crop_width))(img)
|
||||||
|
# crop
|
||||||
|
img = img.crop((
|
||||||
|
self.crop_x,
|
||||||
|
self.crop_y,
|
||||||
|
self.crop_x + self.crop_width,
|
||||||
|
self.crop_y + self.crop_height
|
||||||
|
))
|
||||||
|
else:
|
||||||
|
raise Exception("Control images not supported for non-bucket datasets")
|
||||||
|
|
||||||
|
elif img.width != img.height:
|
||||||
min_size = min(img.width, img.height)
|
min_size = min(img.width, img.height)
|
||||||
if self.dataset_config.square_crop:
|
if self.dataset_config.square_crop:
|
||||||
# center crop to a square
|
# center crop to a square
|
||||||
|
|||||||
@@ -33,7 +33,8 @@ class FeedForward(nn.Module):
|
|||||||
self.w3 = nn.Linear(dim, hidden_dim, bias=False)
|
self.w3 = nn.Linear(dim, hidden_dim, bias=False)
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
return self.w2(nn.functional.silu(self.w1(x)) * self.w3(x)) # type: ignore
|
# type: ignore
|
||||||
|
return self.w2(nn.functional.silu(self.w1(x)) * self.w3(x))
|
||||||
|
|
||||||
|
|
||||||
def repeat_kv(keys: torch.Tensor, values: torch.Tensor, repeats: int, dim: int) -> Tuple[torch.Tensor, torch.Tensor]:
|
def repeat_kv(keys: torch.Tensor, values: torch.Tensor, repeats: int, dim: int) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
@@ -103,15 +104,18 @@ class Attention(nn.Module):
|
|||||||
else:
|
else:
|
||||||
cache.update(xk, xv)
|
cache.update(xk, xv)
|
||||||
key, val = cache.key, cache.value
|
key, val = cache.key, cache.value
|
||||||
key = key.view(seqlen_sum * cache.max_seq_len, self.n_kv_heads, self.head_dim)
|
key = key.view(seqlen_sum * cache.max_seq_len,
|
||||||
val = val.view(seqlen_sum * cache.max_seq_len, self.n_kv_heads, self.head_dim)
|
self.n_kv_heads, self.head_dim)
|
||||||
|
val = val.view(seqlen_sum * cache.max_seq_len,
|
||||||
|
self.n_kv_heads, self.head_dim)
|
||||||
|
|
||||||
# Repeat keys and values to match number of query heads
|
# Repeat keys and values to match number of query heads
|
||||||
key, val = repeat_kv(key, val, self.repeats, dim=1)
|
key, val = repeat_kv(key, val, self.repeats, dim=1)
|
||||||
|
|
||||||
# xformers requires (B=1, S, H, D)
|
# xformers requires (B=1, S, H, D)
|
||||||
xq, key, val = xq[None, ...], key[None, ...], val[None, ...]
|
xq, key, val = xq[None, ...], key[None, ...], val[None, ...]
|
||||||
output = memory_efficient_attention(xq, key, val, mask if cache is None else cache.mask)
|
output = memory_efficient_attention(
|
||||||
|
xq, key, val, mask if cache is None else cache.mask)
|
||||||
output = output.view(seqlen_sum, self.n_heads * self.head_dim)
|
output = output.view(seqlen_sum, self.n_heads * self.head_dim)
|
||||||
|
|
||||||
assert isinstance(output, torch.Tensor)
|
assert isinstance(output, torch.Tensor)
|
||||||
@@ -260,8 +264,8 @@ class PixtralVisionEncoder(nn.Module):
|
|||||||
assert head_dim % 2 == 0, "ROPE requires even head_dim"
|
assert head_dim % 2 == 0, "ROPE requires even head_dim"
|
||||||
self._freqs_cis: Optional[torch.Tensor] = None
|
self._freqs_cis: Optional[torch.Tensor] = None
|
||||||
|
|
||||||
@staticmethod
|
@classmethod
|
||||||
def from_pretrained(pretrained_model_name_or_path: str) -> 'PixtralVisionEncoder':
|
def from_pretrained(cls, pretrained_model_name_or_path: str) -> 'PixtralVisionEncoder':
|
||||||
if os.path.isdir(pretrained_model_name_or_path):
|
if os.path.isdir(pretrained_model_name_or_path):
|
||||||
model_folder = pretrained_model_name_or_path
|
model_folder = pretrained_model_name_or_path
|
||||||
else:
|
else:
|
||||||
@@ -275,11 +279,12 @@ class PixtralVisionEncoder(nn.Module):
|
|||||||
with open(os.path.join(model_folder, "config.json"), "r") as f:
|
with open(os.path.join(model_folder, "config.json"), "r") as f:
|
||||||
config = json.load(f)
|
config = json.load(f)
|
||||||
|
|
||||||
model = PixtralVisionEncoder(**config)
|
model = cls(**config)
|
||||||
|
|
||||||
# see if there is a state_dict
|
# see if there is a state_dict
|
||||||
if os.path.exists(os.path.join(model_folder, "model.safetensors")):
|
if os.path.exists(os.path.join(model_folder, "model.safetensors")):
|
||||||
state_dict = load_file(os.path.join(model_folder, "model.safetensors"))
|
state_dict = load_file(os.path.join(
|
||||||
|
model_folder, "model.safetensors"))
|
||||||
model.load_state_dict(state_dict)
|
model.load_state_dict(state_dict)
|
||||||
|
|
||||||
return model
|
return model
|
||||||
@@ -319,14 +324,17 @@ class PixtralVisionEncoder(nn.Module):
|
|||||||
image_features: tensor of token features for all tokens of all images of
|
image_features: tensor of token features for all tokens of all images of
|
||||||
shape (N_toks, D)
|
shape (N_toks, D)
|
||||||
"""
|
"""
|
||||||
assert isinstance(images, list), f"Expected list of images, got {type(images)}"
|
assert isinstance(
|
||||||
|
images, list), f"Expected list of images, got {type(images)}"
|
||||||
assert all(len(img.shape) == 3 for img in
|
assert all(len(img.shape) == 3 for img in
|
||||||
images), f"Expected images with shape (C, H, W), got {[img.shape for img in images]}"
|
images), f"Expected images with shape (C, H, W), got {[img.shape for img in images]}"
|
||||||
# pass images through initial convolution independently
|
# pass images through initial convolution independently
|
||||||
patch_embeds_list = [self.patch_conv(img.unsqueeze(0)).squeeze(0) for img in images]
|
patch_embeds_list = [self.patch_conv(
|
||||||
|
img.unsqueeze(0)).squeeze(0) for img in images]
|
||||||
|
|
||||||
# flatten to a single sequence
|
# flatten to a single sequence
|
||||||
patch_embeds = torch.cat([p.flatten(1).permute(1, 0) for p in patch_embeds_list], dim=0)
|
patch_embeds = torch.cat([p.flatten(1).permute(1, 0)
|
||||||
|
for p in patch_embeds_list], dim=0)
|
||||||
patch_embeds = self.ln_pre(patch_embeds)
|
patch_embeds = self.ln_pre(patch_embeds)
|
||||||
|
|
||||||
# positional embeddings
|
# positional embeddings
|
||||||
@@ -355,7 +363,8 @@ class VisionLanguageAdapter(nn.Module):
|
|||||||
self.w_out = nn.Linear(out_dim, out_dim, bias=True)
|
self.w_out = nn.Linear(out_dim, out_dim, bias=True)
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
return self.w_out(self.gelu(self.w_in(x))) # type: ignore[no-any-return]
|
# type: ignore[no-any-return]
|
||||||
|
return self.w_out(self.gelu(self.w_in(x)))
|
||||||
|
|
||||||
|
|
||||||
class VisionTransformerBlocks(nn.Module):
|
class VisionTransformerBlocks(nn.Module):
|
||||||
@@ -401,7 +410,8 @@ def normalize(image: torch.Tensor, mean: torch.Tensor, std: torch.Tensor) -> tor
|
|||||||
Returns:
|
Returns:
|
||||||
torch.Tensor: Normalized image with shape (C, H, W).
|
torch.Tensor: Normalized image with shape (C, H, W).
|
||||||
"""
|
"""
|
||||||
assert image.shape[0] == len(mean) == len(std), f"{image.shape=}, {mean.shape=}, {std.shape=}"
|
assert image.shape[0] == len(mean) == len(
|
||||||
|
std), f"{image.shape=}, {mean.shape=}, {std.shape=}"
|
||||||
|
|
||||||
# Reshape mean and std to (C, 1, 1) for broadcasting
|
# Reshape mean and std to (C, 1, 1) for broadcasting
|
||||||
mean = mean.view(-1, 1, 1)
|
mean = mean.view(-1, 1, 1)
|
||||||
@@ -473,10 +483,12 @@ class PixtralVisionImagePreprocessor:
|
|||||||
"""
|
"""
|
||||||
# should not have batch
|
# should not have batch
|
||||||
if len(image.shape) == 4:
|
if len(image.shape) == 4:
|
||||||
raise ValueError(f"Expected image with shape (C, H, W), got {image.shape}")
|
raise ValueError(
|
||||||
|
f"Expected image with shape (C, H, W), got {image.shape}")
|
||||||
|
|
||||||
if image.min() < 0.0 or image.max() > 1.0:
|
if image.min() < 0.0 or image.max() > 1.0:
|
||||||
raise ValueError(f"image tensor values must be between 0 and 1. Got min: {image.min()}, max: {image.max()}")
|
raise ValueError(
|
||||||
|
f"image tensor values must be between 0 and 1. Got min: {image.min()}, max: {image.max()}")
|
||||||
|
|
||||||
w, h = self._image_to_num_tokens(image)
|
w, h = self._image_to_num_tokens(image)
|
||||||
assert w > 0
|
assert w > 0
|
||||||
@@ -490,3 +502,98 @@ class PixtralVisionImagePreprocessor:
|
|||||||
processed_image = transform_image(image, new_image_size)
|
processed_image = transform_image(image, new_image_size)
|
||||||
|
|
||||||
return processed_image
|
return processed_image
|
||||||
|
|
||||||
|
|
||||||
|
class PixtralVisionImagePreprocessorCompatibleReturn:
|
||||||
|
def __init__(self, pixel_values) -> None:
|
||||||
|
self.pixel_values = pixel_values
|
||||||
|
|
||||||
|
|
||||||
|
# Compatable version with ai toolkit flow
|
||||||
|
class PixtralVisionImagePreprocessorCompatible(PixtralVisionImagePreprocessor):
|
||||||
|
def __init__(self, image_patch_size=16, max_image_size=1024) -> None:
|
||||||
|
super().__init__(
|
||||||
|
image_patch_size=image_patch_size,
|
||||||
|
max_image_size=max_image_size
|
||||||
|
)
|
||||||
|
self.size = {
|
||||||
|
'height': max_image_size,
|
||||||
|
'width': max_image_size
|
||||||
|
}
|
||||||
|
self.image_mean = DATASET_MEAN
|
||||||
|
self.image_std = DATASET_STD
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
images,
|
||||||
|
return_tensors="pt",
|
||||||
|
do_resize=True,
|
||||||
|
do_rescale=False,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
out_stack = []
|
||||||
|
if len(images.shape) == 3:
|
||||||
|
images = images.unsqueeze(0)
|
||||||
|
for i in range(images.shape[0]):
|
||||||
|
image = images[i]
|
||||||
|
processed_image = super().__call__(image)
|
||||||
|
out_stack.append(processed_image)
|
||||||
|
|
||||||
|
output = torch.stack(out_stack, dim=0)
|
||||||
|
return PixtralVisionImagePreprocessorCompatibleReturn(output)
|
||||||
|
|
||||||
|
|
||||||
|
class PixtralVisionEncoderCompatibleReturn:
|
||||||
|
def __init__(self, hidden_states) -> None:
|
||||||
|
self.hidden_states = hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class PixtralVisionEncoderCompatibleConfig:
|
||||||
|
def __init__(self):
|
||||||
|
self.image_size = 1024
|
||||||
|
self.hidden_size = 1024
|
||||||
|
self.patch_size = 16
|
||||||
|
|
||||||
|
|
||||||
|
class PixtralVisionEncoderCompatible(PixtralVisionEncoder):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
hidden_size: int = 1024,
|
||||||
|
num_channels: int = 3,
|
||||||
|
image_size: int = 1024,
|
||||||
|
patch_size: int = 16,
|
||||||
|
intermediate_size: int = 4096,
|
||||||
|
num_hidden_layers: int = 24,
|
||||||
|
num_attention_heads: int = 16,
|
||||||
|
rope_theta: float = 1e4, # for rope-2D
|
||||||
|
image_token_id: int = 10,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
super().__init__(
|
||||||
|
hidden_size=hidden_size,
|
||||||
|
num_channels=num_channels,
|
||||||
|
image_size=image_size,
|
||||||
|
patch_size=patch_size,
|
||||||
|
intermediate_size=intermediate_size,
|
||||||
|
num_hidden_layers=num_hidden_layers,
|
||||||
|
num_attention_heads=num_attention_heads,
|
||||||
|
rope_theta=rope_theta,
|
||||||
|
image_token_id=image_token_id,
|
||||||
|
)
|
||||||
|
self.config = PixtralVisionEncoderCompatibleConfig()
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
images,
|
||||||
|
output_hidden_states=True,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
out_stack = []
|
||||||
|
if len(images.shape) == 3:
|
||||||
|
images = images.unsqueeze(0)
|
||||||
|
for i in range(images.shape[0]):
|
||||||
|
image = images[i]
|
||||||
|
# must be in an array
|
||||||
|
image_output = super().forward([image])
|
||||||
|
out_stack.append(image_output)
|
||||||
|
|
||||||
|
output = torch.stack(out_stack, dim=0)
|
||||||
|
return PixtralVisionEncoderCompatibleReturn([output])
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ from collections import OrderedDict
|
|||||||
|
|
||||||
from diffusers import Transformer2DModel, FluxTransformer2DModel
|
from diffusers import Transformer2DModel, FluxTransformer2DModel
|
||||||
from transformers import T5EncoderModel, CLIPTextModel, CLIPTokenizer, T5Tokenizer, CLIPVisionModelWithProjection
|
from transformers import T5EncoderModel, CLIPTextModel, CLIPTokenizer, T5Tokenizer, CLIPVisionModelWithProjection
|
||||||
|
from toolkit.models.pixtral_vision import PixtralVisionEncoder, PixtralVisionImagePreprocessor
|
||||||
|
|
||||||
from toolkit.config_modules import AdapterConfig
|
from toolkit.config_modules import AdapterConfig
|
||||||
from toolkit.paths import REPOS_ROOT
|
from toolkit.paths import REPOS_ROOT
|
||||||
|
|||||||
@@ -1461,6 +1461,7 @@ class StableDiffusion:
|
|||||||
detach_unconditional=False,
|
detach_unconditional=False,
|
||||||
rescale_cfg=None,
|
rescale_cfg=None,
|
||||||
return_conditional_pred=False,
|
return_conditional_pred=False,
|
||||||
|
guidance_embedding_scale=1.0,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
conditional_pred = None
|
conditional_pred = None
|
||||||
@@ -1736,10 +1737,12 @@ class StableDiffusion:
|
|||||||
txt_ids = torch.zeros(bs, text_embeddings.text_embeds.shape[1], 3).to(self.device_torch)
|
txt_ids = torch.zeros(bs, text_embeddings.text_embeds.shape[1], 3).to(self.device_torch)
|
||||||
|
|
||||||
# # handle guidance
|
# # handle guidance
|
||||||
guidance_scale = 1.0 # ?
|
|
||||||
if self.unet.config.guidance_embeds:
|
if self.unet.config.guidance_embeds:
|
||||||
guidance = torch.tensor([guidance_scale], device=self.device_torch)
|
if isinstance(guidance_scale, list):
|
||||||
guidance = guidance.expand(latents.shape[0])
|
guidance = torch.tensor(guidance_scale, device=self.device_torch)
|
||||||
|
else:
|
||||||
|
guidance = torch.tensor([guidance_scale], device=self.device_torch)
|
||||||
|
guidance = guidance.expand(latents.shape[0])
|
||||||
else:
|
else:
|
||||||
guidance = None
|
guidance = None
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user