Add parallel decoder block

This commit is contained in:
turboderp
2024-03-19 18:12:57 +01:00
parent 21772adaf9
commit 9c47269913
8 changed files with 275 additions and 20 deletions

View File

@@ -17,6 +17,7 @@ from exllamav2.generator import (
from exllamav2.attn import ExLlamaV2Attention
from exllamav2.mlp import ExLlamaV2MLP
from exllamav2.moe_mlp import ExLlamaV2MoEMLP
from exllamav2.parallel_decoder import ExLlamaV2ParallelDecoder
import argparse, os, math, time
import torch
@@ -123,6 +124,7 @@ if args.rank_reduce:
while True:
idx -= 1
module = model.modules[idx]
if isinstance(module, ExLlamaV2ParallelDecoder): break
if isinstance(module, ExLlamaV2MLP): break
if isinstance(module, ExLlamaV2MoEMLP): break
if idx < 0: