This commit is contained in:
Yanxing-Shi
2025-05-13 07:39:51 +00:00
parent 2d3dc763f8
commit a8a19be1b0
3 changed files with 55 additions and 254 deletions

View File

@@ -37,7 +37,7 @@ rm -rf tile_engine/ && make tile_engine_gemm -j # rebuild
-timer The type of timer. Possible values are gpu timer or cpu timer. Default is gpu timer.
-init The method of tensor initialization. Set to 0 for random, to 1 for linear, or 2 for constant(1). Default is 0, random.
-metric Metric with which to measure kernel performance. Set to 0 for latency, 1 for tflops, or 2 for bandwidth. Default is 0, latency.
-structured_sparsity whether use sparsity kernel or not. "Whether use sparsity kernel or not. Possible values are true or false. Default is false.
-structured_sparsity whether use sparsity kernel or not. Possible values are true or false. Default is false.
-pipeline The type of pipeline. Possible values are compv3, compv4 or mem. Default is compv3.
-epilogue The type of epilogue. Possible values are cshuffle or default. Default is cshuffle.
-pad_m Whether pad or not in m direction. Possible values are true or false. Default is false.

View File

@@ -16,19 +16,23 @@ DATA_TYPE_MAP = {'fp32' : 'float',
}
LAYOUT_MAP = {'r' : 'ck_tile::tensor_layout::gemm::RowMajor',
'c' : 'ck_tile::tensor_layout::gemm::ColumnMajor'}
'c' : 'ck_tile::tensor_layout::gemm::ColumnMajor'}
DEFAULT_EPILOGUE = """
using GemmEpilogue = ck_tile::DefaultGemm2DEpilogue<
ck_tile::DefaultGemm2DEpilogueProblem<AccDataType,
ck_tile::DefaultGemm2DEpilogueProblem<ADataType,
BDataType,
AccDataType,
CDataType,
CLayout,
pad_m,
pad_n,
kPadM,
kPadN,
WarpTileM,
WarpTileN,
WarpTileK,
UniversalGemmProblem::TransposeC>>;
UniversalGemmProblem::TransposeC,
true,
memory_operation>>;
"""
CSHUFFLE_EPILOGUE = """
@@ -46,94 +50,68 @@ CSHUFFLE_EPILOGUE = """
WarpTileM,
WarpTileN,
WarpTileK,
UniversalGemmProblem::TransposeC>>;
UniversalGemmProblem::TransposeC,
memory_operation>>;
"""
HOT_LOOP_FALSE = """
if(tail_num == ck_tile::TailNumber::Full)
{
Run(ck_tile::bool_constant<false>{},
RunSplitk(ck_tile::bool_constant<false>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Full>{});
}
else if(tail_num == ck_tile::TailNumber::Odd)
{
Run(ck_tile::bool_constant<false>{},
RunSplitk(ck_tile::bool_constant<false>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Odd>{});
}
else if(tail_num == ck_tile::TailNumber::Even)
{
Run(ck_tile::bool_constant<false>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Odd>{});
RunSplitk(ck_tile::bool_constant<false>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Even>{});
}
else
{
throw std::runtime_error("Num K loop must be larger than number of prefetech stages.");
}
"""
RUN_MEM = """
if(tail_num == ck_tile::TailNumber::One)
{
Run(ck_tile::bool_constant<true>{},
// Handle One and Full cases directly
if (tail_num == ck_tile::TailNumber::One) {
RunSplitk(ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::One>{});
}
else if(tail_num == ck_tile::TailNumber::Full)
{
Run(ck_tile::bool_constant<true>{},
} else if (tail_num == ck_tile::TailNumber::Full) {
RunSplitk(ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Full>{});
}
// Variadic call using fold expression
auto check_tail = [&](auto... TNs) {
(try_run< BaseGemmPipeline, decltype(TNs)::value>(tail_num), ...);
};
if constexpr(BaseGemmPipeline::PrefetchStages > 2)
{
if(tail_num == ck_tile::TailNumber::Two)
{
Run(ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Two>{});
}
if(tail_num == ck_tile::TailNumber::Three)
{
Run(ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Three>{});
}
if(tail_num == ck_tile::TailNumber::Four)
{
Run(ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Four>{});
}
if(tail_num == ck_tile::TailNumber::Five)
{
Run(ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Five>{});
}
if(tail_num == ck_tile::TailNumber::Six)
{
Run(ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Six>{});
}
if(tail_num == ck_tile::TailNumber::Seven)
{
Run(ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Seven>{});
}
throw std::runtime_error("The tile number is wrong! It should not exceed the prefetch stage numbers");
}
check_tail(
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Two>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Three>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Four>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Five>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Six>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Seven>{}
);
"""
RUN_COMPV3 = """
if(tail_num == ck_tile::TailNumber::Full)
{
Run(ck_tile::bool_constant<true>{},
RunSplitk(ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Full>{});
}
else if(tail_num == ck_tile::TailNumber::Odd)
{
Run(ck_tile::bool_constant<true>{},
RunSplitk(ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Odd>{});
}
else if(tail_num == ck_tile::TailNumber::Even)
{
Run(ck_tile::bool_constant<true>{},
RunSplitk(ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Even>{});
}
else
@@ -145,16 +123,17 @@ RUN_COMPV3 = """
RUN_COMPV4 = """
if(tail_num == ck_tile::TailNumber::Three)
{
Run(ck_tile::bool_constant<true>{},
RunSplitk(ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Three>{});
}
else
{
Run(ck_tile::bool_constant<true>{},
RunSplitk(ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Two>{});
}
"""
PIPELINE_MAP = {'mem' : ['ck_tile::BaseGemmPipelineAgBgCrMem', 'ck_tile::GemmPipelineAgBgCrMem'],
'compv3' : ['ck_tile::BaseGemmPipelineAgBgCrCompV3', 'ck_tile::GemmPipelineAgBgCrCompV3'],
'compv4' : ['ck_tile::BaseGemmPipelineAgBgCrCompV4', 'ck_tile::GemmPipelineAgBgCrCompV4']}
@@ -163,10 +142,10 @@ SCHEDULER_MAP = {'interwave' : 'ck_tile::GemmPipelineScheduler::Interwave',
'intrawave' : 'ck_tile::GemmPipelineScheduler::Intrawave'}
EPILOGUE_MAP = {'default' :DEFAULT_EPILOGUE,
'cshuffle' : CSHUFFLE_EPILOGUE}
'cshuffle' : CSHUFFLE_EPILOGUE}
HOT_LOOP_TRUE = {'mem' : RUN_MEM,
'compv3' : RUN_COMPV3,
'compv4' : RUN_COMPV4}
'compv4' : RUN_COMPV4}
BOOL_MAP = lambda b_: {True: 'true', False: 'false'}[bool(b_)]

View File

@@ -11,187 +11,9 @@ import os
import sys
import itertools
import copy
<<<<<<< HEAD
import logging
from json_utils import *
from codegen_utils import *
=======
import json
from dataclasses import dataclass
DATA_TYPE_MAP = {'fp32' : 'float',
'fp16' : 'ck_tile::half_t',
'bf16' : 'ck_tile::bf16_t',
'int8' : 'ck_tile::int8_t',
'fp8' : 'ck_tile::fp8_t',
'bf8' : 'ck_tile::bf8_t',
'int4' : 'ck_tile::pk_int4_t'
}
LAYOUT_MAP = {'r' : 'ck_tile::tensor_layout::gemm::RowMajor',
'c' : 'ck_tile::tensor_layout::gemm::ColumnMajor'}
DEFAULT_EPILOGUE = """
using GemmEpilogue = ck_tile::DefaultGemm2DEpilogue<
ck_tile::DefaultGemm2DEpilogueProblem<ADataType,
BDataType,
AccDataType,
CDataType,
CLayout,
kPadM,
kPadN,
WarpTileM,
WarpTileN,
WarpTileK,
UniversalGemmProblem::TransposeC,
true,
memory_operation>>;
"""
CSHUFFLE_EPILOGUE = """
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<ADataType,
BDataType,
AccDataType,
CDataType,
CLayout,
GemmPipelineProblem::kBlockSize,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
WarpM,
WarpN,
WarpTileM,
WarpTileN,
WarpTileK,
UniversalGemmProblem::TransposeC,
memory_operation>>;
"""
HOT_LOOP_FALSE = """
if(tail_num == ck_tile::TailNumber::Full)
{
RunSplitk(ck_tile::bool_constant<false>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Full>{});
}
else if(tail_num == ck_tile::TailNumber::Odd)
{
RunSplitk(ck_tile::bool_constant<false>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Odd>{});
}
else if(tail_num == ck_tile::TailNumber::Even)
{
RunSplitk(ck_tile::bool_constant<false>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Even>{});
}
else
{
throw std::runtime_error("Num K loop must be larger than number of prefetech stages.");
}
"""
RUN_MEM = """
// Handle One and Full cases directly
if (tail_num == ck_tile::TailNumber::One) {
RunSplitk(ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::One>{});
} else if (tail_num == ck_tile::TailNumber::Full) {
RunSplitk(ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Full>{});
}
// Variadic call using fold expression
auto check_tail = [&](auto... TNs) {
(try_run< BaseGemmPipeline, decltype(TNs)::value>(tail_num), ...);
};
check_tail(
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Two>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Three>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Four>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Five>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Six>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Seven>{}
);
"""
RUN_COMPV3 = """
if(tail_num == ck_tile::TailNumber::Full)
{
RunSplitk(ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Full>{});
}
else if(tail_num == ck_tile::TailNumber::Odd)
{
RunSplitk(ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Odd>{});
}
else if(tail_num == ck_tile::TailNumber::Even)
{
RunSplitk(ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Even>{});
}
else
{
throw std::runtime_error("The tail number is wrong. It should be Full, Odd, or Even.");
}
"""
RUN_COMPV4 = """
if(tail_num == ck_tile::TailNumber::Three)
{
RunSplitk(ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Three>{});
}
else
{
RunSplitk(ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Two>{});
}
"""
PIPELINE_MAP = {'mem' : ['ck_tile::BaseGemmPipelineAgBgCrMem', 'ck_tile::GemmPipelineAgBgCrMem'],
'compv3' : ['ck_tile::BaseGemmPipelineAgBgCrCompV3', 'ck_tile::GemmPipelineAgBgCrCompV3'],
'compv4' : ['ck_tile::BaseGemmPipelineAgBgCrCompV4', 'ck_tile::GemmPipelineAgBgCrCompV4']}
SCHEDULER_MAP = {'interwave' : 'ck_tile::GemmPipelineScheduler::Interwave',
'intrawave' : 'ck_tile::GemmPipelineScheduler::Intrawave'}
EPILOGUE_MAP = {'default' :DEFAULT_EPILOGUE,
'cshuffle' : CSHUFFLE_EPILOGUE}
HOT_LOOP_TRUE = {'mem' : RUN_MEM,
'compv3' : RUN_COMPV3,
'compv4' : RUN_COMPV4}
def BOOL_MAP(b_) -> str:
if b_:
return 'true'
else:
return 'false'
@dataclass
class GemmConfig:
def __init__(self, config_data):
self.matrix_cfg : Dict[str, Any] = {}
self.impl_cfg : Dict[str, Any] = {}
for key, value in config_data.items():
if key in ["datatype", "layout_a", "layout_b", "layout_c"]:
self.matrix_cfg[key] = value
else:
self.impl_cfg[key] = value
@property
def datatype(self) -> str:
return self.matrix_cfg["datatype"]["values"][0]
@property
def layouts(self) -> List[str]:
return [
self.matrix_cfg["layout_a"]["values"][0],
self.matrix_cfg["layout_b"]["values"][0],
self.matrix_cfg["layout_c"]["values"][0]
]
>>>>>>> origin
class GemmCodeGenerator:
def __init__(self, output_dir: str, user_provided_config: Optional[GemmConfig] = None):
@@ -222,8 +44,7 @@ class GemmCodeGenerator:
list_f.write(str(w_p / f"gemm_{trait}.hpp") + "\n")
def _generate_all_traits(self):
params = ["pipeline", "epilogue", "scheduler", "pad_m", "pad_n", "pad_k"
]
params = ["pipeline", "epilogue", "scheduler", "pad_m", "pad_n", "pad_k"]
# To remove some unsupported combinations
unsupported_combinations = {
@@ -341,9 +162,9 @@ template <int TileM, int TileN, int TileK,
int WarpTileM, int WarpTileN, int WarpTileK,
bool structured_sparsity>
struct GemmKernel {{
static constexpr bool pad_m = {BOOL_MAP(pad_m)};
static constexpr bool pad_n = {BOOL_MAP(pad_n)};
static constexpr bool pad_k = {BOOL_MAP(pad_k)};
static constexpr bool kPadM = {BOOL_MAP(pad_m)};
static constexpr bool kPadN = {BOOL_MAP(pad_n)};
static constexpr bool kPadK = {BOOL_MAP(pad_k)};
static float launch(ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s) {{
static constexpr bool permuteA = false;
@@ -369,7 +190,7 @@ struct GemmKernel {{
TileParitionerM01>;
using Traits =
ck_tile::TileGemmTraits<pad_m, pad_n, pad_k, ALayout, BLayout, CLayout>;
ck_tile::TileGemmTraits<kPadM, kPadN, kPadK, ALayout, BLayout, CLayout>;
using GemmUniversalTraits =
ck_tile::TileGemmUniversalTraits<kPadM, kPadN, kPadK, DoubleSmemBuffer,
@@ -539,13 +360,13 @@ struct GemmDispatcher {
for trait in self.all_trait_names:
content += f""" kernel_map["{trait}"] = [=]((Profiler& profiler,
content += f""" kernel_map["{trait}"] = [=]( GemmProfiler& profiler,
ck_tile::DeviceMem& c_m_n_dev_buf,
ck_tile::HostTensor<CDataType>& c_m_n_host_result,
ck_tile::HostTensor<CDataType>& c_m_n_dev_result,
int verify,
ck_tile::GemmHostArgs& args,
const ck_tile::stream_config& s) {{
const ck_tile::stream_config& stream) {{
if(structured_sparsity){{ // SMFMA"""
for tile in tile_params:
# Check if we have valid tile/warp combinations
@@ -557,7 +378,7 @@ struct GemmDispatcher {
((tile[6] == 32 and tile[7] == 32 and tile[8] == 16) or
(tile[6] == 16 and tile[7] == 16 and tile[8] == 32))
content += f"""
profiler.run_kernel<{group}::GemmKernel<{tile[0]}, {tile[1]}, {tile[2]}, {tile[3]}, {tile[4]}, {tile[5]}, {tile[6]}, {tile[7]}, {tile[8]}, {BOOL_MAP(sparse)}>>(c_m_n_dev_buf, c_m_n_host_result, c_m_n_dev_result, verify, args, stream);"""
profiler.benchmark_kernel<{trait}::GemmKernel<{tile[0]}, {tile[1]}, {tile[2]}, {tile[3]}, {tile[4]}, {tile[5]}, {tile[6]}, {tile[7]}, {tile[8]}, {BOOL_MAP(sparse)}>>(c_m_n_dev_buf, c_m_n_host_result, c_m_n_dev_result, verify, args, stream);"""
content += f"""
}} else {{"""
for tile in tile_params:
@@ -567,7 +388,7 @@ struct GemmDispatcher {
((tile[1]/(tile[4] * tile[8]) * tile[4] * tile[8]) != tile[1]):
continue
content += f"""
profiler.run_kernel<{trait}::GemmKernel<{tile[0]}, {tile[1]}, {tile[2]}, {tile[3]}, {tile[4]}, {tile[5]}, {tile[6]}, {tile[7]}, {tile[8]}, {BOOL_MAP(False)}>>(c_m_n_dev_buf, c_m_n_host_result, c_m_n_dev_result, verify, args, stream);"""
profiler.benchmark_kernel<{trait}::GemmKernel<{tile[0]}, {tile[1]}, {tile[2]}, {tile[3]}, {tile[4]}, {tile[5]}, {tile[6]}, {tile[7]}, {tile[8]}, {BOOL_MAP(False)}>>(c_m_n_dev_buf, c_m_n_host_result, c_m_n_dev_result, verify, args, stream);"""
content += f"""
}}
}};\n"""
@@ -579,16 +400,17 @@ struct GemmDispatcher {
ck_tile::HostTensor<CDataType>& c_m_n_dev_result,
int verify,
int metric,
bool structured_sparsity,
const KernelTraits& trait,
ck_tile::GemmHostArgs& gemm_args,
const ck_tile::stream_config& s) {
init();
const ck_tile::stream_config& stream) {
init(structured_sparsity);
const std::string key = assemble_key(trait);
auto& kernel_map = get_kernel_map();
auto& profiler = GemmProfiler::instance();
if(auto it = kernel_map.find(key); it != kernel_map.end()) {
it->second(
profiler, c_m_n_dev_buf, c_m_n_host_result, c_m_n_dev_result, verify, gemm_args, s);
profiler, c_m_n_dev_buf, c_m_n_host_result, c_m_n_dev_result, verify, gemm_args, stream);
profiler.select_best_instance(static_cast<Metric>(metric));
return;
}
@@ -662,7 +484,7 @@ if __name__ == "__main__":
"-l", "--list_blobs", action = 'store_true', help="List all kernel instances to file"
)
parser.add_argument(
"-g", "--gen_blobs", action = 'store_true', help="Generate all kernels instances into different files"
"-g", "--gen_blobs", action = 'store_true', help="Generate all kernel instances into different files"
)
args = parser.parse_args()