merge M grouped flatmm

This commit is contained in:
lalala-sh
2025-07-30 07:55:09 +00:00
parent 1b493fac62
commit c4aa2fef46
9 changed files with 1797 additions and 4 deletions

View File

@@ -1,4 +1,5 @@
add_executable(tile_example_flatmm_basic EXCLUDE_FROM_ALL flatmm_basic.cpp)
add_executable(tile_example_grouped_flatmm EXCLUDE_FROM_ALL grouped_flatmm.cpp)
set(EXAMPLE_FLATMM_COMPILE_OPTIONS)
@@ -11,3 +12,4 @@ list(APPEND EXAMPLE_FLATMM_COMPILE_OPTIONS --save-temps -Wno-nrvo)
#list(APPEND EXAMPLE_FLATMM_COMPILE_OPTIONS "SHELL: -mllvm -greedy-reverse-local-assignment=1 -mllvm --slp-threshold=-32 -mllvm -enable-noalias-to-md-conversion=0")
#list(APPEND EXAMPLE_FLATMM_COMPILE_OPTIONS "SHELL: -mllvm -greedy-reverse-local-assignment=1 -mllvm --slp-threshold=-32 -mllvm --disable-schedmodel-in-sched-mi=1 -mllvm --amdgpu-sched-strategy=gcn-iterative-max-occupancy-experimental -mllvm --misched-bottomup=1")
target_compile_options(tile_example_flatmm_basic PRIVATE ${EXAMPLE_FLATMM_COMPILE_OPTIONS})
target_compile_options(tile_example_grouped_flatmm PRIVATE ${EXAMPLE_FLATMM_COMPILE_OPTIONS})

View File

@@ -0,0 +1,382 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#include <hip/hip_runtime.h>
#include <cstring>
#include <iostream>
#include <ostream>
#include <string>
#include <tuple>
#include "flatmm_basic.hpp"
#include "ck_tile/host.hpp"
template <typename Layout>
static constexpr inline auto is_row_major(Layout layout_)
{
return ck_tile::bool_constant<std::is_same_v<ck_tile::remove_cvref_t<decltype(layout_)>,
ck_tile::tensor_layout::gemm::RowMajor>>{};
}
auto create_args(int argc, char* argv[])
{
ck_tile::ArgParser arg_parser;
arg_parser.insert("Ms", "512,256,1024", "m dimension")
.insert("Ns", "512,512,512", "n dimension")
.insert("Ks", "1024,1024,512", "k dimension")
.insert("group_count", "3", "group count")
.insert("a_layout", "R", "A tensor data layout - Row by default")
.insert("b_layout", "C", "B tensor data layout - Row by default")
.insert("c_layout", "R", "C tensor data layout - Row by default")
.insert("stride_a", "0", "Tensor A stride")
.insert("stride_b", "0", "Tensor B stride")
.insert("stride_c", "0", "Tensor C stride")
.insert("v", "1", "0. No validation, 1. Validation on CPU, 2. Validation on GPU")
.insert("prec", "fp8", "data type. fp16/bf16/fp8/bf8")
.insert("mode", "general", "grouped gemm mode: [general | contiguous], general by default")
.insert("wave_tile", "16", "only support 16(16x16) or 32(32x32)")
.insert("warmup", "50", "number of iterations before benchmark the kernel")
.insert("repeat", "100", "number of iterations to benchmark the kernel")
.insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer")
.insert("split_k", "1", "splitK value")
.insert("init", "0", "0:random, 1:linear, 2:constant(1)")
.insert("scale", "0", "0:without scale, 1:per-token/channel scale, only for fp8/bf8")
.insert("warp_tile",
"0",
"0: 16x16, 1: 32x32, 2: 16x16x128 (950 only), 3: 32x32x64 (950 only)");
bool result = arg_parser.parse(argc, argv);
return std::make_tuple(result, arg_parser);
}
template <typename FlatmmConfig,
typename ADataType,
typename BDataType,
typename DsDatatype,
typename AccDataType,
typename CDataType,
typename ALayout,
typename BLayout,
typename DsLayout,
typename ELayout,
bool persistent,
typename CDEElementWise,
typename KernelArguments>
float grouped_flatmm(const KernelArguments& args, const ck_tile::stream_config& s)
{
using CodegenFlatmmShape = ck_tile::TileGemmShape<
ck_tile::sequence<FlatmmConfig::M_Tile, FlatmmConfig::N_Tile, FlatmmConfig::K_Tile>,
ck_tile::sequence<FlatmmConfig::M_Warp, FlatmmConfig::N_Warp, FlatmmConfig::K_Warp>,
ck_tile::sequence<FlatmmConfig::M_Warp_Tile,
FlatmmConfig::N_Warp_Tile,
FlatmmConfig::K_Warp_Tile>>;
using TilePartitioner =
ck_tile::GemmSpatiallyLocalTilePartitioner<CodegenFlatmmShape,
FlatmmConfig::TileParitionerGroupNum,
FlatmmConfig::TileParitionerM01>;
using Traits = ck_tile::TileGemmTraits<FlatmmConfig::kPadM,
FlatmmConfig::kPadN,
FlatmmConfig::kPadK,
ALayout,
BLayout,
ELayout,
FlatmmConfig::NumWaveGroups>;
using CodegenGemmTraits = ck_tile::TileGemmUniversalTraits<FlatmmConfig::kPadM,
FlatmmConfig::kPadN,
FlatmmConfig::kPadK,
FlatmmConfig::DoubleSmemBuffer,
ALayout,
BLayout,
ELayout,
FlatmmConfig::TransposeC,
FlatmmConfig::UseStructuredSparsity,
persistent,
FlatmmConfig::NumWaveGroups,
true>;
using GemmPipelineProblem =
ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, CodegenFlatmmShape, Traits>;
using BaseGemmPipeline = ck_tile::BaseFlatmmPipelineAGmemBGmemCRegV1<GemmPipelineProblem>;
const ck_tile::index_t k_grain = args.k_batch * FlatmmConfig::K_Tile;
const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * FlatmmConfig::K_Tile;
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split);
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
float ave_time{0};
const auto Run = [&](const auto has_hot_loop_,
const auto tail_number_,
const auto memory_operation_) {
constexpr bool has_hot_loop_v = has_hot_loop_.value;
constexpr auto tail_number_v = tail_number_.value;
constexpr auto scheduler = FlatmmConfig::Scheduler;
constexpr auto memory_operation = memory_operation_.value;
using CodegenPipelineProblem = ck_tile::FlatmmPipelineProblem<ADataType,
BDataType,
AccDataType,
CodegenFlatmmShape,
CodegenGemmTraits,
scheduler,
has_hot_loop_v,
tail_number_v>;
using CodegenFlatmmPipeline =
ck_tile::FlatmmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem>;
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<ADataType,
BDataType,
DsDatatype,
AccDataType,
CDataType,
DsLayout,
ELayout,
CDEElementWise,
CodegenPipelineProblem::kBlockSize,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
FlatmmConfig::M_Warp,
FlatmmConfig::N_Warp,
FlatmmConfig::M_Warp_Tile,
FlatmmConfig::N_Warp_Tile,
FlatmmConfig::K_Warp_Tile,
CodegenPipelineProblem::TransposeC,
memory_operation,
FlatmmConfig::NumWaveGroups>>;
// ToDo: Will add the codegen part to test different pipeline policies in GEMM.
// Now we only use the BlockGemmASmemBSmemCRegV1DefaultPolicy.
using Kernel =
ck_tile::GroupedFlatmmKernel<TilePartitioner, CodegenFlatmmPipeline, GemmEpilogue>;
auto kargs = Kernel::MakeKernelArgs(args);
const dim3 grids = Kernel::GridSize(kargs);
constexpr dim3 blocks = Kernel::BlockSize();
// if(!Kernel::IsSupportedArgument(kargs))
// {
// throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n");
// }
// if(s.flush_cache_)
// {
// std::cout << "Flushing cache..." << std::endl;
// static constexpr ck_tile::index_t APackedSize =
// std::is_same_v<BDataType, ck_tile::pk_int4_t> ? 2 : 1;
// static constexpr ck_tile::index_t BPackedSize =
// std::is_same_v<BDataType, ck_tile::pk_int4_t> ? 2 : 1;
// ck_tile::HostTensor<ADataType> a_m(ck_tile::host_tensor_descriptor(
// args.group_count * args.M, args.K, args.stride_A, is_row_major(ALayout{})));
// ck_tile::HostTensor<BDataType> b_n(ck_tile::host_tensor_descriptor(
// args.K, args.group_count * args.N, args.stride_B, is_row_major(BLayout{})));
// auto size_a_buffer = a_m.get_element_space_size_in_bytes() / APackedSize;
// auto size_b_buffer = b_n.get_element_space_size_in_bytes() / BPackedSize;
// ck_tile::RotatingMemWrapper<ADataType, BDataType> rotating_mem(
// kargs.a_ptr, kargs.b_shuffle_ptr, s.rotating_count_, size_a_buffer, size_b_buffer);
// rotating_mem.Print();
// auto run_flush_cache = [&]() {
// // flush icache
// ck_tile::flush_icache();
// // rotating mem
// rotating_mem.Next();
// // clear c mem
// if(args.k_batch > 1)
// hipGetErrorString(hipMemsetAsync(
// args.e_ptr, 0, args.group_count * args.M * args.N * sizeof(CDataType), s.stream_id_));
// };
// ave_time = ck_tile::launch_kernel_preprocess(
// s,
// run_flush_cache,
// ck_tile::make_kernel<blocks.x, FlatmmConfig::kBlockPerCu>(
// Kernel{}, grids, blocks, 0, kargs));
// }
// else
// {
ave_time =
ck_tile::launch_kernel(s,
ck_tile::make_kernel<blocks.x, FlatmmConfig::kBlockPerCu>(
Kernel{}, grids, blocks, 0, kargs));
// }
return ave_time;
};
const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) {
if(args.k_batch == 1)
{
Run(has_hot_loop_,
tail_number_,
ck_tile::integral_constant<ck_tile::memory_operation_enum,
ck_tile::memory_operation_enum::set>{});
}
else
{
Run(has_hot_loop_,
tail_number_,
ck_tile::integral_constant<ck_tile::memory_operation_enum,
ck_tile::memory_operation_enum::atomic_add>{});
}
};
BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num);
return ave_time;
}
#include "run_grouped_flatmm_example.inc"
template<template <typename PreType> typename FlatmmConfig>
int run_grouped_flatmm_example(int argc, char* argv[])
{
auto [result, arg_parser] = create_args(argc, argv);
if(!result)
return -1;
using Row = ck_tile::tensor_layout::gemm::RowMajor;
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
std::string data_type = arg_parser.get_str("prec");
std::string mode = arg_parser.get_str("mode");
std::string a_layout = arg_parser.get_str("a_layout");
std::string b_layout = arg_parser.get_str("b_layout");
int scale_opt = arg_parser.get_int("scale");
if(a_layout == "R" && b_layout == "C")
{
if(mode == "general")
{
// if(data_type == "fp16")
// {
// run_grouped_flatmm_example_with_layouts<ck_tile::half_t, FlatmmConfig<ck_tile::half_t>>(
// argc, argv, Row{}, Col{}, Row{});
// }
// else if(data_type == "bf16")
// {
// run_grouped_flatmm_example_with_layouts<ck_tile::bf16_t, FlatmmConfig<ck_tile::bf16_t>>(
// argc, argv, Row{}, Col{}, Row{});
// }
// else if(data_type == "fp8")
// {
// run_grouped_flatmm_example_with_layouts<ck_tile::fp8_t, FlatmmConfig<ck_tile::fp8_t>>(
// argc, argv, Row{}, Col{}, Row{});
// }
// else if(data_type == "bf8")
// {
// run_grouped_flatmm_example_with_layouts<ck_tile::bf8_t, FlatmmConfig<ck_tile::bf8_t>>(
// argc, argv, Row{}, Col{}, Row{});
// }
// else
// {
// throw std::runtime_error("Unsupported data_type!");
// }
}
else if(mode == "contiguous")
{
if(data_type == "fp16")
{
run_contiguous_grouped_flatmm_example_with_layouts<ck_tile::half_t, FlatmmConfig<ck_tile::half_t>>(
argc, argv, Row{}, Col{}, Row{});
}
else if(data_type == "bf16")
{
run_contiguous_grouped_flatmm_example_with_layouts<ck_tile::bf16_t, FlatmmConfig<ck_tile::bf16_t>>(
argc, argv, Row{}, Col{}, Row{});
}
else if(data_type == "fp8")
{
run_contiguous_grouped_flatmm_example_with_layouts<ck_tile::fp8_t, FlatmmConfig<ck_tile::fp8_t>>(
argc, argv, Row{}, Col{}, Row{});
}
else if(data_type == "bf8")
{
run_contiguous_grouped_flatmm_example_with_layouts<ck_tile::bf8_t, FlatmmConfig<ck_tile::bf8_t>>(
argc, argv, Row{}, Col{}, Row{});
}
else
{
throw std::runtime_error("Unsupported data_type!");
}
}
else if(mode == "masked")
{
if(data_type == "fp16")
{
run_masked_grouped_flatmm_example_with_layouts<ck_tile::half_t, FlatmmConfig<ck_tile::half_t>>(
argc, argv, Row{}, Col{}, Row{});
}
else if(data_type == "bf16")
{
run_masked_grouped_flatmm_example_with_layouts<ck_tile::bf16_t, FlatmmConfig<ck_tile::bf16_t>>(
argc, argv, Row{}, Col{}, Row{});
}
else if(data_type == "fp8")
{
run_masked_grouped_flatmm_example_with_layouts<ck_tile::fp8_t, FlatmmConfig<ck_tile::fp8_t>>(
argc, argv, Row{}, Col{}, Row{});
}
else if(data_type == "bf8")
{
run_masked_grouped_flatmm_example_with_layouts<ck_tile::bf8_t, FlatmmConfig<ck_tile::bf8_t>>(
argc, argv, Row{}, Col{}, Row{});
}
else
{
throw std::runtime_error("Unsupported data_type!");
}
}
else
{
throw std::runtime_error("Unsupported mode!");
}
}
else
{
throw std::runtime_error("Unsupported data layout configuration for A,B and C tensors!");
}
return -1;
}
int main(int argc, char* argv[])
{
auto [result, arg_parser] = create_args(argc, argv);
if(!result)
return EXIT_FAILURE;
try
{
int warp_tile = arg_parser.get_int("warp_tile");
if(warp_tile == 0)
{
return !run_grouped_flatmm_example<FlatmmConfig16>(argc, argv);
}
// else if(warp_tile == 1)
// {
// return !run_grouped_flatmm_example<FlatmmConfig32>(argc, argv);
// }
// else if(warp_tile == 2)
// {
// return !run_grouped_flatmm_example<FlatmmConfig16_950>(argc, argv);
// }
// else
// {
// return !run_grouped_flatmm_example<FlatmmConfig32_950>(argc, argv);
// }
}
catch(const std::runtime_error& e)
{
std::cerr << "Runtime error: " << e.what() << '\n';
return EXIT_FAILURE;
}
}

