Implement grouped gemm fastgelu for RDNA4 (#3303)

* Implement grouped gemm fastgelu for RDNA4

* chore: some cleanup and minor inconsistencies in grouped gemm profiler

* chore: clarified logic and reporting of supported instance warnings

[ROCm/composable_kernel commit: f9c6ba0403]
This commit is contained in:
Erwin Terpstra
2026-01-07 19:20:44 +01:00
committed by GitHub
parent 6f6256381a
commit d074af36c9
24 changed files with 665 additions and 399 deletions

View File

@@ -7,6 +7,7 @@
#include <sstream>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
#include "ck/utility/env.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
@@ -242,7 +243,6 @@ struct DeviceGroupedGemm_Wmma_CShuffleV3 : public DeviceGroupedGemmSplitK<ALayou
static constexpr index_t B2E_M01 = 8;
using GroupedGemmBlock2ETileMap = OffsettedBlockToCTileMap<Block2ETileMapKSplit>;
using KernelArgument = typename GridwiseGemm::Argument;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
template <typename KernelArgument_>
struct GemmTransKernelArgBase
{
@@ -274,23 +274,38 @@ struct DeviceGroupedGemm_Wmma_CShuffleV3 : public DeviceGroupedGemmSplitK<ALayou
}
// Argument
// TODO: Add A/B/CDE element op?
struct Argument : public BaseArgument
{
Argument(std::vector<const void*>& p_As,
std::vector<const void*>& p_Bs,
std::vector<std::array<const void*, NumDTensor>>& p_Ds,
std::vector<void*>& p_Es,
std::vector<GemmDesc>& gemm_descs)
: Argument(p_As, p_Bs, p_Es, gemm_descs, DefaultKBatch)
std::vector<GemmDesc>& gemm_descs,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation c_element_op)
: Argument(p_As,
p_Bs,
p_Ds,
p_Es,
gemm_descs,
a_element_op,
b_element_op,
c_element_op,
DefaultKBatch)
{
// TODO: use occupancy api to calculate appropriate batch size.
}
Argument(std::vector<const void*>& p_As,
std::vector<const void*>& p_Bs,
std::vector<std::array<const void*, NumDTensor>>& p_Ds,
std::vector<void*>& p_Es,
std::vector<GemmDesc>& gemm_descs,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation c_element_op,
index_t kbatch)
: K_BATCH{kbatch}, gemm_kernel_host_args_{nullptr}
{
@@ -299,9 +314,11 @@ struct DeviceGroupedGemm_Wmma_CShuffleV3 : public DeviceGroupedGemmSplitK<ALayou
if(!(group_count_ == ck::type_convert<ck::index_t>(p_As.size()) &&
group_count_ == ck::type_convert<ck::index_t>(p_Bs.size()) &&
((NumDTensor == 0 && p_Ds.size() == 0) ||
group_count_ == ck::type_convert<ck::index_t>(p_Ds.size())) &&
group_count_ == ck::type_convert<ck::index_t>(p_Es.size())))
{
throw std::runtime_error("wrong! group_count_ != p_As/b/c.size");
throw std::runtime_error("wrong! group_count_ != p_As/b/d/e.size");
}
gemm_kernel_args_.reserve(group_count_);
@@ -320,9 +337,22 @@ struct DeviceGroupedGemm_Wmma_CShuffleV3 : public DeviceGroupedGemmSplitK<ALayou
continue;
}
const index_t stride_a = gemm_descs[i].stride_A_;
const index_t stride_b = gemm_descs[i].stride_B_;
const index_t stride_c = gemm_descs[i].stride_C_;
const index_t stride_a = gemm_descs[i].stride_A_;
const index_t stride_b = gemm_descs[i].stride_B_;
const index_t stride_c = gemm_descs[i].stride_C_;
const auto& stride_d_vec = gemm_descs[i].stride_Ds_;
if(!(NumDTensor == ck::type_convert<ck::index_t>(stride_d_vec.size())))
{
throw std::runtime_error("wrong! stride D mismatch");
}
// Copy D stride vector to fixed-size array
std::array<index_t, NumDTensor> stride_ds;
if constexpr(NumDTensor > 0)
{
std::copy(stride_d_vec.begin(), stride_d_vec.end(), stride_ds);
}
const index_t m_padded = GridwiseGemm::CalculateMPadded(M);
const index_t n_padded = GridwiseGemm::CalculateNPadded(N);
@@ -346,19 +376,19 @@ struct DeviceGroupedGemm_Wmma_CShuffleV3 : public DeviceGroupedGemmSplitK<ALayou
auto karg = KernelArgument(std::array<const void*, 1>{p_As[i]},
std::array<const void*, 1>{p_Bs[i]},
std::array<const void*, 0>{}, // p_ds_grid_
p_Ds[i],
type_convert<EDataType*>(p_Es[i]),
M,
N,
K,
std::array<index_t, 1>{stride_a},
std::array<index_t, 1>{stride_b},
std::array<index_t, 0>{}, // StrideDs_
stride_ds,
stride_c,
K_BATCH,
PassThrough{},
PassThrough{},
PassThrough{},
a_element_op,
b_element_op,
c_element_op,
false);
gemm_kernel_args_.emplace_back(
@@ -632,6 +662,23 @@ struct DeviceGroupedGemm_Wmma_CShuffleV3 : public DeviceGroupedGemmSplitK<ALayou
}
}
if constexpr(!std::is_same_v<CDEElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough>)
{
if(arg.K_BATCH > 1)
{
// Using SplitK and a C element op would require a two stage kernel where the second
// stage applies the op on the accumulated results
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
std::cout << "C element operators are not supported when using SplitK. Set "
"K_BATCH to 1 or remove the operator."
<< std::endl;
}
return false;
}
}
if constexpr(std::is_same_v<ComputeTypeA, f8_t> || std::is_same_v<ComputeTypeA, bf8_t> ||
std::is_same_v<ComputeTypeB, f8_t> || std::is_same_v<ComputeTypeB, bf8_t>)
{
@@ -681,14 +728,15 @@ struct DeviceGroupedGemm_Wmma_CShuffleV3 : public DeviceGroupedGemmSplitK<ALayou
static auto MakeArgument(std::vector<const void*>& p_As,
std::vector<const void*>& p_Bs,
std::vector<std::array<const void*, NumDTensor>>&,
std::vector<std::array<const void*, NumDTensor>>& p_Ds,
std::vector<void*>& p_Es,
std::vector<GemmDesc> gemm_descs,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation)
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation c_element_op)
{
return Argument{p_As, p_Bs, p_Es, gemm_descs};
return Argument{
p_As, p_Bs, p_Ds, p_Es, gemm_descs, a_element_op, b_element_op, c_element_op};
}
static auto MakeInvoker() { return Invoker{}; }
@@ -697,14 +745,15 @@ struct DeviceGroupedGemm_Wmma_CShuffleV3 : public DeviceGroupedGemmSplitK<ALayou
std::unique_ptr<BaseArgument>
MakeArgumentPointer(std::vector<const void*>& p_As,
std::vector<const void*>& p_Bs,
std::vector<std::array<const void*, NumDTensor>>&,
std::vector<std::array<const void*, NumDTensor>>& p_Ds,
std::vector<void*>& p_Es,
std::vector<GemmDesc>& gemm_descs,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation) override
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation c_element_op) override
{
return std::make_unique<Argument>(p_As, p_Bs, p_Es, gemm_descs);
return std::make_unique<Argument>(
p_As, p_Bs, p_Ds, p_Es, gemm_descs, a_element_op, b_element_op, c_element_op);
}
// polymorphic

