Add support for contiguous grouped gemm

This commit is contained in:
Feng Shijie
2025-07-09 06:37:29 +00:00
parent 0deeba90e6
commit fae4ebac66
4 changed files with 387 additions and 61 deletions

View File

@@ -19,8 +19,9 @@ template <typename ADataType,
typename CDataType,
typename ALayout,
typename BLayout,
typename CLayout>
float grouped_flatmm(const ck_tile::GroupedFlatmmHostArgs& args, const ck_tile::stream_config& s)
typename CLayout,
typename KernelArguments>
float grouped_flatmm(const KernelArguments& args, const ck_tile::stream_config& s)
{
// The kPadM, kPadN, kPadK & kBlockPerCu should also come from the Codegen part.
constexpr bool kPadM = false;
@@ -76,12 +77,12 @@ float grouped_flatmm(const ck_tile::GroupedFlatmmHostArgs& args, const ck_tile::
constexpr auto tail_number_v = tail_number_.value;
constexpr auto memory_operation = memory_operation_.value;
using CodegenPipelineProblem = ck_tile::FlatmmPipelineProblem<ADataType,
BDataType,
AccDataType,
CodegenFlatmmShape,
CodegenGemmTraits,
has_hot_loop_v,
tail_number_v>;
BDataType,
AccDataType,
CodegenFlatmmShape,
CodegenGemmTraits,
has_hot_loop_v,
tail_number_v>;
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<ADataType,
@@ -184,34 +185,69 @@ int run_grouped_flatmm_example(int argc, char* argv[])
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
std::string data_type = arg_parser.get_str("prec");
std::string mode = arg_parser.get_str("mode");
std::string a_layout = arg_parser.get_str("a_layout");
std::string b_layout = arg_parser.get_str("b_layout");
if(a_layout == "R" && b_layout == "C")
{
if(data_type == "fp16")
if(mode == "general")
{
run_grouped_flatmm_example_with_layouts<ck_tile::half_t>(
argc, argv, Row{}, Col{}, Row{});
if(data_type == "fp16")
{
run_grouped_flatmm_example_with_layouts<ck_tile::half_t>(
argc, argv, Row{}, Col{}, Row{});
}
else if(data_type == "bf16")
{
run_grouped_flatmm_example_with_layouts<ck_tile::bf16_t>(
argc, argv, Row{}, Col{}, Row{});
}
else if(data_type == "fp8")
{
run_grouped_flatmm_example_with_layouts<ck_tile::fp8_t>(
argc, argv, Row{}, Col{}, Row{});
}
else if(data_type == "bf8")
{
run_grouped_flatmm_example_with_layouts<ck_tile::bf8_t>(
argc, argv, Row{}, Col{}, Row{});
}
else
{
throw std::runtime_error("Unsupported data_type!");
}
}
else if(data_type == "bf16")
else if(mode == "contiguous")
{
run_grouped_flatmm_example_with_layouts<ck_tile::bf16_t>(
argc, argv, Row{}, Col{}, Row{});
}
else if(data_type == "fp8")
{
run_grouped_flatmm_example_with_layouts<ck_tile::fp8_t>(
argc, argv, Row{}, Col{}, Row{});
}
else if(data_type == "bf8")
{
run_grouped_flatmm_example_with_layouts<ck_tile::bf8_t>(
argc, argv, Row{}, Col{}, Row{});
if(data_type == "fp16")
{
run_contiguous_grouped_flatmm_example_with_layouts<ck_tile::half_t>(
argc, argv, Row{}, Col{}, Row{});
}
else if(data_type == "bf16")
{
run_contiguous_grouped_flatmm_example_with_layouts<ck_tile::bf16_t>(
argc, argv, Row{}, Col{}, Row{});
}
else if(data_type == "fp8")
{
run_contiguous_grouped_flatmm_example_with_layouts<ck_tile::fp8_t>(
argc, argv, Row{}, Col{}, Row{});
}
else if(data_type == "bf8")
{
run_contiguous_grouped_flatmm_example_with_layouts<ck_tile::bf8_t>(
argc, argv, Row{}, Col{}, Row{});
}
else
{
throw std::runtime_error("Unsupported data_type!");
}
}
else
{
throw std::runtime_error("Unsupported data_type!");
throw std::runtime_error("Unsupported mode!");
}
}
else

View File

@@ -29,11 +29,11 @@
#define CK_TILE_PIPELINE_DEFAULT CK_TILE_PIPELINE_COMPUTE
#endif
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY)
#if (CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY)
#define GEMM_PIPELINE ck_tile::GemmPipelineAgBgCrMem
#define UNIVERSAL_GEMM_PIPELINE ck_tile::BaseGemmPipelineAgBgCrMem
#define GEMM_PIPELINE_SCHEDULER ck_tile::GemmPipelineScheduler::Interwave
#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE)
#elif (CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE)
#define GEMM_PIPELINE ck_tile::GemmPipelineAgBgCrCompV3
#define UNIVERSAL_GEMM_PIPELINE ck_tile::BaseGemmPipelineAgBgCrCompV3
#define GEMM_PIPELINE_SCHEDULER ck_tile::GemmPipelineScheduler::Intrawave
@@ -153,9 +153,9 @@ struct GemmConfig<ck_tile::fp8_t>
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile = 128;
#elif defined(USING_MFMA_32x32x64_F8) // MI350 FP8 32X32 (need tune)
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 256;
static constexpr ck_tile::index_t K_Tile = 128;
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 256;
static constexpr ck_tile::index_t K_Tile = 128;
static constexpr ck_tile::index_t M_Warp = 1;
static constexpr ck_tile::index_t N_Warp = 4;
@@ -165,9 +165,9 @@ struct GemmConfig<ck_tile::fp8_t>
static constexpr ck_tile::index_t N_Warp_Tile = 32;
static constexpr ck_tile::index_t K_Warp_Tile = 64;
#elif defined(USING_MFMA_16x16x32_F8) // MI300 FP8 16X16
static constexpr ck_tile::index_t M_Tile = 16;
static constexpr ck_tile::index_t N_Tile = 128;
static constexpr ck_tile::index_t K_Tile = 256;
static constexpr ck_tile::index_t M_Tile = 16;
static constexpr ck_tile::index_t N_Tile = 128;
static constexpr ck_tile::index_t K_Tile = 256;
static constexpr ck_tile::index_t M_Warp = 1;
static constexpr ck_tile::index_t N_Warp = 4;
@@ -177,9 +177,9 @@ struct GemmConfig<ck_tile::fp8_t>
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile = 64;
#elif defined(USING_MFMA_32x32x16_F8) // MI300 FP8 32X32 (need tune)
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 256;
static constexpr ck_tile::index_t K_Tile = 128;
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 256;
static constexpr ck_tile::index_t K_Tile = 128;
static constexpr ck_tile::index_t M_Warp = 1;
static constexpr ck_tile::index_t N_Warp = 8;
@@ -222,9 +222,9 @@ struct GemmConfig
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile = 128;
#elif defined(USING_MFMA_32x32x64_F8) // MI350 FP8 32X32 (need tune)
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 128;
static constexpr ck_tile::index_t K_Tile = 128;
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 128;
static constexpr ck_tile::index_t K_Tile = 128;
static constexpr ck_tile::index_t M_Warp = 1;
static constexpr ck_tile::index_t N_Warp = 4;
@@ -234,9 +234,9 @@ struct GemmConfig
static constexpr ck_tile::index_t N_Warp_Tile = 32;
static constexpr ck_tile::index_t K_Warp_Tile = 64;
#elif defined(USING_MFMA_16x16x32_F16) // MI350 FP16 16X16 (need tune)
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 128;
static constexpr ck_tile::index_t K_Tile = 128;
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 128;
static constexpr ck_tile::index_t K_Tile = 128;
static constexpr ck_tile::index_t M_Warp = 1;
static constexpr ck_tile::index_t N_Warp = 4;
@@ -246,9 +246,9 @@ struct GemmConfig
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile = 32;
#elif defined(USING_MFMA_32x32x16_F16) // MI350 FP16 32X32 (need tune)
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 128;
static constexpr ck_tile::index_t K_Tile = 128;
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 128;
static constexpr ck_tile::index_t K_Tile = 128;
static constexpr ck_tile::index_t M_Warp = 1;
static constexpr ck_tile::index_t N_Warp = 4;
@@ -323,15 +323,16 @@ struct GemmConfig
auto create_args(int argc, char* argv[])
{
ck_tile::ArgParser arg_parser;
arg_parser.insert("Ms", "512,128,1024", "m dimension")
.insert("Ns", "256,512,1024", "n dimension")
.insert("Ks", "512,1024,1024", "k dimension")
arg_parser.insert("Ms", "512,256,1024", "m dimension")
.insert("Ns", "1024,512,256", "n dimension")
.insert("Ks", "1024,1024,512", "k dimension")
.insert("group_count", "3", "group count")
.insert("a_layout", "R", "A tensor data layout - Row by default")
.insert("b_layout", "C", "B tensor data layout - Row by default")
.insert("c_layout", "R", "C tensor data layout - Row by default")
.insert("v", "2", "0. No validation, 1. Validation on CPU, 2. Validation on GPU")
.insert("prec", "fp8", "data type. fp16/bf16/fp8/bf8")
.insert("mode", "general", "grouped gemm mode: [general | contiguous], general by default")
.insert("warmup", "5", "number of iterations before benchmark the kernel")
.insert("repeat", "10", "number of iterations to benchmark the kernel")
.insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer")

View File

@@ -89,6 +89,35 @@ float invoke_gemm(int n_warmup, int n_repeat, const ck_tile::GroupedFlatmmHostAr
return ave_time;
}
template <typename ADataType,
typename BDataType,
typename AccDataType,
typename CDataType,
typename ALayout,
typename BLayout,
typename CLayout>
float invoke_gemm(int n_warmup, int n_repeat, const ck_tile::ContiguousGroupedFlatmmHostArgs& args)
{
float ave_time =
grouped_flatmm<ADataType, BDataType, AccDataType, CDataType, ALayout, BLayout, CLayout>(
args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat});
std::string op_name{"Grouped Gemm"};
std::size_t flop = std::size_t(2) * args.M * args.N * args.K;
std::size_t num_byte = sizeof(ADataType) * args.M * args.K +
sizeof(BDataType) * args.N * args.K +
sizeof(CDataType) * args.M * args.N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_byte / 1.E6 / ave_time;
std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops << " TFlops, "
<< gb_per_sec << " GB/s, " << op_name << std::endl;
return ave_time;
}
template <typename PrecType, typename ALayout, typename BLayout, typename CLayout>
int run_grouped_flatmm_example_with_layouts(int argc,
char* argv[],
@@ -318,3 +347,181 @@ int run_grouped_flatmm_example_with_layouts(int argc,
return pass;
}
template <typename PrecType, typename ALayout, typename BLayout, typename CLayout>
int run_contiguous_grouped_flatmm_example_with_layouts(
int argc,
char* argv[],
const ALayout a_layout = ALayout{},
const BLayout b_layout = BLayout{},
[[maybe_unused]] const CLayout c_layout = CLayout{})
{
auto [result, arg_parser] = create_args(argc, argv);
if(!result)
{
return -1;
};
using ADataType = typename GemmBasicTypeConfig<PrecType>::ADataType;
using BDataType = typename GemmBasicTypeConfig<PrecType>::BDataType;
using CDataType = typename GemmBasicTypeConfig<PrecType>::CDataType;
using AccDataType = typename GemmBasicTypeConfig<PrecType>::AccDataType;
constexpr int BlockM = GemmConfig<BDataType>::M_Tile;
const int group_count = arg_parser.get_int("group_count");
const int repeat = arg_parser.get_int("repeat");
const int warmup = arg_parser.get_int("warmup");
std::vector<ck_tile::index_t> Ms = arg_parser.get_int_vec("Ms");
std::vector<ck_tile::index_t> Ns = arg_parser.get_int_vec("Ns");
std::vector<ck_tile::index_t> Ks = arg_parser.get_int_vec("Ks");
if(!(int(Ms.size()) == group_count))
{
std::cout << "Please check the input data." << std::endl;
// padding additional Ms if needed
for(int i = 0; i < group_count; i++)
{
Ms.push_back(256 + 64 * i);
}
}
ck_tile::index_t M =
std::reduce(Ms.begin(), Ms.begin() + group_count, 0, [](auto acc, auto group_m) {
// round up to the multiple of BlockM
return acc + (group_m + BlockM - 1) / BlockM * BlockM;
});
std::cout << "Total M: " << M << std::endl;
ck_tile::index_t N = Ns[0];
ck_tile::index_t K = Ks[0];
ck_tile::index_t kbatch = arg_parser.get_int("split_k");
ck_tile::index_t stride_A = 0;
ck_tile::index_t stride_B = 0;
ck_tile::index_t stride_C = 0;
stride_A = ck_tile::get_default_stride(M, K, stride_A, is_row_major(a_layout));
stride_B = ck_tile::get_default_stride(K, N * group_count, stride_B, is_row_major(b_layout));
stride_C = ck_tile::get_default_stride(M, N, stride_C, is_row_major(c_layout));
ck_tile::HostTensor<ADataType> a_m_k_tensor(
ck_tile::host_tensor_descriptor(M, K, stride_A, is_row_major(a_layout)));
ck_tile::HostTensor<BDataType> b_k_n_tensor(ck_tile::HostTensor<BDataType>(
ck_tile::host_tensor_descriptor(K, N * group_count, stride_B, is_row_major(b_layout))));
ck_tile::HostTensor<CDataType> c_m_n_tensor(ck_tile::HostTensor<CDataType>(
ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(c_layout))));
std::vector<ck_tile::index_t> m_indices(std::size_t(M), -1);
int indices_fill_start = 0;
for(int i = 0; i < group_count; ++i)
{
int group_m = Ms[i];
int padded_group_m = (group_m + BlockM - 1) / BlockM * BlockM;
for(int j = 0; j < padded_group_m; j++)
{
m_indices[indices_fill_start + j] = j < group_m ? i : -1; // -1 for padding
}
indices_fill_start += padded_group_m;
}
ck_tile::FillUniformDistribution<ADataType>{-1.f, 1.f}(a_m_k_tensor);
ck_tile::FillUniformDistribution<BDataType>{-4.f, 4.f}(b_k_n_tensor);
c_m_n_tensor.SetZero();
ck_tile::HostTensor<BDataType> b_shuffle_host = shuffle_b<BDataType>(b_k_n_tensor);
std::unique_ptr<ck_tile::DeviceMem> a_m_k_dev_buf(
std::make_unique<ck_tile::DeviceMem>(a_m_k_tensor.get_element_space_size_in_bytes()));
std::unique_ptr<ck_tile::DeviceMem> b_shfl_dev_buf(
std::make_unique<ck_tile::DeviceMem>(b_shuffle_host.get_element_space_size_in_bytes()));
std::unique_ptr<ck_tile::DeviceMem> c_m_n_dev_buf(
std::make_unique<ck_tile::DeviceMem>(c_m_n_tensor.get_element_space_size_in_bytes()));
c_m_n_dev_buf->SetZero();
ck_tile::DeviceMem m_indices_dev_buf(M * sizeof(ck_tile::index_t));
m_indices_dev_buf.ToDevice(m_indices.data());
ck_tile::ContiguousGroupedFlatmmHostArgs kernal_args{
static_cast<ck_tile::index_t*>(m_indices_dev_buf.GetDeviceBuffer()),
M,
N,
K,
a_m_k_dev_buf->GetDeviceBuffer(),
stride_A,
b_shfl_dev_buf->GetDeviceBuffer(),
stride_B,
c_m_n_dev_buf->GetDeviceBuffer(),
stride_C,
kbatch,
};
invoke_gemm<ADataType, BDataType, AccDataType, CDataType, ALayout, BLayout, CLayout>(
warmup, repeat, kernal_args);
c_m_n_dev_buf->FromDevice(c_m_n_tensor.data());
bool pass{true};
if(arg_parser.get_int("v") == 1)
{
throw std::runtime_error(
"Not support v=1 host verification in contiguous grouped gemm, use "
"v=2 device verification instead");
}
else if(arg_parser.get_int("v") == 2)
{
BDataType* d_B;
CDataType* d_C;
ck_tile::hip_check_error(hipMalloc(&d_B, N * K * sizeof(BDataType)));
ck_tile::hip_check_error(hipMalloc(&d_C, M * N * sizeof(CDataType)));
ck_tile::HostTensor<CDataType> c_gpu_ref_host(
ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{})));
ck_tile::index_t acc_m = 0;
for(int i = 0; i < group_count; ++i)
{
ck_tile::index_t padded_M = (Ms[i] + BlockM - 1) / BlockM * BlockM;
ck_tile::hip_check_error(hipMemcpy(d_B,
b_k_n_tensor.data() + group_count * N * K,
N * K * sizeof(BDataType),
hipMemcpyHostToDevice));
ck_tile::reference_gemm_gpu<ADataType,
BDataType,
AccDataType,
CDataType,
ALayout,
BLayout,
CLayout>(
static_cast<ADataType*>(a_m_k_dev_buf->GetDeviceBuffer()) + acc_m * K,
d_B,
d_C + acc_m * N,
padded_M,
N,
K,
stride_A,
stride_B,
stride_C);
acc_m += padded_M;
}
ck_tile::hip_check_error(hipMemcpy(
c_gpu_ref_host.data(), d_C, M * N * sizeof(CDataType), hipMemcpyDeviceToHost));
ck_tile::hip_check_error(hipFree(d_B));
ck_tile::hip_check_error(hipFree(d_C));
float rtol = 1e-3;
float atol = 1e-3;
pass = ck_tile::check_err(
c_m_n_tensor, c_gpu_ref_host, "Error: Incorrect results!", rtol, atol);
std::cout << "Relative error threshold: " << rtol << " Absolute error threshold: " << atol
<< std::endl;
std::cout << "The GPU veification result is: " << (pass ? "correct" : "fail") << std::endl;
}
return pass;
}

