mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-29 18:51:37 +00:00
Add initial support for chroma radiance
This commit is contained in:
@@ -15,7 +15,7 @@ from toolkit.accelerator import unwrap_model
|
||||
from optimum.quanto import freeze, QTensor
|
||||
from toolkit.util.quantize import quantize, get_qtype
|
||||
from transformers import T5TokenizerFast, T5EncoderModel, CLIPTextModel, CLIPTokenizer
|
||||
from .pipeline import ChromaPipeline
|
||||
from .pipeline import ChromaPipeline, prepare_latent_image_ids
|
||||
from einops import rearrange, repeat
|
||||
import random
|
||||
import torch.nn.functional as F
|
||||
@@ -324,12 +324,19 @@ class ChromaModel(BaseModel):
|
||||
ph=2,
|
||||
pw=2
|
||||
)
|
||||
|
||||
img_ids = prepare_latent_image_ids(
|
||||
bs,
|
||||
h,
|
||||
w,
|
||||
patch_size=2
|
||||
).to(device=self.device_torch)
|
||||
|
||||
img_ids = torch.zeros(h // 2, w // 2, 3)
|
||||
img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None]
|
||||
img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :]
|
||||
img_ids = repeat(img_ids, "h w c -> b (h w) c",
|
||||
b=bs).to(self.device_torch)
|
||||
# img_ids = torch.zeros(h // 2, w // 2, 3)
|
||||
# img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None]
|
||||
# img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :]
|
||||
# img_ids = repeat(img_ids, "h w c -> b (h w) c",
|
||||
# b=bs).to(self.device_torch)
|
||||
|
||||
txt_ids = torch.zeros(
|
||||
bs, text_embeddings.text_embeds.shape[1], 3).to(self.device_torch)
|
||||
|
||||
Reference in New Issue
Block a user