Add more quant options

This commit is contained in:
turboderp
2023-09-16 15:04:12 +02:00
parent e25cfdac80
commit 2f72437fcb
3 changed files with 33 additions and 3 deletions

View File

@@ -93,8 +93,10 @@ qparams_options = \
[
QParams(32, [3, 2], [0.05, 0.95], 4),
QParams(32, [3, 2], [0.25, 0.75], 4),
QParams(32, [4, 2], [0.25, 0.75], 4),
QParams(32, [4, 3, 2], [0.1, 0.4, 0.5], 4),
QParams(32, [4, 3], [0.1, 0.9], 4),
QParams(32, [6, 3], [0.2, 0.8], 4),
QParams(128, [3], [1.0], 4),
QParams(32, [3], [1.0], 4),
QParams(32, [4, 3], [0.05, 0.95], 4),
@@ -102,10 +104,15 @@ qparams_options = \
QParams(64, [4, 3], [0.6, 0.4], 4),
QParams(128, [4], [1.0], 4),
QParams(32, [4], [1.0], 4),
QParams(32, [5, 4], [0.1, 0.9], 4),
QParams(32, [6, 4], [0.1, 0.9], 4),
QParams(128, [5], [1.0], 4),
QParams(32, [6, 5], [0.1, 0.9], 4),
QParams(32, [8, 6, 5], [0.05, 0.05, 0.9], 4),
QParams(32, [6, 5], [0.4, 0.6], 4),
QParams(32, [8, 6, 5], [0.1, 0.3, 0.6], 4),
QParams(128, [6], [1.0], 4),
QParams(32, [6], [1.0], 4),
QParams(128, [8, 6], [0.1, 0.9], 4),
QParams(32, [8], [1.0], 4),
]

View File

@@ -6,6 +6,7 @@ from conversion.tokenize import tokenize
from conversion.quantize import embeddings, measure_quant, quant
from conversion.optimize import optimize
from conversion.compile import compile_model
from conversion.qparams import qparams_headoptions
# import tracemalloc
# tracemalloc.start()
@@ -26,6 +27,30 @@ parser.add_argument("-ss", "--shard_size", type = str, help = "Max shard size in
args = parser.parse_args()
# Check some args
if not args.in_dir:
print(" ## Please specify input model directory (-i, --in_dir)")
sys.exit()
if not args.out_dir:
print(" ## Please specify output/working directory (-o, --out_dir)")
sys.exit()
if not args.cal_dataset:
print(" ## Please specify dataset Parquet file (-c, --cal_dataset)")
sys.exit()
if args.length > 2048 or args.measurement_length > 2048:
print(" !! Warning: calibration rows > 2048 tokens may result in excessive VRAM use")
if not args.head_bits in qparams_headoptions:
print(f" ## Error: {args.head_bits} is not a supported option for head layer bitrate")
sys.exit()
if args.bits < 2 or args.bits > 8:
print(f" !! Warning: target bitrate {args.bits} will likely not be attainable")
# Arguments
in_dir = None if args.in_dir == "" else os.path.abspath(args.in_dir)

View File

@@ -117,7 +117,6 @@ if args.eval_dataset:
logprob_count = 0
for i in range(eval_tokens.shape[0]):
#for i in range(126, 127):
if i % 10 == 0: print(".", end = "")
sys.stdout.flush()
@@ -126,8 +125,7 @@ if args.eval_dataset:
input_ids = input_ids[:, :-1]
logits = model.forward(input_ids)
# print (tokenizer.decode(input_ids))
logits = logits.float() + 1e-10
target_ids = input_ids[:, 1:].to(logits.device)