mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-30 11:11:37 +00:00
Add a config flag to trigger fast image size db builder. Add config flag to set unconditional prompt for guidance loss
This commit is contained in:
@@ -147,7 +147,7 @@ class SDTrainer(BaseSDTrainProcess):
|
|||||||
# cache unconditional embeds (blank prompt)
|
# cache unconditional embeds (blank prompt)
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
self.unconditional_embeds = self.sd.encode_prompt(
|
self.unconditional_embeds = self.sd.encode_prompt(
|
||||||
[''],
|
[self.train_config.unconditional_prompt],
|
||||||
long_prompts=self.do_long_prompts
|
long_prompts=self.do_long_prompts
|
||||||
).to(
|
).to(
|
||||||
self.device_torch,
|
self.device_torch,
|
||||||
|
|||||||
@@ -471,6 +471,7 @@ class TrainConfig:
|
|||||||
# contrastive loss
|
# contrastive loss
|
||||||
self.do_guidance_loss = kwargs.get('do_guidance_loss', False)
|
self.do_guidance_loss = kwargs.get('do_guidance_loss', False)
|
||||||
self.guidance_loss_target: Union[int, List[int, int]] = kwargs.get('guidance_loss_target', 3.0)
|
self.guidance_loss_target: Union[int, List[int, int]] = kwargs.get('guidance_loss_target', 3.0)
|
||||||
|
self.unconditional_prompt: str = kwargs.get('unconditional_prompt', '')
|
||||||
if isinstance(self.guidance_loss_target, tuple):
|
if isinstance(self.guidance_loss_target, tuple):
|
||||||
self.guidance_loss_target = list(self.guidance_loss_target)
|
self.guidance_loss_target = list(self.guidance_loss_target)
|
||||||
|
|
||||||
@@ -837,6 +838,9 @@ class DatasetConfig:
|
|||||||
self.controls = [self.controls]
|
self.controls = [self.controls]
|
||||||
# remove empty strings
|
# remove empty strings
|
||||||
self.controls = [control for control in self.controls if control.strip() != '']
|
self.controls = [control for control in self.controls if control.strip() != '']
|
||||||
|
|
||||||
|
# if true, will use a fask method to get image sizes. This can result in errors. Do not use unless you know what you are doing
|
||||||
|
self.fast_image_size: bool = kwargs.get('fast_image_size', False)
|
||||||
|
|
||||||
|
|
||||||
def preprocess_dataset_raw_config(raw_config: List[dict]) -> List[dict]:
|
def preprocess_dataset_raw_config(raw_config: List[dict]) -> List[dict]:
|
||||||
|
|||||||
@@ -84,15 +84,16 @@ class FileItemDTO(
|
|||||||
video.release()
|
video.release()
|
||||||
size_database[file_key] = (width, height, file_signature)
|
size_database[file_key] = (width, height, file_signature)
|
||||||
else:
|
else:
|
||||||
# original method is significantly faster, but some images are read sideways. Not sure why. Do slow method for now.
|
if self.dataset_config.fast_image_size:
|
||||||
# process width and height
|
# original method is significantly faster, but some images are read sideways. Not sure why. Do slow method by default.
|
||||||
# try:
|
try:
|
||||||
# w, h = image_utils.get_image_size(self.path)
|
w, h = image_utils.get_image_size(self.path)
|
||||||
# except image_utils.UnknownImageFormat:
|
except image_utils.UnknownImageFormat:
|
||||||
# print_once(f'Warning: Some images in the dataset cannot be fast read. ' + \
|
print_once(f'Warning: Some images in the dataset cannot be fast read. ' + \
|
||||||
# f'This process is faster for png, jpeg')
|
f'This process is faster for png, jpeg')
|
||||||
img = exif_transpose(Image.open(self.path))
|
else:
|
||||||
w, h = img.size
|
img = exif_transpose(Image.open(self.path))
|
||||||
|
w, h = img.size
|
||||||
size_database[file_key] = (w, h, file_signature)
|
size_database[file_key] = (w, h, file_signature)
|
||||||
self.width: int = w
|
self.width: int = w
|
||||||
self.height: int = h
|
self.height: int = h
|
||||||
|
|||||||
Reference in New Issue
Block a user