mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-21 14:59:02 +00:00
Added config to set the min value of a mask
This commit is contained in:
@@ -236,6 +236,7 @@ class DatasetConfig:
|
||||
self.augments: List[str] = kwargs.get('augments', [])
|
||||
self.control_path: str = kwargs.get('control_path', None) # depth maps, etc
|
||||
self.mask_path: str = kwargs.get('mask_path', None) # focus mask (black and white. White has higher loss than black)
|
||||
self.mask_min_value: float = kwargs.get('mask_min_value', 0.01) # min value for . 0 - 1
|
||||
self.poi: Union[str, None] = kwargs.get('poi', None) # if one is set and in json data, will be used as auto crop scale point of interes
|
||||
|
||||
# cache latents will store them in memory
|
||||
|
||||
@@ -11,7 +11,7 @@ import torch
|
||||
from safetensors.torch import load_file, save_file
|
||||
from tqdm import tqdm
|
||||
|
||||
from toolkit.basic import flush
|
||||
from toolkit.basic import flush, value_map
|
||||
from toolkit.buckets import get_bucket_for_image_size
|
||||
from toolkit.metadata import get_meta_for_safetensors
|
||||
from toolkit.prompt_utils import inject_trigger_into_prompt
|
||||
@@ -442,6 +442,7 @@ class MaskFileItemDTOMixin:
|
||||
self.mask_path: Union[str, None] = None
|
||||
self.mask_tensor: Union[torch.Tensor, None] = None
|
||||
dataset_config: 'DatasetConfig' = kwargs.get('dataset_config', None)
|
||||
self.mask_min_value = dataset_config.mask_min_value
|
||||
if dataset_config.mask_path is not None:
|
||||
# find the control image path
|
||||
mask_path = dataset_config.mask_path
|
||||
@@ -502,6 +503,7 @@ class MaskFileItemDTOMixin:
|
||||
raise Exception("Mask images not supported for non-bucket datasets")
|
||||
|
||||
self.mask_tensor = transforms.ToTensor()(img)
|
||||
self.mask_tensor = value_map(self.mask_tensor, 0, 1.0, self.mask_min_value, 1.0)
|
||||
# convert to grayscale
|
||||
|
||||
def cleanup_mask(self: 'FileItemDTO'):
|
||||
|
||||
Reference in New Issue
Block a user