Bugfixes. Added small augmentations to dataloader. Will switch to abluminations soon though. Added ability to adjust step count on start to override what is in the file

This commit is contained in:
Jaret Burkett
2023-09-20 05:30:10 -06:00
parent 0f105690cc
commit 19255cdc7c
5 changed files with 43 additions and 18 deletions

View File

@@ -392,7 +392,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
def load_training_state_from_metadata(self, path):
meta = load_metadata_from_safetensors(path)
# if 'training_info' in Orderdict keys
if 'training_info' in meta and 'step' in meta['training_info']:
if 'training_info' in meta and 'step' in meta['training_info'] and self.train_config.start_step is None:
self.step_num = meta['training_info']['step']
self.start_step = self.step_num
print(f"Found step {self.step_num} in metadata, starting from there")
@@ -796,6 +796,10 @@ class BaseSDTrainProcess(BaseTrainProcess):
else:
self.params.append(param)
if self.train_config.start_step is not None:
self.step_num = self.train_config.start_step
self.start_step = self.step_num
optimizer_type = self.train_config.optimizer.lower()
optimizer = get_optimizer(self.params, optimizer_type, learning_rate=self.train_config.lr,
optimizer_params=self.train_config.optimizer_params)

View File

@@ -38,6 +38,8 @@ dataset_config = DatasetConfig(
default_caption='default',
buckets=True,
bucket_tolerance=bucket_tolerance,
augments=['ColorJitter', 'RandomEqualize'],
)
dataloader = get_dataloader_from_datasets([dataset_config], batch_size=batch_size)
@@ -48,22 +50,22 @@ for batch in dataloader:
batch: 'DataLoaderBatchDTO'
img_batch = batch.tensor
# chunks = torch.chunk(img_batch, batch_size, dim=0)
# # put them so they are size by side
# big_img = torch.cat(chunks, dim=3)
# big_img = big_img.squeeze(0)
#
# min_val = big_img.min()
# max_val = big_img.max()
#
# big_img = (big_img / 2 + 0.5).clamp(0, 1)
#
# # convert to image
# img = transforms.ToPILImage()(big_img)
#
# show_img(img)
#
# time.sleep(1.0)
chunks = torch.chunk(img_batch, batch_size, dim=0)
# put them so they are size by side
big_img = torch.cat(chunks, dim=3)
big_img = big_img.squeeze(0)
min_val = big_img.min()
max_val = big_img.max()
big_img = (big_img / 2 + 0.5).clamp(0, 1)
# convert to image
img = transforms.ToPILImage()(big_img)
show_img(img)
time.sleep(1.0)
cv2.destroyAllWindows()

View File

@@ -112,6 +112,7 @@ class TrainConfig:
self.weight_jitter = kwargs.get('weight_jitter', 0.0)
self.merge_network_on_save = kwargs.get('merge_network_on_save', False)
self.max_grad_norm = kwargs.get('max_grad_norm', 1.0)
self.start_step = kwargs.get('start_step', None)
class ModelConfig:
@@ -221,12 +222,18 @@ class DatasetConfig:
self.caption_dropout_rate: float = float(kwargs.get('caption_dropout_rate', 0.0))
self.flip_x: bool = kwargs.get('flip_x', False)
self.flip_y: bool = kwargs.get('flip_y', False)
self.augments: List[str] = kwargs.get('augments', [])
# cache latents will store them in memory
self.cache_latents: bool = kwargs.get('cache_latents', False)
# cache latents to disk will store them on disk. If both are true, it will save to disk, but keep in memory
self.cache_latents_to_disk: bool = kwargs.get('cache_latents_to_disk', False)
if len(self.augments) > 0 and (self.cache_latents or self.cache_latents_to_disk):
print(f"WARNING: Augments are not supported with caching latents. Setting cache_latents to False")
self.cache_latents = False
self.cache_latents_to_disk = False
# legacy compatability
legacy_caption_type = kwargs.get('caption_type', None)
if legacy_caption_type:

View File

@@ -49,6 +49,7 @@ class FileItemDTO(LatentCachingFileItemDTOMixin, CaptionProcessingDTOMixin, Imag
self.crop_height: int = kwargs.get('crop_height', self.scale_to_height)
self.flip_x: bool = kwargs.get('flip_x', False)
self.flip_y: bool = kwargs.get('flip_x', False)
self.augments: List[str] = self.dataset_config.augments
self.network_weight: float = self.dataset_config.network_weight
self.is_reg = self.dataset_config.is_reg

View File

@@ -25,10 +25,15 @@ if TYPE_CHECKING:
from toolkit.data_loader import AiToolkitDataset
from toolkit.data_transfer_object.data_loader import FileItemDTO
# def get_associated_caption_from_img_path(img_path):
transforms_dict = {
'ColorJitter': transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.01),
'RandomEqualize': transforms.RandomEqualize(p=0.2),
}
class CaptionMixin:
def get_caption_item(self: 'AiToolkitDataset', index):
if not hasattr(self, 'caption_type'):
@@ -287,6 +292,12 @@ class ImageProcessingDTOMixin:
img = transforms.CenterCrop(min_img_size)(img)
img = img.resize((self.dataset_config.resolution, self.dataset_config.resolution), Image.BICUBIC)
if self.augments is not None and len(self.augments) > 0:
# do augmentations
for augment in self.augments:
if augment in transforms_dict:
img = transforms_dict[augment](img)
if transform:
img = transform(img)