mirror of
https://github.com/turboderp-org/exllamav2.git
synced 2026-04-22 23:39:01 +00:00
Refactor to consolidate attn params
This commit is contained in:
@@ -13,6 +13,10 @@ from exllamav2.generator import (
|
||||
ExLlamaV2Sampler
|
||||
)
|
||||
|
||||
from exllamav2.attn import (
|
||||
ExLlamaV2Attention
|
||||
)
|
||||
|
||||
import argparse, os, math, time
|
||||
import pandas, fastparquet
|
||||
import torch
|
||||
@@ -262,7 +266,8 @@ if args.eval_dataset or args.standard_perplexity:
|
||||
sys.stdout.flush()
|
||||
|
||||
batch_size, seq_len = eval_tokens.shape
|
||||
attn_mask = model.build_attn_mask(stream_batch_size, seq_len, 0, None, "cuda:0")
|
||||
attn_params = ExLlamaV2Attention.Params(stream_batch_size, seq_len, 0, None, None)
|
||||
# attn_mask = model.build_attn_mask(stream_batch_size, seq_len, 0, None, "cuda:0")
|
||||
|
||||
for idx, module in enumerate(model.modules):
|
||||
module.set_device_idx(-1 if idx == 0 else 0)
|
||||
@@ -283,7 +288,7 @@ if args.eval_dataset or args.standard_perplexity:
|
||||
a = b
|
||||
b = min(b + stream_batch_size, eval_tokens.shape[0])
|
||||
x = hidden_state[a:b, :, :].to("cuda:0")
|
||||
x = module.forward(x, cache = None, attn_mask = attn_mask, past_len = 0, loras = None, position_offsets = None)
|
||||
x = module.forward(x, cache = None, attn_params = attn_params, past_len = 0, loras = None)
|
||||
|
||||
if idx < len(model.modules) - 1:
|
||||
hidden_state[a:b, :, :] = x.to("cpu")
|
||||
|
||||
Reference in New Issue
Block a user