From 0301ccf745d24f41abdf05ba84ffb61a26ddaff7 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Thu, 19 Feb 2026 19:42:28 -0800 Subject: [PATCH] Small cleanup and try to get qwen 3 work with the text gen. (#12537) --- comfy/sd.py | 4 +--- comfy/sd1_clip.py | 4 ++-- comfy/text_encoders/anima.py | 2 ++ comfy/text_encoders/llama.py | 22 +++++++++++++++++----- comfy/text_encoders/lumina2.py | 6 ------ 5 files changed, 22 insertions(+), 16 deletions(-) diff --git a/comfy/sd.py b/comfy/sd.py index 164f30803..ce6ca5d17 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -426,10 +426,8 @@ class CLIP: def generate(self, tokens, do_sample=True, max_length=256, temperature=1.0, top_k=50, top_p=0.95, min_p=0.0, repetition_penalty=1.0, seed=None): self.cond_stage_model.reset_clip_options() - if self.layer_idx is not None: - self.cond_stage_model.set_clip_options({"layer": self.layer_idx}) - self.load_model() + self.cond_stage_model.set_clip_options({"layer": None}) self.cond_stage_model.set_clip_options({"execution_device": self.patcher.load_device}) return self.cond_stage_model.generate(tokens, do_sample=do_sample, max_length=max_length, temperature=temperature, top_k=top_k, top_p=top_p, min_p=min_p, repetition_penalty=repetition_penalty, seed=seed) diff --git a/comfy/sd1_clip.py b/comfy/sd1_clip.py index d9d014055..17e2b4816 100644 --- a/comfy/sd1_clip.py +++ b/comfy/sd1_clip.py @@ -308,14 +308,14 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): def load_sd(self, sd): return self.transformer.load_state_dict(sd, strict=False, assign=getattr(self, "can_assign_sd", False)) - def generate(self, tokens, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed, stop_tokens=[]): + def generate(self, tokens, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed): if isinstance(tokens, dict): tokens_only = next(iter(tokens.values())) # todo: get this better? else: tokens_only = tokens tokens_only = [[t[0] for t in b] for b in tokens_only] embeds = self.process_tokens(tokens_only, device=self.execution_device)[0] - return self.transformer.generate(embeds, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed, stop_tokens) + return self.transformer.generate(embeds, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed) def parse_parentheses(string): result = [] diff --git a/comfy/text_encoders/anima.py b/comfy/text_encoders/anima.py index d8c5a6f92..2e31b2b04 100644 --- a/comfy/text_encoders/anima.py +++ b/comfy/text_encoders/anima.py @@ -33,6 +33,8 @@ class AnimaTokenizer: def state_dict(self): return {} + def decode(self, token_ids, **kwargs): + return self.qwen3_06b.decode(token_ids, **kwargs) class Qwen3_06BModel(sd1_clip.SDClipModel): def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, attention_mask=True, model_options={}): diff --git a/comfy/text_encoders/llama.py b/comfy/text_encoders/llama.py index e5d21fa74..ccc200b7a 100644 --- a/comfy/text_encoders/llama.py +++ b/comfy/text_encoders/llama.py @@ -105,6 +105,7 @@ class Qwen3_06BConfig: rope_scale = None final_norm: bool = True lm_head: bool = False + stop_tokens = [151643, 151645] @dataclass class Qwen3_06B_ACE15_Config: @@ -128,6 +129,7 @@ class Qwen3_06B_ACE15_Config: rope_scale = None final_norm: bool = True lm_head: bool = False + stop_tokens = [151643, 151645] @dataclass class Qwen3_2B_ACE15_lm_Config: @@ -151,6 +153,7 @@ class Qwen3_2B_ACE15_lm_Config: rope_scale = None final_norm: bool = True lm_head: bool = False + stop_tokens = [151643, 151645] @dataclass class Qwen3_4B_ACE15_lm_Config: @@ -174,6 +177,7 @@ class Qwen3_4B_ACE15_lm_Config: rope_scale = None final_norm: bool = True lm_head: bool = False + stop_tokens = [151643, 151645] @dataclass class Qwen3_4BConfig: @@ -197,6 +201,7 @@ class Qwen3_4BConfig: rope_scale = None final_norm: bool = True lm_head: bool = False + stop_tokens = [151643, 151645] @dataclass class Qwen3_8BConfig: @@ -220,6 +225,7 @@ class Qwen3_8BConfig: rope_scale = None final_norm: bool = True lm_head: bool = False + stop_tokens = [151643, 151645] @dataclass class Ovis25_2BConfig: @@ -290,6 +296,7 @@ class Gemma2_2B_Config: rope_scale = None final_norm: bool = True lm_head: bool = False + stop_tokens = [1] @dataclass class Gemma3_4B_Config: @@ -314,6 +321,7 @@ class Gemma3_4B_Config: rope_scale = [8.0, 1.0] final_norm: bool = True lm_head: bool = False + stop_tokens = [1, 106] GEMMA3_VISION_CONFIG = {"num_channels": 3, "hidden_act": "gelu_pytorch_tanh", "hidden_size": 1152, "image_size": 896, "intermediate_size": 4304, "model_type": "siglip_vision_model", "num_attention_heads": 16, "num_hidden_layers": 27, "patch_size": 14} @@ -347,6 +355,7 @@ class Gemma3_12B_Config: lm_head: bool = False vision_config = GEMMA3_VISION_CONFIG mm_tokens_per_image = 256 + stop_tokens = [1, 106] class RMSNorm(nn.Module): def __init__(self, dim: int, eps: float = 1e-5, add=False, device=None, dtype=None): @@ -803,10 +812,13 @@ class BaseGenerate: comfy.ops.uncast_bias_weight(module, weight, None, offload_stream) return x - def generate(self, embeds=None, do_sample=True, max_length=256, temperature=1.0, top_k=50, top_p=0.9, min_p=0.0, repetition_penalty=1.0, seed=42, stop_tokens=[], initial_tokens=[], execution_dtype=None, min_tokens=0): + def generate(self, embeds=None, do_sample=True, max_length=256, temperature=1.0, top_k=50, top_p=0.9, min_p=0.0, repetition_penalty=1.0, seed=42, stop_tokens=None, initial_tokens=[], execution_dtype=None, min_tokens=0): device = embeds.device model_config = self.model.config + if stop_tokens is None: + stop_tokens = self.model.config.stop_tokens + if execution_dtype is None: if comfy.model_management.should_use_bf16(device): execution_dtype = torch.bfloat16 @@ -925,7 +937,7 @@ class Qwen25_3B(BaseLlama, torch.nn.Module): self.model = Llama2_(config, device=device, dtype=dtype, ops=operations) self.dtype = dtype -class Qwen3_06B(BaseLlama, BaseQwen3, torch.nn.Module): +class Qwen3_06B(BaseLlama, BaseQwen3, BaseGenerate, torch.nn.Module): def __init__(self, config_dict, dtype, device, operations): super().__init__() config = Qwen3_06BConfig(**config_dict) @@ -952,7 +964,7 @@ class Qwen3_2B_ACE15_lm(BaseLlama, BaseQwen3, torch.nn.Module): self.model = Llama2_(config, device=device, dtype=dtype, ops=operations) self.dtype = dtype -class Qwen3_4B(BaseLlama, BaseQwen3, torch.nn.Module): +class Qwen3_4B(BaseLlama, BaseQwen3, BaseGenerate, torch.nn.Module): def __init__(self, config_dict, dtype, device, operations): super().__init__() config = Qwen3_4BConfig(**config_dict) @@ -970,7 +982,7 @@ class Qwen3_4B_ACE15_lm(BaseLlama, BaseQwen3, torch.nn.Module): self.model = Llama2_(config, device=device, dtype=dtype, ops=operations) self.dtype = dtype -class Qwen3_8B(BaseLlama, BaseQwen3, torch.nn.Module): +class Qwen3_8B(BaseLlama, BaseQwen3, BaseGenerate, torch.nn.Module): def __init__(self, config_dict, dtype, device, operations): super().__init__() config = Qwen3_8BConfig(**config_dict) @@ -1034,7 +1046,7 @@ class Qwen25_7BVLI(BaseLlama, BaseGenerate, torch.nn.Module): return super().forward(x, attention_mask=attention_mask, embeds=embeds, num_tokens=num_tokens, intermediate_output=intermediate_output, final_layer_norm_intermediate=final_layer_norm_intermediate, dtype=dtype, position_ids=position_ids) -class Gemma2_2B(BaseLlama, torch.nn.Module): +class Gemma2_2B(BaseLlama, BaseGenerate, torch.nn.Module): def __init__(self, config_dict, dtype, device, operations): super().__init__() config = Gemma2_2B_Config(**config_dict) diff --git a/comfy/text_encoders/lumina2.py b/comfy/text_encoders/lumina2.py index 1b731e094..01ebdfabe 100644 --- a/comfy/text_encoders/lumina2.py +++ b/comfy/text_encoders/lumina2.py @@ -31,9 +31,6 @@ class Gemma2_2BModel(sd1_clip.SDClipModel): def __init__(self, device="cpu", layer="hidden", layer_idx=-2, dtype=None, attention_mask=True, model_options={}): super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"start": 2, "pad": 0}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Gemma2_2B, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options) - def generate(self, embeds, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed): - return super().generate(embeds, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed, stop_tokens=[107]) - class Gemma3_4BModel(sd1_clip.SDClipModel): def __init__(self, device="cpu", layer="hidden", layer_idx=-2, dtype=None, attention_mask=True, model_options={}): llama_quantization_metadata = model_options.get("llama_quantization_metadata", None) @@ -43,9 +40,6 @@ class Gemma3_4BModel(sd1_clip.SDClipModel): super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"start": 2, "pad": 0}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Gemma3_4B, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options) - def generate(self, embeds, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed): - return super().generate(embeds, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed, stop_tokens=[106]) - class Gemma3_4B_Vision_Model(sd1_clip.SDClipModel): def __init__(self, device="cpu", layer="hidden", layer_idx=-2, dtype=None, attention_mask=True, model_options={}): llama_quantization_metadata = model_options.get("llama_quantization_metadata", None)