diff --git a/extensions_built_in/sd_trainer/SDTrainer.py b/extensions_built_in/sd_trainer/SDTrainer.py index eb3b057c..f98cab23 100644 --- a/extensions_built_in/sd_trainer/SDTrainer.py +++ b/extensions_built_in/sd_trainer/SDTrainer.py @@ -21,7 +21,8 @@ def flush(): gc.collect() adapter_transforms = transforms.Compose([ - transforms.PILToTensor(), + # transforms.PILToTensor(), + transforms.ToTensor(), ]) @@ -44,6 +45,13 @@ class SDTrainer(BaseSDTrainProcess): flush() def get_adapter_images(self, batch: 'DataLoaderBatchDTO'): + if self.adapter_config.image_dir is None: + # adapter needs 0 to 1 values, batch is -1 to 1 + adapter_batch = batch.tensor.clone().to( + self.device_torch, dtype=get_torch_dtype(self.train_config.dtype) + ) + adapter_batch = (adapter_batch + 1) / 2 + return adapter_batch img_ext_list = ['.jpg', '.jpeg', '.png', '.webp'] adapter_folder_path = self.adapter_config.image_dir adapter_images = [] diff --git a/jobs/process/BaseSDTrainProcess.py b/jobs/process/BaseSDTrainProcess.py index 40dded91..bbcb466b 100644 --- a/jobs/process/BaseSDTrainProcess.py +++ b/jobs/process/BaseSDTrainProcess.py @@ -75,7 +75,7 @@ class BaseSDTrainProcess(BaseTrainProcess): self.has_first_sample_requested = False self.first_sample_config = self.sample_config self.logging_config = LogingConfig(**self.get_conf('logging', {})) - self.optimizer = None + self.optimizer: torch.optim.Optimizer = None self.lr_scheduler = None self.data_loader: Union[DataLoader, None] = None self.data_loader_reg: Union[DataLoader, None] = None @@ -543,28 +543,39 @@ class BaseSDTrainProcess(BaseTrainProcess): return noisy_latents, noise, timesteps, conditioned_prompts, imgs def setup_adapter(self): - dtype = get_torch_dtype(self.train_config.dtype) + # t2i adapter is_t2i = self.adapter_config.type == 't2i' + suffix = 't2i' if is_t2i else 'ip' + adapter_name = self.name + if self.network_config is not None: + adapter_name = f"{adapter_name}_{suffix}" + latest_save_path = self.get_latest_save_path(adapter_name) + + dtype = get_torch_dtype(self.train_config.dtype) if is_t2i: - self.adapter = T2IAdapter( - in_channels=self.adapter_config.in_channels, - channels=self.adapter_config.channels, - num_res_blocks=self.adapter_config.num_res_blocks, - downscale_factor=self.adapter_config.downscale_factor, - adapter_type=self.adapter_config.adapter_type, - ) + # if we do not have a last save path and we have a name_or_path, + # load from that + if latest_save_path is None and self.adapter_config.name_or_path is not None: + self.adapter = T2IAdapter.from_pretrained( + self.adapter_config.name_or_path, + torch_dtype=get_torch_dtype(self.train_config.dtype), + varient="fp16", + # use_safetensors=True, + ) + else: + self.adapter = T2IAdapter( + in_channels=self.adapter_config.in_channels, + channels=self.adapter_config.channels, + num_res_blocks=self.adapter_config.num_res_blocks, + downscale_factor=self.adapter_config.downscale_factor, + adapter_type=self.adapter_config.adapter_type, + ) else: self.adapter = IPAdapter( sd=self.sd, adapter_config=self.adapter_config, ) self.adapter.to(self.device_torch, dtype=dtype) - # t2i adapter - suffix = 't2i' if is_t2i else 'ip' - adapter_name = self.name - if self.network_config is not None: - adapter_name = f"{adapter_name}_{suffix}" - latest_save_path = self.get_latest_save_path(adapter_name) if latest_save_path is not None: # load adapter from path print(f"Loading adapter from {latest_save_path}") diff --git a/toolkit/buckets.py b/toolkit/buckets.py index 3711ea7e..be69085b 100644 --- a/toolkit/buckets.py +++ b/toolkit/buckets.py @@ -83,7 +83,8 @@ def get_bucket_for_image_size( width: int, height: int, bucket_size_list: List[BucketResolution] = None, - resolution: Union[int, None] = None + resolution: Union[int, None] = None, + divisibility: int = 8 ) -> BucketResolution: if bucket_size_list is None and resolution is None: @@ -93,7 +94,7 @@ def get_bucket_for_image_size( # if real resolution is smaller, use that instead real_resolution = get_resolution(width, height) resolution = min(resolution, real_resolution) - bucket_size_list = get_bucket_sizes(resolution=resolution) + bucket_size_list = get_bucket_sizes(resolution=resolution, divisibility=divisibility) # Check for exact match first for bucket in bucket_size_list: diff --git a/toolkit/data_loader.py b/toolkit/data_loader.py index 86214b1a..1af7e8ba 100644 --- a/toolkit/data_loader.py +++ b/toolkit/data_loader.py @@ -281,7 +281,8 @@ class PairedImageDataset(Dataset): bucket_resolution = get_bucket_for_image_size( width=img2.width, height=img2.height, - resolution=self.size + resolution=self.size, + # divisibility=self. ) # images will be same base dimension, but may be trimmed. We need to shrink and then central crop diff --git a/toolkit/dataloader_mixins.py b/toolkit/dataloader_mixins.py index d39e7f03..4ea24e20 100644 --- a/toolkit/dataloader_mixins.py +++ b/toolkit/dataloader_mixins.py @@ -121,7 +121,7 @@ class BucketsMixin: width = file_item.crop_width height = file_item.crop_height - bucket_resolution = get_bucket_for_image_size(width, height, resolution=resolution) + bucket_resolution = get_bucket_for_image_size(width, height, resolution=resolution, divisibility=bucket_tolerance) # set the scaling height and with to match smallest size, and keep aspect ratio if width > height: