Revert "Implement device grouped gemm fixed nk multi abd for rdna4 (#3619)" (#3705)

This reverts commit 301eb5cf08.
This commit is contained in:
Illia Silin
2026-02-03 09:52:14 -08:00
committed by GitHub
parent 8cbd09c84a
commit 569640dc70
24 changed files with 120 additions and 3517 deletions

View File

@@ -1,194 +0,0 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include <iostream>
#include <sstream>
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/utility/functional4.hpp"
#include "ck/utility/tuple_helper.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
namespace ck {
namespace tensor_operation {
namespace host {
template <typename AsTensorTuple,
typename BsTensorTuple,
typename DsTensorTuple,
typename EDataType,
typename AccDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation,
typename AComputeType,
typename BComputeType>
struct ReferenceGemmMultiABD : public device::BaseOperator
{
// Argument
struct Argument : public device::BaseArgument
{
Argument(const AsTensorTuple& as_m_k,
const BsTensorTuple& bs_k_n,
const DsTensorTuple& ds_m_n,
Tensor<EDataType>& e_m_n,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op)
: as_m_k_{as_m_k},
bs_k_n_{bs_k_n},
ds_m_n_{ds_m_n},
e_m_n_{e_m_n},
a_element_op_{a_element_op},
b_element_op_{b_element_op},
cde_element_op_{cde_element_op}
{
}
const AsTensorTuple& as_m_k_;
const BsTensorTuple& bs_k_n_;
const DsTensorTuple& ds_m_n_;
Tensor<EDataType>& e_m_n_;
AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_;
CDEElementwiseOperation cde_element_op_;
};
// Invoker
struct Invoker : public device::BaseInvoker
{
using Argument = ReferenceGemmMultiABD::Argument;
float Run(const Argument& arg)
{
static constexpr index_t NumATensor = AsTensorTuple::Size();
static constexpr index_t NumBTensor = BsTensorTuple::Size();
static constexpr index_t NumDTensor = DsTensorTuple::Size();
const int M = arg.as_m_k_[Number<0>{}].mDesc.GetLengths()[0];
const int K = arg.as_m_k_[Number<0>{}].mDesc.GetLengths()[1];
const int N = arg.bs_k_n_[Number<0>{}].mDesc.GetLengths()[1];
Tensor<AComputeType> a_m_k({M, K});
for(int m = 0; m < M; ++m)
{
for(int k = 0; k < K; ++k)
{
// result
auto data_refs1 = ck::tie(a_m_k(m, k));
// inputs
auto data_refs2 = generate_tie(
[&](auto i) -> auto& { return arg.as_m_k_[Number<i>{}](m, k); },
Number<NumATensor>{});
auto data_refs = concat_tuple_of_reference(data_refs1, data_refs2);
unpack(arg.a_element_op_, data_refs);
}
}
Tensor<BComputeType> b_k_n({K, N});
for(int k = 0; k < K; ++k)
{
for(int n = 0; n < N; ++n)
{
// result
auto data_refs1 = ck::tie(b_k_n(k, n));
// inputs
auto data_refs2 = generate_tie(
[&](auto i) -> auto& { return arg.bs_k_n_[Number<i>{}](k, n); },
Number<NumBTensor>{});
auto data_refs = concat_tuple_of_reference(data_refs1, data_refs2);
unpack(arg.b_element_op_, data_refs);
}
}
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
Tensor<AccDataType> c_m_n({M, N});
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<AComputeType,
BComputeType,
AccDataType,
AccDataType,
PassThrough,
PassThrough,
PassThrough>;
auto ref_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_gemm.MakeInvoker();
auto ref_argument = ref_gemm.MakeArgument(
a_m_k, b_k_n, c_m_n, PassThrough{}, PassThrough{}, PassThrough{});
ref_invoker.Run(ref_argument);
for(int m = 0; m < M; ++m)
{
for(int n = 0; n < N; ++n)
{
// compulsory
auto data_refs1 = ck::tie(arg.e_m_n_(m, n), c_m_n(m, n));
// optional (if multiple Ds)
auto data_refs2 = generate_tie(
[&](auto i) -> auto& { return arg.ds_m_n_[Number<i>{}](m, n); },
Number<NumDTensor>{});
auto data_refs = concat_tuple_of_reference(data_refs1, data_refs2);
unpack(arg.cde_element_op_, data_refs);
}
}
return 0;
}
float Run(const device::BaseArgument* p_arg,
const StreamConfig& /* stream_config */ = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg));
}
};
static constexpr bool IsValidCompilationParameter()
{
// TODO: properly implement this check
return true;
}
bool IsSupportedArgument(const device::BaseArgument*) override { return true; }
static auto MakeArgument(const AsTensorTuple& as_m_k,
const BsTensorTuple& bs_k_n,
const DsTensorTuple& ds_m_n,
Tensor<EDataType>& e_m_n,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op)
{
return Argument{as_m_k, bs_k_n, ds_m_n, e_m_n, a_element_op, b_element_op, cde_element_op};
}
static auto MakeInvoker() { return Invoker{}; }
virtual std::unique_ptr<device::BaseInvoker> MakeInvokerPointer()
{
return std::make_unique<Invoker>(Invoker{});
}
std::string GetTypeString() const override
{
auto str = std::stringstream();
// clang-format off
str << "ReferenceGemmMultiABD"
<< std::endl;
// clang-format on
return str.str();
}
};
} // namespace host
} // namespace tensor_operation
} // namespace ck