View File

@@ -0,0 +1,935 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
// mfma_type, 0:32x32, 1:16x16
template <typename FlatmmConfig, typename T>
auto shuffle_b(const ck_tile::HostTensor<T>& t)
{
assert(t.get_lengths().size() == 2);
int n_ = t.get_lengths()[1];
int k_ = t.get_lengths()[0];
constexpr int divisor = FlatmmConfig::N_Warp_Tile == 32 ? 2 : 4;
ck_tile::HostTensor<T> t_view({n_ / FlatmmConfig::N_Warp_Tile,
FlatmmConfig::N_Warp_Tile,
k_ / FlatmmConfig::K_Warp_Tile,
divisor,
FlatmmConfig::K_Warp_Tile / divisor});
std::copy(t.begin(), t.end(), t_view.begin());
return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4});
}
template <typename ADataType, typename BDataType, typename AccDataType, typename CDataType>
auto calculate_rtol_atol(const ck_tile::index_t K,
const ck_tile::index_t kbatch,
const float max_accumulated_value)
{
using ComputeType =
std::conditional_t<sizeof(ADataType) < sizeof(BDataType), ADataType, BDataType>;
// Calculate thresholds
const auto rtol = ck_tile::get_relative_threshold<ComputeType, CDataType, AccDataType>(
ck_tile::integer_divide_ceil(K, kbatch));
const auto atol = ck_tile::get_absolute_threshold<ComputeType, CDataType, AccDataType>(
max_accumulated_value / kbatch, ck_tile::integer_divide_ceil(K, kbatch));
// Calculate error due to split_k accumulation
const auto rtol_split_k =
ck_tile::get_relative_threshold<CDataType, CDataType, CDataType>(kbatch);
const auto atol_split_k = ck_tile::get_absolute_threshold<CDataType, CDataType, CDataType>(
max_accumulated_value, kbatch);
// Use higher threshold
return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k));
}
// template <typename FlatmmConfig,
// typename ADataType,
// typename BDataType,
// typename DsDatatype,
// typename AccDataType,
// typename CDataType,
// typename ALayout,
// typename BLayout,
// typename DsLayout,
// typename CLayout,
// typename ScaleM,
// typename ScaleN,
// typename CDEElementWise = ck_tile::element_wise::PassThrough>
// float invoke_gemm(int n_warmup, int n_repeat, const ck_tile::GroupedFlatmmHostArgs<ScaleM, ScaleN>& args)
// {
// float ave_time =
// grouped_flatmm<FlatmmConfig,
// ADataType,
// BDataType,
// DsDatatype,
// AccDataType,
// CDataType,
// ALayout,
// BLayout,
// DsLayout,
// CLayout,
// false,
// CDEElementWise>(
// args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat});
// std::string op_name{"Grouped Gemm"};
// std::size_t flop = 0, num_btype = 0;
// for(int j = 0; j < args.group_count; ++j)
// {
// flop += std::size_t(2) * args.M[j] * args.N[j] * args.K[j];
// num_btype += sizeof(ADataType) * args.M[j] * args.K[j] +
// sizeof(BDataType) * args.K[j] * args.N[j] +
// sizeof(CDataType) * args.M[j] * args.N[j];
// }
// float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
// float gb_per_sec = num_btype / 1.E6 / ave_time;
// std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops << " TFlops, "
// << gb_per_sec << " GB/s, " << op_name << std::endl;
// return ave_time;
// }
template <typename FlatmmConfig,
typename ADataType,
typename BDataType,
typename DsDatatype,
typename AccDataType,
typename CDataType,
typename ALayout,
typename BLayout,
typename DsLayout,
typename CLayout,
typename ScaleM,
typename ScaleN,
typename CDEElementWise = ck_tile::element_wise::PassThrough>
float invoke_gemm(int n_warmup, int n_repeat, const ck_tile::ContiguousGroupedFlatmmHostArgs<ScaleM, ScaleN>& args)
{
float ave_time =
grouped_flatmm<FlatmmConfig,
ADataType,
BDataType,
DsDatatype,
AccDataType,
CDataType,
ALayout,
BLayout,
DsLayout,
CLayout,
false,
CDEElementWise>(
args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat});
std::string op_name{"Grouped Gemm"};
std::size_t flop = std::size_t(2) * args.M * args.N * args.K;
std::size_t num_byte = sizeof(ADataType) * args.M * args.K +
sizeof(BDataType) * args.N * args.K +
sizeof(CDataType) * args.M * args.N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_byte / 1.E6 / ave_time;
std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops << " TFlops, "
<< gb_per_sec << " GB/s, " << op_name << std::endl;
return ave_time;
}
template <typename FlatmmConfig,
typename ADataType,
typename BDataType,
typename DsDatatype,
typename AccDataType,
typename CDataType,
typename ALayout,
typename BLayout,
typename DsLayout,
typename CLayout,
typename ScaleM,
typename ScaleN,
typename CDEElementWise = ck_tile::element_wise::PassThrough>
float invoke_gemm(int n_warmup, int n_repeat, int val_m, const ck_tile::MaskedGroupedFlatmmHostArgs<ScaleM, ScaleN>& args)
{
float ave_time =
grouped_flatmm<FlatmmConfig,
ADataType,
BDataType,
DsDatatype,
AccDataType,
CDataType,
ALayout,
BLayout,
DsLayout,
CLayout,
false,
CDEElementWise>(
args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat});
std::string op_name{"Grouped Gemm"};
std::size_t flop = std::size_t(2) * val_m * args.N * args.K;
std::size_t num_byte = sizeof(ADataType) * val_m * args.K +
sizeof(BDataType) * args.N * args.K * args.group_count +
sizeof(CDataType) * val_m * args.N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_byte / 1.E6 / ave_time;
std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops << " TFlops, "
<< gb_per_sec << " GB/s, " << op_name << std::endl;
return ave_time;
}
// template <typename PrecType,
// typename FlatmmConfig,
// int ScaleGranularityM = -1,
// int ScaleGranularityN = -1,
// typename ALayout,
// typename BLayout,
// typename CLayout>
// int run_grouped_flatmm_example_with_layouts(int argc,
// char* argv[],
// const ALayout a_layout = ALayout{},
// const BLayout b_layout = BLayout{},
// [[maybe_unused]] const CLayout c_layout = CLayout{})
// {
// auto [result, arg_parser] = create_args(argc, argv);
// if(!result)
// {
// return -1;
// };
// using ADataType = typename GemmBasicTypeConfig<PrecType>::ADataType;
// using BDataType = typename GemmBasicTypeConfig<PrecType>::BDataType;
// using CDataType = typename GemmBasicTypeConfig<PrecType>::CDataType;
// using AccDataType = typename GemmBasicTypeConfig<PrecType>::AccDataType;
// const int group_count = arg_parser.get_int("group_count");
// const int repeat = arg_parser.get_int("repeat");
// const int warmup = arg_parser.get_int("warmup");
// std::vector<ck_tile::index_t> Ms = arg_parser.get_int_vec("Ms");
// std::vector<ck_tile::index_t> Ns = arg_parser.get_int_vec("Ns");
// std::vector<ck_tile::index_t> Ks = arg_parser.get_int_vec("Ks");
// std::vector<ck_tile::index_t> stride_As;
// std::vector<ck_tile::index_t> stride_Bs;
// std::vector<ck_tile::index_t> stride_Cs;
// ck_tile::index_t kbatch = arg_parser.get_int("split_k");
// std::vector<ck_tile::HostTensor<ADataType>> a_m_k_tensors;
// std::vector<ck_tile::HostTensor<BDataType>> b_k_n_tensors;
// std::vector<ck_tile::HostTensor<CDataType>> c_m_n_tensors;
// std::vector<ck_tile::HostTensor<AccDataType>> per_token_scales;
// std::vector<ck_tile::HostTensor<AccDataType>> per_channel_scales;
// std::vector<std::unique_ptr<ck_tile::DeviceMem>> a_m_k_dev_buf;
// std::vector<std::unique_ptr<ck_tile::DeviceMem>> b_shfl_dev_buf;
// std::vector<std::unique_ptr<ck_tile::DeviceMem>> c_m_n_dev_buf;
// std::vector<std::unique_ptr<ck_tile::DeviceMem>> per_token_scales_dev_buf;
// std::vector<std::unique_ptr<ck_tile::DeviceMem>> per_channel_scales_dev_buf;
// std::vector<void*> group_a_ptrs;
// std::vector<void*> group_b_ptrs;
// std::vector<void*> group_c_ptrs;
// std::vector<ck_tile::FlatmmScalePointer<ScaleGranularityM>> group_scale_a_ptrs;
// std::vector<ck_tile::FlatmmScalePointer<ScaleGranularityN>> group_scale_b_ptrs;
// if(!(int(Ms.size()) == group_count && int(Ns.size()) == group_count &&
// int(Ks.size()) == group_count))
// {
// std::cout << "Please check the input data." << std::endl;
// for(int i = 0; i < group_count; i++)
// {
// Ms.push_back(256 + 256 * i);
// Ns.push_back(128 + 128 * i);
// Ks.push_back(512 + 512 * i);
// }
// }
// for(int i = 0; i < group_count; ++i)
// {
// const ck_tile::index_t M = Ms[i];
// const ck_tile::index_t N = Ns[i];
// const ck_tile::index_t K = Ks[i];
// stride_As.push_back(ck_tile::get_default_stride(M, K, 0, is_row_major(a_layout)));
// stride_Bs.push_back(ck_tile::get_default_stride(K, N, 0, is_row_major(b_layout)));
// stride_Cs.push_back(ck_tile::get_default_stride(M, N, 0, is_row_major(c_layout)));
// a_m_k_tensors.push_back(ck_tile::HostTensor<ADataType>(
// ck_tile::host_tensor_descriptor(M, K, stride_As[i], is_row_major(a_layout))));
// b_k_n_tensors.push_back(ck_tile::HostTensor<BDataType>(
// ck_tile::host_tensor_descriptor(K, N, stride_Bs[i], is_row_major(b_layout))));
// c_m_n_tensors.push_back(ck_tile::HostTensor<CDataType>(
// ck_tile::host_tensor_descriptor(M, N, stride_Cs[i], is_row_major(c_layout))));
// per_token_scales.push_back(ck_tile::HostTensorDescriptor({M}, {1}));
// per_channel_scales.push_back(ck_tile::HostTensorDescriptor({N}, {1}));
// std::cout << "gemm[" << i << "]"
// << " a_m_k: " << a_m_k_tensors[i].mDesc << " b_k_n: " << b_k_n_tensors[i].mDesc
// << " c_m_n: " << c_m_n_tensors[i].mDesc << std::endl;
// ck_tile::FillUniformDistribution<ADataType>{-1.f, 1.f}(a_m_k_tensors[i]);
// ck_tile::FillUniformDistribution<BDataType>{-4.f, 4.f}(b_k_n_tensors[i]);
// ck_tile::FillUniformDistribution<ADataType>{-1.f, 1.f}(per_token_scales[i]);
// ck_tile::FillUniformDistribution<ADataType>{-1.f, 1.f}(per_channel_scales[i]);
// ck_tile::HostTensor<BDataType> b_shuffle_host = shuffle_b<FlatmmConfig, BDataType>(b_k_n_tensors[i]);
// a_m_k_dev_buf.push_back(std::make_unique<ck_tile::DeviceMem>(
// a_m_k_tensors[i].get_element_space_size_in_bytes()));
// b_shfl_dev_buf.push_back(
// std::make_unique<ck_tile::DeviceMem>(b_shuffle_host.get_element_space_size_in_bytes()));
// c_m_n_dev_buf.push_back(std::make_unique<ck_tile::DeviceMem>(
// c_m_n_tensors[i].get_element_space_size_in_bytes()));
// a_m_k_dev_buf[i]->ToDevice(a_m_k_tensors[i].data());
// b_shfl_dev_buf[i]->ToDevice(b_shuffle_host.data());
// c_m_n_tensors[i].SetZero();
// c_m_n_dev_buf[i]->SetZero();
// per_token_scales_dev_buf[i]->ToDevice(per_token_scales[i].data());
// per_channel_scales_dev_buf[i]->ToDevice(per_channel_scales[i].data());
// auto per_token_scale_dev_ptr = ck_tile::FlatmmScalePointer<ScaleGranularityM>{
// static_cast<float*>(per_token_scales_dev_buf[i]->GetDeviceBuffer())};
// auto per_channel_scale_dev_ptr = ck_tile::FlatmmScalePointer<ScaleGranularityN>{
// static_cast<float*>(per_channel_scales_dev_buf[i]->GetDeviceBuffer())};
// group_a_ptrs.push_back(a_m_k_dev_buf[i]->GetDeviceBuffer());
// group_b_ptrs.push_back(b_shfl_dev_buf[i]->GetDeviceBuffer());
// group_c_ptrs.push_back(c_m_n_dev_buf[i]->GetDeviceBuffer());
// group_scale_a_ptrs.push_back(per_token_scale_dev_ptr);
// group_scale_b_ptrs.push_back(per_channel_scale_dev_ptr);
// }
// ck_tile::DeviceMem group_m_dev_buf(group_count * sizeof(ck_tile::index_t));
// ck_tile::DeviceMem group_n_dev_buf(group_count * sizeof(ck_tile::index_t));
// ck_tile::DeviceMem group_k_dev_buf(group_count * sizeof(ck_tile::index_t));
// ck_tile::DeviceMem group_a_ptrs_dev_buf(group_count * sizeof(void*));
// ck_tile::DeviceMem group_b_ptrs_dev_buf(group_count * sizeof(void*));
// ck_tile::DeviceMem group_c_ptrs_dev_buf(group_count * sizeof(void*));
// ck_tile::DeviceMem group_stride_a_dev_buf(group_count * sizeof(ck_tile::index_t));
// ck_tile::DeviceMem group_stride_b_dev_buf(group_count * sizeof(ck_tile::index_t));
// ck_tile::DeviceMem group_stride_c_dev_buf(group_count * sizeof(ck_tile::index_t));
// ck_tile::DeviceMem group_scale_a_ptrs_dev_buf(group_count * sizeof(void*));
// ck_tile::DeviceMem group_scale_b_ptrs_dev_buf(group_count * sizeof(void*));
// group_m_dev_buf.ToDevice(Ms.data());
// group_n_dev_buf.ToDevice(Ns.data());
// group_k_dev_buf.ToDevice(Ks.data());
// group_a_ptrs_dev_buf.ToDevice(group_a_ptrs.data());
// group_b_ptrs_dev_buf.ToDevice(group_b_ptrs.data());
// group_c_ptrs_dev_buf.ToDevice(group_c_ptrs.data());
// group_stride_a_dev_buf.ToDevice(stride_As.data());
// group_stride_b_dev_buf.ToDevice(stride_Bs.data());
// group_stride_c_dev_buf.ToDevice(stride_Cs.data());
// group_scale_a_ptrs_dev_buf.ToDevice(group_scale_a_ptrs.data());
// group_scale_b_ptrs_dev_buf.ToDevice(group_scale_b_ptrs.data());
// using ScaleAPtr = decltype(group_scale_a_ptrs[0]);
// using ScaleBPtr = decltype(group_scale_b_ptrs[0]);
// ck_tile::GroupedFlatmmHostArgs<ScaleAPtr, ScaleBPtr> kernal_args{
// group_count,
// static_cast<ck_tile::index_t*>(group_m_dev_buf.GetDeviceBuffer()),
// static_cast<ck_tile::index_t*>(group_n_dev_buf.GetDeviceBuffer()),
// static_cast<ck_tile::index_t*>(group_k_dev_buf.GetDeviceBuffer()),
// static_cast<const void**>(group_a_ptrs_dev_buf.GetDeviceBuffer()),
// static_cast<ck_tile::index_t*>(group_stride_a_dev_buf.GetDeviceBuffer()),
// static_cast<const void**>(group_b_ptrs_dev_buf.GetDeviceBuffer()),
// static_cast<ck_tile::index_t*>(group_stride_b_dev_buf.GetDeviceBuffer()),
// static_cast<void**>(group_c_ptrs_dev_buf.GetDeviceBuffer()),
// static_cast<ck_tile::index_t*>(group_stride_c_dev_buf.GetDeviceBuffer()),
// kbatch,
// static_cast<ScaleAPtr*>(per_token_scale_dev_buf.GetDeviceBuffer()),
// static_cast<ScaleBPtr*>(per_channel_scale_dev_buf.GetDeviceBuffer())
// };
// invoke_gemm<FlatmmConfig,
// ADataType,
// BDataType,
// ck_tile::tuple<>,
// AccDataType,
// CDataType,
// ALayout,
// BLayout,
// ck_tile::tuple<>,
// CLayout,
// ScaleAPtr,
// ScaleBPtr>(
// warmup, repeat, kernal_args);
// for(int i = 0; i < group_count; i++)
// {
// c_m_n_dev_buf[i]->FromDevice(c_m_n_tensors[i].data());
// }
// bool pass{true};
// if(arg_parser.get_int("v") == 1)
// {
// for(int i = 0; i < group_count; ++i)
// {
// ck_tile::HostTensor<CDataType> c_m_n_host_ref(ck_tile::host_tensor_descriptor(
// Ms[i], Ns[i], stride_Cs[i], is_row_major(CLayout{})));
// c_m_n_host_ref.SetZero();
// ck_tile::reference_gemm<ADataType, BDataType, AccDataType, CDataType>(
// a_m_k_tensors[i], b_k_n_tensors[i], c_m_n_host_ref);
// const float max_accumulated_value =
// *std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end());
// const auto rtol_atol =
// calculate_rtol_atol<ADataType, BDataType, AccDataType, CDataType>(
// Ks[i], 1 /*kbatch*/, max_accumulated_value);
// pass &= ck_tile::check_err(c_m_n_tensors[i],
// c_m_n_host_ref,
// "Error: Incorrect results!",
// rtol_atol.at(ck_tile::number<0>{}),
// rtol_atol.at(ck_tile::number<1>{}));
// std::cout << "gemm[" << i
// << "] Relative error threshold: " << rtol_atol.at(ck_tile::number<0>{})
// << " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{})
// << std::endl;
// }
// std::cout << "The CPU verification result is:" << (pass ? "correct" : "fail") << std::endl;
// }
// else if(arg_parser.get_int("v") == 2)
// {
// for(int i = 0; i < group_count; ++i)
// {
// ck_tile::index_t M = Ms[i];
// ck_tile::index_t N = Ns[i];
// ck_tile::index_t K = Ks[i];
// ck_tile::index_t stride_A = stride_As[i];
// ck_tile::index_t stride_B = stride_Bs[i];
// ck_tile::index_t stride_C = stride_Cs[i];
// ADataType* d_A;
// BDataType* d_B;
// CDataType* d_C;
// ck_tile::hip_check_error(hipMalloc(&d_A, M * K * sizeof(ADataType)));
// ck_tile::hip_check_error(hipMalloc(&d_B, N * K * sizeof(BDataType)));
// ck_tile::hip_check_error(hipMalloc(&d_C, M * N * sizeof(CDataType)));
// ck_tile::hip_check_error(hipMemcpy(
// d_A, a_m_k_tensors[i].data(), M * K * sizeof(ADataType), hipMemcpyHostToDevice));
// ck_tile::hip_check_error(hipMemcpy(
// d_B, b_k_n_tensors[i].data(), N * K * sizeof(BDataType), hipMemcpyHostToDevice));
// ck_tile::reference_gemm_gpu<ADataType,
// BDataType,
// AccDataType,
// CDataType,
// ALayout,
// BLayout,
// CLayout>(
// d_A, d_B, d_C, M, N, K, stride_A, stride_B, stride_C);
// ck_tile::HostTensor<CDataType> c_gpu_ref_host(
// ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{})));
// ck_tile::hip_check_error(hipMemcpy(
// c_gpu_ref_host.data(), d_C, M * N * sizeof(CDataType), hipMemcpyDeviceToHost));
// ck_tile::hip_check_error(hipFree(d_A));
// ck_tile::hip_check_error(hipFree(d_B));
// ck_tile::hip_check_error(hipFree(d_C));
// const float max_accumulated_value =
// *std::max_element(c_gpu_ref_host.mData.begin(), c_gpu_ref_host.mData.end());
// const auto rtol_atol =
// calculate_rtol_atol<ADataType, BDataType, AccDataType, CDataType>(
// K, kbatch, max_accumulated_value);
// float rtol = 1e-3;
// float atol = 1e-3;
// pass = ck_tile::check_err(
// c_m_n_tensors[i], c_gpu_ref_host, "Error: Incorrect results!", rtol, atol);
// std::cout << "gemm[" << i << "]\nRelative error threshold: " << rtol
// << " Absolute error threshold: " << atol << std::endl;
// std::cout << "The GPU veification result is: " << (pass ? "correct" : "fail")
// << std::endl;
// }
// }
// return pass;
// }
template <typename PrecType,
typename FlatmmConfig,
int ScaleGranularityM = -1,
int ScaleGranularityN = -1,
typename ALayout,
typename BLayout,
typename CLayout>
int run_contiguous_grouped_flatmm_example_with_layouts(
int argc,
char* argv[],
const ALayout a_layout = ALayout{},
const BLayout b_layout = BLayout{},
[[maybe_unused]] const CLayout c_layout = CLayout{})
{
auto [result, arg_parser] = create_args(argc, argv);
if(!result)
{
return -1;
};
using ADataType = typename GemmBasicTypeConfig<PrecType>::ADataType;
using BDataType = typename GemmBasicTypeConfig<PrecType>::BDataType;
using CDataType = typename GemmBasicTypeConfig<PrecType>::CDataType;
using AccDataType = typename GemmBasicTypeConfig<PrecType>::AccDataType;
constexpr int BlockM = FlatmmConfig::M_Tile;
const int group_count = arg_parser.get_int("group_count");
const int repeat = arg_parser.get_int("repeat");
const int warmup = arg_parser.get_int("warmup");
std::vector<ck_tile::index_t> Ms = arg_parser.get_int_vec("Ms");
std::vector<ck_tile::index_t> Ns = arg_parser.get_int_vec("Ns");
std::vector<ck_tile::index_t> Ks = arg_parser.get_int_vec("Ks");
if(!(int(Ms.size()) == group_count))
{
std::cout << "Please check the input data." << std::endl;
// padding additional Ms if needed
for(int i = 0; i < group_count; i++)
{
Ms.push_back(256 + 64 * i);
}
}
ck_tile::index_t M =
std::reduce(Ms.begin(), Ms.begin() + group_count, 0, [](auto acc, auto group_m) {
// round up to the multiple of BlockM
return acc + (group_m + BlockM - 1) / BlockM * BlockM;
});
std::cout << "Total M: " << M << std::endl;
ck_tile::index_t N = Ns[0];
ck_tile::index_t K = Ks[0];
ck_tile::index_t kbatch = arg_parser.get_int("split_k");
ck_tile::index_t stride_A = 0;
ck_tile::index_t stride_B = 0;
ck_tile::index_t stride_C = 0;
stride_A = ck_tile::get_default_stride(M, K, stride_A, is_row_major(a_layout));
stride_B = ck_tile::get_default_stride(K, N * group_count, stride_B, is_row_major(b_layout));
stride_C = ck_tile::get_default_stride(M, N, stride_C, is_row_major(c_layout));
ck_tile::HostTensor<ADataType> a_m_k_tensor(
ck_tile::host_tensor_descriptor(M, K, stride_A, is_row_major(a_layout)));
ck_tile::HostTensor<BDataType> b_k_n_tensor(ck_tile::HostTensor<BDataType>(
ck_tile::host_tensor_descriptor(K, N * group_count, stride_B, is_row_major(b_layout))));
ck_tile::HostTensor<CDataType> c_m_n_tensor(ck_tile::HostTensor<CDataType>(
ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(c_layout))));
ck_tile::HostTensor<AccDataType> per_token_scale(ck_tile::HostTensorDescriptor({M}, {1}));
ck_tile::HostTensor<AccDataType> per_channel_scale(ck_tile::HostTensorDescriptor({N}, {1}));
std::vector<ck_tile::index_t> m_indices(M);
int indices_fill_start = 0;
for(int i = 0; i < group_count; ++i)
{
int group_m = Ms[i];
int padded_group_m = (group_m + BlockM - 1) / BlockM * BlockM;
for(int j = 0; j < padded_group_m; j++)
{
m_indices[indices_fill_start + j] = j < group_m ? i : -1; // -1 for padding
}
indices_fill_start += padded_group_m;
}
ck_tile::FillUniformDistribution<ADataType>{-1.f, 1.f}(a_m_k_tensor);
ck_tile::FillUniformDistribution<BDataType>{-.5f, .5f}(b_k_n_tensor);
ck_tile::FillUniformDistribution<AccDataType>{-1.f, 1.f}(per_token_scale);
ck_tile::FillUniformDistribution<AccDataType>{-1.f, 1.f}(per_channel_scale);
constexpr int N_Warp_Tile = FlatmmConfig::N_Warp_Tile;
assert(N % N_Warp_Tile == 0 &&
"N must be divisible by N_Warp_Tile for contiguous grouped gemm");
ck_tile::HostTensor<BDataType> b_shuffle_host = shuffle_b<FlatmmConfig, BDataType>(b_k_n_tensor);
std::unique_ptr<ck_tile::DeviceMem> a_m_k_dev_buf(
std::make_unique<ck_tile::DeviceMem>(a_m_k_tensor.get_element_space_size_in_bytes()));
std::unique_ptr<ck_tile::DeviceMem> b_shfl_dev_buf(
std::make_unique<ck_tile::DeviceMem>(b_shuffle_host.get_element_space_size_in_bytes()));
std::unique_ptr<ck_tile::DeviceMem> c_m_n_dev_buf(
std::make_unique<ck_tile::DeviceMem>(c_m_n_tensor.get_element_space_size_in_bytes()));
ck_tile::DeviceMem per_token_scale_dev_buf(per_token_scale.get_element_space_size_in_bytes());
ck_tile::DeviceMem per_channel_scale_dev_buf(
per_channel_scale.get_element_space_size_in_bytes());
c_m_n_dev_buf->SetZero();
ck_tile::DeviceMem m_indices_dev_buf(M * sizeof(ck_tile::index_t));
m_indices_dev_buf.ToDevice(m_indices.data());
a_m_k_dev_buf->ToDevice(a_m_k_tensor.data());
b_shfl_dev_buf->ToDevice(b_shuffle_host.data());
per_token_scale_dev_buf.ToDevice(per_token_scale.data());
per_channel_scale_dev_buf.ToDevice(per_channel_scale.data());
auto per_token_scale_dev_ptr = ck_tile::FlatmmScalePointer<ScaleGranularityM>{
static_cast<float*>(per_token_scale_dev_buf.GetDeviceBuffer())};
auto per_channel_scale_dev_ptr = ck_tile::FlatmmScalePointer<ScaleGranularityN>{
static_cast<float*>(per_channel_scale_dev_buf.GetDeviceBuffer())};
ck_tile::ContiguousGroupedFlatmmHostArgs<decltype(per_token_scale_dev_ptr), decltype(per_channel_scale_dev_ptr)> kernal_args{
static_cast<ck_tile::index_t*>(m_indices_dev_buf.GetDeviceBuffer()),
M,
N,
K,
a_m_k_dev_buf->GetDeviceBuffer(),
stride_A,
b_shfl_dev_buf->GetDeviceBuffer(),
stride_B,
{},{},
c_m_n_dev_buf->GetDeviceBuffer(),
stride_C,
kbatch,
static_cast<float*>(per_token_scale_dev_buf.GetDeviceBuffer()),
static_cast<float*>(per_channel_scale_dev_buf.GetDeviceBuffer())
};
invoke_gemm<FlatmmConfig,
ADataType,
BDataType,
ck_tile::tuple<>,
AccDataType,
CDataType,
ALayout,
BLayout,
ck_tile::tuple<>,
CLayout,
decltype(per_token_scale_dev_ptr),
decltype(per_channel_scale_dev_ptr)>(
warmup, repeat, kernal_args);
c_m_n_dev_buf->FromDevice(c_m_n_tensor.data());
bool pass{true};
if(arg_parser.get_int("v") == 1)
{
throw std::runtime_error(
"Not support v=1 host verification in contiguous grouped gemm, use "
"v=2 device verification instead");
}
else if(arg_parser.get_int("v") == 2)
{
BDataType* d_B;
CDataType* d_C;
ck_tile::hip_check_error(hipMalloc(&d_B, N * K * sizeof(BDataType)));
ck_tile::hip_check_error(hipMalloc(&d_C, M * N * sizeof(CDataType)));
ck_tile::hip_check_error(hipMemset(d_C, 0, M * N * sizeof(CDataType)));
ck_tile::HostTensor<CDataType> c_gpu_ref_host(
ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{})));
ck_tile::index_t acc_m = 0;
for(int i = 0; i < group_count; ++i)
{
ck_tile::index_t padded_M = (Ms[i] + BlockM - 1) / BlockM * BlockM;
ck_tile::hip_check_error(hipMemcpy(d_B,
b_k_n_tensor.data() + i * N * K,
N * K * sizeof(BDataType),
hipMemcpyHostToDevice));
ck_tile::reference_gemm_gpu<ADataType,
BDataType,
AccDataType,
CDataType,
ALayout,
BLayout,
CLayout>(
static_cast<ADataType*>(a_m_k_dev_buf->GetDeviceBuffer()) + acc_m * K,
d_B,
d_C + acc_m * N,
padded_M,
N,
K,
stride_A,
stride_B,
stride_C);
acc_m += padded_M;
}
ck_tile::hip_check_error(hipMemcpy(
c_gpu_ref_host.data(), d_C, M * N * sizeof(CDataType), hipMemcpyDeviceToHost));
ck_tile::hip_check_error(hipFree(d_B));
ck_tile::hip_check_error(hipFree(d_C));
float rtol = 1e-3;
float atol = 1e-3;
pass = ck_tile::check_err(
c_m_n_tensor, c_gpu_ref_host, "Error: Incorrect results!", rtol, atol);
std::cout << "Relative error threshold: " << rtol << " Absolute error threshold: " << atol
<< std::endl;
std::cout << "The GPU veification result is: " << (pass ? "correct" : "fail") << std::endl;
}
return pass;
}
template <typename PrecType,
typename FlatmmConfig,
int ScaleGranularityM = -1,
int ScaleGranularityN = -1,
typename ALayout,
typename BLayout,
typename CLayout>
int run_masked_grouped_flatmm_example_with_layouts(
int argc,
char* argv[],
const ALayout a_layout = ALayout{},
const BLayout b_layout = BLayout{},
[[maybe_unused]] const CLayout c_layout = CLayout{})
{
auto [result, arg_parser] = create_args(argc, argv);
if(!result)
{
return -1;
};
using ADataType = typename GemmBasicTypeConfig<PrecType>::ADataType;
using BDataType = typename GemmBasicTypeConfig<PrecType>::BDataType;
using CDataType = typename GemmBasicTypeConfig<PrecType>::CDataType;
using AccDataType = typename GemmBasicTypeConfig<PrecType>::AccDataType;
constexpr int BlockM = FlatmmConfig::M_Tile;
const int group_count = arg_parser.get_int("group_count");
const int repeat = arg_parser.get_int("repeat");
const int warmup = arg_parser.get_int("warmup");
std::vector<ck_tile::index_t> Ms = arg_parser.get_int_vec("Ms");
std::vector<ck_tile::index_t> Ns = arg_parser.get_int_vec("Ns");
std::vector<ck_tile::index_t> Ks = arg_parser.get_int_vec("Ks");
if(!(int(Ms.size()) == group_count))
{
std::cout << "Please check the input data." << std::endl;
// padding additional Ms if needed
for(int i = 0; i < group_count; i++)
{
Ms.push_back(256 + 64 * i);
}
}
ck_tile::index_t M = 4096; // Ms[0];
ck_tile::index_t N = Ns[0];
ck_tile::index_t K = Ks[0];
ck_tile::index_t kbatch = arg_parser.get_int("split_k");
ck_tile::index_t stride_A = K;
ck_tile::index_t stride_B = K;
ck_tile::index_t stride_C = N;
stride_A = ck_tile::get_default_stride(group_count * M, K, stride_A, is_row_major(a_layout));
stride_B = ck_tile::get_default_stride(K, N * group_count, stride_B, is_row_major(b_layout));
stride_C = ck_tile::get_default_stride(group_count * M, N, stride_C, is_row_major(c_layout));
ck_tile::HostTensor<ADataType> a_m_k_tensor(
ck_tile::host_tensor_descriptor(group_count * M, K, stride_A, is_row_major(a_layout)));
ck_tile::HostTensor<BDataType> b_k_n_tensor(ck_tile::HostTensor<BDataType>(
ck_tile::host_tensor_descriptor(K, N * group_count, stride_B, is_row_major(b_layout))));
ck_tile::HostTensor<CDataType> c_m_n_tensor(ck_tile::HostTensor<CDataType>(
ck_tile::host_tensor_descriptor(group_count * M, N, stride_C, is_row_major(c_layout))));
ck_tile::HostTensor<AccDataType> per_token_scale(ck_tile::HostTensorDescriptor({group_count * M}, {1}));
ck_tile::HostTensor<AccDataType> per_channel_scale(ck_tile::HostTensorDescriptor({group_count * N}, {1}));
std::vector<ck_tile::index_t> m_indices(group_count);
int indices_fill_start = 0;
for(int i = 0; i < group_count; ++i)
{
int group_m = Ms[i];
int padded_group_m = (group_m + BlockM - 1) / BlockM * BlockM;
for(int j = 0; j < padded_group_m; j++)
{
m_indices[i] = group_m;
// m_indices[i] = padded_group_m; // -1 for padding
}
}
ck_tile::FillUniformDistribution<ADataType>{-1.f, 1.f}(a_m_k_tensor);
ck_tile::FillUniformDistribution<BDataType>{-.5f, .5f}(b_k_n_tensor);
ck_tile::FillUniformDistribution<AccDataType>{-1.f, 1.f}(per_token_scale);
ck_tile::FillUniformDistribution<AccDataType>{-1.f, 1.f}(per_channel_scale);
constexpr int N_Warp_Tile = FlatmmConfig::N_Warp_Tile;
assert(N % N_Warp_Tile == 0 &&
"N must be divisible by N_Warp_Tile for contiguous grouped gemm");
ck_tile::HostTensor<BDataType> b_shuffle_host = shuffle_b<FlatmmConfig, BDataType>(b_k_n_tensor);
std::unique_ptr<ck_tile::DeviceMem> a_m_k_dev_buf(
std::make_unique<ck_tile::DeviceMem>(a_m_k_tensor.get_element_space_size_in_bytes()));
std::unique_ptr<ck_tile::DeviceMem> b_shfl_dev_buf(
std::make_unique<ck_tile::DeviceMem>(b_shuffle_host.get_element_space_size_in_bytes()));
std::unique_ptr<ck_tile::DeviceMem> c_m_n_dev_buf(
std::make_unique<ck_tile::DeviceMem>(c_m_n_tensor.get_element_space_size_in_bytes()));
ck_tile::DeviceMem per_token_scale_dev_buf(per_token_scale.get_element_space_size_in_bytes());
ck_tile::DeviceMem per_channel_scale_dev_buf(
per_channel_scale.get_element_space_size_in_bytes());
c_m_n_dev_buf->SetZero();
ck_tile::DeviceMem m_indices_dev_buf(group_count * sizeof(ck_tile::index_t));
m_indices_dev_buf.ToDevice(m_indices.data());
a_m_k_dev_buf->ToDevice(a_m_k_tensor.data());
b_shfl_dev_buf->ToDevice(b_shuffle_host.data());
per_token_scale_dev_buf.ToDevice(per_token_scale.data());
per_channel_scale_dev_buf.ToDevice(per_channel_scale.data());
auto per_token_scale_dev_ptr = ck_tile::FlatmmScalePointer<ScaleGranularityM>{
static_cast<float*>(per_token_scale_dev_buf.GetDeviceBuffer())};
auto per_channel_scale_dev_ptr = ck_tile::FlatmmScalePointer<ScaleGranularityN>{
static_cast<float*>(per_channel_scale_dev_buf.GetDeviceBuffer())};
ck_tile::MaskedGroupedFlatmmHostArgs<decltype(per_token_scale_dev_ptr), decltype(per_channel_scale_dev_ptr)> kernal_args{
static_cast<ck_tile::index_t*>(m_indices_dev_buf.GetDeviceBuffer()),
group_count,
M,
N,
K,
a_m_k_dev_buf->GetDeviceBuffer(),
stride_A,
b_shfl_dev_buf->GetDeviceBuffer(),
stride_B,
{},{},
c_m_n_dev_buf->GetDeviceBuffer(),
stride_C,
kbatch,
static_cast<float*>(per_token_scale_dev_buf.GetDeviceBuffer()),
static_cast<float*>(per_channel_scale_dev_buf.GetDeviceBuffer())
};
int sum_val_m = 0;
for (int gi = 0; gi < group_count; gi++)
{
sum_val_m += m_indices[gi];
}
invoke_gemm<FlatmmConfig,
ADataType,
BDataType,
ck_tile::tuple<>,
AccDataType,
CDataType,
ALayout,
BLayout,
ck_tile::tuple<>,
CLayout,
decltype(per_token_scale_dev_ptr),
decltype(per_channel_scale_dev_ptr)>(
warmup, repeat, sum_val_m, kernal_args);
c_m_n_dev_buf->FromDevice(c_m_n_tensor.data());
bool pass{true};
if(arg_parser.get_int("v") == 1)
{
throw std::runtime_error(
"Not support v=1 host verification in contiguous grouped gemm, use "
"v=2 device verification instead");
}
else if(arg_parser.get_int("v") == 2)
{
BDataType* d_B;
CDataType* d_C;
ck_tile::hip_check_error(hipMalloc(&d_B, N * K * sizeof(BDataType)));
ck_tile::hip_check_error(hipMalloc(&d_C, group_count * M * N * sizeof(CDataType)));
ck_tile::hip_check_error(hipMemset(d_C, 0, group_count * M * N * sizeof(CDataType)));
ck_tile::HostTensor<CDataType> c_gpu_ref_host(
ck_tile::host_tensor_descriptor(group_count * M, N, stride_C, is_row_major(CLayout{})));
ck_tile::index_t acc_m = 0;
for(int i = 0; i < group_count; ++i)
{
ck_tile::hip_check_error(hipMemcpy(d_B,
b_k_n_tensor.data() + i * N * K,
N * K * sizeof(BDataType),
hipMemcpyHostToDevice));
if constexpr(ScaleGranularityM == -1 && ScaleGranularityN == -1)
{
ck_tile::reference_gemm_gpu<ADataType,
BDataType,
AccDataType,
CDataType,
ALayout,
BLayout,
CLayout>(
static_cast<ADataType*>(a_m_k_dev_buf->GetDeviceBuffer()) + i * M * K,
d_B,
d_C + i * M * N,
m_indices[i],
N,
K,
stride_A,
stride_B,
stride_C);
}
else
{
ck_tile::reference_blockwise_gemm_gpu<ADataType,
BDataType,
AccDataType,
CDataType,
ALayout,
BLayout,
CLayout>(
static_cast<ADataType*>(a_m_k_dev_buf->GetDeviceBuffer()) + i * M * K,
d_B,
d_C + i * M * N,
m_indices[i],
N,
K,
stride_A,
stride_B,
stride_C,
ScaleGranularityM,
ScaleGranularityN,
K,
static_cast<float*>(per_token_scale_dev_buf.GetDeviceBuffer()) + i * M,
static_cast<float*>(per_channel_scale_dev_buf.GetDeviceBuffer())) + i * N;
}
ck_tile::hip_check_error(hipMemcpy(c_gpu_ref_host.data() + i * M * N,
d_C + i * M * N,
M * N * sizeof(CDataType),
hipMemcpyDeviceToHost));
}
ck_tile::hip_check_error(hipFree(d_B));
ck_tile::hip_check_error(hipFree(d_C));
float rtol = 1e-3;
float atol = 1e-3;
pass = ck_tile::check_err(
c_m_n_tensor, c_gpu_ref_host, "Error: Incorrect results!", rtol, atol);
std::cout << "Relative error threshold: " << rtol << " Absolute error threshold: " << atol
<< std::endl;
std::cout << "The GPU veification result is: " << (pass ? "correct" : "fail") << std::endl;
}
return pass;
}

