Added config to set the min value of a mask

This commit is contained in:
Jaret Burkett
2023-10-09 15:47:54 -06:00
parent bb1d3793e3
commit f4c90bb589
2 changed files with 4 additions and 1 deletions

View File

@@ -236,6 +236,7 @@ class DatasetConfig:
self.augments: List[str] = kwargs.get('augments', [])
self.control_path: str = kwargs.get('control_path', None) # depth maps, etc
self.mask_path: str = kwargs.get('mask_path', None) # focus mask (black and white. White has higher loss than black)
self.mask_min_value: float = kwargs.get('mask_min_value', 0.01) # min value for . 0 - 1
self.poi: Union[str, None] = kwargs.get('poi', None) # if one is set and in json data, will be used as auto crop scale point of interes
# cache latents will store them in memory

View File

@@ -11,7 +11,7 @@ import torch
from safetensors.torch import load_file, save_file
from tqdm import tqdm
from toolkit.basic import flush
from toolkit.basic import flush, value_map
from toolkit.buckets import get_bucket_for_image_size
from toolkit.metadata import get_meta_for_safetensors
from toolkit.prompt_utils import inject_trigger_into_prompt
@@ -442,6 +442,7 @@ class MaskFileItemDTOMixin:
self.mask_path: Union[str, None] = None
self.mask_tensor: Union[torch.Tensor, None] = None
dataset_config: 'DatasetConfig' = kwargs.get('dataset_config', None)
self.mask_min_value = dataset_config.mask_min_value
if dataset_config.mask_path is not None:
# find the control image path
mask_path = dataset_config.mask_path
@@ -502,6 +503,7 @@ class MaskFileItemDTOMixin:
raise Exception("Mask images not supported for non-bucket datasets")
self.mask_tensor = transforms.ToTensor()(img)
self.mask_tensor = value_map(self.mask_tensor, 0, 1.0, self.mask_min_value, 1.0)
# convert to grayscale
def cleanup_mask(self: 'FileItemDTO'):