mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-29 10:41:28 +00:00
Added cogview4. Loss still needs work.
This commit is contained in:
@@ -168,11 +168,17 @@ class BaseModel:
|
||||
self.invert_assistant_lora = False
|
||||
self._after_sample_img_hooks = []
|
||||
self._status_update_hooks = []
|
||||
self.is_transformer = False
|
||||
|
||||
# properties for old arch for backwards compatibility
|
||||
@property
|
||||
def unet(self):
|
||||
return self.model
|
||||
|
||||
# set unet to model
|
||||
@unet.setter
|
||||
def unet(self, value):
|
||||
self.model = value
|
||||
|
||||
@property
|
||||
def unet_unwrapped(self):
|
||||
@@ -235,6 +241,7 @@ class BaseModel:
|
||||
|
||||
def generate_single_image(
|
||||
self,
|
||||
pipeline,
|
||||
gen_config: GenerateImageConfig,
|
||||
conditional_embeds: PromptEmbeds,
|
||||
unconditional_embeds: PromptEmbeds,
|
||||
@@ -257,6 +264,25 @@ class BaseModel:
|
||||
def get_prompt_embeds(self, prompt: str) -> PromptEmbeds:
|
||||
raise NotImplementedError(
|
||||
"get_prompt_embeds must be implemented in child classes")
|
||||
|
||||
def get_model_has_grad(self):
|
||||
raise NotImplementedError(
|
||||
"get_model_has_grad must be implemented in child classes")
|
||||
|
||||
def get_te_has_grad(self):
|
||||
raise NotImplementedError(
|
||||
"get_te_has_grad must be implemented in child classes")
|
||||
|
||||
def save_model(self, output_path, meta, save_dtype):
|
||||
# todo handle dtype without overloading anything (vram, cpu, etc)
|
||||
unwrap_model(self.pipeline).save_pretrained(
|
||||
save_directory=output_path,
|
||||
safe_serialization=True,
|
||||
)
|
||||
# save out meta config
|
||||
meta_path = os.path.join(output_path, 'aitk_meta.yaml')
|
||||
with open(meta_path, 'w') as f:
|
||||
yaml.dump(meta, f)
|
||||
# end must be implemented in child classes
|
||||
|
||||
def te_train(self):
|
||||
@@ -512,6 +538,7 @@ class BaseModel:
|
||||
self.device_torch, dtype=self.unet.dtype)
|
||||
|
||||
img = self.generate_single_image(
|
||||
pipeline,
|
||||
gen_config,
|
||||
conditional_embeds,
|
||||
unconditional_embeds,
|
||||
@@ -603,7 +630,8 @@ class BaseModel:
|
||||
self,
|
||||
original_samples: torch.FloatTensor,
|
||||
noise: torch.FloatTensor,
|
||||
timesteps: torch.IntTensor
|
||||
timesteps: torch.IntTensor,
|
||||
**kwargs,
|
||||
) -> torch.FloatTensor:
|
||||
original_samples_chunks = torch.chunk(
|
||||
original_samples, original_samples.shape[0], dim=0)
|
||||
@@ -1071,7 +1099,7 @@ class BaseModel:
|
||||
for name, param in self.text_encoder.named_parameters(recurse=True, prefix=f"{SD_PREFIX_TEXT_ENCODER}"):
|
||||
named_params[name] = param
|
||||
if unet:
|
||||
if self.is_flux or self.is_lumina2:
|
||||
if self.is_flux or self.is_lumina2 or self.is_transformer:
|
||||
for name, param in self.unet.named_parameters(recurse=True, prefix="transformer"):
|
||||
named_params[name] = param
|
||||
else:
|
||||
@@ -1105,59 +1133,11 @@ class BaseModel:
|
||||
return named_params
|
||||
|
||||
def save(self, output_file: str, meta: OrderedDict, save_dtype=get_torch_dtype('fp16'), logit_scale=None):
|
||||
version_string = '1'
|
||||
if self.is_v2:
|
||||
version_string = '2'
|
||||
if self.is_xl:
|
||||
version_string = 'sdxl'
|
||||
if self.is_ssd:
|
||||
# overwrite sdxl because both wil be true here
|
||||
version_string = 'ssd'
|
||||
if self.is_ssd and self.is_vega:
|
||||
version_string = 'vega'
|
||||
# if output file does not end in .safetensors, then it is a directory and we are
|
||||
# saving in diffusers format
|
||||
if not output_file.endswith('.safetensors'):
|
||||
# diffusers
|
||||
if self.is_flux:
|
||||
# only save the unet
|
||||
transformer: FluxTransformer2DModel = unwrap_model(self.unet)
|
||||
transformer.save_pretrained(
|
||||
save_directory=os.path.join(output_file, 'transformer'),
|
||||
safe_serialization=True,
|
||||
)
|
||||
elif self.is_lumina2:
|
||||
# only save the unet
|
||||
transformer: Lumina2Transformer2DModel = unwrap_model(
|
||||
self.unet)
|
||||
transformer.save_pretrained(
|
||||
save_directory=os.path.join(output_file, 'transformer'),
|
||||
safe_serialization=True,
|
||||
)
|
||||
|
||||
else:
|
||||
|
||||
self.pipeline.save_pretrained(
|
||||
save_directory=output_file,
|
||||
safe_serialization=True,
|
||||
)
|
||||
# save out meta config
|
||||
meta_path = os.path.join(output_file, 'aitk_meta.yaml')
|
||||
with open(meta_path, 'w') as f:
|
||||
yaml.dump(meta, f)
|
||||
|
||||
else:
|
||||
save_ldm_model_from_diffusers(
|
||||
sd=self,
|
||||
output_file=output_file,
|
||||
meta=meta,
|
||||
save_dtype=save_dtype,
|
||||
sd_version=version_string,
|
||||
)
|
||||
if self.config_file is not None:
|
||||
output_path_no_ext = os.path.splitext(output_file)[0]
|
||||
output_config_path = f"{output_path_no_ext}.yaml"
|
||||
shutil.copyfile(self.config_file, output_config_path)
|
||||
self.save_model(
|
||||
output_path=output_file,
|
||||
meta=meta,
|
||||
save_dtype=save_dtype
|
||||
)
|
||||
|
||||
def prepare_optimizer_params(
|
||||
self,
|
||||
@@ -1240,12 +1220,7 @@ class BaseModel:
|
||||
def save_device_state(self):
|
||||
# saves the current device state for all modules
|
||||
# this is useful for when we want to alter the state and restore it
|
||||
if self.is_lumina2:
|
||||
unet_has_grad = self.unet.x_embedder.weight.requires_grad
|
||||
elif self.is_pixart or self.is_v3 or self.is_auraflow or self.is_flux:
|
||||
unet_has_grad = self.unet.proj_out.weight.requires_grad
|
||||
else:
|
||||
unet_has_grad = self.unet.conv_in.weight.requires_grad
|
||||
unet_has_grad = self.get_model_has_grad()
|
||||
|
||||
self.device_state = {
|
||||
**empty_preset,
|
||||
@@ -1262,13 +1237,7 @@ class BaseModel:
|
||||
if isinstance(self.text_encoder, list):
|
||||
self.device_state['text_encoder']: List[dict] = []
|
||||
for encoder in self.text_encoder:
|
||||
if isinstance(encoder, LlamaModel):
|
||||
te_has_grad = encoder.layers[0].mlp.gate_proj.weight.requires_grad
|
||||
else:
|
||||
try:
|
||||
te_has_grad = encoder.text_model.final_layer_norm.weight.requires_grad
|
||||
except:
|
||||
te_has_grad = encoder.encoder.block[0].layer[0].SelfAttention.q.weight.requires_grad
|
||||
te_has_grad = self.get_te_has_grad()
|
||||
self.device_state['text_encoder'].append({
|
||||
'training': encoder.training,
|
||||
'device': encoder.device,
|
||||
@@ -1276,17 +1245,7 @@ class BaseModel:
|
||||
'requires_grad': te_has_grad
|
||||
})
|
||||
else:
|
||||
if isinstance(self.text_encoder, T5EncoderModel) or isinstance(self.text_encoder, UMT5EncoderModel):
|
||||
te_has_grad = self.text_encoder.encoder.block[
|
||||
0].layer[0].SelfAttention.q.weight.requires_grad
|
||||
elif isinstance(self.text_encoder, Gemma2Model):
|
||||
te_has_grad = self.text_encoder.layers[0].mlp.gate_proj.weight.requires_grad
|
||||
elif isinstance(self.text_encoder, Qwen2Model):
|
||||
te_has_grad = self.text_encoder.layers[0].mlp.gate_proj.weight.requires_grad
|
||||
elif isinstance(self.text_encoder, LlamaModel):
|
||||
te_has_grad = self.text_encoder.layers[0].mlp.gate_proj.weight.requires_grad
|
||||
else:
|
||||
te_has_grad = self.text_encoder.text_model.final_layer_norm.weight.requires_grad
|
||||
te_has_grad = self.get_te_has_grad()
|
||||
|
||||
self.device_state['text_encoder'] = {
|
||||
'training': self.text_encoder.training,
|
||||
|
||||
Reference in New Issue
Block a user