View File

@@ -31,6 +31,7 @@ using S = ck::Sequence<Is...>;
using Empty_Tuple = ck::Tuple<>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using FastGelu = ck::tensor_operation::element_wise::FastGelu;
using AccDataType = F32;
using DsDataType = Empty_Tuple;
@@ -38,10 +39,6 @@ using DsDataType = Empty_Tuple;
using DsLayout = Empty_Tuple;
using ELayout = Row;
using AElementOp = PassThrough;
using BElementOp = PassThrough;
using CDEElementOp = PassThrough;
static constexpr auto PipelineV1 = BlockGemmPipelineVersion::v1;
static constexpr auto PipelineV3 = BlockGemmPipelineVersion::v3;
static constexpr auto IntrawaveScheduler = BlockGemmPipelineScheduler::Intrawave;
@@ -54,6 +51,9 @@ template <typename T,
device::GemmSpecialization GemmSpec,
BlockGemmPipelineScheduler BlkGemmPipeSched,
BlockGemmPipelineVersion BlkGemmPipelineVer,
typename AElementOp,
typename BElementOp,
typename CDEElementOp,
enable_if_t<sizeof(T) == 2, bool> = false>
using device_grouped_gemm_wmma_universal_km_kn_mn_instances =
std::tuple<
@@ -73,6 +73,9 @@ template <typename T,
device::GemmSpecialization GemmSpec,
BlockGemmPipelineScheduler BlkGemmPipeSched,
BlockGemmPipelineVersion BlkGemmPipelineVer,
typename AElementOp,
typename BElementOp,
typename CDEElementOp,
enable_if_t<sizeof(T) == 2, bool> = false>
using device_grouped_gemm_wmma_universal_km_nk_mn_instances = std::tuple<
// clang-format off
@@ -91,6 +94,9 @@ template <typename T,
device::GemmSpecialization GemmSpec,
BlockGemmPipelineScheduler BlkGemmPipeSched,
BlockGemmPipelineVersion BlkGemmPipelineVer,
typename AElementOp,
typename BElementOp,
typename CDEElementOp,
enable_if_t<sizeof(T) == 2, bool> = false>
using device_grouped_gemm_wmma_universal_mk_kn_mn_instances =
std::tuple<
@@ -110,6 +116,9 @@ template <typename T,
device::GemmSpecialization GemmSpec,
BlockGemmPipelineScheduler BlkGemmPipeSched,
BlockGemmPipelineVersion BlkGemmPipelineVer,
typename AElementOp,
typename BElementOp,
typename CDEElementOp,
enable_if_t<sizeof(T) == 2, bool> = false>
using device_grouped_gemm_wmma_universal_mk_nk_mn_instances =
std::tuple<
@@ -124,17 +133,38 @@ using device_grouped_gemm_wmma_universal_mk_nk_mn_instances =
// clang-format on
>;
// List of instance variants to add (pipeline/scheduler/padding combinations)
// Some are disabled now, can be re-enabled if needed
using InstanceVariant =
ck::Tuple<device::GemmSpecialization, BlockGemmPipelineScheduler, BlockGemmPipelineVersion>;
static constexpr InstanceVariant InstanceVariants[] = {
make_tuple(GemmDefault, IntrawaveScheduler, PipelineV1),
// make_tuple(GemmDefault, InterwaveScheduler, PipelineV1),
make_tuple(GemmDefault, IntrawaveScheduler, PipelineV3),
make_tuple(GemmMNKPadding, IntrawaveScheduler, PipelineV1),
// make_tuple(GemmMNKPadding, InterwaveScheduler, PipelineV1),
// make_tuple(GemmMNKPadding, IntrawaveScheduler, PipelineV3),
};
// Helper function to add a list of layout instances with specific A/B/E datatypes for all supported
// padding/scheduler/pipeline version combinations
template <typename ALayout,
typename BLayout,
template <device::GemmSpecialization GemmSpec,
BlockGemmPipelineScheduler BlkGemmPipeSched,
BlockGemmPipelineVersion BlkGemmPipelineVer>
BlockGemmPipelineVersion BlkGemmPipelineVer,
typename AElementOp,
typename BElementOp,
typename CDEElementOp>
typename LayoutInstances,
typename ADataType, // NOTE: type parameters as last so that they can be inferred from the
typename BDataType, // vector argument
typename EDataType>
typename EDataType,
typename AElementOp,
typename BElementOp,
typename CDEElementOp>
void add_device_grouped_gemm_wmma_universal_instances(
std::vector<std::unique_ptr<DeviceGroupedGemm<ALayout,
BLayout,
@@ -148,18 +178,17 @@ void add_device_grouped_gemm_wmma_universal_instances(
BElementOp,
CDEElementOp>>>& instances)
{
add_device_operation_instances(instances,
LayoutInstances<GemmDefault, IntrawaveScheduler, PipelineV1>{});
add_device_operation_instances(instances,
LayoutInstances<GemmDefault, InterwaveScheduler, PipelineV1>{});
add_device_operation_instances(instances,
LayoutInstances<GemmDefault, IntrawaveScheduler, PipelineV3>{});
add_device_operation_instances(
instances, LayoutInstances<GemmMNKPadding, IntrawaveScheduler, PipelineV1>{});
add_device_operation_instances(
instances, LayoutInstances<GemmMNKPadding, InterwaveScheduler, PipelineV1>{});
add_device_operation_instances(
instances, LayoutInstances<GemmMNKPadding, IntrawaveScheduler, PipelineV3>{});
// Add all instances from our instance list
static_for<0, std::size(InstanceVariants), 1>{}([&](auto i) {
constexpr auto instance = InstanceVariants[i];
add_device_operation_instances(instances,
LayoutInstances<instance.At(Number<0>{}),
instance.At(Number<1>{}),
instance.At(Number<2>{}),
AElementOp,
BElementOp,
CDEElementOp>{});
});
}
// Helper function to add a list of layout instances for instances with matching A/B/E data types
@@ -170,8 +199,14 @@ template <typename T,
template <typename T2,
device::GemmSpecialization GemmSpec,
BlockGemmPipelineScheduler BlkGemmPipeSched,
BlockGemmPipelineVersion BlkGemmPipelineVer>
typename LayoutInstances>
BlockGemmPipelineVersion BlkGemmPipelineVer,
typename AElementOp,
typename BElementOp,
typename CDEElementOp>
typename LayoutInstances,
typename AElementOp, // NOTE: element-wise op parameters as last so that they can be
typename BElementOp, // inferred from the vector argument
typename CDEElementOp>
void add_device_grouped_gemm_wmma_universal_instances(
std::vector<std::unique_ptr<DeviceGroupedGemm<ALayout,
BLayout,
@@ -185,18 +220,18 @@ void add_device_grouped_gemm_wmma_universal_instances(
BElementOp,
CDEElementOp>>>& instances)
{
add_device_operation_instances(
instances, LayoutInstances<T, GemmDefault, IntrawaveScheduler, PipelineV1>{});
add_device_operation_instances(
instances, LayoutInstances<T, GemmDefault, InterwaveScheduler, PipelineV1>{});
add_device_operation_instances(
instances, LayoutInstances<T, GemmDefault, IntrawaveScheduler, PipelineV3>{});
add_device_operation_instances(
instances, LayoutInstances<T, GemmMNKPadding, IntrawaveScheduler, PipelineV1>{});
add_device_operation_instances(
instances, LayoutInstances<T, GemmMNKPadding, InterwaveScheduler, PipelineV1>{});
add_device_operation_instances(
instances, LayoutInstances<T, GemmMNKPadding, IntrawaveScheduler, PipelineV3>{});
// Add all instances from our instance list
static_for<0, std::size(InstanceVariants), 1>{}([&](auto i) {
constexpr auto instance = InstanceVariants[i];
add_device_operation_instances(instances,
LayoutInstances<T,
instance.At(Number<0>{}),
instance.At(Number<1>{}),
instance.At(Number<2>{}),
AElementOp,
BElementOp,
CDEElementOp>{});
});
}
} // namespace instance

View File

@@ -15,6 +15,64 @@ namespace tensor_operation {
namespace device {
namespace instance {
#if defined(CK_USE_WMMA)
#if defined(CK_ENABLE_FP16)
void add_device_grouped_gemm_fastgelu_wmma_f16_f16_f16_mk_kn_mn_instances(
std::vector<std::unique_ptr<DeviceGroupedGemm<Row,
Row,
Empty_Tuple,
Row,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
FastGelu>>>& instances);
void add_device_grouped_gemm_fastgelu_wmma_f16_f16_f16_mk_nk_mn_instances(
std::vector<std::unique_ptr<DeviceGroupedGemm<Row,
Col,
Empty_Tuple,
Row,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
FastGelu>>>& instances);
void add_device_grouped_gemm_fastgelu_wmma_f16_f16_f16_km_kn_mn_instances(
std::vector<std::unique_ptr<DeviceGroupedGemm<Col,
Row,
Empty_Tuple,
Row,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
FastGelu>>>& instances);
void add_device_grouped_gemm_fastgelu_wmma_f16_f16_f16_km_nk_mn_instances(
std::vector<std::unique_ptr<DeviceGroupedGemm<Col,
Col,
Empty_Tuple,
Row,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
FastGelu>>>& instances);
#endif // CK_ENABLE_FP16
#endif // CK_USE_WMMA
#if defined(CK_USE_XDL)
#if defined(CK_ENABLE_FP16)
void add_device_grouped_gemm_fastgelu_xdl_f16_f16_f16_mk_kn_mn_instances(
std::vector<std::unique_ptr<DeviceGroupedGemm<Row,
Row,
@@ -66,6 +124,8 @@ void add_device_grouped_gemm_fastgelu_xdl_f16_f16_f16_km_nk_mn_instances(
PassThrough,
PassThrough,
FastGelu>>>& instances);
#endif // CK_ENABLE_FP16
#endif // CK_USE_XDL
// GroupedGEMM + GELU
template <typename ALayout,
@@ -102,30 +162,52 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
#if defined(CK_ENABLE_FP16)
if constexpr(is_same_v<ADataType, half_t> && is_same_v<BDataType, half_t> &&
is_same_v<EDataType, half_t>)
{
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
is_same_v<ELayout, Row>)
{
#if defined(CK_USE_XDL)
add_device_grouped_gemm_fastgelu_xdl_f16_f16_f16_mk_kn_mn_instances(op_ptrs);
#endif
#if defined(CK_USE_WMMA)
add_device_grouped_gemm_fastgelu_wmma_f16_f16_f16_mk_kn_mn_instances(op_ptrs);
#endif
}
else if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
is_same_v<ELayout, Row>)
{
#if defined(CK_USE_XDL)
add_device_grouped_gemm_fastgelu_xdl_f16_f16_f16_mk_nk_mn_instances(op_ptrs);
#endif
#if defined(CK_USE_WMMA)
add_device_grouped_gemm_fastgelu_wmma_f16_f16_f16_mk_nk_mn_instances(op_ptrs);
#endif
}
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Row> &&
is_same_v<ELayout, Row>)
{
#if defined(CK_USE_XDL)
add_device_grouped_gemm_fastgelu_xdl_f16_f16_f16_km_kn_mn_instances(op_ptrs);
#endif
#if defined(CK_USE_WMMA)
add_device_grouped_gemm_fastgelu_wmma_f16_f16_f16_km_kn_mn_instances(op_ptrs);
#endif
}
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Col> &&
is_same_v<ELayout, Row>)
{
#if defined(CK_USE_XDL)
add_device_grouped_gemm_fastgelu_xdl_f16_f16_f16_km_nk_mn_instances(op_ptrs);
#endif
#if defined(CK_USE_WMMA)
add_device_grouped_gemm_fastgelu_wmma_f16_f16_f16_km_nk_mn_instances(op_ptrs);
#endif
}
}
#endif
return op_ptrs;
}
};

