mirror of
https://github.com/turboderp-org/exllamav2.git
synced 2026-04-20 14:29:28 +00:00
TP mode for attn layer, non-paged
This commit is contained in:
@@ -10,6 +10,7 @@ from exllamav2 import(
|
||||
ExLlamaV2Cache_Q4,
|
||||
ExLlamaV2Cache_Q6,
|
||||
ExLlamaV2Cache_Q8,
|
||||
ExLlamaV2Cache_TP,
|
||||
ExLlamaV2Tokenizer,
|
||||
model_init,
|
||||
)
|
||||
@@ -142,15 +143,20 @@ if args.draft_model_dir:
|
||||
# Create cache
|
||||
|
||||
if args.cache_8bit:
|
||||
cache = ExLlamaV2Cache_8bit(model, lazy = not model.loaded)
|
||||
cache_type = ExLlamaV2Cache_8bit
|
||||
elif args.cache_q4:
|
||||
cache = ExLlamaV2Cache_Q4(model, lazy = not model.loaded)
|
||||
cache_type = ExLlamaV2Cache_Q4
|
||||
elif args.cache_q6:
|
||||
cache = ExLlamaV2Cache_Q6(model, lazy=not model.loaded)
|
||||
cache_type = ExLlamaV2Cache_Q6
|
||||
elif args.cache_q8:
|
||||
cache = ExLlamaV2Cache_Q8(model, lazy = not model.loaded)
|
||||
cache_type = ExLlamaV2Cache_Q8
|
||||
else:
|
||||
cache = ExLlamaV2Cache(model, lazy = not model.loaded)
|
||||
cache_type = ExLlamaV2Cache
|
||||
|
||||
if model.tp_context:
|
||||
cache = ExLlamaV2Cache_TP(model, base = cache_type)
|
||||
else:
|
||||
cache = cache_type(model, lazy = not model.loaded)
|
||||
|
||||
# Load model now if auto split enabled
|
||||
|
||||
|
||||
@@ -970,6 +970,17 @@ class ExLlamaV2Attention(ExLlamaV2Module):
|
||||
**kwargs
|
||||
)
|
||||
|
||||
if self.is_tp:
|
||||
return self.forward_tp(
|
||||
hidden_states,
|
||||
cache,
|
||||
attn_params,
|
||||
past_len,
|
||||
intermediates,
|
||||
loras,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if self.q_handle is None or intermediates:
|
||||
return self.forward_torch(
|
||||
hidden_states,
|
||||
@@ -1091,6 +1102,112 @@ class ExLlamaV2Attention(ExLlamaV2Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
def forward_tp(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
cache: ExLlamaV2CacheBase | None = None,
|
||||
attn_params: ExLlamaV2Attention.Params | None = None,
|
||||
past_len: int | None = None,
|
||||
intermediates: bool = False,
|
||||
loras: list[ExLlamaV2Lora] | None = None,
|
||||
** kwargs
|
||||
):
|
||||
cfg = self.model.config
|
||||
split = self.model.tp_context.get_split(BROADCAST_KV)
|
||||
batch_size, q_len, _ = hidden_states.shape
|
||||
attn_params.prep_tp(self.model)
|
||||
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
|
||||
past_len = 0 if cache is None else cache.current_seq_len
|
||||
|
||||
assert self.q_handle is not None
|
||||
use_flash_attn = has_flash_attn and not cfg.no_flash_attn
|
||||
assert use_flash_attn, "Tensor parallel inference requires flash-attn"
|
||||
|
||||
hidden_states = self.model.tp_context.broadcast(0, hidden_states, BROADCAST_KV, dim = cfg.head_dim)
|
||||
|
||||
residual = hidden_states
|
||||
|
||||
post_norm = self.pre_layernorm.forward(hidden_states) if self.has_norm else hidden_states
|
||||
q = self.q_proj.forward_tp(post_norm, loras = loras, output_split = True, dim = cfg.head_dim)
|
||||
k = self.k_proj.forward_tp(post_norm, loras = loras, output_split = True, dim = cfg.head_dim)
|
||||
v = self.v_proj.forward_tp(post_norm, loras = loras, output_split = True, dim = cfg.head_dim)
|
||||
|
||||
q = [q_.view(batch_size, q_len, q_.shape[1] // cfg.head_dim, cfg.head_dim) for q_ in q]
|
||||
k = [k_.view(batch_size, q_len, k_.shape[1] // cfg.head_dim, cfg.head_dim) for k_ in k]
|
||||
v = [v_.view(batch_size, q_len, v_.shape[1] // cfg.head_dim, cfg.head_dim) for v_ in v]
|
||||
|
||||
if cache:
|
||||
k_cache, v_cache = cache.get_kv_state(self.layer_idx, batch_size, 0, past_len)
|
||||
else:
|
||||
k_cache, v_cache = None, None
|
||||
|
||||
if cfg.arch.rope_style != RopeStyle.NONE:
|
||||
for idx, (dev, a, b) in enumerate(split):
|
||||
constants = self.model.get_device_context(dev, scratch = True)
|
||||
context = self.model.get_device_context(dev)
|
||||
torch.cuda.set_stream(context.stream)
|
||||
for t, heads in [(q[idx], cfg.num_key_value_groups), (k[idx], 1)]:
|
||||
ext_c.rope_(
|
||||
t,
|
||||
constants.sin,
|
||||
constants.cos,
|
||||
past_len,
|
||||
(b - a) * heads,
|
||||
cfg.head_dim,
|
||||
attn_params.position_offsets_tp[idx] if attn_params.position_offsets is not None else none_tensor,
|
||||
cfg.arch.rope_style == RopeStyle.NEOX
|
||||
)
|
||||
|
||||
attn_outputs = []
|
||||
for idx in range(len(split)):
|
||||
dev, a, b = split[idx]
|
||||
context = self.model.get_device_context(dev)
|
||||
torch.cuda.set_stream(context.stream)
|
||||
|
||||
if k_cache is not None:
|
||||
attn_output = flash_attn_with_kvcache(
|
||||
q = q[idx],
|
||||
k = k[idx],
|
||||
v = v[idx],
|
||||
k_cache = k_cache[idx],
|
||||
v_cache = v_cache[idx],
|
||||
causal = True,
|
||||
softmax_scale = self.scaling,
|
||||
cache_seqlens = attn_params.past_len_tp[idx]
|
||||
)
|
||||
else:
|
||||
attn_output = flash_attn_func(
|
||||
q[idx],
|
||||
k[idx],
|
||||
v[idx],
|
||||
causal = True,
|
||||
softmax_scale=self.scaling,
|
||||
)
|
||||
|
||||
attn_output = attn_output.view(batch_size * q_len, (b - a) * cfg.head_dim * cfg.num_key_value_groups)
|
||||
attn_outputs.append(attn_output)
|
||||
|
||||
if cache is not None:
|
||||
cache.store_kv_state(self.layer_idx, batch_size, past_len, q_len)
|
||||
|
||||
# Output projection
|
||||
|
||||
attn_outputs = self.model.tp_context.allgather(1, attn_outputs, BROADCAST_Q, BROADCAST_Q, dim = cfg.head_dim)
|
||||
|
||||
hidden_states = self.o_proj.forward_tp(attn_outputs, loras = loras, dim = cfg.head_dim, output_split = True)
|
||||
|
||||
if self.has_residual:
|
||||
self.model.tp_context.add_residual(hidden_states, residual, BROADCAST_Q, dim = cfg.head_dim)
|
||||
|
||||
hidden_states = self.model.tp_context.gather(0, hidden_states, BROADCAST_Q, dim = cfg.head_dim)
|
||||
|
||||
# if self.post_layernorm: # TODO: ...
|
||||
# hidden_states = self.post_layernorm.forward(hidden_states)
|
||||
|
||||
hidden_states = hidden_states.view(batch_size, q_len, hidden_states.shape[-1])
|
||||
return hidden_states
|
||||
|
||||
|
||||
def forward_torch(self,
|
||||
hidden_states: torch.Tensor,
|
||||
cache: ExLlamaV2CacheBase | None = None,
|
||||
|
||||
@@ -7,12 +7,14 @@ class Params:
|
||||
batch_size: int
|
||||
seq_len: int
|
||||
past_len: int | None
|
||||
past_len_tp: list[torch.Tensor | None] | None
|
||||
past_lens: list[int] | None
|
||||
input_mask: torch.Tensor | None
|
||||
multi_cache: bool
|
||||
attn_mask: torch.Tensor | None
|
||||
attn_masks: torch.Tensor | None
|
||||
position_offsets: torch.Tensor | None
|
||||
position_offsets_tp: list[torch.Tensor | None] | None
|
||||
past_lens_tensor: torch.Tensor | None
|
||||
paged: bool
|
||||
|
||||
@@ -46,7 +48,9 @@ class Params:
|
||||
self.attn_masks = None
|
||||
|
||||
self.position_offsets = position_offsets
|
||||
self.position_offsets_tp = None
|
||||
self.past_lens_tensor = None
|
||||
self.past_len_tp = None
|
||||
self.paged = paged
|
||||
|
||||
def is_causal(self) -> bool:
|
||||
@@ -106,6 +110,25 @@ class Params:
|
||||
attn_masks.append(self.build_single_attn_mask(1, self.seq_len, past_len, device, self.input_mask[i]))
|
||||
return attn_masks
|
||||
|
||||
def prep_tp(self, model):
|
||||
if self.position_offsets_tp is not None:
|
||||
return
|
||||
split = model.tp_context.get_split(BROADCAST_KV)
|
||||
self.position_offsets_tp = []
|
||||
self.past_len_tp = []
|
||||
pl = torch.tensor([self.past_len] * self.batch_size, dtype = torch.int)
|
||||
for dev, a, b in split:
|
||||
context = model.get_device_context(dev)
|
||||
torch.cuda.set_stream(context.stream)
|
||||
if self.position_offsets is None:
|
||||
self.position_offsets_tp.append(None)
|
||||
else:
|
||||
self.position_offsets_tp.append(safe_move_tensor(self.position_offsets, dev, non_blocking = True))
|
||||
if self.past_len is None:
|
||||
self.past_len_tp.append(None)
|
||||
else:
|
||||
self.past_len_tp.append(safe_move_tensor(pl, dev, non_blocking = True))
|
||||
|
||||
|
||||
class PagedParams(Params):
|
||||
|
||||
|
||||
@@ -799,8 +799,8 @@ class ExLlamaV2Cache_TP(ExLlamaV2CacheBase):
|
||||
offset,
|
||||
width,
|
||||
page_size,
|
||||
cache_seqlens[idx],
|
||||
block_table[idx]
|
||||
cache_seqlens[idx] if cache_seqlens else None,
|
||||
block_table[idx] if block_table else None
|
||||
)
|
||||
kc.append(k)
|
||||
vc.append(v)
|
||||
@@ -824,8 +824,8 @@ class ExLlamaV2Cache_TP(ExLlamaV2CacheBase):
|
||||
offset,
|
||||
width,
|
||||
page_size,
|
||||
cache_seqlens[idx],
|
||||
block_table[idx]
|
||||
cache_seqlens[idx] if cache_seqlens else None,
|
||||
block_table[idx] if block_table else None
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -10,7 +10,8 @@ from exllamav2 import(
|
||||
def add_args(parser):
|
||||
|
||||
parser.add_argument("-m", "--model_dir", type = str, help = "Path to model directory")
|
||||
parser.add_argument("-gs", "--gpu_split", type = str, help = "\"auto\", or VRAM allocation per GPU in GB")
|
||||
parser.add_argument("-gs", "--gpu_split", type = str, help = "\"auto\", or VRAM allocation per GPU in GB. \"auto\" is implied by default in tensor-parallel mode.")
|
||||
parser.add_argument("-tp", "--tensor_parallel", action = "store_true", help = "Load in tensor-parallel mode (not fully supported for all models)")
|
||||
parser.add_argument("-l", "--length", type = int, help = "Maximum sequence length")
|
||||
parser.add_argument("-rs", "--rope_scale", type = float, help = "RoPE scaling factor")
|
||||
parser.add_argument("-ra", "--rope_alpha", type = float, help = "RoPE alpha value (NTK)")
|
||||
@@ -33,6 +34,7 @@ def print_options(args):
|
||||
|
||||
print_opts = []
|
||||
if args.gpu_split is not None: print_opts += [f"gpu_split: {args.gpu_split}"]
|
||||
if args.tensor_parallel is not None: print_opts += ["tensor_parallel"]
|
||||
if args.length is not None: print_opts += [f"length: {args.length}"]
|
||||
if args.rope_scale is not None: print_opts += [f"rope_scale: {args.rope_scale}"]
|
||||
if args.rope_alpha is not None: print_opts += [f"rope_alpha: {args.rope_alpha}"]
|
||||
@@ -131,13 +133,18 @@ def init(args,
|
||||
if args.gpu_split and args.gpu_split != "auto":
|
||||
split = [float(alloc) for alloc in args.gpu_split.split(",")]
|
||||
|
||||
if args.gpu_split != "auto" and not skip_load:
|
||||
if args.tensor_parallel:
|
||||
if args.gpu_split == "auto": split = None
|
||||
model.load_tp(split, progress = progress)
|
||||
|
||||
elif args.gpu_split != "auto" and not skip_load:
|
||||
if not quiet and not progress: print(" -- Loading model...")
|
||||
t = time.time()
|
||||
model.load(split, progress = progress)
|
||||
t = time.time() - t
|
||||
if benchmark and not quiet:
|
||||
print(f" -- Loaded model in {t:.4f} seconds")
|
||||
|
||||
else:
|
||||
assert allow_auto_split, "Auto split not allowed."
|
||||
|
||||
|
||||
@@ -7,6 +7,7 @@ from exllamav2 import(
|
||||
ExLlamaV2Cache_Q4,
|
||||
ExLlamaV2Cache_Q6,
|
||||
ExLlamaV2Cache_Q8,
|
||||
ExLlamaV2Cache_TP,
|
||||
ExLlamaV2Tokenizer,
|
||||
model_init,
|
||||
)
|
||||
@@ -96,11 +97,14 @@ if args.stream_layers:
|
||||
|
||||
model_init.check_args(args)
|
||||
model_init.print_options(args)
|
||||
model, tokenizer = model_init.init(args,
|
||||
allow_auto_split = True,
|
||||
skip_load = args.stream_layers,
|
||||
benchmark = True,
|
||||
max_output_len = args.max_output_len)
|
||||
model, tokenizer = model_init.init(
|
||||
args,
|
||||
allow_auto_split = True,
|
||||
skip_load = args.stream_layers,
|
||||
benchmark = True,
|
||||
max_output_len = args.max_output_len,
|
||||
progress = True
|
||||
)
|
||||
cache = None
|
||||
|
||||
# Auto split
|
||||
@@ -113,7 +117,7 @@ if not model.loaded and not args.stream_layers:
|
||||
print(" -- Loading model...")
|
||||
cache = ExLlamaV2Cache(model, lazy = True)
|
||||
t = time.time()
|
||||
model.load_autosplit(cache)
|
||||
model.load_autosplit(cache, progress = True)
|
||||
t = time.time() - t
|
||||
print(f" -- Loaded model in {t:.4f} seconds")
|
||||
|
||||
@@ -185,7 +189,7 @@ if args.prompt:
|
||||
with torch.inference_mode():
|
||||
|
||||
if cache is None:
|
||||
cache = ExLlamaV2Cache(model)
|
||||
cache = ExLlamaV2Cache(model) if not model.tp_context else ExLlamaV2Cache_TP(model)
|
||||
|
||||
ids = tokenizer.encode(args.prompt)
|
||||
tokens_prompt = ids.shape[-1]
|
||||
@@ -292,7 +296,8 @@ if args.eval_dataset or args.standard_perplexity:
|
||||
|
||||
def ppl(input_ids__, logits__, lengths__, bins = False):
|
||||
|
||||
logits_device = model.modules[-1].device()
|
||||
logits_device = model.modules[-1].device() if not model.tp_context else \
|
||||
torch.device(model.tp_context.device)
|
||||
|
||||
if bins:
|
||||
num_bins = (max(lengths__) + 255) // 256
|
||||
@@ -389,7 +394,10 @@ if args.eval_dataset or args.standard_perplexity:
|
||||
sys.stdout.flush()
|
||||
|
||||
if cache is None:
|
||||
cache = ExLlamaV2Cache(model, max_seq_len = eval_length) if eval_length > model.config.max_input_len else None
|
||||
if eval_length > model.config.max_input_len:
|
||||
cache = ExLlamaV2Cache(model, max_seq_len = eval_length) if not model.tp_context else ExLlamaV2Cache_TP(model, max_seq_len = eval_length)
|
||||
else:
|
||||
cache = None
|
||||
|
||||
for i in range(eval_tokens.shape[0]):
|
||||
|
||||
@@ -470,7 +478,8 @@ if args.eval_dataset or args.standard_perplexity:
|
||||
else:
|
||||
print(f" -- Inference (token)", end = "")
|
||||
sys.stdout.flush()
|
||||
cache = ExLlamaV2Cache(model, max_seq_len = eval_length)
|
||||
cache = ExLlamaV2Cache(model, max_seq_len = eval_length) if not model.tp_context else \
|
||||
ExLlamaV2Cache_TP(model, max_seq_len = eval_length)
|
||||
test_ppl_token()
|
||||
|
||||
if args.eval_token_8bit:
|
||||
@@ -479,7 +488,8 @@ if args.eval_dataset or args.standard_perplexity:
|
||||
else:
|
||||
print(f" -- Inference (token, 8-bit cache)", end = "")
|
||||
sys.stdout.flush()
|
||||
cache = ExLlamaV2Cache_8bit(model, max_seq_len = eval_length)
|
||||
cache = ExLlamaV2Cache_8bit(model, max_seq_len = eval_length) if not model.tp_context else \
|
||||
ExLlamaV2Cache_TP(model, max_seq_len = eval_length, base = ExLlamaV2Cache_8bit)
|
||||
test_ppl_token()
|
||||
|
||||
if args.eval_token_q4:
|
||||
@@ -488,7 +498,8 @@ if args.eval_dataset or args.standard_perplexity:
|
||||
else:
|
||||
print(f" -- Inference (token, Q4 cache)", end = "")
|
||||
sys.stdout.flush()
|
||||
cache = ExLlamaV2Cache_Q4(model, max_seq_len = eval_length)
|
||||
cache = ExLlamaV2Cache_Q4(model, max_seq_len = eval_length) if not model.tp_context else \
|
||||
ExLlamaV2Cache_TP(model, max_seq_len = eval_length, base = ExLlamaV2Cache_Q4)
|
||||
# cache.calibrate(tokenizer)
|
||||
test_ppl_token()
|
||||
|
||||
@@ -498,7 +509,8 @@ if args.eval_dataset or args.standard_perplexity:
|
||||
else:
|
||||
print(f" -- Inference (token, Q6 cache)", end = "")
|
||||
sys.stdout.flush()
|
||||
cache = ExLlamaV2Cache_Q6(model, max_seq_len = eval_length)
|
||||
cache = ExLlamaV2Cache_Q6(model, max_seq_len = eval_length) if not model.tp_context else \
|
||||
ExLlamaV2Cache_TP(model, max_seq_len = eval_length, base = ExLlamaV2Cache_Q6)
|
||||
# cache.calibrate(tokenizer)
|
||||
test_ppl_token()
|
||||
|
||||
@@ -508,7 +520,8 @@ if args.eval_dataset or args.standard_perplexity:
|
||||
else:
|
||||
print(f" -- Inference (token, Q8 cache)", end = "")
|
||||
sys.stdout.flush()
|
||||
cache = ExLlamaV2Cache_Q8(model, max_seq_len = eval_length)
|
||||
cache = ExLlamaV2Cache_Q8(model, max_seq_len = eval_length) if not model.tp_context else \
|
||||
ExLlamaV2Cache_TP(model, max_seq_len = eval_length, base = ExLlamaV2Cache_Q8)
|
||||
# cache.calibrate(tokenizer)
|
||||
test_ppl_token()
|
||||
|
||||
@@ -520,7 +533,7 @@ if args.prompt_speed:
|
||||
with torch.inference_mode():
|
||||
|
||||
if cache is None:
|
||||
cache = ExLlamaV2Cache(model)
|
||||
cache = ExLlamaV2Cache(model) if not model.tp_context else ExLlamaV2Cache_TP(model)
|
||||
|
||||
ids = torch.randint(0, model.config.vocab_size - 1, (1, model.config.max_seq_len))
|
||||
|
||||
@@ -571,7 +584,7 @@ if args.speed:
|
||||
with torch.inference_mode():
|
||||
|
||||
if cache is None:
|
||||
cache = ExLlamaV2Cache(model)
|
||||
cache = ExLlamaV2Cache(model) if not model.tp_context else ExLlamaV2Cache_TP(model)
|
||||
cache.current_seq_len = 0
|
||||
|
||||
print(f" -- Measuring token speed...")
|
||||
|
||||
Reference in New Issue
Block a user