mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-07-01 12:17:00 +00:00
Refactor JSON scripts.
This commit is contained in:
@@ -63,10 +63,10 @@ def parse_xdl_cshuffle_params(params: List[str]) -> Dict[str, Any]:
|
||||
"shuffle": map_data_type(params[8]) if len(params) > 8 else "FP32",
|
||||
"output": map_data_type(params[10]) if len(params) > 10 else "FP32"
|
||||
},
|
||||
"elementwise_operation": "PASS_THROUGH",
|
||||
"device_operation": "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle"
|
||||
"elementwise_operation": "PASS_THROUGH"
|
||||
},
|
||||
"algorithm": {
|
||||
"device_operation": "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle",
|
||||
"algorithm_type": "XDL",
|
||||
"thread_block": {
|
||||
"block_size": int(params[17]) if len(params) > 17 else 256,
|
||||
@@ -183,10 +183,10 @@ def parse_xdl_cshuffle_v3_params(params: List[str]) -> Dict[str, Any]:
|
||||
"shuffle": map_data_type(params[8]) if len(params) > 8 else "FP32",
|
||||
"output": map_data_type(params[10]) if len(params) > 10 else "FP32"
|
||||
},
|
||||
"elementwise_operation": "PASS_THROUGH",
|
||||
"device_operation": "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3"
|
||||
"elementwise_operation": "PASS_THROUGH"
|
||||
},
|
||||
"algorithm": {
|
||||
"device_operation": "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3",
|
||||
"algorithm_type": "XDL",
|
||||
"thread_block": {
|
||||
"block_size": int(params[16]) if len(params) > 16 else 256,
|
||||
@@ -311,10 +311,10 @@ def parse_xdl_cshuffle_params_with_lds_extra(params: List[str]) -> Dict[str, Any
|
||||
"shuffle": map_data_type(params[8]) if len(params) > 8 else "FP32",
|
||||
"output": map_data_type(params[10]) if len(params) > 10 else "FP32"
|
||||
},
|
||||
"elementwise_operation": "PASS_THROUGH",
|
||||
"device_operation": "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle"
|
||||
"elementwise_operation": "PASS_THROUGH"
|
||||
},
|
||||
"algorithm": {
|
||||
"device_operation": "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle",
|
||||
"algorithm_type": "XDL",
|
||||
"thread_block": {
|
||||
"block_size": int(params[17]) if len(params) > 17 else 256,
|
||||
@@ -427,10 +427,10 @@ def parse_wmma_cshuffle_params(params: List[str]) -> Dict[str, Any]:
|
||||
"shuffle": map_data_type(params[8]) if len(params) > 8 else "FP16",
|
||||
"output": map_data_type(params[10]) if len(params) > 10 else "FP16"
|
||||
},
|
||||
"elementwise_operation": "PASS_THROUGH",
|
||||
"device_operation": "DeviceGroupedConvFwdMultipleD_Wmma_CShuffle"
|
||||
"elementwise_operation": "PASS_THROUGH"
|
||||
},
|
||||
"algorithm": {
|
||||
"device_operation": "DeviceGroupedConvFwdMultipleD_Wmma_CShuffle",
|
||||
"algorithm_type": "WMMA",
|
||||
"thread_block": {
|
||||
"block_size": int(params[17]) if len(params) > 17 else 128,
|
||||
@@ -648,11 +648,11 @@ def convert_instantiations(input_file: str, output_file: str):
|
||||
"shuffle": "enum (FP32, FP16, BF16, FP8, I8, I32, U8)",
|
||||
"output": "enum (FP32, FP16, BF16, FP8, I8, I32, U8)"
|
||||
},
|
||||
"elementwise_operation": "enum (BIAS, BIAS_CLAMP, BIAS_BNORM_CLAMP, BILINEAR, CLAMP, SCALE, PASS_THROUGH)",
|
||||
"device_operation": "string (DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, DeviceGroupedConvFwdMultipleD_Wmma_CShuffle, etc.)"
|
||||
"elementwise_operation": "enum (BIAS, BIAS_CLAMP, BIAS_BNORM_CLAMP, BILINEAR, CLAMP, SCALE, PASS_THROUGH)"
|
||||
},
|
||||
"algorithm_xdl": {
|
||||
"description": "Algorithm schema for XDL-based operations (algorithm_type = 'XDL')",
|
||||
"device_operation": "string (DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, DeviceGroupedConvFwdMultipleD_Wmma_CShuffle, etc.)",
|
||||
"algorithm_type": "string literal 'XDL'",
|
||||
"gridwise_xdl_gemm": {
|
||||
"ak1": "integer - A matrix K dimension vectorization",
|
||||
@@ -666,6 +666,7 @@ def convert_instantiations(input_file: str, output_file: str):
|
||||
},
|
||||
"algorithm_wmma": {
|
||||
"description": "Algorithm schema for WMMA-based operations (algorithm_type = 'WMMA')",
|
||||
"device_operation": "string (DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, DeviceGroupedConvFwdMultipleD_Wmma_CShuffle, etc.)",
|
||||
"algorithm_type": "string literal 'WMMA'",
|
||||
"gridwise_wmma_gemm": {
|
||||
"k1": "integer - K dimension vectorization",
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,67 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Script to move elementwise_operation back to signature section
|
||||
while keeping device_operation in algorithm section.
|
||||
"""
|
||||
|
||||
import json
|
||||
import sys
|
||||
|
||||
def refactor_instantiation(inst: dict) -> dict:
|
||||
"""Move elementwise_operation back to signature, keep device_operation in algorithm"""
|
||||
|
||||
# Extract elementwise_operation from algorithm
|
||||
elementwise_op = inst["algorithm"].pop("elementwise_operation", "PASS_THROUGH")
|
||||
|
||||
# Add to signature
|
||||
inst["signature"]["elementwise_operation"] = elementwise_op
|
||||
|
||||
return inst
|
||||
|
||||
def refactor_json(input_file: str, output_file: str):
|
||||
"""Main refactoring function"""
|
||||
|
||||
print(f"Loading {input_file}...")
|
||||
with open(input_file, 'r') as f:
|
||||
data = json.load(f)
|
||||
|
||||
print(f"Refactoring {len(data['instantiations'])} instantiations...")
|
||||
|
||||
# Refactor each instantiation
|
||||
for inst in data["instantiations"]:
|
||||
refactor_instantiation(inst)
|
||||
|
||||
# Update schema documentation
|
||||
if "schemas" in data:
|
||||
# Add elementwise_operation back to signature schema
|
||||
if "signature" in data["schemas"]:
|
||||
sig_schema = data["schemas"]["signature"]
|
||||
# Add after data_type
|
||||
new_sig_schema = {}
|
||||
for key, value in sig_schema.items():
|
||||
new_sig_schema[key] = value
|
||||
if key == "data_type":
|
||||
new_sig_schema["elementwise_operation"] = "enum (BIAS, BIAS_CLAMP, BIAS_BNORM_CLAMP, BILINEAR, CLAMP, SCALE, PASS_THROUGH)"
|
||||
data["schemas"]["signature"] = new_sig_schema
|
||||
|
||||
# Remove elementwise_operation from algorithm schemas (keep device_operation)
|
||||
for algo_schema_key in ["algorithm_xdl", "algorithm_wmma"]:
|
||||
if algo_schema_key in data["schemas"]:
|
||||
data["schemas"][algo_schema_key].pop("elementwise_operation", None)
|
||||
|
||||
print(f"Writing to {output_file}...")
|
||||
with open(output_file, 'w') as f:
|
||||
json.dump(data, f, indent=2)
|
||||
|
||||
print("Done!")
|
||||
|
||||
if __name__ == "__main__":
|
||||
input_file = "experimental/builder/instances/forward_conv_structured_instantiations.json"
|
||||
output_file = "experimental/builder/instances/forward_conv_structured_instantiations.json"
|
||||
|
||||
if len(sys.argv) > 1:
|
||||
input_file = sys.argv[1]
|
||||
if len(sys.argv) > 2:
|
||||
output_file = sys.argv[2]
|
||||
|
||||
refactor_json(input_file, output_file)
|
||||
82
experimental/builder/instances/refactor_json_structure.py
Normal file
82
experimental/builder/instances/refactor_json_structure.py
Normal file
@@ -0,0 +1,82 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Script to refactor forward_conv_structured_instantiations.json
|
||||
to move elementwise_operation and device_operation from signature to algorithm.
|
||||
"""
|
||||
|
||||
import json
|
||||
import sys
|
||||
|
||||
def refactor_instantiation(inst: dict) -> dict:
|
||||
"""Move elementwise_operation and device_operation from signature to algorithm"""
|
||||
|
||||
# Extract fields from signature
|
||||
elementwise_op = inst["signature"].pop("elementwise_operation", "PASS_THROUGH")
|
||||
device_op = inst["signature"].pop("device_operation", "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle")
|
||||
|
||||
# Add to algorithm at the beginning for better readability
|
||||
algorithm = inst["algorithm"]
|
||||
new_algorithm = {
|
||||
"elementwise_operation": elementwise_op,
|
||||
"device_operation": device_op
|
||||
}
|
||||
|
||||
# Copy rest of algorithm fields
|
||||
for key, value in algorithm.items():
|
||||
new_algorithm[key] = value
|
||||
|
||||
inst["algorithm"] = new_algorithm
|
||||
|
||||
return inst
|
||||
|
||||
def refactor_json(input_file: str, output_file: str):
|
||||
"""Main refactoring function"""
|
||||
|
||||
print(f"Loading {input_file}...")
|
||||
with open(input_file, 'r') as f:
|
||||
data = json.load(f)
|
||||
|
||||
print(f"Refactoring {len(data['instantiations'])} instantiations...")
|
||||
|
||||
# Refactor each instantiation
|
||||
for inst in data["instantiations"]:
|
||||
refactor_instantiation(inst)
|
||||
|
||||
# Update schema documentation
|
||||
if "schemas" in data:
|
||||
# Update signature schema - remove elementwise_operation and device_operation
|
||||
if "signature" in data["schemas"]:
|
||||
sig_schema = data["schemas"]["signature"]
|
||||
sig_schema.pop("elementwise_operation", None)
|
||||
sig_schema.pop("device_operation", None)
|
||||
|
||||
# Add these fields to algorithm schemas
|
||||
for algo_schema_key in ["algorithm_xdl", "algorithm_wmma"]:
|
||||
if algo_schema_key in data["schemas"]:
|
||||
algo_schema = data["schemas"][algo_schema_key]
|
||||
# Add at the beginning of the description
|
||||
if "elementwise_operation" not in algo_schema:
|
||||
# Insert before other fields
|
||||
new_schema = {
|
||||
"elementwise_operation": "enum (BIAS, BIAS_CLAMP, BIAS_BNORM_CLAMP, BILINEAR, CLAMP, SCALE, PASS_THROUGH)",
|
||||
"device_operation": "string (DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, DeviceGroupedConvFwdMultipleD_Wmma_CShuffle, etc.)"
|
||||
}
|
||||
new_schema.update(algo_schema)
|
||||
data["schemas"][algo_schema_key] = new_schema
|
||||
|
||||
print(f"Writing to {output_file}...")
|
||||
with open(output_file, 'w') as f:
|
||||
json.dump(data, f, indent=2)
|
||||
|
||||
print("Done!")
|
||||
|
||||
if __name__ == "__main__":
|
||||
input_file = "experimental/builder/instances/forward_conv_structured_instantiations.json"
|
||||
output_file = "experimental/builder/instances/forward_conv_structured_instantiations.json"
|
||||
|
||||
if len(sys.argv) > 1:
|
||||
input_file = sys.argv[1]
|
||||
if len(sys.argv) > 2:
|
||||
output_file = sys.argv[2]
|
||||
|
||||
refactor_json(input_file, output_file)
|
||||
Reference in New Issue
Block a user