TP mode for attn layer, non-paged

This commit is contained in:
turboderp
2024-08-14 23:41:10 +02:00
parent 65b9e17c4f
commit b30f796690
6 changed files with 193 additions and 27 deletions

View File

@@ -10,6 +10,7 @@ from exllamav2 import(
ExLlamaV2Cache_Q4,
ExLlamaV2Cache_Q6,
ExLlamaV2Cache_Q8,
ExLlamaV2Cache_TP,
ExLlamaV2Tokenizer,
model_init,
)
@@ -142,15 +143,20 @@ if args.draft_model_dir:
# Create cache
if args.cache_8bit:
cache = ExLlamaV2Cache_8bit(model, lazy = not model.loaded)
cache_type = ExLlamaV2Cache_8bit
elif args.cache_q4:
cache = ExLlamaV2Cache_Q4(model, lazy = not model.loaded)
cache_type = ExLlamaV2Cache_Q4
elif args.cache_q6:
cache = ExLlamaV2Cache_Q6(model, lazy=not model.loaded)
cache_type = ExLlamaV2Cache_Q6
elif args.cache_q8:
cache = ExLlamaV2Cache_Q8(model, lazy = not model.loaded)
cache_type = ExLlamaV2Cache_Q8
else:
cache = ExLlamaV2Cache(model, lazy = not model.loaded)
cache_type = ExLlamaV2Cache
if model.tp_context:
cache = ExLlamaV2Cache_TP(model, base = cache_type)
else:
cache = cache_type(model, lazy = not model.loaded)
# Load model now if auto split enabled

View File

