Refactor JSON scripts.

This commit is contained in:
Ville Pietilä
2025-11-05 04:08:53 -06:00
parent 116e0c1c61
commit 3c6aae58f7
4 changed files with 1585 additions and 1434 deletions

View File

@@ -63,10 +63,10 @@ def parse_xdl_cshuffle_params(params: List[str]) -> Dict[str, Any]:
"shuffle": map_data_type(params[8]) if len(params) > 8 else "FP32",
"output": map_data_type(params[10]) if len(params) > 10 else "FP32"
},
"elementwise_operation": "PASS_THROUGH",
"device_operation": "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle"
"elementwise_operation": "PASS_THROUGH"
},
"algorithm": {
"device_operation": "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle",
"algorithm_type": "XDL",
"thread_block": {
"block_size": int(params[17]) if len(params) > 17 else 256,
@@ -183,10 +183,10 @@ def parse_xdl_cshuffle_v3_params(params: List[str]) -> Dict[str, Any]:
"shuffle": map_data_type(params[8]) if len(params) > 8 else "FP32",
"output": map_data_type(params[10]) if len(params) > 10 else "FP32"
},
"elementwise_operation": "PASS_THROUGH",
"device_operation": "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3"
"elementwise_operation": "PASS_THROUGH"
},
"algorithm": {
"device_operation": "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3",
"algorithm_type": "XDL",
"thread_block": {
"block_size": int(params[16]) if len(params) > 16 else 256,
@@ -311,10 +311,10 @@ def parse_xdl_cshuffle_params_with_lds_extra(params: List[str]) -> Dict[str, Any
"shuffle": map_data_type(params[8]) if len(params) > 8 else "FP32",
"output": map_data_type(params[10]) if len(params) > 10 else "FP32"
},
"elementwise_operation": "PASS_THROUGH",
"device_operation": "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle"
"elementwise_operation": "PASS_THROUGH"
},
"algorithm": {
"device_operation": "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle",
"algorithm_type": "XDL",
"thread_block": {
"block_size": int(params[17]) if len(params) > 17 else 256,
@@ -427,10 +427,10 @@ def parse_wmma_cshuffle_params(params: List[str]) -> Dict[str, Any]:
"shuffle": map_data_type(params[8]) if len(params) > 8 else "FP16",
"output": map_data_type(params[10]) if len(params) > 10 else "FP16"
},
"elementwise_operation": "PASS_THROUGH",
"device_operation": "DeviceGroupedConvFwdMultipleD_Wmma_CShuffle"
"elementwise_operation": "PASS_THROUGH"
},
"algorithm": {
"device_operation": "DeviceGroupedConvFwdMultipleD_Wmma_CShuffle",
"algorithm_type": "WMMA",
"thread_block": {
"block_size": int(params[17]) if len(params) > 17 else 128,
@@ -648,11 +648,11 @@ def convert_instantiations(input_file: str, output_file: str):
"shuffle": "enum (FP32, FP16, BF16, FP8, I8, I32, U8)",
"output": "enum (FP32, FP16, BF16, FP8, I8, I32, U8)"
},
"elementwise_operation": "enum (BIAS, BIAS_CLAMP, BIAS_BNORM_CLAMP, BILINEAR, CLAMP, SCALE, PASS_THROUGH)",
"device_operation": "string (DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, DeviceGroupedConvFwdMultipleD_Wmma_CShuffle, etc.)"
"elementwise_operation": "enum (BIAS, BIAS_CLAMP, BIAS_BNORM_CLAMP, BILINEAR, CLAMP, SCALE, PASS_THROUGH)"
},
"algorithm_xdl": {
"description": "Algorithm schema for XDL-based operations (algorithm_type = 'XDL')",
"device_operation": "string (DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, DeviceGroupedConvFwdMultipleD_Wmma_CShuffle, etc.)",
"algorithm_type": "string literal 'XDL'",
"gridwise_xdl_gemm": {
"ak1": "integer - A matrix K dimension vectorization",
@@ -666,6 +666,7 @@ def convert_instantiations(input_file: str, output_file: str):
},
"algorithm_wmma": {
"description": "Algorithm schema for WMMA-based operations (algorithm_type = 'WMMA')",
"device_operation": "string (DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, DeviceGroupedConvFwdMultipleD_Wmma_CShuffle, etc.)",
"algorithm_type": "string literal 'WMMA'",
"gridwise_wmma_gemm": {
"k1": "integer - K dimension vectorization",

View File

@@ -0,0 +1,67 @@
#!/usr/bin/env python3
"""
Script to move elementwise_operation back to signature section
while keeping device_operation in algorithm section.
"""
import json
import sys
def refactor_instantiation(inst: dict) -> dict:
"""Move elementwise_operation back to signature, keep device_operation in algorithm"""
# Extract elementwise_operation from algorithm
elementwise_op = inst["algorithm"].pop("elementwise_operation", "PASS_THROUGH")
# Add to signature
inst["signature"]["elementwise_operation"] = elementwise_op
return inst
def refactor_json(input_file: str, output_file: str):
"""Main refactoring function"""
print(f"Loading {input_file}...")
with open(input_file, 'r') as f:
data = json.load(f)
print(f"Refactoring {len(data['instantiations'])} instantiations...")
# Refactor each instantiation
for inst in data["instantiations"]:
refactor_instantiation(inst)
# Update schema documentation
if "schemas" in data:
# Add elementwise_operation back to signature schema
if "signature" in data["schemas"]:
sig_schema = data["schemas"]["signature"]
# Add after data_type
new_sig_schema = {}
for key, value in sig_schema.items():
new_sig_schema[key] = value
if key == "data_type":
new_sig_schema["elementwise_operation"] = "enum (BIAS, BIAS_CLAMP, BIAS_BNORM_CLAMP, BILINEAR, CLAMP, SCALE, PASS_THROUGH)"
data["schemas"]["signature"] = new_sig_schema
# Remove elementwise_operation from algorithm schemas (keep device_operation)
for algo_schema_key in ["algorithm_xdl", "algorithm_wmma"]:
if algo_schema_key in data["schemas"]:
data["schemas"][algo_schema_key].pop("elementwise_operation", None)
print(f"Writing to {output_file}...")
with open(output_file, 'w') as f:
json.dump(data, f, indent=2)
print("Done!")
if __name__ == "__main__":
input_file = "experimental/builder/instances/forward_conv_structured_instantiations.json"
output_file = "experimental/builder/instances/forward_conv_structured_instantiations.json"
if len(sys.argv) > 1:
input_file = sys.argv[1]
if len(sys.argv) > 2:
output_file = sys.argv[2]
refactor_json(input_file, output_file)

View File

@@ -0,0 +1,82 @@
#!/usr/bin/env python3
"""
Script to refactor forward_conv_structured_instantiations.json
to move elementwise_operation and device_operation from signature to algorithm.
"""
import json
import sys
def refactor_instantiation(inst: dict) -> dict:
"""Move elementwise_operation and device_operation from signature to algorithm"""
# Extract fields from signature
elementwise_op = inst["signature"].pop("elementwise_operation", "PASS_THROUGH")
device_op = inst["signature"].pop("device_operation", "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle")
# Add to algorithm at the beginning for better readability
algorithm = inst["algorithm"]
new_algorithm = {
"elementwise_operation": elementwise_op,
"device_operation": device_op
}
# Copy rest of algorithm fields
for key, value in algorithm.items():
new_algorithm[key] = value
inst["algorithm"] = new_algorithm
return inst
def refactor_json(input_file: str, output_file: str):
"""Main refactoring function"""
print(f"Loading {input_file}...")
with open(input_file, 'r') as f:
data = json.load(f)
print(f"Refactoring {len(data['instantiations'])} instantiations...")
# Refactor each instantiation
for inst in data["instantiations"]:
refactor_instantiation(inst)
# Update schema documentation
if "schemas" in data:
# Update signature schema - remove elementwise_operation and device_operation
if "signature" in data["schemas"]:
sig_schema = data["schemas"]["signature"]
sig_schema.pop("elementwise_operation", None)
sig_schema.pop("device_operation", None)
# Add these fields to algorithm schemas
for algo_schema_key in ["algorithm_xdl", "algorithm_wmma"]:
if algo_schema_key in data["schemas"]:
algo_schema = data["schemas"][algo_schema_key]
# Add at the beginning of the description
if "elementwise_operation" not in algo_schema:
# Insert before other fields
new_schema = {
"elementwise_operation": "enum (BIAS, BIAS_CLAMP, BIAS_BNORM_CLAMP, BILINEAR, CLAMP, SCALE, PASS_THROUGH)",
"device_operation": "string (DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, DeviceGroupedConvFwdMultipleD_Wmma_CShuffle, etc.)"
}
new_schema.update(algo_schema)
data["schemas"][algo_schema_key] = new_schema
print(f"Writing to {output_file}...")
with open(output_file, 'w') as f:
json.dump(data, f, indent=2)
print("Done!")
if __name__ == "__main__":
input_file = "experimental/builder/instances/forward_conv_structured_instantiations.json"
output_file = "experimental/builder/instances/forward_conv_structured_instantiations.json"
if len(sys.argv) > 1:
input_file = sys.argv[1]
if len(sys.argv) > 2:
output_file = sys.argv[2]
refactor_json(input_file, output_file)