View File

@@ -53,10 +53,52 @@ struct GroupedFlatmmHostArgs
index_t k_batch;
};
struct ContiguousGroupedFlatmmHostArgs
{
CK_TILE_HOST ContiguousGroupedFlatmmHostArgs() = default;
CK_TILE_HOST ContiguousGroupedFlatmmHostArgs(index_t* M_indices_,
index_t M_,
index_t N_,
index_t K_,
const void* a_ptr_,
index_t stride_A_,
const void* b_shuffle_ptr_,
index_t stride_B_,
void* c_ptr_,
index_t stride_C_,
index_t k_batch_)
: M_indices(M_indices_),
M(M_),
N(N_),
K(K_),
a_ptr(a_ptr_),
stride_A(stride_A_),
b_shuffle_ptr(b_shuffle_ptr_),
stride_B(stride_B_),
c_ptr(c_ptr_),
stride_C(stride_C_),
k_batch(k_batch_)
{
}
index_t* M_indices;
index_t M;
index_t N;
index_t K;
const void* a_ptr;
index_t stride_A;
const void* b_shuffle_ptr;
index_t stride_B;
void* c_ptr;
index_t stride_C;
index_t k_batch;
};
template <typename TilePartitioner_, typename FlatmmPipeline_, typename EpiloguePipeline_>
struct GroupedFlatmmKernel : FlatmmKernel<TilePartitioner_, FlatmmPipeline_, EpiloguePipeline_>
{
using UnderlyingGemmKernel = FlatmmKernel<TilePartitioner_, FlatmmPipeline_, EpiloguePipeline_>;
using BlockGemmShape = typename UnderlyingGemmKernel::BlockGemmShape;
using TilePartitioner = remove_cvref_t<TilePartitioner_>;
using FlatmmPipeline = remove_cvref_t<FlatmmPipeline_>;
@@ -68,15 +110,13 @@ struct GroupedFlatmmKernel : FlatmmKernel<TilePartitioner_, FlatmmPipeline_, Epi
// Below type is actually accumulation data type - the output of block GEMM.
using CDataType = remove_cvref_t<typename EpiloguePipeline::ODataType>;
using GroupedFlatmmKernelArgs = GroupedFlatmmHostArgs;
CK_TILE_HOST static const std::string GetName()
{
return concat(
'_', "grouped_flatmm", gemm_prec_str<ADataType, BDataType>, FlatmmPipeline::GetName());
}
CK_TILE_HOST_DEVICE static auto GridSize(const GroupedFlatmmKernelArgs& kernelArgs)
template <class KernelArgs>
CK_TILE_HOST_DEVICE static auto GridSizeImpl(const KernelArgs& kernelArgs)
{
hipDeviceProp_t prop;
int deviceId = 0; // default device
@@ -89,29 +129,41 @@ struct GroupedFlatmmKernel : FlatmmKernel<TilePartitioner_, FlatmmPipeline_, Epi
e = hipOccupancyMaxActiveBlocksPerMultiprocessor(
&maxActiveBlocksPerCU,
// reinterpret_cast<void*>(GroupedFlatmmKernel::Kernel),
reinterpret_cast<void*>(
kentry2<block_size, GroupedFlatmmKernel, GroupedFlatmmKernelArgs>),
reinterpret_cast<void*>(kentry2<block_size, GroupedFlatmmKernel, KernelArgs>),
block_size,
dync_smem_size);
const int persistent_block_size = prop.multiProcessorCount * maxActiveBlocksPerCU;
// print maxActiveBlocksPerCU and persistent_block_size
// std::cout << "maxActiveBlocksPerCU: " << maxActiveBlocksPerCU
// << ", persistent_block_size: " << persistent_block_size << std::endl;
std::cout << "maxActiveBlocksPerCU: " << maxActiveBlocksPerCU
<< ", persistent_block_size: " << persistent_block_size << std::endl;
assert(kernelArgs.k_batch == 1);
return dim3(persistent_block_size, 1, kernelArgs.k_batch);
}
CK_TILE_HOST static constexpr GroupedFlatmmKernelArgs
MakeKernelArgs(const GroupedFlatmmHostArgs& hostArgs)
CK_TILE_HOST_DEVICE static auto
GridSize([[maybe_unused]] const GroupedFlatmmHostArgs& kernelArgs)
{
return GridSizeImpl<GroupedFlatmmHostArgs>(kernelArgs);
}
CK_TILE_HOST_DEVICE static auto
GridSize([[maybe_unused]] const ContiguousGroupedFlatmmHostArgs& kernelArgs)
{
return GridSizeImpl<ContiguousGroupedFlatmmHostArgs>(kernelArgs);
}
CK_TILE_HOST static constexpr auto MakeKernelArgs(const GroupedFlatmmHostArgs& hostArgs)
{
return hostArgs;
}
CK_TILE_HOST static constexpr auto
MakeKernelArgs(const ContiguousGroupedFlatmmHostArgs& hostArgs)
{
return hostArgs;
}
CK_TILE_DEVICE void operator()(GroupedFlatmmKernelArgs kargs) const
CK_TILE_DEVICE void operator()(GroupedFlatmmHostArgs kargs) const
{
int group_idx = 0;
int block_linear_idx = blockIdx.x;
@@ -147,6 +199,36 @@ struct GroupedFlatmmKernel : FlatmmKernel<TilePartitioner_, FlatmmPipeline_, Epi
block_linear_idx -= group_block_cnt;
}
}
CK_TILE_DEVICE void operator()(ContiguousGroupedFlatmmHostArgs kargs) const
{
int block_linear_idx = blockIdx.x;
int total_block_cnt = gridDim.x;
int total_work_tile_cnt = TilePartitioner::GridSize(kargs.M, kargs.N);
UnderlyingGemmKernel underlying_kernel{};
for(; block_linear_idx < total_work_tile_cnt; block_linear_idx += total_block_cnt)
{
auto [block_m_idx, block_n_idx] = TilePartitioner::GetOutputTileIndex(block_linear_idx);
// get the group index from the M_indices
int group_idx = kargs.M_indices[block_m_idx * BlockGemmShape::kM];
typename UnderlyingGemmKernel::FlatmmKernelArgs impl_kargs{
kargs.a_ptr,
static_cast<const BDataType*>(kargs.b_shuffle_ptr) + group_idx * kargs.N * kargs.K,
kargs.c_ptr,
kargs.M,
kargs.N,
kargs.K,
kargs.stride_A,
kargs.stride_B,
kargs.stride_C,
kargs.k_batch,
};
// call the underlying flatmm kernel
underlying_kernel(impl_kargs, block_linear_idx);
}
}
};
} // namespace ck_tile