Small cleanup and try to get qwen 3 work with the text gen. (#12537)

This commit is contained in:
comfyanonymous
2026-02-19 19:42:28 -08:00
committed by GitHub
parent 4d172e9ad7
commit 0301ccf745
5 changed files with 22 additions and 16 deletions

View File

@@ -426,10 +426,8 @@ class CLIP:
def generate(self, tokens, do_sample=True, max_length=256, temperature=1.0, top_k=50, top_p=0.95, min_p=0.0, repetition_penalty=1.0, seed=None):
self.cond_stage_model.reset_clip_options()
if self.layer_idx is not None:
self.cond_stage_model.set_clip_options({"layer": self.layer_idx})
self.load_model()
self.cond_stage_model.set_clip_options({"layer": None})
self.cond_stage_model.set_clip_options({"execution_device": self.patcher.load_device})
return self.cond_stage_model.generate(tokens, do_sample=do_sample, max_length=max_length, temperature=temperature, top_k=top_k, top_p=top_p, min_p=min_p, repetition_penalty=repetition_penalty, seed=seed)

View File

@@ -308,14 +308,14 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
def load_sd(self, sd):
return self.transformer.load_state_dict(sd, strict=False, assign=getattr(self, "can_assign_sd", False))
def generate(self, tokens, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed, stop_tokens=[]):
def generate(self, tokens, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed):
if isinstance(tokens, dict):
tokens_only = next(iter(tokens.values())) # todo: get this better?
else:
tokens_only = tokens
tokens_only = [[t[0] for t in b] for b in tokens_only]
embeds = self.process_tokens(tokens_only, device=self.execution_device)[0]
return self.transformer.generate(embeds, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed, stop_tokens)
return self.transformer.generate(embeds, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed)
def parse_parentheses(string):
result = []

View File

@@ -33,6 +33,8 @@ class AnimaTokenizer:
def state_dict(self):
return {}
def decode(self, token_ids, **kwargs):
return self.qwen3_06b.decode(token_ids, **kwargs)
class Qwen3_06BModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, attention_mask=True, model_options={}):

View File

