Refactor to consolidate attn params

This commit is contained in:
turboderp
2024-01-04 04:52:49 +01:00
parent f2e7648d98
commit 41b15dd1c3
11 changed files with 184 additions and 120 deletions

View File

@@ -13,6 +13,10 @@ from exllamav2.generator import (
ExLlamaV2Sampler
)
from exllamav2.attn import (
ExLlamaV2Attention
)
import argparse, os, math, time
import pandas, fastparquet
import torch
@@ -262,7 +266,8 @@ if args.eval_dataset or args.standard_perplexity:
sys.stdout.flush()
batch_size, seq_len = eval_tokens.shape
attn_mask = model.build_attn_mask(stream_batch_size, seq_len, 0, None, "cuda:0")
attn_params = ExLlamaV2Attention.Params(stream_batch_size, seq_len, 0, None, None)
# attn_mask = model.build_attn_mask(stream_batch_size, seq_len, 0, None, "cuda:0")
for idx, module in enumerate(model.modules):
module.set_device_idx(-1 if idx == 0 else 0)
@@ -283,7 +288,7 @@ if args.eval_dataset or args.standard_perplexity:
a = b
b = min(b + stream_batch_size, eval_tokens.shape[0])
x = hidden_state[a:b, :, :].to("cuda:0")
x = module.forward(x, cache = None, attn_mask = attn_mask, past_len = 0, loras = None, position_offsets = None)
x = module.forward(x, cache = None, attn_params = attn_params, past_len = 0, loras = None)
if idx < len(model.modules) - 1:
hidden_state[a:b, :, :] = x.to("cpu")