View File

@@ -20,9 +20,9 @@ void add_device_grouped_gemm_wmma_universal_bf16_bf16_bf16_km_kn_mn_instances(
BF16,
DsDataType,
BF16,
AElementOp,
BElementOp,
CDEElementOp>>>& instances)
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_device_grouped_gemm_wmma_universal_instances<
BF16,

View File

@@ -20,9 +20,9 @@ void add_device_grouped_gemm_wmma_universal_bf16_bf16_bf16_km_nk_mn_instances(
BF16,
DsDataType,
BF16,
AElementOp,
BElementOp,
CDEElementOp>>>& instances)
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_device_grouped_gemm_wmma_universal_instances<
BF16,

View File

@@ -20,9 +20,9 @@ void add_device_grouped_gemm_wmma_universal_bf16_bf16_bf16_mk_kn_mn_instances(
BF16,
DsDataType,
BF16,
AElementOp,
BElementOp,
CDEElementOp>>>& instances)
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_device_grouped_gemm_wmma_universal_instances<
BF16,

View File

@@ -20,9 +20,9 @@ void add_device_grouped_gemm_wmma_universal_bf16_bf16_bf16_mk_nk_mn_instances(
BF16,
DsDataType,
BF16,
AElementOp,
BElementOp,
CDEElementOp>>>& instances)
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_device_grouped_gemm_wmma_universal_instances<
BF16,