View File

@@ -10,6 +10,7 @@
#include "ck/ck.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_gemm_multi_abd_fixed_nk.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp"
namespace ck {
namespace tensor_operation {
@@ -20,7 +21,6 @@ using Multiply = ck::tensor_operation::element_wise::Multiply;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using AddFastGelu = ck::tensor_operation::element_wise::AddFastGelu;
#if defined(CK_USE_XDL)
// RRR
void add_device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_mk_kn_mn_bias_gelu_instances(
std::vector<std::unique_ptr<DeviceGroupedGemmMultiABDFixedNK<ck::Tuple<Row>,
@@ -179,167 +179,6 @@ void add_device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_km_kn_mn_instan
PassThrough,
Multiply,
PassThrough>>>& instances);
#endif
#if defined(CK_USE_WMMA)
void add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_kn_mn_bias_gelu_instances(
std::vector<std::unique_ptr<DeviceGroupedGemmMultiABDFixedNK<ck::Tuple<Row>,
ck::Tuple<Row, Row>,
ck::Tuple<Row>,
Row,
ck::Tuple<BF16>,
ck::Tuple<I8, BF16>,
ck::Tuple<BF16>,
BF16,
PassThrough,
Multiply,
AddFastGelu>>>& instances);
void add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_kn_mn_bias_instances(
std::vector<std::unique_ptr<DeviceGroupedGemmMultiABDFixedNK<ck::Tuple<Row>,
ck::Tuple<Row, Row>,
ck::Tuple<Row>,
Row,
ck::Tuple<BF16>,
ck::Tuple<I8, BF16>,
ck::Tuple<BF16>,
BF16,
PassThrough,
Multiply,
Add>>>& instances);
void add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_kn_mn_gelu_instances(
std::vector<std::unique_ptr<DeviceGroupedGemmMultiABDFixedNK<ck::Tuple<Row>,
ck::Tuple<Row, Row>,
ck::Tuple<>,
Row,
ck::Tuple<BF16>,
ck::Tuple<I8, BF16>,
ck::Tuple<>,
BF16,
PassThrough,
Multiply,
FastGelu>>>& instances);
void add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_kn_mn_instances(
std::vector<std::unique_ptr<DeviceGroupedGemmMultiABDFixedNK<ck::Tuple<Row>,
ck::Tuple<Row, Row>,
ck::Tuple<>,
Row,
ck::Tuple<BF16>,
ck::Tuple<I8, BF16>,
ck::Tuple<>,
BF16,
PassThrough,
Multiply,
PassThrough>>>& instances);
// RCR
void add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_nk_mn_bias_gelu_instances(
std::vector<std::unique_ptr<DeviceGroupedGemmMultiABDFixedNK<ck::Tuple<Row>,
ck::Tuple<Col, Col>,
ck::Tuple<Row>,
Row,
ck::Tuple<BF16>,
ck::Tuple<I8, BF16>,
ck::Tuple<BF16>,
BF16,
PassThrough,
Multiply,
AddFastGelu>>>& instances);
void add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_nk_mn_bias_instances(
std::vector<std::unique_ptr<DeviceGroupedGemmMultiABDFixedNK<ck::Tuple<Row>,
ck::Tuple<Col, Col>,
ck::Tuple<Row>,
Row,
ck::Tuple<BF16>,
ck::Tuple<I8, BF16>,
ck::Tuple<BF16>,
BF16,
PassThrough,
Multiply,
Add>>>& instances);
void add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_nk_mn_gelu_instances(
std::vector<std::unique_ptr<DeviceGroupedGemmMultiABDFixedNK<ck::Tuple<Row>,
ck::Tuple<Col, Col>,
ck::Tuple<>,
Row,
ck::Tuple<BF16>,
ck::Tuple<I8, BF16>,
ck::Tuple<>,
BF16,
PassThrough,
Multiply,
FastGelu>>>& instances);
void add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_nk_mn_instances(
std::vector<std::unique_ptr<DeviceGroupedGemmMultiABDFixedNK<ck::Tuple<Row>,
ck::Tuple<Col, Col>,
ck::Tuple<>,
Row,
ck::Tuple<BF16>,
ck::Tuple<I8, BF16>,
ck::Tuple<>,
BF16,
PassThrough,
Multiply,
PassThrough>>>& instances);
// CRR
void add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_km_kn_mn_bias_gelu_instances(
std::vector<std::unique_ptr<DeviceGroupedGemmMultiABDFixedNK<ck::Tuple<Col>,
ck::Tuple<Row, Row>,
ck::Tuple<Row>,
Row,
ck::Tuple<BF16>,
ck::Tuple<I8, BF16>,
ck::Tuple<BF16>,
BF16,
PassThrough,
Multiply,
AddFastGelu>>>& instances);
void add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_km_kn_mn_bias_instances(
std::vector<std::unique_ptr<DeviceGroupedGemmMultiABDFixedNK<ck::Tuple<Col>,
ck::Tuple<Row, Row>,
ck::Tuple<Row>,
Row,
ck::Tuple<BF16>,
ck::Tuple<I8, BF16>,
ck::Tuple<BF16>,
BF16,
PassThrough,
Multiply,
Add>>>& instances);
void add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_km_kn_mn_gelu_instances(
std::vector<std::unique_ptr<DeviceGroupedGemmMultiABDFixedNK<ck::Tuple<Col>,
ck::Tuple<Row, Row>,
ck::Tuple<>,
Row,
ck::Tuple<BF16>,
ck::Tuple<I8, BF16>,
ck::Tuple<>,
BF16,
PassThrough,
Multiply,
FastGelu>>>& instances);
void add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_km_kn_mn_instances(
std::vector<std::unique_ptr<DeviceGroupedGemmMultiABDFixedNK<ck::Tuple<Col>,
ck::Tuple<Row, Row>,
ck::Tuple<>,
Row,
ck::Tuple<BF16>,
ck::Tuple<I8, BF16>,
ck::Tuple<>,
BF16,
PassThrough,
Multiply,
PassThrough>>>& instances);
#endif // CK_USE
// GEMM + Add + Gelu
template <typename AsLayout,
@@ -379,7 +218,6 @@ struct DeviceOperationInstanceFactory<
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
#if defined(CK_USE_XDL)
if constexpr(is_same_v<AsDataType, ck::Tuple<BF16>> &&
is_same_v<BsDataType, ck::Tuple<I8, BF16>> &&
is_same_v<DsDataType, ck::Tuple<BF16>> && is_same_v<EDataType, BF16>)
@@ -408,38 +246,6 @@ struct DeviceOperationInstanceFactory<
op_ptrs);
}
}
#endif // CK_USE_XDL
#if defined(CK_USE_WMMA)
if constexpr(is_same_v<AsDataType, ck::Tuple<BF16>> &&
is_same_v<BsDataType, ck::Tuple<I8, BF16>> &&
is_same_v<DsDataType, ck::Tuple<BF16>> && is_same_v<EDataType, BF16>)
{
if constexpr(is_same_v<AsLayout, ck::Tuple<Row>> &&
is_same_v<BsLayout, ck::Tuple<Row, Row>> &&
is_same_v<DsLayout, ck::Tuple<Row>> && is_same_v<ELayout, Row>)
{
add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_kn_mn_bias_gelu_instances(
op_ptrs);
}
if constexpr(is_same_v<AsLayout, ck::Tuple<Col>> &&
is_same_v<BsLayout, ck::Tuple<Row, Row>> &&
is_same_v<DsLayout, ck::Tuple<Row>> && is_same_v<ELayout, Row>)
{
add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_km_kn_mn_bias_gelu_instances(
op_ptrs);
}
if constexpr(is_same_v<AsLayout, ck::Tuple<Row>> &&
is_same_v<BsLayout, ck::Tuple<Col, Col>> &&
is_same_v<DsLayout, ck::Tuple<Row>> && is_same_v<ELayout, Row>)
{
add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_nk_mn_bias_gelu_instances(
op_ptrs);
}
}
#endif // CK_USE_WMMA
return op_ptrs;
}
@@ -483,7 +289,6 @@ struct DeviceOperationInstanceFactory<
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
#if defined(CK_USE_XDL)
if constexpr(is_same_v<AsDataType, ck::Tuple<BF16>> &&
is_same_v<BsDataType, ck::Tuple<I8, BF16>> &&
is_same_v<DsDataType, ck::Tuple<BF16>> && is_same_v<EDataType, BF16>)
@@ -512,38 +317,6 @@ struct DeviceOperationInstanceFactory<
op_ptrs);
}
}
#endif // CK_USE_XDL
#if defined(CK_USE_WMMA)
if constexpr(is_same_v<AsDataType, ck::Tuple<BF16>> &&
is_same_v<BsDataType, ck::Tuple<I8, BF16>> &&
is_same_v<DsDataType, ck::Tuple<BF16>> && is_same_v<EDataType, BF16>)
{
if constexpr(is_same_v<AsLayout, ck::Tuple<Row>> &&
is_same_v<BsLayout, ck::Tuple<Row, Row>> &&
is_same_v<DsLayout, ck::Tuple<Row>> && is_same_v<ELayout, Row>)
{
add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_kn_mn_bias_instances(
op_ptrs);
}
if constexpr(is_same_v<AsLayout, ck::Tuple<Col>> &&
is_same_v<BsLayout, ck::Tuple<Row, Row>> &&
is_same_v<DsLayout, ck::Tuple<Row>> && is_same_v<ELayout, Row>)
{
add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_km_kn_mn_bias_instances(
op_ptrs);
}
if constexpr(is_same_v<AsLayout, ck::Tuple<Row>> &&
is_same_v<BsLayout, ck::Tuple<Col, Col>> &&
is_same_v<DsLayout, ck::Tuple<Row>> && is_same_v<ELayout, Row>)
{
add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_nk_mn_bias_instances(
op_ptrs);
}
}
#endif // CK_USE_WMMA
return op_ptrs;
}
@@ -587,7 +360,6 @@ struct DeviceOperationInstanceFactory<
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
#if defined(CK_USE_XDL)
if constexpr(is_same_v<AsDataType, ck::Tuple<BF16>> &&
is_same_v<BsDataType, ck::Tuple<I8, BF16>> &&
is_same_v<DsDataType, ck::Tuple<>> && is_same_v<EDataType, BF16>)
@@ -616,38 +388,6 @@ struct DeviceOperationInstanceFactory<
op_ptrs);
}
}
#endif // CK_USE_XDL
#if defined(CK_USE_WMMA)
if constexpr(is_same_v<AsDataType, ck::Tuple<BF16>> &&
is_same_v<BsDataType, ck::Tuple<I8, BF16>> &&
is_same_v<DsDataType, ck::Tuple<>> && is_same_v<EDataType, BF16>)
{
if constexpr(is_same_v<AsLayout, ck::Tuple<Row>> &&
is_same_v<BsLayout, ck::Tuple<Row, Row>> &&
is_same_v<DsLayout, ck::Tuple<>> && is_same_v<ELayout, Row>)
{
add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_kn_mn_gelu_instances(
op_ptrs);
}
if constexpr(is_same_v<AsLayout, ck::Tuple<Col>> &&
is_same_v<BsLayout, ck::Tuple<Row, Row>> &&
is_same_v<DsLayout, ck::Tuple<>> && is_same_v<ELayout, Row>)
{
add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_km_kn_mn_gelu_instances(
op_ptrs);
}
if constexpr(is_same_v<AsLayout, ck::Tuple<Row>> &&
is_same_v<BsLayout, ck::Tuple<Col, Col>> &&
is_same_v<DsLayout, ck::Tuple<>> && is_same_v<ELayout, Row>)
{
add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_nk_mn_gelu_instances(
op_ptrs);
}
}
#endif // CK_USE_WMMA
return op_ptrs;
}
@@ -691,7 +431,6 @@ struct DeviceOperationInstanceFactory<
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
#if defined(CK_USE_XDL)
if constexpr(is_same_v<AsDataType, ck::Tuple<BF16>> &&
is_same_v<BsDataType, ck::Tuple<I8, BF16>> &&
is_same_v<DsDataType, ck::Tuple<>> && is_same_v<EDataType, BF16>)
@@ -720,38 +459,6 @@ struct DeviceOperationInstanceFactory<
op_ptrs);
}
}
#endif // CK_USE_XDL
#if defined(CK_USE_WMMA)
if constexpr(is_same_v<AsDataType, ck::Tuple<BF16>> &&
is_same_v<BsDataType, ck::Tuple<I8, BF16>> &&
is_same_v<DsDataType, ck::Tuple<>> && is_same_v<EDataType, BF16>)
{
if constexpr(is_same_v<AsLayout, ck::Tuple<Row>> &&
is_same_v<BsLayout, ck::Tuple<Row, Row>> &&
is_same_v<DsLayout, ck::Tuple<>> && is_same_v<ELayout, Row>)
{
add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_kn_mn_instances(
op_ptrs);
}
if constexpr(is_same_v<AsLayout, ck::Tuple<Col>> &&
is_same_v<BsLayout, ck::Tuple<Row, Row>> &&
is_same_v<DsLayout, ck::Tuple<>> && is_same_v<ELayout, Row>)
{
add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_km_kn_mn_instances(
op_ptrs);
}
if constexpr(is_same_v<AsLayout, ck::Tuple<Row>> &&
is_same_v<BsLayout, ck::Tuple<Col, Col>> &&
is_same_v<DsLayout, ck::Tuple<>> && is_same_v<ELayout, Row>)
{
add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_nk_mn_instances(
op_ptrs);
}
}
#endif // CK_USE_WMMA
return op_ptrs;
}

