diff --git a/jobs/process/BaseSDTrainProcess.py b/jobs/process/BaseSDTrainProcess.py index b6ca5a73..ffc2573a 100644 --- a/jobs/process/BaseSDTrainProcess.py +++ b/jobs/process/BaseSDTrainProcess.py @@ -582,7 +582,8 @@ class BaseSDTrainProcess(BaseTrainProcess): params.append(param) self.ema = ExponentialMovingAverage( params, - self.train_config.ema_config.ema_decay + self.train_config.ema_config.ema_decay, + use_feedback=self.train_config.ema_config.use_feedback, ) def before_dataset_load(self): diff --git a/repositories/sd-scripts b/repositories/sd-scripts index 25f961bc..b78c0e2a 160000 --- a/repositories/sd-scripts +++ b/repositories/sd-scripts @@ -1 +1 @@ -Subproject commit 25f961bc779bc79aef440813e3e8e92244ac5739 +Subproject commit b78c0e2a69e52ce6c79abc6c8c82d1a9cabcf05c diff --git a/toolkit/config_modules.py b/toolkit/config_modules.py index e08e8195..398caa0a 100644 --- a/toolkit/config_modules.py +++ b/toolkit/config_modules.py @@ -418,6 +418,8 @@ class EMAConfig: def __init__(self, **kwargs): self.use_ema: bool = kwargs.get('use_ema', False) self.ema_decay: float = kwargs.get('ema_decay', 0.999) + # feeds back the decay difference into the parameter + self.use_feedback: bool = kwargs.get('use_feedback', False) class ReferenceDatasetConfig: diff --git a/toolkit/ema.py b/toolkit/ema.py index 43eb8c8f..6a5df2df 100644 --- a/toolkit/ema.py +++ b/toolkit/ema.py @@ -43,7 +43,9 @@ class ExponentialMovingAverage: self, parameters: Iterable[torch.nn.Parameter] = None, decay: float = 0.995, - use_num_updates: bool = True + use_num_updates: bool = True, + # feeds back the decat to the parameter + use_feedback: bool = False ): if parameters is None: raise ValueError("parameters must be provided") @@ -51,6 +53,7 @@ class ExponentialMovingAverage: raise ValueError('Decay must be between 0 and 1') self.decay = decay self.num_updates = 0 if use_num_updates else None + self.use_feedback = use_feedback parameters = list(parameters) self.shadow_params = [ p.clone().detach() @@ -123,6 +126,9 @@ class ExponentialMovingAverage: tmp.mul_(one_minus_decay) s_param.sub_(tmp) + if self.use_feedback: + param.add_(tmp) + def copy_to( self, parameters: Optional[Iterable[torch.nn.Parameter]] = None diff --git a/toolkit/stable_diffusion_model.py b/toolkit/stable_diffusion_model.py index 4c0a12f4..04f08de8 100644 --- a/toolkit/stable_diffusion_model.py +++ b/toolkit/stable_diffusion_model.py @@ -118,6 +118,7 @@ class StableDiffusion: dtype='fp16', custom_pipeline=None, noise_scheduler=None, + quantize_device=None, ): self.custom_pipeline = custom_pipeline self.device = device @@ -171,6 +172,8 @@ class StableDiffusion: if self.is_flux or self.is_v3 or self.is_auraflow: self.is_flow_matching = True + self.quantize_device = quantize_device if quantize_device is not None else self.device + def load_model(self): if self.is_loaded: return @@ -454,10 +457,6 @@ class StableDiffusion: elif self.model_config.is_flux: print("Loading Flux model") base_model_path = "black-forest-labs/FLUX.1-schnell" - scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(base_model_path, subfolder="scheduler") - print("Loading vae") - vae = AutoencoderKL.from_pretrained(base_model_path, subfolder="vae", torch_dtype=dtype) - flush() print("Loading transformer") subfolder = 'transformer' transformer_path = model_path @@ -472,19 +471,19 @@ class StableDiffusion: # low_cpu_mem_usage=False, # device_map=None ) - transformer.to(self.device_torch, dtype=dtype) + transformer.to(torch.device(self.quantize_device), dtype=dtype) flush() if self.model_config.lora_path is not None: # need the pipe to do this unfortunately for now # we have to fuse in the weights before quantizing pipe: FluxPipeline = FluxPipeline( - scheduler=scheduler, + scheduler=None, text_encoder=None, tokenizer=None, text_encoder_2=None, tokenizer_2=None, - vae=vae, + vae=None, transformer=transformer, ) pipe.load_lora_weights(self.model_config.lora_path, adapter_name="lora1") @@ -496,6 +495,15 @@ class StableDiffusion: print("Quantizing transformer") quantize(transformer, weights=qfloat8) freeze(transformer) + transformer.to(self.device_torch) + else: + transformer.to(self.device_torch, dtype=dtype) + + flush() + + scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(base_model_path, subfolder="scheduler") + print("Loading vae") + vae = AutoencoderKL.from_pretrained(base_model_path, subfolder="vae", torch_dtype=dtype) flush() print("Loading t5")