model_diff.py: Use deferred load and close file handles between modules

This commit is contained in:
turboderp
2025-05-12 21:23:48 +02:00
parent a905cffb1a
commit 5c3ff204c4

View File

@@ -56,18 +56,24 @@ def main(args):
for idx, (module_a, module_b) in enumerate(zip(model_a.modules, model_b.modules)):
config_a.stc.begin_deferred_load()
module_a.load("cuda:0" if not module_a.caps.get("prefer_cpu") else "cpu")
config_a.stc.end_deferred_load()
params_a = {}
state_a = module_a.prepare_for_device(state_a, params_a)
state_a = module_a.forward(state_a, params_a)
module_a.unload()
config_a.stc.close()
free_mem()
config_b.stc.begin_deferred_load()
module_b.load("cuda:0" if not module_b.caps.get("prefer_cpu") else "cpu")
config_b.stc.end_deferred_load()
params_b = {}
state_b = module_b.prepare_for_device(state_b, params_b)
state_b = module_b.forward(state_b, params_b)
module_b.unload()
config_b.stc.close()
free_mem()
if idx < args.keep_b: