mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
WIP more work on cogview4
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
torch==2.5.1
|
||||
torchvision==0.20.1
|
||||
safetensors
|
||||
git+https://github.com/huggingface/diffusers@97fda1b75c70705b245a462044fedb47abb17e56
|
||||
git+https://github.com/huggingface/diffusers@24c062aaa19f5626d03d058daf8afffa2dfd49f7
|
||||
transformers==4.49.0
|
||||
lycoris-lora==1.8.3
|
||||
flatten_json
|
||||
|
||||
@@ -20,7 +20,6 @@ from transformers import GlmModel, AutoTokenizer
|
||||
from diffusers import FlowMatchEulerDiscreteScheduler
|
||||
from typing import TYPE_CHECKING
|
||||
from toolkit.accelerator import unwrap_model
|
||||
|
||||
from toolkit.samplers.custom_flowmatch_sampler import CustomFlowMatchEulerDiscreteScheduler
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -71,7 +70,7 @@ class CogView4(BaseModel):
|
||||
self.is_flow_matching = True
|
||||
self.is_transformer = True
|
||||
self.target_lora_modules = ['CogView4Transformer2DModel']
|
||||
|
||||
|
||||
# cache for holding noise
|
||||
self.effective_noise = None
|
||||
|
||||
@@ -86,7 +85,6 @@ class CogView4(BaseModel):
|
||||
base_model_path = "THUDM/CogView4-6B"
|
||||
model_path = self.model_config.name_or_path
|
||||
|
||||
# pipe = CogView4Pipeline.from_pretrained("THUDM/CogView4-6B", torch_dtype=torch.bfloat16)
|
||||
self.print_and_status_update("Loading CogView4 model")
|
||||
# base_model_path = "black-forest-labs/FLUX.1-schnell"
|
||||
base_model_path = self.model_config.name_or_path_original
|
||||
@@ -213,19 +211,6 @@ class CogView4(BaseModel):
|
||||
generator: torch.Generator,
|
||||
extra: dict,
|
||||
):
|
||||
# there is a bug in the check in diffusers code that requires the prompt embeds to be the same length for conditional and unconditional
|
||||
# they are processed in 2 passes and the encoding code doesnt do this. So it shouldnt be needed. But, we will zero pad the shorter one. for now. Just inference here, so it should be fine.
|
||||
if conditional_embeds.text_embeds.shape[1] < unconditional_embeds.text_embeds.shape[1]:
|
||||
pad_len = unconditional_embeds.text_embeds.shape[1] - \
|
||||
conditional_embeds.text_embeds.shape[1]
|
||||
conditional_embeds.text_embeds = torch.cat([conditional_embeds.text_embeds, torch.zeros(conditional_embeds.text_embeds.shape[0], pad_len,
|
||||
conditional_embeds.text_embeds.shape[2], device=conditional_embeds.text_embeds.device, dtype=conditional_embeds.text_embeds.dtype)], dim=1)
|
||||
elif conditional_embeds.text_embeds.shape[1] > unconditional_embeds.text_embeds.shape[1]:
|
||||
pad_len = conditional_embeds.text_embeds.shape[1] - \
|
||||
unconditional_embeds.text_embeds.shape[1]
|
||||
unconditional_embeds.text_embeds = torch.cat([unconditional_embeds.text_embeds, torch.zeros(unconditional_embeds.text_embeds.shape[0], pad_len,
|
||||
unconditional_embeds.text_embeds.shape[2], device=unconditional_embeds.text_embeds.device, dtype=unconditional_embeds.text_embeds.dtype)], dim=1)
|
||||
|
||||
img = pipeline(
|
||||
prompt_embeds=conditional_embeds.text_embeds.to(
|
||||
self.device_torch, dtype=self.torch_dtype),
|
||||
@@ -259,12 +244,12 @@ class CogView4(BaseModel):
|
||||
[target_size], dtype=self.torch_dtype, device=self.device_torch)
|
||||
target_size = original_size.clone()
|
||||
noise_pred_cond = self.model(
|
||||
hidden_states=latent_model_input, # torch.Size([1, 16, 128, 128])
|
||||
encoder_hidden_states=text_embeddings.text_embeds, # torch.Size([1, 16, 4096])
|
||||
hidden_states=latent_model_input,
|
||||
encoder_hidden_states=text_embeddings.text_embeds,
|
||||
timestep=timestep,
|
||||
original_size=original_size, # [[1024., 1024.]]
|
||||
target_size=target_size, # [[1024., 1024.]]
|
||||
crop_coords=crops_coords_top_left, # [[0., 0.]]
|
||||
original_size=original_size,
|
||||
target_size=target_size,
|
||||
crop_coords=crops_coords_top_left,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
return noise_pred_cond
|
||||
@@ -283,9 +268,9 @@ class CogView4(BaseModel):
|
||||
|
||||
def get_te_has_grad(self):
|
||||
return self.text_encoder.layers[0].mlp.down_proj.weight.requires_grad
|
||||
|
||||
|
||||
def save_model(self, output_path, meta, save_dtype):
|
||||
# only save the unet
|
||||
# only save the unet
|
||||
transformer: CogView4Transformer2DModel = unwrap_model(self.model)
|
||||
transformer.save_pretrained(
|
||||
save_directory=os.path.join(output_path, 'transformer'),
|
||||
@@ -295,7 +280,7 @@ class CogView4(BaseModel):
|
||||
meta_path = os.path.join(output_path, 'aitk_meta.yaml')
|
||||
with open(meta_path, 'w') as f:
|
||||
yaml.dump(meta, f)
|
||||
|
||||
|
||||
def get_loss_target(self, *args, **kwargs):
|
||||
noise = kwargs.get('noise')
|
||||
effective_noise = self.effective_noise
|
||||
@@ -305,25 +290,27 @@ class CogView4(BaseModel):
|
||||
if noise is None:
|
||||
raise ValueError("Noise is not provided")
|
||||
# return batch.latents
|
||||
# return (batch.latents - noise).detach()
|
||||
return (noise - batch.latents).detach()
|
||||
# return (batch.latents).detach()
|
||||
# return (effective_noise - batch.latents).detach()
|
||||
|
||||
|
||||
|
||||
def _get_low_res_latents(self, latents):
|
||||
# todo prevent needing to do this and grab the tensor another way.
|
||||
# todo prevent needing to do this and grab the tensor another way.
|
||||
with torch.no_grad():
|
||||
# Decode latents to image space
|
||||
images = self.decode_latents(latents, device=latents.device, dtype=latents.dtype)
|
||||
|
||||
images = self.decode_latents(
|
||||
latents, device=latents.device, dtype=latents.dtype)
|
||||
|
||||
# Downsample by a factor of 2 using bilinear interpolation
|
||||
B, C, H, W = images.shape
|
||||
low_res_images = torch.nn.functional.interpolate(
|
||||
images,
|
||||
size=(H // 4, W // 4),
|
||||
size=(H // 2, W // 2),
|
||||
mode="bilinear",
|
||||
align_corners=False
|
||||
)
|
||||
|
||||
|
||||
# Upsample back to original resolution to match expected VAE input dimensions
|
||||
upsampled_low_res_images = torch.nn.functional.interpolate(
|
||||
low_res_images,
|
||||
@@ -331,12 +318,12 @@ class CogView4(BaseModel):
|
||||
mode="bilinear",
|
||||
align_corners=False
|
||||
)
|
||||
|
||||
|
||||
# Encode the low-resolution images back to latent space
|
||||
low_res_latents = self.encode_images(upsampled_low_res_images, device=latents.device, dtype=latents.dtype)
|
||||
low_res_latents = self.encode_images(
|
||||
upsampled_low_res_images, device=latents.device, dtype=latents.dtype)
|
||||
return low_res_latents
|
||||
|
||||
|
||||
|
||||
# def add_noise(
|
||||
# self,
|
||||
# original_samples: torch.FloatTensor,
|
||||
@@ -345,19 +332,19 @@ class CogView4(BaseModel):
|
||||
# **kwargs,
|
||||
# ) -> torch.FloatTensor:
|
||||
# relay_start_point = 500
|
||||
|
||||
|
||||
# # Store original samples for loss calculation
|
||||
# self.original_samples = original_samples
|
||||
|
||||
|
||||
# # Prepare chunks for batch processing
|
||||
# original_samples_chunks = torch.chunk(
|
||||
# original_samples, original_samples.shape[0], dim=0)
|
||||
# noise_chunks = torch.chunk(noise, noise.shape[0], dim=0)
|
||||
# timesteps_chunks = torch.chunk(timesteps, timesteps.shape[0], dim=0)
|
||||
|
||||
|
||||
# # Get the low res latents only if needed
|
||||
# low_res_latents_chunks = None
|
||||
|
||||
|
||||
# # Handle case where timesteps is a single value for all samples
|
||||
# if len(timesteps_chunks) == 1 and len(timesteps_chunks) != len(original_samples_chunks):
|
||||
# timesteps_chunks = [timesteps_chunks[0]] * len(original_samples_chunks)
|
||||
@@ -368,7 +355,7 @@ class CogView4(BaseModel):
|
||||
# for idx in range(original_samples.shape[0]):
|
||||
# t = timesteps_chunks[idx]
|
||||
# t_01 = (t / 1000).to(original_samples_chunks[idx].device)
|
||||
|
||||
|
||||
# # Flowmatching interpolation between original and noise
|
||||
# if t > relay_start_point:
|
||||
# # Standard flowmatching - direct linear interpolation
|
||||
@@ -379,30 +366,29 @@ class CogView4(BaseModel):
|
||||
# if low_res_latents_chunks is None:
|
||||
# low_res_latents = self._get_low_res_latents(original_samples)
|
||||
# low_res_latents_chunks = torch.chunk(low_res_latents, low_res_latents.shape[0], dim=0)
|
||||
|
||||
|
||||
# # Calculate the relay ratio (0 to 1)
|
||||
# t_ratio = t.float() / relay_start_point
|
||||
# t_ratio = torch.clamp(t_ratio, 0.0, 1.0)
|
||||
|
||||
|
||||
# # First blend between original and low-res based on t_ratio
|
||||
# z0_t = (1 - t_ratio) * original_samples_chunks[idx] + t_ratio * low_res_latents_chunks[idx]
|
||||
|
||||
|
||||
# added_lor_res_noise = z0_t - original_samples_chunks[idx]
|
||||
|
||||
|
||||
# # Then apply flowmatching interpolation between this blended state and noise
|
||||
# noisy_latents = (1 - t_01) * z0_t + t_01 * noise_chunks[idx]
|
||||
|
||||
|
||||
# # For prediction target, we need to store the effective "source"
|
||||
# effective_noise_chunks.append(noise_chunks[idx] + added_lor_res_noise)
|
||||
|
||||
|
||||
# noisy_latents_chunks.append(noisy_latents)
|
||||
|
||||
# noisy_latents = torch.cat(noisy_latents_chunks, dim=0)
|
||||
# self.effective_noise = torch.cat(effective_noise_chunks, dim=0) # Store for loss calculation
|
||||
|
||||
|
||||
# return noisy_latents
|
||||
|
||||
|
||||
# def add_noise(
|
||||
# self,
|
||||
# original_samples: torch.FloatTensor,
|
||||
@@ -411,20 +397,20 @@ class CogView4(BaseModel):
|
||||
# **kwargs,
|
||||
# ) -> torch.FloatTensor:
|
||||
# relay_start_point = 500
|
||||
|
||||
|
||||
# # Store original samples for loss calculation
|
||||
# self.original_samples = original_samples
|
||||
|
||||
|
||||
# # Prepare chunks for batch processing
|
||||
# original_samples_chunks = torch.chunk(
|
||||
# original_samples, original_samples.shape[0], dim=0)
|
||||
# noise_chunks = torch.chunk(noise, noise.shape[0], dim=0)
|
||||
# timesteps_chunks = torch.chunk(timesteps, timesteps.shape[0], dim=0)
|
||||
|
||||
|
||||
# # Get the low res latents only if needed
|
||||
# low_res_latents = self._get_low_res_latents(original_samples)
|
||||
# low_res_latents_chunks = torch.chunk(low_res_latents, low_res_latents.shape[0], dim=0)
|
||||
|
||||
|
||||
# # Handle case where timesteps is a single value for all samples
|
||||
# if len(timesteps_chunks) == 1 and len(timesteps_chunks) != len(original_samples_chunks):
|
||||
# timesteps_chunks = [timesteps_chunks[0]] * len(original_samples_chunks)
|
||||
@@ -435,24 +421,25 @@ class CogView4(BaseModel):
|
||||
# for idx in range(original_samples.shape[0]):
|
||||
# t = timesteps_chunks[idx]
|
||||
# t_01 = (t / 1000).to(original_samples_chunks[idx].device)
|
||||
|
||||
|
||||
# lrln = low_res_latents_chunks[idx] - original_samples_chunks[idx]
|
||||
# lrln = lrln * (1 - t_01)
|
||||
|
||||
# # make the noise an interpolation between noise and low_res_latents with
|
||||
# # lrln = lrln * (1 - t_01)
|
||||
|
||||
# # make the noise an interpolation between noise and low_res_latents with
|
||||
# # being noise at t_01=1 and low_res_latents at t_01=0
|
||||
# # new_noise = t_01 * noise_chunks[idx] + (1 - t_01) * lrln
|
||||
# new_noise = noise_chunks[idx] + lrln
|
||||
|
||||
# new_noise = t_01 * noise_chunks[idx] + (1 - t_01) * lrln
|
||||
# # new_noise = noise_chunks[idx] + lrln
|
||||
# # new_noise = noise_chunks[idx] + lrln
|
||||
|
||||
# # Then apply flowmatching interpolation between this blended state and noise
|
||||
# noisy_latents = (1 - t_01) * original_samples + t_01 * new_noise
|
||||
|
||||
|
||||
# # For prediction target, we need to store the effective "source"
|
||||
# effective_noise_chunks.append(new_noise)
|
||||
|
||||
|
||||
# noisy_latents_chunks.append(noisy_latents)
|
||||
|
||||
# noisy_latents = torch.cat(noisy_latents_chunks, dim=0)
|
||||
# self.effective_noise = torch.cat(effective_noise_chunks, dim=0) # Store for loss calculation
|
||||
|
||||
|
||||
# return noisy_latents
|
||||
|
||||
@@ -89,7 +89,7 @@ class CustomFlowMatchEulerDiscreteScheduler(FlowMatchEulerDiscreteScheduler):
|
||||
) -> torch.Tensor:
|
||||
t_01 = (timesteps / 1000).to(original_samples.device)
|
||||
# forward ODE
|
||||
noisy_model_input = (1 - t_01) * original_samples + t_01 * noise
|
||||
noisy_model_input = (1.0 - t_01) * original_samples + t_01 * noise
|
||||
# reverse ODE
|
||||
# noisy_model_input = (1 - t_01) * noise + t_01 * original_samples
|
||||
return noisy_model_input
|
||||
|
||||
Reference in New Issue
Block a user