diff --git a/extensions_built_in/image_reference_slider_trainer/ImageReferenceSliderTrainerProcess.py b/extensions_built_in/image_reference_slider_trainer/ImageReferenceSliderTrainerProcess.py index 127f56cc..dba3a57f 100644 --- a/extensions_built_in/image_reference_slider_trainer/ImageReferenceSliderTrainerProcess.py +++ b/extensions_built_in/image_reference_slider_trainer/ImageReferenceSliderTrainerProcess.py @@ -20,13 +20,18 @@ def flush(): gc.collect() +class DatasetConfig: + def __init__(self, **kwargs): + self.pair_folder: str = kwargs.get('pair_folder', None) + self.network_weight: float = kwargs.get('network_weight', 1.0) + self.target_class: str = kwargs.get('target_class', '') + self.size: int = kwargs.get('size', 512) + + class ReferenceSliderConfig: def __init__(self, **kwargs): - self.slider_pair_folder: str = kwargs.get('slider_pair_folder', None) - self.resolutions: List[int] = kwargs.get('resolutions', [512]) - self.batch_full_slide: bool = kwargs.get('batch_full_slide', True) - self.target_class: int = kwargs.get('target_class', '') self.additional_losses: List[str] = kwargs.get('additional_losses', []) + self.datasets: List[DatasetConfig] = [DatasetConfig(**d) for d in kwargs.get('datasets', [])] class ImageReferenceSliderTrainerProcess(BaseSDTrainProcess): @@ -46,12 +51,13 @@ class ImageReferenceSliderTrainerProcess(BaseSDTrainProcess): if self.data_loader is None: print(f"Loading datasets") datasets = [] - for res in self.slider_config.resolutions: - print(f" - Dataset: {self.slider_config.slider_pair_folder}") + for dataset in self.slider_config.datasets: + print(f" - Dataset: {dataset.pair_folder}") config = { - 'path': self.slider_config.slider_pair_folder, - 'size': res, - 'default_prompt': self.slider_config.target_class + 'path': dataset.pair_folder, + 'size': dataset.size, + 'default_prompt': dataset.target_class, + 'network_weight': dataset.network_weight, } image_dataset = PairedImageDataset(config) datasets.append(image_dataset) @@ -78,7 +84,7 @@ class ImageReferenceSliderTrainerProcess(BaseSDTrainProcess): do_mirror_loss = 'mirror' in self.slider_config.additional_losses with torch.no_grad(): - imgs, prompts = batch + imgs, prompts, base_network_weight = batch dtype = get_torch_dtype(self.train_config.dtype) imgs: torch.Tensor = imgs.to(self.device_torch, dtype=dtype) # split batched images in half so left is negative and right is positive @@ -129,7 +135,7 @@ class ImageReferenceSliderTrainerProcess(BaseSDTrainProcess): noisy_latents = torch.cat([noisy_positive_latents, noisy_negative_latents], dim=0) noise = torch.cat([noise_positive, noise_negative], dim=0) timesteps = torch.cat([timesteps, timesteps], dim=0) - network_multiplier = [1.0, -1.0] + network_multiplier = [base_network_weight * 1.0, base_network_weight * -1.0] flush() @@ -180,7 +186,8 @@ class ImageReferenceSliderTrainerProcess(BaseSDTrainProcess): noise_pred_pos, noise_pred_neg = torch.chunk(noise_pred, 2, dim=0) # mirror the negative noise_pred_neg = torch.flip(noise_pred_neg.clone(), dims=[3]) - loss_mirror = torch.nn.functional.mse_loss(noise_pred_pos.float(), noise_pred_neg.float(), reduction="none") + loss_mirror = torch.nn.functional.mse_loss(noise_pred_pos.float(), noise_pred_neg.float(), + reduction="none") loss_mirror = loss_mirror.mean([1, 2, 3]) loss_mirror = loss_mirror.mean() loss_mirror_float = loss_mirror.item() diff --git a/extensions_built_in/image_reference_slider_trainer/config/train.example.yaml b/extensions_built_in/image_reference_slider_trainer/config/train.example.yaml index 52e3d5d6..8b0f4734 100644 --- a/extensions_built_in/image_reference_slider_trainer/config/train.example.yaml +++ b/extensions_built_in/image_reference_slider_trainer/config/train.example.yaml @@ -1,7 +1,7 @@ --- job: extension config: - name: subject_turner_v1 + name: example_name process: - type: 'image_reference_slider_trainer' training_folder: "/mnt/Train/out/LoRA" @@ -10,10 +10,8 @@ config: log_dir: "/home/jaret/Dev/.tensorboard" network: type: "lora" - linear: 64 - linear_alpha: 32 - conv: 32 - conv_alpha: 16 + linear: 8 + linear_alpha: 8 train: noise_scheduler: "ddpm" # or "ddpm", "lms", "euler_a" steps: 5000 @@ -30,11 +28,9 @@ config: dtype: bf16 xformers: true skip_first_sample: true - noise_offset: 0.0 # I don't recommend using unless you are trying to make a darker lora. Then do 0.1 MAX + noise_offset: 0.0 model: -# name_or_path: "/mnt/Models/stable-diffusion/models/stable-diffusion/sdxl/sd_xl_base_0.9.safetensors" - name_or_path: "/mnt/Models/stable-diffusion/models/stable-diffusion/Ostris/Ostris_Real_v1.safetensors" - # name_or_path: "/mnt/Models/stable-diffusion/models/stable-diffusion/sd_v1-5_vae.ckpt" + name_or_path: "/path/to/model.safetensors" is_v2: false # for v2 models is_xl: false # for SDXL models is_v_pred: false # for v-prediction models (most v2 models) @@ -81,18 +77,31 @@ config: verbose: false slider: - resolutions: - - 512 - slider_pair_folder: "/mnt/Datasets/stable-diffusion/slider_reference/subject_turner" - target_class: "photo of a person" -# additional_losses: -# - "mirror" + datasets: + - pair_folder: "/path/to/folder/side/by/side/images" + network_weight: 2.0 + target_class: "" # only used as default if caption txt are not present + size: 512 + - pair_folder: "/path/to/folder/side/by/side/images" + network_weight: 4.0 + target_class: "" # only used as default if caption txt are not present + size: 512 +# you can put any information you want here, and it will be saved in the model +# the below is an example. I recommend doing trigger words at a minimum +# in the metadata. The software will include this plus some other information meta: - name: "[name]" - version: '1.0' + name: "[name]" # [name] gets replaced with the name above + description: A short description of your model + trigger_words: + - put + - trigger + - words + - here + version: '0.1' creator: - name: Ostris - Jaret Burkett - email: jaret@ostris.com - website: https://ostris.com + name: Your Name + email: your@email.com + website: https://yourwebsite.com + any: All meta data above is arbitrary, it can be whatever you want. \ No newline at end of file diff --git a/toolkit/basic.py b/toolkit/basic.py new file mode 100644 index 00000000..248ffc4e --- /dev/null +++ b/toolkit/basic.py @@ -0,0 +1,4 @@ + + +def value_map(inputs, min_in, max_in, min_out, max_out): + return (inputs - min_in) * (max_out - min_out) / (max_in - min_in) + min_out diff --git a/toolkit/data_loader.py b/toolkit/data_loader.py index 1e2264ac..2a58ab43 100644 --- a/toolkit/data_loader.py +++ b/toolkit/data_loader.py @@ -149,6 +149,7 @@ class PairedImageDataset(Dataset): self.size = self.get_config('size', 512) self.path = self.get_config('path', required=True) self.default_prompt = self.get_config('default_prompt', '') + self.network_weight = self.get_config('network_weight', 1.0) self.file_list = [os.path.join(self.path, file) for file in os.listdir(self.path) if file.lower().endswith(('.jpg', '.jpeg', '.png', '.webp'))] print(f" - Found {len(self.file_list)} images") @@ -200,5 +201,5 @@ class PairedImageDataset(Dataset): img = img.resize((width, height), Image.BICUBIC) img = self.transform(img) - return img, prompt + return img, prompt, self.network_weight