View File

@@ -1,17 +1,13 @@
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
# ONLY XDL_AND_WMMA_KERNELS
# ONLY XDL_KERNELS
set(GROUPED_GEMM_FIXED_NK_MULTI_ABD_INSTANCES)
list(APPEND GROUPED_GEMM_FIXED_NK_MULTI_ABD_INSTANCES
device_grouped_gemm_xdl_fixed_nk_bias_gelu_bf16_i8_bf16_mk_kn_mn_instance.cpp
device_grouped_gemm_xdl_fixed_nk_bias_gelu_bf16_i8_bf16_mk_nk_mn_instance.cpp
device_grouped_gemm_xdl_fixed_nk_bias_gelu_bf16_i8_bf16_km_kn_mn_instance.cpp
device_grouped_gemm_wmma_fixed_nk_bias_gelu_bf16_i8_bf16_mk_kn_mn_instance.cpp
device_grouped_gemm_wmma_fixed_nk_bias_gelu_bf16_i8_bf16_mk_nk_mn_instance.cpp
device_grouped_gemm_wmma_fixed_nk_bias_gelu_bf16_i8_bf16_km_kn_mn_instance.cpp
)
add_instance_library(device_grouped_gemm_fixed_nk_multi_abd_instance ${GROUPED_GEMM_FIXED_NK_MULTI_ABD_INSTANCES})

