mirror of
https://github.com/turboderp-org/exllamav2.git
synced 2026-05-04 05:01:39 +00:00
590 lines
21 KiB
Python
590 lines
21 KiB
Python
from exllamav2.model import \
|
|
(
|
|
ExLlamaV2Embedding,
|
|
ExLlamaV2Attention,
|
|
ExLlamaV2MLP,
|
|
ExLlamaV2MoEMLP,
|
|
ExLlamaV2Linear,
|
|
ExLlamaV2RMSNorm,
|
|
ExLlamaV2LayerNorm
|
|
)
|
|
|
|
from safetensors import safe_open
|
|
from safetensors.torch import save_file
|
|
from conversion.qparams import QParams, qparams_headoptions, qparams_attn, qparams_mlp, get_qparams_reduced
|
|
from conversion.adaptivegptq import AdaptiveGPTQ
|
|
import torch
|
|
from torch import nn
|
|
import os, time, math, json
|
|
import torch.nn.functional as F
|
|
import gc
|
|
|
|
# graceful exiting
|
|
import signal
|
|
import sys
|
|
|
|
interrupted = False
|
|
|
|
def signal_handler(signal, frame):
|
|
global interrupted
|
|
if interrupted:
|
|
print("\nGracefully exiting...")
|
|
sys.exit(0)
|
|
else:
|
|
interrupted = True
|
|
print("\nCTRL-C again to quit or type 'exit'. You can always resume the process at a later time.")
|
|
user_input = input("\nPress Enter to continue processing or type 'exit' to quit: ").strip().lower()
|
|
if user_input == 'exit':
|
|
print("Gracefully exiting...")
|
|
sys.exit(0)
|
|
interrupted = False
|
|
|
|
signal.signal(signal.SIGINT, signal_handler)
|
|
|
|
def list_live_tensors():
|
|
|
|
tensors = {}
|
|
gc.collect()
|
|
torch.cuda.empty_cache()
|
|
|
|
for obj in gc.get_objects():
|
|
try:
|
|
if torch.is_tensor(obj) or (hasattr(obj, 'data') and torch.is_tensor(obj.data)):
|
|
d = str(obj.size()) + ", " + str(obj.dtype) + ", " + str(obj.device)
|
|
if d in tensors.keys():
|
|
tensors[d] += 1
|
|
else:
|
|
tensors[d] = 1
|
|
except:
|
|
pass
|
|
|
|
print("-----------")
|
|
for k, v in tensors.items():
|
|
print(f"{v} : {k}")
|
|
|
|
|
|
# Get initial token embeddings
|
|
|
|
def embeddings(job, save_fn, model, measure = False):
|
|
|
|
module = model.modules[0]
|
|
assert isinstance(module, ExLlamaV2Embedding)
|
|
|
|
with safe_open(job["cal_filename"], framework = "pt", device = "cpu") as f:
|
|
input_ids = f.get_tensor("input_ids")
|
|
|
|
module.load()
|
|
input_ids[input_ids >= module.native_vocab_size] = 0
|
|
hidden_state = module.forward(input_ids)
|
|
module.unload()
|
|
|
|
embeddings_dict = { f"row.{i:05}": hidden_state[i:i+1, :, :].contiguous() for i in range(hidden_state.shape[0]) }
|
|
save_file(embeddings_dict, os.path.join(job["out_dir"], "hidden_states.safetensors"))
|
|
|
|
|
|
# Test quantization options
|
|
|
|
def test_quant(source: ExLlamaV2Linear,
|
|
lq: AdaptiveGPTQ,
|
|
qparams: list):
|
|
|
|
variants = []
|
|
variants_bits = []
|
|
|
|
original = nn.Linear(source.in_features, source.out_features, source.has_bias, device = "meta", dtype = torch.float16)
|
|
original.weight = nn.Parameter(source.linear.weight.clone())
|
|
if source.has_bias: original.bias.weight = nn.Parameter(source.linear.bias.clone())
|
|
|
|
for qp in qparams:
|
|
|
|
lq.configure(qp.group_size, qp.bits, qp.bits_prop, qp.scale_bits)
|
|
lq.quantize()
|
|
quantized = lq.apply_temp()
|
|
quantized.to("cpu")
|
|
|
|
variants.append(quantized)
|
|
total_bits = qp.total_bits(quantized.weight.T.shape, original.bias.weight.shape if source.has_bias else None)
|
|
variants_bits.append(total_bits)
|
|
|
|
numel = quantized.weight.numel()
|
|
if source.has_bias: numel += original.bias.numel()
|
|
bpw = total_bits / numel
|
|
desc = qp.desc
|
|
|
|
print(f" -- {source.key:50} {desc:50} {bpw:2.2f} bpw")
|
|
|
|
return variants, variants_bits
|
|
|
|
|
|
def test_error(module, hidden_states, target_states, cache, attn_params):
|
|
|
|
rfn_sum = 0
|
|
rfn_count = 0
|
|
for x, xref in zip(hidden_states, target_states):
|
|
x = x.cuda()
|
|
xref = xref.cuda()
|
|
xtest = module.forward(x, cache, attn_params)
|
|
xtest = xtest[0].float()
|
|
xref = xref[0].float()
|
|
rfn_sum += (torch.linalg.norm(xtest - xref, 'fro') / torch.linalg.norm(xref, 'fro')).item()
|
|
rfn_count += 1
|
|
|
|
return max(1e-6, 1 - (rfn_sum / rfn_count))
|
|
|
|
|
|
def measure_attn(module, hidden_states, target_states, quantizers, cache, attn_params):
|
|
|
|
qjobs, qmaps = get_qparams_reduced(qparams_attn)
|
|
results = []
|
|
|
|
quantizers["q_proj"].prepare()
|
|
quantizers["k_proj"].reuse_h(quantizers["q_proj"])
|
|
quantizers["v_proj"].reuse_h(quantizers["q_proj"])
|
|
quantizers["o_proj"].prepare()
|
|
|
|
options_q, bits_q = test_quant(module.q_proj, quantizers["q_proj"], qjobs[0])
|
|
options_k, bits_k = test_quant(module.k_proj, quantizers["k_proj"], qjobs[1])
|
|
options_v, bits_v = test_quant(module.v_proj, quantizers["v_proj"], qjobs[2])
|
|
options_o, bits_o = test_quant(module.o_proj, quantizers["o_proj"], qjobs[3])
|
|
|
|
total_numel = module.q_proj.numel()
|
|
total_numel += module.k_proj.numel()
|
|
total_numel += module.v_proj.numel()
|
|
total_numel += module.o_proj.numel()
|
|
|
|
(q_, k_, v_, o_) = (-1, -1, -1, -1)
|
|
for (q, k, v, o) in qmaps:
|
|
|
|
if q != q_: module.q_proj.linear.weight = nn.Parameter(options_q[q].weight.cuda())
|
|
if k != k_: module.k_proj.linear.weight = nn.Parameter(options_k[k].weight.cuda())
|
|
if v != v_: module.v_proj.linear.weight = nn.Parameter(options_v[v].weight.cuda())
|
|
if o != o_: module.o_proj.linear.weight = nn.Parameter(options_o[o].weight.cuda())
|
|
(q_, k_, v_, o_) = (q, k, v, o)
|
|
|
|
total_bits = bits_q[q]
|
|
total_bits += bits_k[k]
|
|
total_bits += bits_v[v]
|
|
total_bits += bits_o[o]
|
|
total_bpw = total_bits / total_numel
|
|
|
|
accuracy = test_error(module, hidden_states, target_states, cache, attn_params)
|
|
print(f" -- {total_bpw:1.4f} bpw accuracy: {accuracy:1.8f}")
|
|
|
|
torch.cuda.empty_cache()
|
|
|
|
r = { "accuracy": accuracy,
|
|
"total_bits": total_bits,
|
|
"q_proj": qjobs[0][q].get_dict(),
|
|
"k_proj": qjobs[1][k].get_dict(),
|
|
"v_proj": qjobs[2][v].get_dict(),
|
|
"o_proj": qjobs[3][o].get_dict() }
|
|
results.append(r)
|
|
|
|
return results
|
|
|
|
|
|
def measure_mlp(module, hidden_states, target_states, quantizers, cache, attn_params):
|
|
|
|
has_gate = module.model.config.arch.mlp_gate
|
|
|
|
qjobs, qmaps = get_qparams_reduced(qparams_mlp, not has_gate)
|
|
results = []
|
|
|
|
quantizers["up_proj"].prepare()
|
|
if has_gate: quantizers["gate_proj"].reuse_h(quantizers["up_proj"])
|
|
quantizers["down_proj"].prepare()
|
|
|
|
options_g, bits_g = test_quant(module.gate_proj, quantizers[f"gate_proj"], qjobs[0]) if has_gate else (None, None)
|
|
options_u, bits_u = test_quant(module.up_proj, quantizers[f"up_proj"], qjobs[1])
|
|
options_d, bits_d = test_quant(module.down_proj, quantizers[f"down_proj"], qjobs[2])
|
|
|
|
total_numel = module.gate_proj.numel() if has_gate else 0
|
|
total_numel += module.up_proj.numel()
|
|
total_numel += module.down_proj.numel()
|
|
|
|
if has_gate:
|
|
|
|
(g_, u_, d_) = (-1, -1, -1)
|
|
for (g, u, d) in qmaps:
|
|
|
|
if g != g_: module.gate_proj.linear.weight = nn.Parameter(options_g[g].weight.cuda())
|
|
if u != u_: module.up_proj.linear.weight = nn.Parameter(options_u[u].weight.cuda())
|
|
if d != d_: module.down_proj.linear.weight = nn.Parameter(options_d[d].weight.cuda())
|
|
(g_, u_, d_) = (g, u, d)
|
|
|
|
total_bits = bits_g[g]
|
|
total_bits += bits_u[u]
|
|
total_bits += bits_d[d]
|
|
total_bpw = total_bits / total_numel
|
|
|
|
accuracy = test_error(module, hidden_states, target_states, cache, attn_params)
|
|
print(f" -- {total_bpw:1.4f} bpw accuracy: {accuracy:1.8f}")
|
|
|
|
torch.cuda.empty_cache()
|
|
|
|
r = { "accuracy": accuracy,
|
|
"total_bits": total_bits,
|
|
"gate_proj": qjobs[0][g].get_dict(),
|
|
"up_proj": qjobs[1][u].get_dict(),
|
|
"down_proj": qjobs[2][d].get_dict() }
|
|
results.append(r)
|
|
|
|
else:
|
|
|
|
(u_, d_) = (-1, -1)
|
|
for (u , d) in qmaps:
|
|
|
|
if u != u_: module.up_proj.linear.weight = nn.Parameter(options_u[u].weight.cuda())
|
|
if d != d_: module.down_proj.linear.weight = nn.Parameter(options_d[d].weight.cuda())
|
|
(u_, d_) = (u, d)
|
|
|
|
total_bits = bits_u[u]
|
|
total_bits += bits_d[d]
|
|
total_bpw = total_bits / total_numel
|
|
|
|
accuracy = test_error(module, hidden_states, target_states, cache, attn_params)
|
|
print(f" -- {total_bpw:1.4f} bpw accuracy: {accuracy:1.8f}")
|
|
|
|
torch.cuda.empty_cache()
|
|
|
|
r = { "accuracy": accuracy,
|
|
"total_bits": total_bits,
|
|
"up_proj": qjobs[1][u].get_dict(),
|
|
"down_proj": qjobs[2][d].get_dict() }
|
|
results.append(r)
|
|
|
|
|
|
return results
|
|
|
|
|
|
def measure_moe_mlp(module, hidden_states, target_states, quantizers, cache, attn_mask):
|
|
|
|
qjobs, qmaps = get_qparams_reduced(qparams_mlp)
|
|
num_experts = module.model.config.num_experts
|
|
results = []
|
|
|
|
quantizers["w1.0"].prepare()
|
|
for i in range(num_experts):
|
|
if i > 0: quantizers[f"w1.{i}"].reuse_h(quantizers["w1.0"])
|
|
quantizers[f"w2.{i}"].prepare()
|
|
quantizers[f"w3.{i}"].reuse_h(quantizers["w1.0"])
|
|
|
|
options_g, bits_g = [], []
|
|
options_u, bits_u = [], []
|
|
options_d, bits_d = [], []
|
|
for i in range(num_experts):
|
|
options_g_, bits_g_ = test_quant(module.w1[i], quantizers[f"w1.{i}"], qjobs[0])
|
|
del quantizers[f"w1.{i}"]
|
|
options_u_, bits_u_ = test_quant(module.w3[i], quantizers[f"w3.{i}"], qjobs[1])
|
|
del quantizers[f"w3.{i}"]
|
|
options_d_, bits_d_ = test_quant(module.w2[i], quantizers[f"w2.{i}"], qjobs[2])
|
|
del quantizers[f"w2.{i}"]
|
|
options_g.append(options_g_)
|
|
options_u.append(options_u_)
|
|
options_d.append(options_d_)
|
|
bits_g.append(bits_g_)
|
|
bits_u.append(bits_u_)
|
|
bits_d.append(bits_d_)
|
|
|
|
quantizers.clear()
|
|
gc.collect()
|
|
torch.cuda.empty_cache()
|
|
|
|
total_numel = sum(module.w1[i].numel() for i in range(num_experts))
|
|
total_numel += sum(module.w3[i].numel() for i in range(num_experts))
|
|
total_numel += sum(module.w2[i].numel() for i in range(num_experts))
|
|
|
|
(g_, u_, d_) = (-1, -1, -1)
|
|
for (g, u, d) in qmaps:
|
|
|
|
for i in range(num_experts):
|
|
if g != g_: module.w1[i].linear.weight = nn.Parameter(options_g[i][g].weight.cuda())
|
|
if u != u_: module.w3[i].linear.weight = nn.Parameter(options_u[i][u].weight.cuda())
|
|
if d != d_: module.w2[i].linear.weight = nn.Parameter(options_d[i][d].weight.cuda())
|
|
(g_, u_, d_) = (g, u, d)
|
|
|
|
total_bits = sum(bits_g[i][g] for i in range(num_experts))
|
|
total_bits += sum(bits_u[i][u] for i in range(num_experts))
|
|
total_bits += sum(bits_d[i][d] for i in range(num_experts))
|
|
total_bpw = total_bits / total_numel
|
|
|
|
accuracy = test_error(module, hidden_states, target_states, cache, attn_mask)
|
|
print(f" -- {total_bpw:1.4f} bpw accuracy: {accuracy:1.8f}")
|
|
|
|
torch.cuda.empty_cache()
|
|
|
|
r = { "accuracy": accuracy,
|
|
"total_bits": total_bits,
|
|
"w1": qjobs[0][g].get_dict(),
|
|
"w3": qjobs[1][u].get_dict(),
|
|
"w2": qjobs[2][d].get_dict() }
|
|
results.append(r)
|
|
|
|
return results
|
|
|
|
# helpful status box for insights around conversions
|
|
def get_remaining_time_str(estimated_time_remaining):
|
|
remaining_minutes = int(estimated_time_remaining // 60)
|
|
remaining_seconds = int(estimated_time_remaining % 60)
|
|
return f"{remaining_minutes}min {remaining_seconds}sec"
|
|
|
|
def format_line(label, box_width):
|
|
return f"| {label.ljust(box_width - 3)}|"
|
|
|
|
def print_status_box(*content_lines):
|
|
max_content_width = max(len(line) for line in content_lines)
|
|
box_width = max_content_width + 4
|
|
|
|
print('-' * box_width)
|
|
for line in content_lines:
|
|
print(format_line(line, box_width))
|
|
print('-' * box_width)
|
|
|
|
@torch.inference_mode()
|
|
def measure_quant(job, save_fn, model):
|
|
|
|
# vars for status box
|
|
time_spent_list = []
|
|
rolling_window_size = 10 # (increase to average over larger window)
|
|
completed_steps = 0
|
|
accuracy_sum = 0
|
|
accuracy_count = 0
|
|
overall_rolling_accuracy = 0
|
|
|
|
snapshot_interval = 10
|
|
temp_filename = os.path.join(job["out_dir"], "hidden_states_temp.safetensors")
|
|
states_filename = os.path.join(job["out_dir"], "hidden_states.safetensors")
|
|
measurement = job.get("measurement", {})
|
|
|
|
# Quantize
|
|
|
|
last_ckpt_layer_name = "None"
|
|
|
|
if not "last_module_idx" in job:
|
|
job["last_module_idx"] = 0
|
|
else:
|
|
i = job["last_module_idx"]
|
|
if i < len(model.modules):
|
|
last_ckpt_layer_name = f"{model.modules[i].key} ({model.modules[i].name})"
|
|
print(f" -- Resuming from layer: {last_ckpt_layer_name}")
|
|
|
|
# vars to support status box
|
|
total_modules = len(model.modules)
|
|
last_module_idx = job["last_module_idx"] # resume tracking steps where it stopped previously
|
|
remaining_steps = total_modules - last_module_idx
|
|
|
|
hidden_states = []
|
|
with safe_open(states_filename, framework = "pt", device = "cpu") as f:
|
|
for k in sorted(f.keys()):
|
|
hidden_states.append(f.get_tensor(k))
|
|
|
|
index = job["last_module_idx"]
|
|
while True:
|
|
|
|
# sig handler should catch it faster in most cases
|
|
if interrupted:
|
|
print("Measurement process was interrupted. Please decide:")
|
|
if interrupted:
|
|
print("Exiting after saving the current state.")
|
|
job["measurement"] = measurement.copy()
|
|
job["last_module_idx"] = index
|
|
save_fn()
|
|
return "interrupted"
|
|
else:
|
|
print("Resuming the process.")
|
|
|
|
index += 1
|
|
if index >= len(model.modules): break
|
|
|
|
# Timer
|
|
|
|
begin_time = time.time()
|
|
|
|
# Prepare module
|
|
|
|
module = model.modules[index]
|
|
module.load()
|
|
|
|
print(f" -- Layer: {module.key} ({module.name})")
|
|
|
|
# Create quantizers
|
|
|
|
quantizers = {}
|
|
|
|
if isinstance(module, ExLlamaV2Attention):
|
|
mode = "self_attn"
|
|
quantizers["q_proj"] = AdaptiveGPTQ(module.q_proj.linear)
|
|
quantizers["k_proj"] = AdaptiveGPTQ(module.k_proj.linear)
|
|
quantizers["v_proj"] = AdaptiveGPTQ(module.v_proj.linear)
|
|
quantizers["o_proj"] = AdaptiveGPTQ(module.o_proj.linear)
|
|
|
|
elif isinstance(module, ExLlamaV2MLP):
|
|
mode = "mlp"
|
|
has_gate = module.model.config.arch.mlp_gate
|
|
if has_gate: quantizers["gate_proj"] = AdaptiveGPTQ(module.gate_proj.linear)
|
|
quantizers["up_proj"] = AdaptiveGPTQ(module.up_proj.linear)
|
|
quantizers["down_proj"] = AdaptiveGPTQ(module.down_proj.linear)
|
|
|
|
elif isinstance(module, ExLlamaV2MoEMLP):
|
|
mode = "block_sparse_moe"
|
|
for i in range(model.config.num_experts):
|
|
quantizers[f"w1.{i}"] = AdaptiveGPTQ(module.w1[i].linear)
|
|
quantizers[f"w3.{i}"] = AdaptiveGPTQ(module.w3[i].linear)
|
|
quantizers[f"w2.{i}"] = AdaptiveGPTQ(module.w2[i].linear)
|
|
|
|
elif isinstance(module, ExLlamaV2Linear):
|
|
mode = "linear"
|
|
# Don't measure head layer
|
|
|
|
elif isinstance(module, ExLlamaV2RMSNorm) or isinstance(module, ExLlamaV2LayerNorm):
|
|
mode = "norm"
|
|
|
|
# Reference forward pass
|
|
|
|
cache = None
|
|
attn_params = ExLlamaV2Attention.Params(1, hidden_states[0].shape[1], 0, None, None) if mode == "self_attn" else None
|
|
|
|
target_states = []
|
|
if mode == "block_sparse_moe":
|
|
uncalibrated_experts = [0 for _ in range(model.config.num_experts)]
|
|
|
|
for i in range(len(hidden_states)):
|
|
|
|
x = hidden_states[i].to("cuda:0")
|
|
outputs = module.forward(x, cache, attn_params, intermediates = True)
|
|
|
|
# Hessians
|
|
|
|
if mode == "self_attn":
|
|
quantizers["q_proj"].add_batch(outputs["post_norm"]) # Reuse H for K and V
|
|
quantizers["o_proj"].add_batch(outputs["attn_output"])
|
|
|
|
if mode == "mlp":
|
|
quantizers["up_proj"].add_batch(outputs["post_norm"]) # Reuse H for gate_proj
|
|
quantizers["down_proj"].add_batch(outputs["pre_down"])
|
|
|
|
if mode == "block_sparse_moe":
|
|
for j in range(model.config.num_experts):
|
|
if f"pre_down.{j}" in outputs:
|
|
quantizers[f"w1.{j}"].add_batch(outputs["post_norm"])
|
|
quantizers[f"w2.{j}"].add_batch(outputs[f"pre_down.{j}"])
|
|
if outputs[f"pre_down.{j}"].shape[0] < outputs["post_norm"].shape[0] / 10:
|
|
uncalibrated_experts[j] += 1
|
|
else:
|
|
uncalibrated_experts[j] += 1
|
|
|
|
target_states.append(outputs["hidden_states"].to("cpu"))
|
|
|
|
# For MoE layers, warn if any layer received less than 10% of a calibration batch
|
|
|
|
if mode == "block_sparse_moe":
|
|
for j in range(model.config.num_experts):
|
|
ue = uncalibrated_experts[j]
|
|
if ue > 0:
|
|
print(f" !! Warning: w2.{j} has less than 10% calibration for {ue}/{len(hidden_states)} rows")
|
|
|
|
# Measurement
|
|
|
|
m = None
|
|
|
|
if mode == "self_attn":
|
|
m = measure_attn(module, hidden_states, target_states, quantizers, cache, attn_params)
|
|
|
|
if mode == "mlp":
|
|
m = measure_mlp(module, hidden_states, target_states, quantizers, cache, attn_params)
|
|
|
|
if mode == "block_sparse_moe":
|
|
m = measure_moe_mlp(module, hidden_states, target_states, quantizers, cache, attn_params)
|
|
|
|
measurement[module.key + "." + mode] = m
|
|
|
|
# # track overall accuracy for status box
|
|
# if m is not None and len(m) > 0:
|
|
# layer_accuracies = [result['accuracy'] for result in m]
|
|
# layer_accuracy_sum = sum(layer_accuracies)
|
|
# layer_accuracy_count = len(layer_accuracies)
|
|
#
|
|
# accuracy_sum += layer_accuracy_sum
|
|
# accuracy_count += layer_accuracy_count
|
|
# overall_rolling_accuracy = accuracy_sum / accuracy_count
|
|
|
|
# Unload module
|
|
|
|
module.unload()
|
|
torch.cuda.empty_cache()
|
|
|
|
# Advance
|
|
|
|
hidden_states = target_states
|
|
|
|
# Timing and status box
|
|
|
|
end_time = time.time()
|
|
duration = end_time - begin_time
|
|
|
|
time_spent_list.append(duration)
|
|
if len(time_spent_list) > rolling_window_size:
|
|
time_spent_list.pop(0)
|
|
average_time_per_step = sum(time_spent_list) / len(time_spent_list)
|
|
|
|
remaining_steps = total_modules - index
|
|
estimated_time_remaining = average_time_per_step * remaining_steps
|
|
completed_steps = index
|
|
|
|
completed_module_name_str = f"Measured: {module.key} ({module.name})"
|
|
duration_str = f"Duration: {duration:.2f} seconds"
|
|
completed_step_str = f"Completed step: {completed_steps}/{total_modules}"
|
|
avg_time_str = f"Avg time / step (rolling): {average_time_per_step:.2f} seconds"
|
|
remaining_time_str = f"Estimated remaining time: {get_remaining_time_str(estimated_time_remaining)}"
|
|
# overall_accuracy_str = f"Overall avg accuracy: {overall_rolling_accuracy:.8f}" if accuracy_count > 0 else ""
|
|
last_ckpt_str = f"Last checkpoint layer: {last_ckpt_layer_name}"
|
|
|
|
content_lines = [completed_module_name_str,
|
|
duration_str,
|
|
completed_step_str,
|
|
avg_time_str,
|
|
remaining_time_str,
|
|
last_ckpt_str]
|
|
|
|
# if accuracy_count > 0:
|
|
# content_lines.append(overall_accuracy_str)
|
|
|
|
print_status_box(*content_lines)
|
|
|
|
# Checkpoint
|
|
|
|
if index % snapshot_interval == 0 or index == len(model.modules) - 1:
|
|
|
|
save_dict = {f"row.{idx:05}": h for idx, h in enumerate(hidden_states)}
|
|
save_file(save_dict, temp_filename)
|
|
save_dict = None
|
|
|
|
job["invalid"] = True
|
|
save_fn()
|
|
|
|
os.replace(temp_filename, states_filename)
|
|
|
|
job["measurement"] = measurement.copy()
|
|
job["last_module_idx"] = index
|
|
|
|
last_ckpt_layer_name = f"{module.key} ({module.name})"
|
|
|
|
del job["invalid"]
|
|
save_fn()
|
|
|
|
# Export measurement
|
|
|
|
exp_measurement = { "measurement": job["measurement"],
|
|
"last_module_idx": job["last_module_idx"] }
|
|
|
|
measurement_files = [os.path.join(job["out_dir"], "measurement.json")]
|
|
if job["output_measurement"] is not None:
|
|
measurement_files += [job["output_measurement"]]
|
|
print(f" -- Writing {job['output_measurement']}")
|
|
|
|
for filename in measurement_files:
|
|
with open(filename, "w", encoding = "utf8") as f:
|
|
f.write(json.dumps(exp_measurement, indent = 4))
|
|
|
|
return "completed" # graceful exiting
|