Update TP example

This commit is contained in:
turboderp
2024-08-22 13:32:25 +02:00
parent 4117daa546
commit 555c360798

View File

@@ -10,9 +10,22 @@ config = ExLlamaV2Config(model_dir)
config.arch_compat_overrides()
config.no_graphs = True
model = ExLlamaV2(config)
model.load_tp(progress = True)
# Load the model in tensor-parallel mode. With no gpu_split specified, the model will attempt to split across
# all visible devices according to the currently available VRAM on each. expect_cache_tokens is necessary for
# balancing the split, in case the GPUs are of uneven sizes, or if the number of GPUs doesn't divide the number
# of KV heads in the model
#
# The cache type for a TP model is always ExLlamaV2Cache_TP and should be allocated after the model. To use a
# quantized cache, add a `base = ExLlamaV2Cache_Q6` etc. argument to the cache constructor. It's advisable
# to also add `expect_cache_base = ExLlamaV2Cache_Q6` to load_tp() as well so the size can be correctly
# accounted for when splitting the model.
model.load_tp(progress = True, expect_cache_tokens = 16384)
cache = ExLlamaV2Cache_TP(model, max_seq_len = 16384)
# After loading the model, all other functions should work the same
print("Loading tokenizer...")
tokenizer = ExLlamaV2Tokenizer(config)