mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-21 21:39:15 +00:00
This reverts commit1a8bd3d34b. [ROCm/composable_kernel commit:569640dc70]
This commit is contained in:
@@ -1,194 +0,0 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
|
||||
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_base.hpp"
|
||||
#include "ck/library/utility/host_tensor.hpp"
|
||||
#include "ck/utility/functional4.hpp"
|
||||
#include "ck/utility/tuple_helper.hpp"
|
||||
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace host {
|
||||
|
||||
template <typename AsTensorTuple,
|
||||
typename BsTensorTuple,
|
||||
typename DsTensorTuple,
|
||||
typename EDataType,
|
||||
typename AccDataType,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CDEElementwiseOperation,
|
||||
typename AComputeType,
|
||||
typename BComputeType>
|
||||
struct ReferenceGemmMultiABD : public device::BaseOperator
|
||||
{
|
||||
// Argument
|
||||
struct Argument : public device::BaseArgument
|
||||
{
|
||||
Argument(const AsTensorTuple& as_m_k,
|
||||
const BsTensorTuple& bs_k_n,
|
||||
const DsTensorTuple& ds_m_n,
|
||||
Tensor<EDataType>& e_m_n,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CDEElementwiseOperation cde_element_op)
|
||||
: as_m_k_{as_m_k},
|
||||
bs_k_n_{bs_k_n},
|
||||
ds_m_n_{ds_m_n},
|
||||
e_m_n_{e_m_n},
|
||||
a_element_op_{a_element_op},
|
||||
b_element_op_{b_element_op},
|
||||
cde_element_op_{cde_element_op}
|
||||
{
|
||||
}
|
||||
|
||||
const AsTensorTuple& as_m_k_;
|
||||
const BsTensorTuple& bs_k_n_;
|
||||
const DsTensorTuple& ds_m_n_;
|
||||
Tensor<EDataType>& e_m_n_;
|
||||
|
||||
AElementwiseOperation a_element_op_;
|
||||
BElementwiseOperation b_element_op_;
|
||||
CDEElementwiseOperation cde_element_op_;
|
||||
};
|
||||
|
||||
// Invoker
|
||||
struct Invoker : public device::BaseInvoker
|
||||
{
|
||||
using Argument = ReferenceGemmMultiABD::Argument;
|
||||
|
||||
float Run(const Argument& arg)
|
||||
{
|
||||
static constexpr index_t NumATensor = AsTensorTuple::Size();
|
||||
static constexpr index_t NumBTensor = BsTensorTuple::Size();
|
||||
static constexpr index_t NumDTensor = DsTensorTuple::Size();
|
||||
|
||||
const int M = arg.as_m_k_[Number<0>{}].mDesc.GetLengths()[0];
|
||||
const int K = arg.as_m_k_[Number<0>{}].mDesc.GetLengths()[1];
|
||||
const int N = arg.bs_k_n_[Number<0>{}].mDesc.GetLengths()[1];
|
||||
|
||||
Tensor<AComputeType> a_m_k({M, K});
|
||||
for(int m = 0; m < M; ++m)
|
||||
{
|
||||
for(int k = 0; k < K; ++k)
|
||||
{
|
||||
// result
|
||||
auto data_refs1 = ck::tie(a_m_k(m, k));
|
||||
// inputs
|
||||
auto data_refs2 = generate_tie(
|
||||
[&](auto i) -> auto& { return arg.as_m_k_[Number<i>{}](m, k); },
|
||||
Number<NumATensor>{});
|
||||
auto data_refs = concat_tuple_of_reference(data_refs1, data_refs2);
|
||||
unpack(arg.a_element_op_, data_refs);
|
||||
}
|
||||
}
|
||||
|
||||
Tensor<BComputeType> b_k_n({K, N});
|
||||
for(int k = 0; k < K; ++k)
|
||||
{
|
||||
for(int n = 0; n < N; ++n)
|
||||
{
|
||||
// result
|
||||
auto data_refs1 = ck::tie(b_k_n(k, n));
|
||||
// inputs
|
||||
auto data_refs2 = generate_tie(
|
||||
[&](auto i) -> auto& { return arg.bs_k_n_[Number<i>{}](k, n); },
|
||||
Number<NumBTensor>{});
|
||||
auto data_refs = concat_tuple_of_reference(data_refs1, data_refs2);
|
||||
unpack(arg.b_element_op_, data_refs);
|
||||
}
|
||||
}
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
Tensor<AccDataType> c_m_n({M, N});
|
||||
|
||||
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<AComputeType,
|
||||
BComputeType,
|
||||
AccDataType,
|
||||
AccDataType,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>;
|
||||
auto ref_gemm = ReferenceGemmInstance{};
|
||||
auto ref_invoker = ref_gemm.MakeInvoker();
|
||||
|
||||
auto ref_argument = ref_gemm.MakeArgument(
|
||||
a_m_k, b_k_n, c_m_n, PassThrough{}, PassThrough{}, PassThrough{});
|
||||
|
||||
ref_invoker.Run(ref_argument);
|
||||
|
||||
for(int m = 0; m < M; ++m)
|
||||
{
|
||||
for(int n = 0; n < N; ++n)
|
||||
{
|
||||
// compulsory
|
||||
auto data_refs1 = ck::tie(arg.e_m_n_(m, n), c_m_n(m, n));
|
||||
// optional (if multiple Ds)
|
||||
auto data_refs2 = generate_tie(
|
||||
[&](auto i) -> auto& { return arg.ds_m_n_[Number<i>{}](m, n); },
|
||||
Number<NumDTensor>{});
|
||||
auto data_refs = concat_tuple_of_reference(data_refs1, data_refs2);
|
||||
unpack(arg.cde_element_op_, data_refs);
|
||||
}
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
float Run(const device::BaseArgument* p_arg,
|
||||
const StreamConfig& /* stream_config */ = StreamConfig{}) override
|
||||
{
|
||||
return Run(*dynamic_cast<const Argument*>(p_arg));
|
||||
}
|
||||
};
|
||||
|
||||
static constexpr bool IsValidCompilationParameter()
|
||||
{
|
||||
// TODO: properly implement this check
|
||||
return true;
|
||||
}
|
||||
|
||||
bool IsSupportedArgument(const device::BaseArgument*) override { return true; }
|
||||
|
||||
static auto MakeArgument(const AsTensorTuple& as_m_k,
|
||||
const BsTensorTuple& bs_k_n,
|
||||
const DsTensorTuple& ds_m_n,
|
||||
Tensor<EDataType>& e_m_n,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CDEElementwiseOperation cde_element_op)
|
||||
{
|
||||
return Argument{as_m_k, bs_k_n, ds_m_n, e_m_n, a_element_op, b_element_op, cde_element_op};
|
||||
}
|
||||
|
||||
static auto MakeInvoker() { return Invoker{}; }
|
||||
|
||||
virtual std::unique_ptr<device::BaseInvoker> MakeInvokerPointer()
|
||||
{
|
||||
return std::make_unique<Invoker>(Invoker{});
|
||||
}
|
||||
|
||||
std::string GetTypeString() const override
|
||||
{
|
||||
auto str = std::stringstream();
|
||||
|
||||
// clang-format off
|
||||
str << "ReferenceGemmMultiABD"
|
||||
<< std::endl;
|
||||
// clang-format on
|
||||
|
||||
return str.str();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace host
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -10,6 +10,7 @@
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_grouped_gemm_multi_abd_fixed_nk.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
@@ -20,7 +21,6 @@ using Multiply = ck::tensor_operation::element_wise::Multiply;
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
using AddFastGelu = ck::tensor_operation::element_wise::AddFastGelu;
|
||||
|
||||
#if defined(CK_USE_XDL)
|
||||
// RRR
|
||||
void add_device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_mk_kn_mn_bias_gelu_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedGemmMultiABDFixedNK<ck::Tuple<Row>,
|
||||
@@ -179,167 +179,6 @@ void add_device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_km_kn_mn_instan
|
||||
PassThrough,
|
||||
Multiply,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
|
||||
#if defined(CK_USE_WMMA)
|
||||
void add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_kn_mn_bias_gelu_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedGemmMultiABDFixedNK<ck::Tuple<Row>,
|
||||
ck::Tuple<Row, Row>,
|
||||
ck::Tuple<Row>,
|
||||
Row,
|
||||
ck::Tuple<BF16>,
|
||||
ck::Tuple<I8, BF16>,
|
||||
ck::Tuple<BF16>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
Multiply,
|
||||
AddFastGelu>>>& instances);
|
||||
|
||||
void add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_kn_mn_bias_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedGemmMultiABDFixedNK<ck::Tuple<Row>,
|
||||
ck::Tuple<Row, Row>,
|
||||
ck::Tuple<Row>,
|
||||
Row,
|
||||
ck::Tuple<BF16>,
|
||||
ck::Tuple<I8, BF16>,
|
||||
ck::Tuple<BF16>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
Multiply,
|
||||
Add>>>& instances);
|
||||
|
||||
void add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_kn_mn_gelu_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedGemmMultiABDFixedNK<ck::Tuple<Row>,
|
||||
ck::Tuple<Row, Row>,
|
||||
ck::Tuple<>,
|
||||
Row,
|
||||
ck::Tuple<BF16>,
|
||||
ck::Tuple<I8, BF16>,
|
||||
ck::Tuple<>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
Multiply,
|
||||
FastGelu>>>& instances);
|
||||
|
||||
void add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedGemmMultiABDFixedNK<ck::Tuple<Row>,
|
||||
ck::Tuple<Row, Row>,
|
||||
ck::Tuple<>,
|
||||
Row,
|
||||
ck::Tuple<BF16>,
|
||||
ck::Tuple<I8, BF16>,
|
||||
ck::Tuple<>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
Multiply,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
// RCR
|
||||
void add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_nk_mn_bias_gelu_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedGemmMultiABDFixedNK<ck::Tuple<Row>,
|
||||
ck::Tuple<Col, Col>,
|
||||
ck::Tuple<Row>,
|
||||
Row,
|
||||
ck::Tuple<BF16>,
|
||||
ck::Tuple<I8, BF16>,
|
||||
ck::Tuple<BF16>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
Multiply,
|
||||
AddFastGelu>>>& instances);
|
||||
|
||||
void add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_nk_mn_bias_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedGemmMultiABDFixedNK<ck::Tuple<Row>,
|
||||
ck::Tuple<Col, Col>,
|
||||
ck::Tuple<Row>,
|
||||
Row,
|
||||
ck::Tuple<BF16>,
|
||||
ck::Tuple<I8, BF16>,
|
||||
ck::Tuple<BF16>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
Multiply,
|
||||
Add>>>& instances);
|
||||
|
||||
void add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_nk_mn_gelu_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedGemmMultiABDFixedNK<ck::Tuple<Row>,
|
||||
ck::Tuple<Col, Col>,
|
||||
ck::Tuple<>,
|
||||
Row,
|
||||
ck::Tuple<BF16>,
|
||||
ck::Tuple<I8, BF16>,
|
||||
ck::Tuple<>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
Multiply,
|
||||
FastGelu>>>& instances);
|
||||
|
||||
void add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedGemmMultiABDFixedNK<ck::Tuple<Row>,
|
||||
ck::Tuple<Col, Col>,
|
||||
ck::Tuple<>,
|
||||
Row,
|
||||
ck::Tuple<BF16>,
|
||||
ck::Tuple<I8, BF16>,
|
||||
ck::Tuple<>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
Multiply,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
// CRR
|
||||
void add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_km_kn_mn_bias_gelu_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedGemmMultiABDFixedNK<ck::Tuple<Col>,
|
||||
ck::Tuple<Row, Row>,
|
||||
ck::Tuple<Row>,
|
||||
Row,
|
||||
ck::Tuple<BF16>,
|
||||
ck::Tuple<I8, BF16>,
|
||||
ck::Tuple<BF16>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
Multiply,
|
||||
AddFastGelu>>>& instances);
|
||||
|
||||
void add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_km_kn_mn_bias_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedGemmMultiABDFixedNK<ck::Tuple<Col>,
|
||||
ck::Tuple<Row, Row>,
|
||||
ck::Tuple<Row>,
|
||||
Row,
|
||||
ck::Tuple<BF16>,
|
||||
ck::Tuple<I8, BF16>,
|
||||
ck::Tuple<BF16>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
Multiply,
|
||||
Add>>>& instances);
|
||||
|
||||
void add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_km_kn_mn_gelu_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedGemmMultiABDFixedNK<ck::Tuple<Col>,
|
||||
ck::Tuple<Row, Row>,
|
||||
ck::Tuple<>,
|
||||
Row,
|
||||
ck::Tuple<BF16>,
|
||||
ck::Tuple<I8, BF16>,
|
||||
ck::Tuple<>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
Multiply,
|
||||
FastGelu>>>& instances);
|
||||
|
||||
void add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_km_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedGemmMultiABDFixedNK<ck::Tuple<Col>,
|
||||
ck::Tuple<Row, Row>,
|
||||
ck::Tuple<>,
|
||||
Row,
|
||||
ck::Tuple<BF16>,
|
||||
ck::Tuple<I8, BF16>,
|
||||
ck::Tuple<>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
Multiply,
|
||||
PassThrough>>>& instances);
|
||||
#endif // CK_USE
|
||||
|
||||
// GEMM + Add + Gelu
|
||||
template <typename AsLayout,
|
||||
@@ -379,7 +218,6 @@ struct DeviceOperationInstanceFactory<
|
||||
{
|
||||
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
|
||||
|
||||
#if defined(CK_USE_XDL)
|
||||
if constexpr(is_same_v<AsDataType, ck::Tuple<BF16>> &&
|
||||
is_same_v<BsDataType, ck::Tuple<I8, BF16>> &&
|
||||
is_same_v<DsDataType, ck::Tuple<BF16>> && is_same_v<EDataType, BF16>)
|
||||
@@ -408,38 +246,6 @@ struct DeviceOperationInstanceFactory<
|
||||
op_ptrs);
|
||||
}
|
||||
}
|
||||
#endif // CK_USE_XDL
|
||||
|
||||
#if defined(CK_USE_WMMA)
|
||||
if constexpr(is_same_v<AsDataType, ck::Tuple<BF16>> &&
|
||||
is_same_v<BsDataType, ck::Tuple<I8, BF16>> &&
|
||||
is_same_v<DsDataType, ck::Tuple<BF16>> && is_same_v<EDataType, BF16>)
|
||||
{
|
||||
if constexpr(is_same_v<AsLayout, ck::Tuple<Row>> &&
|
||||
is_same_v<BsLayout, ck::Tuple<Row, Row>> &&
|
||||
is_same_v<DsLayout, ck::Tuple<Row>> && is_same_v<ELayout, Row>)
|
||||
{
|
||||
add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_kn_mn_bias_gelu_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
|
||||
if constexpr(is_same_v<AsLayout, ck::Tuple<Col>> &&
|
||||
is_same_v<BsLayout, ck::Tuple<Row, Row>> &&
|
||||
is_same_v<DsLayout, ck::Tuple<Row>> && is_same_v<ELayout, Row>)
|
||||
{
|
||||
add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_km_kn_mn_bias_gelu_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
|
||||
if constexpr(is_same_v<AsLayout, ck::Tuple<Row>> &&
|
||||
is_same_v<BsLayout, ck::Tuple<Col, Col>> &&
|
||||
is_same_v<DsLayout, ck::Tuple<Row>> && is_same_v<ELayout, Row>)
|
||||
{
|
||||
add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_nk_mn_bias_gelu_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
}
|
||||
#endif // CK_USE_WMMA
|
||||
|
||||
return op_ptrs;
|
||||
}
|
||||
@@ -483,7 +289,6 @@ struct DeviceOperationInstanceFactory<
|
||||
{
|
||||
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
|
||||
|
||||
#if defined(CK_USE_XDL)
|
||||
if constexpr(is_same_v<AsDataType, ck::Tuple<BF16>> &&
|
||||
is_same_v<BsDataType, ck::Tuple<I8, BF16>> &&
|
||||
is_same_v<DsDataType, ck::Tuple<BF16>> && is_same_v<EDataType, BF16>)
|
||||
@@ -512,38 +317,6 @@ struct DeviceOperationInstanceFactory<
|
||||
op_ptrs);
|
||||
}
|
||||
}
|
||||
#endif // CK_USE_XDL
|
||||
|
||||
#if defined(CK_USE_WMMA)
|
||||
if constexpr(is_same_v<AsDataType, ck::Tuple<BF16>> &&
|
||||
is_same_v<BsDataType, ck::Tuple<I8, BF16>> &&
|
||||
is_same_v<DsDataType, ck::Tuple<BF16>> && is_same_v<EDataType, BF16>)
|
||||
{
|
||||
if constexpr(is_same_v<AsLayout, ck::Tuple<Row>> &&
|
||||
is_same_v<BsLayout, ck::Tuple<Row, Row>> &&
|
||||
is_same_v<DsLayout, ck::Tuple<Row>> && is_same_v<ELayout, Row>)
|
||||
{
|
||||
add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_kn_mn_bias_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
|
||||
if constexpr(is_same_v<AsLayout, ck::Tuple<Col>> &&
|
||||
is_same_v<BsLayout, ck::Tuple<Row, Row>> &&
|
||||
is_same_v<DsLayout, ck::Tuple<Row>> && is_same_v<ELayout, Row>)
|
||||
{
|
||||
add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_km_kn_mn_bias_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
|
||||
if constexpr(is_same_v<AsLayout, ck::Tuple<Row>> &&
|
||||
is_same_v<BsLayout, ck::Tuple<Col, Col>> &&
|
||||
is_same_v<DsLayout, ck::Tuple<Row>> && is_same_v<ELayout, Row>)
|
||||
{
|
||||
add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_nk_mn_bias_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
}
|
||||
#endif // CK_USE_WMMA
|
||||
|
||||
return op_ptrs;
|
||||
}
|
||||
@@ -587,7 +360,6 @@ struct DeviceOperationInstanceFactory<
|
||||
{
|
||||
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
|
||||
|
||||
#if defined(CK_USE_XDL)
|
||||
if constexpr(is_same_v<AsDataType, ck::Tuple<BF16>> &&
|
||||
is_same_v<BsDataType, ck::Tuple<I8, BF16>> &&
|
||||
is_same_v<DsDataType, ck::Tuple<>> && is_same_v<EDataType, BF16>)
|
||||
@@ -616,38 +388,6 @@ struct DeviceOperationInstanceFactory<
|
||||
op_ptrs);
|
||||
}
|
||||
}
|
||||
#endif // CK_USE_XDL
|
||||
|
||||
#if defined(CK_USE_WMMA)
|
||||
if constexpr(is_same_v<AsDataType, ck::Tuple<BF16>> &&
|
||||
is_same_v<BsDataType, ck::Tuple<I8, BF16>> &&
|
||||
is_same_v<DsDataType, ck::Tuple<>> && is_same_v<EDataType, BF16>)
|
||||
{
|
||||
if constexpr(is_same_v<AsLayout, ck::Tuple<Row>> &&
|
||||
is_same_v<BsLayout, ck::Tuple<Row, Row>> &&
|
||||
is_same_v<DsLayout, ck::Tuple<>> && is_same_v<ELayout, Row>)
|
||||
{
|
||||
add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_kn_mn_gelu_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
|
||||
if constexpr(is_same_v<AsLayout, ck::Tuple<Col>> &&
|
||||
is_same_v<BsLayout, ck::Tuple<Row, Row>> &&
|
||||
is_same_v<DsLayout, ck::Tuple<>> && is_same_v<ELayout, Row>)
|
||||
{
|
||||
add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_km_kn_mn_gelu_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
|
||||
if constexpr(is_same_v<AsLayout, ck::Tuple<Row>> &&
|
||||
is_same_v<BsLayout, ck::Tuple<Col, Col>> &&
|
||||
is_same_v<DsLayout, ck::Tuple<>> && is_same_v<ELayout, Row>)
|
||||
{
|
||||
add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_nk_mn_gelu_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
}
|
||||
#endif // CK_USE_WMMA
|
||||
|
||||
return op_ptrs;
|
||||
}
|
||||
@@ -691,7 +431,6 @@ struct DeviceOperationInstanceFactory<
|
||||
{
|
||||
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
|
||||
|
||||
#if defined(CK_USE_XDL)
|
||||
if constexpr(is_same_v<AsDataType, ck::Tuple<BF16>> &&
|
||||
is_same_v<BsDataType, ck::Tuple<I8, BF16>> &&
|
||||
is_same_v<DsDataType, ck::Tuple<>> && is_same_v<EDataType, BF16>)
|
||||
@@ -720,38 +459,6 @@ struct DeviceOperationInstanceFactory<
|
||||
op_ptrs);
|
||||
}
|
||||
}
|
||||
#endif // CK_USE_XDL
|
||||
|
||||
#if defined(CK_USE_WMMA)
|
||||
if constexpr(is_same_v<AsDataType, ck::Tuple<BF16>> &&
|
||||
is_same_v<BsDataType, ck::Tuple<I8, BF16>> &&
|
||||
is_same_v<DsDataType, ck::Tuple<>> && is_same_v<EDataType, BF16>)
|
||||
{
|
||||
if constexpr(is_same_v<AsLayout, ck::Tuple<Row>> &&
|
||||
is_same_v<BsLayout, ck::Tuple<Row, Row>> &&
|
||||
is_same_v<DsLayout, ck::Tuple<>> && is_same_v<ELayout, Row>)
|
||||
{
|
||||
add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_kn_mn_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
|
||||
if constexpr(is_same_v<AsLayout, ck::Tuple<Col>> &&
|
||||
is_same_v<BsLayout, ck::Tuple<Row, Row>> &&
|
||||
is_same_v<DsLayout, ck::Tuple<>> && is_same_v<ELayout, Row>)
|
||||
{
|
||||
add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_km_kn_mn_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
|
||||
if constexpr(is_same_v<AsLayout, ck::Tuple<Row>> &&
|
||||
is_same_v<BsLayout, ck::Tuple<Col, Col>> &&
|
||||
is_same_v<DsLayout, ck::Tuple<>> && is_same_v<ELayout, Row>)
|
||||
{
|
||||
add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_nk_mn_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
}
|
||||
#endif // CK_USE_WMMA
|
||||
|
||||
return op_ptrs;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user