Add initial support for chroma radiance

This commit is contained in:
Jaret Burkett
2025-09-10 08:41:05 -06:00
parent af6fdaaaf9
commit b95c17dc17
9 changed files with 1339 additions and 20 deletions

View File

@@ -15,7 +15,7 @@ from toolkit.accelerator import unwrap_model
from optimum.quanto import freeze, QTensor
from toolkit.util.quantize import quantize, get_qtype
from transformers import T5TokenizerFast, T5EncoderModel, CLIPTextModel, CLIPTokenizer
from .pipeline import ChromaPipeline
from .pipeline import ChromaPipeline, prepare_latent_image_ids
from einops import rearrange, repeat
import random
import torch.nn.functional as F
@@ -324,12 +324,19 @@ class ChromaModel(BaseModel):
ph=2,
pw=2
)
img_ids = prepare_latent_image_ids(
bs,
h,
w,
patch_size=2
).to(device=self.device_torch)
img_ids = torch.zeros(h // 2, w // 2, 3)
img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None]
img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :]
img_ids = repeat(img_ids, "h w c -> b (h w) c",
b=bs).to(self.device_torch)
# img_ids = torch.zeros(h // 2, w // 2, 3)
# img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None]
# img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :]
# img_ids = repeat(img_ids, "h w c -> b (h w) c",
# b=bs).to(self.device_torch)
txt_ids = torch.zeros(
bs, text_embeddings.text_embeds.shape[1], 3).to(self.device_torch)