Added support for full finetuning flux with randomized param activation. Examples coming soon

This commit is contained in:
Jaret Burkett
2024-11-21 13:05:32 -07:00
parent 894374b2e9
commit 96d418bb95
4 changed files with 194 additions and 8 deletions

View File

@@ -3,6 +3,7 @@ from typing import List
import torch
from toolkit.optimizers.optimizer_utils import copy_stochastic, stochastic_grad_accummulation
from optimum.quanto import QBytesTensor
import random
class Adafactor(torch.optim.Optimizer):
@@ -105,6 +106,8 @@ class Adafactor(torch.optim.Optimizer):
scale_parameter=True,
relative_step=True,
warmup_init=False,
do_paramiter_swapping=False,
paramiter_swapping_factor=0.1,
):
if lr is not None and relative_step:
raise ValueError(
@@ -140,6 +143,49 @@ class Adafactor(torch.optim.Optimizer):
param.register_post_accumulate_grad_hook(
stochastic_grad_accummulation
)
self.do_paramiter_swapping = do_paramiter_swapping
self.paramiter_swapping_factor = paramiter_swapping_factor
self._total_paramiter_size = 0
# count total paramiters
for group in self.param_groups:
for param in group['params']:
self._total_paramiter_size += torch.numel(param)
# pretty print total paramiters with comma seperation
print(f"Total training paramiters: {self._total_paramiter_size:,}")
# needs to be enabled to count paramiters
if self.do_paramiter_swapping:
self.enable_paramiter_swapping(self.paramiter_swapping_factor)
def enable_paramiter_swapping(self, paramiter_swapping_factor=0.1):
self.do_paramiter_swapping = True
self.paramiter_swapping_factor = paramiter_swapping_factor
# call it an initial time
self.swap_paramiters()
def swap_paramiters(self):
all_params = []
# deactivate all paramiters
for group in self.param_groups:
for param in group['params']:
param.requires_grad_(False)
# remove any grad
param.grad = None
all_params.append(param)
# shuffle all paramiters
random.shuffle(all_params)
# keep activating paramiters until we are going to go over the target paramiters
target_paramiters = int(self._total_paramiter_size * self.paramiter_swapping_factor)
total_paramiters = 0
for param in all_params:
total_paramiters += torch.numel(param)
if total_paramiters >= target_paramiters:
break
else:
param.requires_grad_(True)
@staticmethod
def _get_lr(param_group, param_state):
@@ -209,7 +255,7 @@ class Adafactor(torch.optim.Optimizer):
for group in self.param_groups:
for p in group["params"]:
if p.grad is None:
if p.grad is None or not p.requires_grad:
continue
grad = p.grad