mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-30 19:21:39 +00:00
Added base for using guidance during training. Still not working right.
This commit is contained in:
@@ -1,6 +1,8 @@
|
|||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from typing import Union
|
from typing import Union
|
||||||
from diffusers import T2IAdapter
|
from diffusers import T2IAdapter
|
||||||
|
|
||||||
|
from toolkit import train_tools
|
||||||
from toolkit.basic import value_map
|
from toolkit.basic import value_map
|
||||||
from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO
|
from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO
|
||||||
from toolkit.ip_adapter import IPAdapter
|
from toolkit.ip_adapter import IPAdapter
|
||||||
@@ -30,6 +32,7 @@ class SDTrainer(BaseSDTrainProcess):
|
|||||||
super().__init__(process_id, job, config, **kwargs)
|
super().__init__(process_id, job, config, **kwargs)
|
||||||
self.assistant_adapter: Union['T2IAdapter', None]
|
self.assistant_adapter: Union['T2IAdapter', None]
|
||||||
self.do_prior_prediction = False
|
self.do_prior_prediction = False
|
||||||
|
self.target_class = self.get_conf('target_class', '')
|
||||||
if self.train_config.inverted_mask_prior:
|
if self.train_config.inverted_mask_prior:
|
||||||
self.do_prior_prediction = True
|
self.do_prior_prediction = True
|
||||||
|
|
||||||
@@ -171,6 +174,99 @@ class SDTrainer(BaseSDTrainProcess):
|
|||||||
def preprocess_batch(self, batch: 'DataLoaderBatchDTO'):
|
def preprocess_batch(self, batch: 'DataLoaderBatchDTO'):
|
||||||
return batch
|
return batch
|
||||||
|
|
||||||
|
def get_guided_loss(
|
||||||
|
self,
|
||||||
|
noisy_latents: torch.Tensor,
|
||||||
|
conditional_embeds: PromptEmbeds,
|
||||||
|
match_adapter_assist: bool,
|
||||||
|
network_weight_list: list,
|
||||||
|
timesteps: torch.Tensor,
|
||||||
|
pred_kwargs: dict,
|
||||||
|
batch: 'DataLoaderBatchDTO',
|
||||||
|
noise: torch.Tensor,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
with torch.no_grad():
|
||||||
|
dtype = get_torch_dtype(self.train_config.dtype)
|
||||||
|
# target class is unconditional
|
||||||
|
target_class_embeds = self.sd.encode_prompt(self.target_class).detach()
|
||||||
|
|
||||||
|
if batch.unconditional_latents is not None:
|
||||||
|
# do the unconditional prediction here instead of a prior prediction
|
||||||
|
unconditional_noisy_latents = self.sd.noise_scheduler.add_noise(batch.unconditional_latents, noise,
|
||||||
|
timesteps)
|
||||||
|
|
||||||
|
was_network_active = self.network.is_active
|
||||||
|
self.network.is_active = False
|
||||||
|
self.sd.unet.eval()
|
||||||
|
|
||||||
|
guidance_scale = 1.0
|
||||||
|
|
||||||
|
def cfg(uncon, con):
|
||||||
|
return uncon + guidance_scale * (
|
||||||
|
con - uncon
|
||||||
|
)
|
||||||
|
|
||||||
|
target_conditional = self.sd.predict_noise(
|
||||||
|
latents=noisy_latents.to(self.device_torch, dtype=dtype).detach(),
|
||||||
|
conditional_embeddings=conditional_embeds.to(self.device_torch, dtype=dtype).detach(),
|
||||||
|
timestep=timesteps,
|
||||||
|
guidance_scale=1.0,
|
||||||
|
**pred_kwargs # adapter residuals in here
|
||||||
|
).detach()
|
||||||
|
|
||||||
|
target_unconditional = self.sd.predict_noise(
|
||||||
|
latents=unconditional_noisy_latents.to(self.device_torch, dtype=dtype).detach(),
|
||||||
|
conditional_embeddings=target_class_embeds.to(self.device_torch, dtype=dtype).detach(),
|
||||||
|
timestep=timesteps,
|
||||||
|
guidance_scale=1.0,
|
||||||
|
**pred_kwargs # adapter residuals in here
|
||||||
|
).detach()
|
||||||
|
|
||||||
|
neutral_latents = (noisy_latents + unconditional_noisy_latents) / 2.0
|
||||||
|
|
||||||
|
target_noise = cfg(target_unconditional, target_conditional)
|
||||||
|
# latents = self.noise_scheduler.step(target_noise, timesteps, noisy_latents, return_dict=False)[0]
|
||||||
|
|
||||||
|
# target_pred = target_pred - noisy_latents + (unconditional_noisy_latents - noise)
|
||||||
|
|
||||||
|
# target_noise_res = noisy_latents - unconditional_noisy_latents
|
||||||
|
|
||||||
|
# target_pred = cfg(unconditional_noisy_latents, target_pred)
|
||||||
|
# target_pred = target_pred + target_noise_res
|
||||||
|
|
||||||
|
self.network.is_active = True
|
||||||
|
self.sd.unet.train()
|
||||||
|
|
||||||
|
prediction = self.sd.predict_noise(
|
||||||
|
latents=neutral_latents.to(self.device_torch, dtype=dtype).detach(),
|
||||||
|
conditional_embeddings=conditional_embeds.to(self.device_torch, dtype=dtype).detach(),
|
||||||
|
timestep=timesteps,
|
||||||
|
guidance_scale=1.0,
|
||||||
|
**pred_kwargs # adapter residuals in here
|
||||||
|
)
|
||||||
|
|
||||||
|
# prediction_res = target_pred - prediction
|
||||||
|
|
||||||
|
|
||||||
|
# prediction = cfg(prediction, target_pred)
|
||||||
|
|
||||||
|
loss = torch.nn.functional.mse_loss(prediction.float(), target_noise.float(), reduction="none")
|
||||||
|
loss = loss.mean([1, 2, 3])
|
||||||
|
|
||||||
|
if self.train_config.learnable_snr_gos:
|
||||||
|
# add snr_gamma
|
||||||
|
loss = apply_learnable_snr_gos(loss, timesteps, self.snr_gos)
|
||||||
|
elif self.train_config.snr_gamma is not None and self.train_config.snr_gamma > 0.000001:
|
||||||
|
# add snr_gamma
|
||||||
|
loss = apply_snr_weight(loss, timesteps, self.sd.noise_scheduler, self.train_config.snr_gamma, fixed=True)
|
||||||
|
elif self.train_config.min_snr_gamma is not None and self.train_config.min_snr_gamma > 0.000001:
|
||||||
|
# add min_snr_gamma
|
||||||
|
loss = apply_snr_weight(loss, timesteps, self.sd.noise_scheduler, self.train_config.min_snr_gamma)
|
||||||
|
|
||||||
|
loss = loss.mean()
|
||||||
|
return loss
|
||||||
|
|
||||||
def get_prior_prediction(
|
def get_prior_prediction(
|
||||||
self,
|
self,
|
||||||
noisy_latents: torch.Tensor,
|
noisy_latents: torch.Tensor,
|
||||||
@@ -369,8 +465,6 @@ class SDTrainer(BaseSDTrainProcess):
|
|||||||
else:
|
else:
|
||||||
prompt_2_list = [prompts_2]
|
prompt_2_list = [prompts_2]
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
for noisy_latents, noise, timesteps, conditioned_prompts, imgs, adapter_images, mask_multiplier, prompt_2 in zip(
|
for noisy_latents, noise, timesteps, conditioned_prompts, imgs, adapter_images, mask_multiplier, prompt_2 in zip(
|
||||||
noisy_latents_list,
|
noisy_latents_list,
|
||||||
noise_list,
|
noise_list,
|
||||||
@@ -386,8 +480,9 @@ class SDTrainer(BaseSDTrainProcess):
|
|||||||
with self.timer('encode_prompt'):
|
with self.timer('encode_prompt'):
|
||||||
if grad_on_text_encoder:
|
if grad_on_text_encoder:
|
||||||
with torch.set_grad_enabled(True):
|
with torch.set_grad_enabled(True):
|
||||||
conditional_embeds = self.sd.encode_prompt(conditioned_prompts, prompt_2, long_prompts=True).to(
|
conditional_embeds = self.sd.encode_prompt(conditioned_prompts, prompt_2,
|
||||||
# conditional_embeds = self.sd.encode_prompt(conditioned_prompts, prompt_2, long_prompts=False).to(
|
long_prompts=True).to(
|
||||||
|
# conditional_embeds = self.sd.encode_prompt(conditioned_prompts, prompt_2, long_prompts=False).to(
|
||||||
self.device_torch,
|
self.device_torch,
|
||||||
dtype=dtype)
|
dtype=dtype)
|
||||||
else:
|
else:
|
||||||
@@ -398,8 +493,9 @@ class SDTrainer(BaseSDTrainProcess):
|
|||||||
te.eval()
|
te.eval()
|
||||||
else:
|
else:
|
||||||
self.sd.text_encoder.eval()
|
self.sd.text_encoder.eval()
|
||||||
conditional_embeds = self.sd.encode_prompt(conditioned_prompts, prompt_2, long_prompts=True).to(
|
conditional_embeds = self.sd.encode_prompt(conditioned_prompts, prompt_2,
|
||||||
# conditional_embeds = self.sd.encode_prompt(conditioned_prompts, prompt_2, long_prompts=False).to(
|
long_prompts=True).to(
|
||||||
|
# conditional_embeds = self.sd.encode_prompt(conditioned_prompts, prompt_2, long_prompts=False).to(
|
||||||
self.device_torch,
|
self.device_torch,
|
||||||
dtype=dtype)
|
dtype=dtype)
|
||||||
|
|
||||||
@@ -450,27 +546,42 @@ class SDTrainer(BaseSDTrainProcess):
|
|||||||
conditional_embeds = self.adapter(conditional_embeds, conditional_clip_embeds)
|
conditional_embeds = self.adapter(conditional_embeds, conditional_clip_embeds)
|
||||||
|
|
||||||
self.before_unet_predict()
|
self.before_unet_predict()
|
||||||
with self.timer('predict_unet'):
|
# do a prior pred if we have an unconditional image, we will swap out the giadance later
|
||||||
noise_pred = self.sd.predict_noise(
|
if batch.unconditional_latents is not None:
|
||||||
latents=noisy_latents.to(self.device_torch, dtype=dtype),
|
# do guided loss
|
||||||
conditional_embeddings=conditional_embeds.to(self.device_torch, dtype=dtype),
|
loss = self.get_guided_loss(
|
||||||
timestep=timesteps,
|
|
||||||
guidance_scale=1.0,
|
|
||||||
**pred_kwargs
|
|
||||||
)
|
|
||||||
self.after_unet_predict()
|
|
||||||
|
|
||||||
with self.timer('calculate_loss'):
|
|
||||||
noise = noise.to(self.device_torch, dtype=dtype).detach()
|
|
||||||
loss = self.calculate_loss(
|
|
||||||
noise_pred=noise_pred,
|
|
||||||
noise=noise,
|
|
||||||
noisy_latents=noisy_latents,
|
noisy_latents=noisy_latents,
|
||||||
|
conditional_embeds=conditional_embeds,
|
||||||
|
match_adapter_assist=match_adapter_assist,
|
||||||
|
network_weight_list=network_weight_list,
|
||||||
timesteps=timesteps,
|
timesteps=timesteps,
|
||||||
|
pred_kwargs=pred_kwargs,
|
||||||
batch=batch,
|
batch=batch,
|
||||||
mask_multiplier=mask_multiplier,
|
noise=noise,
|
||||||
prior_pred=prior_pred,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
else:
|
||||||
|
with self.timer('predict_unet'):
|
||||||
|
noise_pred = self.sd.predict_noise(
|
||||||
|
latents=noisy_latents.to(self.device_torch, dtype=dtype),
|
||||||
|
conditional_embeddings=conditional_embeds.to(self.device_torch, dtype=dtype),
|
||||||
|
timestep=timesteps,
|
||||||
|
guidance_scale=1.0,
|
||||||
|
**pred_kwargs
|
||||||
|
)
|
||||||
|
self.after_unet_predict()
|
||||||
|
|
||||||
|
with self.timer('calculate_loss'):
|
||||||
|
noise = noise.to(self.device_torch, dtype=dtype).detach()
|
||||||
|
loss = self.calculate_loss(
|
||||||
|
noise_pred=noise_pred,
|
||||||
|
noise=noise,
|
||||||
|
noisy_latents=noisy_latents,
|
||||||
|
timesteps=timesteps,
|
||||||
|
batch=batch,
|
||||||
|
mask_multiplier=mask_multiplier,
|
||||||
|
prior_pred=prior_pred,
|
||||||
|
)
|
||||||
# check if nan
|
# check if nan
|
||||||
if torch.isnan(loss):
|
if torch.isnan(loss):
|
||||||
raise ValueError("loss is nan")
|
raise ValueError("loss is nan")
|
||||||
|
|||||||
@@ -574,7 +574,12 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
|||||||
else:
|
else:
|
||||||
latents = self.sd.encode_images(imgs)
|
latents = self.sd.encode_images(imgs)
|
||||||
batch.latents = latents
|
batch.latents = latents
|
||||||
# flush() # todo check performance removing this
|
|
||||||
|
if batch.unconditional_tensor is not None and batch.unconditional_latents is None:
|
||||||
|
unconditional_imgs = batch.unconditional_tensor
|
||||||
|
unconditional_imgs = unconditional_imgs.to(self.device_torch, dtype=dtype)
|
||||||
|
unconditional_latents = self.sd.encode_images(unconditional_imgs)
|
||||||
|
batch.unconditional_latents = unconditional_latents
|
||||||
|
|
||||||
unaugmented_latents = None
|
unaugmented_latents = None
|
||||||
if self.train_config.loss_target == 'differential_noise':
|
if self.train_config.loss_target == 'differential_noise':
|
||||||
@@ -655,6 +660,10 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
|||||||
|
|
||||||
noisy_latents = self.sd.noise_scheduler.add_noise(latents, noise, timesteps)
|
noisy_latents = self.sd.noise_scheduler.add_noise(latents, noise, timesteps)
|
||||||
|
|
||||||
|
# determine scaled noise
|
||||||
|
# todo do we need to scale this or does it always predict full intensity
|
||||||
|
# noise = noisy_latents - latents
|
||||||
|
|
||||||
# https://github.com/huggingface/diffusers/blob/324d18fba23f6c9d7475b0ff7c777685f7128d40/examples/t2i_adapter/train_t2i_adapter_sdxl.py#L1170C17-L1171C77
|
# https://github.com/huggingface/diffusers/blob/324d18fba23f6c9d7475b0ff7c777685f7128d40/examples/t2i_adapter/train_t2i_adapter_sdxl.py#L1170C17-L1171C77
|
||||||
if self.train_config.loss_target == 'source' or self.train_config.loss_target == 'unaugmented':
|
if self.train_config.loss_target == 'source' or self.train_config.loss_target == 'unaugmented':
|
||||||
sigmas = self.get_sigmas(timesteps, len(noisy_latents.shape), noisy_latents.dtype)
|
sigmas = self.get_sigmas(timesteps, len(noisy_latents.shape), noisy_latents.dtype)
|
||||||
|
|||||||
@@ -351,6 +351,7 @@ class DatasetConfig:
|
|||||||
self.alpha_mask: bool = kwargs.get('alpha_mask', False) # if true, will use alpha channel as mask
|
self.alpha_mask: bool = kwargs.get('alpha_mask', False) # if true, will use alpha channel as mask
|
||||||
self.mask_path: str = kwargs.get('mask_path',
|
self.mask_path: str = kwargs.get('mask_path',
|
||||||
None) # focus mask (black and white. White has higher loss than black)
|
None) # focus mask (black and white. White has higher loss than black)
|
||||||
|
self.unconditional_path: str = kwargs.get('unconditional_path', None) # path where matching unconditional images are located
|
||||||
self.invert_mask: bool = kwargs.get('invert_mask', False) # invert mask
|
self.invert_mask: bool = kwargs.get('invert_mask', False) # invert mask
|
||||||
self.mask_min_value: float = kwargs.get('mask_min_value', 0.01) # min value for . 0 - 1
|
self.mask_min_value: float = kwargs.get('mask_min_value', 0.01) # min value for . 0 - 1
|
||||||
self.poi: Union[str, None] = kwargs.get('poi',
|
self.poi: Union[str, None] = kwargs.get('poi',
|
||||||
|
|||||||
@@ -7,7 +7,8 @@ from PIL.ImageOps import exif_transpose
|
|||||||
|
|
||||||
from toolkit import image_utils
|
from toolkit import image_utils
|
||||||
from toolkit.dataloader_mixins import CaptionProcessingDTOMixin, ImageProcessingDTOMixin, LatentCachingFileItemDTOMixin, \
|
from toolkit.dataloader_mixins import CaptionProcessingDTOMixin, ImageProcessingDTOMixin, LatentCachingFileItemDTOMixin, \
|
||||||
ControlFileItemDTOMixin, ArgBreakMixin, PoiFileItemDTOMixin, MaskFileItemDTOMixin, AugmentationFileItemDTOMixin
|
ControlFileItemDTOMixin, ArgBreakMixin, PoiFileItemDTOMixin, MaskFileItemDTOMixin, AugmentationFileItemDTOMixin, \
|
||||||
|
UnconditionalFileItemDTOMixin
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from toolkit.config_modules import DatasetConfig
|
from toolkit.config_modules import DatasetConfig
|
||||||
@@ -29,6 +30,7 @@ class FileItemDTO(
|
|||||||
ControlFileItemDTOMixin,
|
ControlFileItemDTOMixin,
|
||||||
MaskFileItemDTOMixin,
|
MaskFileItemDTOMixin,
|
||||||
AugmentationFileItemDTOMixin,
|
AugmentationFileItemDTOMixin,
|
||||||
|
UnconditionalFileItemDTOMixin,
|
||||||
PoiFileItemDTOMixin,
|
PoiFileItemDTOMixin,
|
||||||
ArgBreakMixin,
|
ArgBreakMixin,
|
||||||
):
|
):
|
||||||
@@ -70,6 +72,7 @@ class FileItemDTO(
|
|||||||
self.cleanup_latent()
|
self.cleanup_latent()
|
||||||
self.cleanup_control()
|
self.cleanup_control()
|
||||||
self.cleanup_mask()
|
self.cleanup_mask()
|
||||||
|
self.cleanup_unconditional()
|
||||||
|
|
||||||
|
|
||||||
class DataLoaderBatchDTO:
|
class DataLoaderBatchDTO:
|
||||||
@@ -82,6 +85,8 @@ class DataLoaderBatchDTO:
|
|||||||
self.control_tensor: Union[torch.Tensor, None] = None
|
self.control_tensor: Union[torch.Tensor, None] = None
|
||||||
self.mask_tensor: Union[torch.Tensor, None] = None
|
self.mask_tensor: Union[torch.Tensor, None] = None
|
||||||
self.unaugmented_tensor: Union[torch.Tensor, None] = None
|
self.unaugmented_tensor: Union[torch.Tensor, None] = None
|
||||||
|
self.unconditional_tensor: Union[torch.Tensor, None] = None
|
||||||
|
self.unconditional_latents: Union[torch.Tensor, None] = None
|
||||||
self.sigmas: Union[torch.Tensor, None] = None # can be added elseware and passed along training code
|
self.sigmas: Union[torch.Tensor, None] = None # can be added elseware and passed along training code
|
||||||
if not is_latents_cached:
|
if not is_latents_cached:
|
||||||
# only return a tensor if latents are not cached
|
# only return a tensor if latents are not cached
|
||||||
@@ -138,6 +143,22 @@ class DataLoaderBatchDTO:
|
|||||||
else:
|
else:
|
||||||
unaugmented_tensor.append(x.unaugmented_tensor)
|
unaugmented_tensor.append(x.unaugmented_tensor)
|
||||||
self.unaugmented_tensor = torch.cat([x.unsqueeze(0) for x in unaugmented_tensor])
|
self.unaugmented_tensor = torch.cat([x.unsqueeze(0) for x in unaugmented_tensor])
|
||||||
|
|
||||||
|
# add unconditional tensors
|
||||||
|
if any([x.unconditional_tensor is not None for x in self.file_items]):
|
||||||
|
# find one to use as a base
|
||||||
|
base_unconditional_tensor = None
|
||||||
|
for x in self.file_items:
|
||||||
|
if x.unaugmented_tensor is not None:
|
||||||
|
base_unconditional_tensor = x.unconditional_tensor
|
||||||
|
break
|
||||||
|
unconditional_tensor = []
|
||||||
|
for x in self.file_items:
|
||||||
|
if x.unconditional_tensor is None:
|
||||||
|
unconditional_tensor.append(torch.zeros_like(base_unconditional_tensor))
|
||||||
|
else:
|
||||||
|
unconditional_tensor.append(x.unconditional_tensor)
|
||||||
|
self.unconditional_tensor = torch.cat([x.unsqueeze(0) for x in unconditional_tensor])
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(e)
|
print(e)
|
||||||
raise e
|
raise e
|
||||||
|
|||||||
@@ -351,6 +351,8 @@ class ImageProcessingDTOMixin:
|
|||||||
self.load_control_image()
|
self.load_control_image()
|
||||||
if self.has_mask_image:
|
if self.has_mask_image:
|
||||||
self.load_mask_image()
|
self.load_mask_image()
|
||||||
|
if self.has_unconditional:
|
||||||
|
self.load_unconditional_image()
|
||||||
return
|
return
|
||||||
try:
|
try:
|
||||||
img = Image.open(self.path)
|
img = Image.open(self.path)
|
||||||
@@ -442,6 +444,8 @@ class ImageProcessingDTOMixin:
|
|||||||
self.load_control_image()
|
self.load_control_image()
|
||||||
if self.has_mask_image:
|
if self.has_mask_image:
|
||||||
self.load_mask_image()
|
self.load_mask_image()
|
||||||
|
if self.has_unconditional:
|
||||||
|
self.load_unconditional_image()
|
||||||
|
|
||||||
|
|
||||||
class ControlFileItemDTOMixin:
|
class ControlFileItemDTOMixin:
|
||||||
@@ -661,6 +665,80 @@ class MaskFileItemDTOMixin:
|
|||||||
self.mask_tensor = None
|
self.mask_tensor = None
|
||||||
|
|
||||||
|
|
||||||
|
class UnconditionalFileItemDTOMixin:
|
||||||
|
def __init__(self: 'FileItemDTO', *args, **kwargs):
|
||||||
|
if hasattr(super(), '__init__'):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.has_unconditional = False
|
||||||
|
self.unconditional_path: Union[str, None] = None
|
||||||
|
self.unconditional_tensor: Union[torch.Tensor, None] = None
|
||||||
|
self.unconditional_latent: Union[torch.Tensor, None] = None
|
||||||
|
self.unconditional_transforms = transforms.Compose(
|
||||||
|
[
|
||||||
|
transforms.ToTensor(),
|
||||||
|
transforms.Normalize([0.5], [0.5]),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
dataset_config: 'DatasetConfig' = kwargs.get('dataset_config', None)
|
||||||
|
|
||||||
|
if dataset_config.unconditional_path is not None:
|
||||||
|
# we are using control images
|
||||||
|
img_path = kwargs.get('path', None)
|
||||||
|
img_ext_list = ['.jpg', '.jpeg', '.png', '.webp']
|
||||||
|
file_name_no_ext = os.path.splitext(os.path.basename(img_path))[0]
|
||||||
|
for ext in img_ext_list:
|
||||||
|
if os.path.exists(os.path.join(dataset_config.unconditional_path, file_name_no_ext + ext)):
|
||||||
|
self.unconditional_path = os.path.join(dataset_config.unconditional_path, file_name_no_ext + ext)
|
||||||
|
self.has_unconditional = True
|
||||||
|
break
|
||||||
|
|
||||||
|
def load_unconditional_image(self: 'FileItemDTO'):
|
||||||
|
try:
|
||||||
|
img = Image.open(self.unconditional_path)
|
||||||
|
img = exif_transpose(img)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error: {e}")
|
||||||
|
print(f"Error loading image: {self.mask_path}")
|
||||||
|
|
||||||
|
img = img.convert('RGB')
|
||||||
|
w, h = img.size
|
||||||
|
if w > h and self.scale_to_width < self.scale_to_height:
|
||||||
|
# throw error, they should match
|
||||||
|
raise ValueError(
|
||||||
|
f"unexpected values: w={w}, h={h}, file_item.scale_to_width={self.scale_to_width}, file_item.scale_to_height={self.scale_to_height}, file_item.path={self.path}")
|
||||||
|
elif h > w and self.scale_to_height < self.scale_to_width:
|
||||||
|
# throw error, they should match
|
||||||
|
raise ValueError(
|
||||||
|
f"unexpected values: w={w}, h={h}, file_item.scale_to_width={self.scale_to_width}, file_item.scale_to_height={self.scale_to_height}, file_item.path={self.path}")
|
||||||
|
|
||||||
|
if self.flip_x:
|
||||||
|
# do a flip
|
||||||
|
img.transpose(Image.FLIP_LEFT_RIGHT)
|
||||||
|
if self.flip_y:
|
||||||
|
# do a flip
|
||||||
|
img.transpose(Image.FLIP_TOP_BOTTOM)
|
||||||
|
|
||||||
|
if self.dataset_config.buckets:
|
||||||
|
# scale and crop based on file item
|
||||||
|
img = img.resize((self.scale_to_width, self.scale_to_height), Image.BICUBIC)
|
||||||
|
# img = transforms.CenterCrop((self.crop_height, self.crop_width))(img)
|
||||||
|
# crop
|
||||||
|
img = img.crop((
|
||||||
|
self.crop_x,
|
||||||
|
self.crop_y,
|
||||||
|
self.crop_x + self.crop_width,
|
||||||
|
self.crop_y + self.crop_height
|
||||||
|
))
|
||||||
|
else:
|
||||||
|
raise Exception("Unconditional images are not supported for non-bucket datasets")
|
||||||
|
|
||||||
|
self.unconditional_tensor = self.unconditional_transforms(img)
|
||||||
|
|
||||||
|
def cleanup_unconditional(self: 'FileItemDTO'):
|
||||||
|
self.unconditional_tensor = None
|
||||||
|
self.unconditional_latent = None
|
||||||
|
|
||||||
|
|
||||||
class PoiFileItemDTOMixin:
|
class PoiFileItemDTOMixin:
|
||||||
# Point of interest bounding box. Allows for dynamic cropping without cropping out the main subject
|
# Point of interest bounding box. Allows for dynamic cropping without cropping out the main subject
|
||||||
# items in the poi will always be inside the image when random cropping
|
# items in the poi will always be inside the image when random cropping
|
||||||
|
|||||||
Reference in New Issue
Block a user