View File

@@ -1,144 +0,0 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_gemm_multi_abd.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_wmma_fixed_nk.hpp"
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
template <index_t... Is>
using S = Sequence<Is...>;
using BF16 = bhalf_t;
using I8 = int8_t;
using F32 = float;
using Row = tensor_layout::gemm::RowMajor;
using Col = tensor_layout::gemm::ColumnMajor;
using Multiply = element_wise::Multiply;
using PassThrough = element_wise::PassThrough;
using AddFastGelu = element_wise::AddFastGelu;
using Add = element_wise::Add;
using FastGelu = element_wise::FastGelu;
static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding;
template <typename DsLayout,
typename DsDataType,
typename CDEElementOp,
GemmSpecialization GemmSpec>
using device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_km_kn_mn_instances = std::tuple<
// clang-format off
//#######################################| AsLayout| BsLayout| DsLayout| ELayout| AsData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//#######################################| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MPerBlock_NBlock_NPerBlock| ScalarPerVector|
//#######################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| | _NWaveNPerXdl|
//#######################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGroupedGemm_Wmma_Multi_ABD_Fixed_NK< Tuple<Col>, Tuple<Row, Row>, DsLayout, Row, Tuple<BF16>, Tuple<I8, BF16>, F32, BF16, DsDataType, BF16, PassThrough, Multiply, CDEElementOp, GemmSpec, 256, 256, 256, 32, 8, 8, 16, 16, 8, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 0, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 32, 1, 8>, 4>,
DeviceGroupedGemm_Wmma_Multi_ABD_Fixed_NK< Tuple<Col>, Tuple<Row, Row>, DsLayout, Row, Tuple<BF16>, Tuple<I8, BF16>, F32, BF16, DsDataType, BF16, PassThrough, Multiply, CDEElementOp, GemmSpec, 256, 128, 256, 32, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 0, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 32, 1, 8>, 4>,
DeviceGroupedGemm_Wmma_Multi_ABD_Fixed_NK< Tuple<Col>, Tuple<Row, Row>, DsLayout, Row, Tuple<BF16>, Tuple<I8, BF16>, F32, BF16, DsDataType, BF16, PassThrough, Multiply, CDEElementOp, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 2, 4, S<4, 32, 2>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 0, S<4, 32, 2>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 32, 1, 8>, 4>
// clang-format on
>;
void add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_km_kn_mn_bias_gelu_instances(
std::vector<std::unique_ptr<DeviceGroupedGemmMultiABDFixedNK<Tuple<Col>,
Tuple<Row, Row>,
Tuple<Row>,
Row,
Tuple<BF16>,
Tuple<I8, BF16>,
Tuple<BF16>,
BF16,
PassThrough,
Multiply,
AddFastGelu>>>& instances)
{
add_device_operation_instances(
instances,
device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_km_kn_mn_instances<
Tuple<Row>,
Tuple<BF16>,
AddFastGelu,
GemmMNKPadding>{});
}
void add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_km_kn_mn_bias_instances(
std::vector<std::unique_ptr<DeviceGroupedGemmMultiABDFixedNK<Tuple<Col>,
Tuple<Row, Row>,
Tuple<Row>,
Row,
Tuple<BF16>,
Tuple<I8, BF16>,
Tuple<BF16>,
BF16,
PassThrough,
Multiply,
Add>>>& instances)
{
add_device_operation_instances(
instances,
device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_km_kn_mn_instances<
Tuple<Row>,
Tuple<BF16>,
Add,
GemmMNKPadding>{});
}
void add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_km_kn_mn_instances(
std::vector<std::unique_ptr<DeviceGroupedGemmMultiABDFixedNK<Tuple<Col>,
Tuple<Row, Row>,
Tuple<>,
Row,
Tuple<BF16>,
Tuple<I8, BF16>,
Tuple<>,
BF16,
PassThrough,
Multiply,
PassThrough>>>& instances)
{
add_device_operation_instances(
instances,
device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_km_kn_mn_instances<
Tuple<>,
Tuple<>,
PassThrough,
GemmMNKPadding>{});
}
void add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_km_kn_mn_gelu_instances(
std::vector<std::unique_ptr<DeviceGroupedGemmMultiABDFixedNK<Tuple<Col>,
Tuple<Row, Row>,
Tuple<>,
Row,
Tuple<BF16>,
Tuple<I8, BF16>,
Tuple<>,
BF16,
PassThrough,
Multiply,
FastGelu>>>& instances)
{
add_device_operation_instances(
instances,
device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_km_kn_mn_instances<
Tuple<>,
Tuple<>,
FastGelu,
GemmMNKPadding>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -1,144 +0,0 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_gemm_multi_abd.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_wmma_fixed_nk.hpp"
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
template <index_t... Is>
using S = Sequence<Is...>;
using BF16 = bhalf_t;
using I8 = int8_t;
using F32 = float;
using Row = tensor_layout::gemm::RowMajor;
using Col = tensor_layout::gemm::ColumnMajor;
using Multiply = element_wise::Multiply;
using PassThrough = element_wise::PassThrough;
using AddFastGelu = element_wise::AddFastGelu;
using Add = element_wise::Add;
using FastGelu = element_wise::FastGelu;
static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding;
template <typename DsLayout,
typename DsDataType,
typename CDEElementOp,
GemmSpecialization GemmSpec>
using device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_kn_mn_instances = std::tuple<
// clang-format off
//#######################################| AsLayout| BsLayout| DsLayout| ELayout| AsData| BsData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//#######################################| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MPerBlock_NBlock_NPerBlock| ScalarPerVector|
//#######################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| | |
//#######################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGroupedGemm_Wmma_Multi_ABD_Fixed_NK< Tuple<Row>, Tuple<Row, Row>, DsLayout, Row, Tuple<BF16>, Tuple<I8, BF16>, F32, BF16, DsDataType, BF16, PassThrough, Multiply, CDEElementOp, GemmSpec, 256, 256, 256, 32, 8, 8, 16, 16, 8, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 32, 1, 8>, 4>,
DeviceGroupedGemm_Wmma_Multi_ABD_Fixed_NK< Tuple<Row>, Tuple<Row, Row>, DsLayout, Row, Tuple<BF16>, Tuple<I8, BF16>, F32, BF16, DsDataType, BF16, PassThrough, Multiply, CDEElementOp, GemmSpec, 256, 128, 256, 32, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 32, 1, 8>, 4>,
DeviceGroupedGemm_Wmma_Multi_ABD_Fixed_NK< Tuple<Row>, Tuple<Row, Row>, DsLayout, Row, Tuple<BF16>, Tuple<I8, BF16>, F32, BF16, DsDataType, BF16, PassThrough, Multiply, CDEElementOp, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 2>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 4>
// clang-format on
>;
void add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_kn_mn_bias_gelu_instances(
std::vector<std::unique_ptr<DeviceGroupedGemmMultiABDFixedNK<Tuple<Row>,
Tuple<Row, Row>,
Tuple<Row>,
Row,
Tuple<BF16>,
Tuple<I8, BF16>,
Tuple<BF16>,
BF16,
PassThrough,
Multiply,
AddFastGelu>>>& instances)
{
add_device_operation_instances(
instances,
device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_kn_mn_instances<
Tuple<Row>,
Tuple<BF16>,
AddFastGelu,
GemmMNKPadding>{});
}
void add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_kn_mn_bias_instances(
std::vector<std::unique_ptr<DeviceGroupedGemmMultiABDFixedNK<Tuple<Row>,
Tuple<Row, Row>,
Tuple<Row>,
Row,
Tuple<BF16>,
Tuple<I8, BF16>,
Tuple<BF16>,
BF16,
PassThrough,
Multiply,
Add>>>& instances)
{
add_device_operation_instances(
instances,
device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_kn_mn_instances<
Tuple<Row>,
Tuple<BF16>,
Add,
GemmMNKPadding>{});
}
void add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_kn_mn_instances(
std::vector<std::unique_ptr<DeviceGroupedGemmMultiABDFixedNK<Tuple<Row>,
Tuple<Row, Row>,
Tuple<>,
Row,
Tuple<BF16>,
Tuple<I8, BF16>,
Tuple<>,
BF16,
PassThrough,
Multiply,
PassThrough>>>& instances)
{
add_device_operation_instances(
instances,
device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_kn_mn_instances<
Tuple<>,
Tuple<>,
PassThrough,
GemmMNKPadding>{});
}
void add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_kn_mn_gelu_instances(
std::vector<std::unique_ptr<DeviceGroupedGemmMultiABDFixedNK<Tuple<Row>,
Tuple<Row, Row>,
Tuple<>,
Row,
Tuple<BF16>,
Tuple<I8, BF16>,
Tuple<>,
BF16,
PassThrough,
Multiply,
FastGelu>>>& instances)
{
add_device_operation_instances(
instances,
device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_kn_mn_instances<
Tuple<>,
Tuple<>,
FastGelu,
GemmMNKPadding>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -1,144 +0,0 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_gemm_multi_abd.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_wmma_fixed_nk.hpp"
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
template <index_t... Is>
using S = Sequence<Is...>;
using BF16 = bhalf_t;
using I8 = int8_t;
using F32 = float;
using Row = tensor_layout::gemm::RowMajor;
using Col = tensor_layout::gemm::ColumnMajor;
using Multiply = element_wise::Multiply;
using PassThrough = element_wise::PassThrough;
using AddFastGelu = element_wise::AddFastGelu;
using Add = element_wise::Add;
using FastGelu = element_wise::FastGelu;
static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding;
template <typename DsLayout,
typename DsDataType,
typename CDEElementOp,
GemmSpecialization GemmSpec>
using device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_nk_mn_instances = std::tuple<
// clang-format off
//######################################| AsLayout| BsLayout| DsLayout| ELayout| AsData| BsData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//######################################| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise|Spacialization| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MPerBlock_NBlock_NPerBlock| ScalarPerVector|
//######################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| | |
//######################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGroupedGemm_Wmma_Multi_ABD_Fixed_NK< Tuple<Row>, Tuple<Col, Col>, DsLayout, Row, Tuple<BF16>, Tuple<I8, BF16>, F32, BF16, DsDataType, BF16, PassThrough, Multiply, CDEElementOp, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 2, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8>,
DeviceGroupedGemm_Wmma_Multi_ABD_Fixed_NK< Tuple<Row>, Tuple<Col, Col>, DsLayout, Row, Tuple<BF16>, Tuple<I8, BF16>, F32, BF16, DsDataType, BF16, PassThrough, Multiply, CDEElementOp, GemmSpec, 256, 128, 128, 64, 2, 2, 16, 16, 2, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, 1, 1, S<1, 64, 1, 4>, 8>,
DeviceGroupedGemm_Wmma_Multi_ABD_Fixed_NK< Tuple<Row>, Tuple<Col, Col>, DsLayout, Row, Tuple<BF16>, Tuple<I8, BF16>, F32, BF16, DsDataType, BF16, PassThrough, Multiply, CDEElementOp, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 2, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8>
// clang-format on
>;
void add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_nk_mn_bias_gelu_instances(
std::vector<std::unique_ptr<DeviceGroupedGemmMultiABDFixedNK<Tuple<Row>,
Tuple<Col, Col>,
Tuple<Row>,
Row,
Tuple<BF16>,
Tuple<I8, BF16>,
Tuple<BF16>,
BF16,
PassThrough,
Multiply,
AddFastGelu>>>& instances)
{
add_device_operation_instances(
instances,
device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_nk_mn_instances<
Tuple<Row>,
Tuple<BF16>,
AddFastGelu,
GemmMNKPadding>{});
}
void add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_nk_mn_bias_instances(
std::vector<std::unique_ptr<DeviceGroupedGemmMultiABDFixedNK<Tuple<Row>,
Tuple<Col, Col>,
Tuple<Row>,
Row,
Tuple<BF16>,
Tuple<I8, BF16>,
Tuple<BF16>,
BF16,
PassThrough,
Multiply,
Add>>>& instances)
{
add_device_operation_instances(
instances,
device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_nk_mn_instances<
Tuple<Row>,
Tuple<BF16>,
Add,
GemmMNKPadding>{});
}
void add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_nk_mn_instances(
std::vector<std::unique_ptr<DeviceGroupedGemmMultiABDFixedNK<Tuple<Row>,
Tuple<Col, Col>,
Tuple<>,
Row,
Tuple<BF16>,
Tuple<I8, BF16>,
Tuple<>,
BF16,
PassThrough,
Multiply,
PassThrough>>>& instances)
{
add_device_operation_instances(
instances,
device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_nk_mn_instances<
Tuple<>,
Tuple<>,
PassThrough,
GemmMNKPadding>{});
}
void add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_nk_mn_gelu_instances(
std::vector<std::unique_ptr<DeviceGroupedGemmMultiABDFixedNK<Tuple<Row>,
Tuple<Col, Col>,
Tuple<>,
Row,
Tuple<BF16>,
Tuple<I8, BF16>,
Tuple<>,
BF16,
PassThrough,
Multiply,
FastGelu>>>& instances)
{
add_device_operation_instances(
instances,
device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_nk_mn_instances<
Tuple<>,
Tuple<>,
FastGelu,
GemmMNKPadding>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -61,8 +61,6 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecial
static constexpr auto GemmMNPadding = ck::tensor_operation::device::GemmSpecialization::MNPadding;
static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
// NOTE: After adding unit tests for DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK it tuned out that
// portion of the instances are failing. As a workaround these have been commented out.
template <typename DsLayout,
typename DsDataType,
typename CDEElementOp,
@@ -74,14 +72,14 @@ using device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_km_kn_mn_instances
//######################################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//######################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
// DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 64, 32, 8, 2, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
// DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 64, 128, 32, 8, 2, 32, 32, 1, 2, S<4, 32, 2>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 64, 128, 32, 8, 2, 32, 32, 1, 2, S<4, 32, 2>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 2>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
// DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 128, 128, 64, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 4>, 8>,
DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 128, 128, 64, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 4>, 8>,
DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>,
// DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 128, 64, 128, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8>,
DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 128, 64, 128, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8>,
DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>
// clang-format on
>;

View File

@@ -61,8 +61,6 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecial
static constexpr auto GemmMNPadding = ck::tensor_operation::device::GemmSpecialization::MNPadding;
static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
// NOTE: After adding unit tests for DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK it tuned out that
// portion of the instances are failing. As a workaround these have been commented out.
template <typename DsLayout,
typename DsDataType,
typename CDEElementOp,
@@ -74,14 +72,14 @@ using device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_mk_kn_mn_instances
//######################################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//######################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
// DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 64, 32, 8, 2, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
// DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 64, 128, 32, 8, 2, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 64, 128, 32, 8, 2, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
// DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 128, 128, 64, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 4>, 8>,
DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 128, 128, 64, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 4>, 8>,
DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>,
// DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 128, 64, 128, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8>,
DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 128, 64, 128, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8>,
DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>
// clang-format on
>;

View File

@@ -61,8 +61,6 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecial
static constexpr auto GemmMNPadding = ck::tensor_operation::device::GemmSpecialization::MNPadding;
static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
// NOTE: After adding unit tests for DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK it tuned out that
// portion of the instances are failing. As a workaround these have been commented out.
template <typename DsLayout,
typename DsDataType,
typename CDEElementOp,