WIP Gemma3: not working

This commit is contained in:
Iwan Kawrakow
2025-03-21 17:20:11 +02:00
parent 3d6e25c82d
commit 0989a9f76f
3 changed files with 304 additions and 1 deletions

View File

@@ -70,7 +70,8 @@ class Model:
def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path, is_big_endian: bool = False,
use_temp_file: bool = False, eager: bool = False,
metadata_override: Path | None = None, model_name: str | None = None,
split_max_tensors: int = 0, split_max_size: int = 0, dry_run: bool = False, small_first_shard: bool = False):
split_max_tensors: int = 0, split_max_size: int = 0, dry_run: bool = False, small_first_shard: bool = False,
hparams: dict[str, Any] | None = None):
if type(self) is Model:
raise TypeError(f"{type(self).__name__!r} should not be directly instantiated")
@@ -2856,6 +2857,82 @@ class OlmoModel(Model):
return [(self.map_tensor_name(name), data_torch)]
@Model.register("Gemma3ForCausalLM", "Gemma3ForConditionalGeneration")
class Gemma3Model(Model):
model_arch = gguf.MODEL_ARCH.GEMMA3
has_vision: bool = False
# we need to merge the text_config into the root level of hparams
def __init__(self, *args, **kwargs):
hparams = Model.load_hparams(kwargs["dir_model"])
if "text_config" in hparams:
hparams = {**hparams, **hparams["text_config"]}
kwargs["hparams"] = hparams
super().__init__(*args, **kwargs)
if "vision_config" in hparams:
logger.info("Has vision encoder, but it will be ignored")
self.has_vision = True
def write(self):
super().write()
if self.has_vision:
logger.info("NOTE: this script only convert the language model to GGUF")
logger.info(" for the vision model, please use gemma3_convert_encoder_to_gguf.py")
def set_vocab(self):
self._set_vocab_sentencepiece()
self.gguf_writer.add_add_space_prefix(False)
def set_gguf_parameters(self):
hparams = self.hparams
block_count = hparams["num_hidden_layers"]
# some default values are not specified in the hparams
self.gguf_writer.add_context_length(hparams.get("max_position_embeddings", 131072))
self.gguf_writer.add_embedding_length(hparams["hidden_size"])
self.gguf_writer.add_block_count(block_count)
self.gguf_writer.add_feed_forward_length(hparams["intermediate_size"])
self.gguf_writer.add_head_count(hparams.get("num_attention_heads", 8))
self.gguf_writer.add_layer_norm_rms_eps(self.hparams.get("rms_norm_eps", 1e-6))
self.gguf_writer.add_key_length(hparams.get("head_dim", 256))
self.gguf_writer.add_value_length(hparams.get("head_dim", 256))
self.gguf_writer.add_file_type(self.ftype)
self.gguf_writer.add_rope_freq_base(hparams.get("rope_theta", 1_000_000.0)) # for global layers
# both attn_logit_softcapping and final_logit_softcapping are removed in Gemma3
assert hparams.get("attn_logit_softcapping") is None
assert hparams.get("final_logit_softcapping") is None
self.gguf_writer.add_sliding_window(hparams["sliding_window"])
self.gguf_writer.add_head_count_kv(hparams.get("num_key_value_heads", 4))
if hparams.get("rope_scaling") is not None:
assert hparams["rope_scaling"]["rope_type"] == "linear"
# important: this rope_scaling is only applied for global layers, and not used by 1B model
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR)
self.gguf_writer.add_rope_scaling_factor(hparams["rope_scaling"]["factor"])
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
del bid # unused
if name.startswith("language_model."):
name = name.replace("language_model.", "")
elif name.startswith("multi_modal_projector.") or name.startswith("vision_tower.") \
or name.startswith("multimodal_projector.") or name.startswith("vision_model."): # this is for old HF model, should be removed later
# ignore vision tensors
return []
# remove OOV (out-of-vocabulary) rows in token_embd
if "embed_tokens.weight" in name:
vocab = self._create_vocab_sentencepiece()
tokens = vocab[0]
data_torch = data_torch[:len(tokens)]
# ref code in Gemma3RMSNorm
# output = output * (1.0 + self.weight.float())
if name.endswith("norm.weight"):
data_torch = data_torch + 1
return [(self.map_tensor_name(name), data_torch)]
@Model.register("JinaBertModel", "JinaBertForMaskedLM")
class JinaBertV2Model(BertModel):