View File

@@ -20,9 +20,9 @@ void add_device_grouped_gemm_wmma_universal_f16_f16_f16_km_kn_mn_instances(
F16,
DsDataType,
F16,
AElementOp,
BElementOp,
CDEElementOp>>>& instances)
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_device_grouped_gemm_wmma_universal_instances<
F16,

View File

@@ -20,9 +20,9 @@ void add_device_grouped_gemm_wmma_universal_f16_f16_f16_km_nk_mn_instances(
F16,
DsDataType,
F16,
AElementOp,
BElementOp,
CDEElementOp>>>& instances)
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_device_grouped_gemm_wmma_universal_instances<
F16,

View File

@@ -20,9 +20,9 @@ void add_device_grouped_gemm_wmma_universal_f16_f16_f16_mk_kn_mn_instances(
F16,
DsDataType,
F16,
AElementOp,
BElementOp,
CDEElementOp>>>& instances)
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_device_grouped_gemm_wmma_universal_instances<

View File

@@ -20,9 +20,9 @@ void add_device_grouped_gemm_wmma_universal_f16_f16_f16_mk_nk_mn_instances(
F16,
DsDataType,
F16,
AElementOp,
BElementOp,
CDEElementOp>>>& instances)
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_device_grouped_gemm_wmma_universal_instances<

View File

@@ -17,7 +17,10 @@ using EDataType = F16;
template <device::GemmSpecialization GemmSpec,
BlockGemmPipelineScheduler BlkGemmPipeSched,
BlockGemmPipelineVersion BlkGemmPipelineVer>
BlockGemmPipelineVersion BlkGemmPipelineVer,
typename AElementOp,
typename BElementOp,
typename CDEElementOp>
using device_grouped_gemm_wmma_universal_f16_f8_f16_mk_kn_mn_instances =
std::tuple<
// clang-format off
@@ -40,9 +43,9 @@ void add_device_grouped_gemm_wmma_universal_f16_f8_f16_mk_kn_mn_instances(
BDataType,
DsDataType,
EDataType,
AElementOp,
BElementOp,
CDEElementOp>>>& instances)
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_device_grouped_gemm_wmma_universal_instances<

View File

@@ -17,7 +17,10 @@ using EDataType = F16;
template <device::GemmSpecialization GemmSpec,
BlockGemmPipelineScheduler BlkGemmPipeSched,
BlockGemmPipelineVersion BlkGemmPipelineVer>
BlockGemmPipelineVersion BlkGemmPipelineVer,
typename AElementOp,
typename BElementOp,
typename CDEElementOp>
using device_grouped_gemm_wmma_universal_f8_f16_f16_mk_kn_mn_instances =
std::tuple<
// clang-format off
@@ -40,9 +43,9 @@ void add_device_grouped_gemm_wmma_universal_f8_f16_f16_mk_kn_mn_instances(
BDataType,
DsDataType,
EDataType,
AElementOp,
BElementOp,
CDEElementOp>>>& instances)
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_device_grouped_gemm_wmma_universal_instances<

View File

@@ -1,10 +1,15 @@
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
# ONLY XDL_KERNELS
# ONLY XDL_AND_WMMA_KERNELS
add_instance_library(device_grouped_gemm_fastgelu_instance
device_grouped_gemm_fastgelu_xdl_f16_f16_f16_mk_kn_mn_instance.cpp
device_grouped_gemm_fastgelu_xdl_f16_f16_f16_mk_nk_mn_instance.cpp
device_grouped_gemm_fastgelu_xdl_f16_f16_f16_km_kn_mn_instance.cpp
device_grouped_gemm_fastgelu_xdl_f16_f16_f16_km_nk_mn_instance.cpp
device_grouped_gemm_fastgelu_wmma_f16_f16_f16_mk_kn_mn_instance.cpp
device_grouped_gemm_fastgelu_wmma_f16_f16_f16_mk_nk_mn_instance.cpp
device_grouped_gemm_fastgelu_wmma_f16_f16_f16_km_kn_mn_instance.cpp
device_grouped_gemm_fastgelu_wmma_f16_f16_f16_km_nk_mn_instance.cpp
)

View File