View File

@@ -32,6 +32,12 @@ __launch_bounds__(MaxThreadPerBlock, MinBlockPerCu)
#endif
}
template <int MaxThreadPerBlock, typename Kernel, typename... Args>
__launch_bounds__(MaxThreadPerBlock) __global__ void kentry2(Args... args)
{
Kernel{}(args...);
}
//
// return a anonymous functor(lambda) to be called later
// the KernelImpl should be a class without non-static data member, or let's say

View File

@@ -10,6 +10,7 @@
#include "ck_tile/ops/flatmm/block/flatmm_sn_32x128x512_1x4x1_16x16x32_itl.hpp"
#include "ck_tile/ops/flatmm/block/flatmm_uk_config.hpp"
#include "ck_tile/ops/flatmm/kernel/flatmm_kernel.hpp"
#include "ck_tile/ops/flatmm/kernel/grouped_flatmm_kernel.hpp"
#include "ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v0.hpp"
#include "ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1.hpp"
#include "ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp"

6
include/ck_tile/ops/flatmm/kernel/flatmm_kernel.hpp Executable file → Normal file
View File

@@ -244,7 +244,9 @@ struct FlatmmKernel
static_assert(DsLayout::size() == DsDataType::size(),
"The size of DsLayout and DsDataType should be the same");
// using KernelArgs = FlatmmKernelArgs<DsLayout::size()>;
template<class ScaleM, class ScaleN>
using KernelArgs = FlatmmKernelArgs<ScaleM, ScaleN, DsLayout::size()>;
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
{
@@ -751,7 +753,7 @@ struct FlatmmKernel
CK_TILE_DEVICE void operator()(FlatmmKernelArgs<ScaleM, ScaleN, DsDataType::size()> kargs,
int partition_idx = blockIdx.x) const
{
const auto [iM, iN] = TilePartitioner{kargs.M, kargs.N}.GetOutputTileIndex(blockIdx.x);
const auto [iM, iN] = TilePartitioner{kargs.M, kargs.N}.GetOutputTileIndex(partition_idx);
const index_t i_m = __builtin_amdgcn_readfirstlane(iM * TilePartitioner::MPerBlock);
const index_t i_n = __builtin_amdgcn_readfirstlane(iN * TilePartitioner::NPerBlock);

View File

@@ -0,0 +1,465 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <string>
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp"
#include "ck_tile/ops/flatmm/kernel/flatmm_kernel.hpp"
namespace ck_tile {
template <class ScaleM = FlatmmScalePointer<-1>,
class ScaleN = FlatmmScalePointer<-1>,
index_t NumDTensor = 0>
struct GroupedFlatmmHostArgs
{
CK_TILE_HOST GroupedFlatmmHostArgs() = default;
CK_TILE_HOST GroupedFlatmmHostArgs(index_t group_count_,
index_t* M_,
index_t* N_,
index_t* K_,
const void** a_ptr_,
index_t* stride_A_,
const void** b_shuffle_ptr_,
index_t* stride_B_,
const std::array<const void*, NumDTensor>& ds_ptr_,
const std::array<index_t, NumDTensor>& stride_Ds_,
void** c_ptr_,
index_t* stride_C_,
index_t k_batch_,
ScaleM scale_m_ = nullptr,
ScaleN scale_n_ = nullptr)
: group_count(group_count_),
M(M_),
N(N_),
K(K_),
a_ptr(a_ptr_),
stride_A(stride_A_),
b_shuffle_ptr(b_shuffle_ptr_),
stride_B(stride_B_),
ds_ptr(ds_ptr_),
stride_Ds(stride_Ds_),
c_ptr(c_ptr_),
stride_C(stride_C_),
k_batch(k_batch_),
scale_m(scale_m_),
scale_n(scale_n_)
{
}
index_t group_count;
index_t* M;
index_t* N;
index_t* K;
const void** a_ptr;
index_t* stride_A;
const void** b_shuffle_ptr;
index_t* stride_B;
const std::array<const void*, NumDTensor> ds_ptr;
const std::array<index_t, NumDTensor> stride_Ds;
union
{
void** e_ptr;
void** c_ptr;
};
index_t* stride_C;
index_t k_batch;
ScaleM scale_m = nullptr;
ScaleN scale_n = nullptr;
};
template <class ScaleM = FlatmmScalePointer<-1>,
class ScaleN = FlatmmScalePointer<-1>,
index_t NumDTensor = 0>
struct ContiguousGroupedFlatmmHostArgs
{
CK_TILE_HOST ContiguousGroupedFlatmmHostArgs() = default;
CK_TILE_HOST ContiguousGroupedFlatmmHostArgs(index_t* M_indices_,
index_t M_,
index_t N_,
index_t K_,
const void* a_ptr_,
index_t stride_A_,
const void* b_shuffle_ptr_,
index_t stride_B_,
const std::array<const void*, NumDTensor>& ds_ptr_,
const std::array<index_t, NumDTensor>& stride_Ds_,
void* c_ptr_,
index_t stride_C_,
index_t k_batch_,
ScaleM scale_m_ = nullptr,
ScaleN scale_n_ = nullptr)
: M_indices(M_indices_),
M(M_),
N(N_),
K(K_),
a_ptr(a_ptr_),
stride_A(stride_A_),
b_shuffle_ptr(b_shuffle_ptr_),
stride_B(stride_B_),
ds_ptr(ds_ptr_),
stride_Ds(stride_Ds_),
c_ptr(c_ptr_),
stride_C(stride_C_),
k_batch(k_batch_),
scale_m(scale_m_),
scale_n(scale_n_)
{
}
index_t* M_indices;
index_t M;
index_t N;
index_t K;
const void* a_ptr;
index_t stride_A;
const void* b_shuffle_ptr;
index_t stride_B;
const std::array<const void*, NumDTensor> ds_ptr;
const std::array<index_t, NumDTensor> stride_Ds;
union
{
void* e_ptr;
void* c_ptr;
};
index_t stride_C;
index_t k_batch;
ScaleM scale_m = nullptr;
ScaleN scale_n = nullptr;
};
template <class ScaleM = FlatmmScalePointer<-1>,
class ScaleN = FlatmmScalePointer<-1>,
index_t NumDTensor = 0>
struct MaskedGroupedFlatmmHostArgs
{
CK_TILE_HOST MaskedGroupedFlatmmHostArgs() = default;
CK_TILE_HOST MaskedGroupedFlatmmHostArgs(index_t* M_indices_,
index_t group_count_,
index_t Max_M_,
index_t N_,
index_t K_,
const void* a_ptr_,
index_t stride_A_,
const void* b_shuffle_ptr_,
index_t stride_B_,
const std::array<const void*, NumDTensor>& ds_ptr_,
const std::array<index_t, NumDTensor>& stride_Ds_,
void* c_ptr_,
index_t stride_C_,
index_t k_batch_,
ScaleM scale_m_ = nullptr,
ScaleN scale_n_ = nullptr)
: M_indices(M_indices_),
group_count(group_count_),
M(Max_M_),
N(N_),
K(K_),
a_ptr(a_ptr_),
stride_A(stride_A_),
b_shuffle_ptr(b_shuffle_ptr_),
stride_B(stride_B_),
ds_ptr(ds_ptr_),
stride_Ds(stride_Ds_),
c_ptr(c_ptr_),
stride_C(stride_C_),
k_batch(k_batch_),
scale_m(scale_m_),
scale_n(scale_n_)
{
}
index_t* M_indices;
index_t group_count;
index_t M;
index_t N;
index_t K;
const void* a_ptr;
index_t stride_A;
const void* b_shuffle_ptr;
index_t stride_B;
const std::array<const void*, NumDTensor> ds_ptr;
const std::array<index_t, NumDTensor> stride_Ds;
union
{
void* e_ptr;
void* c_ptr;
};
index_t stride_C;
index_t k_batch;
ScaleM scale_m = nullptr;
ScaleN scale_n = nullptr;
};
template <typename TilePartitioner_, typename FlatmmPipeline_, typename EpiloguePipeline_>
struct GroupedFlatmmKernel : FlatmmKernel<TilePartitioner_, FlatmmPipeline_, EpiloguePipeline_>
{
using UnderlyingGemmKernel = FlatmmKernel<TilePartitioner_, FlatmmPipeline_, EpiloguePipeline_>;
using BlockGemmShape = typename UnderlyingGemmKernel::BlockGemmShape;
using TilePartitioner = remove_cvref_t<TilePartitioner_>;
using FlatmmPipeline = remove_cvref_t<FlatmmPipeline_>;
using EpiloguePipeline = remove_cvref_t<EpiloguePipeline_>;
using ADataType = remove_cvref_t<typename FlatmmPipeline::ADataType>;
using BDataType = remove_cvref_t<typename FlatmmPipeline::BDataType>;
// Below type is actually accumulation data type - the output of block GEMM.
using CDataType = remove_cvref_t<typename EpiloguePipeline::ODataType>;
using DsLayout = remove_cvref_t<typename EpiloguePipeline::DsLayout>;
using DsDataType = remove_cvref_t<typename EpiloguePipeline::DsDataType>;
static constexpr index_t NumDTensor = DsDataType::size();
static constexpr auto I0 = number<0>();
static constexpr auto I1 = number<1>();
static constexpr auto I2 = number<2>();
static constexpr auto I3 = number<3>();
static_assert(DsLayout::size() == DsDataType::size(),
"The size of DsLayout and DsDataType should be the same");
CK_TILE_HOST static const std::string GetName()
{
return concat(
'_', "grouped_flatmm", gemm_prec_str<ADataType, BDataType>, FlatmmPipeline::GetName());
}
template <class ScaleM = FlatmmScalePointer<-1>,
class ScaleN = FlatmmScalePointer<-1>,
index_t NumDTensor = 0>
CK_TILE_HOST_DEVICE static auto
GridSize([[maybe_unused]] const GroupedFlatmmHostArgs<ScaleM, ScaleN, NumDTensor>& kernelArgs)
{
hipDeviceProp_t prop;
int deviceId = 0; // default device
constexpr int block_size = UnderlyingGemmKernel::BlockSize().x;
int dync_smem_size = 0;
int maxActiveBlocksPerCU;
[[maybe_unused]] auto e = hipGetDeviceProperties(&prop, deviceId);
e = hipOccupancyMaxActiveBlocksPerMultiprocessor(
&maxActiveBlocksPerCU,
reinterpret_cast<void*>(
kentry2<block_size, GroupedFlatmmKernel, GroupedFlatmmHostArgs<ScaleM, ScaleN, NumDTensor>>),
block_size,
dync_smem_size);
const int persistent_block_size = prop.multiProcessorCount * maxActiveBlocksPerCU;
std::cout << "maxActiveBlocksPerCU: " << maxActiveBlocksPerCU
<< ", persistent_block_size: " << persistent_block_size << std::endl;
assert(kernelArgs.k_batch == 1);
return dim3(persistent_block_size, 1, kernelArgs.k_batch);
}
template <class ScaleM = FlatmmScalePointer<-1>,
class ScaleN = FlatmmScalePointer<-1>,
index_t NumDTensor = 0>
CK_TILE_HOST_DEVICE static auto
GridSize([[maybe_unused]] const ContiguousGroupedFlatmmHostArgs<ScaleM, ScaleN, NumDTensor>& kernelArgs)
{
hipDeviceProp_t prop;
int deviceId = 0; // default device
constexpr int block_size = UnderlyingGemmKernel::BlockSize().x;
int dync_smem_size = 0;
int maxActiveBlocksPerCU;
[[maybe_unused]] auto e = hipGetDeviceProperties(&prop, deviceId);
e = hipOccupancyMaxActiveBlocksPerMultiprocessor(
&maxActiveBlocksPerCU,
reinterpret_cast<void*>(
kentry2<block_size, GroupedFlatmmKernel, ContiguousGroupedFlatmmHostArgs<ScaleM, ScaleN, NumDTensor>>),
block_size,
dync_smem_size);
const int persistent_block_size = prop.multiProcessorCount * maxActiveBlocksPerCU;
const int total_work_tile_cnt = TilePartitioner::GridSize(kernelArgs.M, kernelArgs.N);
std::cout << "maxActiveBlocksPerCU: " << maxActiveBlocksPerCU
<< ", persistent_block_size: " << persistent_block_size
<< ", total_work_tile_cnt: " << total_work_tile_cnt << std::endl;
assert(kernelArgs.k_batch == 1);
return dim3(min(persistent_block_size, total_work_tile_cnt), 1, kernelArgs.k_batch);
}
template <class ScaleM = FlatmmScalePointer<-1>,
class ScaleN = FlatmmScalePointer<-1>,
index_t NumDTensor = 0>
CK_TILE_HOST_DEVICE static auto
GridSize([[maybe_unused]] const MaskedGroupedFlatmmHostArgs<ScaleM, ScaleN, NumDTensor>& kernelArgs)
{
hipDeviceProp_t prop;
int deviceId = 0; // default device
constexpr int block_size = UnderlyingGemmKernel::BlockSize().x;
int dync_smem_size = 0;
int maxActiveBlocksPerCU;
[[maybe_unused]] auto e = hipGetDeviceProperties(&prop, deviceId);
e = hipOccupancyMaxActiveBlocksPerMultiprocessor(
&maxActiveBlocksPerCU,
reinterpret_cast<void*>(
kentry2<block_size, GroupedFlatmmKernel, MaskedGroupedFlatmmHostArgs<ScaleM, ScaleN, NumDTensor>>),
block_size,
dync_smem_size);
const int persistent_block_size = prop.multiProcessorCount * maxActiveBlocksPerCU;
// const int total_work_tile_cnt = TilePartitioner::GridSize(kernelArgs.M, kernelArgs.N);
std::cout << "maxActiveBlocksPerCU: " << maxActiveBlocksPerCU
<< ", persistent_block_size: " << persistent_block_size << std::endl;
assert(kernelArgs.k_batch == 1);
return dim3(persistent_block_size, 1, kernelArgs.k_batch);
}
template<typename HostArgs>
CK_TILE_HOST static constexpr auto MakeKernelArgs(const HostArgs& hostArgs)
{
return hostArgs;
}
// CK_TILE_HOST static constexpr auto
// MakeKernelArgs(const ContiguousGroupedFlatmmHostArgs& hostArgs)
// {
// return hostArgs;
// }
// CK_TILE_HOST static constexpr auto
// MakeKernelArgs(const MaskedGroupedFlatmmHostArgs& hostArgs)
// {
// return hostArgs;
// }
template <class ScaleM = FlatmmScalePointer<-1>,
class ScaleN = FlatmmScalePointer<-1>,
index_t NumDTensor = 0>
CK_TILE_DEVICE void operator()(GroupedFlatmmHostArgs<ScaleM, ScaleN, NumDTensor> kargs) const
{
int group_idx = 0;
int block_linear_idx = blockIdx.x;
int total_block_cnt = gridDim.x;
UnderlyingGemmKernel underlying_kernel{};
for(; group_idx < kargs.group_count; ++group_idx)
{
const index_t M = kargs.M[group_idx];
const index_t N = kargs.N[group_idx];
const index_t group_block_cnt = TilePartitioner::GridSize(M, N);
while(block_linear_idx < group_block_cnt)
{
// Found the group this block belongs to
// create the kernel args for the underlying flatmm kernel
typename UnderlyingGemmKernel::template KernelArgs<ScaleM, ScaleN> impl_kargs{
kargs.a_ptr[group_idx],
kargs.b_shuffle_ptr[group_idx],
kargs.ds_ptr,
kargs.c_ptr[group_idx],
kargs.M[group_idx],
kargs.N[group_idx],
kargs.K[group_idx],
kargs.stride_A[group_idx],
kargs.stride_B[group_idx],
kargs.stride_Ds,
kargs.stride_C[group_idx],
kargs.k_batch,
};
// call the underlying flatmm kernel
underlying_kernel(impl_kargs, block_linear_idx);
block_linear_idx += total_block_cnt;
}
block_linear_idx -= group_block_cnt;
}
}
template <class ScaleM = FlatmmScalePointer<-1>,
class ScaleN = FlatmmScalePointer<-1>,
index_t NumDTensor = 0>
CK_TILE_DEVICE void operator()(ContiguousGroupedFlatmmHostArgs<ScaleM, ScaleN, NumDTensor> kargs) const
{
int block_linear_idx = blockIdx.x;
int total_block_cnt = gridDim.x;
int total_work_tile_cnt = TilePartitioner::GridSize(kargs.M, kargs.N);
UnderlyingGemmKernel underlying_kernel{};
for(; block_linear_idx < total_work_tile_cnt; block_linear_idx += total_block_cnt)
{
auto [block_m_idx, block_n_idx] =
TilePartitioner{kargs.M, kargs.N}.GetOutputTileIndex(block_linear_idx);
// get the group index from the M_indices
int group_idx = kargs.M_indices[block_m_idx * BlockGemmShape::kM];
typename UnderlyingGemmKernel::template KernelArgs<ScaleM, ScaleN> impl_kargs{
kargs.a_ptr,
static_cast<const BDataType*>(kargs.b_shuffle_ptr) + group_idx * kargs.N * kargs.K,
kargs.ds_ptr,
kargs.c_ptr,
kargs.M,
kargs.N,
kargs.K,
kargs.stride_A,
kargs.stride_B,
kargs.stride_Ds,
kargs.stride_C,
kargs.k_batch,
};
// call the underlying flatmm kernel
underlying_kernel(impl_kargs, block_linear_idx);
}
}
template <class ScaleM = FlatmmScalePointer<-1>,
class ScaleN = FlatmmScalePointer<-1>,
index_t NumDTensor = 0>
CK_TILE_DEVICE void operator()(MaskedGroupedFlatmmHostArgs<ScaleM, ScaleN, NumDTensor> kargs) const
{
int group_idx = 0;
int block_linear_idx = blockIdx.x;
int total_block_cnt = gridDim.x;
UnderlyingGemmKernel underlying_kernel{};
for(; group_idx < kargs.group_count; ++group_idx)
{
const index_t valid_M = kargs.M_indices[group_idx];
const index_t N = kargs.N;
const index_t group_block_cnt = TilePartitioner::GridSize(valid_M, N);
while(block_linear_idx < group_block_cnt)
{
// Found the group this block belongs to
// create the kernel args for the underlying flatmm kernel
typename UnderlyingGemmKernel::template KernelArgs<ScaleM, ScaleN> impl_kargs{
static_cast<const ADataType*>(kargs.a_ptr) + group_idx * kargs.M * kargs.K,
static_cast<const BDataType*>(kargs.b_shuffle_ptr) + group_idx * kargs.N * kargs.K,
kargs.ds_ptr,
static_cast<CDataType*>(kargs.c_ptr) + group_idx * kargs.M * kargs.N,
valid_M,
kargs.N,
kargs.K,
kargs.stride_A,
kargs.stride_B,
kargs.stride_Ds,
kargs.stride_C,
kargs.k_batch,
};
// call the underlying flatmm kernel
underlying_kernel(impl_kargs, block_linear_idx);
block_linear_idx += total_block_cnt;
}
block_linear_idx -= group_block_cnt;
}
}
};
} // namespace ck_tile

View File

@@ -112,7 +112,7 @@ struct GemmTile1DPartitioner
* @param N GEMM's N dimension.
* @return dim3 Structure holding grid's X,Y and Z dimensions.
*/
CK_TILE_HOST static auto
CK_TILE_HOST_DEVICE static auto
GridSize(index_t M, index_t N) noexcept(noexcept(MPerBlock != 0 && NPerBlock != 0)) -> index_t
{
const index_t GridDimX = (M + MPerBlock - 1) / MPerBlock;

View File

@@ -16,7 +16,7 @@ fi
cmake \
-D CMAKE_PREFIX_PATH=/opt/rocm/ \
-D CMAKE_CXX_COMPILER=/opt/rocm/llvm/bin/clang++ \
-D CMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \
-D CMAKE_CXX_FLAGS="-std=c++17 -O3 -ftemplate-backtrace-limit=0 -fPIE -Wno-gnu-line-marker" \
-D CMAKE_BUILD_TYPE=Release \
-D BUILD_DEV=ON \