View File

@@ -207,6 +207,7 @@ class MODEL_ARCH(IntEnum):
MINICPM = auto()
GEMMA = auto()
GEMMA2 = auto()
GEMMA3 = auto()
STARCODER2 = auto()
MAMBA = auto()
XVERSE = auto()
@@ -338,6 +339,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
MODEL_ARCH.MINICPM: "minicpm",
MODEL_ARCH.GEMMA: "gemma",
MODEL_ARCH.GEMMA2: "gemma2",
MODEL_ARCH.GEMMA3: "gemma3",
MODEL_ARCH.STARCODER2: "starcoder2",
MODEL_ARCH.MAMBA: "mamba",
MODEL_ARCH.XVERSE: "xverse",

View File

@@ -203,6 +203,7 @@ enum llm_arch {
LLM_ARCH_MINICPM,
LLM_ARCH_GEMMA,
LLM_ARCH_GEMMA2,
LLM_ARCH_GEMMA3,
LLM_ARCH_STARCODER2,
LLM_ARCH_MAMBA,
LLM_ARCH_XVERSE,
@@ -250,6 +251,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
{ LLM_ARCH_MINICPM, "minicpm" },
{ LLM_ARCH_GEMMA, "gemma" },
{ LLM_ARCH_GEMMA2, "gemma2" },
{ LLM_ARCH_GEMMA3, "gemma3" },
{ LLM_ARCH_STARCODER2, "starcoder2" },
{ LLM_ARCH_MAMBA, "mamba" },
{ LLM_ARCH_XVERSE, "xverse" },
@@ -1056,6 +1058,26 @@ static const std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NA
{ LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" },
},
},
{
LLM_ARCH_GEMMA3,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
{ LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" },
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
{ LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" },
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
{ LLM_TENSOR_ATTN_POST_NORM, "blk.%d.post_attention_norm" },
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
{ LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" },
},
},
{
LLM_ARCH_STARCODER2,
{
@@ -2313,6 +2335,7 @@ struct llama_hparams {
uint32_t n_layer;
uint32_t n_rot;
uint32_t n_swa = 0; // sliding window attention (SWA)
uint32_t n_swa_pattern = 1; // by default, all layers use non-sliding-window attention
uint32_t n_embd_head_k; // dimension of keys (d_k). d_q is assumed to be the same, but there are n_head q heads, and only n_head_kv k-v heads
uint32_t n_embd_head_v; // dimension of values (d_v) aka n_embd_head
uint32_t n_expert = 0;
@@ -2342,7 +2365,9 @@ struct llama_hparams {
float rope_attn_factor = 1.0f;
float rope_freq_base_train;
float rope_freq_base_train_swa;
float rope_freq_scale_train;
float rope_freq_scale_train_swa;
uint32_t n_ctx_orig_yarn;
float rope_yarn_log_mul;
@@ -4963,6 +4988,10 @@ static void llm_load_hparams(
}
hparams.rope_freq_scale_train = ropescale == 0.0f ? 1.0f : 1.0f/ropescale;
// by default assume that the sliding-window layers use the same scaling type as the non-sliding-window layers
hparams.rope_freq_base_train_swa = hparams.rope_freq_base_train;
hparams.rope_freq_scale_train_swa = hparams.rope_freq_scale_train;
ml.get_key(LLM_KV_ROPE_SCALING_ATTN_FACTOR, hparams.rope_attn_factor, false);
// non-transformer models do not have attention heads
@@ -5310,6 +5339,28 @@ static void llm_load_hparams(
default: model.type = e_model::MODEL_UNKNOWN;
}
} break;
case LLM_ARCH_GEMMA3:
{
hparams.n_swa_pattern = 6;
hparams.rope_freq_base_train_swa = 10000.0f;
hparams.rope_freq_scale_train_swa = 1.0f;
ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa);
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
switch (hparams.n_layer) {
case 26: model.type = e_model::MODEL_2B; break;
case 34: model.type = e_model::MODEL_4B; break;
case 48: model.type = e_model::MODEL_12B; break;
case 62: model.type = e_model::MODEL_27B; break;
default: model.type = e_model::MODEL_UNKNOWN;
}
hparams.f_attention_scale = model.type == e_model::MODEL_27B
? 1.0f / std::sqrt(float(hparams.n_embd / hparams.n_head(0)))
: 1.0f / std::sqrt(float(hparams.n_embd_head_k));
} break;
case LLM_ARCH_STARCODER2:
{
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
@@ -7411,6 +7462,38 @@ static bool llm_load_tensors(
layer.ffn_post_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd});
}
} break;
case LLM_ARCH_GEMMA3:
{
model.tok_embd = create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
// output
model.output_norm = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
model.output = create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab},
llama_model_loader::TENSOR_DUPLICATED); // same as tok_embd, duplicated to allow offloading
for (int i = 0; i < n_layer; ++i) {
ggml_context * ctx_layer = ctx_for_layer(i);
ggml_context * ctx_split = ctx_for_layer_split(i);
auto & layer = model.layers[i];
layer.attn_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
layer.wq = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0);
layer.wk = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0);
layer.wv = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0);
layer.wo = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0);
layer.attn_post_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd}, 0);
layer.attn_k_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0);
layer.attn_q_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, 0);
layer.ffn_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
layer.ffn_gate = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0);
layer.ffn_up = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
layer.ffn_down = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0);
layer.ffn_post_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, 0);
}
} break;
case LLM_ARCH_STARCODER2:
{
model.tok_embd = create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
@@ -12664,6 +12747,142 @@ struct llm_build_context {
return gf;
}
struct ggml_cgraph * build_gemma3() {
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
const int64_t n_embd_head = hparams.n_embd_head_k;
ggml_tensor * cur;
ggml_tensor * inpL;
inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
// important: do not normalize weights for raw embeddings input (i.e. encoded image emdeddings)
// no vision yet, so we need to scale
//if (ubatch.token) {
inpL = ggml_scale(ctx0, inpL, sqrtf(n_embd));
cb(inpL, "inp_scaled", -1);
//}
// inp_pos - contains the positions
ggml_tensor * inp_pos = build_inp_pos();
//// TODO: is causal == true correct? might need some changes
//auto * inp_attn = build_attn_inp_kv_unified();
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
// gemma 2 requires different mask for layers using sliding window (SWA)
struct ggml_tensor * KQ_mask = build_inp_KQ_mask(true);
struct ggml_tensor * KQ_mask_swa = build_inp_KQ_mask_swa(true);
for (int il = 0; il < n_layer; ++il) {
const bool is_swa = hparams.n_swa > 0 && hparams.n_swa_pattern > 0 && il % hparams.n_swa_pattern < (hparams.n_swa_pattern - 1);
auto KQ_mask_l = is_swa ? KQ_mask_swa : KQ_mask;
const float freq_base_l = is_swa ? hparams.rope_freq_base_train_swa : cparams.rope_freq_base;
const float freq_scale_l = is_swa ? hparams.rope_freq_scale_train_swa : cparams.rope_freq_scale;
cur = llm_build_norm(ctx0, inpL, hparams, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, cb, il);
cb(cur, "attn_norm", il);
// self-attention
{
// compute Q and K and RoPE them
ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
cb(Qcur, "Qcur", il);
ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
cb(Kcur, "Kcur", il);
ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
cb(Vcur, "Vcur", il);
//Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
//Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
//Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
Qcur = llm_build_norm(ctx0, Qcur, hparams, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, cb, il);
cb(Qcur, "Qcur_normed", il);
Qcur = ggml_rope_ext(
ctx0, Qcur, inp_pos, nullptr,
n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l,
ext_factor, attn_factor, beta_fast, beta_slow);
Kcur = llm_build_norm(ctx0, Kcur, hparams, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, cb, il);
cb(Kcur, "Kcur_normed", il);
Kcur = ggml_rope_ext(
ctx0, Kcur, inp_pos, nullptr,
n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l,
ext_factor, attn_factor, beta_fast, beta_slow);
cb(Qcur, "Qcur", il);
cb(Kcur, "Kcur", il);
cb(Vcur, "Vcur", il);
cur = llm_build_kv(ctx0, lctx, kv_self, gf,
model.layers[il].wo, NULL,
Kcur, Vcur, Qcur, KQ_mask_l, n_tokens, kv_head, n_kv, hparams.f_attention_scale, cb, il);
//cur = build_attn(inp_attn, gf,
// model.layers[il].wo, NULL,
// Qcur, Kcur, Vcur, nullptr, hparams.f_attention_scale, il);
}
cur = llm_build_norm(ctx0, cur, hparams, model.layers[il].attn_post_norm, NULL, LLM_NORM_RMS, cb, il);
cb(cur, "attn_post_norm", il);
if (il == n_layer - 1) {
// skip computing output for unused tokens
ggml_tensor * inp_out_ids = build_inp_out_ids();
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
}
ggml_tensor * sa_out = ggml_add(ctx0, cur, inpL);
cb(sa_out, "sa_out", il);
cur = llm_build_norm(ctx0, sa_out, hparams, model.layers[il].ffn_norm, NULL, LLM_NORM_RMS, cb, il);
cb(cur, "ffn_norm", il);
// feed-forward network
{
cur = llm_build_ffn(ctx0, lctx, cur,
model.layers[il].ffn_up, NULL, NULL,
model.layers[il].ffn_gate, NULL, NULL,
model.layers[il].ffn_down, NULL, NULL,
NULL,
LLM_FFN_GELU, LLM_FFN_PAR, cb, il);
cb(cur, "ffn_out", il);
}
cur = llm_build_norm(ctx0, cur, hparams, model.layers[il].ffn_post_norm, NULL, LLM_NORM_RMS, cb, -1);
cb(cur, "ffn_post_norm", -1);
cur = ggml_add(ctx0, cur, sa_out);
// cvec?
//cur = build_cvec(cur, il);
//cb(cur, "l_out", il);
// input for next layer
inpL = cur;
}
cur = inpL;
cur = llm_build_norm(ctx0, cur, hparams, model.output_norm, NULL, LLM_NORM_RMS, cb, -1);
cb(cur, "result_norm", -1);
// lm_head
cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
cb(cur, "result_output", -1);
ggml_build_forward_expand(gf, cur);
return gf;
}
struct ggml_cgraph * build_starcoder2() {
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
@@ -14996,6 +15215,10 @@ static struct ggml_cgraph * llama_build_graph(
{
result = llm.build_gemma2();
} break;
case LLM_ARCH_GEMMA3:
{
result = llm.build_gemma3();
} break;
case LLM_ARCH_STARCODER2:
{
result = llm.build_starcoder2();
@@ -18826,6 +19049,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) {
case LLM_ARCH_PHI3:
case LLM_ARCH_GEMMA:
case LLM_ARCH_GEMMA2:
case LLM_ARCH_GEMMA3:
case LLM_ARCH_STARCODER2:
case LLM_ARCH_OPENELM:
case LLM_ARCH_GPTNEOX: