mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
Implement grouped gemm fastgelu for RDNA4 (#3303)
* Implement grouped gemm fastgelu for RDNA4 * chore: some cleanup and minor inconsistencies in grouped gemm profiler * chore: clarified logic and reporting of supported instance warnings
This commit is contained in:
@@ -7,6 +7,7 @@
|
||||
#include <sstream>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
|
||||
#include "ck/utility/env.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
#include "ck/host_utility/kernel_launch.hpp"
|
||||
@@ -242,7 +243,6 @@ struct DeviceGroupedGemm_Wmma_CShuffleV3 : public DeviceGroupedGemmSplitK<ALayou
|
||||
static constexpr index_t B2E_M01 = 8;
|
||||
using GroupedGemmBlock2ETileMap = OffsettedBlockToCTileMap<Block2ETileMapKSplit>;
|
||||
using KernelArgument = typename GridwiseGemm::Argument;
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
template <typename KernelArgument_>
|
||||
struct GemmTransKernelArgBase
|
||||
{
|
||||
@@ -274,23 +274,38 @@ struct DeviceGroupedGemm_Wmma_CShuffleV3 : public DeviceGroupedGemmSplitK<ALayou
|
||||
}
|
||||
|
||||
// Argument
|
||||
// TODO: Add A/B/CDE element op?
|
||||
struct Argument : public BaseArgument
|
||||
{
|
||||
|
||||
Argument(std::vector<const void*>& p_As,
|
||||
std::vector<const void*>& p_Bs,
|
||||
std::vector<std::array<const void*, NumDTensor>>& p_Ds,
|
||||
std::vector<void*>& p_Es,
|
||||
std::vector<GemmDesc>& gemm_descs)
|
||||
: Argument(p_As, p_Bs, p_Es, gemm_descs, DefaultKBatch)
|
||||
std::vector<GemmDesc>& gemm_descs,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CDEElementwiseOperation c_element_op)
|
||||
: Argument(p_As,
|
||||
p_Bs,
|
||||
p_Ds,
|
||||
p_Es,
|
||||
gemm_descs,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op,
|
||||
DefaultKBatch)
|
||||
{
|
||||
// TODO: use occupancy api to calculate appropriate batch size.
|
||||
}
|
||||
|
||||
Argument(std::vector<const void*>& p_As,
|
||||
std::vector<const void*>& p_Bs,
|
||||
std::vector<std::array<const void*, NumDTensor>>& p_Ds,
|
||||
std::vector<void*>& p_Es,
|
||||
std::vector<GemmDesc>& gemm_descs,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CDEElementwiseOperation c_element_op,
|
||||
index_t kbatch)
|
||||
: K_BATCH{kbatch}, gemm_kernel_host_args_{nullptr}
|
||||
{
|
||||
@@ -299,9 +314,11 @@ struct DeviceGroupedGemm_Wmma_CShuffleV3 : public DeviceGroupedGemmSplitK<ALayou
|
||||
|
||||
if(!(group_count_ == ck::type_convert<ck::index_t>(p_As.size()) &&
|
||||
group_count_ == ck::type_convert<ck::index_t>(p_Bs.size()) &&
|
||||
((NumDTensor == 0 && p_Ds.size() == 0) ||
|
||||
group_count_ == ck::type_convert<ck::index_t>(p_Ds.size())) &&
|
||||
group_count_ == ck::type_convert<ck::index_t>(p_Es.size())))
|
||||
{
|
||||
throw std::runtime_error("wrong! group_count_ != p_As/b/c.size");
|
||||
throw std::runtime_error("wrong! group_count_ != p_As/b/d/e.size");
|
||||
}
|
||||
|
||||
gemm_kernel_args_.reserve(group_count_);
|
||||
@@ -320,9 +337,22 @@ struct DeviceGroupedGemm_Wmma_CShuffleV3 : public DeviceGroupedGemmSplitK<ALayou
|
||||
continue;
|
||||
}
|
||||
|
||||
const index_t stride_a = gemm_descs[i].stride_A_;
|
||||
const index_t stride_b = gemm_descs[i].stride_B_;
|
||||
const index_t stride_c = gemm_descs[i].stride_C_;
|
||||
const index_t stride_a = gemm_descs[i].stride_A_;
|
||||
const index_t stride_b = gemm_descs[i].stride_B_;
|
||||
const index_t stride_c = gemm_descs[i].stride_C_;
|
||||
const auto& stride_d_vec = gemm_descs[i].stride_Ds_;
|
||||
|
||||
if(!(NumDTensor == ck::type_convert<ck::index_t>(stride_d_vec.size())))
|
||||
{
|
||||
throw std::runtime_error("wrong! stride D mismatch");
|
||||
}
|
||||
|
||||
// Copy D stride vector to fixed-size array
|
||||
std::array<index_t, NumDTensor> stride_ds;
|
||||
if constexpr(NumDTensor > 0)
|
||||
{
|
||||
std::copy(stride_d_vec.begin(), stride_d_vec.end(), stride_ds);
|
||||
}
|
||||
|
||||
const index_t m_padded = GridwiseGemm::CalculateMPadded(M);
|
||||
const index_t n_padded = GridwiseGemm::CalculateNPadded(N);
|
||||
@@ -346,19 +376,19 @@ struct DeviceGroupedGemm_Wmma_CShuffleV3 : public DeviceGroupedGemmSplitK<ALayou
|
||||
|
||||
auto karg = KernelArgument(std::array<const void*, 1>{p_As[i]},
|
||||
std::array<const void*, 1>{p_Bs[i]},
|
||||
std::array<const void*, 0>{}, // p_ds_grid_
|
||||
p_Ds[i],
|
||||
type_convert<EDataType*>(p_Es[i]),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
std::array<index_t, 1>{stride_a},
|
||||
std::array<index_t, 1>{stride_b},
|
||||
std::array<index_t, 0>{}, // StrideDs_
|
||||
stride_ds,
|
||||
stride_c,
|
||||
K_BATCH,
|
||||
PassThrough{},
|
||||
PassThrough{},
|
||||
PassThrough{},
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op,
|
||||
false);
|
||||
|
||||
gemm_kernel_args_.emplace_back(
|
||||
@@ -632,6 +662,23 @@ struct DeviceGroupedGemm_Wmma_CShuffleV3 : public DeviceGroupedGemmSplitK<ALayou
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr(!std::is_same_v<CDEElementwiseOperation,
|
||||
ck::tensor_operation::element_wise::PassThrough>)
|
||||
{
|
||||
if(arg.K_BATCH > 1)
|
||||
{
|
||||
// Using SplitK and a C element op would require a two stage kernel where the second
|
||||
// stage applies the op on the accumulated results
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "C element operators are not supported when using SplitK. Set "
|
||||
"K_BATCH to 1 or remove the operator."
|
||||
<< std::endl;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr(std::is_same_v<ComputeTypeA, f8_t> || std::is_same_v<ComputeTypeA, bf8_t> ||
|
||||
std::is_same_v<ComputeTypeB, f8_t> || std::is_same_v<ComputeTypeB, bf8_t>)
|
||||
{
|
||||
@@ -681,14 +728,15 @@ struct DeviceGroupedGemm_Wmma_CShuffleV3 : public DeviceGroupedGemmSplitK<ALayou
|
||||
|
||||
static auto MakeArgument(std::vector<const void*>& p_As,
|
||||
std::vector<const void*>& p_Bs,
|
||||
std::vector<std::array<const void*, NumDTensor>>&,
|
||||
std::vector<std::array<const void*, NumDTensor>>& p_Ds,
|
||||
std::vector<void*>& p_Es,
|
||||
std::vector<GemmDesc> gemm_descs,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CDEElementwiseOperation)
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CDEElementwiseOperation c_element_op)
|
||||
{
|
||||
return Argument{p_As, p_Bs, p_Es, gemm_descs};
|
||||
return Argument{
|
||||
p_As, p_Bs, p_Ds, p_Es, gemm_descs, a_element_op, b_element_op, c_element_op};
|
||||
}
|
||||
|
||||
static auto MakeInvoker() { return Invoker{}; }
|
||||
@@ -697,14 +745,15 @@ struct DeviceGroupedGemm_Wmma_CShuffleV3 : public DeviceGroupedGemmSplitK<ALayou
|
||||
std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(std::vector<const void*>& p_As,
|
||||
std::vector<const void*>& p_Bs,
|
||||
std::vector<std::array<const void*, NumDTensor>>&,
|
||||
std::vector<std::array<const void*, NumDTensor>>& p_Ds,
|
||||
std::vector<void*>& p_Es,
|
||||
std::vector<GemmDesc>& gemm_descs,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CDEElementwiseOperation) override
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CDEElementwiseOperation c_element_op) override
|
||||
{
|
||||
return std::make_unique<Argument>(p_As, p_Bs, p_Es, gemm_descs);
|
||||
return std::make_unique<Argument>(
|
||||
p_As, p_Bs, p_Ds, p_Es, gemm_descs, a_element_op, b_element_op, c_element_op);
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
|
||||
Reference in New Issue
Block a user