Implement grouped gemm fastgelu for RDNA4 (#3303)

* Implement grouped gemm fastgelu for RDNA4

* chore: some cleanup and minor inconsistencies in grouped gemm profiler

* chore: clarified logic and reporting of supported instance warnings
This commit is contained in:
Erwin Terpstra
2026-01-07 19:20:44 +01:00
committed by GitHub
parent a7d6b1e700
commit f9c6ba0403
24 changed files with 665 additions and 399 deletions

View File

@@ -7,6 +7,7 @@
#include <sstream>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
#include "ck/utility/env.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
@@ -242,7 +243,6 @@ struct DeviceGroupedGemm_Wmma_CShuffleV3 : public DeviceGroupedGemmSplitK<ALayou
static constexpr index_t B2E_M01 = 8;
using GroupedGemmBlock2ETileMap = OffsettedBlockToCTileMap<Block2ETileMapKSplit>;
using KernelArgument = typename GridwiseGemm::Argument;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
template <typename KernelArgument_>
struct GemmTransKernelArgBase
{
@@ -274,23 +274,38 @@ struct DeviceGroupedGemm_Wmma_CShuffleV3 : public DeviceGroupedGemmSplitK<ALayou
}
// Argument
// TODO: Add A/B/CDE element op?
struct Argument : public BaseArgument
{
Argument(std::vector<const void*>& p_As,
std::vector<const void*>& p_Bs,
std::vector<std::array<const void*, NumDTensor>>& p_Ds,
std::vector<void*>& p_Es,
std::vector<GemmDesc>& gemm_descs)
: Argument(p_As, p_Bs, p_Es, gemm_descs, DefaultKBatch)
std::vector<GemmDesc>& gemm_descs,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation c_element_op)
: Argument(p_As,
p_Bs,
p_Ds,
p_Es,
gemm_descs,
a_element_op,
b_element_op,
c_element_op,
DefaultKBatch)
{
// TODO: use occupancy api to calculate appropriate batch size.
}
Argument(std::vector<const void*>& p_As,
std::vector<const void*>& p_Bs,
std::vector<std::array<const void*, NumDTensor>>& p_Ds,
std::vector<void*>& p_Es,
std::vector<GemmDesc>& gemm_descs,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation c_element_op,
index_t kbatch)
: K_BATCH{kbatch}, gemm_kernel_host_args_{nullptr}
{
@@ -299,9 +314,11 @@ struct DeviceGroupedGemm_Wmma_CShuffleV3 : public DeviceGroupedGemmSplitK<ALayou
if(!(group_count_ == ck::type_convert<ck::index_t>(p_As.size()) &&
group_count_ == ck::type_convert<ck::index_t>(p_Bs.size()) &&
((NumDTensor == 0 && p_Ds.size() == 0) ||
group_count_ == ck::type_convert<ck::index_t>(p_Ds.size())) &&
group_count_ == ck::type_convert<ck::index_t>(p_Es.size())))
{
throw std::runtime_error("wrong! group_count_ != p_As/b/c.size");
throw std::runtime_error("wrong! group_count_ != p_As/b/d/e.size");
}
gemm_kernel_args_.reserve(group_count_);
@@ -320,9 +337,22 @@ struct DeviceGroupedGemm_Wmma_CShuffleV3 : public DeviceGroupedGemmSplitK<ALayou
continue;
}
const index_t stride_a = gemm_descs[i].stride_A_;
const index_t stride_b = gemm_descs[i].stride_B_;
const index_t stride_c = gemm_descs[i].stride_C_;
const index_t stride_a = gemm_descs[i].stride_A_;
const index_t stride_b = gemm_descs[i].stride_B_;
const index_t stride_c = gemm_descs[i].stride_C_;
const auto& stride_d_vec = gemm_descs[i].stride_Ds_;
if(!(NumDTensor == ck::type_convert<ck::index_t>(stride_d_vec.size())))
{
throw std::runtime_error("wrong! stride D mismatch");
}
// Copy D stride vector to fixed-size array
std::array<index_t, NumDTensor> stride_ds;
if constexpr(NumDTensor > 0)
{
std::copy(stride_d_vec.begin(), stride_d_vec.end(), stride_ds);
}
const index_t m_padded = GridwiseGemm::CalculateMPadded(M);
const index_t n_padded = GridwiseGemm::CalculateNPadded(N);
@@ -346,19 +376,19 @@ struct DeviceGroupedGemm_Wmma_CShuffleV3 : public DeviceGroupedGemmSplitK<ALayou
auto karg = KernelArgument(std::array<const void*, 1>{p_As[i]},
std::array<const void*, 1>{p_Bs[i]},
std::array<const void*, 0>{}, // p_ds_grid_
p_Ds[i],
type_convert<EDataType*>(p_Es[i]),
M,
N,
K,
std::array<index_t, 1>{stride_a},
std::array<index_t, 1>{stride_b},
std::array<index_t, 0>{}, // StrideDs_
stride_ds,
stride_c,
K_BATCH,
PassThrough{},
PassThrough{},
PassThrough{},
a_element_op,
b_element_op,
c_element_op,
false);
gemm_kernel_args_.emplace_back(
@@ -632,6 +662,23 @@ struct DeviceGroupedGemm_Wmma_CShuffleV3 : public DeviceGroupedGemmSplitK<ALayou
}
}
if constexpr(!std::is_same_v<CDEElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough>)
{
if(arg.K_BATCH > 1)
{
// Using SplitK and a C element op would require a two stage kernel where the second
// stage applies the op on the accumulated results
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
std::cout << "C element operators are not supported when using SplitK. Set "
"K_BATCH to 1 or remove the operator."
<< std::endl;
}
return false;
}
}
if constexpr(std::is_same_v<ComputeTypeA, f8_t> || std::is_same_v<ComputeTypeA, bf8_t> ||
std::is_same_v<ComputeTypeB, f8_t> || std::is_same_v<ComputeTypeB, bf8_t>)
{
@@ -681,14 +728,15 @@ struct DeviceGroupedGemm_Wmma_CShuffleV3 : public DeviceGroupedGemmSplitK<ALayou
static auto MakeArgument(std::vector<const void*>& p_As,
std::vector<const void*>& p_Bs,
std::vector<std::array<const void*, NumDTensor>>&,
std::vector<std::array<const void*, NumDTensor>>& p_Ds,
std::vector<void*>& p_Es,
std::vector<GemmDesc> gemm_descs,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation)
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation c_element_op)
{
return Argument{p_As, p_Bs, p_Es, gemm_descs};
return Argument{
p_As, p_Bs, p_Ds, p_Es, gemm_descs, a_element_op, b_element_op, c_element_op};
}
static auto MakeInvoker() { return Invoker{}; }
@@ -697,14 +745,15 @@ struct DeviceGroupedGemm_Wmma_CShuffleV3 : public DeviceGroupedGemmSplitK<ALayou
std::unique_ptr<BaseArgument>
MakeArgumentPointer(std::vector<const void*>& p_As,
std::vector<const void*>& p_Bs,
std::vector<std::array<const void*, NumDTensor>>&,
std::vector<std::array<const void*, NumDTensor>>& p_Ds,
std::vector<void*>& p_Es,
std::vector<GemmDesc>& gemm_descs,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation) override
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation c_element_op) override
{
return std::make_unique<Argument>(p_As, p_Bs, p_Es, gemm_descs);
return std::make_unique<Argument>(
p_As, p_Bs, p_Ds, p_Es, gemm_descs, a_element_op, b_element_op, c_element_op);
}
// polymorphic