WIP on SAFE encoder. Work on fp16 training improvements. Various other tweaks and improvements

This commit is contained in:
Jaret Burkett
2024-05-27 10:50:24 -06:00
parent 68b7e159bc
commit 833c833f28
9 changed files with 127 additions and 49 deletions

View File

@@ -31,6 +31,7 @@ from jobs.process import BaseSDTrainProcess
from torchvision import transforms
def flush():
torch.cuda.empty_cache()
gc.collect()
@@ -55,6 +56,17 @@ class SDTrainer(BaseSDTrainProcess):
self.negative_prompt_pool: Union[List[str], None] = None
self.batch_negative_prompt: Union[List[str], None] = None
self.scaler = torch.cuda.amp.GradScaler()
# patch the scaler to allow fp16 training
org_unscale_grads = self.scaler._unscale_grads_
def _unscale_grads_replacer(optimizer, inv_scale, found_inf, allow_fp16):
return org_unscale_grads(optimizer, inv_scale, found_inf, True)
self.scaler._unscale_grads_ = _unscale_grads_replacer
self.is_bfloat = self.train_config.dtype == "bfloat16" or self.train_config.dtype == "bf16"
def before_model_load(self):
pass
@@ -1401,6 +1413,7 @@ class SDTrainer(BaseSDTrainProcess):
print("loss is nan")
loss = torch.zeros_like(loss).requires_grad_(True)
with self.timer('backward'):
# todo we have multiplier seperated. works for now as res are not in same batch, but need to change
loss = loss * loss_multiplier.mean()
@@ -1410,7 +1423,10 @@ class SDTrainer(BaseSDTrainProcess):
# 0.0 for the backward pass and the gradients will be 0.0
# I spent weeks on fighting this. DON'T DO IT
# with fsdp_overlap_step_with_backward():
loss.backward()
if self.is_bfloat:
loss.backward()
else:
self.scaler.scale(loss).backward()
# flush()
if not self.is_grad_accumulation_step:
@@ -1423,8 +1439,13 @@ class SDTrainer(BaseSDTrainProcess):
torch.nn.utils.clip_grad_norm_(self.params, self.train_config.max_grad_norm)
# only step if we are not accumulating
with self.timer('optimizer_step'):
# apply gradients
self.optimizer.step()
if self.is_bfloat:
self.optimizer.step()
else:
# apply gradients
self.scaler.step(self.optimizer)
self.scaler.update()
# self.optimizer.step()
self.optimizer.zero_grad(set_to_none=True)
else:
# gradient accumulation. Just a place for breakpoint

View File

