mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-29 10:41:28 +00:00
WIP on SAFE encoder. Work on fp16 training improvements. Various other tweaks and improvements
This commit is contained in:
@@ -31,6 +31,7 @@ from jobs.process import BaseSDTrainProcess
|
|||||||
from torchvision import transforms
|
from torchvision import transforms
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def flush():
|
def flush():
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
gc.collect()
|
gc.collect()
|
||||||
@@ -55,6 +56,17 @@ class SDTrainer(BaseSDTrainProcess):
|
|||||||
self.negative_prompt_pool: Union[List[str], None] = None
|
self.negative_prompt_pool: Union[List[str], None] = None
|
||||||
self.batch_negative_prompt: Union[List[str], None] = None
|
self.batch_negative_prompt: Union[List[str], None] = None
|
||||||
|
|
||||||
|
self.scaler = torch.cuda.amp.GradScaler()
|
||||||
|
|
||||||
|
# patch the scaler to allow fp16 training
|
||||||
|
org_unscale_grads = self.scaler._unscale_grads_
|
||||||
|
def _unscale_grads_replacer(optimizer, inv_scale, found_inf, allow_fp16):
|
||||||
|
return org_unscale_grads(optimizer, inv_scale, found_inf, True)
|
||||||
|
self.scaler._unscale_grads_ = _unscale_grads_replacer
|
||||||
|
|
||||||
|
self.is_bfloat = self.train_config.dtype == "bfloat16" or self.train_config.dtype == "bf16"
|
||||||
|
|
||||||
|
|
||||||
def before_model_load(self):
|
def before_model_load(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@@ -1401,6 +1413,7 @@ class SDTrainer(BaseSDTrainProcess):
|
|||||||
print("loss is nan")
|
print("loss is nan")
|
||||||
loss = torch.zeros_like(loss).requires_grad_(True)
|
loss = torch.zeros_like(loss).requires_grad_(True)
|
||||||
|
|
||||||
|
|
||||||
with self.timer('backward'):
|
with self.timer('backward'):
|
||||||
# todo we have multiplier seperated. works for now as res are not in same batch, but need to change
|
# todo we have multiplier seperated. works for now as res are not in same batch, but need to change
|
||||||
loss = loss * loss_multiplier.mean()
|
loss = loss * loss_multiplier.mean()
|
||||||
@@ -1410,7 +1423,10 @@ class SDTrainer(BaseSDTrainProcess):
|
|||||||
# 0.0 for the backward pass and the gradients will be 0.0
|
# 0.0 for the backward pass and the gradients will be 0.0
|
||||||
# I spent weeks on fighting this. DON'T DO IT
|
# I spent weeks on fighting this. DON'T DO IT
|
||||||
# with fsdp_overlap_step_with_backward():
|
# with fsdp_overlap_step_with_backward():
|
||||||
loss.backward()
|
if self.is_bfloat:
|
||||||
|
loss.backward()
|
||||||
|
else:
|
||||||
|
self.scaler.scale(loss).backward()
|
||||||
# flush()
|
# flush()
|
||||||
|
|
||||||
if not self.is_grad_accumulation_step:
|
if not self.is_grad_accumulation_step:
|
||||||
@@ -1423,8 +1439,13 @@ class SDTrainer(BaseSDTrainProcess):
|
|||||||
torch.nn.utils.clip_grad_norm_(self.params, self.train_config.max_grad_norm)
|
torch.nn.utils.clip_grad_norm_(self.params, self.train_config.max_grad_norm)
|
||||||
# only step if we are not accumulating
|
# only step if we are not accumulating
|
||||||
with self.timer('optimizer_step'):
|
with self.timer('optimizer_step'):
|
||||||
# apply gradients
|
if self.is_bfloat:
|
||||||
self.optimizer.step()
|
self.optimizer.step()
|
||||||
|
else:
|
||||||
|
# apply gradients
|
||||||
|
self.scaler.step(self.optimizer)
|
||||||
|
self.scaler.update()
|
||||||
|
# self.optimizer.step()
|
||||||
self.optimizer.zero_grad(set_to_none=True)
|
self.optimizer.zero_grad(set_to_none=True)
|
||||||
else:
|
else:
|
||||||
# gradient accumulation. Just a place for breakpoint
|
# gradient accumulation. Just a place for breakpoint
|
||||||
|
|||||||
@@ -48,28 +48,55 @@ dataset_config = DatasetConfig(
|
|||||||
buckets=True,
|
buckets=True,
|
||||||
bucket_tolerance=bucket_tolerance,
|
bucket_tolerance=bucket_tolerance,
|
||||||
# poi='person',
|
# poi='person',
|
||||||
|
shuffle_augmentations=True,
|
||||||
# augmentations=[
|
# augmentations=[
|
||||||
# {
|
# {
|
||||||
# 'method': 'RandomBrightnessContrast',
|
# 'method': 'GaussianBlur',
|
||||||
# 'brightness_limit': (-0.3, 0.3),
|
# 'blur_limit': (1, 16),
|
||||||
# 'contrast_limit': (-0.3, 0.3),
|
# 'sigma_limit': (0, 8),
|
||||||
# 'brightness_by_max': False,
|
# 'p': 0.8
|
||||||
# 'p': 1.0
|
|
||||||
# },
|
# },
|
||||||
# {
|
# {
|
||||||
# 'method': 'HueSaturationValue',
|
# 'method': 'ImageCompression',
|
||||||
# 'hue_shift_limit': (-0, 0),
|
# 'quality_lower': 10,
|
||||||
# 'sat_shift_limit': (-40, 40),
|
# 'quality_upper': 100,
|
||||||
# 'val_shift_limit': (-40, 40),
|
# 'compression_type': 0,
|
||||||
# 'p': 1.0
|
# 'p': 0.8
|
||||||
|
# },
|
||||||
|
# {
|
||||||
|
# 'method': 'ImageCompression',
|
||||||
|
# 'quality_lower': 20,
|
||||||
|
# 'quality_upper': 100,
|
||||||
|
# 'compression_type': 1,
|
||||||
|
# 'p': 0.8
|
||||||
|
# },
|
||||||
|
# {
|
||||||
|
# 'method': 'RingingOvershoot',
|
||||||
|
# 'blur_limit': (3, 35),
|
||||||
|
# 'cutoff': (0.7, 1.96),
|
||||||
|
# 'p': 0.8
|
||||||
|
# },
|
||||||
|
# {
|
||||||
|
# 'method': 'GaussNoise',
|
||||||
|
# 'var_limit': (0, 300),
|
||||||
|
# 'per_channel': True,
|
||||||
|
# 'mean': 0.0,
|
||||||
|
# 'p': 0.8
|
||||||
|
# },
|
||||||
|
# {
|
||||||
|
# 'method': 'GlassBlur',
|
||||||
|
# 'sigma': 0.6,
|
||||||
|
# 'max_delta': 7,
|
||||||
|
# 'iterations': 2,
|
||||||
|
# 'mode': 'fast',
|
||||||
|
# 'p': 0.8
|
||||||
|
# },
|
||||||
|
# {
|
||||||
|
# 'method': 'Downscale',
|
||||||
|
# 'scale_max': 0.5,
|
||||||
|
# 'interpolation': 'cv2.INTER_CUBIC',
|
||||||
|
# 'p': 0.8
|
||||||
# },
|
# },
|
||||||
# {
|
|
||||||
# 'method': 'RGBShift',
|
|
||||||
# 'r_shift_limit': (-20, 20),
|
|
||||||
# 'g_shift_limit': (-20, 20),
|
|
||||||
# 'b_shift_limit': (-20, 20),
|
|
||||||
# 'p': 1.0
|
|
||||||
# },
|
|
||||||
# ]
|
# ]
|
||||||
|
|
||||||
|
|
||||||
@@ -100,7 +127,7 @@ for epoch in range(args.epochs):
|
|||||||
|
|
||||||
show_img(img)
|
show_img(img)
|
||||||
|
|
||||||
# time.sleep(0.1)
|
# time.sleep(1.0)
|
||||||
# if not last epoch
|
# if not last epoch
|
||||||
if epoch < args.epochs - 1:
|
if epoch < args.epochs - 1:
|
||||||
trigger_dataloader_setup_epoch(dataloader)
|
trigger_dataloader_setup_epoch(dataloader)
|
||||||
|
|||||||
@@ -529,6 +529,7 @@ class DatasetConfig:
|
|||||||
self.num_workers: int = kwargs.get('num_workers', 4)
|
self.num_workers: int = kwargs.get('num_workers', 4)
|
||||||
self.prefetch_factor: int = kwargs.get('prefetch_factor', 2)
|
self.prefetch_factor: int = kwargs.get('prefetch_factor', 2)
|
||||||
self.extra_values: List[float] = kwargs.get('extra_values', [])
|
self.extra_values: List[float] = kwargs.get('extra_values', [])
|
||||||
|
self.square_crop: bool = kwargs.get('square_crop', False)
|
||||||
|
|
||||||
|
|
||||||
def preprocess_dataset_raw_config(raw_config: List[dict]) -> List[dict]:
|
def preprocess_dataset_raw_config(raw_config: List[dict]) -> List[dict]:
|
||||||
|
|||||||
@@ -388,7 +388,10 @@ class CustomAdapter(torch.nn.Module):
|
|||||||
self.fuse_module.load_state_dict(state_dict['fuse_module'], strict=strict)
|
self.fuse_module.load_state_dict(state_dict['fuse_module'], strict=strict)
|
||||||
|
|
||||||
if 'ilora' in state_dict:
|
if 'ilora' in state_dict:
|
||||||
self.ilora_module.load_state_dict(state_dict['ilora'], strict=strict)
|
try:
|
||||||
|
self.ilora_module.load_state_dict(state_dict['ilora'], strict=strict)
|
||||||
|
except Exception as e:
|
||||||
|
print(e)
|
||||||
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|||||||
@@ -203,7 +203,22 @@ class BucketsMixin:
|
|||||||
if file_item.has_point_of_interest:
|
if file_item.has_point_of_interest:
|
||||||
# Attempt to process the poi if we can. It wont process if the image is smaller than the resolution
|
# Attempt to process the poi if we can. It wont process if the image is smaller than the resolution
|
||||||
did_process_poi = file_item.setup_poi_bucket()
|
did_process_poi = file_item.setup_poi_bucket()
|
||||||
if not did_process_poi:
|
if self.dataset_config.square_crop:
|
||||||
|
# we scale first so smallest size matches resolution
|
||||||
|
scale_factor_x = resolution / width
|
||||||
|
scale_factor_y = resolution / height
|
||||||
|
scale_factor = max(scale_factor_x, scale_factor_y)
|
||||||
|
file_item.scale_to_width = math.ceil(width * scale_factor)
|
||||||
|
file_item.scale_to_height = math.ceil(height * scale_factor)
|
||||||
|
file_item.crop_width = resolution
|
||||||
|
file_item.crop_height = resolution
|
||||||
|
if width > height:
|
||||||
|
file_item.crop_x = int(file_item.scale_to_width / 2 - resolution / 2)
|
||||||
|
file_item.crop_y = 0
|
||||||
|
else:
|
||||||
|
file_item.crop_x = 0
|
||||||
|
file_item.crop_y = int(file_item.scale_to_height / 2 - resolution / 2)
|
||||||
|
elif not did_process_poi:
|
||||||
bucket_resolution = get_bucket_for_image_size(
|
bucket_resolution = get_bucket_for_image_size(
|
||||||
width, height,
|
width, height,
|
||||||
resolution=resolution,
|
resolution=resolution,
|
||||||
|
|||||||
@@ -365,14 +365,15 @@ class IPAdapter(torch.nn.Module):
|
|||||||
input_size=preprocessor_input_size,
|
input_size=preprocessor_input_size,
|
||||||
clip_input_size=self.image_encoder.config.image_size,
|
clip_input_size=self.image_encoder.config.image_size,
|
||||||
)
|
)
|
||||||
if 'height' in self.clip_image_processor.size:
|
if not self.config.image_encoder_arch == 'safe':
|
||||||
self.input_size = self.clip_image_processor.size['height']
|
if 'height' in self.clip_image_processor.size:
|
||||||
elif hasattr(self.clip_image_processor, 'crop_size'):
|
self.input_size = self.clip_image_processor.size['height']
|
||||||
self.input_size = self.clip_image_processor.crop_size['height']
|
elif hasattr(self.clip_image_processor, 'crop_size'):
|
||||||
elif 'shortest_edge' in self.clip_image_processor.size.keys():
|
self.input_size = self.clip_image_processor.crop_size['height']
|
||||||
self.input_size = self.clip_image_processor.size['shortest_edge']
|
elif 'shortest_edge' in self.clip_image_processor.size.keys():
|
||||||
else:
|
self.input_size = self.clip_image_processor.size['shortest_edge']
|
||||||
raise ValueError(f"unknown image processor size: {self.clip_image_processor.size}")
|
else:
|
||||||
|
raise ValueError(f"unknown image processor size: {self.clip_image_processor.size}")
|
||||||
self.current_scale = 1.0
|
self.current_scale = 1.0
|
||||||
self.is_active = True
|
self.is_active = True
|
||||||
is_pixart = sd.is_pixart
|
is_pixart = sd.is_pixart
|
||||||
|
|||||||
@@ -21,19 +21,24 @@ class ILoRAProjModule(torch.nn.Module):
|
|||||||
|
|
||||||
self.num_modules = num_modules
|
self.num_modules = num_modules
|
||||||
self.num_dim = dim
|
self.num_dim = dim
|
||||||
self.norm = torch.nn.LayerNorm(embeddings_dim)
|
|
||||||
|
|
||||||
self.proj = torch.nn.Sequential(
|
self.proj = torch.nn.Sequential(
|
||||||
|
torch.nn.LayerNorm(embeddings_dim),
|
||||||
torch.nn.Linear(embeddings_dim, embeddings_dim * 2),
|
torch.nn.Linear(embeddings_dim, embeddings_dim * 2),
|
||||||
torch.nn.GELU(),
|
torch.nn.GELU(),
|
||||||
torch.nn.Linear(embeddings_dim * 2, num_modules * dim),
|
torch.nn.Linear(embeddings_dim * 2, embeddings_dim * 2),
|
||||||
|
torch.nn.LayerNorm(embeddings_dim * 2),
|
||||||
|
|
||||||
|
torch.nn.Linear(embeddings_dim * 2, embeddings_dim * 4),
|
||||||
|
torch.nn.GELU(),
|
||||||
|
torch.nn.Linear(embeddings_dim * 4, num_modules * dim),
|
||||||
|
torch.nn.LayerNorm(num_modules * dim),
|
||||||
)
|
)
|
||||||
# Initialize the last linear layer weights near zero
|
# Initialize the last linear layer weights near zero
|
||||||
torch.nn.init.uniform_(self.proj[2].weight, a=-0.01, b=0.01)
|
torch.nn.init.uniform_(self.proj[-2].weight, a=-0.01, b=0.01)
|
||||||
torch.nn.init.zeros_(self.proj[2].bias)
|
torch.nn.init.zeros_(self.proj[-2].bias)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
x = self.norm(x)
|
|
||||||
x = self.proj(x)
|
x = self.proj(x)
|
||||||
x = x.reshape(-1, self.num_modules, self.num_dim)
|
x = x.reshape(-1, self.num_modules, self.num_dim)
|
||||||
return x
|
return x
|
||||||
@@ -71,6 +76,8 @@ class InstantLoRAMidModule(torch.nn.Module):
|
|||||||
# reshape if needed
|
# reshape if needed
|
||||||
if len(x.shape) == 3:
|
if len(x.shape) == 3:
|
||||||
scaler = scaler.unsqueeze(1)
|
scaler = scaler.unsqueeze(1)
|
||||||
|
if len(x.shape) == 4:
|
||||||
|
scaler = scaler.unsqueeze(-1).unsqueeze(-1)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(e)
|
print(e)
|
||||||
print(x.shape)
|
print(x.shape)
|
||||||
|
|||||||
@@ -20,11 +20,11 @@ class SAFEReducerBlock(nn.Module):
|
|||||||
|
|
||||||
self.reducer = nn.Sequential(
|
self.reducer = nn.Sequential(
|
||||||
nn.Conv2d(channels, channels, kernel_size=3, padding=1),
|
nn.Conv2d(channels, channels, kernel_size=3, padding=1),
|
||||||
nn.BatchNorm2d(channels),
|
|
||||||
activation(),
|
activation(),
|
||||||
|
nn.BatchNorm2d(channels),
|
||||||
nn.Conv2d(channels, channels, kernel_size=3, padding=1),
|
nn.Conv2d(channels, channels, kernel_size=3, padding=1),
|
||||||
nn.BatchNorm2d(channels),
|
|
||||||
activation(),
|
activation(),
|
||||||
|
nn.BatchNorm2d(channels),
|
||||||
nn.AvgPool2d(kernel_size=2, stride=2),
|
nn.AvgPool2d(kernel_size=2, stride=2),
|
||||||
)
|
)
|
||||||
self.residual_shrink = nn.AvgPool2d(kernel_size=2, stride=2)
|
self.residual_shrink = nn.AvgPool2d(kernel_size=2, stride=2)
|
||||||
@@ -227,6 +227,7 @@ class SAFEVMConfig:
|
|||||||
self.reducer_channels = reducer_channels
|
self.reducer_channels = reducer_channels
|
||||||
self.channels = channels
|
self.channels = channels
|
||||||
self.downscale_factor = downscale_factor
|
self.downscale_factor = downscale_factor
|
||||||
|
self.image_size = 224
|
||||||
|
|
||||||
self.hidden_size = num_vectors
|
self.hidden_size = num_vectors
|
||||||
self.projection_dim = num_vectors
|
self.projection_dim = num_vectors
|
||||||
@@ -242,7 +243,9 @@ class SAFEVMReturn:
|
|||||||
class SAFEVisionModel(SizeAgnosticFeatureEncoder):
|
class SAFEVisionModel(SizeAgnosticFeatureEncoder):
|
||||||
def __init__(self, **kwargs):
|
def __init__(self, **kwargs):
|
||||||
self.config = SAFEVMConfig(**kwargs)
|
self.config = SAFEVMConfig(**kwargs)
|
||||||
super().__init__(**kwargs)
|
self.image_size = None
|
||||||
|
# super().__init__(**kwargs)
|
||||||
|
super(SAFEVisionModel, self).__init__(**kwargs)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_pretrained(cls, *args, **kwargs):
|
def from_pretrained(cls, *args, **kwargs):
|
||||||
|
|||||||
@@ -20,12 +20,12 @@ def get_optimizer(
|
|||||||
# dadaptation uses different lr that is values of 0.1 to 1.0. default to 1.0
|
# dadaptation uses different lr that is values of 0.1 to 1.0. default to 1.0
|
||||||
use_lr = 1.0
|
use_lr = 1.0
|
||||||
if lower_type.endswith('lion'):
|
if lower_type.endswith('lion'):
|
||||||
optimizer = dadaptation.DAdaptLion(params, lr=use_lr, **optimizer_params)
|
optimizer = dadaptation.DAdaptLion(params, eps=1e-6, lr=use_lr, **optimizer_params)
|
||||||
elif lower_type.endswith('adam'):
|
elif lower_type.endswith('adam'):
|
||||||
optimizer = dadaptation.DAdaptLion(params, lr=use_lr, **optimizer_params)
|
optimizer = dadaptation.DAdaptLion(params, eps=1e-6, lr=use_lr, **optimizer_params)
|
||||||
elif lower_type == 'dadaptation':
|
elif lower_type == 'dadaptation':
|
||||||
# backwards compatibility
|
# backwards compatibility
|
||||||
optimizer = dadaptation.DAdaptAdam(params, lr=use_lr, **optimizer_params)
|
optimizer = dadaptation.DAdaptAdam(params, eps=1e-6, lr=use_lr, **optimizer_params)
|
||||||
# warn user that dadaptation is deprecated
|
# warn user that dadaptation is deprecated
|
||||||
print("WARNING: Dadaptation optimizer type has been changed to DadaptationAdam. Please update your config.")
|
print("WARNING: Dadaptation optimizer type has been changed to DadaptationAdam. Please update your config.")
|
||||||
elif lower_type.startswith("prodigy"):
|
elif lower_type.startswith("prodigy"):
|
||||||
@@ -40,22 +40,22 @@ def get_optimizer(
|
|||||||
print(f"Using lr {use_lr}")
|
print(f"Using lr {use_lr}")
|
||||||
# let net be the neural network you want to train
|
# let net be the neural network you want to train
|
||||||
# you can choose weight decay value based on your problem, 0 by default
|
# you can choose weight decay value based on your problem, 0 by default
|
||||||
optimizer = Prodigy(params, lr=use_lr, **optimizer_params)
|
optimizer = Prodigy(params, lr=use_lr, eps=1e-6, **optimizer_params)
|
||||||
elif lower_type.endswith("8bit"):
|
elif lower_type.endswith("8bit"):
|
||||||
import bitsandbytes
|
import bitsandbytes
|
||||||
|
|
||||||
if lower_type == "adam8bit":
|
if lower_type == "adam8bit":
|
||||||
return bitsandbytes.optim.Adam8bit(params, lr=learning_rate, **optimizer_params)
|
return bitsandbytes.optim.Adam8bit(params, lr=learning_rate, eps=1e-6, **optimizer_params)
|
||||||
elif lower_type == "adamw8bit":
|
elif lower_type == "adamw8bit":
|
||||||
return bitsandbytes.optim.AdamW8bit(params, lr=learning_rate, **optimizer_params)
|
return bitsandbytes.optim.AdamW8bit(params, lr=learning_rate, eps=1e-6, **optimizer_params)
|
||||||
elif lower_type == "lion8bit":
|
elif lower_type == "lion8bit":
|
||||||
return bitsandbytes.optim.Lion8bit(params, lr=learning_rate, **optimizer_params)
|
return bitsandbytes.optim.Lion8bit(params, lr=learning_rate, **optimizer_params)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f'Unknown optimizer type {optimizer_type}')
|
raise ValueError(f'Unknown optimizer type {optimizer_type}')
|
||||||
elif lower_type == 'adam':
|
elif lower_type == 'adam':
|
||||||
optimizer = torch.optim.Adam(params, lr=float(learning_rate), **optimizer_params)
|
optimizer = torch.optim.Adam(params, lr=float(learning_rate), eps=1e-6, **optimizer_params)
|
||||||
elif lower_type == 'adamw':
|
elif lower_type == 'adamw':
|
||||||
optimizer = torch.optim.AdamW(params, lr=float(learning_rate), **optimizer_params)
|
optimizer = torch.optim.AdamW(params, lr=float(learning_rate), eps=1e-6, **optimizer_params)
|
||||||
elif lower_type == 'lion':
|
elif lower_type == 'lion':
|
||||||
try:
|
try:
|
||||||
from lion_pytorch import Lion
|
from lion_pytorch import Lion
|
||||||
@@ -63,7 +63,7 @@ def get_optimizer(
|
|||||||
except ImportError:
|
except ImportError:
|
||||||
raise ImportError("Please install lion_pytorch to use Lion optimizer -> pip install lion-pytorch")
|
raise ImportError("Please install lion_pytorch to use Lion optimizer -> pip install lion-pytorch")
|
||||||
elif lower_type == 'adagrad':
|
elif lower_type == 'adagrad':
|
||||||
optimizer = torch.optim.Adagrad(params, lr=float(learning_rate), **optimizer_params)
|
optimizer = torch.optim.Adagrad(params, lr=float(learning_rate), eps=1e-6, **optimizer_params)
|
||||||
elif lower_type == 'adafactor':
|
elif lower_type == 'adafactor':
|
||||||
# hack in stochastic rounding
|
# hack in stochastic rounding
|
||||||
if 'relative_step' not in optimizer_params:
|
if 'relative_step' not in optimizer_params:
|
||||||
@@ -72,7 +72,7 @@ def get_optimizer(
|
|||||||
optimizer_params['scale_parameter'] = True
|
optimizer_params['scale_parameter'] = True
|
||||||
if 'warmup_init' not in optimizer_params:
|
if 'warmup_init' not in optimizer_params:
|
||||||
optimizer_params['warmup_init'] = False
|
optimizer_params['warmup_init'] = False
|
||||||
optimizer = Adafactor(params, lr=float(learning_rate), **optimizer_params)
|
optimizer = Adafactor(params, lr=float(learning_rate), eps=1e-6, **optimizer_params)
|
||||||
from toolkit.util.adafactor_stochastic_rounding import step_adafactor
|
from toolkit.util.adafactor_stochastic_rounding import step_adafactor
|
||||||
optimizer.step = step_adafactor.__get__(optimizer, Adafactor)
|
optimizer.step = step_adafactor.__get__(optimizer, Adafactor)
|
||||||
else:
|
else:
|
||||||
|
|||||||
Reference in New Issue
Block a user