mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-27 00:49:47 +00:00
Bugfixes. Added small augmentations to dataloader. Will switch to abluminations soon though. Added ability to adjust step count on start to override what is in the file
This commit is contained in:
@@ -392,7 +392,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
def load_training_state_from_metadata(self, path):
|
||||
meta = load_metadata_from_safetensors(path)
|
||||
# if 'training_info' in Orderdict keys
|
||||
if 'training_info' in meta and 'step' in meta['training_info']:
|
||||
if 'training_info' in meta and 'step' in meta['training_info'] and self.train_config.start_step is None:
|
||||
self.step_num = meta['training_info']['step']
|
||||
self.start_step = self.step_num
|
||||
print(f"Found step {self.step_num} in metadata, starting from there")
|
||||
@@ -796,6 +796,10 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
else:
|
||||
self.params.append(param)
|
||||
|
||||
if self.train_config.start_step is not None:
|
||||
self.step_num = self.train_config.start_step
|
||||
self.start_step = self.step_num
|
||||
|
||||
optimizer_type = self.train_config.optimizer.lower()
|
||||
optimizer = get_optimizer(self.params, optimizer_type, learning_rate=self.train_config.lr,
|
||||
optimizer_params=self.train_config.optimizer_params)
|
||||
|
||||
@@ -38,6 +38,8 @@ dataset_config = DatasetConfig(
|
||||
default_caption='default',
|
||||
buckets=True,
|
||||
bucket_tolerance=bucket_tolerance,
|
||||
augments=['ColorJitter', 'RandomEqualize'],
|
||||
|
||||
)
|
||||
|
||||
dataloader = get_dataloader_from_datasets([dataset_config], batch_size=batch_size)
|
||||
@@ -48,22 +50,22 @@ for batch in dataloader:
|
||||
batch: 'DataLoaderBatchDTO'
|
||||
img_batch = batch.tensor
|
||||
|
||||
# chunks = torch.chunk(img_batch, batch_size, dim=0)
|
||||
# # put them so they are size by side
|
||||
# big_img = torch.cat(chunks, dim=3)
|
||||
# big_img = big_img.squeeze(0)
|
||||
#
|
||||
# min_val = big_img.min()
|
||||
# max_val = big_img.max()
|
||||
#
|
||||
# big_img = (big_img / 2 + 0.5).clamp(0, 1)
|
||||
#
|
||||
# # convert to image
|
||||
# img = transforms.ToPILImage()(big_img)
|
||||
#
|
||||
# show_img(img)
|
||||
#
|
||||
# time.sleep(1.0)
|
||||
chunks = torch.chunk(img_batch, batch_size, dim=0)
|
||||
# put them so they are size by side
|
||||
big_img = torch.cat(chunks, dim=3)
|
||||
big_img = big_img.squeeze(0)
|
||||
|
||||
min_val = big_img.min()
|
||||
max_val = big_img.max()
|
||||
|
||||
big_img = (big_img / 2 + 0.5).clamp(0, 1)
|
||||
|
||||
# convert to image
|
||||
img = transforms.ToPILImage()(big_img)
|
||||
|
||||
show_img(img)
|
||||
|
||||
time.sleep(1.0)
|
||||
|
||||
cv2.destroyAllWindows()
|
||||
|
||||
|
||||
@@ -112,6 +112,7 @@ class TrainConfig:
|
||||
self.weight_jitter = kwargs.get('weight_jitter', 0.0)
|
||||
self.merge_network_on_save = kwargs.get('merge_network_on_save', False)
|
||||
self.max_grad_norm = kwargs.get('max_grad_norm', 1.0)
|
||||
self.start_step = kwargs.get('start_step', None)
|
||||
|
||||
|
||||
class ModelConfig:
|
||||
@@ -221,12 +222,18 @@ class DatasetConfig:
|
||||
self.caption_dropout_rate: float = float(kwargs.get('caption_dropout_rate', 0.0))
|
||||
self.flip_x: bool = kwargs.get('flip_x', False)
|
||||
self.flip_y: bool = kwargs.get('flip_y', False)
|
||||
self.augments: List[str] = kwargs.get('augments', [])
|
||||
|
||||
# cache latents will store them in memory
|
||||
self.cache_latents: bool = kwargs.get('cache_latents', False)
|
||||
# cache latents to disk will store them on disk. If both are true, it will save to disk, but keep in memory
|
||||
self.cache_latents_to_disk: bool = kwargs.get('cache_latents_to_disk', False)
|
||||
|
||||
if len(self.augments) > 0 and (self.cache_latents or self.cache_latents_to_disk):
|
||||
print(f"WARNING: Augments are not supported with caching latents. Setting cache_latents to False")
|
||||
self.cache_latents = False
|
||||
self.cache_latents_to_disk = False
|
||||
|
||||
# legacy compatability
|
||||
legacy_caption_type = kwargs.get('caption_type', None)
|
||||
if legacy_caption_type:
|
||||
|
||||
@@ -49,6 +49,7 @@ class FileItemDTO(LatentCachingFileItemDTOMixin, CaptionProcessingDTOMixin, Imag
|
||||
self.crop_height: int = kwargs.get('crop_height', self.scale_to_height)
|
||||
self.flip_x: bool = kwargs.get('flip_x', False)
|
||||
self.flip_y: bool = kwargs.get('flip_x', False)
|
||||
self.augments: List[str] = self.dataset_config.augments
|
||||
|
||||
self.network_weight: float = self.dataset_config.network_weight
|
||||
self.is_reg = self.dataset_config.is_reg
|
||||
|
||||
@@ -25,10 +25,15 @@ if TYPE_CHECKING:
|
||||
from toolkit.data_loader import AiToolkitDataset
|
||||
from toolkit.data_transfer_object.data_loader import FileItemDTO
|
||||
|
||||
|
||||
# def get_associated_caption_from_img_path(img_path):
|
||||
|
||||
|
||||
transforms_dict = {
|
||||
'ColorJitter': transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.01),
|
||||
'RandomEqualize': transforms.RandomEqualize(p=0.2),
|
||||
}
|
||||
|
||||
|
||||
class CaptionMixin:
|
||||
def get_caption_item(self: 'AiToolkitDataset', index):
|
||||
if not hasattr(self, 'caption_type'):
|
||||
@@ -287,6 +292,12 @@ class ImageProcessingDTOMixin:
|
||||
img = transforms.CenterCrop(min_img_size)(img)
|
||||
img = img.resize((self.dataset_config.resolution, self.dataset_config.resolution), Image.BICUBIC)
|
||||
|
||||
if self.augments is not None and len(self.augments) > 0:
|
||||
# do augmentations
|
||||
for augment in self.augments:
|
||||
if augment in transforms_dict:
|
||||
img = transforms_dict[augment](img)
|
||||
|
||||
if transform:
|
||||
img = transform(img)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user