mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
WIP on SAFE encoder. Work on fp16 training improvements. Various other tweaks and improvements
This commit is contained in:
@@ -31,6 +31,7 @@ from jobs.process import BaseSDTrainProcess
|
||||
from torchvision import transforms
|
||||
|
||||
|
||||
|
||||
def flush():
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
@@ -55,6 +56,17 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
self.negative_prompt_pool: Union[List[str], None] = None
|
||||
self.batch_negative_prompt: Union[List[str], None] = None
|
||||
|
||||
self.scaler = torch.cuda.amp.GradScaler()
|
||||
|
||||
# patch the scaler to allow fp16 training
|
||||
org_unscale_grads = self.scaler._unscale_grads_
|
||||
def _unscale_grads_replacer(optimizer, inv_scale, found_inf, allow_fp16):
|
||||
return org_unscale_grads(optimizer, inv_scale, found_inf, True)
|
||||
self.scaler._unscale_grads_ = _unscale_grads_replacer
|
||||
|
||||
self.is_bfloat = self.train_config.dtype == "bfloat16" or self.train_config.dtype == "bf16"
|
||||
|
||||
|
||||
def before_model_load(self):
|
||||
pass
|
||||
|
||||
@@ -1401,6 +1413,7 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
print("loss is nan")
|
||||
loss = torch.zeros_like(loss).requires_grad_(True)
|
||||
|
||||
|
||||
with self.timer('backward'):
|
||||
# todo we have multiplier seperated. works for now as res are not in same batch, but need to change
|
||||
loss = loss * loss_multiplier.mean()
|
||||
@@ -1410,7 +1423,10 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
# 0.0 for the backward pass and the gradients will be 0.0
|
||||
# I spent weeks on fighting this. DON'T DO IT
|
||||
# with fsdp_overlap_step_with_backward():
|
||||
loss.backward()
|
||||
if self.is_bfloat:
|
||||
loss.backward()
|
||||
else:
|
||||
self.scaler.scale(loss).backward()
|
||||
# flush()
|
||||
|
||||
if not self.is_grad_accumulation_step:
|
||||
@@ -1423,8 +1439,13 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
torch.nn.utils.clip_grad_norm_(self.params, self.train_config.max_grad_norm)
|
||||
# only step if we are not accumulating
|
||||
with self.timer('optimizer_step'):
|
||||
# apply gradients
|
||||
self.optimizer.step()
|
||||
if self.is_bfloat:
|
||||
self.optimizer.step()
|
||||
else:
|
||||
# apply gradients
|
||||
self.scaler.step(self.optimizer)
|
||||
self.scaler.update()
|
||||
# self.optimizer.step()
|
||||
self.optimizer.zero_grad(set_to_none=True)
|
||||
else:
|
||||
# gradient accumulation. Just a place for breakpoint
|
||||
|
||||
@@ -48,28 +48,55 @@ dataset_config = DatasetConfig(
|
||||
buckets=True,
|
||||
bucket_tolerance=bucket_tolerance,
|
||||
# poi='person',
|
||||
shuffle_augmentations=True,
|
||||
# augmentations=[
|
||||
# {
|
||||
# 'method': 'RandomBrightnessContrast',
|
||||
# 'brightness_limit': (-0.3, 0.3),
|
||||
# 'contrast_limit': (-0.3, 0.3),
|
||||
# 'brightness_by_max': False,
|
||||
# 'p': 1.0
|
||||
# 'method': 'GaussianBlur',
|
||||
# 'blur_limit': (1, 16),
|
||||
# 'sigma_limit': (0, 8),
|
||||
# 'p': 0.8
|
||||
# },
|
||||
# {
|
||||
# 'method': 'HueSaturationValue',
|
||||
# 'hue_shift_limit': (-0, 0),
|
||||
# 'sat_shift_limit': (-40, 40),
|
||||
# 'val_shift_limit': (-40, 40),
|
||||
# 'p': 1.0
|
||||
# 'method': 'ImageCompression',
|
||||
# 'quality_lower': 10,
|
||||
# 'quality_upper': 100,
|
||||
# 'compression_type': 0,
|
||||
# 'p': 0.8
|
||||
# },
|
||||
# {
|
||||
# 'method': 'ImageCompression',
|
||||
# 'quality_lower': 20,
|
||||
# 'quality_upper': 100,
|
||||
# 'compression_type': 1,
|
||||
# 'p': 0.8
|
||||
# },
|
||||
# {
|
||||
# 'method': 'RingingOvershoot',
|
||||
# 'blur_limit': (3, 35),
|
||||
# 'cutoff': (0.7, 1.96),
|
||||
# 'p': 0.8
|
||||
# },
|
||||
# {
|
||||
# 'method': 'GaussNoise',
|
||||
# 'var_limit': (0, 300),
|
||||
# 'per_channel': True,
|
||||
# 'mean': 0.0,
|
||||
# 'p': 0.8
|
||||
# },
|
||||
# {
|
||||
# 'method': 'GlassBlur',
|
||||
# 'sigma': 0.6,
|
||||
# 'max_delta': 7,
|
||||
# 'iterations': 2,
|
||||
# 'mode': 'fast',
|
||||
# 'p': 0.8
|
||||
# },
|
||||
# {
|
||||
# 'method': 'Downscale',
|
||||
# 'scale_max': 0.5,
|
||||
# 'interpolation': 'cv2.INTER_CUBIC',
|
||||
# 'p': 0.8
|
||||
# },
|
||||
# {
|
||||
# 'method': 'RGBShift',
|
||||
# 'r_shift_limit': (-20, 20),
|
||||
# 'g_shift_limit': (-20, 20),
|
||||
# 'b_shift_limit': (-20, 20),
|
||||
# 'p': 1.0
|
||||
# },
|
||||
# ]
|
||||
|
||||
|
||||
@@ -100,7 +127,7 @@ for epoch in range(args.epochs):
|
||||
|
||||
show_img(img)
|
||||
|
||||
# time.sleep(0.1)
|
||||
# time.sleep(1.0)
|
||||
# if not last epoch
|
||||
if epoch < args.epochs - 1:
|
||||
trigger_dataloader_setup_epoch(dataloader)
|
||||
|
||||
@@ -529,6 +529,7 @@ class DatasetConfig:
|
||||
self.num_workers: int = kwargs.get('num_workers', 4)
|
||||
self.prefetch_factor: int = kwargs.get('prefetch_factor', 2)
|
||||
self.extra_values: List[float] = kwargs.get('extra_values', [])
|
||||
self.square_crop: bool = kwargs.get('square_crop', False)
|
||||
|
||||
|
||||
def preprocess_dataset_raw_config(raw_config: List[dict]) -> List[dict]:
|
||||
|
||||
@@ -388,7 +388,10 @@ class CustomAdapter(torch.nn.Module):
|
||||
self.fuse_module.load_state_dict(state_dict['fuse_module'], strict=strict)
|
||||
|
||||
if 'ilora' in state_dict:
|
||||
self.ilora_module.load_state_dict(state_dict['ilora'], strict=strict)
|
||||
try:
|
||||
self.ilora_module.load_state_dict(state_dict['ilora'], strict=strict)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
|
||||
pass
|
||||
|
||||
|
||||
@@ -203,7 +203,22 @@ class BucketsMixin:
|
||||
if file_item.has_point_of_interest:
|
||||
# Attempt to process the poi if we can. It wont process if the image is smaller than the resolution
|
||||
did_process_poi = file_item.setup_poi_bucket()
|
||||
if not did_process_poi:
|
||||
if self.dataset_config.square_crop:
|
||||
# we scale first so smallest size matches resolution
|
||||
scale_factor_x = resolution / width
|
||||
scale_factor_y = resolution / height
|
||||
scale_factor = max(scale_factor_x, scale_factor_y)
|
||||
file_item.scale_to_width = math.ceil(width * scale_factor)
|
||||
file_item.scale_to_height = math.ceil(height * scale_factor)
|
||||
file_item.crop_width = resolution
|
||||
file_item.crop_height = resolution
|
||||
if width > height:
|
||||
file_item.crop_x = int(file_item.scale_to_width / 2 - resolution / 2)
|
||||
file_item.crop_y = 0
|
||||
else:
|
||||
file_item.crop_x = 0
|
||||
file_item.crop_y = int(file_item.scale_to_height / 2 - resolution / 2)
|
||||
elif not did_process_poi:
|
||||
bucket_resolution = get_bucket_for_image_size(
|
||||
width, height,
|
||||
resolution=resolution,
|
||||
|
||||
@@ -365,14 +365,15 @@ class IPAdapter(torch.nn.Module):
|
||||
input_size=preprocessor_input_size,
|
||||
clip_input_size=self.image_encoder.config.image_size,
|
||||
)
|
||||
if 'height' in self.clip_image_processor.size:
|
||||
self.input_size = self.clip_image_processor.size['height']
|
||||
elif hasattr(self.clip_image_processor, 'crop_size'):
|
||||
self.input_size = self.clip_image_processor.crop_size['height']
|
||||
elif 'shortest_edge' in self.clip_image_processor.size.keys():
|
||||
self.input_size = self.clip_image_processor.size['shortest_edge']
|
||||
else:
|
||||
raise ValueError(f"unknown image processor size: {self.clip_image_processor.size}")
|
||||
if not self.config.image_encoder_arch == 'safe':
|
||||
if 'height' in self.clip_image_processor.size:
|
||||
self.input_size = self.clip_image_processor.size['height']
|
||||
elif hasattr(self.clip_image_processor, 'crop_size'):
|
||||
self.input_size = self.clip_image_processor.crop_size['height']
|
||||
elif 'shortest_edge' in self.clip_image_processor.size.keys():
|
||||
self.input_size = self.clip_image_processor.size['shortest_edge']
|
||||
else:
|
||||
raise ValueError(f"unknown image processor size: {self.clip_image_processor.size}")
|
||||
self.current_scale = 1.0
|
||||
self.is_active = True
|
||||
is_pixart = sd.is_pixart
|
||||
|
||||
@@ -21,19 +21,24 @@ class ILoRAProjModule(torch.nn.Module):
|
||||
|
||||
self.num_modules = num_modules
|
||||
self.num_dim = dim
|
||||
self.norm = torch.nn.LayerNorm(embeddings_dim)
|
||||
|
||||
self.proj = torch.nn.Sequential(
|
||||
torch.nn.LayerNorm(embeddings_dim),
|
||||
torch.nn.Linear(embeddings_dim, embeddings_dim * 2),
|
||||
torch.nn.GELU(),
|
||||
torch.nn.Linear(embeddings_dim * 2, num_modules * dim),
|
||||
torch.nn.Linear(embeddings_dim * 2, embeddings_dim * 2),
|
||||
torch.nn.LayerNorm(embeddings_dim * 2),
|
||||
|
||||
torch.nn.Linear(embeddings_dim * 2, embeddings_dim * 4),
|
||||
torch.nn.GELU(),
|
||||
torch.nn.Linear(embeddings_dim * 4, num_modules * dim),
|
||||
torch.nn.LayerNorm(num_modules * dim),
|
||||
)
|
||||
# Initialize the last linear layer weights near zero
|
||||
torch.nn.init.uniform_(self.proj[2].weight, a=-0.01, b=0.01)
|
||||
torch.nn.init.zeros_(self.proj[2].bias)
|
||||
torch.nn.init.uniform_(self.proj[-2].weight, a=-0.01, b=0.01)
|
||||
torch.nn.init.zeros_(self.proj[-2].bias)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.norm(x)
|
||||
x = self.proj(x)
|
||||
x = x.reshape(-1, self.num_modules, self.num_dim)
|
||||
return x
|
||||
@@ -71,6 +76,8 @@ class InstantLoRAMidModule(torch.nn.Module):
|
||||
# reshape if needed
|
||||
if len(x.shape) == 3:
|
||||
scaler = scaler.unsqueeze(1)
|
||||
if len(x.shape) == 4:
|
||||
scaler = scaler.unsqueeze(-1).unsqueeze(-1)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
print(x.shape)
|
||||
|
||||
@@ -20,11 +20,11 @@ class SAFEReducerBlock(nn.Module):
|
||||
|
||||
self.reducer = nn.Sequential(
|
||||
nn.Conv2d(channels, channels, kernel_size=3, padding=1),
|
||||
nn.BatchNorm2d(channels),
|
||||
activation(),
|
||||
nn.BatchNorm2d(channels),
|
||||
nn.Conv2d(channels, channels, kernel_size=3, padding=1),
|
||||
nn.BatchNorm2d(channels),
|
||||
activation(),
|
||||
nn.BatchNorm2d(channels),
|
||||
nn.AvgPool2d(kernel_size=2, stride=2),
|
||||
)
|
||||
self.residual_shrink = nn.AvgPool2d(kernel_size=2, stride=2)
|
||||
@@ -227,6 +227,7 @@ class SAFEVMConfig:
|
||||
self.reducer_channels = reducer_channels
|
||||
self.channels = channels
|
||||
self.downscale_factor = downscale_factor
|
||||
self.image_size = 224
|
||||
|
||||
self.hidden_size = num_vectors
|
||||
self.projection_dim = num_vectors
|
||||
@@ -242,7 +243,9 @@ class SAFEVMReturn:
|
||||
class SAFEVisionModel(SizeAgnosticFeatureEncoder):
|
||||
def __init__(self, **kwargs):
|
||||
self.config = SAFEVMConfig(**kwargs)
|
||||
super().__init__(**kwargs)
|
||||
self.image_size = None
|
||||
# super().__init__(**kwargs)
|
||||
super(SAFEVisionModel, self).__init__(**kwargs)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
|
||||
@@ -20,12 +20,12 @@ def get_optimizer(
|
||||
# dadaptation uses different lr that is values of 0.1 to 1.0. default to 1.0
|
||||
use_lr = 1.0
|
||||
if lower_type.endswith('lion'):
|
||||
optimizer = dadaptation.DAdaptLion(params, lr=use_lr, **optimizer_params)
|
||||
optimizer = dadaptation.DAdaptLion(params, eps=1e-6, lr=use_lr, **optimizer_params)
|
||||
elif lower_type.endswith('adam'):
|
||||
optimizer = dadaptation.DAdaptLion(params, lr=use_lr, **optimizer_params)
|
||||
optimizer = dadaptation.DAdaptLion(params, eps=1e-6, lr=use_lr, **optimizer_params)
|
||||
elif lower_type == 'dadaptation':
|
||||
# backwards compatibility
|
||||
optimizer = dadaptation.DAdaptAdam(params, lr=use_lr, **optimizer_params)
|
||||
optimizer = dadaptation.DAdaptAdam(params, eps=1e-6, lr=use_lr, **optimizer_params)
|
||||
# warn user that dadaptation is deprecated
|
||||
print("WARNING: Dadaptation optimizer type has been changed to DadaptationAdam. Please update your config.")
|
||||
elif lower_type.startswith("prodigy"):
|
||||
@@ -40,22 +40,22 @@ def get_optimizer(
|
||||
print(f"Using lr {use_lr}")
|
||||
# let net be the neural network you want to train
|
||||
# you can choose weight decay value based on your problem, 0 by default
|
||||
optimizer = Prodigy(params, lr=use_lr, **optimizer_params)
|
||||
optimizer = Prodigy(params, lr=use_lr, eps=1e-6, **optimizer_params)
|
||||
elif lower_type.endswith("8bit"):
|
||||
import bitsandbytes
|
||||
|
||||
if lower_type == "adam8bit":
|
||||
return bitsandbytes.optim.Adam8bit(params, lr=learning_rate, **optimizer_params)
|
||||
return bitsandbytes.optim.Adam8bit(params, lr=learning_rate, eps=1e-6, **optimizer_params)
|
||||
elif lower_type == "adamw8bit":
|
||||
return bitsandbytes.optim.AdamW8bit(params, lr=learning_rate, **optimizer_params)
|
||||
return bitsandbytes.optim.AdamW8bit(params, lr=learning_rate, eps=1e-6, **optimizer_params)
|
||||
elif lower_type == "lion8bit":
|
||||
return bitsandbytes.optim.Lion8bit(params, lr=learning_rate, **optimizer_params)
|
||||
else:
|
||||
raise ValueError(f'Unknown optimizer type {optimizer_type}')
|
||||
elif lower_type == 'adam':
|
||||
optimizer = torch.optim.Adam(params, lr=float(learning_rate), **optimizer_params)
|
||||
optimizer = torch.optim.Adam(params, lr=float(learning_rate), eps=1e-6, **optimizer_params)
|
||||
elif lower_type == 'adamw':
|
||||
optimizer = torch.optim.AdamW(params, lr=float(learning_rate), **optimizer_params)
|
||||
optimizer = torch.optim.AdamW(params, lr=float(learning_rate), eps=1e-6, **optimizer_params)
|
||||
elif lower_type == 'lion':
|
||||
try:
|
||||
from lion_pytorch import Lion
|
||||
@@ -63,7 +63,7 @@ def get_optimizer(
|
||||
except ImportError:
|
||||
raise ImportError("Please install lion_pytorch to use Lion optimizer -> pip install lion-pytorch")
|
||||
elif lower_type == 'adagrad':
|
||||
optimizer = torch.optim.Adagrad(params, lr=float(learning_rate), **optimizer_params)
|
||||
optimizer = torch.optim.Adagrad(params, lr=float(learning_rate), eps=1e-6, **optimizer_params)
|
||||
elif lower_type == 'adafactor':
|
||||
# hack in stochastic rounding
|
||||
if 'relative_step' not in optimizer_params:
|
||||
@@ -72,7 +72,7 @@ def get_optimizer(
|
||||
optimizer_params['scale_parameter'] = True
|
||||
if 'warmup_init' not in optimizer_params:
|
||||
optimizer_params['warmup_init'] = False
|
||||
optimizer = Adafactor(params, lr=float(learning_rate), **optimizer_params)
|
||||
optimizer = Adafactor(params, lr=float(learning_rate), eps=1e-6, **optimizer_params)
|
||||
from toolkit.util.adafactor_stochastic_rounding import step_adafactor
|
||||
optimizer.step = step_adafactor.__get__(optimizer, Adafactor)
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user