mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-29 19:28:33 +00:00
merge
This commit is contained in:
@@ -37,7 +37,7 @@ rm -rf tile_engine/ && make tile_engine_gemm -j # rebuild
|
||||
-timer The type of timer. Possible values are gpu timer or cpu timer. Default is gpu timer.
|
||||
-init The method of tensor initialization. Set to 0 for random, to 1 for linear, or 2 for constant(1). Default is 0, random.
|
||||
-metric Metric with which to measure kernel performance. Set to 0 for latency, 1 for tflops, or 2 for bandwidth. Default is 0, latency.
|
||||
-structured_sparsity whether use sparsity kernel or not. "Whether use sparsity kernel or not. Possible values are true or false. Default is false.
|
||||
-structured_sparsity whether use sparsity kernel or not. Possible values are true or false. Default is false.
|
||||
-pipeline The type of pipeline. Possible values are compv3, compv4 or mem. Default is compv3.
|
||||
-epilogue The type of epilogue. Possible values are cshuffle or default. Default is cshuffle.
|
||||
-pad_m Whether pad or not in m direction. Possible values are true or false. Default is false.
|
||||
|
||||
@@ -16,19 +16,23 @@ DATA_TYPE_MAP = {'fp32' : 'float',
|
||||
}
|
||||
|
||||
LAYOUT_MAP = {'r' : 'ck_tile::tensor_layout::gemm::RowMajor',
|
||||
'c' : 'ck_tile::tensor_layout::gemm::ColumnMajor'}
|
||||
'c' : 'ck_tile::tensor_layout::gemm::ColumnMajor'}
|
||||
|
||||
DEFAULT_EPILOGUE = """
|
||||
using GemmEpilogue = ck_tile::DefaultGemm2DEpilogue<
|
||||
ck_tile::DefaultGemm2DEpilogueProblem<AccDataType,
|
||||
ck_tile::DefaultGemm2DEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
CLayout,
|
||||
pad_m,
|
||||
pad_n,
|
||||
kPadM,
|
||||
kPadN,
|
||||
WarpTileM,
|
||||
WarpTileN,
|
||||
WarpTileK,
|
||||
UniversalGemmProblem::TransposeC>>;
|
||||
UniversalGemmProblem::TransposeC,
|
||||
true,
|
||||
memory_operation>>;
|
||||
"""
|
||||
|
||||
CSHUFFLE_EPILOGUE = """
|
||||
@@ -46,94 +50,68 @@ CSHUFFLE_EPILOGUE = """
|
||||
WarpTileM,
|
||||
WarpTileN,
|
||||
WarpTileK,
|
||||
UniversalGemmProblem::TransposeC>>;
|
||||
UniversalGemmProblem::TransposeC,
|
||||
memory_operation>>;
|
||||
"""
|
||||
|
||||
HOT_LOOP_FALSE = """
|
||||
if(tail_num == ck_tile::TailNumber::Full)
|
||||
{
|
||||
Run(ck_tile::bool_constant<false>{},
|
||||
RunSplitk(ck_tile::bool_constant<false>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Full>{});
|
||||
}
|
||||
else if(tail_num == ck_tile::TailNumber::Odd)
|
||||
{
|
||||
Run(ck_tile::bool_constant<false>{},
|
||||
RunSplitk(ck_tile::bool_constant<false>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Odd>{});
|
||||
}
|
||||
else if(tail_num == ck_tile::TailNumber::Even)
|
||||
{
|
||||
Run(ck_tile::bool_constant<false>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Odd>{});
|
||||
RunSplitk(ck_tile::bool_constant<false>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Even>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Num K loop must be larger than number of prefetech stages.");
|
||||
}
|
||||
"""
|
||||
|
||||
RUN_MEM = """
|
||||
if(tail_num == ck_tile::TailNumber::One)
|
||||
{
|
||||
Run(ck_tile::bool_constant<true>{},
|
||||
// Handle One and Full cases directly
|
||||
if (tail_num == ck_tile::TailNumber::One) {
|
||||
RunSplitk(ck_tile::bool_constant<true>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::One>{});
|
||||
}
|
||||
else if(tail_num == ck_tile::TailNumber::Full)
|
||||
{
|
||||
Run(ck_tile::bool_constant<true>{},
|
||||
} else if (tail_num == ck_tile::TailNumber::Full) {
|
||||
RunSplitk(ck_tile::bool_constant<true>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Full>{});
|
||||
}
|
||||
// Variadic call using fold expression
|
||||
auto check_tail = [&](auto... TNs) {
|
||||
(try_run< BaseGemmPipeline, decltype(TNs)::value>(tail_num), ...);
|
||||
};
|
||||
|
||||
if constexpr(BaseGemmPipeline::PrefetchStages > 2)
|
||||
{
|
||||
if(tail_num == ck_tile::TailNumber::Two)
|
||||
{
|
||||
Run(ck_tile::bool_constant<true>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Two>{});
|
||||
}
|
||||
|
||||
if(tail_num == ck_tile::TailNumber::Three)
|
||||
{
|
||||
Run(ck_tile::bool_constant<true>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Three>{});
|
||||
}
|
||||
if(tail_num == ck_tile::TailNumber::Four)
|
||||
{
|
||||
Run(ck_tile::bool_constant<true>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Four>{});
|
||||
}
|
||||
if(tail_num == ck_tile::TailNumber::Five)
|
||||
{
|
||||
Run(ck_tile::bool_constant<true>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Five>{});
|
||||
}
|
||||
if(tail_num == ck_tile::TailNumber::Six)
|
||||
{
|
||||
Run(ck_tile::bool_constant<true>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Six>{});
|
||||
}
|
||||
if(tail_num == ck_tile::TailNumber::Seven)
|
||||
{
|
||||
Run(ck_tile::bool_constant<true>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Seven>{});
|
||||
}
|
||||
throw std::runtime_error("The tile number is wrong! It should not exceed the prefetch stage numbers");
|
||||
}
|
||||
check_tail(
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Two>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Three>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Four>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Five>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Six>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Seven>{}
|
||||
);
|
||||
"""
|
||||
|
||||
RUN_COMPV3 = """
|
||||
if(tail_num == ck_tile::TailNumber::Full)
|
||||
{
|
||||
Run(ck_tile::bool_constant<true>{},
|
||||
RunSplitk(ck_tile::bool_constant<true>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Full>{});
|
||||
}
|
||||
else if(tail_num == ck_tile::TailNumber::Odd)
|
||||
{
|
||||
Run(ck_tile::bool_constant<true>{},
|
||||
RunSplitk(ck_tile::bool_constant<true>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Odd>{});
|
||||
}
|
||||
else if(tail_num == ck_tile::TailNumber::Even)
|
||||
{
|
||||
Run(ck_tile::bool_constant<true>{},
|
||||
RunSplitk(ck_tile::bool_constant<true>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Even>{});
|
||||
}
|
||||
else
|
||||
@@ -145,16 +123,17 @@ RUN_COMPV3 = """
|
||||
RUN_COMPV4 = """
|
||||
if(tail_num == ck_tile::TailNumber::Three)
|
||||
{
|
||||
Run(ck_tile::bool_constant<true>{},
|
||||
RunSplitk(ck_tile::bool_constant<true>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Three>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
Run(ck_tile::bool_constant<true>{},
|
||||
RunSplitk(ck_tile::bool_constant<true>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Two>{});
|
||||
}
|
||||
"""
|
||||
|
||||
|
||||
PIPELINE_MAP = {'mem' : ['ck_tile::BaseGemmPipelineAgBgCrMem', 'ck_tile::GemmPipelineAgBgCrMem'],
|
||||
'compv3' : ['ck_tile::BaseGemmPipelineAgBgCrCompV3', 'ck_tile::GemmPipelineAgBgCrCompV3'],
|
||||
'compv4' : ['ck_tile::BaseGemmPipelineAgBgCrCompV4', 'ck_tile::GemmPipelineAgBgCrCompV4']}
|
||||
@@ -163,10 +142,10 @@ SCHEDULER_MAP = {'interwave' : 'ck_tile::GemmPipelineScheduler::Interwave',
|
||||
'intrawave' : 'ck_tile::GemmPipelineScheduler::Intrawave'}
|
||||
|
||||
EPILOGUE_MAP = {'default' :DEFAULT_EPILOGUE,
|
||||
'cshuffle' : CSHUFFLE_EPILOGUE}
|
||||
'cshuffle' : CSHUFFLE_EPILOGUE}
|
||||
|
||||
HOT_LOOP_TRUE = {'mem' : RUN_MEM,
|
||||
'compv3' : RUN_COMPV3,
|
||||
'compv4' : RUN_COMPV4}
|
||||
'compv4' : RUN_COMPV4}
|
||||
|
||||
BOOL_MAP = lambda b_: {True: 'true', False: 'false'}[bool(b_)]
|
||||
@@ -11,187 +11,9 @@ import os
|
||||
import sys
|
||||
import itertools
|
||||
import copy
|
||||
<<<<<<< HEAD
|
||||
import logging
|
||||
from json_utils import *
|
||||
from codegen_utils import *
|
||||
=======
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
|
||||
DATA_TYPE_MAP = {'fp32' : 'float',
|
||||
'fp16' : 'ck_tile::half_t',
|
||||
'bf16' : 'ck_tile::bf16_t',
|
||||
'int8' : 'ck_tile::int8_t',
|
||||
'fp8' : 'ck_tile::fp8_t',
|
||||
'bf8' : 'ck_tile::bf8_t',
|
||||
'int4' : 'ck_tile::pk_int4_t'
|
||||
}
|
||||
|
||||
LAYOUT_MAP = {'r' : 'ck_tile::tensor_layout::gemm::RowMajor',
|
||||
'c' : 'ck_tile::tensor_layout::gemm::ColumnMajor'}
|
||||
|
||||
DEFAULT_EPILOGUE = """
|
||||
using GemmEpilogue = ck_tile::DefaultGemm2DEpilogue<
|
||||
ck_tile::DefaultGemm2DEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
CLayout,
|
||||
kPadM,
|
||||
kPadN,
|
||||
WarpTileM,
|
||||
WarpTileN,
|
||||
WarpTileK,
|
||||
UniversalGemmProblem::TransposeC,
|
||||
true,
|
||||
memory_operation>>;
|
||||
"""
|
||||
|
||||
CSHUFFLE_EPILOGUE = """
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
CLayout,
|
||||
GemmPipelineProblem::kBlockSize,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
WarpM,
|
||||
WarpN,
|
||||
WarpTileM,
|
||||
WarpTileN,
|
||||
WarpTileK,
|
||||
UniversalGemmProblem::TransposeC,
|
||||
memory_operation>>;
|
||||
"""
|
||||
HOT_LOOP_FALSE = """
|
||||
if(tail_num == ck_tile::TailNumber::Full)
|
||||
{
|
||||
RunSplitk(ck_tile::bool_constant<false>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Full>{});
|
||||
}
|
||||
else if(tail_num == ck_tile::TailNumber::Odd)
|
||||
{
|
||||
RunSplitk(ck_tile::bool_constant<false>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Odd>{});
|
||||
}
|
||||
else if(tail_num == ck_tile::TailNumber::Even)
|
||||
{
|
||||
RunSplitk(ck_tile::bool_constant<false>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Even>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Num K loop must be larger than number of prefetech stages.");
|
||||
}
|
||||
"""
|
||||
RUN_MEM = """
|
||||
// Handle One and Full cases directly
|
||||
if (tail_num == ck_tile::TailNumber::One) {
|
||||
RunSplitk(ck_tile::bool_constant<true>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::One>{});
|
||||
} else if (tail_num == ck_tile::TailNumber::Full) {
|
||||
RunSplitk(ck_tile::bool_constant<true>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Full>{});
|
||||
}
|
||||
// Variadic call using fold expression
|
||||
auto check_tail = [&](auto... TNs) {
|
||||
(try_run< BaseGemmPipeline, decltype(TNs)::value>(tail_num), ...);
|
||||
};
|
||||
|
||||
check_tail(
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Two>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Three>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Four>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Five>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Six>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Seven>{}
|
||||
);
|
||||
"""
|
||||
|
||||
RUN_COMPV3 = """
|
||||
if(tail_num == ck_tile::TailNumber::Full)
|
||||
{
|
||||
RunSplitk(ck_tile::bool_constant<true>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Full>{});
|
||||
}
|
||||
else if(tail_num == ck_tile::TailNumber::Odd)
|
||||
{
|
||||
RunSplitk(ck_tile::bool_constant<true>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Odd>{});
|
||||
}
|
||||
else if(tail_num == ck_tile::TailNumber::Even)
|
||||
{
|
||||
RunSplitk(ck_tile::bool_constant<true>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Even>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("The tail number is wrong. It should be Full, Odd, or Even.");
|
||||
}
|
||||
"""
|
||||
|
||||
RUN_COMPV4 = """
|
||||
if(tail_num == ck_tile::TailNumber::Three)
|
||||
{
|
||||
RunSplitk(ck_tile::bool_constant<true>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Three>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
RunSplitk(ck_tile::bool_constant<true>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Two>{});
|
||||
}
|
||||
"""
|
||||
|
||||
|
||||
PIPELINE_MAP = {'mem' : ['ck_tile::BaseGemmPipelineAgBgCrMem', 'ck_tile::GemmPipelineAgBgCrMem'],
|
||||
'compv3' : ['ck_tile::BaseGemmPipelineAgBgCrCompV3', 'ck_tile::GemmPipelineAgBgCrCompV3'],
|
||||
'compv4' : ['ck_tile::BaseGemmPipelineAgBgCrCompV4', 'ck_tile::GemmPipelineAgBgCrCompV4']}
|
||||
|
||||
SCHEDULER_MAP = {'interwave' : 'ck_tile::GemmPipelineScheduler::Interwave',
|
||||
'intrawave' : 'ck_tile::GemmPipelineScheduler::Intrawave'}
|
||||
|
||||
EPILOGUE_MAP = {'default' :DEFAULT_EPILOGUE,
|
||||
'cshuffle' : CSHUFFLE_EPILOGUE}
|
||||
|
||||
HOT_LOOP_TRUE = {'mem' : RUN_MEM,
|
||||
'compv3' : RUN_COMPV3,
|
||||
'compv4' : RUN_COMPV4}
|
||||
|
||||
|
||||
def BOOL_MAP(b_) -> str:
|
||||
if b_:
|
||||
return 'true'
|
||||
else:
|
||||
return 'false'
|
||||
|
||||
@dataclass
|
||||
class GemmConfig:
|
||||
def __init__(self, config_data):
|
||||
self.matrix_cfg : Dict[str, Any] = {}
|
||||
self.impl_cfg : Dict[str, Any] = {}
|
||||
for key, value in config_data.items():
|
||||
if key in ["datatype", "layout_a", "layout_b", "layout_c"]:
|
||||
self.matrix_cfg[key] = value
|
||||
else:
|
||||
self.impl_cfg[key] = value
|
||||
|
||||
@property
|
||||
def datatype(self) -> str:
|
||||
return self.matrix_cfg["datatype"]["values"][0]
|
||||
|
||||
@property
|
||||
def layouts(self) -> List[str]:
|
||||
return [
|
||||
self.matrix_cfg["layout_a"]["values"][0],
|
||||
self.matrix_cfg["layout_b"]["values"][0],
|
||||
self.matrix_cfg["layout_c"]["values"][0]
|
||||
]
|
||||
|
||||
>>>>>>> origin
|
||||
|
||||
class GemmCodeGenerator:
|
||||
def __init__(self, output_dir: str, user_provided_config: Optional[GemmConfig] = None):
|
||||
@@ -222,8 +44,7 @@ class GemmCodeGenerator:
|
||||
list_f.write(str(w_p / f"gemm_{trait}.hpp") + "\n")
|
||||
|
||||
def _generate_all_traits(self):
|
||||
params = ["pipeline", "epilogue", "scheduler", "pad_m", "pad_n", "pad_k"
|
||||
]
|
||||
params = ["pipeline", "epilogue", "scheduler", "pad_m", "pad_n", "pad_k"]
|
||||
|
||||
# To remove some unsupported combinations
|
||||
unsupported_combinations = {
|
||||
@@ -341,9 +162,9 @@ template <int TileM, int TileN, int TileK,
|
||||
int WarpTileM, int WarpTileN, int WarpTileK,
|
||||
bool structured_sparsity>
|
||||
struct GemmKernel {{
|
||||
static constexpr bool pad_m = {BOOL_MAP(pad_m)};
|
||||
static constexpr bool pad_n = {BOOL_MAP(pad_n)};
|
||||
static constexpr bool pad_k = {BOOL_MAP(pad_k)};
|
||||
static constexpr bool kPadM = {BOOL_MAP(pad_m)};
|
||||
static constexpr bool kPadN = {BOOL_MAP(pad_n)};
|
||||
static constexpr bool kPadK = {BOOL_MAP(pad_k)};
|
||||
|
||||
static float launch(ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s) {{
|
||||
static constexpr bool permuteA = false;
|
||||
@@ -369,7 +190,7 @@ struct GemmKernel {{
|
||||
TileParitionerM01>;
|
||||
|
||||
using Traits =
|
||||
ck_tile::TileGemmTraits<pad_m, pad_n, pad_k, ALayout, BLayout, CLayout>;
|
||||
ck_tile::TileGemmTraits<kPadM, kPadN, kPadK, ALayout, BLayout, CLayout>;
|
||||
|
||||
using GemmUniversalTraits =
|
||||
ck_tile::TileGemmUniversalTraits<kPadM, kPadN, kPadK, DoubleSmemBuffer,
|
||||
@@ -539,13 +360,13 @@ struct GemmDispatcher {
|
||||
|
||||
|
||||
for trait in self.all_trait_names:
|
||||
content += f""" kernel_map["{trait}"] = [=]((Profiler& profiler,
|
||||
content += f""" kernel_map["{trait}"] = [=]( GemmProfiler& profiler,
|
||||
ck_tile::DeviceMem& c_m_n_dev_buf,
|
||||
ck_tile::HostTensor<CDataType>& c_m_n_host_result,
|
||||
ck_tile::HostTensor<CDataType>& c_m_n_dev_result,
|
||||
int verify,
|
||||
ck_tile::GemmHostArgs& args,
|
||||
const ck_tile::stream_config& s) {{
|
||||
const ck_tile::stream_config& stream) {{
|
||||
if(structured_sparsity){{ // SMFMA"""
|
||||
for tile in tile_params:
|
||||
# Check if we have valid tile/warp combinations
|
||||
@@ -557,7 +378,7 @@ struct GemmDispatcher {
|
||||
((tile[6] == 32 and tile[7] == 32 and tile[8] == 16) or
|
||||
(tile[6] == 16 and tile[7] == 16 and tile[8] == 32))
|
||||
content += f"""
|
||||
profiler.run_kernel<{group}::GemmKernel<{tile[0]}, {tile[1]}, {tile[2]}, {tile[3]}, {tile[4]}, {tile[5]}, {tile[6]}, {tile[7]}, {tile[8]}, {BOOL_MAP(sparse)}>>(c_m_n_dev_buf, c_m_n_host_result, c_m_n_dev_result, verify, args, stream);"""
|
||||
profiler.benchmark_kernel<{trait}::GemmKernel<{tile[0]}, {tile[1]}, {tile[2]}, {tile[3]}, {tile[4]}, {tile[5]}, {tile[6]}, {tile[7]}, {tile[8]}, {BOOL_MAP(sparse)}>>(c_m_n_dev_buf, c_m_n_host_result, c_m_n_dev_result, verify, args, stream);"""
|
||||
content += f"""
|
||||
}} else {{"""
|
||||
for tile in tile_params:
|
||||
@@ -567,7 +388,7 @@ struct GemmDispatcher {
|
||||
((tile[1]/(tile[4] * tile[8]) * tile[4] * tile[8]) != tile[1]):
|
||||
continue
|
||||
content += f"""
|
||||
profiler.run_kernel<{trait}::GemmKernel<{tile[0]}, {tile[1]}, {tile[2]}, {tile[3]}, {tile[4]}, {tile[5]}, {tile[6]}, {tile[7]}, {tile[8]}, {BOOL_MAP(False)}>>(c_m_n_dev_buf, c_m_n_host_result, c_m_n_dev_result, verify, args, stream);"""
|
||||
profiler.benchmark_kernel<{trait}::GemmKernel<{tile[0]}, {tile[1]}, {tile[2]}, {tile[3]}, {tile[4]}, {tile[5]}, {tile[6]}, {tile[7]}, {tile[8]}, {BOOL_MAP(False)}>>(c_m_n_dev_buf, c_m_n_host_result, c_m_n_dev_result, verify, args, stream);"""
|
||||
content += f"""
|
||||
}}
|
||||
}};\n"""
|
||||
@@ -579,16 +400,17 @@ struct GemmDispatcher {
|
||||
ck_tile::HostTensor<CDataType>& c_m_n_dev_result,
|
||||
int verify,
|
||||
int metric,
|
||||
bool structured_sparsity,
|
||||
const KernelTraits& trait,
|
||||
ck_tile::GemmHostArgs& gemm_args,
|
||||
const ck_tile::stream_config& s) {
|
||||
init();
|
||||
const ck_tile::stream_config& stream) {
|
||||
init(structured_sparsity);
|
||||
const std::string key = assemble_key(trait);
|
||||
auto& kernel_map = get_kernel_map();
|
||||
auto& profiler = GemmProfiler::instance();
|
||||
if(auto it = kernel_map.find(key); it != kernel_map.end()) {
|
||||
it->second(
|
||||
profiler, c_m_n_dev_buf, c_m_n_host_result, c_m_n_dev_result, verify, gemm_args, s);
|
||||
profiler, c_m_n_dev_buf, c_m_n_host_result, c_m_n_dev_result, verify, gemm_args, stream);
|
||||
profiler.select_best_instance(static_cast<Metric>(metric));
|
||||
return;
|
||||
}
|
||||
@@ -662,7 +484,7 @@ if __name__ == "__main__":
|
||||
"-l", "--list_blobs", action = 'store_true', help="List all kernel instances to file"
|
||||
)
|
||||
parser.add_argument(
|
||||
"-g", "--gen_blobs", action = 'store_true', help="Generate all kernels instances into different files"
|
||||
"-g", "--gen_blobs", action = 'store_true', help="Generate all kernel instances into different files"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
Reference in New Issue
Block a user