mirror of
https://github.com/turboderp-org/exllamav3.git
synced 2026-04-20 14:29:51 +00:00
model_diff.py: Add device argument
This commit is contained in:
@@ -40,6 +40,8 @@ def get_test_tokens(tokenizer, rows, eval_len = 2048, eval_stride = 512):
|
||||
@torch.inference_mode()
|
||||
def main(args):
|
||||
|
||||
device = torch.device(args.device)
|
||||
|
||||
config_a = Config.from_directory(args.model_a)
|
||||
config_a.override_dynamic_seq_len(2048)
|
||||
tokenizer = Tokenizer.from_config(config_a)
|
||||
@@ -57,7 +59,7 @@ def main(args):
|
||||
for idx, (module_a, module_b) in enumerate(zip(model_a.modules, model_b.modules)):
|
||||
|
||||
config_a.stc.begin_deferred_load()
|
||||
module_a.load("cuda:0" if not module_a.caps.get("prefer_cpu") else "cpu")
|
||||
module_a.load(device if not module_a.caps.get("prefer_cpu") else "cpu")
|
||||
config_a.stc.end_deferred_load()
|
||||
params_a = {}
|
||||
state_a = module_a.prepare_for_device(state_a, params_a)
|
||||
@@ -67,7 +69,7 @@ def main(args):
|
||||
free_mem()
|
||||
|
||||
config_b.stc.begin_deferred_load()
|
||||
module_b.load("cuda:0" if not module_b.caps.get("prefer_cpu") else "cpu")
|
||||
module_b.load(device if not module_b.caps.get("prefer_cpu") else "cpu")
|
||||
config_b.stc.end_deferred_load()
|
||||
params_b = {}
|
||||
state_b = module_b.prepare_for_device(state_b, params_b)
|
||||
@@ -199,6 +201,7 @@ if __name__ == "__main__":
|
||||
parser.add_argument("-r", "--rows", type = int, help = "Number of rows", default = 100)
|
||||
parser.add_argument("-kb", "--keep_b", type = int, help = "Maintain B state for number of modules", default = 0)
|
||||
parser.add_argument("-tkm", "--topk_max", type = int, default = 5, help = "Max top-K interval to test")
|
||||
parser.add_argument("-d", "--device", type = int, help = "CUDA device index", default = 0)
|
||||
|
||||
_args = parser.parse_args()
|
||||
main(_args)
|
||||
|
||||
Reference in New Issue
Block a user