WIP more work on cogview4

This commit is contained in:
Jaret Burkett
2025-03-05 09:43:00 -07:00
parent 6f6fb90812
commit aa44828c0c
3 changed files with 51 additions and 64 deletions

View File

@@ -1,7 +1,7 @@
torch==2.5.1
torchvision==0.20.1
safetensors
git+https://github.com/huggingface/diffusers@97fda1b75c70705b245a462044fedb47abb17e56
git+https://github.com/huggingface/diffusers@24c062aaa19f5626d03d058daf8afffa2dfd49f7
transformers==4.49.0
lycoris-lora==1.8.3
flatten_json

View File

@@ -20,7 +20,6 @@ from transformers import GlmModel, AutoTokenizer
from diffusers import FlowMatchEulerDiscreteScheduler
from typing import TYPE_CHECKING
from toolkit.accelerator import unwrap_model
from toolkit.samplers.custom_flowmatch_sampler import CustomFlowMatchEulerDiscreteScheduler
if TYPE_CHECKING:
@@ -71,7 +70,7 @@ class CogView4(BaseModel):
self.is_flow_matching = True
self.is_transformer = True
self.target_lora_modules = ['CogView4Transformer2DModel']
# cache for holding noise
self.effective_noise = None
@@ -86,7 +85,6 @@ class CogView4(BaseModel):
base_model_path = "THUDM/CogView4-6B"
model_path = self.model_config.name_or_path
# pipe = CogView4Pipeline.from_pretrained("THUDM/CogView4-6B", torch_dtype=torch.bfloat16)
self.print_and_status_update("Loading CogView4 model")
# base_model_path = "black-forest-labs/FLUX.1-schnell"
base_model_path = self.model_config.name_or_path_original
@@ -213,19 +211,6 @@ class CogView4(BaseModel):
generator: torch.Generator,
extra: dict,
):
# there is a bug in the check in diffusers code that requires the prompt embeds to be the same length for conditional and unconditional
# they are processed in 2 passes and the encoding code doesnt do this. So it shouldnt be needed. But, we will zero pad the shorter one. for now. Just inference here, so it should be fine.
if conditional_embeds.text_embeds.shape[1] < unconditional_embeds.text_embeds.shape[1]:
pad_len = unconditional_embeds.text_embeds.shape[1] - \
conditional_embeds.text_embeds.shape[1]
conditional_embeds.text_embeds = torch.cat([conditional_embeds.text_embeds, torch.zeros(conditional_embeds.text_embeds.shape[0], pad_len,
conditional_embeds.text_embeds.shape[2], device=conditional_embeds.text_embeds.device, dtype=conditional_embeds.text_embeds.dtype)], dim=1)
elif conditional_embeds.text_embeds.shape[1] > unconditional_embeds.text_embeds.shape[1]:
pad_len = conditional_embeds.text_embeds.shape[1] - \
unconditional_embeds.text_embeds.shape[1]
unconditional_embeds.text_embeds = torch.cat([unconditional_embeds.text_embeds, torch.zeros(unconditional_embeds.text_embeds.shape[0], pad_len,
unconditional_embeds.text_embeds.shape[2], device=unconditional_embeds.text_embeds.device, dtype=unconditional_embeds.text_embeds.dtype)], dim=1)
img = pipeline(
prompt_embeds=conditional_embeds.text_embeds.to(
self.device_torch, dtype=self.torch_dtype),
@@ -259,12 +244,12 @@ class CogView4(BaseModel):
[target_size], dtype=self.torch_dtype, device=self.device_torch)
target_size = original_size.clone()
noise_pred_cond = self.model(
hidden_states=latent_model_input, # torch.Size([1, 16, 128, 128])
encoder_hidden_states=text_embeddings.text_embeds, # torch.Size([1, 16, 4096])
hidden_states=latent_model_input,
encoder_hidden_states=text_embeddings.text_embeds,
timestep=timestep,
original_size=original_size, # [[1024., 1024.]]
target_size=target_size, # [[1024., 1024.]]
crop_coords=crops_coords_top_left, # [[0., 0.]]
original_size=original_size,
target_size=target_size,
crop_coords=crops_coords_top_left,
return_dict=False,
)[0]
return noise_pred_cond
@@ -283,9 +268,9 @@ class CogView4(BaseModel):
def get_te_has_grad(self):
return self.text_encoder.layers[0].mlp.down_proj.weight.requires_grad
def save_model(self, output_path, meta, save_dtype):
# only save the unet
# only save the unet
transformer: CogView4Transformer2DModel = unwrap_model(self.model)
transformer.save_pretrained(
save_directory=os.path.join(output_path, 'transformer'),
@@ -295,7 +280,7 @@ class CogView4(BaseModel):
meta_path = os.path.join(output_path, 'aitk_meta.yaml')
with open(meta_path, 'w') as f:
yaml.dump(meta, f)
def get_loss_target(self, *args, **kwargs):
noise = kwargs.get('noise')
effective_noise = self.effective_noise
@@ -305,25 +290,27 @@ class CogView4(BaseModel):
if noise is None:
raise ValueError("Noise is not provided")
# return batch.latents
# return (batch.latents - noise).detach()
return (noise - batch.latents).detach()
# return (batch.latents).detach()
# return (effective_noise - batch.latents).detach()
def _get_low_res_latents(self, latents):
# todo prevent needing to do this and grab the tensor another way.
# todo prevent needing to do this and grab the tensor another way.
with torch.no_grad():
# Decode latents to image space
images = self.decode_latents(latents, device=latents.device, dtype=latents.dtype)
images = self.decode_latents(
latents, device=latents.device, dtype=latents.dtype)
# Downsample by a factor of 2 using bilinear interpolation
B, C, H, W = images.shape
low_res_images = torch.nn.functional.interpolate(
images,
size=(H // 4, W // 4),
size=(H // 2, W // 2),
mode="bilinear",
align_corners=False
)
# Upsample back to original resolution to match expected VAE input dimensions
upsampled_low_res_images = torch.nn.functional.interpolate(
low_res_images,
@@ -331,12 +318,12 @@ class CogView4(BaseModel):
mode="bilinear",
align_corners=False
)
# Encode the low-resolution images back to latent space
low_res_latents = self.encode_images(upsampled_low_res_images, device=latents.device, dtype=latents.dtype)
low_res_latents = self.encode_images(
upsampled_low_res_images, device=latents.device, dtype=latents.dtype)
return low_res_latents
# def add_noise(
# self,
# original_samples: torch.FloatTensor,
@@ -345,19 +332,19 @@ class CogView4(BaseModel):
# **kwargs,
# ) -> torch.FloatTensor:
# relay_start_point = 500
# # Store original samples for loss calculation
# self.original_samples = original_samples
# # Prepare chunks for batch processing
# original_samples_chunks = torch.chunk(
# original_samples, original_samples.shape[0], dim=0)
# noise_chunks = torch.chunk(noise, noise.shape[0], dim=0)
# timesteps_chunks = torch.chunk(timesteps, timesteps.shape[0], dim=0)
# # Get the low res latents only if needed
# low_res_latents_chunks = None
# # Handle case where timesteps is a single value for all samples
# if len(timesteps_chunks) == 1 and len(timesteps_chunks) != len(original_samples_chunks):
# timesteps_chunks = [timesteps_chunks[0]] * len(original_samples_chunks)
@@ -368,7 +355,7 @@ class CogView4(BaseModel):
# for idx in range(original_samples.shape[0]):
# t = timesteps_chunks[idx]
# t_01 = (t / 1000).to(original_samples_chunks[idx].device)
# # Flowmatching interpolation between original and noise
# if t > relay_start_point:
# # Standard flowmatching - direct linear interpolation
@@ -379,30 +366,29 @@ class CogView4(BaseModel):
# if low_res_latents_chunks is None:
# low_res_latents = self._get_low_res_latents(original_samples)
# low_res_latents_chunks = torch.chunk(low_res_latents, low_res_latents.shape[0], dim=0)
# # Calculate the relay ratio (0 to 1)
# t_ratio = t.float() / relay_start_point
# t_ratio = torch.clamp(t_ratio, 0.0, 1.0)
# # First blend between original and low-res based on t_ratio
# z0_t = (1 - t_ratio) * original_samples_chunks[idx] + t_ratio * low_res_latents_chunks[idx]
# added_lor_res_noise = z0_t - original_samples_chunks[idx]
# # Then apply flowmatching interpolation between this blended state and noise
# noisy_latents = (1 - t_01) * z0_t + t_01 * noise_chunks[idx]
# # For prediction target, we need to store the effective "source"
# effective_noise_chunks.append(noise_chunks[idx] + added_lor_res_noise)
# noisy_latents_chunks.append(noisy_latents)
# noisy_latents = torch.cat(noisy_latents_chunks, dim=0)
# self.effective_noise = torch.cat(effective_noise_chunks, dim=0) # Store for loss calculation
# return noisy_latents
# def add_noise(
# self,
# original_samples: torch.FloatTensor,
@@ -411,20 +397,20 @@ class CogView4(BaseModel):
# **kwargs,
# ) -> torch.FloatTensor:
# relay_start_point = 500
# # Store original samples for loss calculation
# self.original_samples = original_samples
# # Prepare chunks for batch processing
# original_samples_chunks = torch.chunk(
# original_samples, original_samples.shape[0], dim=0)
# noise_chunks = torch.chunk(noise, noise.shape[0], dim=0)
# timesteps_chunks = torch.chunk(timesteps, timesteps.shape[0], dim=0)
# # Get the low res latents only if needed
# low_res_latents = self._get_low_res_latents(original_samples)
# low_res_latents_chunks = torch.chunk(low_res_latents, low_res_latents.shape[0], dim=0)
# # Handle case where timesteps is a single value for all samples
# if len(timesteps_chunks) == 1 and len(timesteps_chunks) != len(original_samples_chunks):
# timesteps_chunks = [timesteps_chunks[0]] * len(original_samples_chunks)
@@ -435,24 +421,25 @@ class CogView4(BaseModel):
# for idx in range(original_samples.shape[0]):
# t = timesteps_chunks[idx]
# t_01 = (t / 1000).to(original_samples_chunks[idx].device)
# lrln = low_res_latents_chunks[idx] - original_samples_chunks[idx]
# lrln = lrln * (1 - t_01)
# # make the noise an interpolation between noise and low_res_latents with
# # lrln = lrln * (1 - t_01)
# # make the noise an interpolation between noise and low_res_latents with
# # being noise at t_01=1 and low_res_latents at t_01=0
# # new_noise = t_01 * noise_chunks[idx] + (1 - t_01) * lrln
# new_noise = noise_chunks[idx] + lrln
# new_noise = t_01 * noise_chunks[idx] + (1 - t_01) * lrln
# # new_noise = noise_chunks[idx] + lrln
# # new_noise = noise_chunks[idx] + lrln
# # Then apply flowmatching interpolation between this blended state and noise
# noisy_latents = (1 - t_01) * original_samples + t_01 * new_noise
# # For prediction target, we need to store the effective "source"
# effective_noise_chunks.append(new_noise)
# noisy_latents_chunks.append(noisy_latents)
# noisy_latents = torch.cat(noisy_latents_chunks, dim=0)
# self.effective_noise = torch.cat(effective_noise_chunks, dim=0) # Store for loss calculation
# return noisy_latents

View File

@@ -89,7 +89,7 @@ class CustomFlowMatchEulerDiscreteScheduler(FlowMatchEulerDiscreteScheduler):
) -> torch.Tensor:
t_01 = (timesteps / 1000).to(original_samples.device)
# forward ODE
noisy_model_input = (1 - t_01) * original_samples + t_01 * noise
noisy_model_input = (1.0 - t_01) * original_samples + t_01 * noise
# reverse ODE
# noisy_model_input = (1 - t_01) * noise + t_01 * original_samples
return noisy_model_input