mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-04-29 02:41:27 +00:00
feat: Add basic text generation support with native models, initially supporting Gemma3 (#12392)
This commit is contained in:
29
comfy/sd.py
29
comfy/sd.py
@@ -423,6 +423,19 @@ class CLIP:
|
||||
def get_key_patches(self):
|
||||
return self.patcher.get_key_patches()
|
||||
|
||||
def generate(self, tokens, do_sample=True, max_length=256, temperature=1.0, top_k=50, top_p=0.95, min_p=0.0, repetition_penalty=1.0, seed=None):
|
||||
self.cond_stage_model.reset_clip_options()
|
||||
|
||||
if self.layer_idx is not None:
|
||||
self.cond_stage_model.set_clip_options({"layer": self.layer_idx})
|
||||
|
||||
self.load_model()
|
||||
self.cond_stage_model.set_clip_options({"execution_device": self.patcher.load_device})
|
||||
return self.cond_stage_model.generate(tokens, do_sample=do_sample, max_length=max_length, temperature=temperature, top_k=top_k, top_p=top_p, min_p=min_p, repetition_penalty=repetition_penalty, seed=seed)
|
||||
|
||||
def decode(self, token_ids, skip_special_tokens=True):
|
||||
return self.tokenizer.decode(token_ids, skip_special_tokens=skip_special_tokens)
|
||||
|
||||
class VAE:
|
||||
def __init__(self, sd=None, device=None, config=None, dtype=None, metadata=None):
|
||||
if 'decoder.up_blocks.0.resnets.0.norm1.weight' in sd.keys(): #diffusers format
|
||||
@@ -1182,6 +1195,7 @@ class TEModel(Enum):
|
||||
JINA_CLIP_2 = 19
|
||||
QWEN3_8B = 20
|
||||
QWEN3_06B = 21
|
||||
GEMMA_3_4B_VISION = 22
|
||||
|
||||
|
||||
def detect_te_model(sd):
|
||||
@@ -1210,7 +1224,10 @@ def detect_te_model(sd):
|
||||
if 'model.layers.47.self_attn.q_norm.weight' in sd:
|
||||
return TEModel.GEMMA_3_12B
|
||||
if 'model.layers.0.self_attn.q_norm.weight' in sd:
|
||||
return TEModel.GEMMA_3_4B
|
||||
if 'vision_model.embeddings.patch_embedding.weight' in sd:
|
||||
return TEModel.GEMMA_3_4B_VISION
|
||||
else:
|
||||
return TEModel.GEMMA_3_4B
|
||||
return TEModel.GEMMA_2_2B
|
||||
if 'model.layers.0.self_attn.k_proj.bias' in sd:
|
||||
weight = sd['model.layers.0.self_attn.k_proj.bias']
|
||||
@@ -1270,6 +1287,8 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
|
||||
else:
|
||||
if "text_projection" in clip_data[i]:
|
||||
clip_data[i]["text_projection.weight"] = clip_data[i]["text_projection"].transpose(0, 1) #old models saved with the CLIPSave node
|
||||
if "lm_head.weight" in clip_data[i]:
|
||||
clip_data[i]["model.lm_head.weight"] = clip_data[i].pop("lm_head.weight") # prefix missing in some models
|
||||
|
||||
tokenizer_data = {}
|
||||
clip_target = EmptyClass()
|
||||
@@ -1335,6 +1354,14 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
|
||||
clip_target.clip = comfy.text_encoders.lumina2.te(**llama_detect(clip_data), model_type="gemma3_4b")
|
||||
clip_target.tokenizer = comfy.text_encoders.lumina2.NTokenizer
|
||||
tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None)
|
||||
elif te_model == TEModel.GEMMA_3_4B_VISION:
|
||||
clip_target.clip = comfy.text_encoders.lumina2.te(**llama_detect(clip_data), model_type="gemma3_4b_vision")
|
||||
clip_target.tokenizer = comfy.text_encoders.lumina2.NTokenizer
|
||||
tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None)
|
||||
elif te_model == TEModel.GEMMA_3_12B:
|
||||
clip_target.clip = comfy.text_encoders.lt.gemma3_te(**llama_detect(clip_data))
|
||||
clip_target.tokenizer = comfy.text_encoders.lt.Gemma3_12BTokenizer
|
||||
tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None)
|
||||
elif te_model == TEModel.LLAMA3_8:
|
||||
clip_target.clip = comfy.text_encoders.hidream.hidream_clip(**llama_detect(clip_data),
|
||||
clip_l=False, clip_g=False, t5=False, llama=True, dtype_t5=None)
|
||||
|
||||
Reference in New Issue
Block a user