mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 10:09:41 +00:00
[CK_TILE] Refine fp8 support in flatmm (#2239)
* [CK_TILE] Refine fp8 in flatmm
1. Replace USING_MFMA_16x16x32 & USING_MFMA_16x16x32 with constexpr
2. Add an additional const check to avoid build error in HotLoopScheduler
3. Refine shuffleb to support both tile 32x32 and 16x16
4. Support command option -init
5. Move Gemm warp defintion to a separate struct
* fix clang format
* fix clang format
* keep default bhavior unchanged (warp tile = 16x16)
* fix tile engine build error
* fix a typo in codegen_utils.py
* address review comments
* address review comments
---------
Co-authored-by: Thomas Ning <Thomas.Ning@amd.com>
[ROCm/composable_kernel commit: 37e1a27537]
This commit is contained in:
@@ -44,9 +44,12 @@ CSHUFFLE_EPILOGUE = """
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
ck_tile::tuple<>,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ck_tile::tuple<>,
|
||||
CLayout,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
GemmPipelineProblem::kBlockSize,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
|
||||
@@ -193,7 +193,7 @@ struct GemmKernel {{
|
||||
static constexpr bool kPadN = {pad_n};
|
||||
static constexpr bool kPadK = {pad_k};
|
||||
|
||||
static float launch(ck_tile::GemmHostArgs& args, const ck_tile::stream_config& stream) {{
|
||||
static float launch(ck_tile::GemmHostArgs<>& args, const ck_tile::stream_config& stream) {{
|
||||
static constexpr bool permuteA = false;
|
||||
static constexpr bool permuteB = false;
|
||||
static constexpr bool DoubleSmemBuffer ={"true" if pipeline == "compv4" else "false"};
|
||||
@@ -306,7 +306,7 @@ struct GemmKernel {{
|
||||
// clear c mem
|
||||
if(args.k_batch > 1)
|
||||
hipGetErrorString(hipMemsetAsync(
|
||||
args.c_ptr, 0, args.M * args.N * sizeof(CDataType), stream.stream_id_));
|
||||
args.e_ptr, 0, args.M * args.N * sizeof(CDataType), stream.stream_id_));
|
||||
}};
|
||||
ave_time = ck_tile::launch_kernel_preprocess(
|
||||
stream,
|
||||
@@ -570,12 +570,13 @@ struct GemmDispatcher {
|
||||
// Use a static local variable
|
||||
static std::unordered_map<
|
||||
std::string,
|
||||
std::vector<std::function<std::tuple<std::string, float>(ck_tile::GemmHostArgs&, const ck_tile::stream_config&)>>>
|
||||
std::vector<std::function<std::tuple<std::string, float>(ck_tile::GemmHostArgs<>&, const ck_tile::stream_config&)>>>
|
||||
kernel_map;
|
||||
return kernel_map;
|
||||
}
|
||||
|
||||
static void init(bool structured_sparsity) {
|
||||
ck_tile::ignore = structured_sparsity;
|
||||
auto& kernel_map = get_kernel_map();
|
||||
if(!kernel_map.empty()) return;
|
||||
\n"""
|
||||
@@ -586,7 +587,7 @@ struct GemmDispatcher {
|
||||
for j in range(len(tile)):
|
||||
tile_m, tile_n, tile_k, warp_m, warp_n, warp_k, warp_tile_m, warp_tile_n, warp_tile_k = tile[
|
||||
j]
|
||||
content += f"""[=](ck_tile::GemmHostArgs& args, const ck_tile::stream_config& stream) {{ """
|
||||
content += f"""[=](ck_tile::GemmHostArgs<>& args, const ck_tile::stream_config& stream) {{ """
|
||||
content += f"""
|
||||
if(structured_sparsity){{ // SMFMA"""
|
||||
sparse = self.config.problem.datatype_map['matrix_a'] == 'fp16' and \
|
||||
@@ -615,7 +616,7 @@ struct GemmDispatcher {
|
||||
content += """ }
|
||||
|
||||
template <typename Kernel>
|
||||
static std::tuple<std::string, float> run_kernel(ck_tile::GemmHostArgs& args, const ck_tile::stream_config& stream)
|
||||
static std::tuple<std::string, float> run_kernel(ck_tile::GemmHostArgs<>& args, const ck_tile::stream_config& stream)
|
||||
{
|
||||
std::string name = Kernel::get_name();
|
||||
float avg_time = Kernel::launch(args, stream);
|
||||
|
||||
@@ -22,7 +22,7 @@ class GemmProfiler
|
||||
|
||||
void benchmark(GemmProblem& gemm_problem,
|
||||
std::vector<std::function<std::tuple<std::string, float>(
|
||||
ck_tile::GemmHostArgs&, const ck_tile::stream_config&)>>& callables)
|
||||
ck_tile::GemmHostArgs<>&, const ck_tile::stream_config&)>>& callables)
|
||||
{
|
||||
const ALayout layout_a = ALayout{};
|
||||
const BLayout layout_b = BLayout{};
|
||||
@@ -89,7 +89,7 @@ class GemmProfiler
|
||||
c_m_n_dev_buf.SetZero();
|
||||
c_m_n_dev_result.SetZero();
|
||||
|
||||
ck_tile::GemmHostArgs gemm_args;
|
||||
ck_tile::GemmHostArgs<> gemm_args;
|
||||
gemm_args.a_ptr = a_m_k_dev_buf.GetDeviceBuffer();
|
||||
gemm_args.b_ptr = b_k_n_dev_buf.GetDeviceBuffer();
|
||||
gemm_args.c_ptr = c_m_n_dev_buf.GetDeviceBuffer();
|
||||
|
||||
Reference in New Issue
Block a user