model_diff.py: Add device argument

This commit is contained in:
turboderp
2025-05-30 19:09:12 +02:00
parent dc69b16752
commit f8dc9975fe

View File

@@ -40,6 +40,8 @@ def get_test_tokens(tokenizer, rows, eval_len = 2048, eval_stride = 512):
@torch.inference_mode()
def main(args):
device = torch.device(args.device)
config_a = Config.from_directory(args.model_a)
config_a.override_dynamic_seq_len(2048)
tokenizer = Tokenizer.from_config(config_a)
@@ -57,7 +59,7 @@ def main(args):
for idx, (module_a, module_b) in enumerate(zip(model_a.modules, model_b.modules)):
config_a.stc.begin_deferred_load()
module_a.load("cuda:0" if not module_a.caps.get("prefer_cpu") else "cpu")
module_a.load(device if not module_a.caps.get("prefer_cpu") else "cpu")
config_a.stc.end_deferred_load()
params_a = {}
state_a = module_a.prepare_for_device(state_a, params_a)
@@ -67,7 +69,7 @@ def main(args):
free_mem()
config_b.stc.begin_deferred_load()
module_b.load("cuda:0" if not module_b.caps.get("prefer_cpu") else "cpu")
module_b.load(device if not module_b.caps.get("prefer_cpu") else "cpu")
config_b.stc.end_deferred_load()
params_b = {}
state_b = module_b.prepare_for_device(state_b, params_b)
@@ -199,6 +201,7 @@ if __name__ == "__main__":
parser.add_argument("-r", "--rows", type = int, help = "Number of rows", default = 100)
parser.add_argument("-kb", "--keep_b", type = int, help = "Maintain B state for number of modules", default = 0)
parser.add_argument("-tkm", "--topk_max", type = int, default = 5, help = "Max top-K interval to test")
parser.add_argument("-d", "--device", type = int, help = "CUDA device index", default = 0)
_args = parser.parse_args()
main(_args)