Added cogview4. Loss still needs work.

This commit is contained in:
Jaret Burkett
2025-03-04 18:43:52 -07:00
parent c57434ad7b
commit 6f6fb90812
12 changed files with 661 additions and 152 deletions

View File

@@ -168,11 +168,17 @@ class BaseModel:
self.invert_assistant_lora = False
self._after_sample_img_hooks = []
self._status_update_hooks = []
self.is_transformer = False
# properties for old arch for backwards compatibility
@property
def unet(self):
return self.model
# set unet to model
@unet.setter
def unet(self, value):
self.model = value
@property
def unet_unwrapped(self):
@@ -235,6 +241,7 @@ class BaseModel:
def generate_single_image(
self,
pipeline,
gen_config: GenerateImageConfig,
conditional_embeds: PromptEmbeds,
unconditional_embeds: PromptEmbeds,
@@ -257,6 +264,25 @@ class BaseModel:
def get_prompt_embeds(self, prompt: str) -> PromptEmbeds:
raise NotImplementedError(
"get_prompt_embeds must be implemented in child classes")
def get_model_has_grad(self):
raise NotImplementedError(
"get_model_has_grad must be implemented in child classes")
def get_te_has_grad(self):
raise NotImplementedError(
"get_te_has_grad must be implemented in child classes")
def save_model(self, output_path, meta, save_dtype):
# todo handle dtype without overloading anything (vram, cpu, etc)
unwrap_model(self.pipeline).save_pretrained(
save_directory=output_path,
safe_serialization=True,
)
# save out meta config
meta_path = os.path.join(output_path, 'aitk_meta.yaml')
with open(meta_path, 'w') as f:
yaml.dump(meta, f)
# end must be implemented in child classes
def te_train(self):
@@ -512,6 +538,7 @@ class BaseModel:
self.device_torch, dtype=self.unet.dtype)
img = self.generate_single_image(
pipeline,
gen_config,
conditional_embeds,
unconditional_embeds,
@@ -603,7 +630,8 @@ class BaseModel:
self,
original_samples: torch.FloatTensor,
noise: torch.FloatTensor,
timesteps: torch.IntTensor
timesteps: torch.IntTensor,
**kwargs,
) -> torch.FloatTensor:
original_samples_chunks = torch.chunk(
original_samples, original_samples.shape[0], dim=0)
@@ -1071,7 +1099,7 @@ class BaseModel:
for name, param in self.text_encoder.named_parameters(recurse=True, prefix=f"{SD_PREFIX_TEXT_ENCODER}"):
named_params[name] = param
if unet:
if self.is_flux or self.is_lumina2:
if self.is_flux or self.is_lumina2 or self.is_transformer:
for name, param in self.unet.named_parameters(recurse=True, prefix="transformer"):
named_params[name] = param
else:
@@ -1105,59 +1133,11 @@ class BaseModel:
return named_params
def save(self, output_file: str, meta: OrderedDict, save_dtype=get_torch_dtype('fp16'), logit_scale=None):
version_string = '1'
if self.is_v2:
version_string = '2'
if self.is_xl:
version_string = 'sdxl'
if self.is_ssd:
# overwrite sdxl because both wil be true here
version_string = 'ssd'
if self.is_ssd and self.is_vega:
version_string = 'vega'
# if output file does not end in .safetensors, then it is a directory and we are
# saving in diffusers format
if not output_file.endswith('.safetensors'):
# diffusers
if self.is_flux:
# only save the unet
transformer: FluxTransformer2DModel = unwrap_model(self.unet)
transformer.save_pretrained(
save_directory=os.path.join(output_file, 'transformer'),
safe_serialization=True,
)
elif self.is_lumina2:
# only save the unet
transformer: Lumina2Transformer2DModel = unwrap_model(
self.unet)
transformer.save_pretrained(
save_directory=os.path.join(output_file, 'transformer'),
safe_serialization=True,
)
else:
self.pipeline.save_pretrained(
save_directory=output_file,
safe_serialization=True,
)
# save out meta config
meta_path = os.path.join(output_file, 'aitk_meta.yaml')
with open(meta_path, 'w') as f:
yaml.dump(meta, f)
else:
save_ldm_model_from_diffusers(
sd=self,
output_file=output_file,
meta=meta,
save_dtype=save_dtype,
sd_version=version_string,
)
if self.config_file is not None:
output_path_no_ext = os.path.splitext(output_file)[0]
output_config_path = f"{output_path_no_ext}.yaml"
shutil.copyfile(self.config_file, output_config_path)
self.save_model(
output_path=output_file,
meta=meta,
save_dtype=save_dtype
)
def prepare_optimizer_params(
self,
@@ -1240,12 +1220,7 @@ class BaseModel:
def save_device_state(self):
# saves the current device state for all modules
# this is useful for when we want to alter the state and restore it
if self.is_lumina2:
unet_has_grad = self.unet.x_embedder.weight.requires_grad
elif self.is_pixart or self.is_v3 or self.is_auraflow or self.is_flux:
unet_has_grad = self.unet.proj_out.weight.requires_grad
else:
unet_has_grad = self.unet.conv_in.weight.requires_grad
unet_has_grad = self.get_model_has_grad()
self.device_state = {
**empty_preset,
@@ -1262,13 +1237,7 @@ class BaseModel:
if isinstance(self.text_encoder, list):
self.device_state['text_encoder']: List[dict] = []
for encoder in self.text_encoder:
if isinstance(encoder, LlamaModel):
te_has_grad = encoder.layers[0].mlp.gate_proj.weight.requires_grad
else:
try:
te_has_grad = encoder.text_model.final_layer_norm.weight.requires_grad
except:
te_has_grad = encoder.encoder.block[0].layer[0].SelfAttention.q.weight.requires_grad
te_has_grad = self.get_te_has_grad()
self.device_state['text_encoder'].append({
'training': encoder.training,
'device': encoder.device,
@@ -1276,17 +1245,7 @@ class BaseModel:
'requires_grad': te_has_grad
})
else:
if isinstance(self.text_encoder, T5EncoderModel) or isinstance(self.text_encoder, UMT5EncoderModel):
te_has_grad = self.text_encoder.encoder.block[
0].layer[0].SelfAttention.q.weight.requires_grad
elif isinstance(self.text_encoder, Gemma2Model):
te_has_grad = self.text_encoder.layers[0].mlp.gate_proj.weight.requires_grad
elif isinstance(self.text_encoder, Qwen2Model):
te_has_grad = self.text_encoder.layers[0].mlp.gate_proj.weight.requires_grad
elif isinstance(self.text_encoder, LlamaModel):
te_has_grad = self.text_encoder.layers[0].mlp.gate_proj.weight.requires_grad
else:
te_has_grad = self.text_encoder.text_model.final_layer_norm.weight.requires_grad
te_has_grad = self.get_te_has_grad()
self.device_state['text_encoder'] = {
'training': self.text_encoder.training,