@@ -105,6 +105,7 @@ class Qwen3_06BConfig:
rope_scale = None
final_norm: bool = True
lm_head: bool = False
stop_tokens = [151643, 151645]
@dataclass
class Qwen3_06B_ACE15_Config:
@@ -128,6 +129,7 @@ class Qwen3_06B_ACE15_Config:
rope_scale = None
final_norm: bool = True
lm_head: bool = False
stop_tokens = [151643, 151645]
@dataclass
class Qwen3_2B_ACE15_lm_Config:
@@ -151,6 +153,7 @@ class Qwen3_2B_ACE15_lm_Config:
rope_scale = None
final_norm: bool = True
lm_head: bool = False
stop_tokens = [151643, 151645]
@dataclass
class Qwen3_4B_ACE15_lm_Config:
@@ -174,6 +177,7 @@ class Qwen3_4B_ACE15_lm_Config:
rope_scale = None
final_norm: bool = True
lm_head: bool = False
stop_tokens = [151643, 151645]
@dataclass
class Qwen3_4BConfig:
@@ -197,6 +201,7 @@ class Qwen3_4BConfig:
rope_scale = None
final_norm: bool = True
lm_head: bool = False
stop_tokens = [151643, 151645]
@dataclass
class Qwen3_8BConfig:
@@ -220,6 +225,7 @@ class Qwen3_8BConfig:
rope_scale = None
final_norm: bool = True
lm_head: bool = False
stop_tokens = [151643, 151645]
@dataclass
class Ovis25_2BConfig:
@@ -290,6 +296,7 @@ class Gemma2_2B_Config:
rope_scale = None
final_norm: bool = True
lm_head: bool = False
stop_tokens = [1]
@dataclass
class Gemma3_4B_Config:
@@ -314,6 +321,7 @@ class Gemma3_4B_Config:
rope_scale = [8.0, 1.0]
final_norm: bool = True
lm_head: bool = False
stop_tokens = [1, 106]
GEMMA3_VISION_CONFIG = {"num_channels": 3, "hidden_act": "gelu_pytorch_tanh", "hidden_size": 1152, "image_size": 896, "intermediate_size": 4304, "model_type": "siglip_vision_model", "num_attention_heads": 16, "num_hidden_layers": 27, "patch_size": 14}
@@ -347,6 +355,7 @@ class Gemma3_12B_Config:
lm_head: bool = False
vision_config = GEMMA3_VISION_CONFIG
mm_tokens_per_image = 256
stop_tokens = [1, 106]
class RMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-5, add=False, device=None, dtype=None):
@@ -803,10 +812,13 @@ class BaseGenerate:
comfy.ops.uncast_bias_weight(module, weight, None, offload_stream)
return x
def generate(self, embeds=None, do_sample=True, max_length=256, temperature=1.0, top_k=50, top_p=0.9, min_p=0.0, repetition_penalty=1.0, seed=42, stop_tokens=[], initial_tokens=[], execution_dtype=None, min_tokens=0):
def generate(self, embeds=None, do_sample=True, max_length=256, temperature=1.0, top_k=50, top_p=0.9, min_p=0.0, repetition_penalty=1.0, seed=42, stop_tokens=None, initial_tokens=[], execution_dtype=None, min_tokens=0):
device = embeds.device
model_config = self.model.config
if stop_tokens is None:
stop_tokens = self.model.config.stop_tokens
if execution_dtype is None:
if comfy.model_management.should_use_bf16(device):
execution_dtype = torch.bfloat16
@@ -925,7 +937,7 @@ class Qwen25_3B(BaseLlama, torch.nn.Module):
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
self.dtype = dtype
class Qwen3_06B(BaseLlama, BaseQwen3, torch.nn.Module):
class Qwen3_06B(BaseLlama, BaseQwen3, BaseGenerate, torch.nn.Module):
def __init__(self, config_dict, dtype, device, operations):
super().__init__()
config = Qwen3_06BConfig(**config_dict)
@@ -952,7 +964,7 @@ class Qwen3_2B_ACE15_lm(BaseLlama, BaseQwen3, torch.nn.Module):
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
self.dtype = dtype
class Qwen3_4B(BaseLlama, BaseQwen3, torch.nn.Module):
class Qwen3_4B(BaseLlama, BaseQwen3, BaseGenerate, torch.nn.Module):
def __init__(self, config_dict, dtype, device, operations):
super().__init__()
config = Qwen3_4BConfig(**config_dict)
@@ -970,7 +982,7 @@ class Qwen3_4B_ACE15_lm(BaseLlama, BaseQwen3, torch.nn.Module):
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
self.dtype = dtype
class Qwen3_8B(BaseLlama, BaseQwen3, torch.nn.Module):
class Qwen3_8B(BaseLlama, BaseQwen3, BaseGenerate, torch.nn.Module):
def __init__(self, config_dict, dtype, device, operations):
super().__init__()
config = Qwen3_8BConfig(**config_dict)
@@ -1034,7 +1046,7 @@ class Qwen25_7BVLI(BaseLlama, BaseGenerate, torch.nn.Module):
return super().forward(x, attention_mask=attention_mask, embeds=embeds, num_tokens=num_tokens, intermediate_output=intermediate_output, final_layer_norm_intermediate=final_layer_norm_intermediate, dtype=dtype, position_ids=position_ids)
class Gemma2_2B(BaseLlama, torch.nn.Module):
class Gemma2_2B(BaseLlama, BaseGenerate, torch.nn.Module):
def __init__(self, config_dict, dtype, device, operations):
super().__init__()
config = Gemma2_2B_Config(**config_dict)

View File

@@ -31,9 +31,6 @@ class Gemma2_2BModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="hidden", layer_idx=-2, dtype=None, attention_mask=True, model_options={}):
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"start": 2, "pad": 0}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Gemma2_2B, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
def generate(self, embeds, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed):
return super().generate(embeds, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed, stop_tokens=[107])
class Gemma3_4BModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="hidden", layer_idx=-2, dtype=None, attention_mask=True, model_options={}):
llama_quantization_metadata = model_options.get("llama_quantization_metadata", None)
@@ -43,9 +40,6 @@ class Gemma3_4BModel(sd1_clip.SDClipModel):
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"start": 2, "pad": 0}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Gemma3_4B, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
def generate(self, embeds, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed):
return super().generate(embeds, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed, stop_tokens=[106])
class Gemma3_4B_Vision_Model(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="hidden", layer_idx=-2, dtype=None, attention_mask=True, model_options={}):
llama_quantization_metadata = model_options.get("llama_quantization_metadata", None)