@@ -48,28 +48,55 @@ dataset_config = DatasetConfig(
buckets=True,
bucket_tolerance=bucket_tolerance,
# poi='person',
shuffle_augmentations=True,
# augmentations=[
# {
# 'method': 'RandomBrightnessContrast',
# 'brightness_limit': (-0.3, 0.3),
# 'contrast_limit': (-0.3, 0.3),
# 'brightness_by_max': False,
# 'p': 1.0
# 'method': 'GaussianBlur',
# 'blur_limit': (1, 16),
# 'sigma_limit': (0, 8),
# 'p': 0.8
# },
# {
# 'method': 'HueSaturationValue',
# 'hue_shift_limit': (-0, 0),
# 'sat_shift_limit': (-40, 40),
# 'val_shift_limit': (-40, 40),
# 'p': 1.0
# 'method': 'ImageCompression',
# 'quality_lower': 10,
# 'quality_upper': 100,
# 'compression_type': 0,
# 'p': 0.8
# },
# {
# 'method': 'ImageCompression',
# 'quality_lower': 20,
# 'quality_upper': 100,
# 'compression_type': 1,
# 'p': 0.8
# },
# {
# 'method': 'RingingOvershoot',
# 'blur_limit': (3, 35),
# 'cutoff': (0.7, 1.96),
# 'p': 0.8
# },
# {
# 'method': 'GaussNoise',
# 'var_limit': (0, 300),
# 'per_channel': True,
# 'mean': 0.0,
# 'p': 0.8
# },
# {
# 'method': 'GlassBlur',
# 'sigma': 0.6,
# 'max_delta': 7,
# 'iterations': 2,
# 'mode': 'fast',
# 'p': 0.8
# },
# {
# 'method': 'Downscale',
# 'scale_max': 0.5,
# 'interpolation': 'cv2.INTER_CUBIC',
# 'p': 0.8
# },
# {
# 'method': 'RGBShift',
# 'r_shift_limit': (-20, 20),
# 'g_shift_limit': (-20, 20),
# 'b_shift_limit': (-20, 20),
# 'p': 1.0
# },
# ]
@@ -100,7 +127,7 @@ for epoch in range(args.epochs):
show_img(img)
# time.sleep(0.1)
# time.sleep(1.0)
# if not last epoch
if epoch < args.epochs - 1:
trigger_dataloader_setup_epoch(dataloader)

View File

@@ -529,6 +529,7 @@ class DatasetConfig:
self.num_workers: int = kwargs.get('num_workers', 4)
self.prefetch_factor: int = kwargs.get('prefetch_factor', 2)
self.extra_values: List[float] = kwargs.get('extra_values', [])
self.square_crop: bool = kwargs.get('square_crop', False)
def preprocess_dataset_raw_config(raw_config: List[dict]) -> List[dict]:

View File

@@ -388,7 +388,10 @@ class CustomAdapter(torch.nn.Module):
self.fuse_module.load_state_dict(state_dict['fuse_module'], strict=strict)
if 'ilora' in state_dict:
self.ilora_module.load_state_dict(state_dict['ilora'], strict=strict)
try:
self.ilora_module.load_state_dict(state_dict['ilora'], strict=strict)
except Exception as e:
print(e)
pass

View File

@@ -203,7 +203,22 @@ class BucketsMixin:
if file_item.has_point_of_interest:
# Attempt to process the poi if we can. It wont process if the image is smaller than the resolution
did_process_poi = file_item.setup_poi_bucket()
if not did_process_poi:
if self.dataset_config.square_crop:
# we scale first so smallest size matches resolution
scale_factor_x = resolution / width
scale_factor_y = resolution / height
scale_factor = max(scale_factor_x, scale_factor_y)
file_item.scale_to_width = math.ceil(width * scale_factor)
file_item.scale_to_height = math.ceil(height * scale_factor)
file_item.crop_width = resolution
file_item.crop_height = resolution
if width > height:
file_item.crop_x = int(file_item.scale_to_width / 2 - resolution / 2)
file_item.crop_y = 0
else:
file_item.crop_x = 0
file_item.crop_y = int(file_item.scale_to_height / 2 - resolution / 2)
elif not did_process_poi:
bucket_resolution = get_bucket_for_image_size(
width, height,
resolution=resolution,

View File

@@ -365,14 +365,15 @@ class IPAdapter(torch.nn.Module):
input_size=preprocessor_input_size,
clip_input_size=self.image_encoder.config.image_size,
)
if 'height' in self.clip_image_processor.size:
self.input_size = self.clip_image_processor.size['height']
elif hasattr(self.clip_image_processor, 'crop_size'):
self.input_size = self.clip_image_processor.crop_size['height']
elif 'shortest_edge' in self.clip_image_processor.size.keys():
self.input_size = self.clip_image_processor.size['shortest_edge']
else:
raise ValueError(f"unknown image processor size: {self.clip_image_processor.size}")
if not self.config.image_encoder_arch == 'safe':
if 'height' in self.clip_image_processor.size:
self.input_size = self.clip_image_processor.size['height']
elif hasattr(self.clip_image_processor, 'crop_size'):
self.input_size = self.clip_image_processor.crop_size['height']
elif 'shortest_edge' in self.clip_image_processor.size.keys():
self.input_size = self.clip_image_processor.size['shortest_edge']
else:
raise ValueError(f"unknown image processor size: {self.clip_image_processor.size}")
self.current_scale = 1.0
self.is_active = True
is_pixart = sd.is_pixart

View File

@@ -21,19 +21,24 @@ class ILoRAProjModule(torch.nn.Module):
self.num_modules = num_modules
self.num_dim = dim
self.norm = torch.nn.LayerNorm(embeddings_dim)
self.proj = torch.nn.Sequential(
torch.nn.LayerNorm(embeddings_dim),
torch.nn.Linear(embeddings_dim, embeddings_dim * 2),
torch.nn.GELU(),
torch.nn.Linear(embeddings_dim * 2, num_modules * dim),
torch.nn.Linear(embeddings_dim * 2, embeddings_dim * 2),
torch.nn.LayerNorm(embeddings_dim * 2),
torch.nn.Linear(embeddings_dim * 2, embeddings_dim * 4),
torch.nn.GELU(),
torch.nn.Linear(embeddings_dim * 4, num_modules * dim),
torch.nn.LayerNorm(num_modules * dim),
)
# Initialize the last linear layer weights near zero
torch.nn.init.uniform_(self.proj[2].weight, a=-0.01, b=0.01)
torch.nn.init.zeros_(self.proj[2].bias)
torch.nn.init.uniform_(self.proj[-2].weight, a=-0.01, b=0.01)
torch.nn.init.zeros_(self.proj[-2].bias)
def forward(self, x):
x = self.norm(x)
x = self.proj(x)
x = x.reshape(-1, self.num_modules, self.num_dim)
return x
@@ -71,6 +76,8 @@ class InstantLoRAMidModule(torch.nn.Module):
# reshape if needed
if len(x.shape) == 3:
scaler = scaler.unsqueeze(1)
if len(x.shape) == 4:
scaler = scaler.unsqueeze(-1).unsqueeze(-1)
except Exception as e:
print(e)
print(x.shape)

View File

@@ -20,11 +20,11 @@ class SAFEReducerBlock(nn.Module):
self.reducer = nn.Sequential(
nn.Conv2d(channels, channels, kernel_size=3, padding=1),
nn.BatchNorm2d(channels),
activation(),
nn.BatchNorm2d(channels),
nn.Conv2d(channels, channels, kernel_size=3, padding=1),
nn.BatchNorm2d(channels),
activation(),
nn.BatchNorm2d(channels),
nn.AvgPool2d(kernel_size=2, stride=2),
)
self.residual_shrink = nn.AvgPool2d(kernel_size=2, stride=2)
@@ -227,6 +227,7 @@ class SAFEVMConfig:
self.reducer_channels = reducer_channels
self.channels = channels
self.downscale_factor = downscale_factor
self.image_size = 224
self.hidden_size = num_vectors
self.projection_dim = num_vectors
@@ -242,7 +243,9 @@ class SAFEVMReturn:
class SAFEVisionModel(SizeAgnosticFeatureEncoder):
def __init__(self, **kwargs):
self.config = SAFEVMConfig(**kwargs)
super().__init__(**kwargs)
self.image_size = None
# super().__init__(**kwargs)
super(SAFEVisionModel, self).__init__(**kwargs)
@classmethod
def from_pretrained(cls, *args, **kwargs):

View File

@@ -20,12 +20,12 @@ def get_optimizer(
# dadaptation uses different lr that is values of 0.1 to 1.0. default to 1.0
use_lr = 1.0
if lower_type.endswith('lion'):
optimizer = dadaptation.DAdaptLion(params, lr=use_lr, **optimizer_params)
optimizer = dadaptation.DAdaptLion(params, eps=1e-6, lr=use_lr, **optimizer_params)
elif lower_type.endswith('adam'):
optimizer = dadaptation.DAdaptLion(params, lr=use_lr, **optimizer_params)
optimizer = dadaptation.DAdaptLion(params, eps=1e-6, lr=use_lr, **optimizer_params)
elif lower_type == 'dadaptation':
# backwards compatibility
optimizer = dadaptation.DAdaptAdam(params, lr=use_lr, **optimizer_params)
optimizer = dadaptation.DAdaptAdam(params, eps=1e-6, lr=use_lr, **optimizer_params)
# warn user that dadaptation is deprecated
print("WARNING: Dadaptation optimizer type has been changed to DadaptationAdam. Please update your config.")
elif lower_type.startswith("prodigy"):
@@ -40,22 +40,22 @@ def get_optimizer(
print(f"Using lr {use_lr}")
# let net be the neural network you want to train
# you can choose weight decay value based on your problem, 0 by default
optimizer = Prodigy(params, lr=use_lr, **optimizer_params)
optimizer = Prodigy(params, lr=use_lr, eps=1e-6, **optimizer_params)
elif lower_type.endswith("8bit"):
import bitsandbytes
if lower_type == "adam8bit":
return bitsandbytes.optim.Adam8bit(params, lr=learning_rate, **optimizer_params)
return bitsandbytes.optim.Adam8bit(params, lr=learning_rate, eps=1e-6, **optimizer_params)
elif lower_type == "adamw8bit":
return bitsandbytes.optim.AdamW8bit(params, lr=learning_rate, **optimizer_params)
return bitsandbytes.optim.AdamW8bit(params, lr=learning_rate, eps=1e-6, **optimizer_params)
elif lower_type == "lion8bit":
return bitsandbytes.optim.Lion8bit(params, lr=learning_rate, **optimizer_params)
else:
raise ValueError(f'Unknown optimizer type {optimizer_type}')
elif lower_type == 'adam':
optimizer = torch.optim.Adam(params, lr=float(learning_rate), **optimizer_params)
optimizer = torch.optim.Adam(params, lr=float(learning_rate), eps=1e-6, **optimizer_params)
elif lower_type == 'adamw':
optimizer = torch.optim.AdamW(params, lr=float(learning_rate), **optimizer_params)
optimizer = torch.optim.AdamW(params, lr=float(learning_rate), eps=1e-6, **optimizer_params)
elif lower_type == 'lion':
try:
from lion_pytorch import Lion
@@ -63,7 +63,7 @@ def get_optimizer(
except ImportError:
raise ImportError("Please install lion_pytorch to use Lion optimizer -> pip install lion-pytorch")
elif lower_type == 'adagrad':
optimizer = torch.optim.Adagrad(params, lr=float(learning_rate), **optimizer_params)
optimizer = torch.optim.Adagrad(params, lr=float(learning_rate), eps=1e-6, **optimizer_params)
elif lower_type == 'adafactor':
# hack in stochastic rounding
if 'relative_step' not in optimizer_params:
@@ -72,7 +72,7 @@ def get_optimizer(
optimizer_params['scale_parameter'] = True
if 'warmup_init' not in optimizer_params:
optimizer_params['warmup_init'] = False
optimizer = Adafactor(params, lr=float(learning_rate), **optimizer_params)
optimizer = Adafactor(params, lr=float(learning_rate), eps=1e-6, **optimizer_params)
from toolkit.util.adafactor_stochastic_rounding import step_adafactor
optimizer.step = step_adafactor.__get__(optimizer, Adafactor)
else: