mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-03-01 00:33:56 +00:00
Handle inpainting training for control_lora adapter
This commit is contained in:
@@ -1670,11 +1670,12 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
)
|
||||
|
||||
else:
|
||||
with self.timer('predict_unet'):
|
||||
if unconditional_embeds is not None:
|
||||
unconditional_embeds = unconditional_embeds.to(self.device_torch, dtype=dtype).detach()
|
||||
if self.adapter and isinstance(self.adapter, CustomAdapter):
|
||||
if unconditional_embeds is not None:
|
||||
unconditional_embeds = unconditional_embeds.to(self.device_torch, dtype=dtype).detach()
|
||||
if self.adapter and isinstance(self.adapter, CustomAdapter):
|
||||
with self.timer('condition_noisy_latents'):
|
||||
noisy_latents = self.adapter.condition_noisy_latents(noisy_latents, batch)
|
||||
with self.timer('predict_unet'):
|
||||
noise_pred = self.predict_noise(
|
||||
noisy_latents=noisy_latents.to(self.device_torch, dtype=dtype),
|
||||
timesteps=timesteps,
|
||||
|
||||
@@ -252,9 +252,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
|
||||
test_image_paths = []
|
||||
if self.adapter_config is not None and self.adapter_config.test_img_path is not None:
|
||||
test_image_path_list = self.adapter_config.test_img_path.split(',')
|
||||
test_image_path_list = [p.strip() for p in test_image_path_list]
|
||||
test_image_path_list = [p for p in test_image_path_list if p != '']
|
||||
test_image_path_list = self.adapter_config.test_img_path
|
||||
# divide up images so they are evenly distributed across prompts
|
||||
for i in range(len(sample_config.prompts)):
|
||||
test_image_paths.append(test_image_path_list[i % len(test_image_path_list)])
|
||||
|
||||
@@ -165,7 +165,13 @@ class AdapterConfig:
|
||||
self.downscale_factor: int = kwargs.get('downscale_factor', 8)
|
||||
self.adapter_type: str = kwargs.get('adapter_type', 'full_adapter')
|
||||
self.image_dir: str = kwargs.get('image_dir', None)
|
||||
self.test_img_path: str = kwargs.get('test_img_path', None)
|
||||
self.test_img_path: List[str] = kwargs.get('test_img_path', None)
|
||||
if self.test_img_path is not None:
|
||||
if isinstance(self.test_img_path, str):
|
||||
self.test_img_path = self.test_img_path.split(',')
|
||||
self.test_img_path = [p.strip() for p in self.test_img_path]
|
||||
self.test_img_path = [p for p in self.test_img_path if p != '']
|
||||
|
||||
self.train: str = kwargs.get('train', False)
|
||||
self.image_encoder_path: str = kwargs.get('image_encoder_path', None)
|
||||
self.name_or_path = kwargs.get('name_or_path', None)
|
||||
@@ -244,6 +250,7 @@ class AdapterConfig:
|
||||
self.num_control_images: int = kwargs.get('num_control_images', 1)
|
||||
# decimal for how often the control is dropped out and replaced with noise 1.0 is 100%
|
||||
self.control_image_dropout: float = kwargs.get('control_image_dropout', 0.0)
|
||||
self.has_inpainting_input: bool = kwargs.get('has_inpainting_input', False)
|
||||
|
||||
|
||||
class EmbeddingConfig:
|
||||
@@ -714,6 +721,9 @@ class DatasetConfig:
|
||||
self.flip_y: bool = kwargs.get('flip_y', False)
|
||||
self.augments: List[str] = kwargs.get('augments', [])
|
||||
self.control_path: Union[str,List[str]] = kwargs.get('control_path', None) # depth maps, etc
|
||||
# inpaint images should be webp/png images with alpha channel. The alpha 0 (invisible) section will
|
||||
# be the part conditioned to be inpainted. The alpha 1 (visible) section will be the part that is ignored
|
||||
self.inpaint_path: Union[str,List[str]] = kwargs.get('inpaint_path', None)
|
||||
# instead of cropping ot match image, it will serve the full size control image (clip images ie for ip adapters)
|
||||
self.full_size_control_images: bool = kwargs.get('full_size_control_images', False)
|
||||
self.alpha_mask: bool = kwargs.get('alpha_mask', False) # if true, will use alpha channel as mask
|
||||
|
||||
@@ -569,13 +569,56 @@ class CustomAdapter(torch.nn.Module):
|
||||
def condition_noisy_latents(self, latents: torch.Tensor, batch:DataLoaderBatchDTO):
|
||||
with torch.no_grad():
|
||||
if self.adapter_type in ['control_lora']:
|
||||
# inpainting input is 0-1 (bs, 4, h, w) on batch.inpaint_tensor
|
||||
# 4th channel is the mask with 1 being keep area and 0 being area to inpaint.
|
||||
sd: StableDiffusion = self.sd_ref()
|
||||
control_tensor = batch.control_tensor
|
||||
inpainting_latent = None
|
||||
if self.config.has_inpainting_input:
|
||||
do_dropout = random.random() < self.config.control_image_dropout
|
||||
if batch.inpaint_tensor is not None and not do_dropout:
|
||||
# currently 0-1, we need rgb to be -1 to 1 before encoding with the vae
|
||||
inpainting_tensor_rgba = batch.inpaint_tensor.to(latents.device, dtype=latents.dtype)
|
||||
inpainting_tensor_mask = inpainting_tensor_rgba[:, 3:4, :, :]
|
||||
inpainting_tensor_rgb = inpainting_tensor_rgba[:, :3, :, :]
|
||||
# we need to make sure the inpaint area is black multiply the rgb channels by the mask
|
||||
inpainting_tensor_rgb = inpainting_tensor_rgb * inpainting_tensor_mask
|
||||
|
||||
# if it is not the size of batch.tensor, (bs,ch,h,w) then we need to resize it
|
||||
if inpainting_tensor_rgb.shape[2] != batch.tensor.shape[2] or inpainting_tensor_rgb.shape[3] != batch.tensor.shape[3]:
|
||||
inpainting_tensor_rgb = F.interpolate(inpainting_tensor_rgb, size=(batch.tensor.shape[2], batch.tensor.shape[3]), mode='bilinear')
|
||||
|
||||
# scale to -1 to 1
|
||||
inpainting_tensor_rgb = inpainting_tensor_rgb * 2 - 1
|
||||
|
||||
# encode it
|
||||
inpainting_latent = sd.encode_images(inpainting_tensor_rgb).to(latents.device, latents.dtype)
|
||||
|
||||
# resize the mask to match the new encoded size
|
||||
inpainting_tensor_mask = F.interpolate(inpainting_tensor_mask, size=(inpainting_latent.shape[2], inpainting_latent.shape[3]), mode='bilinear')
|
||||
inpainting_tensor_mask = inpainting_tensor_mask.to(latents.device, latents.dtype)
|
||||
# mask needs to be 1 for inpaint area and 0 for area to leave alone. So flip it.
|
||||
inpainting_tensor_mask = 1 - inpainting_tensor_mask
|
||||
# leave the mask as 0-1 and concat on channel of latents
|
||||
inpainting_latent = torch.cat((inpainting_latent, inpainting_tensor_mask), dim=1)
|
||||
else:
|
||||
# we have iinpainting but didnt get a control. or we are doing a dropout
|
||||
# the input needs to be all zeros for the latents and all 1s for the mask
|
||||
inpainting_latent = torch.zeros_like(latents)
|
||||
# add ones for the mask since we are technically inpainting everything
|
||||
inpainting_latent = torch.cat((inpainting_latent, torch.ones_like(inpainting_latent[:, :1, :, :])), dim=1)
|
||||
|
||||
if self.config.num_control_images == 1:
|
||||
# this is our only control
|
||||
control_latent = inpainting_latent.to(latents.device, latents.dtype)
|
||||
latents = torch.cat((latents, control_latent), dim=1)
|
||||
return latents.detach()
|
||||
|
||||
control_tensor = batch.control_tensor.to(latents.device, dtype=latents.dtype)
|
||||
if control_tensor is None:
|
||||
# concat random normal noise onto the latents
|
||||
# check dimension, this is before they are rearranged
|
||||
# it is latent_model_input = torch.cat([latents, control_image], dim=2) after rearranging
|
||||
ctrl = torch.randn(
|
||||
ctrl = torch.zeros(
|
||||
latents.shape[0], # bs
|
||||
latents.shape[1] * self.num_control_images, # ch
|
||||
latents.shape[2],
|
||||
@@ -583,6 +626,9 @@ class CustomAdapter(torch.nn.Module):
|
||||
device=latents.device,
|
||||
dtype=latents.dtype
|
||||
)
|
||||
if inpainting_latent is not None:
|
||||
# inpainting always comes first
|
||||
ctrl = torch.cat((inpainting_latent, ctrl), dim=1)
|
||||
latents = torch.cat((latents, ctrl), dim=1)
|
||||
return latents.detach()
|
||||
# if we have multiple control tensors, they come in like [bs, num_control_images, ch, h, w]
|
||||
@@ -622,6 +668,9 @@ class CustomAdapter(torch.nn.Module):
|
||||
control_latent_list.append(control_latent)
|
||||
# stack them on the channel dimension
|
||||
control_latent = torch.cat(control_latent_list, dim=1)
|
||||
if inpainting_latent is not None:
|
||||
# inpainting always comes first
|
||||
control_latent = torch.cat((inpainting_latent, control_latent), dim=1)
|
||||
# concat it onto the latents
|
||||
latents = torch.cat((latents, control_latent), dim=1)
|
||||
return latents.detach()
|
||||
|
||||
@@ -12,7 +12,7 @@ from PIL.ImageOps import exif_transpose
|
||||
from toolkit import image_utils
|
||||
from toolkit.dataloader_mixins import CaptionProcessingDTOMixin, ImageProcessingDTOMixin, LatentCachingFileItemDTOMixin, \
|
||||
ControlFileItemDTOMixin, ArgBreakMixin, PoiFileItemDTOMixin, MaskFileItemDTOMixin, AugmentationFileItemDTOMixin, \
|
||||
UnconditionalFileItemDTOMixin, ClipImageFileItemDTOMixin
|
||||
UnconditionalFileItemDTOMixin, ClipImageFileItemDTOMixin, InpaintControlFileItemDTOMixin
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -34,6 +34,7 @@ class FileItemDTO(
|
||||
CaptionProcessingDTOMixin,
|
||||
ImageProcessingDTOMixin,
|
||||
ControlFileItemDTOMixin,
|
||||
InpaintControlFileItemDTOMixin,
|
||||
ClipImageFileItemDTOMixin,
|
||||
MaskFileItemDTOMixin,
|
||||
AugmentationFileItemDTOMixin,
|
||||
@@ -108,6 +109,7 @@ class FileItemDTO(
|
||||
self.tensor = None
|
||||
self.cleanup_latent()
|
||||
self.cleanup_control()
|
||||
self.cleanup_inpaint()
|
||||
self.cleanup_clip_image()
|
||||
self.cleanup_mask()
|
||||
self.cleanup_unconditional()
|
||||
@@ -154,6 +156,22 @@ class DataLoaderBatchDTO:
|
||||
else:
|
||||
control_tensors.append(x.control_tensor)
|
||||
self.control_tensor = torch.cat([x.unsqueeze(0) for x in control_tensors])
|
||||
|
||||
self.inpaint_tensor: Union[torch.Tensor, None] = None
|
||||
if any([x.inpaint_tensor is not None for x in self.file_items]):
|
||||
# find one to use as a base
|
||||
base_inpaint_tensor = None
|
||||
for x in self.file_items:
|
||||
if x.inpaint_tensor is not None:
|
||||
base_inpaint_tensor = x.inpaint_tensor
|
||||
break
|
||||
inpaint_tensors = []
|
||||
for x in self.file_items:
|
||||
if x.inpaint_tensor is None:
|
||||
inpaint_tensors.append(torch.zeros_like(base_inpaint_tensor))
|
||||
else:
|
||||
inpaint_tensors.append(x.inpaint_tensor)
|
||||
self.inpaint_tensor = torch.cat([x.unsqueeze(0) for x in inpaint_tensors])
|
||||
|
||||
self.loss_multiplier_list: List[float] = [x.loss_multiplier for x in self.file_items]
|
||||
|
||||
|
||||
@@ -635,6 +635,8 @@ class ImageProcessingDTOMixin:
|
||||
self.get_latent()
|
||||
if self.has_control_image:
|
||||
self.load_control_image()
|
||||
if self.has_inpaint_image:
|
||||
self.load_inpaint_image()
|
||||
if self.has_clip_image:
|
||||
self.load_clip_image()
|
||||
if self.has_mask_image:
|
||||
@@ -730,6 +732,8 @@ class ImageProcessingDTOMixin:
|
||||
if not only_load_latents:
|
||||
if self.has_control_image:
|
||||
self.load_control_image()
|
||||
if self.has_inpaint_image:
|
||||
self.load_inpaint_image()
|
||||
if self.has_clip_image:
|
||||
self.load_clip_image()
|
||||
if self.has_mask_image:
|
||||
@@ -738,6 +742,89 @@ class ImageProcessingDTOMixin:
|
||||
self.load_unconditional_image()
|
||||
|
||||
|
||||
class InpaintControlFileItemDTOMixin:
|
||||
def __init__(self: 'FileItemDTO', *args, **kwargs):
|
||||
if hasattr(super(), '__init__'):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.has_inpaint_image = False
|
||||
self.inpaint_path: Union[str, None] = None
|
||||
self.inpaint_tensor: Union[torch.Tensor, None] = None
|
||||
dataset_config: 'DatasetConfig' = kwargs.get('dataset_config', None)
|
||||
if dataset_config.inpaint_path is not None:
|
||||
# find the control image path
|
||||
inpaint_path = dataset_config.inpaint_path
|
||||
# we are using control images
|
||||
img_path = kwargs.get('path', None)
|
||||
img_ext_list = ['.png', '.webp']
|
||||
file_name_no_ext = os.path.splitext(os.path.basename(img_path))[0]
|
||||
|
||||
for ext in img_ext_list:
|
||||
p = os.path.join(inpaint_path, file_name_no_ext + ext)
|
||||
if os.path.exists(p):
|
||||
self.inpaint_path = p
|
||||
self.has_inpaint_image = True
|
||||
break
|
||||
|
||||
def load_inpaint_image(self: 'FileItemDTO'):
|
||||
try:
|
||||
# image must have alpha channel for inpaint
|
||||
img = Image.open(self.inpaint_path)
|
||||
# make sure has aplha
|
||||
if img.mode != 'RGBA':
|
||||
raise ValueError(f"Image must have alpha channel for inpaint: {self.inpaint_path}")
|
||||
img = exif_transpose(img)
|
||||
|
||||
w, h = img.size
|
||||
if w > h and self.scale_to_width < self.scale_to_height:
|
||||
# throw error, they should match
|
||||
raise ValueError(
|
||||
f"unexpected values: w={w}, h={h}, file_item.scale_to_width={self.scale_to_width}, file_item.scale_to_height={self.scale_to_height}, file_item.path={self.path}")
|
||||
elif h > w and self.scale_to_height < self.scale_to_width:
|
||||
# throw error, they should match
|
||||
raise ValueError(
|
||||
f"unexpected values: w={w}, h={h}, file_item.scale_to_width={self.scale_to_width}, file_item.scale_to_height={self.scale_to_height}, file_item.path={self.path}")
|
||||
|
||||
if self.flip_x:
|
||||
# do a flip
|
||||
img = img.transpose(Image.FLIP_LEFT_RIGHT)
|
||||
if self.flip_y:
|
||||
# do a flip
|
||||
img = img.transpose(Image.FLIP_TOP_BOTTOM)
|
||||
|
||||
if self.dataset_config.buckets:
|
||||
# scale and crop based on file item
|
||||
img = img.resize((self.scale_to_width, self.scale_to_height), Image.BICUBIC)
|
||||
# img = transforms.CenterCrop((self.crop_height, self.crop_width))(img)
|
||||
# crop
|
||||
img = img.crop((
|
||||
self.crop_x,
|
||||
self.crop_y,
|
||||
self.crop_x + self.crop_width,
|
||||
self.crop_y + self.crop_height
|
||||
))
|
||||
else:
|
||||
raise Exception("Inpaint images not supported for non-bucket datasets")
|
||||
|
||||
transform = transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
])
|
||||
if self.aug_replay_spatial_transforms:
|
||||
tensor = self.augment_spatial_control(img, transform=transform)
|
||||
else:
|
||||
tensor = transform(img)
|
||||
|
||||
# is 0 to 1 with alpha
|
||||
self.inpaint_tensor = tensor
|
||||
|
||||
except Exception as e:
|
||||
print_acc(f"Error: {e}")
|
||||
print_acc(f"Error loading image: {self.inpaint_path}")
|
||||
|
||||
|
||||
def cleanup_inpaint(self: 'FileItemDTO'):
|
||||
self.inpaint_tensor = None
|
||||
|
||||
|
||||
class ControlFileItemDTOMixin:
|
||||
def __init__(self: 'FileItemDTO', *args, **kwargs):
|
||||
if hasattr(super(), '__init__'):
|
||||
@@ -786,7 +873,7 @@ class ControlFileItemDTOMixin:
|
||||
print_acc(f"Error: {e}")
|
||||
print_acc(f"Error loading image: {control_path}")
|
||||
|
||||
if self.full_size_control_images:
|
||||
if not self.full_size_control_images:
|
||||
# we just scale them to 512x512:
|
||||
w, h = img.size
|
||||
img = img.resize((512, 512), Image.BICUBIC)
|
||||
|
||||
@@ -46,15 +46,24 @@ class ImgEmbedder(torch.nn.Module):
|
||||
cls,
|
||||
model: FluxTransformer2DModel,
|
||||
adapter: 'ControlLoraAdapter',
|
||||
num_control_images=1
|
||||
num_control_images=1,
|
||||
has_inpainting_input=False
|
||||
):
|
||||
if model.__class__.__name__ == 'FluxTransformer2DModel':
|
||||
if model.__class__.__name__ == 'FluxTransformer2DModel':
|
||||
num_adapter_in_channels = model.x_embedder.in_features * num_control_images
|
||||
|
||||
if has_inpainting_input:
|
||||
# inpainting has the mask before packing latents. it is normally 16 ch + 1ch mask
|
||||
# packed it is 64ch + 4ch mask
|
||||
# so we need to add 4 to the input channels
|
||||
num_adapter_in_channels += 4
|
||||
|
||||
x_embedder: torch.nn.Linear = model.x_embedder
|
||||
img_embedder = cls(
|
||||
adapter,
|
||||
orig_layer=x_embedder,
|
||||
in_channels=x_embedder.in_features * num_control_images,
|
||||
out_channels=x_embedder.out_features,
|
||||
in_channels=num_adapter_in_channels,
|
||||
out_channels=x_embedder.out_features,
|
||||
)
|
||||
|
||||
# hijack the forward method
|
||||
@@ -181,7 +190,8 @@ class ControlLoraAdapter(torch.nn.Module):
|
||||
self.x_embedder = ImgEmbedder.from_model(
|
||||
sd.unet,
|
||||
self,
|
||||
num_control_images=config.num_control_images
|
||||
num_control_images=config.num_control_images,
|
||||
has_inpainting_input=config.has_inpainting_input
|
||||
)
|
||||
self.x_embedder.to(self.device_torch)
|
||||
|
||||
|
||||
@@ -16,6 +16,10 @@ from k_diffusion.external import CompVisVDenoiser, CompVisDenoiser
|
||||
from k_diffusion.sampling import get_sigmas_karras, BrownianTreeNoiseSampler
|
||||
from toolkit.models.flux import bypass_flux_guidance, restore_flux_guidance
|
||||
from diffusers.image_processor import PipelineImageInput
|
||||
from PIL import Image
|
||||
import torch.nn.functional as F
|
||||
from torchvision import transforms
|
||||
|
||||
|
||||
|
||||
if is_torch_xla_available():
|
||||
@@ -1428,6 +1432,22 @@ class FluxWithCFGPipeline(FluxPipeline):
|
||||
|
||||
|
||||
class FluxAdvancedControlPipeline(FluxControlPipeline):
|
||||
def __init__(
|
||||
self,
|
||||
scheduler,
|
||||
vae,
|
||||
text_encoder,
|
||||
tokenizer,
|
||||
text_encoder_2,
|
||||
tokenizer_2,
|
||||
transformer,
|
||||
do_inpainting=False,
|
||||
num_controls=1,
|
||||
):
|
||||
self.do_inpainting = do_inpainting
|
||||
self.num_controls = num_controls
|
||||
super().__init__(scheduler, vae, text_encoder, tokenizer, text_encoder_2, tokenizer_2, transformer)
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self,
|
||||
@@ -1581,6 +1601,17 @@ class FluxAdvancedControlPipeline(FluxControlPipeline):
|
||||
# 4. Prepare latent variables
|
||||
# num_channels_latents = self.transformer.config.in_channels // 8
|
||||
num_channels_latents = 128 // 8
|
||||
|
||||
# pull mask off control image if there is one it is a pil image
|
||||
mask = None
|
||||
if control_image is not None and self.do_inpainting and control_image.mode == "RGBA":
|
||||
control_img_array = np.array(control_image)
|
||||
mask = control_img_array[:, :, 3:4]
|
||||
# scale it to 0 - 1
|
||||
mask = mask / 255.0
|
||||
# multiply rgb by mask
|
||||
control_img_array = control_img_array[:, :, :3] * mask
|
||||
control_image = Image.fromarray(control_img_array.astype(np.uint8))
|
||||
|
||||
control_image = self.prepare_image(
|
||||
image=control_image,
|
||||
@@ -1593,14 +1624,28 @@ class FluxAdvancedControlPipeline(FluxControlPipeline):
|
||||
)
|
||||
|
||||
if control_image.ndim == 4:
|
||||
num_control_channels = num_channels_latents
|
||||
control_image = self.vae.encode(control_image).latent_dist.sample(generator=generator)
|
||||
control_image = (control_image - self.vae.config.shift_factor) * self.vae.config.scaling_factor
|
||||
|
||||
if mask is not None:
|
||||
transform = transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
])
|
||||
mask = transform(mask).to(device, dtype=control_image.dtype).unsqueeze(0)
|
||||
# resize mask to match control image
|
||||
mask = F.interpolate(mask, size=(control_image.shape[2], control_image.shape[3]), mode="bilinear", align_corners=False)
|
||||
mask = mask.to(device)
|
||||
# invert mask
|
||||
mask = 1 - mask
|
||||
control_image = torch.cat([control_image, mask], dim=1)
|
||||
num_control_channels += 1
|
||||
|
||||
height_control_image, width_control_image = control_image.shape[2:]
|
||||
control_image = self._pack_latents(
|
||||
control_image,
|
||||
batch_size * num_images_per_prompt,
|
||||
num_channels_latents,
|
||||
num_control_channels,
|
||||
height_control_image,
|
||||
width_control_image,
|
||||
)
|
||||
@@ -1642,9 +1687,6 @@ class FluxAdvancedControlPipeline(FluxControlPipeline):
|
||||
guidance = guidance.expand(latents.shape[0])
|
||||
else:
|
||||
guidance = None
|
||||
|
||||
# flux has 64 input channels.
|
||||
total_controls = (self.transformer.config.in_channels // 64) - 1
|
||||
|
||||
# 6. Denoising loop
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
@@ -1652,7 +1694,16 @@ class FluxAdvancedControlPipeline(FluxControlPipeline):
|
||||
if self.interrupt:
|
||||
continue
|
||||
|
||||
control_image_list = [torch.zeros_like(latents) for _ in range(total_controls)]
|
||||
control_image_list = []
|
||||
for idx in range(self.num_controls):
|
||||
if idx == 0 and self.do_inpainting:
|
||||
ctrl = torch.zeros_like(latents)
|
||||
# do ones for mask and zeros for image
|
||||
ctrl = torch.cat([ctrl, torch.ones_like(ctrl[:, :, :4])], dim=2)
|
||||
control_image_list.append(ctrl)
|
||||
else:
|
||||
control_image_list.append(torch.zeros_like(latents))
|
||||
|
||||
control_image_list[control_image_idx] = control_image
|
||||
|
||||
latent_model_input = torch.cat([latents] + control_image_list, dim=2)
|
||||
|
||||
@@ -1246,6 +1246,8 @@ class StableDiffusion:
|
||||
# see if it is a control lora
|
||||
if self.adapter.control_lora is not None:
|
||||
Pipe = FluxAdvancedControlPipeline
|
||||
extra_args['do_inpainting'] = self.adapter.config.has_inpainting_input
|
||||
extra_args['num_controls'] = self.adapter.config.num_control_images
|
||||
|
||||
pipeline = Pipe(
|
||||
vae=self.vae,
|
||||
@@ -1257,6 +1259,7 @@ class StableDiffusion:
|
||||
scheduler=noise_scheduler,
|
||||
**extra_args
|
||||
)
|
||||
|
||||
pipeline.watermark = None
|
||||
elif self.is_lumina2:
|
||||
pipeline = Lumina2Text2ImgPipeline(
|
||||
@@ -1355,7 +1358,14 @@ class StableDiffusion:
|
||||
extra = {}
|
||||
validation_image = None
|
||||
if self.adapter is not None and gen_config.adapter_image_path is not None:
|
||||
validation_image = Image.open(gen_config.adapter_image_path).convert("RGB")
|
||||
validation_image = Image.open(gen_config.adapter_image_path)
|
||||
# if the name doesnt have .inpainting. in it, make sure it is rgb
|
||||
if ".inpaint." not in gen_config.adapter_image_path:
|
||||
validation_image = validation_image.convert("RGB")
|
||||
else:
|
||||
# make sure it has an alpha
|
||||
if validation_image.mode != "RGBA":
|
||||
raise ValueError("Inpainting images must have an alpha channel")
|
||||
if isinstance(self.adapter, T2IAdapter):
|
||||
# not sure why this is double??
|
||||
validation_image = validation_image.resize((gen_config.width * 2, gen_config.height * 2))
|
||||
|
||||
Reference in New Issue
Block a user