Make Qwen 8B work with TextGenerate node. (#13160)

This commit is contained in:
comfyanonymous
2026-03-25 20:21:44 -07:00
committed by GitHub
parent 3eba2dcf2d
commit 2a1f402601
2 changed files with 8 additions and 1 deletions

View File

@@ -928,6 +928,7 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
weight = state_dict.pop(weight_key, None)
if weight is None:
logging.warning(f"Missing weight for layer {layer_name}")
self.weight = None
return
manually_loaded_keys = [weight_key]
@@ -1034,6 +1035,9 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
if self.bias is not None:
sd["{}bias".format(prefix)] = self.bias
if self.weight is None:
return sd
if isinstance(self.weight, QuantizedTensor):
sd_out = self.weight.state_dict("{}weight".format(prefix))
for k in sd_out:

View File

@@ -224,7 +224,7 @@ class Qwen3_8BConfig:
k_norm = "gemma3"
rope_scale = None
final_norm: bool = True
lm_head: bool = False
lm_head: bool = True
stop_tokens = [151643, 151645]
@dataclass
@@ -912,6 +912,9 @@ class BaseGenerate:
class BaseQwen3:
def logits(self, x):
input = x[:, -1:]
if self.model.config.lm_head:
return self.model.lm_head(input)
module = self.model.embed_tokens
offload_stream = None