@@ -970,6 +970,17 @@ class ExLlamaV2Attention(ExLlamaV2Module):
**kwargs
)
if self.is_tp:
return self.forward_tp(
hidden_states,
cache,
attn_params,
past_len,
intermediates,
loras,
**kwargs,
)
if self.q_handle is None or intermediates:
return self.forward_torch(
hidden_states,
@@ -1091,6 +1102,112 @@ class ExLlamaV2Attention(ExLlamaV2Module):
return hidden_states
def forward_tp(
self,
hidden_states: torch.Tensor,
cache: ExLlamaV2CacheBase | None = None,
attn_params: ExLlamaV2Attention.Params | None = None,
past_len: int | None = None,
intermediates: bool = False,
loras: list[ExLlamaV2Lora] | None = None,
** kwargs
):
cfg = self.model.config
split = self.model.tp_context.get_split(BROADCAST_KV)
batch_size, q_len, _ = hidden_states.shape
attn_params.prep_tp(self.model)
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
past_len = 0 if cache is None else cache.current_seq_len
assert self.q_handle is not None
use_flash_attn = has_flash_attn and not cfg.no_flash_attn
assert use_flash_attn, "Tensor parallel inference requires flash-attn"
hidden_states = self.model.tp_context.broadcast(0, hidden_states, BROADCAST_KV, dim = cfg.head_dim)
residual = hidden_states
post_norm = self.pre_layernorm.forward(hidden_states) if self.has_norm else hidden_states
q = self.q_proj.forward_tp(post_norm, loras = loras, output_split = True, dim = cfg.head_dim)
k = self.k_proj.forward_tp(post_norm, loras = loras, output_split = True, dim = cfg.head_dim)
v = self.v_proj.forward_tp(post_norm, loras = loras, output_split = True, dim = cfg.head_dim)
q = [q_.view(batch_size, q_len, q_.shape[1] // cfg.head_dim, cfg.head_dim) for q_ in q]
k = [k_.view(batch_size, q_len, k_.shape[1] // cfg.head_dim, cfg.head_dim) for k_ in k]
v = [v_.view(batch_size, q_len, v_.shape[1] // cfg.head_dim, cfg.head_dim) for v_ in v]
if cache:
k_cache, v_cache = cache.get_kv_state(self.layer_idx, batch_size, 0, past_len)
else:
k_cache, v_cache = None, None
if cfg.arch.rope_style != RopeStyle.NONE:
for idx, (dev, a, b) in enumerate(split):
constants = self.model.get_device_context(dev, scratch = True)
context = self.model.get_device_context(dev)
torch.cuda.set_stream(context.stream)
for t, heads in [(q[idx], cfg.num_key_value_groups), (k[idx], 1)]:
ext_c.rope_(
t,
constants.sin,
constants.cos,
past_len,
(b - a) * heads,
cfg.head_dim,
attn_params.position_offsets_tp[idx] if attn_params.position_offsets is not None else none_tensor,
cfg.arch.rope_style == RopeStyle.NEOX
)
attn_outputs = []
for idx in range(len(split)):
dev, a, b = split[idx]
context = self.model.get_device_context(dev)
torch.cuda.set_stream(context.stream)
if k_cache is not None:
attn_output = flash_attn_with_kvcache(
q = q[idx],
k = k[idx],
v = v[idx],
k_cache = k_cache[idx],
v_cache = v_cache[idx],
causal = True,
softmax_scale = self.scaling,
cache_seqlens = attn_params.past_len_tp[idx]
)
else:
attn_output = flash_attn_func(
q[idx],
k[idx],
v[idx],
causal = True,
softmax_scale=self.scaling,
)
attn_output = attn_output.view(batch_size * q_len, (b - a) * cfg.head_dim * cfg.num_key_value_groups)
attn_outputs.append(attn_output)
if cache is not None:
cache.store_kv_state(self.layer_idx, batch_size, past_len, q_len)
# Output projection
attn_outputs = self.model.tp_context.allgather(1, attn_outputs, BROADCAST_Q, BROADCAST_Q, dim = cfg.head_dim)
hidden_states = self.o_proj.forward_tp(attn_outputs, loras = loras, dim = cfg.head_dim, output_split = True)
if self.has_residual:
self.model.tp_context.add_residual(hidden_states, residual, BROADCAST_Q, dim = cfg.head_dim)
hidden_states = self.model.tp_context.gather(0, hidden_states, BROADCAST_Q, dim = cfg.head_dim)
# if self.post_layernorm: # TODO: ...
# hidden_states = self.post_layernorm.forward(hidden_states)
hidden_states = hidden_states.view(batch_size, q_len, hidden_states.shape[-1])
return hidden_states
def forward_torch(self,
hidden_states: torch.Tensor,
cache: ExLlamaV2CacheBase | None = None,

View File

@@ -7,12 +7,14 @@ class Params:
batch_size: int
seq_len: int
past_len: int | None
past_len_tp: list[torch.Tensor | None] | None
past_lens: list[int] | None
input_mask: torch.Tensor | None
multi_cache: bool
attn_mask: torch.Tensor | None
attn_masks: torch.Tensor | None
position_offsets: torch.Tensor | None
position_offsets_tp: list[torch.Tensor | None] | None
past_lens_tensor: torch.Tensor | None
paged: bool
@@ -46,7 +48,9 @@ class Params:
self.attn_masks = None
self.position_offsets = position_offsets
self.position_offsets_tp = None
self.past_lens_tensor = None
self.past_len_tp = None
self.paged = paged
def is_causal(self) -> bool:
@@ -106,6 +110,25 @@ class Params:
attn_masks.append(self.build_single_attn_mask(1, self.seq_len, past_len, device, self.input_mask[i]))
return attn_masks
def prep_tp(self, model):
if self.position_offsets_tp is not None:
return
split = model.tp_context.get_split(BROADCAST_KV)
self.position_offsets_tp = []
self.past_len_tp = []
pl = torch.tensor([self.past_len] * self.batch_size, dtype = torch.int)
for dev, a, b in split:
context = model.get_device_context(dev)
torch.cuda.set_stream(context.stream)
if self.position_offsets is None:
self.position_offsets_tp.append(None)
else:
self.position_offsets_tp.append(safe_move_tensor(self.position_offsets, dev, non_blocking = True))
if self.past_len is None:
self.past_len_tp.append(None)
else:
self.past_len_tp.append(safe_move_tensor(pl, dev, non_blocking = True))
class PagedParams(Params):

View File

@@ -799,8 +799,8 @@ class ExLlamaV2Cache_TP(ExLlamaV2CacheBase):
offset,
width,
page_size,
cache_seqlens[idx],
block_table[idx]
cache_seqlens[idx] if cache_seqlens else None,
block_table[idx] if block_table else None
)
kc.append(k)
vc.append(v)
@@ -824,8 +824,8 @@ class ExLlamaV2Cache_TP(ExLlamaV2CacheBase):
offset,
width,
page_size,
cache_seqlens[idx],
block_table[idx]
cache_seqlens[idx] if cache_seqlens else None,
block_table[idx] if block_table else None
)

View File

@@ -10,7 +10,8 @@ from exllamav2 import(
def add_args(parser):
parser.add_argument("-m", "--model_dir", type = str, help = "Path to model directory")
parser.add_argument("-gs", "--gpu_split", type = str, help = "\"auto\", or VRAM allocation per GPU in GB")
parser.add_argument("-gs", "--gpu_split", type = str, help = "\"auto\", or VRAM allocation per GPU in GB. \"auto\" is implied by default in tensor-parallel mode.")
parser.add_argument("-tp", "--tensor_parallel", action = "store_true", help = "Load in tensor-parallel mode (not fully supported for all models)")
parser.add_argument("-l", "--length", type = int, help = "Maximum sequence length")
parser.add_argument("-rs", "--rope_scale", type = float, help = "RoPE scaling factor")
parser.add_argument("-ra", "--rope_alpha", type = float, help = "RoPE alpha value (NTK)")
@@ -33,6 +34,7 @@ def print_options(args):
print_opts = []
if args.gpu_split is not None: print_opts += [f"gpu_split: {args.gpu_split}"]
if args.tensor_parallel is not None: print_opts += ["tensor_parallel"]
if args.length is not None: print_opts += [f"length: {args.length}"]
if args.rope_scale is not None: print_opts += [f"rope_scale: {args.rope_scale}"]
if args.rope_alpha is not None: print_opts += [f"rope_alpha: {args.rope_alpha}"]
@@ -131,13 +133,18 @@ def init(args,
if args.gpu_split and args.gpu_split != "auto":
split = [float(alloc) for alloc in args.gpu_split.split(",")]
if args.gpu_split != "auto" and not skip_load:
if args.tensor_parallel:
if args.gpu_split == "auto": split = None
model.load_tp(split, progress = progress)
elif args.gpu_split != "auto" and not skip_load:
if not quiet and not progress: print(" -- Loading model...")
t = time.time()
model.load(split, progress = progress)
t = time.time() - t
if benchmark and not quiet:
print(f" -- Loaded model in {t:.4f} seconds")
else:
assert allow_auto_split, "Auto split not allowed."

View File

@@ -7,6 +7,7 @@ from exllamav2 import(
ExLlamaV2Cache_Q4,
ExLlamaV2Cache_Q6,
ExLlamaV2Cache_Q8,
ExLlamaV2Cache_TP,
ExLlamaV2Tokenizer,
model_init,
)
@@ -96,11 +97,14 @@ if args.stream_layers:
model_init.check_args(args)
model_init.print_options(args)
model, tokenizer = model_init.init(args,
allow_auto_split = True,
skip_load = args.stream_layers,
benchmark = True,
max_output_len = args.max_output_len)
model, tokenizer = model_init.init(
args,
allow_auto_split = True,
skip_load = args.stream_layers,
benchmark = True,
max_output_len = args.max_output_len,
progress = True
)
cache = None
# Auto split
@@ -113,7 +117,7 @@ if not model.loaded and not args.stream_layers:
print(" -- Loading model...")
cache = ExLlamaV2Cache(model, lazy = True)
t = time.time()
model.load_autosplit(cache)
model.load_autosplit(cache, progress = True)
t = time.time() - t
print(f" -- Loaded model in {t:.4f} seconds")
@@ -185,7 +189,7 @@ if args.prompt:
with torch.inference_mode():
if cache is None:
cache = ExLlamaV2Cache(model)
cache = ExLlamaV2Cache(model) if not model.tp_context else ExLlamaV2Cache_TP(model)
ids = tokenizer.encode(args.prompt)
tokens_prompt = ids.shape[-1]
@@ -292,7 +296,8 @@ if args.eval_dataset or args.standard_perplexity:
def ppl(input_ids__, logits__, lengths__, bins = False):
logits_device = model.modules[-1].device()
logits_device = model.modules[-1].device() if not model.tp_context else \
torch.device(model.tp_context.device)
if bins:
num_bins = (max(lengths__) + 255) // 256
@@ -389,7 +394,10 @@ if args.eval_dataset or args.standard_perplexity:
sys.stdout.flush()
if cache is None:
cache = ExLlamaV2Cache(model, max_seq_len = eval_length) if eval_length > model.config.max_input_len else None
if eval_length > model.config.max_input_len:
cache = ExLlamaV2Cache(model, max_seq_len = eval_length) if not model.tp_context else ExLlamaV2Cache_TP(model, max_seq_len = eval_length)
else:
cache = None
for i in range(eval_tokens.shape[0]):
@@ -470,7 +478,8 @@ if args.eval_dataset or args.standard_perplexity:
else:
print(f" -- Inference (token)", end = "")
sys.stdout.flush()
cache = ExLlamaV2Cache(model, max_seq_len = eval_length)
cache = ExLlamaV2Cache(model, max_seq_len = eval_length) if not model.tp_context else \
ExLlamaV2Cache_TP(model, max_seq_len = eval_length)
test_ppl_token()
if args.eval_token_8bit:
@@ -479,7 +488,8 @@ if args.eval_dataset or args.standard_perplexity:
else:
print(f" -- Inference (token, 8-bit cache)", end = "")
sys.stdout.flush()
cache = ExLlamaV2Cache_8bit(model, max_seq_len = eval_length)
cache = ExLlamaV2Cache_8bit(model, max_seq_len = eval_length) if not model.tp_context else \
ExLlamaV2Cache_TP(model, max_seq_len = eval_length, base = ExLlamaV2Cache_8bit)
test_ppl_token()
if args.eval_token_q4:
@@ -488,7 +498,8 @@ if args.eval_dataset or args.standard_perplexity:
else:
print(f" -- Inference (token, Q4 cache)", end = "")
sys.stdout.flush()
cache = ExLlamaV2Cache_Q4(model, max_seq_len = eval_length)
cache = ExLlamaV2Cache_Q4(model, max_seq_len = eval_length) if not model.tp_context else \
ExLlamaV2Cache_TP(model, max_seq_len = eval_length, base = ExLlamaV2Cache_Q4)
# cache.calibrate(tokenizer)
test_ppl_token()
@@ -498,7 +509,8 @@ if args.eval_dataset or args.standard_perplexity:
else:
print(f" -- Inference (token, Q6 cache)", end = "")
sys.stdout.flush()
cache = ExLlamaV2Cache_Q6(model, max_seq_len = eval_length)
cache = ExLlamaV2Cache_Q6(model, max_seq_len = eval_length) if not model.tp_context else \
ExLlamaV2Cache_TP(model, max_seq_len = eval_length, base = ExLlamaV2Cache_Q6)
# cache.calibrate(tokenizer)
test_ppl_token()
@@ -508,7 +520,8 @@ if args.eval_dataset or args.standard_perplexity:
else:
print(f" -- Inference (token, Q8 cache)", end = "")
sys.stdout.flush()
cache = ExLlamaV2Cache_Q8(model, max_seq_len = eval_length)
cache = ExLlamaV2Cache_Q8(model, max_seq_len = eval_length) if not model.tp_context else \
ExLlamaV2Cache_TP(model, max_seq_len = eval_length, base = ExLlamaV2Cache_Q8)
# cache.calibrate(tokenizer)
test_ppl_token()
@@ -520,7 +533,7 @@ if args.prompt_speed:
with torch.inference_mode():
if cache is None:
cache = ExLlamaV2Cache(model)
cache = ExLlamaV2Cache(model) if not model.tp_context else ExLlamaV2Cache_TP(model)
ids = torch.randint(0, model.config.vocab_size - 1, (1, model.config.max_seq_len))
@@ -571,7 +584,7 @@ if args.speed:
with torch.inference_mode():
if cache is None:
cache = ExLlamaV2Cache(model)
cache = ExLlamaV2Cache(model) if not model.tp_context else ExLlamaV2Cache_TP(model)
cache.current_seq_len = 0
print(f" -- Measuring token speed...")