@@ -0,0 +1,37 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include <cstdlib>
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_splitk_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_gemm_fastgelu_wmma_f16_f16_f16_km_kn_mn_instances(
std::vector<std::unique_ptr<DeviceGroupedGemm<Col,
Row,
DsLayout,
Row,
F16,
F16,
DsDataType,
F16,
PassThrough,
PassThrough,
FastGelu>>>& instances)
{
add_device_grouped_gemm_wmma_universal_instances<
F16,
Col,
Row,
device_grouped_gemm_wmma_universal_km_kn_mn_instances>(instances);
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,37 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include <cstdlib>
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_splitk_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_gemm_fastgelu_wmma_f16_f16_f16_km_nk_mn_instances(
std::vector<std::unique_ptr<DeviceGroupedGemm<Col,
Col,
DsLayout,
Row,
F16,
F16,
DsDataType,
F16,
PassThrough,
PassThrough,
FastGelu>>>& instances)
{
add_device_grouped_gemm_wmma_universal_instances<
F16,
Col,
Col,
device_grouped_gemm_wmma_universal_km_nk_mn_instances>(instances);
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,38 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include <cstdlib>
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_splitk_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_gemm_fastgelu_wmma_f16_f16_f16_mk_kn_mn_instances(
std::vector<std::unique_ptr<DeviceGroupedGemm<Row,
Row,
DsLayout,
Row,
F16,
F16,
DsDataType,
F16,
PassThrough,
PassThrough,
FastGelu>>>& instances)
{
add_device_grouped_gemm_wmma_universal_instances<
F16,
Row,
Row,
device_grouped_gemm_wmma_universal_mk_kn_mn_instances>(instances);
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,38 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include <cstdlib>
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_splitk_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_gemm_fastgelu_wmma_f16_f16_f16_mk_nk_mn_instances(
std::vector<std::unique_ptr<DeviceGroupedGemm<Row,
Col,
DsLayout,
Row,
F16,
F16,
DsDataType,
F16,
PassThrough,
PassThrough,
FastGelu>>>& instances)
{
add_device_grouped_gemm_wmma_universal_instances<
F16,
Row,
Col,
device_grouped_gemm_wmma_universal_mk_nk_mn_instances>(instances);
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -17,6 +17,8 @@
#include "ck/tensor_operation/gpu/device/device_grouped_gemm.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "profile_grouped_gemm_impl.hpp"
namespace ck {
namespace profiler {
@@ -38,242 +40,30 @@ bool profile_grouped_gemm_fastgelu_impl(int do_verification,
const std::vector<int>& StrideBs,
const std::vector<int>& StrideCs)
{
bool pass = true;
auto f_host_tensor_descriptor =
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
using namespace ck::literals;
if(is_same<decltype(layout), tensor_layout::gemm::RowMajor>::value)
{
return HostTensorDescriptor({row, col}, {stride, 1_uz});
}
else
{
return HostTensorDescriptor({row, col}, {1_uz, stride});
}
};
std::size_t group_count = Ms.size();
if(!(group_count == Ns.size() && group_count == Ks.size() && group_count == StrideAs.size() &&
group_count == StrideBs.size() && group_count == StrideCs.size()))
{
throw std::runtime_error("wrong! inconsistent M/N/Ks, StrideA/B/Cs size\n");
}
std::vector<Tensor<ADataType>> a_m_k;
std::vector<Tensor<BDataType>> b_k_n;
std::vector<Tensor<CDataType>> c_m_n_device_results;
for(std::size_t i = 0; i < group_count; i++)
{
a_m_k.push_back(
Tensor<ADataType>(f_host_tensor_descriptor(Ms[i], Ks[i], StrideAs[i], ALayout{})));
b_k_n.push_back(
Tensor<BDataType>(f_host_tensor_descriptor(Ks[i], Ns[i], StrideBs[i], BLayout{})));
c_m_n_device_results.push_back(
Tensor<CDataType>(f_host_tensor_descriptor(Ms[i], Ns[i], StrideCs[i], CLayout{})));
std::cout << "group: " << i << " a_m_k[" << i << "]:" << a_m_k[i].mDesc << ", b_k_n[" << i
<< "]:" << b_k_n[i].mDesc << ", c_m_n_device_results[" << i
<< "]:" << c_m_n_device_results[i].mDesc << std::endl;
switch(init_method)
{
case 0: break;
case 1:
ck::utils::FillUniformDistributionIntegerValue<ADataType>{}(a_m_k[i]);
ck::utils::FillUniformDistributionIntegerValue<BDataType>{}(b_k_n[i]);
break;
default:
ck::utils::FillUniformDistribution<ADataType>{0.0, 1.0}(a_m_k[i]);
ck::utils::FillUniformDistribution<BDataType>{-0.5, 0.5}(b_k_n[i]);
}
ck::utils::FillConstant<CDataType>{}(c_m_n_device_results[i]);
}
using AElementOp = ck::tensor_operation::element_wise::PassThrough;
using BElementOp = ck::tensor_operation::element_wise::PassThrough;
using CElementOp = ck::tensor_operation::element_wise::FastGelu;
const auto a_element_op = AElementOp{};
const auto b_element_op = BElementOp{};
const auto c_element_op = CElementOp{};
using DeviceMemPtr = std::unique_ptr<DeviceMem>;
std::vector<DeviceMemPtr> a_device_buf, b_device_buf, c_device_buf;
a_device_buf.reserve(group_count);
b_device_buf.reserve(group_count);
c_device_buf.reserve(group_count);
std::vector<const void*> p_a, p_b;
std::vector<void*> p_c;
p_a.reserve(group_count);
p_b.reserve(group_count);
p_c.reserve(group_count);
std::vector<ck::tensor_operation::device::GemmDesc> gemm_descs;
gemm_descs.reserve(group_count);
for(std::size_t i = 0; i < group_count; i++)
{
a_device_buf.emplace_back(
std::make_unique<DeviceMem>(sizeof(ADataType) * a_m_k[i].mDesc.GetElementSpaceSize()));
b_device_buf.emplace_back(
std::make_unique<DeviceMem>(sizeof(BDataType) * b_k_n[i].mDesc.GetElementSpaceSize()));
c_device_buf.emplace_back(std::make_unique<DeviceMem>(
sizeof(CDataType) * c_m_n_device_results[i].mDesc.GetElementSpaceSize()));
a_device_buf[i]->ToDevice(a_m_k[i].mData.data());
b_device_buf[i]->ToDevice(b_k_n[i].mData.data());
c_device_buf[i]->SetZero();
gemm_descs.push_back({Ms[i], Ns[i], Ks[i], StrideAs[i], StrideBs[i], StrideCs[i], {}});
p_a.push_back(a_device_buf[i]->GetDeviceBuffer());
p_b.push_back(b_device_buf[i]->GetDeviceBuffer());
p_c.push_back(c_device_buf[i]->GetDeviceBuffer());
}
using DeviceOp = ck::tensor_operation::device::DeviceGroupedGemm<ALayout,
BLayout,
ck::Tuple<>,
CLayout,
ADataType,
BDataType,
ck::Tuple<>,
CDataType,
AElementOp,
BElementOp,
CElementOp>;
const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
DeviceOp>::GetInstances();
if(op_ptrs.size() <= 0)
{
throw std::runtime_error("wrong! no device GEMM instance found");
}
std::string best_gemm_name;
float best_ave_time = 0;
float best_tflops = 0;
float best_gb_per_sec = 0;
auto p_ds = std::vector<std::array<const void*, 0>>{};
// profile device GEMM instances
for(auto& gemm_ptr : op_ptrs)
{
auto argument_ptr = gemm_ptr->MakeArgumentPointer(
p_a, p_b, p_ds, p_c, gemm_descs, a_element_op, b_element_op, c_element_op);
auto invoker_ptr = gemm_ptr->MakeInvokerPointer();
DeviceMem gemm_desc_workspace(gemm_ptr->GetWorkSpaceSize(argument_ptr.get()));
gemm_ptr->SetWorkSpacePointer(argument_ptr.get(), gemm_desc_workspace.GetDeviceBuffer());
if(gemm_ptr->IsSupportedArgument(argument_ptr.get()))
{
std::string gemm_name = gemm_ptr->GetTypeString();
float ave_time =
invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel});
std::size_t flop = 0, num_btype = 0;
for(std::size_t i = 0; i < gemm_descs.size(); i++)
{
flop += std::size_t(2) * Ms[i] * Ns[i] * Ks[i];
num_btype += sizeof(ADataType) * Ms[i] * Ks[i] + sizeof(BDataType) * Ks[i] * Ns[i] +
sizeof(CDataType) * Ms[i] * Ns[i];
}
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops << " TFlops, "
<< gb_per_sec << " GB/s, " << gemm_name << std::endl;
if(tflops > best_tflops)
{
best_gemm_name = gemm_name;
best_tflops = tflops;
best_ave_time = ave_time;
best_gb_per_sec = gb_per_sec;
}
if(do_verification)
{
for(std::size_t i = 0; i < gemm_descs.size(); i++)
{
c_device_buf[i]->FromDevice(c_m_n_device_results[i].mData.data());
Tensor<CDataType> c_m_n_host_result(
f_host_tensor_descriptor(Ms[i], Ns[i], StrideCs[i], CLayout{}));
using ReferenceGemmInstance =
ck::tensor_operation::host::ReferenceGemm<ADataType,
BDataType,
CDataType,
AccDataType,
AElementOp,
BElementOp,
CElementOp>;
auto ref_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_gemm.MakeInvoker();
auto ref_argument = ref_gemm.MakeArgument(a_m_k[i],
b_k_n[i],
c_m_n_host_result,
a_element_op,
b_element_op,
c_element_op);
ref_invoker.Run(ref_argument);
bool group_pass =
ck::utils::check_err(c_m_n_device_results[i], c_m_n_host_result);
pass = pass && group_pass;
std::cout << "group: " << i << " verification result: " << std::boolalpha
<< group_pass << std::endl;
if(do_log)
{
LogRangeAsType<float>(std::cout << "a : ", a_m_k[i].mData, ",")
<< std::endl;
LogRangeAsType<float>(std::cout << "b: ", b_k_n[i].mData, ",") << std::endl;
LogRangeAsType<float>(
std::cout << "c_device: ", c_m_n_device_results[i].mData, ",")
<< std::endl;
LogRangeAsType<float>(
std::cout << "c_host : ", c_m_n_host_result.mData, ",")
<< std::endl;
}
}
}
}
else
{
std::cout << "does not support this GEMM problem" << std::endl;
}
}
if(do_verification)
{
std::cout << "Verification: " << (pass ? "SUCCESS" : "FAILURE") << std::endl;
}
std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, "
<< best_gb_per_sec << " GB/s, " << best_gemm_name << std::endl;
return pass;
return profile_grouped_gemm_impl<ADataType,
BDataType,
CDataType,
AccDataType,
ALayout,
BLayout,
CLayout,
AElementOp,
BElementOp,
CElementOp>(do_verification,
init_method,
do_log,
time_kernel,
Ms,
Ns,
Ks,
StrideAs,
StrideBs,
StrideCs,
{1});
}
} // namespace profiler

View File

@@ -13,6 +13,7 @@
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_gemm.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_gemm_fastgelu.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/convolution_parameter.hpp"
@@ -25,13 +26,18 @@
namespace ck {
namespace profiler {
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
template <typename ADataType,
typename BDataType,
typename CDataType,
typename AccDataType,
typename ALayout,
typename BLayout,
typename CLayout>
typename CLayout,
typename AElementOp = PassThrough,
typename BElementOp = PassThrough,
typename CElementOp = PassThrough>
bool profile_grouped_gemm_impl(int do_verification,
int init_method,
bool do_log,
@@ -43,8 +49,8 @@ bool profile_grouped_gemm_impl(int do_verification,
const std::vector<int>& StrideBs,
const std::vector<int>& StrideCs,
const std::vector<int>& kbatches = {},
int n_warmup = 1,
int n_iter = 10,
int n_warmup = -1,
int n_iter = -1,
int instance_index = -1,
bool fail_if_no_supported_instance = false)
{
@@ -93,7 +99,7 @@ bool profile_grouped_gemm_impl(int do_verification,
c_m_n_host_results.push_back(
Tensor<CDataType>(f_host_tensor_descriptor(Ms[i], Ns[i], StrideCs[i], CLayout{})));
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
if(do_log)
{
std::cout << "group: " << i << " a_m_k[" << i << "]:" << a_m_k[i].mDesc << ", b_k_n["
<< i << "]:" << b_k_n[i].mDesc << ", c_m_n_device_results[" << i
@@ -103,21 +109,17 @@ bool profile_grouped_gemm_impl(int do_verification,
{
case 0: break;
case 1:
ck::utils::FillUniformDistributionIntegerValue<ADataType>{-2.f, 2.f}(a_m_k[i]);
ck::utils::FillUniformDistributionIntegerValue<BDataType>{-2.f, 2.f}(b_k_n[i]);
max_abs_in_val = 2.f;
ck::utils::FillUniformDistributionIntegerValue<ADataType>{-5.f, 5.f}(a_m_k[i]);
ck::utils::FillUniformDistributionIntegerValue<BDataType>{-5.f, 5.f}(b_k_n[i]);
max_abs_in_val = 5.f;
break;
default:
ck::utils::FillUniformDistribution<ADataType>{-0.5f, 0.5f}(a_m_k[i]);
ck::utils::FillUniformDistribution<ADataType>{0.0f, 1.0f}(a_m_k[i]);
ck::utils::FillUniformDistribution<BDataType>{-0.5f, 0.5f}(b_k_n[i]);
max_abs_in_val = 0.5f;
max_abs_in_val = 1.0f;
}
}
using AElementOp = ck::tensor_operation::element_wise::PassThrough;
using BElementOp = ck::tensor_operation::element_wise::PassThrough;
using CElementOp = ck::tensor_operation::element_wise::PassThrough;
const auto a_element_op = AElementOp{};
const auto b_element_op = BElementOp{};
const auto c_element_op = CElementOp{};
@@ -200,6 +202,17 @@ bool profile_grouped_gemm_impl(int do_verification,
int num_kernel = 0;
auto p_ds = std::vector<std::array<const void*, 0>>{};
StreamConfig stream_config{nullptr, time_kernel};
if(n_warmup >= 0)
{
stream_config.cold_niters_ = n_warmup;
}
if(n_iter >= 0)
{
stream_config.nrepeat_ = n_iter;
}
if(do_verification)
{
for(std::size_t i = 0; i < gemm_descs.size(); i++)
@@ -225,19 +238,33 @@ bool profile_grouped_gemm_impl(int do_verification,
ref_invoker.Run(ref_argument);
}
}
// If the user will provide not empty kbatches list, then we test predefined set of kbatch
// values.
std::vector<int> kbatch_list = {1, 2, 4, 8, 12, 16, 20, 24, 32, 48, 64};
if(!kbatches.empty())
{
kbatch_list = kbatches;
}
// Check if the operation requested any KBatch size > 1
bool operation_requires_splitk_support = false;
for(auto kbatch : kbatch_list)
{
if(kbatch > 1)
{
operation_requires_splitk_support = true;
break;
}
}
// profile device GEMM instances
int instances_supporting_all_batch_sizes = 0;
int instances_supported = 0;
int instances_supporting_splitk = 0;
for(auto& gemm_ptr : op_ptrs)
{
auto argument_ptr =
gemm_ptr->MakeArgumentPointer(p_a,
p_b,
p_ds,
p_c,
gemm_descs,
ck::tensor_operation::element_wise::PassThrough{},
ck::tensor_operation::element_wise::PassThrough{},
ck::tensor_operation::element_wise::PassThrough{});
auto argument_ptr = gemm_ptr->MakeArgumentPointer(
p_a, p_b, p_ds, p_c, gemm_descs, a_element_op, b_element_op, c_element_op);
auto invoker_ptr = gemm_ptr->MakeInvokerPointer();
@@ -261,16 +288,9 @@ bool profile_grouped_gemm_impl(int do_verification,
std::string gemm_name = gemm_ptr->GetTypeString();
std::vector<int> kbatch_list = {1, 2, 4, 8, 12, 16, 20, 24, 32, 48, 64};
// If the user will provide not empty kbatches list, then we test predefined set of kbatch
// values.
if(!kbatches.empty())
{
kbatch_list = kbatches;
}
bool all_batch_sizes_supported = true;
// Keep track if we found any supported instance
bool any_supported_instance = false;
bool any_supported_nontrivial_kbatch = false;
for(std::size_t j = 0; j < kbatch_list.size(); j++)
{
auto kbatch_curr = kbatch_list[j];
@@ -290,11 +310,17 @@ bool profile_grouped_gemm_impl(int do_verification,
continue;
}
// Keep track of which supported instances we found
any_supported_instance = true;
if(kbatch_curr > 1)
{
any_supported_nontrivial_kbatch = true;
}
for(std::size_t i = 0; i < gemm_descs.size(); i++)
c_device_buf[i]->SetZero();
invoker_ptr->Run(argument_ptr.get(),
StreamConfig{nullptr, false, 0, n_warmup, n_iter});
float ave_time = invoker_ptr->Run(argument_ptr.get(), stream_config);
if(do_verification)
{
@@ -329,7 +355,7 @@ bool profile_grouped_gemm_impl(int do_verification,
}
}
std::cout << "Instance: " << gemm_name << " verification "
std::cout << "Instance: " << gemm_name << "; KBatch: " << kbatch_curr << " "
<< (instance_pass ? "SUCCEED" : "FAILED") << std::endl;
pass = pass && instance_pass;
@@ -337,10 +363,6 @@ bool profile_grouped_gemm_impl(int do_verification,
if(time_kernel)
{
float ave_time =
invoker_ptr->Run(argument_ptr.get(),
StreamConfig{nullptr, time_kernel, 0, n_warmup, n_iter});
std::size_t flop = 0, num_btype = 0;
for(std::size_t i = 0; i < gemm_descs.size(); i++)
{
@@ -370,24 +392,38 @@ bool profile_grouped_gemm_impl(int do_verification,
}
else
{
all_batch_sizes_supported = false;
std::cout << "Instance: " << gemm_name << ", does not support this GEMM problem"
std::cout << "Instance: " << gemm_name
<< ", does not support this GEMM problem (KBatch: " << kbatch_curr << ")"
<< std::endl;
}
}
// If all batch sizes were supported by this instance, the instance can be marked as
// If any kbatch sizes > 1 were supported by this instance, the instance can be marked as
// 'supported' for this problem
if(all_batch_sizes_supported)
if(any_supported_nontrivial_kbatch)
{
++instances_supporting_all_batch_sizes;
++instances_supporting_splitk;
}
if(any_supported_instance)
{
++instances_supported;
}
}
// Warn if not a single instance was supported
if(instances_supporting_all_batch_sizes == 0)
if(instances_supported == 0)
{
std::cout << "Warning! No instance found that supported all of the batch sizes."
std::cout << "Warning! No supported instance found." << std::endl;
if(fail_if_no_supported_instance)
{
return false;
}
}
else if(operation_requires_splitk_support && instances_supporting_splitk == 0)
{
std::cout << "Warning! No instance found that supported any of the kbatch sizes."
<< std::endl;
if(fail_if_no_supported_instance)

View File

@@ -12,6 +12,12 @@ if (CK_USE_XDL OR CK_USE_WMMA)
target_link_libraries(test_grouped_gemm_splitk PRIVATE utility device_grouped_gemm_instance)
add_dependencies(test_grouped_gemm test_grouped_gemm_splitk)
endif()
add_gtest_executable(test_grouped_gemm_fastgelu test_grouped_gemm_fastgelu.cpp)
if(result EQUAL 0)
target_link_libraries(test_grouped_gemm_fastgelu PRIVATE utility device_grouped_gemm_fastgelu_instance)
add_dependencies(test_grouped_gemm test_grouped_gemm_fastgelu)
endif()
endif()
add_gtest_executable(test_grouped_gemm_interface test_grouped_gemm_interface_xdl.cpp)

View File

@@ -0,0 +1,62 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include <tuple>
#include <vector>
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
#include "ck/utility/data_type.hpp"
#include "gtest/gtest.h"
#include "test_grouped_gemm_util.hpp"
ck::index_t param_mask = 0xffffff;
ck::index_t instance_index = -1;
using F16 = ck::half_t;
using BF16 = ck::bhalf_t;
using F8 = ck::f8_t;
using I8 = int8_t;
using AElementOp = ck::tensor_operation::element_wise::PassThrough;
using BElementOp = ck::tensor_operation::element_wise::PassThrough;
using CDEElementOp = ck::tensor_operation::element_wise::FastGelu;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
template <typename Tuple>
class TestGroupedGemm : public ck::test::TestGroupedGemm<Tuple, true>
{
};
// clang-format off
using KernelTypes = ::testing::Types<
std::tuple< Row, Row, Row, F16, F16, F16, AElementOp, BElementOp, CDEElementOp>,
std::tuple< Row, Col, Row, F16, F16, F16, AElementOp, BElementOp, CDEElementOp>,
std::tuple< Col, Row, Row, F16, F16, F16, AElementOp, BElementOp, CDEElementOp>,
std::tuple< Col, Col, Row, F16, F16, F16, AElementOp, BElementOp, CDEElementOp>
>;
// clang-format on
TYPED_TEST_SUITE(TestGroupedGemm, KernelTypes);
#include "test_grouped_gemm_ut_cases.inc"
int main(int argc, char** argv)
{
testing::InitGoogleTest(&argc, argv);
if(argc == 1) {}
else if(argc == 3)
{
param_mask = strtol(argv[1], nullptr, 0);
instance_index = atoi(argv[2]);
}
else
{
std::cout << "Usage of " << argv[0] << std::endl;
std::cout << "Arg1,2: param_mask instance_index(-1 means all)" << std::endl;
}
return RUN_ALL_TESTS();
}

View File

@@ -65,12 +65,11 @@ TYPED_TEST(TestGroupedGemm, MNKPadded)
TYPED_TEST(TestGroupedGemm, TestLargeKBatch)
{
// gfx11 does not support split-K due to missing atomic add for fp16/bf16
// Technically, we could still run the tests for fp32, but we currently don't have instances for
// it so we disable it entirely
if(ck::is_gfx11_supported())
GTEST_SKIP() << "Split-K not supported for FP16/BF16 on GFX11 due to missing atomic add "
"instructions";
// In some cases Split K is not supported. Running this test would fail since no instance will
// be supported, so we skip the test
if(!this->IsSplitKSupported())
GTEST_SKIP() << "Split-K not supported for for the current configuration (FP16/BF16 on "
"GFX11, or using CDE element-wise operation)";
const std::vector<int> Ms{188, 210};
constexpr int N = 768;

View File

@@ -7,11 +7,14 @@
#include <string>
#include <sstream>
#include <tuple>
#include <type_traits>
#include <vector>
#include <gtest/gtest.h>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
#include "profiler/profile_grouped_gemm_impl.hpp"
extern ck::index_t param_mask;
@@ -32,16 +35,46 @@ std::string serialize_range(const Range& range)
return std::string(str.begin(), str.end() - 2);
}
// Helper primary template (will be specialized on the boolean)
template <std::size_t N,
typename Tuple,
typename Default,
bool InRange = (N < std::tuple_size_v<std::remove_reference_t<Tuple>>)>
struct tuple_element_or_impl;
// Specialization for the in-range case: use std::tuple_element_t
template <std::size_t N, typename Tuple, typename Default>
struct tuple_element_or_impl<N, Tuple, Default, true>
{
using type = std::tuple_element_t<N, std::remove_reference_t<Tuple>>;
};
// Specialization for the out-of-range case: use Default
template <std::size_t N, typename Tuple, typename Default>
struct tuple_element_or_impl<N, Tuple, Default, false>
{
using type = Default;
};
// User-facing alias
template <std::size_t N, typename Tuple, typename Default>
using tuple_element_or_t = typename tuple_element_or_impl<N, Tuple, Default>::type;
template <typename Tuple, bool FailIfNoSupportedInstances = false>
class TestGroupedGemm : public testing::Test
{
protected:
using ALayout = std::tuple_element_t<0, Tuple>;
using BLayout = std::tuple_element_t<1, Tuple>;
using ELayout = std::tuple_element_t<2, Tuple>;
using ADataType = std::tuple_element_t<3, Tuple>;
using BDataType = std::tuple_element_t<4, Tuple>;
using EDataType = std::tuple_element_t<5, Tuple>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using ALayout = std::tuple_element_t<0, Tuple>;
using BLayout = std::tuple_element_t<1, Tuple>;
using ELayout = std::tuple_element_t<2, Tuple>;
using ADataType = std::tuple_element_t<3, Tuple>;
using BDataType = std::tuple_element_t<4, Tuple>;
using EDataType = std::tuple_element_t<5, Tuple>;
using AElementOp = tuple_element_or_t<6, Tuple, PassThrough>;
using BElementOp = tuple_element_or_t<7, Tuple, PassThrough>;
using CDEElementOp = tuple_element_or_t<8, Tuple, PassThrough>;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
@@ -57,15 +90,25 @@ class TestGroupedGemm : public testing::Test
bool fail_if_no_supported_instances_ = FailIfNoSupportedInstances;
std::vector<int> k_batches_;
void SetUp() override
bool IsSplitKSupported()
{
// gfx11 does not support split-K due to missing atomic add for fp16/bf16
// Technically, we could still use split-K for fp32, but we currently don't have
// instances for it so we disable it entirely
constexpr bool require_16bit_atomic_add =
std::is_same_v<EDataType, ck::half_t> || std::is_same_v<EDataType, ck::bhalf_t>;
if(require_16bit_atomic_add && ck::is_gfx11_supported())
bool missing_atomic_add = require_16bit_atomic_add && ck::is_gfx11_supported();
// CDE element operators are not supported in combination with split K
constexpr bool has_cde_element_operator = !std::is_same_v<CDEElementOp, PassThrough>;
return !missing_atomic_add && !has_cde_element_operator;
}
void SetUp() override
{
if(!IsSplitKSupported())
{
// gfx11 does not support split-K due to missing atomic add for fp16/bf16
// Technically, we could still use split-K for fp32, but we currently don't have
// instances for it so we disable it entirely
k_batches_ = {1};
}
else
@@ -147,21 +190,24 @@ class TestGroupedGemm : public testing::Test
float,
ALayout,
BLayout,
ELayout>(verify_,
init_method_,
log_,
bench_,
Ms,
Ns,
Ks,
StrideAs,
StrideBs,
StrideCs,
kbatches,
n_warmup_,
n_iter_,
instance_index,
fail_if_no_supported_instances_);
ELayout,
AElementOp,
BElementOp,
CDEElementOp>(verify_,
init_method_,
log_,
bench_,
Ms,
Ns,
Ks,
StrideAs,
StrideBs,
StrideCs,
kbatches,
n_warmup_,
n_iter_,
instance_index,
fail_if_no_supported_instances_);
EXPECT_TRUE(pass);
}
};