diff --git a/tile_engine/ops/gemm/README.md b/tile_engine/ops/gemm/README.md index ff92cef6cf..0a919a7773 100644 --- a/tile_engine/ops/gemm/README.md +++ b/tile_engine/ops/gemm/README.md @@ -37,7 +37,7 @@ rm -rf tile_engine/ && make tile_engine_gemm -j # rebuild -timer The type of timer. Possible values are gpu timer or cpu timer. Default is gpu timer. -init The method of tensor initialization. Set to 0 for random, to 1 for linear, or 2 for constant(1). Default is 0, random. -metric Metric with which to measure kernel performance. Set to 0 for latency, 1 for tflops, or 2 for bandwidth. Default is 0, latency. - -structured_sparsity whether use sparsity kernel or not. "Whether use sparsity kernel or not. Possible values are true or false. Default is false. + -structured_sparsity whether use sparsity kernel or not. Possible values are true or false. Default is false. -pipeline The type of pipeline. Possible values are compv3, compv4 or mem. Default is compv3. -epilogue The type of epilogue. Possible values are cshuffle or default. Default is cshuffle. -pad_m Whether pad or not in m direction. Possible values are true or false. Default is false. diff --git a/tile_engine/ops/gemm/codegen_utils.py b/tile_engine/ops/gemm/codegen_utils.py index 95863bf63b..bb8085baef 100644 --- a/tile_engine/ops/gemm/codegen_utils.py +++ b/tile_engine/ops/gemm/codegen_utils.py @@ -16,19 +16,23 @@ DATA_TYPE_MAP = {'fp32' : 'float', } LAYOUT_MAP = {'r' : 'ck_tile::tensor_layout::gemm::RowMajor', - 'c' : 'ck_tile::tensor_layout::gemm::ColumnMajor'} + 'c' : 'ck_tile::tensor_layout::gemm::ColumnMajor'} DEFAULT_EPILOGUE = """ using GemmEpilogue = ck_tile::DefaultGemm2DEpilogue< - ck_tile::DefaultGemm2DEpilogueProblem>; + UniversalGemmProblem::TransposeC, + true, + memory_operation>>; """ CSHUFFLE_EPILOGUE = """ @@ -46,94 +50,68 @@ CSHUFFLE_EPILOGUE = """ WarpTileM, WarpTileN, WarpTileK, - UniversalGemmProblem::TransposeC>>; + UniversalGemmProblem::TransposeC, + memory_operation>>; """ - HOT_LOOP_FALSE = """ if(tail_num == ck_tile::TailNumber::Full) { - Run(ck_tile::bool_constant{}, + RunSplitk(ck_tile::bool_constant{}, ck_tile::integral_constant{}); } else if(tail_num == ck_tile::TailNumber::Odd) { - Run(ck_tile::bool_constant{}, + RunSplitk(ck_tile::bool_constant{}, ck_tile::integral_constant{}); } else if(tail_num == ck_tile::TailNumber::Even) { - Run(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); + RunSplitk(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); } else { throw std::runtime_error("Num K loop must be larger than number of prefetech stages."); } """ - RUN_MEM = """ - if(tail_num == ck_tile::TailNumber::One) - { - Run(ck_tile::bool_constant{}, + // Handle One and Full cases directly + if (tail_num == ck_tile::TailNumber::One) { + RunSplitk(ck_tile::bool_constant{}, ck_tile::integral_constant{}); - } - else if(tail_num == ck_tile::TailNumber::Full) - { - Run(ck_tile::bool_constant{}, + } else if (tail_num == ck_tile::TailNumber::Full) { + RunSplitk(ck_tile::bool_constant{}, ck_tile::integral_constant{}); } + // Variadic call using fold expression + auto check_tail = [&](auto... TNs) { + (try_run< BaseGemmPipeline, decltype(TNs)::value>(tail_num), ...); + }; - if constexpr(BaseGemmPipeline::PrefetchStages > 2) - { - if(tail_num == ck_tile::TailNumber::Two) - { - Run(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); - } - - if(tail_num == ck_tile::TailNumber::Three) - { - Run(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); - } - if(tail_num == ck_tile::TailNumber::Four) - { - Run(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); - } - if(tail_num == ck_tile::TailNumber::Five) - { - Run(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); - } - if(tail_num == ck_tile::TailNumber::Six) - { - Run(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); - } - if(tail_num == ck_tile::TailNumber::Seven) - { - Run(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); - } - throw std::runtime_error("The tile number is wrong! It should not exceed the prefetch stage numbers"); - } + check_tail( + ck_tile::integral_constant{}, + ck_tile::integral_constant{}, + ck_tile::integral_constant{}, + ck_tile::integral_constant{}, + ck_tile::integral_constant{}, + ck_tile::integral_constant{} + ); """ RUN_COMPV3 = """ if(tail_num == ck_tile::TailNumber::Full) { - Run(ck_tile::bool_constant{}, + RunSplitk(ck_tile::bool_constant{}, ck_tile::integral_constant{}); } else if(tail_num == ck_tile::TailNumber::Odd) { - Run(ck_tile::bool_constant{}, + RunSplitk(ck_tile::bool_constant{}, ck_tile::integral_constant{}); } else if(tail_num == ck_tile::TailNumber::Even) { - Run(ck_tile::bool_constant{}, + RunSplitk(ck_tile::bool_constant{}, ck_tile::integral_constant{}); } else @@ -145,16 +123,17 @@ RUN_COMPV3 = """ RUN_COMPV4 = """ if(tail_num == ck_tile::TailNumber::Three) { - Run(ck_tile::bool_constant{}, + RunSplitk(ck_tile::bool_constant{}, ck_tile::integral_constant{}); } else { - Run(ck_tile::bool_constant{}, + RunSplitk(ck_tile::bool_constant{}, ck_tile::integral_constant{}); } """ + PIPELINE_MAP = {'mem' : ['ck_tile::BaseGemmPipelineAgBgCrMem', 'ck_tile::GemmPipelineAgBgCrMem'], 'compv3' : ['ck_tile::BaseGemmPipelineAgBgCrCompV3', 'ck_tile::GemmPipelineAgBgCrCompV3'], 'compv4' : ['ck_tile::BaseGemmPipelineAgBgCrCompV4', 'ck_tile::GemmPipelineAgBgCrCompV4']} @@ -163,10 +142,10 @@ SCHEDULER_MAP = {'interwave' : 'ck_tile::GemmPipelineScheduler::Interwave', 'intrawave' : 'ck_tile::GemmPipelineScheduler::Intrawave'} EPILOGUE_MAP = {'default' :DEFAULT_EPILOGUE, - 'cshuffle' : CSHUFFLE_EPILOGUE} + 'cshuffle' : CSHUFFLE_EPILOGUE} HOT_LOOP_TRUE = {'mem' : RUN_MEM, 'compv3' : RUN_COMPV3, - 'compv4' : RUN_COMPV4} + 'compv4' : RUN_COMPV4} BOOL_MAP = lambda b_: {True: 'true', False: 'false'}[bool(b_)] \ No newline at end of file diff --git a/tile_engine/ops/gemm/gemm_instance_builder.py b/tile_engine/ops/gemm/gemm_instance_builder.py index b21aac3008..e3e36a1da0 100755 --- a/tile_engine/ops/gemm/gemm_instance_builder.py +++ b/tile_engine/ops/gemm/gemm_instance_builder.py @@ -11,187 +11,9 @@ import os import sys import itertools import copy -<<<<<<< HEAD import logging from json_utils import * from codegen_utils import * -======= -import json -from dataclasses import dataclass - -DATA_TYPE_MAP = {'fp32' : 'float', - 'fp16' : 'ck_tile::half_t', - 'bf16' : 'ck_tile::bf16_t', - 'int8' : 'ck_tile::int8_t', - 'fp8' : 'ck_tile::fp8_t', - 'bf8' : 'ck_tile::bf8_t', - 'int4' : 'ck_tile::pk_int4_t' - } - -LAYOUT_MAP = {'r' : 'ck_tile::tensor_layout::gemm::RowMajor', - 'c' : 'ck_tile::tensor_layout::gemm::ColumnMajor'} - -DEFAULT_EPILOGUE = """ - using GemmEpilogue = ck_tile::DefaultGemm2DEpilogue< - ck_tile::DefaultGemm2DEpilogueProblem>; -""" - -CSHUFFLE_EPILOGUE = """ - using GemmEpilogue = ck_tile::CShuffleEpilogue< - ck_tile::CShuffleEpilogueProblem>; -""" -HOT_LOOP_FALSE = """ - if(tail_num == ck_tile::TailNumber::Full) - { - RunSplitk(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); - } - else if(tail_num == ck_tile::TailNumber::Odd) - { - RunSplitk(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); - } - else if(tail_num == ck_tile::TailNumber::Even) - { - RunSplitk(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); - } - else - { - throw std::runtime_error("Num K loop must be larger than number of prefetech stages."); - } -""" -RUN_MEM = """ - // Handle One and Full cases directly - if (tail_num == ck_tile::TailNumber::One) { - RunSplitk(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); - } else if (tail_num == ck_tile::TailNumber::Full) { - RunSplitk(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); - } - // Variadic call using fold expression - auto check_tail = [&](auto... TNs) { - (try_run< BaseGemmPipeline, decltype(TNs)::value>(tail_num), ...); - }; - - check_tail( - ck_tile::integral_constant{}, - ck_tile::integral_constant{}, - ck_tile::integral_constant{}, - ck_tile::integral_constant{}, - ck_tile::integral_constant{}, - ck_tile::integral_constant{} - ); -""" - -RUN_COMPV3 = """ - if(tail_num == ck_tile::TailNumber::Full) - { - RunSplitk(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); - } - else if(tail_num == ck_tile::TailNumber::Odd) - { - RunSplitk(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); - } - else if(tail_num == ck_tile::TailNumber::Even) - { - RunSplitk(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); - } - else - { - throw std::runtime_error("The tail number is wrong. It should be Full, Odd, or Even."); - } -""" - -RUN_COMPV4 = """ - if(tail_num == ck_tile::TailNumber::Three) - { - RunSplitk(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); - } - else - { - RunSplitk(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); - } -""" - - -PIPELINE_MAP = {'mem' : ['ck_tile::BaseGemmPipelineAgBgCrMem', 'ck_tile::GemmPipelineAgBgCrMem'], - 'compv3' : ['ck_tile::BaseGemmPipelineAgBgCrCompV3', 'ck_tile::GemmPipelineAgBgCrCompV3'], - 'compv4' : ['ck_tile::BaseGemmPipelineAgBgCrCompV4', 'ck_tile::GemmPipelineAgBgCrCompV4']} - -SCHEDULER_MAP = {'interwave' : 'ck_tile::GemmPipelineScheduler::Interwave', - 'intrawave' : 'ck_tile::GemmPipelineScheduler::Intrawave'} - -EPILOGUE_MAP = {'default' :DEFAULT_EPILOGUE, - 'cshuffle' : CSHUFFLE_EPILOGUE} - -HOT_LOOP_TRUE = {'mem' : RUN_MEM, - 'compv3' : RUN_COMPV3, - 'compv4' : RUN_COMPV4} - - -def BOOL_MAP(b_) -> str: - if b_: - return 'true' - else: - return 'false' - -@dataclass -class GemmConfig: - def __init__(self, config_data): - self.matrix_cfg : Dict[str, Any] = {} - self.impl_cfg : Dict[str, Any] = {} - for key, value in config_data.items(): - if key in ["datatype", "layout_a", "layout_b", "layout_c"]: - self.matrix_cfg[key] = value - else: - self.impl_cfg[key] = value - - @property - def datatype(self) -> str: - return self.matrix_cfg["datatype"]["values"][0] - - @property - def layouts(self) -> List[str]: - return [ - self.matrix_cfg["layout_a"]["values"][0], - self.matrix_cfg["layout_b"]["values"][0], - self.matrix_cfg["layout_c"]["values"][0] - ] - ->>>>>>> origin class GemmCodeGenerator: def __init__(self, output_dir: str, user_provided_config: Optional[GemmConfig] = None): @@ -222,8 +44,7 @@ class GemmCodeGenerator: list_f.write(str(w_p / f"gemm_{trait}.hpp") + "\n") def _generate_all_traits(self): - params = ["pipeline", "epilogue", "scheduler", "pad_m", "pad_n", "pad_k" - ] + params = ["pipeline", "epilogue", "scheduler", "pad_m", "pad_n", "pad_k"] # To remove some unsupported combinations unsupported_combinations = { @@ -341,9 +162,9 @@ template struct GemmKernel {{ - static constexpr bool pad_m = {BOOL_MAP(pad_m)}; - static constexpr bool pad_n = {BOOL_MAP(pad_n)}; - static constexpr bool pad_k = {BOOL_MAP(pad_k)}; + static constexpr bool kPadM = {BOOL_MAP(pad_m)}; + static constexpr bool kPadN = {BOOL_MAP(pad_n)}; + static constexpr bool kPadK = {BOOL_MAP(pad_k)}; static float launch(ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s) {{ static constexpr bool permuteA = false; @@ -369,7 +190,7 @@ struct GemmKernel {{ TileParitionerM01>; using Traits = - ck_tile::TileGemmTraits; + ck_tile::TileGemmTraits; using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits& c_m_n_host_result, ck_tile::HostTensor& c_m_n_dev_result, int verify, ck_tile::GemmHostArgs& args, - const ck_tile::stream_config& s) {{ + const ck_tile::stream_config& stream) {{ if(structured_sparsity){{ // SMFMA""" for tile in tile_params: # Check if we have valid tile/warp combinations @@ -557,7 +378,7 @@ struct GemmDispatcher { ((tile[6] == 32 and tile[7] == 32 and tile[8] == 16) or (tile[6] == 16 and tile[7] == 16 and tile[8] == 32)) content += f""" - profiler.run_kernel<{group}::GemmKernel<{tile[0]}, {tile[1]}, {tile[2]}, {tile[3]}, {tile[4]}, {tile[5]}, {tile[6]}, {tile[7]}, {tile[8]}, {BOOL_MAP(sparse)}>>(c_m_n_dev_buf, c_m_n_host_result, c_m_n_dev_result, verify, args, stream);""" + profiler.benchmark_kernel<{trait}::GemmKernel<{tile[0]}, {tile[1]}, {tile[2]}, {tile[3]}, {tile[4]}, {tile[5]}, {tile[6]}, {tile[7]}, {tile[8]}, {BOOL_MAP(sparse)}>>(c_m_n_dev_buf, c_m_n_host_result, c_m_n_dev_result, verify, args, stream);""" content += f""" }} else {{""" for tile in tile_params: @@ -567,7 +388,7 @@ struct GemmDispatcher { ((tile[1]/(tile[4] * tile[8]) * tile[4] * tile[8]) != tile[1]): continue content += f""" - profiler.run_kernel<{trait}::GemmKernel<{tile[0]}, {tile[1]}, {tile[2]}, {tile[3]}, {tile[4]}, {tile[5]}, {tile[6]}, {tile[7]}, {tile[8]}, {BOOL_MAP(False)}>>(c_m_n_dev_buf, c_m_n_host_result, c_m_n_dev_result, verify, args, stream);""" + profiler.benchmark_kernel<{trait}::GemmKernel<{tile[0]}, {tile[1]}, {tile[2]}, {tile[3]}, {tile[4]}, {tile[5]}, {tile[6]}, {tile[7]}, {tile[8]}, {BOOL_MAP(False)}>>(c_m_n_dev_buf, c_m_n_host_result, c_m_n_dev_result, verify, args, stream);""" content += f""" }} }};\n""" @@ -579,16 +400,17 @@ struct GemmDispatcher { ck_tile::HostTensor& c_m_n_dev_result, int verify, int metric, + bool structured_sparsity, const KernelTraits& trait, ck_tile::GemmHostArgs& gemm_args, - const ck_tile::stream_config& s) { - init(); + const ck_tile::stream_config& stream) { + init(structured_sparsity); const std::string key = assemble_key(trait); auto& kernel_map = get_kernel_map(); auto& profiler = GemmProfiler::instance(); if(auto it = kernel_map.find(key); it != kernel_map.end()) { it->second( - profiler, c_m_n_dev_buf, c_m_n_host_result, c_m_n_dev_result, verify, gemm_args, s); + profiler, c_m_n_dev_buf, c_m_n_host_result, c_m_n_dev_result, verify, gemm_args, stream); profiler.select_best_instance(static_cast(metric)); return; } @@ -662,7 +484,7 @@ if __name__ == "__main__": "-l", "--list_blobs", action = 'store_true', help="List all kernel instances to file" ) parser.add_argument( - "-g", "--gen_blobs", action = 'store_true', help="Generate all kernels instances into different files" + "-g", "--gen_blobs", action = 'store_true', help="Generate all kernel instances into different files" ) args = parser.parse_args()