mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-24 14:54:47 +00:00
Added Multi_ABD support into Gemm and GroupedGemmFixedNK (#978)
* added an example grouped_gemm_multi_abd * fixed ci * add setElementwiseOp * changed API * clean code: add multiA into example * fixed v7r2 copy * add transpose * clean * fixed vector_load check * Update example/15_grouped_gemm/grouped_gemm_multi_abd_xdl_fixed_nk_bias_fp16.cpp Co-authored-by: Bartłomiej Kocot <barkocot@amd.com> * Update example/15_grouped_gemm/grouped_gemm_multi_abd_xdl_fixed_nk_bias_fp16.cpp Co-authored-by: Bartłomiej Kocot <barkocot@amd.com> * Update example/15_grouped_gemm/grouped_gemm_multi_abd_xdl_fixed_nk_bias_fp16.cpp Co-authored-by: Bartłomiej Kocot <barkocot@amd.com> * Update include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp Co-authored-by: Bartłomiej Kocot <barkocot@amd.com> * Update include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp Co-authored-by: Bartłomiej Kocot <barkocot@amd.com> * Update include/ck/tensor_operation/gpu/device/device_grouped_gemm_multi_abd.hpp Co-authored-by: Bartłomiej Kocot <barkocot@amd.com> * Update include/ck/tensor_operation/gpu/device/device_grouped_gemm_multi_abd.hpp Co-authored-by: Bartłomiej Kocot <barkocot@amd.com> * Update include/ck/tensor_operation/gpu/device/device_grouped_gemm_multi_abd.hpp Co-authored-by: Bartłomiej Kocot <barkocot@amd.com> * Update include/ck/tensor_operation/gpu/device/device_grouped_gemm_multi_abd.hpp Co-authored-by: Bartłomiej Kocot <barkocot@amd.com> * Update include/ck/tensor_operation/gpu/device/device_grouped_gemm_multi_abd.hpp Co-authored-by: Bartłomiej Kocot <barkocot@amd.com> * Update include/ck/tensor_operation/gpu/device/device_grouped_gemm_multi_abd.hpp Co-authored-by: Bartłomiej Kocot <barkocot@amd.com> * Update include/ck/tensor_operation/gpu/device/device_grouped_gemm_multi_abd_fixed_nk.hpp Co-authored-by: Bartłomiej Kocot <barkocot@amd.com> * Update include/ck/tensor_operation/gpu/device/device_grouped_gemm_multi_abd_fixed_nk.hpp Co-authored-by: Bartłomiej Kocot <barkocot@amd.com> * add reduce * testing * add example_b16_i8 * refactor example * clean * add mpading * disable reduce for kbatch = 1 * seperate reduce device op * add reduce op * add guard for workspace_size * add instances * format * fixed * add client example * add a colmajor * add instances * Update cmake-ck-dev.sh * Update profile_gemm_splitk.cpp * Update gridwise_gemm_xdlops_v2r4r2.hpp * format * Update profile_gemm_splitk.cpp * fixed * fixed * adjust test * adjust precision loss * adjust test * fixed * add bf16_i8 scale bias * fixed scale * fixed scale elementwise_op * revert contraction deviceop changes * fixed * Add AddFastGelu * Revert "Merge branch 'jizhan/gemm_splitk_reduce' into grouped_gemm_multi_abd_fixed_nk_example" This reverts commit3b5d001efd, reversing changes made to943199a991. * add Scales into elementwise * add gemm_multi_abd client example * add client examples * add rcr and crr * add grouped gemm client example * add grouped gemm client example * add instance for rcr crr * format * fixed * fixed cmake * fixed * fixed client_example * format * fixed contraction isSupport * Update include/ck/tensor_operation/gpu/device/device_grouped_gemm_multi_abd_fixed_nk.hpp Co-authored-by: Bartłomiej Kocot <barkocot@amd.com> * Update device_reduce_threadwise.hpp * clean * Fixes * Fix example --------- Co-authored-by: Jing Zhang <jizha@amd.com> Co-authored-by: Bartłomiej Kocot <barkocot@amd.com>
This commit is contained in:
@@ -0,0 +1,468 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cstdlib>
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_gemm_multiple_abd.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_abd_xdl_cshuffle.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
using Scales = ck::tensor_operation::element_wise::Scales;
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
using AddFastGelu = ck::tensor_operation::element_wise::AddFastGelu;
|
||||
|
||||
#ifdef CK_ENABLE_INT8
|
||||
// RRR
|
||||
void add_device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_bias_gelu_v1_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleABD<ck::Tuple<Row>,
|
||||
ck::Tuple<Row, Row>,
|
||||
ck::Tuple<Row>,
|
||||
Row,
|
||||
ck::Tuple<BF16>,
|
||||
ck::Tuple<I8, BF16>,
|
||||
ck::Tuple<BF16>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
Scales,
|
||||
AddFastGelu>>>& instances);
|
||||
|
||||
void add_device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_bias_v1_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleABD<ck::Tuple<Row>,
|
||||
ck::Tuple<Row, Row>,
|
||||
ck::Tuple<Row>,
|
||||
Row,
|
||||
ck::Tuple<BF16>,
|
||||
ck::Tuple<I8, BF16>,
|
||||
ck::Tuple<BF16>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
Scales,
|
||||
Add>>>& instances);
|
||||
|
||||
void add_device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_gelu_v1_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleABD<ck::Tuple<Row>,
|
||||
ck::Tuple<Row, Row>,
|
||||
ck::Tuple<>,
|
||||
Row,
|
||||
ck::Tuple<BF16>,
|
||||
ck::Tuple<I8, BF16>,
|
||||
ck::Tuple<>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
Scales,
|
||||
FastGelu>>>& instances);
|
||||
|
||||
void add_device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_v1_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleABD<ck::Tuple<Row>,
|
||||
ck::Tuple<Row, Row>,
|
||||
ck::Tuple<>,
|
||||
Row,
|
||||
ck::Tuple<BF16>,
|
||||
ck::Tuple<I8, BF16>,
|
||||
ck::Tuple<>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
Scales,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
// RCR
|
||||
void add_device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_nk_mn_bias_gelu_v1_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleABD<ck::Tuple<Row>,
|
||||
ck::Tuple<Col, Col>,
|
||||
ck::Tuple<Row>,
|
||||
Row,
|
||||
ck::Tuple<BF16>,
|
||||
ck::Tuple<I8, BF16>,
|
||||
ck::Tuple<BF16>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
Scales,
|
||||
AddFastGelu>>>& instances);
|
||||
|
||||
void add_device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_nk_mn_bias_v1_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleABD<ck::Tuple<Row>,
|
||||
ck::Tuple<Col, Col>,
|
||||
ck::Tuple<Row>,
|
||||
Row,
|
||||
ck::Tuple<BF16>,
|
||||
ck::Tuple<I8, BF16>,
|
||||
ck::Tuple<BF16>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
Scales,
|
||||
Add>>>& instances);
|
||||
|
||||
void add_device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_nk_mn_gelu_v1_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleABD<ck::Tuple<Row>,
|
||||
ck::Tuple<Col, Col>,
|
||||
ck::Tuple<>,
|
||||
Row,
|
||||
ck::Tuple<BF16>,
|
||||
ck::Tuple<I8, BF16>,
|
||||
ck::Tuple<>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
Scales,
|
||||
FastGelu>>>& instances);
|
||||
|
||||
void add_device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_nk_mn_v1_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleABD<ck::Tuple<Row>,
|
||||
ck::Tuple<Col, Col>,
|
||||
ck::Tuple<>,
|
||||
Row,
|
||||
ck::Tuple<BF16>,
|
||||
ck::Tuple<I8, BF16>,
|
||||
ck::Tuple<>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
Scales,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
// CRR
|
||||
void add_device_gemm_xdl_multi_abd_bf16_i8_bf16_km_kn_mn_bias_gelu_v1_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleABD<ck::Tuple<Col>,
|
||||
ck::Tuple<Row, Row>,
|
||||
ck::Tuple<Row>,
|
||||
Row,
|
||||
ck::Tuple<BF16>,
|
||||
ck::Tuple<I8, BF16>,
|
||||
ck::Tuple<BF16>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
Scales,
|
||||
AddFastGelu>>>& instances);
|
||||
|
||||
void add_device_gemm_xdl_multi_abd_bf16_i8_bf16_km_kn_mn_bias_v1_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleABD<ck::Tuple<Col>,
|
||||
ck::Tuple<Row, Row>,
|
||||
ck::Tuple<Row>,
|
||||
Row,
|
||||
ck::Tuple<BF16>,
|
||||
ck::Tuple<I8, BF16>,
|
||||
ck::Tuple<BF16>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
Scales,
|
||||
Add>>>& instances);
|
||||
|
||||
void add_device_gemm_xdl_multi_abd_bf16_i8_bf16_km_kn_mn_gelu_v1_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleABD<ck::Tuple<Col>,
|
||||
ck::Tuple<Row, Row>,
|
||||
ck::Tuple<>,
|
||||
Row,
|
||||
ck::Tuple<BF16>,
|
||||
ck::Tuple<I8, BF16>,
|
||||
ck::Tuple<>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
Scales,
|
||||
FastGelu>>>& instances);
|
||||
|
||||
void add_device_gemm_xdl_multi_abd_bf16_i8_bf16_km_kn_mn_v1_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleABD<ck::Tuple<Col>,
|
||||
ck::Tuple<Row, Row>,
|
||||
ck::Tuple<>,
|
||||
Row,
|
||||
ck::Tuple<BF16>,
|
||||
ck::Tuple<I8, BF16>,
|
||||
ck::Tuple<>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
Scales,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
|
||||
// GEMM + Add + Gelu
|
||||
template <typename AsLayout,
|
||||
typename BsLayout,
|
||||
typename DsLayout,
|
||||
typename ELayout,
|
||||
typename AsDataType,
|
||||
typename BsDataType,
|
||||
typename DsDataType,
|
||||
typename EDataType>
|
||||
struct DeviceOperationInstanceFactory<
|
||||
ck::tensor_operation::device::DeviceGemmMultipleABD<AsLayout,
|
||||
BsLayout,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
AsDataType,
|
||||
BsDataType,
|
||||
DsDataType,
|
||||
EDataType,
|
||||
PassThrough,
|
||||
Scales,
|
||||
AddFastGelu>>
|
||||
{
|
||||
using DeviceOp = DeviceGemmMultipleABD<AsLayout,
|
||||
BsLayout,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
AsDataType,
|
||||
BsDataType,
|
||||
DsDataType,
|
||||
EDataType,
|
||||
PassThrough,
|
||||
Scales,
|
||||
AddFastGelu>;
|
||||
|
||||
static auto GetInstances()
|
||||
{
|
||||
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
|
||||
|
||||
#ifdef CK_ENABLE_INT8
|
||||
if constexpr(is_same_v<AsDataType, ck::Tuple<BF16>> &&
|
||||
is_same_v<BsDataType, ck::Tuple<I8, BF16>> &&
|
||||
is_same_v<DsDataType, ck::Tuple<BF16>> && is_same_v<EDataType, BF16>)
|
||||
{
|
||||
if constexpr(is_same_v<AsLayout, ck::Tuple<Row>> &&
|
||||
is_same_v<BsLayout, ck::Tuple<Row, Row>> &&
|
||||
is_same_v<DsLayout, ck::Tuple<Row>> && is_same_v<ELayout, Row>)
|
||||
{
|
||||
add_device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_bias_gelu_v1_instances(op_ptrs);
|
||||
}
|
||||
|
||||
if constexpr(is_same_v<AsLayout, ck::Tuple<Col>> &&
|
||||
is_same_v<BsLayout, ck::Tuple<Row, Row>> &&
|
||||
is_same_v<DsLayout, ck::Tuple<Row>> && is_same_v<ELayout, Row>)
|
||||
{
|
||||
add_device_gemm_xdl_multi_abd_bf16_i8_bf16_km_kn_mn_bias_gelu_v1_instances(op_ptrs);
|
||||
}
|
||||
|
||||
if constexpr(is_same_v<AsLayout, ck::Tuple<Row>> &&
|
||||
is_same_v<BsLayout, ck::Tuple<Col, Col>> &&
|
||||
is_same_v<DsLayout, ck::Tuple<Row>> && is_same_v<ELayout, Row>)
|
||||
{
|
||||
add_device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_nk_mn_bias_gelu_v1_instances(op_ptrs);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
return op_ptrs;
|
||||
}
|
||||
};
|
||||
|
||||
// GEMM + Add
|
||||
template <typename AsLayout,
|
||||
typename BsLayout,
|
||||
typename DsLayout,
|
||||
typename ELayout,
|
||||
typename AsDataType,
|
||||
typename BsDataType,
|
||||
typename DsDataType,
|
||||
typename EDataType>
|
||||
struct DeviceOperationInstanceFactory<
|
||||
ck::tensor_operation::device::DeviceGemmMultipleABD<AsLayout,
|
||||
BsLayout,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
AsDataType,
|
||||
BsDataType,
|
||||
DsDataType,
|
||||
EDataType,
|
||||
PassThrough,
|
||||
Scales,
|
||||
Add>>
|
||||
{
|
||||
using DeviceOp = DeviceGemmMultipleABD<AsLayout,
|
||||
BsLayout,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
AsDataType,
|
||||
BsDataType,
|
||||
DsDataType,
|
||||
EDataType,
|
||||
PassThrough,
|
||||
Scales,
|
||||
Add>;
|
||||
|
||||
static auto GetInstances()
|
||||
{
|
||||
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
|
||||
|
||||
#ifdef CK_ENABLE_INT8
|
||||
if constexpr(is_same_v<AsDataType, ck::Tuple<BF16>> &&
|
||||
is_same_v<BsDataType, ck::Tuple<I8, BF16>> &&
|
||||
is_same_v<DsDataType, ck::Tuple<BF16>> && is_same_v<EDataType, BF16>)
|
||||
{
|
||||
if constexpr(is_same_v<AsLayout, ck::Tuple<Row>> &&
|
||||
is_same_v<BsLayout, ck::Tuple<Row, Row>> &&
|
||||
is_same_v<DsLayout, ck::Tuple<Row>> && is_same_v<ELayout, Row>)
|
||||
{
|
||||
add_device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_bias_v1_instances(op_ptrs);
|
||||
}
|
||||
|
||||
if constexpr(is_same_v<AsLayout, ck::Tuple<Col>> &&
|
||||
is_same_v<BsLayout, ck::Tuple<Row, Row>> &&
|
||||
is_same_v<DsLayout, ck::Tuple<Row>> && is_same_v<ELayout, Row>)
|
||||
{
|
||||
add_device_gemm_xdl_multi_abd_bf16_i8_bf16_km_kn_mn_bias_v1_instances(op_ptrs);
|
||||
}
|
||||
|
||||
if constexpr(is_same_v<AsLayout, ck::Tuple<Row>> &&
|
||||
is_same_v<BsLayout, ck::Tuple<Col, Col>> &&
|
||||
is_same_v<DsLayout, ck::Tuple<Row>> && is_same_v<ELayout, Row>)
|
||||
{
|
||||
add_device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_nk_mn_bias_v1_instances(op_ptrs);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
return op_ptrs;
|
||||
}
|
||||
};
|
||||
|
||||
// GEMM + Gelu
|
||||
template <typename AsLayout,
|
||||
typename BsLayout,
|
||||
typename DsLayout,
|
||||
typename ELayout,
|
||||
typename AsDataType,
|
||||
typename BsDataType,
|
||||
typename DsDataType,
|
||||
typename EDataType>
|
||||
struct DeviceOperationInstanceFactory<
|
||||
ck::tensor_operation::device::DeviceGemmMultipleABD<AsLayout,
|
||||
BsLayout,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
AsDataType,
|
||||
BsDataType,
|
||||
DsDataType,
|
||||
EDataType,
|
||||
PassThrough,
|
||||
Scales,
|
||||
FastGelu>>
|
||||
{
|
||||
using DeviceOp = DeviceGemmMultipleABD<AsLayout,
|
||||
BsLayout,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
AsDataType,
|
||||
BsDataType,
|
||||
DsDataType,
|
||||
EDataType,
|
||||
PassThrough,
|
||||
Scales,
|
||||
FastGelu>;
|
||||
|
||||
static auto GetInstances()
|
||||
{
|
||||
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
|
||||
|
||||
#ifdef CK_ENABLE_INT8
|
||||
if constexpr(is_same_v<AsDataType, ck::Tuple<BF16>> &&
|
||||
is_same_v<BsDataType, ck::Tuple<I8, BF16>> &&
|
||||
is_same_v<DsDataType, ck::Tuple<>> && is_same_v<EDataType, BF16>)
|
||||
{
|
||||
if constexpr(is_same_v<AsLayout, ck::Tuple<Row>> &&
|
||||
is_same_v<BsLayout, ck::Tuple<Row, Row>> &&
|
||||
is_same_v<DsLayout, ck::Tuple<>> && is_same_v<ELayout, Row>)
|
||||
{
|
||||
add_device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_gelu_v1_instances(op_ptrs);
|
||||
}
|
||||
|
||||
if constexpr(is_same_v<AsLayout, ck::Tuple<Col>> &&
|
||||
is_same_v<BsLayout, ck::Tuple<Row, Row>> &&
|
||||
is_same_v<DsLayout, ck::Tuple<>> && is_same_v<ELayout, Row>)
|
||||
{
|
||||
add_device_gemm_xdl_multi_abd_bf16_i8_bf16_km_kn_mn_gelu_v1_instances(op_ptrs);
|
||||
}
|
||||
|
||||
if constexpr(is_same_v<AsLayout, ck::Tuple<Row>> &&
|
||||
is_same_v<BsLayout, ck::Tuple<Col, Col>> &&
|
||||
is_same_v<DsLayout, ck::Tuple<>> && is_same_v<ELayout, Row>)
|
||||
{
|
||||
add_device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_nk_mn_gelu_v1_instances(op_ptrs);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
return op_ptrs;
|
||||
}
|
||||
};
|
||||
|
||||
// GEMM
|
||||
template <typename AsLayout,
|
||||
typename BsLayout,
|
||||
typename DsLayout,
|
||||
typename ELayout,
|
||||
typename AsDataType,
|
||||
typename BsDataType,
|
||||
typename DsDataType,
|
||||
typename EDataType>
|
||||
struct DeviceOperationInstanceFactory<
|
||||
ck::tensor_operation::device::DeviceGemmMultipleABD<AsLayout,
|
||||
BsLayout,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
AsDataType,
|
||||
BsDataType,
|
||||
DsDataType,
|
||||
EDataType,
|
||||
PassThrough,
|
||||
Scales,
|
||||
PassThrough>>
|
||||
{
|
||||
using DeviceOp = DeviceGemmMultipleABD<AsLayout,
|
||||
BsLayout,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
AsDataType,
|
||||
BsDataType,
|
||||
DsDataType,
|
||||
EDataType,
|
||||
PassThrough,
|
||||
Scales,
|
||||
PassThrough>;
|
||||
|
||||
static auto GetInstances()
|
||||
{
|
||||
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
|
||||
|
||||
#ifdef CK_ENABLE_INT8
|
||||
if constexpr(is_same_v<AsDataType, ck::Tuple<BF16>> &&
|
||||
is_same_v<BsDataType, ck::Tuple<I8, BF16>> &&
|
||||
is_same_v<DsDataType, ck::Tuple<>> && is_same_v<EDataType, BF16>)
|
||||
{
|
||||
if constexpr(is_same_v<AsLayout, ck::Tuple<Row>> &&
|
||||
is_same_v<BsLayout, ck::Tuple<Row, Row>> &&
|
||||
is_same_v<DsLayout, ck::Tuple<>> && is_same_v<ELayout, Row>)
|
||||
{
|
||||
add_device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_v1_instances(op_ptrs);
|
||||
}
|
||||
|
||||
if constexpr(is_same_v<AsLayout, ck::Tuple<Col>> &&
|
||||
is_same_v<BsLayout, ck::Tuple<Row, Row>> &&
|
||||
is_same_v<DsLayout, ck::Tuple<>> && is_same_v<ELayout, Row>)
|
||||
{
|
||||
add_device_gemm_xdl_multi_abd_bf16_i8_bf16_km_kn_mn_v1_instances(op_ptrs);
|
||||
}
|
||||
|
||||
if constexpr(is_same_v<AsLayout, ck::Tuple<Row>> &&
|
||||
is_same_v<BsLayout, ck::Tuple<Col, Col>> &&
|
||||
is_same_v<DsLayout, ck::Tuple<>> && is_same_v<ELayout, Row>)
|
||||
{
|
||||
add_device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_nk_mn_v1_instances(op_ptrs);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
return op_ptrs;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,470 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cstdlib>
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_grouped_gemm_multi_abd_fixed_nk.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
using Scales = ck::tensor_operation::element_wise::Scales;
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
using AddFastGelu = ck::tensor_operation::element_wise::AddFastGelu;
|
||||
|
||||
// RRR
|
||||
void add_device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_mk_kn_mn_bias_gelu_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedGemmMultiABDFixedNK<ck::Tuple<Row>,
|
||||
ck::Tuple<Row, Row>,
|
||||
ck::Tuple<Row>,
|
||||
Row,
|
||||
ck::Tuple<BF16>,
|
||||
ck::Tuple<I8, BF16>,
|
||||
ck::Tuple<BF16>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
Scales,
|
||||
AddFastGelu>>>& instances);
|
||||
|
||||
void add_device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_mk_kn_mn_bias_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedGemmMultiABDFixedNK<ck::Tuple<Row>,
|
||||
ck::Tuple<Row, Row>,
|
||||
ck::Tuple<Row>,
|
||||
Row,
|
||||
ck::Tuple<BF16>,
|
||||
ck::Tuple<I8, BF16>,
|
||||
ck::Tuple<BF16>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
Scales,
|
||||
Add>>>& instances);
|
||||
|
||||
void add_device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_mk_kn_mn_gelu_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedGemmMultiABDFixedNK<ck::Tuple<Row>,
|
||||
ck::Tuple<Row, Row>,
|
||||
ck::Tuple<>,
|
||||
Row,
|
||||
ck::Tuple<BF16>,
|
||||
ck::Tuple<I8, BF16>,
|
||||
ck::Tuple<>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
Scales,
|
||||
FastGelu>>>& instances);
|
||||
|
||||
void add_device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_mk_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedGemmMultiABDFixedNK<ck::Tuple<Row>,
|
||||
ck::Tuple<Row, Row>,
|
||||
ck::Tuple<>,
|
||||
Row,
|
||||
ck::Tuple<BF16>,
|
||||
ck::Tuple<I8, BF16>,
|
||||
ck::Tuple<>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
Scales,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
// RCR
|
||||
void add_device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_mk_nk_mn_bias_gelu_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedGemmMultiABDFixedNK<ck::Tuple<Row>,
|
||||
ck::Tuple<Col, Col>,
|
||||
ck::Tuple<Row>,
|
||||
Row,
|
||||
ck::Tuple<BF16>,
|
||||
ck::Tuple<I8, BF16>,
|
||||
ck::Tuple<BF16>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
Scales,
|
||||
AddFastGelu>>>& instances);
|
||||
|
||||
void add_device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_mk_nk_mn_bias_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedGemmMultiABDFixedNK<ck::Tuple<Row>,
|
||||
ck::Tuple<Col, Col>,
|
||||
ck::Tuple<Row>,
|
||||
Row,
|
||||
ck::Tuple<BF16>,
|
||||
ck::Tuple<I8, BF16>,
|
||||
ck::Tuple<BF16>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
Scales,
|
||||
Add>>>& instances);
|
||||
|
||||
void add_device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_mk_nk_mn_gelu_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedGemmMultiABDFixedNK<ck::Tuple<Row>,
|
||||
ck::Tuple<Col, Col>,
|
||||
ck::Tuple<>,
|
||||
Row,
|
||||
ck::Tuple<BF16>,
|
||||
ck::Tuple<I8, BF16>,
|
||||
ck::Tuple<>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
Scales,
|
||||
FastGelu>>>& instances);
|
||||
|
||||
void add_device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_mk_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedGemmMultiABDFixedNK<ck::Tuple<Row>,
|
||||
ck::Tuple<Col, Col>,
|
||||
ck::Tuple<>,
|
||||
Row,
|
||||
ck::Tuple<BF16>,
|
||||
ck::Tuple<I8, BF16>,
|
||||
ck::Tuple<>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
Scales,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
// CRR
|
||||
void add_device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_km_kn_mn_bias_gelu_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedGemmMultiABDFixedNK<ck::Tuple<Col>,
|
||||
ck::Tuple<Row, Row>,
|
||||
ck::Tuple<Row>,
|
||||
Row,
|
||||
ck::Tuple<BF16>,
|
||||
ck::Tuple<I8, BF16>,
|
||||
ck::Tuple<BF16>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
Scales,
|
||||
AddFastGelu>>>& instances);
|
||||
|
||||
void add_device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_km_kn_mn_bias_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedGemmMultiABDFixedNK<ck::Tuple<Col>,
|
||||
ck::Tuple<Row, Row>,
|
||||
ck::Tuple<Row>,
|
||||
Row,
|
||||
ck::Tuple<BF16>,
|
||||
ck::Tuple<I8, BF16>,
|
||||
ck::Tuple<BF16>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
Scales,
|
||||
Add>>>& instances);
|
||||
|
||||
void add_device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_km_kn_mn_gelu_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedGemmMultiABDFixedNK<ck::Tuple<Col>,
|
||||
ck::Tuple<Row, Row>,
|
||||
ck::Tuple<>,
|
||||
Row,
|
||||
ck::Tuple<BF16>,
|
||||
ck::Tuple<I8, BF16>,
|
||||
ck::Tuple<>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
Scales,
|
||||
FastGelu>>>& instances);
|
||||
|
||||
void add_device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_km_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedGemmMultiABDFixedNK<ck::Tuple<Col>,
|
||||
ck::Tuple<Row, Row>,
|
||||
ck::Tuple<>,
|
||||
Row,
|
||||
ck::Tuple<BF16>,
|
||||
ck::Tuple<I8, BF16>,
|
||||
ck::Tuple<>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
Scales,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
// GEMM + Add + Gelu
|
||||
template <typename AsLayout,
|
||||
typename BsLayout,
|
||||
typename DsLayout,
|
||||
typename ELayout,
|
||||
typename AsDataType,
|
||||
typename BsDataType,
|
||||
typename DsDataType,
|
||||
typename EDataType>
|
||||
struct DeviceOperationInstanceFactory<
|
||||
ck::tensor_operation::device::DeviceGroupedGemmMultiABDFixedNK<AsLayout,
|
||||
BsLayout,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
AsDataType,
|
||||
BsDataType,
|
||||
DsDataType,
|
||||
EDataType,
|
||||
PassThrough,
|
||||
Scales,
|
||||
AddFastGelu>>
|
||||
{
|
||||
using DeviceOp = DeviceGroupedGemmMultiABDFixedNK<AsLayout,
|
||||
BsLayout,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
AsDataType,
|
||||
BsDataType,
|
||||
DsDataType,
|
||||
EDataType,
|
||||
PassThrough,
|
||||
Scales,
|
||||
AddFastGelu>;
|
||||
|
||||
static auto GetInstances()
|
||||
{
|
||||
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
|
||||
|
||||
if constexpr(is_same_v<AsDataType, ck::Tuple<BF16>> &&
|
||||
is_same_v<BsDataType, ck::Tuple<I8, BF16>> &&
|
||||
is_same_v<DsDataType, ck::Tuple<BF16>> && is_same_v<EDataType, BF16>)
|
||||
{
|
||||
if constexpr(is_same_v<AsLayout, ck::Tuple<Row>> &&
|
||||
is_same_v<BsLayout, ck::Tuple<Row, Row>> &&
|
||||
is_same_v<DsLayout, ck::Tuple<Row>> && is_same_v<ELayout, Row>)
|
||||
{
|
||||
add_device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_mk_kn_mn_bias_gelu_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
|
||||
if constexpr(is_same_v<AsLayout, ck::Tuple<Col>> &&
|
||||
is_same_v<BsLayout, ck::Tuple<Row, Row>> &&
|
||||
is_same_v<DsLayout, ck::Tuple<Row>> && is_same_v<ELayout, Row>)
|
||||
{
|
||||
add_device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_km_kn_mn_bias_gelu_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
|
||||
if constexpr(is_same_v<AsLayout, ck::Tuple<Row>> &&
|
||||
is_same_v<BsLayout, ck::Tuple<Col, Col>> &&
|
||||
is_same_v<DsLayout, ck::Tuple<Row>> && is_same_v<ELayout, Row>)
|
||||
{
|
||||
add_device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_mk_nk_mn_bias_gelu_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
}
|
||||
|
||||
return op_ptrs;
|
||||
}
|
||||
};
|
||||
|
||||
// GEMM + Add
|
||||
template <typename AsLayout,
|
||||
typename BsLayout,
|
||||
typename DsLayout,
|
||||
typename ELayout,
|
||||
typename AsDataType,
|
||||
typename BsDataType,
|
||||
typename DsDataType,
|
||||
typename EDataType>
|
||||
struct DeviceOperationInstanceFactory<
|
||||
ck::tensor_operation::device::DeviceGroupedGemmMultiABDFixedNK<AsLayout,
|
||||
BsLayout,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
AsDataType,
|
||||
BsDataType,
|
||||
DsDataType,
|
||||
EDataType,
|
||||
PassThrough,
|
||||
Scales,
|
||||
Add>>
|
||||
{
|
||||
using DeviceOp = DeviceGroupedGemmMultiABDFixedNK<AsLayout,
|
||||
BsLayout,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
AsDataType,
|
||||
BsDataType,
|
||||
DsDataType,
|
||||
EDataType,
|
||||
PassThrough,
|
||||
Scales,
|
||||
Add>;
|
||||
|
||||
static auto GetInstances()
|
||||
{
|
||||
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
|
||||
|
||||
if constexpr(is_same_v<AsDataType, ck::Tuple<BF16>> &&
|
||||
is_same_v<BsDataType, ck::Tuple<I8, BF16>> &&
|
||||
is_same_v<DsDataType, ck::Tuple<BF16>> && is_same_v<EDataType, BF16>)
|
||||
{
|
||||
if constexpr(is_same_v<AsLayout, ck::Tuple<Row>> &&
|
||||
is_same_v<BsLayout, ck::Tuple<Row, Row>> &&
|
||||
is_same_v<DsLayout, ck::Tuple<Row>> && is_same_v<ELayout, Row>)
|
||||
{
|
||||
add_device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_mk_kn_mn_bias_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
|
||||
if constexpr(is_same_v<AsLayout, ck::Tuple<Col>> &&
|
||||
is_same_v<BsLayout, ck::Tuple<Row, Row>> &&
|
||||
is_same_v<DsLayout, ck::Tuple<Row>> && is_same_v<ELayout, Row>)
|
||||
{
|
||||
add_device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_km_kn_mn_bias_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
|
||||
if constexpr(is_same_v<AsLayout, ck::Tuple<Row>> &&
|
||||
is_same_v<BsLayout, ck::Tuple<Col, Col>> &&
|
||||
is_same_v<DsLayout, ck::Tuple<Row>> && is_same_v<ELayout, Row>)
|
||||
{
|
||||
add_device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_mk_nk_mn_bias_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
}
|
||||
|
||||
return op_ptrs;
|
||||
}
|
||||
};
|
||||
|
||||
// GEMM + Gelu
|
||||
template <typename AsLayout,
|
||||
typename BsLayout,
|
||||
typename DsLayout,
|
||||
typename ELayout,
|
||||
typename AsDataType,
|
||||
typename BsDataType,
|
||||
typename DsDataType,
|
||||
typename EDataType>
|
||||
struct DeviceOperationInstanceFactory<
|
||||
ck::tensor_operation::device::DeviceGroupedGemmMultiABDFixedNK<AsLayout,
|
||||
BsLayout,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
AsDataType,
|
||||
BsDataType,
|
||||
DsDataType,
|
||||
EDataType,
|
||||
PassThrough,
|
||||
Scales,
|
||||
FastGelu>>
|
||||
{
|
||||
using DeviceOp = DeviceGroupedGemmMultiABDFixedNK<AsLayout,
|
||||
BsLayout,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
AsDataType,
|
||||
BsDataType,
|
||||
DsDataType,
|
||||
EDataType,
|
||||
PassThrough,
|
||||
Scales,
|
||||
FastGelu>;
|
||||
|
||||
static auto GetInstances()
|
||||
{
|
||||
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
|
||||
|
||||
if constexpr(is_same_v<AsDataType, ck::Tuple<BF16>> &&
|
||||
is_same_v<BsDataType, ck::Tuple<I8, BF16>> &&
|
||||
is_same_v<DsDataType, ck::Tuple<>> && is_same_v<EDataType, BF16>)
|
||||
{
|
||||
if constexpr(is_same_v<AsLayout, ck::Tuple<Row>> &&
|
||||
is_same_v<BsLayout, ck::Tuple<Row, Row>> &&
|
||||
is_same_v<DsLayout, ck::Tuple<>> && is_same_v<ELayout, Row>)
|
||||
{
|
||||
add_device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_mk_kn_mn_gelu_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
|
||||
if constexpr(is_same_v<AsLayout, ck::Tuple<Col>> &&
|
||||
is_same_v<BsLayout, ck::Tuple<Row, Row>> &&
|
||||
is_same_v<DsLayout, ck::Tuple<>> && is_same_v<ELayout, Row>)
|
||||
{
|
||||
add_device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_km_kn_mn_gelu_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
|
||||
if constexpr(is_same_v<AsLayout, ck::Tuple<Row>> &&
|
||||
is_same_v<BsLayout, ck::Tuple<Col, Col>> &&
|
||||
is_same_v<DsLayout, ck::Tuple<>> && is_same_v<ELayout, Row>)
|
||||
{
|
||||
add_device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_mk_nk_mn_gelu_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
}
|
||||
|
||||
return op_ptrs;
|
||||
}
|
||||
};
|
||||
|
||||
// GEMM
|
||||
template <typename AsLayout,
|
||||
typename BsLayout,
|
||||
typename DsLayout,
|
||||
typename ELayout,
|
||||
typename AsDataType,
|
||||
typename BsDataType,
|
||||
typename DsDataType,
|
||||
typename EDataType>
|
||||
struct DeviceOperationInstanceFactory<
|
||||
ck::tensor_operation::device::DeviceGroupedGemmMultiABDFixedNK<AsLayout,
|
||||
BsLayout,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
AsDataType,
|
||||
BsDataType,
|
||||
DsDataType,
|
||||
EDataType,
|
||||
PassThrough,
|
||||
Scales,
|
||||
PassThrough>>
|
||||
{
|
||||
using DeviceOp = DeviceGroupedGemmMultiABDFixedNK<AsLayout,
|
||||
BsLayout,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
AsDataType,
|
||||
BsDataType,
|
||||
DsDataType,
|
||||
EDataType,
|
||||
PassThrough,
|
||||
Scales,
|
||||
PassThrough>;
|
||||
|
||||
static auto GetInstances()
|
||||
{
|
||||
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
|
||||
|
||||
if constexpr(is_same_v<AsDataType, ck::Tuple<BF16>> &&
|
||||
is_same_v<BsDataType, ck::Tuple<I8, BF16>> &&
|
||||
is_same_v<DsDataType, ck::Tuple<>> && is_same_v<EDataType, BF16>)
|
||||
{
|
||||
if constexpr(is_same_v<AsLayout, ck::Tuple<Row>> &&
|
||||
is_same_v<BsLayout, ck::Tuple<Row, Row>> &&
|
||||
is_same_v<DsLayout, ck::Tuple<>> && is_same_v<ELayout, Row>)
|
||||
{
|
||||
add_device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_mk_kn_mn_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
|
||||
if constexpr(is_same_v<AsLayout, ck::Tuple<Col>> &&
|
||||
is_same_v<BsLayout, ck::Tuple<Row, Row>> &&
|
||||
is_same_v<DsLayout, ck::Tuple<>> && is_same_v<ELayout, Row>)
|
||||
{
|
||||
add_device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_km_kn_mn_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
|
||||
if constexpr(is_same_v<AsLayout, ck::Tuple<Row>> &&
|
||||
is_same_v<BsLayout, ck::Tuple<Col, Col>> &&
|
||||
is_same_v<DsLayout, ck::Tuple<>> && is_same_v<ELayout, Row>)
|
||||
{
|
||||
add_device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_mk_nk_mn_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
}
|
||||
|
||||
return op_ptrs;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
Reference in New Issue
Block a user