mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 02:02:46 +00:00
Implement grouped gemm fastgelu for RDNA4 (#3303)
* Implement grouped gemm fastgelu for RDNA4
* chore: some cleanup and minor inconsistencies in grouped gemm profiler
* chore: clarified logic and reporting of supported instance warnings
[ROCm/composable_kernel commit: f9c6ba0403]
This commit is contained in:
@@ -7,6 +7,7 @@
|
||||
#include <sstream>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
|
||||
#include "ck/utility/env.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
#include "ck/host_utility/kernel_launch.hpp"
|
||||
@@ -242,7 +243,6 @@ struct DeviceGroupedGemm_Wmma_CShuffleV3 : public DeviceGroupedGemmSplitK<ALayou
|
||||
static constexpr index_t B2E_M01 = 8;
|
||||
using GroupedGemmBlock2ETileMap = OffsettedBlockToCTileMap<Block2ETileMapKSplit>;
|
||||
using KernelArgument = typename GridwiseGemm::Argument;
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
template <typename KernelArgument_>
|
||||
struct GemmTransKernelArgBase
|
||||
{
|
||||
@@ -274,23 +274,38 @@ struct DeviceGroupedGemm_Wmma_CShuffleV3 : public DeviceGroupedGemmSplitK<ALayou
|
||||
}
|
||||
|
||||
// Argument
|
||||
// TODO: Add A/B/CDE element op?
|
||||
struct Argument : public BaseArgument
|
||||
{
|
||||
|
||||
Argument(std::vector<const void*>& p_As,
|
||||
std::vector<const void*>& p_Bs,
|
||||
std::vector<std::array<const void*, NumDTensor>>& p_Ds,
|
||||
std::vector<void*>& p_Es,
|
||||
std::vector<GemmDesc>& gemm_descs)
|
||||
: Argument(p_As, p_Bs, p_Es, gemm_descs, DefaultKBatch)
|
||||
std::vector<GemmDesc>& gemm_descs,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CDEElementwiseOperation c_element_op)
|
||||
: Argument(p_As,
|
||||
p_Bs,
|
||||
p_Ds,
|
||||
p_Es,
|
||||
gemm_descs,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op,
|
||||
DefaultKBatch)
|
||||
{
|
||||
// TODO: use occupancy api to calculate appropriate batch size.
|
||||
}
|
||||
|
||||
Argument(std::vector<const void*>& p_As,
|
||||
std::vector<const void*>& p_Bs,
|
||||
std::vector<std::array<const void*, NumDTensor>>& p_Ds,
|
||||
std::vector<void*>& p_Es,
|
||||
std::vector<GemmDesc>& gemm_descs,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CDEElementwiseOperation c_element_op,
|
||||
index_t kbatch)
|
||||
: K_BATCH{kbatch}, gemm_kernel_host_args_{nullptr}
|
||||
{
|
||||
@@ -299,9 +314,11 @@ struct DeviceGroupedGemm_Wmma_CShuffleV3 : public DeviceGroupedGemmSplitK<ALayou
|
||||
|
||||
if(!(group_count_ == ck::type_convert<ck::index_t>(p_As.size()) &&
|
||||
group_count_ == ck::type_convert<ck::index_t>(p_Bs.size()) &&
|
||||
((NumDTensor == 0 && p_Ds.size() == 0) ||
|
||||
group_count_ == ck::type_convert<ck::index_t>(p_Ds.size())) &&
|
||||
group_count_ == ck::type_convert<ck::index_t>(p_Es.size())))
|
||||
{
|
||||
throw std::runtime_error("wrong! group_count_ != p_As/b/c.size");
|
||||
throw std::runtime_error("wrong! group_count_ != p_As/b/d/e.size");
|
||||
}
|
||||
|
||||
gemm_kernel_args_.reserve(group_count_);
|
||||
@@ -320,9 +337,22 @@ struct DeviceGroupedGemm_Wmma_CShuffleV3 : public DeviceGroupedGemmSplitK<ALayou
|
||||
continue;
|
||||
}
|
||||
|
||||
const index_t stride_a = gemm_descs[i].stride_A_;
|
||||
const index_t stride_b = gemm_descs[i].stride_B_;
|
||||
const index_t stride_c = gemm_descs[i].stride_C_;
|
||||
const index_t stride_a = gemm_descs[i].stride_A_;
|
||||
const index_t stride_b = gemm_descs[i].stride_B_;
|
||||
const index_t stride_c = gemm_descs[i].stride_C_;
|
||||
const auto& stride_d_vec = gemm_descs[i].stride_Ds_;
|
||||
|
||||
if(!(NumDTensor == ck::type_convert<ck::index_t>(stride_d_vec.size())))
|
||||
{
|
||||
throw std::runtime_error("wrong! stride D mismatch");
|
||||
}
|
||||
|
||||
// Copy D stride vector to fixed-size array
|
||||
std::array<index_t, NumDTensor> stride_ds;
|
||||
if constexpr(NumDTensor > 0)
|
||||
{
|
||||
std::copy(stride_d_vec.begin(), stride_d_vec.end(), stride_ds);
|
||||
}
|
||||
|
||||
const index_t m_padded = GridwiseGemm::CalculateMPadded(M);
|
||||
const index_t n_padded = GridwiseGemm::CalculateNPadded(N);
|
||||
@@ -346,19 +376,19 @@ struct DeviceGroupedGemm_Wmma_CShuffleV3 : public DeviceGroupedGemmSplitK<ALayou
|
||||
|
||||
auto karg = KernelArgument(std::array<const void*, 1>{p_As[i]},
|
||||
std::array<const void*, 1>{p_Bs[i]},
|
||||
std::array<const void*, 0>{}, // p_ds_grid_
|
||||
p_Ds[i],
|
||||
type_convert<EDataType*>(p_Es[i]),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
std::array<index_t, 1>{stride_a},
|
||||
std::array<index_t, 1>{stride_b},
|
||||
std::array<index_t, 0>{}, // StrideDs_
|
||||
stride_ds,
|
||||
stride_c,
|
||||
K_BATCH,
|
||||
PassThrough{},
|
||||
PassThrough{},
|
||||
PassThrough{},
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op,
|
||||
false);
|
||||
|
||||
gemm_kernel_args_.emplace_back(
|
||||
@@ -632,6 +662,23 @@ struct DeviceGroupedGemm_Wmma_CShuffleV3 : public DeviceGroupedGemmSplitK<ALayou
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr(!std::is_same_v<CDEElementwiseOperation,
|
||||
ck::tensor_operation::element_wise::PassThrough>)
|
||||
{
|
||||
if(arg.K_BATCH > 1)
|
||||
{
|
||||
// Using SplitK and a C element op would require a two stage kernel where the second
|
||||
// stage applies the op on the accumulated results
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "C element operators are not supported when using SplitK. Set "
|
||||
"K_BATCH to 1 or remove the operator."
|
||||
<< std::endl;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr(std::is_same_v<ComputeTypeA, f8_t> || std::is_same_v<ComputeTypeA, bf8_t> ||
|
||||
std::is_same_v<ComputeTypeB, f8_t> || std::is_same_v<ComputeTypeB, bf8_t>)
|
||||
{
|
||||
@@ -681,14 +728,15 @@ struct DeviceGroupedGemm_Wmma_CShuffleV3 : public DeviceGroupedGemmSplitK<ALayou
|
||||
|
||||
static auto MakeArgument(std::vector<const void*>& p_As,
|
||||
std::vector<const void*>& p_Bs,
|
||||
std::vector<std::array<const void*, NumDTensor>>&,
|
||||
std::vector<std::array<const void*, NumDTensor>>& p_Ds,
|
||||
std::vector<void*>& p_Es,
|
||||
std::vector<GemmDesc> gemm_descs,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CDEElementwiseOperation)
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CDEElementwiseOperation c_element_op)
|
||||
{
|
||||
return Argument{p_As, p_Bs, p_Es, gemm_descs};
|
||||
return Argument{
|
||||
p_As, p_Bs, p_Ds, p_Es, gemm_descs, a_element_op, b_element_op, c_element_op};
|
||||
}
|
||||
|
||||
static auto MakeInvoker() { return Invoker{}; }
|
||||
@@ -697,14 +745,15 @@ struct DeviceGroupedGemm_Wmma_CShuffleV3 : public DeviceGroupedGemmSplitK<ALayou
|
||||
std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(std::vector<const void*>& p_As,
|
||||
std::vector<const void*>& p_Bs,
|
||||
std::vector<std::array<const void*, NumDTensor>>&,
|
||||
std::vector<std::array<const void*, NumDTensor>>& p_Ds,
|
||||
std::vector<void*>& p_Es,
|
||||
std::vector<GemmDesc>& gemm_descs,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CDEElementwiseOperation) override
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CDEElementwiseOperation c_element_op) override
|
||||
{
|
||||
return std::make_unique<Argument>(p_As, p_Bs, p_Es, gemm_descs);
|
||||
return std::make_unique<Argument>(
|
||||
p_As, p_Bs, p_Ds, p_Es, gemm_descs, a_element_op, b_element_op, c_element_op);
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
|
||||
@@ -31,6 +31,7 @@ using S = ck::Sequence<Is...>;
|
||||
|
||||
using Empty_Tuple = ck::Tuple<>;
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
using FastGelu = ck::tensor_operation::element_wise::FastGelu;
|
||||
|
||||
using AccDataType = F32;
|
||||
using DsDataType = Empty_Tuple;
|
||||
@@ -38,10 +39,6 @@ using DsDataType = Empty_Tuple;
|
||||
using DsLayout = Empty_Tuple;
|
||||
using ELayout = Row;
|
||||
|
||||
using AElementOp = PassThrough;
|
||||
using BElementOp = PassThrough;
|
||||
using CDEElementOp = PassThrough;
|
||||
|
||||
static constexpr auto PipelineV1 = BlockGemmPipelineVersion::v1;
|
||||
static constexpr auto PipelineV3 = BlockGemmPipelineVersion::v3;
|
||||
static constexpr auto IntrawaveScheduler = BlockGemmPipelineScheduler::Intrawave;
|
||||
@@ -54,6 +51,9 @@ template <typename T,
|
||||
device::GemmSpecialization GemmSpec,
|
||||
BlockGemmPipelineScheduler BlkGemmPipeSched,
|
||||
BlockGemmPipelineVersion BlkGemmPipelineVer,
|
||||
typename AElementOp,
|
||||
typename BElementOp,
|
||||
typename CDEElementOp,
|
||||
enable_if_t<sizeof(T) == 2, bool> = false>
|
||||
using device_grouped_gemm_wmma_universal_km_kn_mn_instances =
|
||||
std::tuple<
|
||||
@@ -73,6 +73,9 @@ template <typename T,
|
||||
device::GemmSpecialization GemmSpec,
|
||||
BlockGemmPipelineScheduler BlkGemmPipeSched,
|
||||
BlockGemmPipelineVersion BlkGemmPipelineVer,
|
||||
typename AElementOp,
|
||||
typename BElementOp,
|
||||
typename CDEElementOp,
|
||||
enable_if_t<sizeof(T) == 2, bool> = false>
|
||||
using device_grouped_gemm_wmma_universal_km_nk_mn_instances = std::tuple<
|
||||
// clang-format off
|
||||
@@ -91,6 +94,9 @@ template <typename T,
|
||||
device::GemmSpecialization GemmSpec,
|
||||
BlockGemmPipelineScheduler BlkGemmPipeSched,
|
||||
BlockGemmPipelineVersion BlkGemmPipelineVer,
|
||||
typename AElementOp,
|
||||
typename BElementOp,
|
||||
typename CDEElementOp,
|
||||
enable_if_t<sizeof(T) == 2, bool> = false>
|
||||
using device_grouped_gemm_wmma_universal_mk_kn_mn_instances =
|
||||
std::tuple<
|
||||
@@ -110,6 +116,9 @@ template <typename T,
|
||||
device::GemmSpecialization GemmSpec,
|
||||
BlockGemmPipelineScheduler BlkGemmPipeSched,
|
||||
BlockGemmPipelineVersion BlkGemmPipelineVer,
|
||||
typename AElementOp,
|
||||
typename BElementOp,
|
||||
typename CDEElementOp,
|
||||
enable_if_t<sizeof(T) == 2, bool> = false>
|
||||
using device_grouped_gemm_wmma_universal_mk_nk_mn_instances =
|
||||
std::tuple<
|
||||
@@ -124,17 +133,38 @@ using device_grouped_gemm_wmma_universal_mk_nk_mn_instances =
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
// List of instance variants to add (pipeline/scheduler/padding combinations)
|
||||
// Some are disabled now, can be re-enabled if needed
|
||||
using InstanceVariant =
|
||||
ck::Tuple<device::GemmSpecialization, BlockGemmPipelineScheduler, BlockGemmPipelineVersion>;
|
||||
static constexpr InstanceVariant InstanceVariants[] = {
|
||||
|
||||
make_tuple(GemmDefault, IntrawaveScheduler, PipelineV1),
|
||||
// make_tuple(GemmDefault, InterwaveScheduler, PipelineV1),
|
||||
make_tuple(GemmDefault, IntrawaveScheduler, PipelineV3),
|
||||
|
||||
make_tuple(GemmMNKPadding, IntrawaveScheduler, PipelineV1),
|
||||
// make_tuple(GemmMNKPadding, InterwaveScheduler, PipelineV1),
|
||||
// make_tuple(GemmMNKPadding, IntrawaveScheduler, PipelineV3),
|
||||
};
|
||||
|
||||
// Helper function to add a list of layout instances with specific A/B/E datatypes for all supported
|
||||
// padding/scheduler/pipeline version combinations
|
||||
template <typename ALayout,
|
||||
typename BLayout,
|
||||
template <device::GemmSpecialization GemmSpec,
|
||||
BlockGemmPipelineScheduler BlkGemmPipeSched,
|
||||
BlockGemmPipelineVersion BlkGemmPipelineVer>
|
||||
BlockGemmPipelineVersion BlkGemmPipelineVer,
|
||||
typename AElementOp,
|
||||
typename BElementOp,
|
||||
typename CDEElementOp>
|
||||
typename LayoutInstances,
|
||||
typename ADataType, // NOTE: type parameters as last so that they can be inferred from the
|
||||
typename BDataType, // vector argument
|
||||
typename EDataType>
|
||||
typename EDataType,
|
||||
typename AElementOp,
|
||||
typename BElementOp,
|
||||
typename CDEElementOp>
|
||||
void add_device_grouped_gemm_wmma_universal_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedGemm<ALayout,
|
||||
BLayout,
|
||||
@@ -148,18 +178,17 @@ void add_device_grouped_gemm_wmma_universal_instances(
|
||||
BElementOp,
|
||||
CDEElementOp>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(instances,
|
||||
LayoutInstances<GemmDefault, IntrawaveScheduler, PipelineV1>{});
|
||||
add_device_operation_instances(instances,
|
||||
LayoutInstances<GemmDefault, InterwaveScheduler, PipelineV1>{});
|
||||
add_device_operation_instances(instances,
|
||||
LayoutInstances<GemmDefault, IntrawaveScheduler, PipelineV3>{});
|
||||
add_device_operation_instances(
|
||||
instances, LayoutInstances<GemmMNKPadding, IntrawaveScheduler, PipelineV1>{});
|
||||
add_device_operation_instances(
|
||||
instances, LayoutInstances<GemmMNKPadding, InterwaveScheduler, PipelineV1>{});
|
||||
add_device_operation_instances(
|
||||
instances, LayoutInstances<GemmMNKPadding, IntrawaveScheduler, PipelineV3>{});
|
||||
// Add all instances from our instance list
|
||||
static_for<0, std::size(InstanceVariants), 1>{}([&](auto i) {
|
||||
constexpr auto instance = InstanceVariants[i];
|
||||
add_device_operation_instances(instances,
|
||||
LayoutInstances<instance.At(Number<0>{}),
|
||||
instance.At(Number<1>{}),
|
||||
instance.At(Number<2>{}),
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
CDEElementOp>{});
|
||||
});
|
||||
}
|
||||
|
||||
// Helper function to add a list of layout instances for instances with matching A/B/E data types
|
||||
@@ -170,8 +199,14 @@ template <typename T,
|
||||
template <typename T2,
|
||||
device::GemmSpecialization GemmSpec,
|
||||
BlockGemmPipelineScheduler BlkGemmPipeSched,
|
||||
BlockGemmPipelineVersion BlkGemmPipelineVer>
|
||||
typename LayoutInstances>
|
||||
BlockGemmPipelineVersion BlkGemmPipelineVer,
|
||||
typename AElementOp,
|
||||
typename BElementOp,
|
||||
typename CDEElementOp>
|
||||
typename LayoutInstances,
|
||||
typename AElementOp, // NOTE: element-wise op parameters as last so that they can be
|
||||
typename BElementOp, // inferred from the vector argument
|
||||
typename CDEElementOp>
|
||||
void add_device_grouped_gemm_wmma_universal_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedGemm<ALayout,
|
||||
BLayout,
|
||||
@@ -185,18 +220,18 @@ void add_device_grouped_gemm_wmma_universal_instances(
|
||||
BElementOp,
|
||||
CDEElementOp>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances, LayoutInstances<T, GemmDefault, IntrawaveScheduler, PipelineV1>{});
|
||||
add_device_operation_instances(
|
||||
instances, LayoutInstances<T, GemmDefault, InterwaveScheduler, PipelineV1>{});
|
||||
add_device_operation_instances(
|
||||
instances, LayoutInstances<T, GemmDefault, IntrawaveScheduler, PipelineV3>{});
|
||||
add_device_operation_instances(
|
||||
instances, LayoutInstances<T, GemmMNKPadding, IntrawaveScheduler, PipelineV1>{});
|
||||
add_device_operation_instances(
|
||||
instances, LayoutInstances<T, GemmMNKPadding, InterwaveScheduler, PipelineV1>{});
|
||||
add_device_operation_instances(
|
||||
instances, LayoutInstances<T, GemmMNKPadding, IntrawaveScheduler, PipelineV3>{});
|
||||
// Add all instances from our instance list
|
||||
static_for<0, std::size(InstanceVariants), 1>{}([&](auto i) {
|
||||
constexpr auto instance = InstanceVariants[i];
|
||||
add_device_operation_instances(instances,
|
||||
LayoutInstances<T,
|
||||
instance.At(Number<0>{}),
|
||||
instance.At(Number<1>{}),
|
||||
instance.At(Number<2>{}),
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
CDEElementOp>{});
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
|
||||
@@ -15,6 +15,64 @@ namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
#if defined(CK_USE_WMMA)
|
||||
#if defined(CK_ENABLE_FP16)
|
||||
void add_device_grouped_gemm_fastgelu_wmma_f16_f16_f16_mk_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedGemm<Row,
|
||||
Row,
|
||||
Empty_Tuple,
|
||||
Row,
|
||||
F16,
|
||||
F16,
|
||||
Empty_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
FastGelu>>>& instances);
|
||||
|
||||
void add_device_grouped_gemm_fastgelu_wmma_f16_f16_f16_mk_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedGemm<Row,
|
||||
Col,
|
||||
Empty_Tuple,
|
||||
Row,
|
||||
F16,
|
||||
F16,
|
||||
Empty_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
FastGelu>>>& instances);
|
||||
|
||||
void add_device_grouped_gemm_fastgelu_wmma_f16_f16_f16_km_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedGemm<Col,
|
||||
Row,
|
||||
Empty_Tuple,
|
||||
Row,
|
||||
F16,
|
||||
F16,
|
||||
Empty_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
FastGelu>>>& instances);
|
||||
|
||||
void add_device_grouped_gemm_fastgelu_wmma_f16_f16_f16_km_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedGemm<Col,
|
||||
Col,
|
||||
Empty_Tuple,
|
||||
Row,
|
||||
F16,
|
||||
F16,
|
||||
Empty_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
FastGelu>>>& instances);
|
||||
#endif // CK_ENABLE_FP16
|
||||
#endif // CK_USE_WMMA
|
||||
|
||||
#if defined(CK_USE_XDL)
|
||||
#if defined(CK_ENABLE_FP16)
|
||||
void add_device_grouped_gemm_fastgelu_xdl_f16_f16_f16_mk_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedGemm<Row,
|
||||
Row,
|
||||
@@ -66,6 +124,8 @@ void add_device_grouped_gemm_fastgelu_xdl_f16_f16_f16_km_nk_mn_instances(
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
FastGelu>>>& instances);
|
||||
#endif // CK_ENABLE_FP16
|
||||
#endif // CK_USE_XDL
|
||||
|
||||
// GroupedGEMM + GELU
|
||||
template <typename ALayout,
|
||||
@@ -102,30 +162,52 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
|
||||
{
|
||||
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
|
||||
|
||||
#if defined(CK_ENABLE_FP16)
|
||||
if constexpr(is_same_v<ADataType, half_t> && is_same_v<BDataType, half_t> &&
|
||||
is_same_v<EDataType, half_t>)
|
||||
{
|
||||
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
|
||||
is_same_v<ELayout, Row>)
|
||||
{
|
||||
#if defined(CK_USE_XDL)
|
||||
add_device_grouped_gemm_fastgelu_xdl_f16_f16_f16_mk_kn_mn_instances(op_ptrs);
|
||||
#endif
|
||||
#if defined(CK_USE_WMMA)
|
||||
add_device_grouped_gemm_fastgelu_wmma_f16_f16_f16_mk_kn_mn_instances(op_ptrs);
|
||||
#endif
|
||||
}
|
||||
else if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
|
||||
is_same_v<ELayout, Row>)
|
||||
{
|
||||
#if defined(CK_USE_XDL)
|
||||
add_device_grouped_gemm_fastgelu_xdl_f16_f16_f16_mk_nk_mn_instances(op_ptrs);
|
||||
#endif
|
||||
#if defined(CK_USE_WMMA)
|
||||
add_device_grouped_gemm_fastgelu_wmma_f16_f16_f16_mk_nk_mn_instances(op_ptrs);
|
||||
#endif
|
||||
}
|
||||
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Row> &&
|
||||
is_same_v<ELayout, Row>)
|
||||
{
|
||||
#if defined(CK_USE_XDL)
|
||||
add_device_grouped_gemm_fastgelu_xdl_f16_f16_f16_km_kn_mn_instances(op_ptrs);
|
||||
#endif
|
||||
#if defined(CK_USE_WMMA)
|
||||
add_device_grouped_gemm_fastgelu_wmma_f16_f16_f16_km_kn_mn_instances(op_ptrs);
|
||||
#endif
|
||||
}
|
||||
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Col> &&
|
||||
is_same_v<ELayout, Row>)
|
||||
{
|
||||
#if defined(CK_USE_XDL)
|
||||
add_device_grouped_gemm_fastgelu_xdl_f16_f16_f16_km_nk_mn_instances(op_ptrs);
|
||||
#endif
|
||||
#if defined(CK_USE_WMMA)
|
||||
add_device_grouped_gemm_fastgelu_wmma_f16_f16_f16_km_nk_mn_instances(op_ptrs);
|
||||
#endif
|
||||
}
|
||||
}
|
||||
#endif
|
||||
return op_ptrs;
|
||||
}
|
||||
};
|
||||
|
||||
@@ -20,9 +20,9 @@ void add_device_grouped_gemm_wmma_universal_bf16_bf16_bf16_km_kn_mn_instances(
|
||||
BF16,
|
||||
DsDataType,
|
||||
BF16,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
CDEElementOp>>>& instances)
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances)
|
||||
{
|
||||
add_device_grouped_gemm_wmma_universal_instances<
|
||||
BF16,
|
||||
|
||||
@@ -20,9 +20,9 @@ void add_device_grouped_gemm_wmma_universal_bf16_bf16_bf16_km_nk_mn_instances(
|
||||
BF16,
|
||||
DsDataType,
|
||||
BF16,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
CDEElementOp>>>& instances)
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances)
|
||||
{
|
||||
add_device_grouped_gemm_wmma_universal_instances<
|
||||
BF16,
|
||||
|
||||
@@ -20,9 +20,9 @@ void add_device_grouped_gemm_wmma_universal_bf16_bf16_bf16_mk_kn_mn_instances(
|
||||
BF16,
|
||||
DsDataType,
|
||||
BF16,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
CDEElementOp>>>& instances)
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances)
|
||||
{
|
||||
add_device_grouped_gemm_wmma_universal_instances<
|
||||
BF16,
|
||||
|
||||
@@ -20,9 +20,9 @@ void add_device_grouped_gemm_wmma_universal_bf16_bf16_bf16_mk_nk_mn_instances(
|
||||
BF16,
|
||||
DsDataType,
|
||||
BF16,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
CDEElementOp>>>& instances)
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances)
|
||||
{
|
||||
add_device_grouped_gemm_wmma_universal_instances<
|
||||
BF16,
|
||||
|
||||
@@ -20,9 +20,9 @@ void add_device_grouped_gemm_wmma_universal_f16_f16_f16_km_kn_mn_instances(
|
||||
F16,
|
||||
DsDataType,
|
||||
F16,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
CDEElementOp>>>& instances)
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances)
|
||||
{
|
||||
add_device_grouped_gemm_wmma_universal_instances<
|
||||
F16,
|
||||
|
||||
@@ -20,9 +20,9 @@ void add_device_grouped_gemm_wmma_universal_f16_f16_f16_km_nk_mn_instances(
|
||||
F16,
|
||||
DsDataType,
|
||||
F16,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
CDEElementOp>>>& instances)
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances)
|
||||
{
|
||||
add_device_grouped_gemm_wmma_universal_instances<
|
||||
F16,
|
||||
|
||||
@@ -20,9 +20,9 @@ void add_device_grouped_gemm_wmma_universal_f16_f16_f16_mk_kn_mn_instances(
|
||||
F16,
|
||||
DsDataType,
|
||||
F16,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
CDEElementOp>>>& instances)
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances)
|
||||
{
|
||||
|
||||
add_device_grouped_gemm_wmma_universal_instances<
|
||||
|
||||
@@ -20,9 +20,9 @@ void add_device_grouped_gemm_wmma_universal_f16_f16_f16_mk_nk_mn_instances(
|
||||
F16,
|
||||
DsDataType,
|
||||
F16,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
CDEElementOp>>>& instances)
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances)
|
||||
{
|
||||
|
||||
add_device_grouped_gemm_wmma_universal_instances<
|
||||
|
||||
@@ -17,7 +17,10 @@ using EDataType = F16;
|
||||
|
||||
template <device::GemmSpecialization GemmSpec,
|
||||
BlockGemmPipelineScheduler BlkGemmPipeSched,
|
||||
BlockGemmPipelineVersion BlkGemmPipelineVer>
|
||||
BlockGemmPipelineVersion BlkGemmPipelineVer,
|
||||
typename AElementOp,
|
||||
typename BElementOp,
|
||||
typename CDEElementOp>
|
||||
using device_grouped_gemm_wmma_universal_f16_f8_f16_mk_kn_mn_instances =
|
||||
std::tuple<
|
||||
// clang-format off
|
||||
@@ -40,9 +43,9 @@ void add_device_grouped_gemm_wmma_universal_f16_f8_f16_mk_kn_mn_instances(
|
||||
BDataType,
|
||||
DsDataType,
|
||||
EDataType,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
CDEElementOp>>>& instances)
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances)
|
||||
{
|
||||
|
||||
add_device_grouped_gemm_wmma_universal_instances<
|
||||
|
||||
@@ -17,7 +17,10 @@ using EDataType = F16;
|
||||
|
||||
template <device::GemmSpecialization GemmSpec,
|
||||
BlockGemmPipelineScheduler BlkGemmPipeSched,
|
||||
BlockGemmPipelineVersion BlkGemmPipelineVer>
|
||||
BlockGemmPipelineVersion BlkGemmPipelineVer,
|
||||
typename AElementOp,
|
||||
typename BElementOp,
|
||||
typename CDEElementOp>
|
||||
using device_grouped_gemm_wmma_universal_f8_f16_f16_mk_kn_mn_instances =
|
||||
std::tuple<
|
||||
// clang-format off
|
||||
@@ -40,9 +43,9 @@ void add_device_grouped_gemm_wmma_universal_f8_f16_f16_mk_kn_mn_instances(
|
||||
BDataType,
|
||||
DsDataType,
|
||||
EDataType,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
CDEElementOp>>>& instances)
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances)
|
||||
{
|
||||
|
||||
add_device_grouped_gemm_wmma_universal_instances<
|
||||
|
||||
@@ -1,10 +1,15 @@
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
# ONLY XDL_KERNELS
|
||||
# ONLY XDL_AND_WMMA_KERNELS
|
||||
add_instance_library(device_grouped_gemm_fastgelu_instance
|
||||
device_grouped_gemm_fastgelu_xdl_f16_f16_f16_mk_kn_mn_instance.cpp
|
||||
device_grouped_gemm_fastgelu_xdl_f16_f16_f16_mk_nk_mn_instance.cpp
|
||||
device_grouped_gemm_fastgelu_xdl_f16_f16_f16_km_kn_mn_instance.cpp
|
||||
device_grouped_gemm_fastgelu_xdl_f16_f16_f16_km_nk_mn_instance.cpp
|
||||
|
||||
device_grouped_gemm_fastgelu_wmma_f16_f16_f16_mk_kn_mn_instance.cpp
|
||||
device_grouped_gemm_fastgelu_wmma_f16_f16_f16_mk_nk_mn_instance.cpp
|
||||
device_grouped_gemm_fastgelu_wmma_f16_f16_f16_km_kn_mn_instance.cpp
|
||||
device_grouped_gemm_fastgelu_wmma_f16_f16_f16_km_nk_mn_instance.cpp
|
||||
)
|
||||
|
||||
@@ -0,0 +1,37 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include <cstdlib>
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_splitk_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_grouped_gemm_fastgelu_wmma_f16_f16_f16_km_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedGemm<Col,
|
||||
Row,
|
||||
DsLayout,
|
||||
Row,
|
||||
F16,
|
||||
F16,
|
||||
DsDataType,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
FastGelu>>>& instances)
|
||||
{
|
||||
add_device_grouped_gemm_wmma_universal_instances<
|
||||
F16,
|
||||
Col,
|
||||
Row,
|
||||
device_grouped_gemm_wmma_universal_km_kn_mn_instances>(instances);
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,37 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include <cstdlib>
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_splitk_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_grouped_gemm_fastgelu_wmma_f16_f16_f16_km_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedGemm<Col,
|
||||
Col,
|
||||
DsLayout,
|
||||
Row,
|
||||
F16,
|
||||
F16,
|
||||
DsDataType,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
FastGelu>>>& instances)
|
||||
{
|
||||
add_device_grouped_gemm_wmma_universal_instances<
|
||||
F16,
|
||||
Col,
|
||||
Col,
|
||||
device_grouped_gemm_wmma_universal_km_nk_mn_instances>(instances);
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,38 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include <cstdlib>
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_splitk_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_grouped_gemm_fastgelu_wmma_f16_f16_f16_mk_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedGemm<Row,
|
||||
Row,
|
||||
DsLayout,
|
||||
Row,
|
||||
F16,
|
||||
F16,
|
||||
DsDataType,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
FastGelu>>>& instances)
|
||||
{
|
||||
|
||||
add_device_grouped_gemm_wmma_universal_instances<
|
||||
F16,
|
||||
Row,
|
||||
Row,
|
||||
device_grouped_gemm_wmma_universal_mk_kn_mn_instances>(instances);
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,38 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include <cstdlib>
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_splitk_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_grouped_gemm_fastgelu_wmma_f16_f16_f16_mk_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedGemm<Row,
|
||||
Col,
|
||||
DsLayout,
|
||||
Row,
|
||||
F16,
|
||||
F16,
|
||||
DsDataType,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
FastGelu>>>& instances)
|
||||
{
|
||||
|
||||
add_device_grouped_gemm_wmma_universal_instances<
|
||||
F16,
|
||||
Row,
|
||||
Col,
|
||||
device_grouped_gemm_wmma_universal_mk_nk_mn_instances>(instances);
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -17,6 +17,8 @@
|
||||
#include "ck/tensor_operation/gpu/device/device_grouped_gemm.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
#include "profile_grouped_gemm_impl.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace profiler {
|
||||
|
||||
@@ -38,242 +40,30 @@ bool profile_grouped_gemm_fastgelu_impl(int do_verification,
|
||||
const std::vector<int>& StrideBs,
|
||||
const std::vector<int>& StrideCs)
|
||||
{
|
||||
|
||||
bool pass = true;
|
||||
|
||||
auto f_host_tensor_descriptor =
|
||||
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
|
||||
using namespace ck::literals;
|
||||
|
||||
if(is_same<decltype(layout), tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
return HostTensorDescriptor({row, col}, {stride, 1_uz});
|
||||
}
|
||||
else
|
||||
{
|
||||
return HostTensorDescriptor({row, col}, {1_uz, stride});
|
||||
}
|
||||
};
|
||||
|
||||
std::size_t group_count = Ms.size();
|
||||
|
||||
if(!(group_count == Ns.size() && group_count == Ks.size() && group_count == StrideAs.size() &&
|
||||
group_count == StrideBs.size() && group_count == StrideCs.size()))
|
||||
{
|
||||
throw std::runtime_error("wrong! inconsistent M/N/Ks, StrideA/B/Cs size\n");
|
||||
}
|
||||
|
||||
std::vector<Tensor<ADataType>> a_m_k;
|
||||
std::vector<Tensor<BDataType>> b_k_n;
|
||||
std::vector<Tensor<CDataType>> c_m_n_device_results;
|
||||
|
||||
for(std::size_t i = 0; i < group_count; i++)
|
||||
{
|
||||
a_m_k.push_back(
|
||||
Tensor<ADataType>(f_host_tensor_descriptor(Ms[i], Ks[i], StrideAs[i], ALayout{})));
|
||||
b_k_n.push_back(
|
||||
Tensor<BDataType>(f_host_tensor_descriptor(Ks[i], Ns[i], StrideBs[i], BLayout{})));
|
||||
|
||||
c_m_n_device_results.push_back(
|
||||
Tensor<CDataType>(f_host_tensor_descriptor(Ms[i], Ns[i], StrideCs[i], CLayout{})));
|
||||
|
||||
std::cout << "group: " << i << " a_m_k[" << i << "]:" << a_m_k[i].mDesc << ", b_k_n[" << i
|
||||
<< "]:" << b_k_n[i].mDesc << ", c_m_n_device_results[" << i
|
||||
<< "]:" << c_m_n_device_results[i].mDesc << std::endl;
|
||||
|
||||
switch(init_method)
|
||||
{
|
||||
case 0: break;
|
||||
case 1:
|
||||
ck::utils::FillUniformDistributionIntegerValue<ADataType>{}(a_m_k[i]);
|
||||
ck::utils::FillUniformDistributionIntegerValue<BDataType>{}(b_k_n[i]);
|
||||
break;
|
||||
default:
|
||||
ck::utils::FillUniformDistribution<ADataType>{0.0, 1.0}(a_m_k[i]);
|
||||
ck::utils::FillUniformDistribution<BDataType>{-0.5, 0.5}(b_k_n[i]);
|
||||
}
|
||||
|
||||
ck::utils::FillConstant<CDataType>{}(c_m_n_device_results[i]);
|
||||
}
|
||||
|
||||
using AElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using BElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using CElementOp = ck::tensor_operation::element_wise::FastGelu;
|
||||
|
||||
const auto a_element_op = AElementOp{};
|
||||
const auto b_element_op = BElementOp{};
|
||||
const auto c_element_op = CElementOp{};
|
||||
|
||||
using DeviceMemPtr = std::unique_ptr<DeviceMem>;
|
||||
std::vector<DeviceMemPtr> a_device_buf, b_device_buf, c_device_buf;
|
||||
|
||||
a_device_buf.reserve(group_count);
|
||||
b_device_buf.reserve(group_count);
|
||||
c_device_buf.reserve(group_count);
|
||||
|
||||
std::vector<const void*> p_a, p_b;
|
||||
std::vector<void*> p_c;
|
||||
|
||||
p_a.reserve(group_count);
|
||||
p_b.reserve(group_count);
|
||||
p_c.reserve(group_count);
|
||||
|
||||
std::vector<ck::tensor_operation::device::GemmDesc> gemm_descs;
|
||||
|
||||
gemm_descs.reserve(group_count);
|
||||
|
||||
for(std::size_t i = 0; i < group_count; i++)
|
||||
{
|
||||
a_device_buf.emplace_back(
|
||||
std::make_unique<DeviceMem>(sizeof(ADataType) * a_m_k[i].mDesc.GetElementSpaceSize()));
|
||||
b_device_buf.emplace_back(
|
||||
std::make_unique<DeviceMem>(sizeof(BDataType) * b_k_n[i].mDesc.GetElementSpaceSize()));
|
||||
c_device_buf.emplace_back(std::make_unique<DeviceMem>(
|
||||
sizeof(CDataType) * c_m_n_device_results[i].mDesc.GetElementSpaceSize()));
|
||||
|
||||
a_device_buf[i]->ToDevice(a_m_k[i].mData.data());
|
||||
b_device_buf[i]->ToDevice(b_k_n[i].mData.data());
|
||||
c_device_buf[i]->SetZero();
|
||||
|
||||
gemm_descs.push_back({Ms[i], Ns[i], Ks[i], StrideAs[i], StrideBs[i], StrideCs[i], {}});
|
||||
|
||||
p_a.push_back(a_device_buf[i]->GetDeviceBuffer());
|
||||
p_b.push_back(b_device_buf[i]->GetDeviceBuffer());
|
||||
p_c.push_back(c_device_buf[i]->GetDeviceBuffer());
|
||||
}
|
||||
|
||||
using DeviceOp = ck::tensor_operation::device::DeviceGroupedGemm<ALayout,
|
||||
BLayout,
|
||||
ck::Tuple<>,
|
||||
CLayout,
|
||||
ADataType,
|
||||
BDataType,
|
||||
ck::Tuple<>,
|
||||
CDataType,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
CElementOp>;
|
||||
|
||||
const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
|
||||
DeviceOp>::GetInstances();
|
||||
|
||||
if(op_ptrs.size() <= 0)
|
||||
{
|
||||
throw std::runtime_error("wrong! no device GEMM instance found");
|
||||
}
|
||||
|
||||
std::string best_gemm_name;
|
||||
float best_ave_time = 0;
|
||||
float best_tflops = 0;
|
||||
float best_gb_per_sec = 0;
|
||||
|
||||
auto p_ds = std::vector<std::array<const void*, 0>>{};
|
||||
|
||||
// profile device GEMM instances
|
||||
for(auto& gemm_ptr : op_ptrs)
|
||||
{
|
||||
auto argument_ptr = gemm_ptr->MakeArgumentPointer(
|
||||
p_a, p_b, p_ds, p_c, gemm_descs, a_element_op, b_element_op, c_element_op);
|
||||
|
||||
auto invoker_ptr = gemm_ptr->MakeInvokerPointer();
|
||||
DeviceMem gemm_desc_workspace(gemm_ptr->GetWorkSpaceSize(argument_ptr.get()));
|
||||
gemm_ptr->SetWorkSpacePointer(argument_ptr.get(), gemm_desc_workspace.GetDeviceBuffer());
|
||||
|
||||
if(gemm_ptr->IsSupportedArgument(argument_ptr.get()))
|
||||
{
|
||||
std::string gemm_name = gemm_ptr->GetTypeString();
|
||||
|
||||
float ave_time =
|
||||
invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel});
|
||||
|
||||
std::size_t flop = 0, num_btype = 0;
|
||||
for(std::size_t i = 0; i < gemm_descs.size(); i++)
|
||||
{
|
||||
flop += std::size_t(2) * Ms[i] * Ns[i] * Ks[i];
|
||||
num_btype += sizeof(ADataType) * Ms[i] * Ks[i] + sizeof(BDataType) * Ks[i] * Ns[i] +
|
||||
sizeof(CDataType) * Ms[i] * Ns[i];
|
||||
}
|
||||
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
|
||||
float gb_per_sec = num_btype / 1.E6 / ave_time;
|
||||
std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops << " TFlops, "
|
||||
<< gb_per_sec << " GB/s, " << gemm_name << std::endl;
|
||||
|
||||
if(tflops > best_tflops)
|
||||
{
|
||||
best_gemm_name = gemm_name;
|
||||
best_tflops = tflops;
|
||||
best_ave_time = ave_time;
|
||||
best_gb_per_sec = gb_per_sec;
|
||||
}
|
||||
|
||||
if(do_verification)
|
||||
{
|
||||
for(std::size_t i = 0; i < gemm_descs.size(); i++)
|
||||
{
|
||||
|
||||
c_device_buf[i]->FromDevice(c_m_n_device_results[i].mData.data());
|
||||
Tensor<CDataType> c_m_n_host_result(
|
||||
f_host_tensor_descriptor(Ms[i], Ns[i], StrideCs[i], CLayout{}));
|
||||
|
||||
using ReferenceGemmInstance =
|
||||
ck::tensor_operation::host::ReferenceGemm<ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
AccDataType,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
CElementOp>;
|
||||
|
||||
auto ref_gemm = ReferenceGemmInstance{};
|
||||
auto ref_invoker = ref_gemm.MakeInvoker();
|
||||
auto ref_argument = ref_gemm.MakeArgument(a_m_k[i],
|
||||
b_k_n[i],
|
||||
c_m_n_host_result,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op);
|
||||
|
||||
ref_invoker.Run(ref_argument);
|
||||
|
||||
bool group_pass =
|
||||
ck::utils::check_err(c_m_n_device_results[i], c_m_n_host_result);
|
||||
pass = pass && group_pass;
|
||||
|
||||
std::cout << "group: " << i << " verification result: " << std::boolalpha
|
||||
<< group_pass << std::endl;
|
||||
|
||||
if(do_log)
|
||||
{
|
||||
LogRangeAsType<float>(std::cout << "a : ", a_m_k[i].mData, ",")
|
||||
<< std::endl;
|
||||
LogRangeAsType<float>(std::cout << "b: ", b_k_n[i].mData, ",") << std::endl;
|
||||
LogRangeAsType<float>(
|
||||
std::cout << "c_device: ", c_m_n_device_results[i].mData, ",")
|
||||
<< std::endl;
|
||||
LogRangeAsType<float>(
|
||||
std::cout << "c_host : ", c_m_n_host_result.mData, ",")
|
||||
<< std::endl;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cout << "does not support this GEMM problem" << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
if(do_verification)
|
||||
{
|
||||
std::cout << "Verification: " << (pass ? "SUCCESS" : "FAILURE") << std::endl;
|
||||
}
|
||||
|
||||
std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, "
|
||||
<< best_gb_per_sec << " GB/s, " << best_gemm_name << std::endl;
|
||||
|
||||
return pass;
|
||||
return profile_grouped_gemm_impl<ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
AccDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
CElementOp>(do_verification,
|
||||
init_method,
|
||||
do_log,
|
||||
time_kernel,
|
||||
Ms,
|
||||
Ns,
|
||||
Ks,
|
||||
StrideAs,
|
||||
StrideBs,
|
||||
StrideCs,
|
||||
{1});
|
||||
}
|
||||
|
||||
} // namespace profiler
|
||||
|
||||
@@ -13,6 +13,7 @@
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_gemm.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_gemm_fastgelu.hpp"
|
||||
|
||||
#include "ck/library/utility/check_err.hpp"
|
||||
#include "ck/library/utility/convolution_parameter.hpp"
|
||||
@@ -25,13 +26,18 @@
|
||||
namespace ck {
|
||||
namespace profiler {
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename CDataType,
|
||||
typename AccDataType,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename CLayout>
|
||||
typename CLayout,
|
||||
typename AElementOp = PassThrough,
|
||||
typename BElementOp = PassThrough,
|
||||
typename CElementOp = PassThrough>
|
||||
bool profile_grouped_gemm_impl(int do_verification,
|
||||
int init_method,
|
||||
bool do_log,
|
||||
@@ -43,8 +49,8 @@ bool profile_grouped_gemm_impl(int do_verification,
|
||||
const std::vector<int>& StrideBs,
|
||||
const std::vector<int>& StrideCs,
|
||||
const std::vector<int>& kbatches = {},
|
||||
int n_warmup = 1,
|
||||
int n_iter = 10,
|
||||
int n_warmup = -1,
|
||||
int n_iter = -1,
|
||||
int instance_index = -1,
|
||||
bool fail_if_no_supported_instance = false)
|
||||
{
|
||||
@@ -93,7 +99,7 @@ bool profile_grouped_gemm_impl(int do_verification,
|
||||
|
||||
c_m_n_host_results.push_back(
|
||||
Tensor<CDataType>(f_host_tensor_descriptor(Ms[i], Ns[i], StrideCs[i], CLayout{})));
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
if(do_log)
|
||||
{
|
||||
std::cout << "group: " << i << " a_m_k[" << i << "]:" << a_m_k[i].mDesc << ", b_k_n["
|
||||
<< i << "]:" << b_k_n[i].mDesc << ", c_m_n_device_results[" << i
|
||||
@@ -103,21 +109,17 @@ bool profile_grouped_gemm_impl(int do_verification,
|
||||
{
|
||||
case 0: break;
|
||||
case 1:
|
||||
ck::utils::FillUniformDistributionIntegerValue<ADataType>{-2.f, 2.f}(a_m_k[i]);
|
||||
ck::utils::FillUniformDistributionIntegerValue<BDataType>{-2.f, 2.f}(b_k_n[i]);
|
||||
max_abs_in_val = 2.f;
|
||||
ck::utils::FillUniformDistributionIntegerValue<ADataType>{-5.f, 5.f}(a_m_k[i]);
|
||||
ck::utils::FillUniformDistributionIntegerValue<BDataType>{-5.f, 5.f}(b_k_n[i]);
|
||||
max_abs_in_val = 5.f;
|
||||
break;
|
||||
default:
|
||||
ck::utils::FillUniformDistribution<ADataType>{-0.5f, 0.5f}(a_m_k[i]);
|
||||
ck::utils::FillUniformDistribution<ADataType>{0.0f, 1.0f}(a_m_k[i]);
|
||||
ck::utils::FillUniformDistribution<BDataType>{-0.5f, 0.5f}(b_k_n[i]);
|
||||
max_abs_in_val = 0.5f;
|
||||
max_abs_in_val = 1.0f;
|
||||
}
|
||||
}
|
||||
|
||||
using AElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using BElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using CElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
const auto a_element_op = AElementOp{};
|
||||
const auto b_element_op = BElementOp{};
|
||||
const auto c_element_op = CElementOp{};
|
||||
@@ -200,6 +202,17 @@ bool profile_grouped_gemm_impl(int do_verification,
|
||||
int num_kernel = 0;
|
||||
auto p_ds = std::vector<std::array<const void*, 0>>{};
|
||||
|
||||
StreamConfig stream_config{nullptr, time_kernel};
|
||||
if(n_warmup >= 0)
|
||||
{
|
||||
stream_config.cold_niters_ = n_warmup;
|
||||
}
|
||||
|
||||
if(n_iter >= 0)
|
||||
{
|
||||
stream_config.nrepeat_ = n_iter;
|
||||
}
|
||||
|
||||
if(do_verification)
|
||||
{
|
||||
for(std::size_t i = 0; i < gemm_descs.size(); i++)
|
||||
@@ -225,19 +238,33 @@ bool profile_grouped_gemm_impl(int do_verification,
|
||||
ref_invoker.Run(ref_argument);
|
||||
}
|
||||
}
|
||||
|
||||
// If the user will provide not empty kbatches list, then we test predefined set of kbatch
|
||||
// values.
|
||||
std::vector<int> kbatch_list = {1, 2, 4, 8, 12, 16, 20, 24, 32, 48, 64};
|
||||
if(!kbatches.empty())
|
||||
{
|
||||
kbatch_list = kbatches;
|
||||
}
|
||||
|
||||
// Check if the operation requested any KBatch size > 1
|
||||
bool operation_requires_splitk_support = false;
|
||||
for(auto kbatch : kbatch_list)
|
||||
{
|
||||
if(kbatch > 1)
|
||||
{
|
||||
operation_requires_splitk_support = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// profile device GEMM instances
|
||||
int instances_supporting_all_batch_sizes = 0;
|
||||
int instances_supported = 0;
|
||||
int instances_supporting_splitk = 0;
|
||||
for(auto& gemm_ptr : op_ptrs)
|
||||
{
|
||||
auto argument_ptr =
|
||||
gemm_ptr->MakeArgumentPointer(p_a,
|
||||
p_b,
|
||||
p_ds,
|
||||
p_c,
|
||||
gemm_descs,
|
||||
ck::tensor_operation::element_wise::PassThrough{},
|
||||
ck::tensor_operation::element_wise::PassThrough{},
|
||||
ck::tensor_operation::element_wise::PassThrough{});
|
||||
auto argument_ptr = gemm_ptr->MakeArgumentPointer(
|
||||
p_a, p_b, p_ds, p_c, gemm_descs, a_element_op, b_element_op, c_element_op);
|
||||
|
||||
auto invoker_ptr = gemm_ptr->MakeInvokerPointer();
|
||||
|
||||
@@ -261,16 +288,9 @@ bool profile_grouped_gemm_impl(int do_verification,
|
||||
|
||||
std::string gemm_name = gemm_ptr->GetTypeString();
|
||||
|
||||
std::vector<int> kbatch_list = {1, 2, 4, 8, 12, 16, 20, 24, 32, 48, 64};
|
||||
|
||||
// If the user will provide not empty kbatches list, then we test predefined set of kbatch
|
||||
// values.
|
||||
if(!kbatches.empty())
|
||||
{
|
||||
kbatch_list = kbatches;
|
||||
}
|
||||
|
||||
bool all_batch_sizes_supported = true;
|
||||
// Keep track if we found any supported instance
|
||||
bool any_supported_instance = false;
|
||||
bool any_supported_nontrivial_kbatch = false;
|
||||
for(std::size_t j = 0; j < kbatch_list.size(); j++)
|
||||
{
|
||||
auto kbatch_curr = kbatch_list[j];
|
||||
@@ -290,11 +310,17 @@ bool profile_grouped_gemm_impl(int do_verification,
|
||||
continue;
|
||||
}
|
||||
|
||||
// Keep track of which supported instances we found
|
||||
any_supported_instance = true;
|
||||
if(kbatch_curr > 1)
|
||||
{
|
||||
any_supported_nontrivial_kbatch = true;
|
||||
}
|
||||
|
||||
for(std::size_t i = 0; i < gemm_descs.size(); i++)
|
||||
c_device_buf[i]->SetZero();
|
||||
|
||||
invoker_ptr->Run(argument_ptr.get(),
|
||||
StreamConfig{nullptr, false, 0, n_warmup, n_iter});
|
||||
float ave_time = invoker_ptr->Run(argument_ptr.get(), stream_config);
|
||||
|
||||
if(do_verification)
|
||||
{
|
||||
@@ -329,7 +355,7 @@ bool profile_grouped_gemm_impl(int do_verification,
|
||||
}
|
||||
}
|
||||
|
||||
std::cout << "Instance: " << gemm_name << " verification "
|
||||
std::cout << "Instance: " << gemm_name << "; KBatch: " << kbatch_curr << " "
|
||||
<< (instance_pass ? "SUCCEED" : "FAILED") << std::endl;
|
||||
|
||||
pass = pass && instance_pass;
|
||||
@@ -337,10 +363,6 @@ bool profile_grouped_gemm_impl(int do_verification,
|
||||
|
||||
if(time_kernel)
|
||||
{
|
||||
float ave_time =
|
||||
invoker_ptr->Run(argument_ptr.get(),
|
||||
StreamConfig{nullptr, time_kernel, 0, n_warmup, n_iter});
|
||||
|
||||
std::size_t flop = 0, num_btype = 0;
|
||||
for(std::size_t i = 0; i < gemm_descs.size(); i++)
|
||||
{
|
||||
@@ -370,24 +392,38 @@ bool profile_grouped_gemm_impl(int do_verification,
|
||||
}
|
||||
else
|
||||
{
|
||||
all_batch_sizes_supported = false;
|
||||
std::cout << "Instance: " << gemm_name << ", does not support this GEMM problem"
|
||||
std::cout << "Instance: " << gemm_name
|
||||
<< ", does not support this GEMM problem (KBatch: " << kbatch_curr << ")"
|
||||
<< std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
// If all batch sizes were supported by this instance, the instance can be marked as
|
||||
// If any kbatch sizes > 1 were supported by this instance, the instance can be marked as
|
||||
// 'supported' for this problem
|
||||
if(all_batch_sizes_supported)
|
||||
if(any_supported_nontrivial_kbatch)
|
||||
{
|
||||
++instances_supporting_all_batch_sizes;
|
||||
++instances_supporting_splitk;
|
||||
}
|
||||
|
||||
if(any_supported_instance)
|
||||
{
|
||||
++instances_supported;
|
||||
}
|
||||
}
|
||||
|
||||
// Warn if not a single instance was supported
|
||||
if(instances_supporting_all_batch_sizes == 0)
|
||||
if(instances_supported == 0)
|
||||
{
|
||||
std::cout << "Warning! No instance found that supported all of the batch sizes."
|
||||
std::cout << "Warning! No supported instance found." << std::endl;
|
||||
|
||||
if(fail_if_no_supported_instance)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else if(operation_requires_splitk_support && instances_supporting_splitk == 0)
|
||||
{
|
||||
std::cout << "Warning! No instance found that supported any of the kbatch sizes."
|
||||
<< std::endl;
|
||||
|
||||
if(fail_if_no_supported_instance)
|
||||
|
||||
@@ -12,6 +12,12 @@ if (CK_USE_XDL OR CK_USE_WMMA)
|
||||
target_link_libraries(test_grouped_gemm_splitk PRIVATE utility device_grouped_gemm_instance)
|
||||
add_dependencies(test_grouped_gemm test_grouped_gemm_splitk)
|
||||
endif()
|
||||
|
||||
add_gtest_executable(test_grouped_gemm_fastgelu test_grouped_gemm_fastgelu.cpp)
|
||||
if(result EQUAL 0)
|
||||
target_link_libraries(test_grouped_gemm_fastgelu PRIVATE utility device_grouped_gemm_fastgelu_instance)
|
||||
add_dependencies(test_grouped_gemm test_grouped_gemm_fastgelu)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
add_gtest_executable(test_grouped_gemm_interface test_grouped_gemm_interface_xdl.cpp)
|
||||
|
||||
62
test/grouped_gemm/test_grouped_gemm_fastgelu.cpp
Normal file
62
test/grouped_gemm/test_grouped_gemm_fastgelu.cpp
Normal file
@@ -0,0 +1,62 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include <tuple>
|
||||
#include <vector>
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
|
||||
#include "ck/utility/data_type.hpp"
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "test_grouped_gemm_util.hpp"
|
||||
|
||||
ck::index_t param_mask = 0xffffff;
|
||||
ck::index_t instance_index = -1;
|
||||
|
||||
using F16 = ck::half_t;
|
||||
using BF16 = ck::bhalf_t;
|
||||
using F8 = ck::f8_t;
|
||||
using I8 = int8_t;
|
||||
|
||||
using AElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using BElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using CDEElementOp = ck::tensor_operation::element_wise::FastGelu;
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
template <typename Tuple>
|
||||
class TestGroupedGemm : public ck::test::TestGroupedGemm<Tuple, true>
|
||||
{
|
||||
};
|
||||
|
||||
// clang-format off
|
||||
using KernelTypes = ::testing::Types<
|
||||
std::tuple< Row, Row, Row, F16, F16, F16, AElementOp, BElementOp, CDEElementOp>,
|
||||
std::tuple< Row, Col, Row, F16, F16, F16, AElementOp, BElementOp, CDEElementOp>,
|
||||
std::tuple< Col, Row, Row, F16, F16, F16, AElementOp, BElementOp, CDEElementOp>,
|
||||
std::tuple< Col, Col, Row, F16, F16, F16, AElementOp, BElementOp, CDEElementOp>
|
||||
>;
|
||||
// clang-format on
|
||||
|
||||
TYPED_TEST_SUITE(TestGroupedGemm, KernelTypes);
|
||||
|
||||
#include "test_grouped_gemm_ut_cases.inc"
|
||||
int main(int argc, char** argv)
|
||||
{
|
||||
testing::InitGoogleTest(&argc, argv);
|
||||
if(argc == 1) {}
|
||||
else if(argc == 3)
|
||||
{
|
||||
param_mask = strtol(argv[1], nullptr, 0);
|
||||
instance_index = atoi(argv[2]);
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cout << "Usage of " << argv[0] << std::endl;
|
||||
std::cout << "Arg1,2: param_mask instance_index(-1 means all)" << std::endl;
|
||||
}
|
||||
return RUN_ALL_TESTS();
|
||||
}
|
||||
@@ -65,12 +65,11 @@ TYPED_TEST(TestGroupedGemm, MNKPadded)
|
||||
|
||||
TYPED_TEST(TestGroupedGemm, TestLargeKBatch)
|
||||
{
|
||||
// gfx11 does not support split-K due to missing atomic add for fp16/bf16
|
||||
// Technically, we could still run the tests for fp32, but we currently don't have instances for
|
||||
// it so we disable it entirely
|
||||
if(ck::is_gfx11_supported())
|
||||
GTEST_SKIP() << "Split-K not supported for FP16/BF16 on GFX11 due to missing atomic add "
|
||||
"instructions";
|
||||
// In some cases Split K is not supported. Running this test would fail since no instance will
|
||||
// be supported, so we skip the test
|
||||
if(!this->IsSplitKSupported())
|
||||
GTEST_SKIP() << "Split-K not supported for for the current configuration (FP16/BF16 on "
|
||||
"GFX11, or using CDE element-wise operation)";
|
||||
|
||||
const std::vector<int> Ms{188, 210};
|
||||
constexpr int N = 768;
|
||||
|
||||
@@ -7,11 +7,14 @@
|
||||
#include <string>
|
||||
#include <sstream>
|
||||
#include <tuple>
|
||||
#include <type_traits>
|
||||
#include <vector>
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
|
||||
#include "profiler/profile_grouped_gemm_impl.hpp"
|
||||
|
||||
extern ck::index_t param_mask;
|
||||
@@ -32,16 +35,46 @@ std::string serialize_range(const Range& range)
|
||||
return std::string(str.begin(), str.end() - 2);
|
||||
}
|
||||
|
||||
// Helper primary template (will be specialized on the boolean)
|
||||
template <std::size_t N,
|
||||
typename Tuple,
|
||||
typename Default,
|
||||
bool InRange = (N < std::tuple_size_v<std::remove_reference_t<Tuple>>)>
|
||||
struct tuple_element_or_impl;
|
||||
|
||||
// Specialization for the in-range case: use std::tuple_element_t
|
||||
template <std::size_t N, typename Tuple, typename Default>
|
||||
struct tuple_element_or_impl<N, Tuple, Default, true>
|
||||
{
|
||||
using type = std::tuple_element_t<N, std::remove_reference_t<Tuple>>;
|
||||
};
|
||||
|
||||
// Specialization for the out-of-range case: use Default
|
||||
template <std::size_t N, typename Tuple, typename Default>
|
||||
struct tuple_element_or_impl<N, Tuple, Default, false>
|
||||
{
|
||||
using type = Default;
|
||||
};
|
||||
|
||||
// User-facing alias
|
||||
template <std::size_t N, typename Tuple, typename Default>
|
||||
using tuple_element_or_t = typename tuple_element_or_impl<N, Tuple, Default>::type;
|
||||
|
||||
template <typename Tuple, bool FailIfNoSupportedInstances = false>
|
||||
class TestGroupedGemm : public testing::Test
|
||||
{
|
||||
protected:
|
||||
using ALayout = std::tuple_element_t<0, Tuple>;
|
||||
using BLayout = std::tuple_element_t<1, Tuple>;
|
||||
using ELayout = std::tuple_element_t<2, Tuple>;
|
||||
using ADataType = std::tuple_element_t<3, Tuple>;
|
||||
using BDataType = std::tuple_element_t<4, Tuple>;
|
||||
using EDataType = std::tuple_element_t<5, Tuple>;
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
using ALayout = std::tuple_element_t<0, Tuple>;
|
||||
using BLayout = std::tuple_element_t<1, Tuple>;
|
||||
using ELayout = std::tuple_element_t<2, Tuple>;
|
||||
using ADataType = std::tuple_element_t<3, Tuple>;
|
||||
using BDataType = std::tuple_element_t<4, Tuple>;
|
||||
using EDataType = std::tuple_element_t<5, Tuple>;
|
||||
using AElementOp = tuple_element_or_t<6, Tuple, PassThrough>;
|
||||
using BElementOp = tuple_element_or_t<7, Tuple, PassThrough>;
|
||||
using CDEElementOp = tuple_element_or_t<8, Tuple, PassThrough>;
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
@@ -57,15 +90,25 @@ class TestGroupedGemm : public testing::Test
|
||||
bool fail_if_no_supported_instances_ = FailIfNoSupportedInstances;
|
||||
std::vector<int> k_batches_;
|
||||
|
||||
void SetUp() override
|
||||
bool IsSplitKSupported()
|
||||
{
|
||||
// gfx11 does not support split-K due to missing atomic add for fp16/bf16
|
||||
// Technically, we could still use split-K for fp32, but we currently don't have
|
||||
// instances for it so we disable it entirely
|
||||
constexpr bool require_16bit_atomic_add =
|
||||
std::is_same_v<EDataType, ck::half_t> || std::is_same_v<EDataType, ck::bhalf_t>;
|
||||
if(require_16bit_atomic_add && ck::is_gfx11_supported())
|
||||
bool missing_atomic_add = require_16bit_atomic_add && ck::is_gfx11_supported();
|
||||
|
||||
// CDE element operators are not supported in combination with split K
|
||||
constexpr bool has_cde_element_operator = !std::is_same_v<CDEElementOp, PassThrough>;
|
||||
|
||||
return !missing_atomic_add && !has_cde_element_operator;
|
||||
}
|
||||
|
||||
void SetUp() override
|
||||
{
|
||||
if(!IsSplitKSupported())
|
||||
{
|
||||
// gfx11 does not support split-K due to missing atomic add for fp16/bf16
|
||||
// Technically, we could still use split-K for fp32, but we currently don't have
|
||||
// instances for it so we disable it entirely
|
||||
k_batches_ = {1};
|
||||
}
|
||||
else
|
||||
@@ -147,21 +190,24 @@ class TestGroupedGemm : public testing::Test
|
||||
float,
|
||||
ALayout,
|
||||
BLayout,
|
||||
ELayout>(verify_,
|
||||
init_method_,
|
||||
log_,
|
||||
bench_,
|
||||
Ms,
|
||||
Ns,
|
||||
Ks,
|
||||
StrideAs,
|
||||
StrideBs,
|
||||
StrideCs,
|
||||
kbatches,
|
||||
n_warmup_,
|
||||
n_iter_,
|
||||
instance_index,
|
||||
fail_if_no_supported_instances_);
|
||||
ELayout,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
CDEElementOp>(verify_,
|
||||
init_method_,
|
||||
log_,
|
||||
bench_,
|
||||
Ms,
|
||||
Ns,
|
||||
Ks,
|
||||
StrideAs,
|
||||
StrideBs,
|
||||
StrideCs,
|
||||
kbatches,
|
||||
n_warmup_,
|
||||
n_iter_,
|
||||
instance_index,
|
||||
fail_if_no_supported_instances_);
|
||||
EXPECT_TRUE(pass);
|
||||
}
|
||||
};
|
||||
|
||||
Reference in New Issue
Block a user