Make BOS preference arch dependent

This commit is contained in:
turboderp
2024-03-19 17:59:02 +01:00
parent e79ed6011e
commit efd20eec03
2 changed files with 12 additions and 2 deletions

View File

@@ -64,6 +64,7 @@ class ExLlamaV2ArchParams:
self.norm_key_1 = ".input_layernorm"
self.norm_key_2 = ".post_attention_layernorm"
self.norm_constant_bias = 0
self.requires_bos = False
self.rope_neox_style = True
# Mixtral
@@ -94,6 +95,7 @@ class ExLlamaV2ArchParams:
self.norm_key_1 = ".input_layernorm"
self.norm_key_2 = ".post_attention_layernorm"
self.norm_constant_bias = 0
self.requires_bos = False
self.rope_neox_style = True
# Yi
@@ -122,6 +124,7 @@ class ExLlamaV2ArchParams:
self.norm_key_1 = ".ln1"
self.norm_key_2 = ".ln2"
self.norm_constant_bias = 0
self.requires_bos = False
self.rope_neox_style = True
# Orion
@@ -150,6 +153,7 @@ class ExLlamaV2ArchParams:
self.norm_key_1 = ".input_layernorm"
self.norm_key_2 = ".post_attention_layernorm"
self.norm_constant_bias = 0
self.requires_bos = False
self.rope_neox_style = True
# Qwen2 (1.5)
@@ -178,6 +182,7 @@ class ExLlamaV2ArchParams:
self.norm_key_1 = ".input_layernorm"
self.norm_key_2 = ".post_attention_layernorm"
self.norm_constant_bias = 0
self.requires_bos = False
self.rope_neox_style = True
# Gemma
@@ -206,6 +211,7 @@ class ExLlamaV2ArchParams:
self.norm_key_1 = ".input_layernorm"
self.norm_key_2 = ".post_attention_layernorm"
self.norm_constant_bias = 1
self.requires_bos = True
self.rope_neox_style = True
# StarCoder2
@@ -233,6 +239,7 @@ class ExLlamaV2ArchParams:
self.norm_key_1 = ".input_layernorm"
self.norm_key_2 = ".post_attention_layernorm"
self.norm_constant_bias = 0
self.requires_bos = False
self.rope_neox_style = True
# GemMoE
@@ -264,6 +271,7 @@ class ExLlamaV2ArchParams:
self.norm_key_1 = ".input_layernorm"
self.norm_key_2 = ".post_attention_layernorm"
self.norm_constant_bias = 1
self.requires_bos = True
self.rope_neox_style = True
# Llama (default + fallback)
@@ -295,6 +303,7 @@ class ExLlamaV2ArchParams:
self.norm_key_1 = ".input_layernorm"
self.norm_key_2 = ".post_attention_layernorm"
self.norm_constant_bias = 0
self.requires_bos = False
self.rope_neox_style = True
# Arch overrides

View File

@@ -43,7 +43,7 @@ parser.add_argument("-el", "--eval_length", type = int, default = 2048, help = "
parser.add_argument("-et", "--eval_token", action = "store_true", help = "Evaluate perplexity on token-by-token inference using cache")
parser.add_argument("-e8", "--eval_token_8bit", action = "store_true", help = "Evaluate perplexity on token-by-token inference using 8-bit (FP8) cache")
parser.add_argument("-eq4", "--eval_token_q4", action = "store_true", help = "Evaluate perplexity on token-by-token inference using Q4 cache")
parser.add_argument("-eb", "--eval_bos", action = "store_true", help = "Add BOS token to every row in perplexity test (required by Gemma and maybe other models.)")
# parser.add_argument("-eb", "--eval_bos", action = "store_true", help = "Add BOS token to every row in perplexity test (required by Gemma and maybe other models.)")
parser.add_argument("-p", "--prompt", type = str, help = "Generate from prompt (basic sampling settings)")
parser.add_argument("-pnb", "--prompt_no_bos", action = "store_true", help = "Don't add BOS token to prompt")
parser.add_argument("-t", "--tokens", type = int, default = 128, help = "Max no. tokens")
@@ -260,7 +260,8 @@ if args.eval_dataset or args.standard_perplexity:
eval_tokens = get_tokens(eval_rows, eval_length, eval_dataset, tokenizer)
eval_len = [eval_tokens.shape[1]] * eval_tokens.shape[0]
if args.eval_bos:
# if args.eval_bos:
if model.config.arch.requires_bos:
boss = torch.full((eval_tokens.shape[0], 1), tokenizer.bos_token_id, dtype = torch.long)
eval_tokens = torch.cat((boss, eval_tokens[:, :-1]), dim = 1)