diff --git a/Jenkinsfile b/Jenkinsfile index a7dc8360ee..7cfd3c1c90 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -1229,11 +1229,24 @@ pipeline { -D CMAKE_BUILD_TYPE=Release \ -D GPU_TARGETS="gfx90a" \ -D GEMM_DATATYPE="fp8;fp16" \ + -D GEMM_LAYOUT="rcr;rrr;crr;ccr" \ -DCMAKE_CXX_FLAGS=" -O3 " .. && \ - ninja -j64 benchmark_gemm_fp8 && \ - ./bin/benchmark_gemm_fp8 && \ - ninja -j64 benchmark_gemm_fp16 && \ - ./bin/benchmark_gemm_fp16 """ + ninja -j64 benchmark_gemm_fp8_rcr && \ + ./bin/benchmark_gemm_fp8_rcr && \ + ninja -j64 benchmark_gemm_fp16_rcr && \ + ./bin/benchmark_gemm_fp16_rcr && \ + ninja -j64 benchmark_gemm_fp8_crr && \ + ./bin/benchmark_gemm_fp8_crr && \ + ninja -j64 benchmark_gemm_fp16_crr && \ + ./bin/benchmark_gemm_fp16_crr && \ + ninja -j64 benchmark_gemm_fp8_ccr && \ + ./bin/benchmark_gemm_fp8_ccr && \ + ninja -j64 benchmark_gemm_fp16_ccr && \ + ./bin/benchmark_gemm_fp16_ccr && \ + ninja -j64 benchmark_gemm_fp8_rrr && \ + ./bin/benchmark_gemm_fp8_rrr && \ + ninja -j64 benchmark_gemm_fp16_rrr && \ + ./bin/benchmark_gemm_fp16_rrr """ } steps{ buildHipClangJobAndReboot(setup_args:setup_args, no_reboot:true, build_type: 'Release', execute_cmd: execute_args) @@ -1254,11 +1267,24 @@ pipeline { -D CMAKE_BUILD_TYPE=Release \ -D GPU_TARGETS="gfx942" \ -D GEMM_DATATYPE="fp8;fp16" \ + -D GEMM_LAYOUT="rcr;rrr;crr;ccr" \ -DCMAKE_CXX_FLAGS=" -O3 " .. && \ - ninja -j128 benchmark_gemm_fp8 && \ - ./bin/benchmark_gemm_fp8 && \ - ninja -j128 benchmark_gemm_fp16 && \ - ./bin/benchmark_gemm_fp16 """ + ninja -j64 benchmark_gemm_fp8_rcr && \ + ./bin/benchmark_gemm_fp8_rcr && \ + ninja -j64 benchmark_gemm_fp16_rcr && \ + ./bin/benchmark_gemm_fp16_rcr && \ + ninja -j64 benchmark_gemm_fp8_crr && \ + ./bin/benchmark_gemm_fp8_crr && \ + ninja -j64 benchmark_gemm_fp16_crr && \ + ./bin/benchmark_gemm_fp16_crr && \ + ninja -j64 benchmark_gemm_fp8_ccr && \ + ./bin/benchmark_gemm_fp8_ccr && \ + ninja -j64 benchmark_gemm_fp16_ccr && \ + ./bin/benchmark_gemm_fp16_ccr && \ + ninja -j64 benchmark_gemm_fp8_rrr && \ + ./bin/benchmark_gemm_fp8_rrr && \ + ninja -j64 benchmark_gemm_fp16_rrr && \ + ./bin/benchmark_gemm_fp16_rrr """ } steps{ buildHipClangJobAndReboot(setup_args:setup_args, no_reboot:true, build_type: 'Release', execute_cmd: execute_args) diff --git a/script/cmake-ck-dev.sh b/script/cmake-ck-dev.sh index 4d0836af39..839b6c4f08 100755 --- a/script/cmake-ck-dev.sh +++ b/script/cmake-ck-dev.sh @@ -5,13 +5,16 @@ rm -rf CMakeFiles MY_PROJECT_SOURCE=$1 -if [ $# -ge 2 ] ; then +if [ $# -ge 2 ] && [[ "$2" =~ ^gfx ]]; then GPU_TARGETS=$2 shift 2 + echo "GPU targets provided: $GPU_TARGETS" REST_ARGS=$@ else + echo "No GPU targets provided, using default targets: gfx908;gfx90a;gfx942" GPU_TARGETS="gfx908;gfx90a;gfx942" - REST_ARGS= + shift 1 + REST_ARGS=$@ fi cmake \ diff --git a/script/cmake-ck-release.sh b/script/cmake-ck-release.sh index acb04ac75f..311ea91822 100755 --- a/script/cmake-ck-release.sh +++ b/script/cmake-ck-release.sh @@ -5,13 +5,16 @@ rm -rf CMakeFiles MY_PROJECT_SOURCE=$1 -if [ $# -ge 2 ] ; then +if [ $# -ge 2 ] && [[ "$2" =~ ^gfx ]]; then GPU_TARGETS=$2 shift 2 + echo "GPU targets provided: $GPU_TARGETS" REST_ARGS=$@ else + echo "No GPU targets provided, using default targets: gfx908;gfx90a;gfx942" GPU_TARGETS="gfx908;gfx90a;gfx942" - REST_ARGS= + shift 1 + REST_ARGS=$@ fi cmake \ diff --git a/tile_engine/ops/gemm/CMakeLists.txt b/tile_engine/ops/gemm/CMakeLists.txt index 5db55f02d5..fe9b7802a7 100644 --- a/tile_engine/ops/gemm/CMakeLists.txt +++ b/tile_engine/ops/gemm/CMakeLists.txt @@ -1,21 +1,32 @@ set(GEMM_DATATYPE "fp8;fp16" CACHE STRING "List of datatypes for GEMM (semicolon-separated)") +set(GEMM_LAYOUT "rcr" CACHE STRING "List of layout for GEMM (semicolon-separated)") -function(build_gemm_for_datatype datatype) - set(working_path "${CMAKE_CURRENT_BINARY_DIR}/${datatype}/") - set(json_blob "${CMAKE_CURRENT_LIST_DIR}/configs/default_config.json") - #set(json_blob "${CMAKE_CURRENT_LIST_DIR}/configs/user_provided_config.json") +function(build_gemm_for_datatype datatype layout) + set(working_path "${CMAKE_CURRENT_BINARY_DIR}/${datatype}/${layout}") + + # Comment this if-else block when using user_provided_config + if(layout STREQUAL "rcr") + set(json_blob "${CMAKE_CURRENT_LIST_DIR}/configs/default_config.json") + else() + set(json_blob "${CMAKE_CURRENT_LIST_DIR}/configs/custom_ci_config.json") + endif() + + # uncomment this if you want to use user_provided_config.json + # set(json_blob "${CMAKE_CURRENT_LIST_DIR}/configs/user_provided_config.json") + # Generate kernel list execute_process( COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/gemm_instance_builder.py --working_path ${working_path} --datatype ${datatype} + --layout ${layout} --config_json ${json_blob} --list_blobs RESULT_VARIABLE ret ) if(NOT ret EQUAL 0) - message(FATAL_ERROR "Failed to list kernels for ${datatype}: ${ret}") + message(FATAL_ERROR "Failed to list kernels for ${datatype} ${layout}: ${ret}") endif() file(STRINGS "${working_path}/gemm_instance_blobs.txt" codegen_blobs) @@ -27,11 +38,12 @@ function(build_gemm_for_datatype datatype) COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/gemm_instance_builder.py --working_path "${working_path}" --datatype ${datatype} + --layout ${layout} --config_json "${json_blob}" --gen_blobs - COMMENT "Generating GEMM instance sources for ${datatype}" + COMMENT "Generating GEMM instance sources for ${datatype} ${layout}" ) - add_custom_target(gemm_gen_${datatype} DEPENDS ${codegen_blobs}) + add_custom_target(gemm_gen_${datatype}_${layout} DEPENDS ${codegen_blobs}) set(intermediate_libs) list(LENGTH codegen_blobs codegen_blobs_len) @@ -69,7 +81,7 @@ function(build_gemm_for_datatype datatype) #list(LENGTH chunk_files chunk_files_len) #if(chunk_files_len AND chunk_files_len GREATER 1) if(chunk_files) - set(sub_intermediate_lib_name "gemm_objlib_${name}_${i}_${datatype}") + set(sub_intermediate_lib_name "gemm_objlib_${name}_${i}_${datatype}_${layout}") add_library(${sub_intermediate_lib_name} OBJECT ${chunk_files}) list(APPEND sub_intermediate_libs ${sub_intermediate_lib_name}) endif() @@ -80,7 +92,7 @@ function(build_gemm_for_datatype datatype) #list(LENGTH sub_intermediate_libs sub_intermediate_libs_len) #if(sub_intermediate_libs AND sub_intermediate_libs_len GREATER 1) if(sub_intermediate_libs) - set(intermediate_lib_name "gemm_staticlib_${name}_${datatype}") + set(intermediate_lib_name "gemm_staticlib_${name}_${datatype}_${layout}") # Collect the $ expressions set(obj_exprs) @@ -89,7 +101,7 @@ function(build_gemm_for_datatype datatype) endforeach() add_library(${intermediate_lib_name} STATIC ${obj_exprs}) - add_dependencies(${intermediate_lib_name} gemm_gen_${datatype}) + add_dependencies(${intermediate_lib_name} gemm_gen_${datatype}_${layout}) #foreach(objlib IN LISTS sub_intermediate_libs) # target_sources(${intermediate_lib_name} PRIVATE $) #endforeach() @@ -99,28 +111,28 @@ function(build_gemm_for_datatype datatype) endforeach() # Interface library for instances - add_library(gemm_template_instances_${datatype} INTERFACE) - add_dependencies(gemm_template_instances_${datatype} gemm_gen_${datatype}) - target_link_libraries(gemm_template_instances_${datatype} INTERFACE ${intermediate_libs}) - target_include_directories(gemm_template_instances_${datatype} INTERFACE + add_library(gemm_template_instances_${datatype}_${layout} INTERFACE) + add_dependencies(gemm_template_instances_${datatype}_${layout} gemm_gen_${datatype}_${layout}) + target_link_libraries(gemm_template_instances_${datatype}_${layout} INTERFACE ${intermediate_libs}) + target_include_directories(gemm_template_instances_${datatype}_${layout} INTERFACE ${CMAKE_CURRENT_LIST_DIR} "${working_path}" ) - set_target_properties(gemm_template_instances_${datatype} PROPERTIES LINKER_LANGUAGE CXX) + set_target_properties(gemm_template_instances_${datatype}_${layout} PROPERTIES LINKER_LANGUAGE CXX) # Host API interface library - add_library(gemm_host_api_${datatype} INTERFACE) - target_link_libraries(gemm_host_api_${datatype} INTERFACE gemm_template_instances_${datatype}) - target_include_directories(gemm_host_api_${datatype} INTERFACE + add_library(gemm_host_api_${datatype}_${layout} INTERFACE) + target_link_libraries(gemm_host_api_${datatype}_${layout} INTERFACE gemm_template_instances_${datatype}_${layout}) + target_include_directories(gemm_host_api_${datatype}_${layout} INTERFACE ${CMAKE_CURRENT_LIST_DIR} "${working_path}" ) # Executable per datatype - set(exec_name "benchmark_gemm_${datatype}") + set(exec_name "benchmark_gemm_${datatype}_${layout}") add_executable(${exec_name} benchmark_gemm.cpp) - target_link_libraries(${exec_name} PRIVATE gemm_host_api_${datatype}) + target_link_libraries(${exec_name} PRIVATE gemm_host_api_${datatype}_${layout}) target_compile_options(${exec_name} PRIVATE -Wno-undefined-func-template -Wno-float-equal @@ -130,5 +142,7 @@ endfunction() # Process each datatype in isolation foreach(dt IN LISTS GEMM_DATATYPE) - build_gemm_for_datatype(${dt}) + foreach(l IN LISTS GEMM_LAYOUT) + build_gemm_for_datatype(${dt} ${l}) + endforeach() endforeach() diff --git a/tile_engine/ops/gemm/README.md b/tile_engine/ops/gemm/README.md index e74da4b958..a16b74d297 100644 --- a/tile_engine/ops/gemm/README.md +++ b/tile_engine/ops/gemm/README.md @@ -7,6 +7,7 @@ CK Tile Engine GEMM is used to generate and run GEMM kernels with different comb Users can specify custom kernel configurations such as tile size, warp size, padding, pipeline, scheduler, and epilogue in the config file. This allows building only for selected configurations, significantly reducing build time. For reference please see `./configs/user_provided_config.json`. + The Tile engine also has a default kernel configuration for providing range of configuration parameter values, which helps users who lack kernel development experience to benchmark. For reference please see in `./configs/default_config.json` If user does not provide kernel configuration, the tile engine uses default kernel configuration to generate kernel instances and benchmark. @@ -18,25 +19,28 @@ mkdir build && cd build # build composable kernel # replace [Arch] with the appropriate architecture or leave blank and # replace [Datatype1;Datatype2;...] in comma separated datatypes string (possible datatypes are [fp8, bf8, int8, fp16, bf16]) -sh ../script/cmake-ck-dev.sh ../ [Arch] -DGEMM_DATATYPE="[Datatype1;Datatype2]" +# replace [Layout1;Layout2;...] in comma separated datatypes string (possible layouts are [rcr, rrr, crr, ccr]) +sh ../script/cmake-ck-dev.sh ../ [Arch] -DGEMM_DATATYPE="[Datatype1;Datatype2]" -DGEMM_LAYOUT="[Layout1;Layout2]" # generate different executable for each passed datatype -make benchmark_gemm_[Datatype1] -j -make benchmark_gemm_[Datatype2] -j +make benchmark_gemm_[Datatype1]_[Layout1] -j +make benchmark_gemm_[Datatype1]_[Layout2] -j +make benchmark_gemm_[Datatype2]_[Layout1] -j +make benchmark_gemm_[Datatype2]_[Layout2] -j ``` -`benchmark_gemm_[Datatypes]` will be located in the `./bin/` directory. +`benchmark_gemm_[Datatype]_[Layout]` will be located in the `./bin/` directory. -`benchmark_gemm_[Datatypes]` must be rebuilt everytime if configuration file is modified. +`benchmark_gemm_[Datatype]_[Layout]` must be rebuilt everytime if configuration file is modified. ``` bash -rm -rf tile_engine/ && make benchmark_gemm_[Datatypes] -j # rebuild +rm -rf tile_engine/ && make benchmark_gemm_[Datatypes]_[Layout] -j # rebuild ``` -## For eaxmple build for gfx942 for fp8 and fp16 datatypes +## For eaxmple build for gfx942 for fp8 and fp16 datatypes with rcr layout ``` bash mkdir build && cd build -sh ../script/cmake-ck-dev.sh ../ gfx942 -DGEMM_DATATYPE="fp8;fp16" -make benchmark_gemm_fp8 -j -make benchmark_gemm_fp16 -j +sh ../script/cmake-ck-dev.sh ../ gfx942 -DGEMM_DATATYPE="fp8;fp16" -DGEMM_LAYOUT="rcr" +make benchmark_gemm_fp8_rcr -j +make benchmark_gemm_fp16_rcr -j ``` ## benchmark_gemm inputs @@ -103,7 +107,7 @@ The following JSON file specifies parameters used to generate and build GEMM ker At runtime, a specific subset of the generated kernels can be selected using command-line arguments. ``` bash -./bin/benchmark_gemm -pipeline=compv3 -scheduler=intrawave -epilogue=default +./bin/benchmark_gemm_[Datatype]_[Layout] -pipeline=compv3 -scheduler=intrawave -epilogue=default ``` The above command runs kernels configured with the compv3 pipeline, intrawave scheduler, and default epilogue, while sweeping over different BlockTile sizes, WarpTile sizes, and WarpTile mappings. diff --git a/tile_engine/ops/gemm/configs/benchmark.json b/tile_engine/ops/gemm/configs/benchmark.json index 601784049b..1560698b77 100644 --- a/tile_engine/ops/gemm/configs/benchmark.json +++ b/tile_engine/ops/gemm/configs/benchmark.json @@ -1,20 +1,5 @@ { "problem": { - "layout_a": { - "values": [ - "r" - ] - }, - "layout_b": { - "values": [ - "c" - ] - }, - "layout_c": { - "values": [ - "r" - ] - } }, "tile_config": { "tile_m": { diff --git a/tile_engine/ops/gemm/configs/custom_ci_config.json b/tile_engine/ops/gemm/configs/custom_ci_config.json new file mode 100644 index 0000000000..9187fb01eb --- /dev/null +++ b/tile_engine/ops/gemm/configs/custom_ci_config.json @@ -0,0 +1,82 @@ +{ + "problem": { + }, + "tile_config": { + "tile_m": { + "values": [ + 128 ] + }, + "tile_n": { + "values": [ + 128 + ] + }, + "tile_k": { + "values": [ + 128 + ] + }, + "warp_m": { + "values": [ + 2 + ] + }, + "warp_n": { + "values": [ + 2 + ] + }, + "warp_k": { + "values": [ + 1 + ] + }, + "warp_tile_m": { + "values": [ + 32 + ] + }, + "warp_tile_n": { + "values": [ + 32 + ] + }, + "warp_tile_k": { + "values": [ + 16 + ] + } + }, + "trait_config": { + "pipeline": { + "values": [ + "compv3" + ] + }, + "scheduler": { + "values": [ + "intrawave" + ] + }, + "epilogue": { + "values": [ + "default" + ] + }, + "pad_m": { + "values": [ + false + ] + }, + "pad_n": { + "values": [ + false + ] + }, + "pad_k": { + "values": [ + false + ] + } + } +} \ No newline at end of file diff --git a/tile_engine/ops/gemm/configs/default_config.json b/tile_engine/ops/gemm/configs/default_config.json index 069a3b080c..12a8ddd4b7 100644 --- a/tile_engine/ops/gemm/configs/default_config.json +++ b/tile_engine/ops/gemm/configs/default_config.json @@ -1,20 +1,5 @@ { "problem": { - "layout_a": { - "values": [ - "r" - ] - }, - "layout_b": { - "values": [ - "c" - ] - }, - "layout_c": { - "values": [ - "r" - ] - } }, "tile_config": { "tile_m": { diff --git a/tile_engine/ops/gemm/configs/user_provided_config.json b/tile_engine/ops/gemm/configs/user_provided_config.json index 79bcced82a..5761b39ada 100644 --- a/tile_engine/ops/gemm/configs/user_provided_config.json +++ b/tile_engine/ops/gemm/configs/user_provided_config.json @@ -1,20 +1,5 @@ { "problem": { - "layout_a": { - "values": [ - "r" - ] - }, - "layout_b": { - "values": [ - "c" - ] - }, - "layout_c": { - "values": [ - "r" - ] - } }, "tile_config": { "tile_m": { diff --git a/tile_engine/ops/gemm/gemm_instance_builder.py b/tile_engine/ops/gemm/gemm_instance_builder.py index de1fd0bb62..0b38c44a1a 100755 --- a/tile_engine/ops/gemm/gemm_instance_builder.py +++ b/tile_engine/ops/gemm/gemm_instance_builder.py @@ -98,19 +98,19 @@ class GemmCodeGenerator: _, ) in tile: instance_name = f"gemm_{trait}_{tile_m}x{tile_n}x{tile_k}_{warp_m}x{warp_n}x{warp_k}.cpp" - + if instance_name not in file_name: file_name.add(instance_name) f.write(str(w_p / instance_name) + "\n") files_listed += 1 file_range_map[trait] = (start_idx, files_listed) - - file_path = w_p / 'gemm_instance_blobs_range.txt' - with file_path.open('w') as f: + + file_path = w_p / "gemm_instance_blobs_range.txt" + with file_path.open("w") as f: for name, ranges in file_range_map.items(): s, l = ranges - f.write(name + " " + f"{s}" + " " + f"{l}"+ "\n") + f.write(name + " " + f"{s}" + " " + f"{l}" + "\n") def _generate_all_traits(self): """Generate all possible kernel traits names.""" @@ -563,7 +563,7 @@ struct GemmKernel {{ self.valid_trait_tile_combinations[trait].append(tile_valid_params) def _generate_instantiation_source_files(self): - """Generate kernel instance instantiation source files """ + """Generate kernel instance instantiation source files""" tile_map = {} for trait, tile_valid_params in self.valid_trait_tile_combinations.items(): for tile in tile_valid_params: @@ -583,11 +583,13 @@ struct GemmKernel {{ if key not in tile_map: tile_map[key] = set() tile_map[key].add(value) - + files_listed = 0 for trait, _ in self.valid_trait_tile_combinations.items(): for block_tile, warp_tiles in tile_map.items(): - tile_m, tile_n, tile_k, warp_m, warp_n, warp_k = map(int, block_tile.split('x')) + tile_m, tile_n, tile_k, warp_m, warp_n, warp_k = map( + int, block_tile.split("x") + ) content = f""" // SPDX-License-Identifier: MIT @@ -598,8 +600,10 @@ struct GemmKernel {{ """ for warp_tile in warp_tiles: - warp_tile_m, warp_tile_n, warp_tile_k = map(int, warp_tile.split("x")) - + warp_tile_m, warp_tile_n, warp_tile_k = map( + int, warp_tile.split("x") + ) + sparse = ( self.config.problem.datatype_map["matrix_a"] == "fp16" and self.config.problem.datatype_map["matrix_b"] == "fp16" @@ -619,15 +623,23 @@ struct GemmKernel {{ ) if sparse: files_listed = files_listed + 1 - content = content + f""" + content = ( + content + + f""" template struct {trait}::GemmKernel<{tile_m}, {tile_n}, {tile_k}, {warp_m}, {warp_n}, {warp_k}, {warp_tile_m}, {warp_tile_n}, {warp_tile_k}, true>;""" + ) files_listed = files_listed + 1 - content = content + f""" + content = ( + content + + f""" template struct {trait}::GemmKernel<{tile_m}, {tile_n}, {tile_k}, {warp_m}, {warp_n}, {warp_k}, {warp_tile_m}, {warp_tile_n}, {warp_tile_k}, false>;""" + ) content += f""" """ - (self.output_dir / - f"gemm_{trait}_{tile_m}x{tile_n}x{tile_k}_{warp_m}x{warp_n}x{warp_k}.cpp").write_text(content) + ( + self.output_dir + / f"gemm_{trait}_{tile_m}x{tile_n}x{tile_k}_{warp_m}x{warp_n}x{warp_k}.cpp" + ).write_text(content) print(f"Generated {files_listed} kernel instances in total.") def _generate_dispatcher_file(self): @@ -785,7 +797,7 @@ def do_gen_blobs( def main(args): gemm_config = ( - GemmConfig.from_json(args.config_json, args.datatype) + GemmConfig.from_json(args.config_json, args.datatype, args.layout) if args.config_json is not None else args.config_json ) @@ -823,7 +835,13 @@ if __name__ == "__main__": "-d", "--datatype", required=True, - help="Specify what datatype to use for the kernel generation, e.g. fp16, bf16, int8, fp8, bf8" + help="Specify what datatype to use for the kernel generation, e.g. fp16, bf16, int8, fp8, bf8", + ) + parser.add_argument( + "-ly", + "--layout", + required=True, + help="Specify what layout to use for the kernel generation, e.g. rcr, rrr", ) parser.add_argument( "-l", diff --git a/tile_engine/ops/gemm/json_config.py b/tile_engine/ops/gemm/json_config.py index 8b83977dd3..675a2052ef 100644 --- a/tile_engine/ops/gemm/json_config.py +++ b/tile_engine/ops/gemm/json_config.py @@ -118,7 +118,9 @@ class GemmConfig: trait_config: TraitConfig @classmethod - def from_json(cls: Type["GemmConfig"], filepath: str, datatype: str) -> "GemmConfig": + def from_json( + cls: Type["GemmConfig"], filepath: str, datatype: str, layout: str + ) -> "GemmConfig": """JSON configuration loader with validation controls""" config_path = Path(filepath) @@ -132,32 +134,40 @@ class GemmConfig: a_type = datatype b_type = datatype c_type = datatype - if b_type == 'int4': + if b_type == "int4": a_type = "fp16" - if b_type in ['bf8', 'fp8', 'int4']: + if b_type in ["bf8", "fp8", "int4"]: c_type = "fp16" + layout_parts = layout.lower() + assert len(layout_parts) == 3, ( + f"Invalid layout string: {layout} (must be 3 characters like 'rcr' where r stands for row major and c stands for column major)" + ) + assert layout_parts[0] in ("r", "c"), ( + f"Invalid matrix_a layout: {layout_parts[0]} (must be 'r' for row major or or 'c' for column major)" + ) + assert layout_parts[1] in ("r", "c"), ( + f"Invalid matrix_a layout: {layout_parts[1]} (must be 'r' for row major or or 'c' for column major)" + ) + assert layout_parts[2] == "r", ( + f"Invalid matrix_c layout: {layout_parts[2]} (must be 'r' only as currently we are supporting only row major)" + ) + a_layout = layout_parts[0] + b_layout = layout_parts[1] + c_layout = layout_parts[2] + # Parse problem config - #TODO: Not reading datatype information from json file. + # TODO: Not reading datatype information from json file. problem = ProblemConfig( datatypes=( - EnumConfigParam( - values=[a_type]), - EnumConfigParam( - values=[b_type]), - EnumConfigParam( - values=[c_type]) + EnumConfigParam(values=[a_type]), + EnumConfigParam(values=[b_type]), + EnumConfigParam(values=[c_type]), ), layouts=( - EnumConfigParam( - values=config_dict["problem"]["layout_a"]["values"] - ), - EnumConfigParam( - values=config_dict["problem"]["layout_b"]["values"] - ), - EnumConfigParam( - values=config_dict["problem"]["layout_c"]["values"] - ), + EnumConfigParam(values=[a_layout]), + EnumConfigParam(values=[b_layout]), + EnumConfigParam(values=[c_layout]), ), )