mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-05-01 03:31:35 +00:00
Adjustments to loading of flux. Added a feedback to ema
This commit is contained in:
@@ -582,7 +582,8 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
|||||||
params.append(param)
|
params.append(param)
|
||||||
self.ema = ExponentialMovingAverage(
|
self.ema = ExponentialMovingAverage(
|
||||||
params,
|
params,
|
||||||
self.train_config.ema_config.ema_decay
|
self.train_config.ema_config.ema_decay,
|
||||||
|
use_feedback=self.train_config.ema_config.use_feedback,
|
||||||
)
|
)
|
||||||
|
|
||||||
def before_dataset_load(self):
|
def before_dataset_load(self):
|
||||||
|
|||||||
Submodule repositories/sd-scripts updated: 25f961bc77...b78c0e2a69
@@ -418,6 +418,8 @@ class EMAConfig:
|
|||||||
def __init__(self, **kwargs):
|
def __init__(self, **kwargs):
|
||||||
self.use_ema: bool = kwargs.get('use_ema', False)
|
self.use_ema: bool = kwargs.get('use_ema', False)
|
||||||
self.ema_decay: float = kwargs.get('ema_decay', 0.999)
|
self.ema_decay: float = kwargs.get('ema_decay', 0.999)
|
||||||
|
# feeds back the decay difference into the parameter
|
||||||
|
self.use_feedback: bool = kwargs.get('use_feedback', False)
|
||||||
|
|
||||||
|
|
||||||
class ReferenceDatasetConfig:
|
class ReferenceDatasetConfig:
|
||||||
|
|||||||
@@ -43,7 +43,9 @@ class ExponentialMovingAverage:
|
|||||||
self,
|
self,
|
||||||
parameters: Iterable[torch.nn.Parameter] = None,
|
parameters: Iterable[torch.nn.Parameter] = None,
|
||||||
decay: float = 0.995,
|
decay: float = 0.995,
|
||||||
use_num_updates: bool = True
|
use_num_updates: bool = True,
|
||||||
|
# feeds back the decat to the parameter
|
||||||
|
use_feedback: bool = False
|
||||||
):
|
):
|
||||||
if parameters is None:
|
if parameters is None:
|
||||||
raise ValueError("parameters must be provided")
|
raise ValueError("parameters must be provided")
|
||||||
@@ -51,6 +53,7 @@ class ExponentialMovingAverage:
|
|||||||
raise ValueError('Decay must be between 0 and 1')
|
raise ValueError('Decay must be between 0 and 1')
|
||||||
self.decay = decay
|
self.decay = decay
|
||||||
self.num_updates = 0 if use_num_updates else None
|
self.num_updates = 0 if use_num_updates else None
|
||||||
|
self.use_feedback = use_feedback
|
||||||
parameters = list(parameters)
|
parameters = list(parameters)
|
||||||
self.shadow_params = [
|
self.shadow_params = [
|
||||||
p.clone().detach()
|
p.clone().detach()
|
||||||
@@ -123,6 +126,9 @@ class ExponentialMovingAverage:
|
|||||||
tmp.mul_(one_minus_decay)
|
tmp.mul_(one_minus_decay)
|
||||||
s_param.sub_(tmp)
|
s_param.sub_(tmp)
|
||||||
|
|
||||||
|
if self.use_feedback:
|
||||||
|
param.add_(tmp)
|
||||||
|
|
||||||
def copy_to(
|
def copy_to(
|
||||||
self,
|
self,
|
||||||
parameters: Optional[Iterable[torch.nn.Parameter]] = None
|
parameters: Optional[Iterable[torch.nn.Parameter]] = None
|
||||||
|
|||||||
@@ -118,6 +118,7 @@ class StableDiffusion:
|
|||||||
dtype='fp16',
|
dtype='fp16',
|
||||||
custom_pipeline=None,
|
custom_pipeline=None,
|
||||||
noise_scheduler=None,
|
noise_scheduler=None,
|
||||||
|
quantize_device=None,
|
||||||
):
|
):
|
||||||
self.custom_pipeline = custom_pipeline
|
self.custom_pipeline = custom_pipeline
|
||||||
self.device = device
|
self.device = device
|
||||||
@@ -171,6 +172,8 @@ class StableDiffusion:
|
|||||||
if self.is_flux or self.is_v3 or self.is_auraflow:
|
if self.is_flux or self.is_v3 or self.is_auraflow:
|
||||||
self.is_flow_matching = True
|
self.is_flow_matching = True
|
||||||
|
|
||||||
|
self.quantize_device = quantize_device if quantize_device is not None else self.device
|
||||||
|
|
||||||
def load_model(self):
|
def load_model(self):
|
||||||
if self.is_loaded:
|
if self.is_loaded:
|
||||||
return
|
return
|
||||||
@@ -454,10 +457,6 @@ class StableDiffusion:
|
|||||||
elif self.model_config.is_flux:
|
elif self.model_config.is_flux:
|
||||||
print("Loading Flux model")
|
print("Loading Flux model")
|
||||||
base_model_path = "black-forest-labs/FLUX.1-schnell"
|
base_model_path = "black-forest-labs/FLUX.1-schnell"
|
||||||
scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(base_model_path, subfolder="scheduler")
|
|
||||||
print("Loading vae")
|
|
||||||
vae = AutoencoderKL.from_pretrained(base_model_path, subfolder="vae", torch_dtype=dtype)
|
|
||||||
flush()
|
|
||||||
print("Loading transformer")
|
print("Loading transformer")
|
||||||
subfolder = 'transformer'
|
subfolder = 'transformer'
|
||||||
transformer_path = model_path
|
transformer_path = model_path
|
||||||
@@ -472,19 +471,19 @@ class StableDiffusion:
|
|||||||
# low_cpu_mem_usage=False,
|
# low_cpu_mem_usage=False,
|
||||||
# device_map=None
|
# device_map=None
|
||||||
)
|
)
|
||||||
transformer.to(self.device_torch, dtype=dtype)
|
transformer.to(torch.device(self.quantize_device), dtype=dtype)
|
||||||
flush()
|
flush()
|
||||||
|
|
||||||
if self.model_config.lora_path is not None:
|
if self.model_config.lora_path is not None:
|
||||||
# need the pipe to do this unfortunately for now
|
# need the pipe to do this unfortunately for now
|
||||||
# we have to fuse in the weights before quantizing
|
# we have to fuse in the weights before quantizing
|
||||||
pipe: FluxPipeline = FluxPipeline(
|
pipe: FluxPipeline = FluxPipeline(
|
||||||
scheduler=scheduler,
|
scheduler=None,
|
||||||
text_encoder=None,
|
text_encoder=None,
|
||||||
tokenizer=None,
|
tokenizer=None,
|
||||||
text_encoder_2=None,
|
text_encoder_2=None,
|
||||||
tokenizer_2=None,
|
tokenizer_2=None,
|
||||||
vae=vae,
|
vae=None,
|
||||||
transformer=transformer,
|
transformer=transformer,
|
||||||
)
|
)
|
||||||
pipe.load_lora_weights(self.model_config.lora_path, adapter_name="lora1")
|
pipe.load_lora_weights(self.model_config.lora_path, adapter_name="lora1")
|
||||||
@@ -496,6 +495,15 @@ class StableDiffusion:
|
|||||||
print("Quantizing transformer")
|
print("Quantizing transformer")
|
||||||
quantize(transformer, weights=qfloat8)
|
quantize(transformer, weights=qfloat8)
|
||||||
freeze(transformer)
|
freeze(transformer)
|
||||||
|
transformer.to(self.device_torch)
|
||||||
|
else:
|
||||||
|
transformer.to(self.device_torch, dtype=dtype)
|
||||||
|
|
||||||
|
flush()
|
||||||
|
|
||||||
|
scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(base_model_path, subfolder="scheduler")
|
||||||
|
print("Loading vae")
|
||||||
|
vae = AutoencoderKL.from_pretrained(base_model_path, subfolder="vae", torch_dtype=dtype)
|
||||||
flush()
|
flush()
|
||||||
|
|
||||||
print("Loading t5")
|
print("Loading t5")
|
||||||
|
|||||||
Reference in New Issue
Block a user