Adjustments to loading of flux. Added a feedback to ema

This commit is contained in:
Jaret Burkett
2024-08-07 13:17:26 -06:00
parent 653fe60f16
commit acafe9984f
5 changed files with 27 additions and 10 deletions

View File

@@ -582,7 +582,8 @@ class BaseSDTrainProcess(BaseTrainProcess):
params.append(param) params.append(param)
self.ema = ExponentialMovingAverage( self.ema = ExponentialMovingAverage(
params, params,
self.train_config.ema_config.ema_decay self.train_config.ema_config.ema_decay,
use_feedback=self.train_config.ema_config.use_feedback,
) )
def before_dataset_load(self): def before_dataset_load(self):

View File

@@ -418,6 +418,8 @@ class EMAConfig:
def __init__(self, **kwargs): def __init__(self, **kwargs):
self.use_ema: bool = kwargs.get('use_ema', False) self.use_ema: bool = kwargs.get('use_ema', False)
self.ema_decay: float = kwargs.get('ema_decay', 0.999) self.ema_decay: float = kwargs.get('ema_decay', 0.999)
# feeds back the decay difference into the parameter
self.use_feedback: bool = kwargs.get('use_feedback', False)
class ReferenceDatasetConfig: class ReferenceDatasetConfig:

View File

@@ -43,7 +43,9 @@ class ExponentialMovingAverage:
self, self,
parameters: Iterable[torch.nn.Parameter] = None, parameters: Iterable[torch.nn.Parameter] = None,
decay: float = 0.995, decay: float = 0.995,
use_num_updates: bool = True use_num_updates: bool = True,
# feeds back the decat to the parameter
use_feedback: bool = False
): ):
if parameters is None: if parameters is None:
raise ValueError("parameters must be provided") raise ValueError("parameters must be provided")
@@ -51,6 +53,7 @@ class ExponentialMovingAverage:
raise ValueError('Decay must be between 0 and 1') raise ValueError('Decay must be between 0 and 1')
self.decay = decay self.decay = decay
self.num_updates = 0 if use_num_updates else None self.num_updates = 0 if use_num_updates else None
self.use_feedback = use_feedback
parameters = list(parameters) parameters = list(parameters)
self.shadow_params = [ self.shadow_params = [
p.clone().detach() p.clone().detach()
@@ -123,6 +126,9 @@ class ExponentialMovingAverage:
tmp.mul_(one_minus_decay) tmp.mul_(one_minus_decay)
s_param.sub_(tmp) s_param.sub_(tmp)
if self.use_feedback:
param.add_(tmp)
def copy_to( def copy_to(
self, self,
parameters: Optional[Iterable[torch.nn.Parameter]] = None parameters: Optional[Iterable[torch.nn.Parameter]] = None

View File

@@ -118,6 +118,7 @@ class StableDiffusion:
dtype='fp16', dtype='fp16',
custom_pipeline=None, custom_pipeline=None,
noise_scheduler=None, noise_scheduler=None,
quantize_device=None,
): ):
self.custom_pipeline = custom_pipeline self.custom_pipeline = custom_pipeline
self.device = device self.device = device
@@ -171,6 +172,8 @@ class StableDiffusion:
if self.is_flux or self.is_v3 or self.is_auraflow: if self.is_flux or self.is_v3 or self.is_auraflow:
self.is_flow_matching = True self.is_flow_matching = True
self.quantize_device = quantize_device if quantize_device is not None else self.device
def load_model(self): def load_model(self):
if self.is_loaded: if self.is_loaded:
return return
@@ -454,10 +457,6 @@ class StableDiffusion:
elif self.model_config.is_flux: elif self.model_config.is_flux:
print("Loading Flux model") print("Loading Flux model")
base_model_path = "black-forest-labs/FLUX.1-schnell" base_model_path = "black-forest-labs/FLUX.1-schnell"
scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(base_model_path, subfolder="scheduler")
print("Loading vae")
vae = AutoencoderKL.from_pretrained(base_model_path, subfolder="vae", torch_dtype=dtype)
flush()
print("Loading transformer") print("Loading transformer")
subfolder = 'transformer' subfolder = 'transformer'
transformer_path = model_path transformer_path = model_path
@@ -472,19 +471,19 @@ class StableDiffusion:
# low_cpu_mem_usage=False, # low_cpu_mem_usage=False,
# device_map=None # device_map=None
) )
transformer.to(self.device_torch, dtype=dtype) transformer.to(torch.device(self.quantize_device), dtype=dtype)
flush() flush()
if self.model_config.lora_path is not None: if self.model_config.lora_path is not None:
# need the pipe to do this unfortunately for now # need the pipe to do this unfortunately for now
# we have to fuse in the weights before quantizing # we have to fuse in the weights before quantizing
pipe: FluxPipeline = FluxPipeline( pipe: FluxPipeline = FluxPipeline(
scheduler=scheduler, scheduler=None,
text_encoder=None, text_encoder=None,
tokenizer=None, tokenizer=None,
text_encoder_2=None, text_encoder_2=None,
tokenizer_2=None, tokenizer_2=None,
vae=vae, vae=None,
transformer=transformer, transformer=transformer,
) )
pipe.load_lora_weights(self.model_config.lora_path, adapter_name="lora1") pipe.load_lora_weights(self.model_config.lora_path, adapter_name="lora1")
@@ -496,6 +495,15 @@ class StableDiffusion:
print("Quantizing transformer") print("Quantizing transformer")
quantize(transformer, weights=qfloat8) quantize(transformer, weights=qfloat8)
freeze(transformer) freeze(transformer)
transformer.to(self.device_torch)
else:
transformer.to(self.device_torch, dtype=dtype)
flush()
scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(base_model_path, subfolder="scheduler")
print("Loading vae")
vae = AutoencoderKL.from_pretrained(base_model_path, subfolder="vae", torch_dtype=dtype)
flush() flush()
print("Loading t5") print("Loading t5")