mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 02:02:46 +00:00
Optimizing fp8_fp16 mixedprec gemm (#1150)
* add delayed cvt
* extend fp16 gemm_splitk instances for fp8_fp16 gemm
* add f8 example
* add 128 kperblk instances for fp8
* add kpb128 instance
* added more instances into kpb128
* clean code
* clean code
* fix
* fix
* fixed
* Update example/35_splitK_gemm/splitK_gemm_xdl_fp16_fp8.cpp
Co-authored-by: Bartłomiej Kocot <barkocot@amd.com>
* Update include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp
Co-authored-by: Bartłomiej Kocot <barkocot@amd.com>
* Update library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_fp8_f16_mk_nk_mn_kpb128_instance.cpp
Co-authored-by: Bartłomiej Kocot <barkocot@amd.com>
---------
Co-authored-by: Jing Zhang <jizha@amd.com>
Co-authored-by: Bartłomiej Kocot <barkocot@amd.com>
[ROCm/composable_kernel commit: 602c4cc0d9]
This commit is contained in:
@@ -10,6 +10,9 @@ foreach(gpu IN LISTS GPU_TARGETS)
|
||||
add_example_executable(example_splitK_gemm_xdl_fp16 splitK_gemm_xdl_fp16.cpp)
|
||||
add_example_dependencies(example_splitK_gemm_xdl example_splitK_gemm_xdl_fp16)
|
||||
|
||||
add_example_executable(example_splitK_gemm_xdl_fp16_fp8 splitK_gemm_xdl_fp16_fp8.cpp)
|
||||
add_example_dependencies(example_splitK_gemm_xdl example_splitK_gemm_xdl_fp16_fp8)
|
||||
|
||||
add_example_executable(example_splitK_gemm_xdl_lds_direct_load_fp16 splitK_gemm_xdl_lds_direct_load_fp16.cpp)
|
||||
add_example_dependencies(example_splitK_gemm_xdl example_splitK_gemm_xdl_lds_direct_load_fp16)
|
||||
|
||||
|
||||
60
example/35_splitK_gemm/splitK_gemm_xdl_fp16_fp8.cpp
Normal file
60
example/35_splitK_gemm/splitK_gemm_xdl_fp16_fp8.cpp
Normal file
@@ -0,0 +1,60 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024 Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
#include <initializer_list>
|
||||
#include <cstdlib>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
#include "ck/library/utility/check_err.hpp"
|
||||
#include "ck/library/utility/device_memory.hpp"
|
||||
#include "ck/library/utility/host_tensor.hpp"
|
||||
#include "ck/library/utility/host_tensor_generator.hpp"
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
|
||||
#include "ck/library/utility/literals.hpp"
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
using F8 = ck::f8_t;
|
||||
using F16 = ck::half_t;
|
||||
using F32 = float;
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
using ADataType = F16;
|
||||
using BDataType = F8;
|
||||
using AccDataType = F32;
|
||||
using CDataType = F16;
|
||||
|
||||
using ALayout = Row;
|
||||
using BLayout = Col;
|
||||
using CLayout = Row;
|
||||
|
||||
using AElementOp = PassThrough;
|
||||
using BElementOp = PassThrough;
|
||||
using CElementOp = PassThrough;
|
||||
|
||||
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
|
||||
|
||||
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdlSplitKCShuffle
|
||||
// clang-format off
|
||||
//######| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| KPer| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
|
||||
//######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector|
|
||||
//######| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl|
|
||||
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
< ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmDefault, 128, 16, 64, 8, 16, 16, 16, 1, 2, S<1, 8, 8, 2>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 16, 16, true, 1, 1, S<1, 16, 1, 8>, 4, F16, ck::PipelineVersion::v1, ck::LoopScheduler::Default, ADataType, BDataType>;
|
||||
|
||||
// clang-format on
|
||||
|
||||
#include "run_splitK_gemm_example.inc"
|
||||
|
||||
int main(int argc, char* argv[]) { return !run_splitK_gemm_example(argc, argv); }
|
||||
@@ -37,7 +37,9 @@ template <index_t BlockSize,
|
||||
index_t NPerXDL,
|
||||
index_t MRepeat,
|
||||
index_t NRepeat,
|
||||
index_t KPack>
|
||||
index_t KPack,
|
||||
typename ComputeTypeA = FloatA,
|
||||
typename ComputeTypeB = FloatB>
|
||||
struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
|
||||
{
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
@@ -59,7 +61,8 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
|
||||
static constexpr index_t A_K1 = AK0MK1BlockDesc{}.GetLength(I2);
|
||||
static constexpr index_t B_K1 = BK0NK1BlockDesc{}.GetLength(I2);
|
||||
|
||||
static constexpr auto xdlops_gemm = XdlopsGemm<FloatA, MPerXDL, NPerXDL, KPack, FloatB>{};
|
||||
static constexpr auto xdlops_gemm =
|
||||
XdlopsGemm<ComputeTypeA, MPerXDL, NPerXDL, KPack, ComputeTypeB>{};
|
||||
|
||||
static constexpr index_t KPerThread = KPerBlock / xdlops_gemm.K0PerXdlops;
|
||||
|
||||
@@ -295,9 +298,9 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
|
||||
const BBlockBuffer& b_block_buf,
|
||||
CThreadBuffer& c_thread_buf) const
|
||||
{
|
||||
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatA>(
|
||||
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeA>(
|
||||
a_thread_desc_.GetElementSpaceSize());
|
||||
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatB>(
|
||||
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeB>(
|
||||
b_thread_desc_.GetElementSpaceSize());
|
||||
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
@@ -319,20 +322,20 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
|
||||
b_thread_buf);
|
||||
|
||||
static_for<0, KPerThread, KPack>{}([&](auto k) {
|
||||
vector_type<FloatA, KPack> a_thread_vec;
|
||||
vector_type<FloatB, KPack> b_thread_vec;
|
||||
vector_type<ComputeTypeA, KPack> a_thread_vec;
|
||||
vector_type<ComputeTypeB, KPack> b_thread_vec;
|
||||
|
||||
static_for<0, KPack, 1>{}([&](auto i) {
|
||||
a_thread_vec.template AsType<FloatA>()(i) = a_thread_buf
|
||||
a_thread_vec.template AsType<ComputeTypeA>()(i) = a_thread_buf
|
||||
[Number<a_thread_desc_.CalculateOffset(make_tuple(0, 0, 0, k + i))>{}];
|
||||
b_thread_vec.template AsType<FloatB>()(i) = b_thread_buf
|
||||
b_thread_vec.template AsType<ComputeTypeB>()(i) = b_thread_buf
|
||||
[Number<b_thread_desc_.CalculateOffset(make_tuple(0, 0, 0, k + i))>{}];
|
||||
});
|
||||
|
||||
using mfma_input_type_a =
|
||||
typename vector_type<FloatA, xdlops_gemm.K1PerXdlops>::type;
|
||||
typename vector_type<ComputeTypeA, xdlops_gemm.K1PerXdlops>::type;
|
||||
using mfma_input_type_b =
|
||||
typename vector_type<FloatB, xdlops_gemm.K1PerXdlops>::type;
|
||||
typename vector_type<ComputeTypeB, xdlops_gemm.K1PerXdlops>::type;
|
||||
|
||||
constexpr index_t c_offset =
|
||||
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
|
||||
@@ -360,7 +363,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
|
||||
make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, xdlops_gemm.GetRegSizePerXdlops()));
|
||||
|
||||
using AThreadCopy = ThreadwiseTensorSliceTransfer_v4<FloatA,
|
||||
FloatA,
|
||||
ComputeTypeA,
|
||||
decltype(a_block_desc_m0_m1_m2_k),
|
||||
decltype(a_thread_desc_),
|
||||
Sequence<1, 1, 1, KPerThread>,
|
||||
@@ -370,7 +373,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
|
||||
A_K1>;
|
||||
|
||||
using BThreadCopy = ThreadwiseTensorSliceTransfer_v4<FloatB,
|
||||
FloatB,
|
||||
ComputeTypeB,
|
||||
decltype(b_block_desc_n0_n1_n2_k),
|
||||
decltype(b_thread_desc_),
|
||||
Sequence<1, 1, 1, KPerThread>,
|
||||
@@ -398,6 +401,8 @@ template <index_t BlockSize,
|
||||
index_t MRepeat,
|
||||
index_t NRepeat,
|
||||
index_t KPack,
|
||||
typename ComputeTypeA = FloatA,
|
||||
typename ComputeTypeB = FloatB,
|
||||
index_t NumMacClusters = CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING_MAC_CLUSTERS>
|
||||
struct BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
|
||||
: public BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize,
|
||||
@@ -410,7 +415,9 @@ struct BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
|
||||
NPerXDL,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack>
|
||||
KPack,
|
||||
ComputeTypeA,
|
||||
ComputeTypeB>
|
||||
{
|
||||
using Base = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize,
|
||||
FloatA,
|
||||
@@ -422,7 +429,9 @@ struct BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
|
||||
NPerXDL,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack>;
|
||||
KPack,
|
||||
ComputeTypeA,
|
||||
ComputeTypeB>;
|
||||
|
||||
#if CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING
|
||||
using Base::a_block_desc_m0_m1_m2_k;
|
||||
@@ -446,9 +455,9 @@ struct BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
|
||||
const BBlockBuffer& b_block_buf,
|
||||
CThreadBuffer& c_thread_buf) const
|
||||
{
|
||||
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatA>(
|
||||
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeA>(
|
||||
a_thread_desc_.GetElementSpaceSize());
|
||||
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatB>(
|
||||
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeB>(
|
||||
b_thread_desc_.GetElementSpaceSize());
|
||||
|
||||
static_for<0, KPerThread, KPerInnerLoop>{}([&](auto k) {
|
||||
@@ -485,22 +494,22 @@ struct BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
|
||||
static_for<0, KPerInnerLoop, KPack>{}([&](auto k_) {
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
vector_type<FloatA, KPack> a_thread_vec;
|
||||
vector_type<FloatB, KPack> b_thread_vec;
|
||||
vector_type<ComputeTypeA, KPack> a_thread_vec;
|
||||
vector_type<ComputeTypeB, KPack> b_thread_vec;
|
||||
|
||||
static_for<0, KPack, 1>{}([&](auto i) {
|
||||
a_thread_vec.template AsType<FloatA>()(i) =
|
||||
a_thread_vec.template AsType<ComputeTypeA>()(i) =
|
||||
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
|
||||
make_tuple(m0, 0, 0, k_ + i))>{}];
|
||||
b_thread_vec.template AsType<FloatB>()(i) =
|
||||
b_thread_vec.template AsType<ComputeTypeB>()(i) =
|
||||
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(n0, 0, 0, k_ + i))>{}];
|
||||
});
|
||||
|
||||
using mfma_input_type_a =
|
||||
typename vector_type<FloatA, xdlops_gemm.K1PerXdlops>::type;
|
||||
typename vector_type<ComputeTypeA, xdlops_gemm.K1PerXdlops>::type;
|
||||
using mfma_input_type_b =
|
||||
typename vector_type<FloatB, xdlops_gemm.K1PerXdlops>::type;
|
||||
typename vector_type<ComputeTypeB, xdlops_gemm.K1PerXdlops>::type;
|
||||
|
||||
constexpr index_t c_offset =
|
||||
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
|
||||
@@ -550,7 +559,7 @@ struct BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
|
||||
make_tuple(Number<NRepeat>{}, I1, I1, Number<KPerInnerLoop>{}));
|
||||
|
||||
using AThreadCopy = ThreadwiseTensorSliceTransfer_v4<FloatA,
|
||||
FloatA,
|
||||
ComputeTypeA,
|
||||
decltype(a_block_desc_m0_m1_m2_k),
|
||||
decltype(a_thread_desc_),
|
||||
Sequence<1, 1, 1, KPerInnerLoop>,
|
||||
@@ -560,7 +569,7 @@ struct BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
|
||||
A_K1>;
|
||||
|
||||
using BThreadCopy = ThreadwiseTensorSliceTransfer_v4<FloatB,
|
||||
FloatB,
|
||||
ComputeTypeB,
|
||||
decltype(b_block_desc_n0_n1_n2_k),
|
||||
decltype(b_thread_desc_),
|
||||
Sequence<1, 1, 1, KPerInnerLoop>,
|
||||
@@ -586,7 +595,9 @@ template <index_t BlockSize,
|
||||
index_t MRepeat,
|
||||
index_t NRepeat,
|
||||
index_t KPack,
|
||||
LoopScheduler LoopSched>
|
||||
LoopScheduler LoopSched,
|
||||
typename ComputeTypeA = FloatA,
|
||||
typename ComputeTypeB = FloatB>
|
||||
constexpr auto BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector()
|
||||
{
|
||||
if constexpr(LoopSched == LoopScheduler::Default)
|
||||
@@ -601,7 +612,9 @@ constexpr auto BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector()
|
||||
NPerXDL,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack>{};
|
||||
KPack,
|
||||
ComputeTypeA,
|
||||
ComputeTypeB>{};
|
||||
}
|
||||
else if constexpr(LoopSched == LoopScheduler::Interwave)
|
||||
{
|
||||
@@ -615,7 +628,9 @@ constexpr auto BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector()
|
||||
NPerXDL,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack>{};
|
||||
KPack,
|
||||
ComputeTypeA,
|
||||
ComputeTypeB>{};
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -60,7 +60,9 @@ template <typename ADataType,
|
||||
index_t CBlockTransferScalarPerVector_NWaveNPerXDL,
|
||||
typename ComputeType = CDataType,
|
||||
PipelineVersion PipelineVer = PipelineVersion::v1,
|
||||
LoopScheduler LoopSched = make_default_loop_scheduler()>
|
||||
LoopScheduler LoopSched = make_default_loop_scheduler(),
|
||||
typename LDSTypeA = ComputeType,
|
||||
typename LDSTypeB = ComputeType>
|
||||
|
||||
struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
|
||||
BLayout,
|
||||
@@ -81,6 +83,9 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
|
||||
// TODO: should be exposed as Tparams.
|
||||
static constexpr index_t NumGemmKPrefetchStage = 1;
|
||||
|
||||
using ComputeTypeA = ComputeType;
|
||||
using ComputeTypeB = ComputeType;
|
||||
|
||||
using GridwiseGemm = GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2<
|
||||
BlockSize,
|
||||
ADataType,
|
||||
@@ -125,7 +130,10 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
|
||||
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
LoopSched,
|
||||
PipelineVer,
|
||||
ComputeType>;
|
||||
ComputeTypeA,
|
||||
ComputeTypeB,
|
||||
LDSTypeA,
|
||||
LDSTypeB>;
|
||||
|
||||
struct Argument : public GridwiseGemm::Argument
|
||||
{
|
||||
|
||||
@@ -21,50 +21,11 @@ struct PassThroughPack2
|
||||
template <typename Y, typename X>
|
||||
__host__ __device__ void operator()(Y& y, const X& x) const;
|
||||
|
||||
__host__ __device__ constexpr void operator()(ck::f8x2_t& y, const ck::half2_t& x) const
|
||||
{
|
||||
// fake conversion
|
||||
uint16_t t = ck::bit_cast<uint32_t>(x);
|
||||
y = ck::bit_cast<ck::f8x2_t>(t);
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr void operator()(ck::half2_t& y, const ck::f8x2_t& x) const
|
||||
{
|
||||
auto t = type_convert<float2_t>(x);
|
||||
y = type_convert<half2_t>(t);
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr void operator()(ck::half2_t& y, const ck::half2_t& x) const
|
||||
{
|
||||
y = x;
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr void operator()(ck::f8x2_t& y, const ck::f8x2_t& x) const
|
||||
{
|
||||
y = x;
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr void operator()(ck::float2_t& y, const ck::float2_t& x) const
|
||||
{
|
||||
y = x;
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr void operator()(ck::int8x2_t& y, const ck::int8x2_t& x) const
|
||||
{
|
||||
y = x;
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr void operator()(ck::bhalf2_t& y, const ck::bhalf2_t& x) const
|
||||
{
|
||||
y = x;
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr void operator()(ck::double2_t& y, const ck::double2_t& x) const
|
||||
{
|
||||
y = x;
|
||||
}
|
||||
|
||||
constexpr const static bool is_pack2_invocable = true;
|
||||
};
|
||||
|
||||
struct PassThrough
|
||||
|
||||
@@ -9,7 +9,6 @@
|
||||
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp"
|
||||
#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp"
|
||||
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp"
|
||||
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp"
|
||||
@@ -96,7 +95,10 @@ template <index_t BlockSize,
|
||||
typename CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
LoopScheduler LoopSched = make_default_loop_scheduler(),
|
||||
PipelineVersion PipelineVer = PipelineVersion::v1,
|
||||
typename ComputeType = FloatC>
|
||||
typename ComputeTypeA = FloatC,
|
||||
typename ComputeTypeB = ComputeTypeA,
|
||||
typename LDSTypeA = ComputeTypeA,
|
||||
typename LDSTypeB = ComputeTypeB>
|
||||
struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
|
||||
{
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
@@ -430,7 +432,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
|
||||
constexpr auto c_block_size =
|
||||
GetCBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock().GetElementSpaceSize();
|
||||
|
||||
return math::max((a_block_space_size + b_block_space_size) * sizeof(ComputeType),
|
||||
return math::max(a_block_space_size * sizeof(LDSTypeA) +
|
||||
b_block_space_size * sizeof(LDSTypeB),
|
||||
c_block_size * sizeof(FloatC));
|
||||
}
|
||||
|
||||
@@ -785,7 +788,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
|
||||
ABlockTransferThreadClusterLengths_K0_M_K1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
FloatA,
|
||||
ComputeType,
|
||||
LDSTypeA,
|
||||
decltype(a_b_k0_m_k1_grid_desc),
|
||||
decltype(a_b_k0_m_k1_block_desc),
|
||||
ABlockTransferSrcAccessOrder,
|
||||
@@ -815,7 +818,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
|
||||
BBlockTransferThreadClusterLengths_K0_N_K1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
FloatB,
|
||||
ComputeType,
|
||||
LDSTypeB,
|
||||
decltype(b_b_k0_n_k1_grid_desc),
|
||||
decltype(b_b_k0_n_k1_block_desc),
|
||||
BBlockTransferSrcAccessOrder,
|
||||
@@ -845,8 +848,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
|
||||
|
||||
auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector<
|
||||
BlockSize,
|
||||
ComputeType, // ComputeType A
|
||||
ComputeType, // ComputeType B
|
||||
LDSTypeA,
|
||||
LDSTypeB,
|
||||
FloatAcc,
|
||||
decltype(a_k0_m_k1_block_desc),
|
||||
decltype(b_k0_n_k1_block_desc),
|
||||
@@ -855,7 +858,9 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
K1,
|
||||
LoopSched>();
|
||||
LoopSched,
|
||||
ComputeTypeA,
|
||||
ComputeTypeB>();
|
||||
|
||||
auto c_thread_buf = blockwise_gemm.GetCThreadBuffer();
|
||||
|
||||
@@ -863,8 +868,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
|
||||
constexpr auto a_block_space_size =
|
||||
math::integer_least_multiple(a_k0_m_k1_block_desc.GetElementSpaceSize(), max_lds_align);
|
||||
|
||||
ComputeType* p_a_block = static_cast<ComputeType*>(p_shared_block);
|
||||
ComputeType* p_b_block = static_cast<ComputeType*>(p_shared_block) + a_block_space_size;
|
||||
auto p_a_block = reinterpret_cast<LDSTypeA*>(p_shared_block);
|
||||
auto p_b_block = reinterpret_cast<LDSTypeB*>(p_a_block + a_block_space_size);
|
||||
|
||||
constexpr auto a_block_slice_copy_step = make_multi_index(0, K0PerBlock, 0, 0);
|
||||
constexpr auto b_block_slice_copy_step = make_multi_index(0, K0PerBlock, 0, 0);
|
||||
|
||||
@@ -8,6 +8,8 @@
|
||||
#include "ck/tensor_description/tensor_descriptor.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
|
||||
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
// Do following things to avoid "alloca" in LLVM-IR, which would cause scratch memory
|
||||
@@ -1156,27 +1158,56 @@ struct ThreadwiseTensorSliceTransfer_v4
|
||||
src_ref_to_origin_disp_idx + data_to_origin_disp_idx +
|
||||
i * src_scalar_step_in_vector);
|
||||
|
||||
// apply type convert
|
||||
src_tmp_vector.template AsType<SrcData>()(i) = src_buf[Number<src_offset>{}];
|
||||
});
|
||||
}
|
||||
// copy data from src_tmp_vector to dst_tmp_vector (data cast data from SrcData to
|
||||
// DstData)
|
||||
vector_type_maker_t<DstData, SrcScalarPerVector> dst_tmp_vector;
|
||||
|
||||
// TODO: if SrcData and DstData are vetor type, then static_cast may not compile
|
||||
static_for<0, SrcScalarPerVector, 1>{}([&](auto i) {
|
||||
dst_tmp_vector.template AsType<DstData>()(i) =
|
||||
type_convert<DstData>(src_tmp_vector.template AsType<SrcData>()[i]);
|
||||
});
|
||||
if constexpr(is_same<remove_cvref_t<SrcData>, f8_t>::value &&
|
||||
is_same<remove_cvref_t<DstData>, half_t>::value &&
|
||||
SrcScalarPerVector % 2 == 0)
|
||||
{
|
||||
// copy data from src_tmp_vector to dst_tmp_vector (data cast data from SrcData to
|
||||
// DstData)
|
||||
vector_type_maker_t<DstData, SrcScalarPerVector> dst_tmp_vector;
|
||||
|
||||
// copy data from dst_tmp_vector into dst_buf
|
||||
static_for<0, SrcScalarPerVector, 1>{}([&](auto i) {
|
||||
constexpr index_t dst_offset = dst_desc.CalculateOffset(
|
||||
dst_origin_idx + data_to_origin_disp_idx + i * src_scalar_step_in_vector);
|
||||
constexpr index_t pack_size = 2;
|
||||
|
||||
dst_buf(Number<dst_offset>{}) = dst_tmp_vector.template AsType<DstData>()[i];
|
||||
});
|
||||
using dst_v_t = typename vector_type_maker_t<DstData, pack_size>::type;
|
||||
using src_v_t = typename vector_type_maker_t<SrcData, pack_size>::type;
|
||||
static_for<0, SrcScalarPerVector / pack_size, 1>{}([&](auto i) {
|
||||
ck::tensor_operation::element_wise::PassThroughPack2{}(
|
||||
dst_tmp_vector.template AsType<dst_v_t>()(i),
|
||||
src_tmp_vector.template AsType<src_v_t>()[i]);
|
||||
});
|
||||
|
||||
// copy data from dst_tmp_vector into dst_buf
|
||||
static_for<0, SrcScalarPerVector, 1>{}([&](auto i) {
|
||||
constexpr index_t dst_offset = dst_desc.CalculateOffset(
|
||||
dst_origin_idx + data_to_origin_disp_idx + i * src_scalar_step_in_vector);
|
||||
|
||||
dst_buf(Number<dst_offset>{}) = dst_tmp_vector.template AsType<DstData>()[i];
|
||||
});
|
||||
}
|
||||
else
|
||||
{
|
||||
// copy data from src_tmp_vector to dst_tmp_vector (data cast data from SrcData to
|
||||
// DstData)
|
||||
vector_type_maker_t<DstData, SrcScalarPerVector> dst_tmp_vector;
|
||||
|
||||
// TODO: if SrcData and DstData are vetor type, then static_cast may not compile
|
||||
static_for<0, SrcScalarPerVector, 1>{}([&](auto i) {
|
||||
dst_tmp_vector.template AsType<DstData>()(i) =
|
||||
type_convert<DstData>(src_tmp_vector.template AsType<SrcData>()[i]);
|
||||
});
|
||||
|
||||
// copy data from dst_tmp_vector into dst_buf
|
||||
static_for<0, SrcScalarPerVector, 1>{}([&](auto i) {
|
||||
constexpr index_t dst_offset = dst_desc.CalculateOffset(
|
||||
dst_origin_idx + data_to_origin_disp_idx + i * src_scalar_step_in_vector);
|
||||
|
||||
dst_buf(Number<dst_offset>{}) = dst_tmp_vector.template AsType<DstData>()[i];
|
||||
});
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
@@ -189,6 +189,11 @@ void add_device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_v2_instances(
|
||||
DeviceGemmSplitK<Row, Col, Row, F16, F8, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_kpb128_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemmSplitK<Row, Col, Row, F16, F8, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_splitk_f16_f16_f16_comp_f8_km_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemmSplitK<Col, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough, F8>>>&
|
||||
@@ -352,6 +357,7 @@ struct DeviceOperationInstanceFactory<
|
||||
add_device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_v1_instances(op_ptrs);
|
||||
add_device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_v1_interwave_instances(op_ptrs);
|
||||
add_device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_v2_instances(op_ptrs);
|
||||
add_device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_kpb128_instances(op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Row> &&
|
||||
is_same_v<CLayout, Row>)
|
||||
|
||||
@@ -1,42 +1,45 @@
|
||||
set(GEMM_SPLITK_INSTANCES)
|
||||
|
||||
list(APPEND GEMM_SPLITK_INSTANCES device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instance.cpp
|
||||
device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instance.cpp
|
||||
device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instance.cpp
|
||||
device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instance.cpp
|
||||
device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_v1_instance.cpp
|
||||
device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_v1_irregular_instance.cpp
|
||||
device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_v1_interwave_instance.cpp
|
||||
device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_v1_interwave_irregular_instance.cpp
|
||||
device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_v2_instance.cpp
|
||||
device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_v2_irregular_instance.cpp
|
||||
device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_v1_instance.cpp
|
||||
device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_v1_irregular_instance.cpp
|
||||
device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_v1_interwave_instance.cpp
|
||||
device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_v1_interwave_irregular_instance.cpp
|
||||
device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_v2_instance.cpp
|
||||
device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_v2_irregular_instance.cpp
|
||||
device_gemm_xdl_splitk_f16_f16_f16_km_kn_mn_instance.cpp
|
||||
device_gemm_xdl_splitk_f16_f16_f16_km_nk_mn_instance.cpp
|
||||
device_gemm_xdl_splitk_fp8_f16_f16_mk_kn_mn_v1_instance.cpp
|
||||
device_gemm_xdl_splitk_fp8_f16_f16_mk_kn_mn_v1_interwave_instance.cpp
|
||||
device_gemm_xdl_splitk_fp8_f16_f16_mk_kn_mn_v2_instance.cpp
|
||||
device_gemm_xdl_splitk_lds_direct_load_f16_f16_f16_mk_nk_mn_instance.cpp
|
||||
device_gemm_xdl_splitk_fp8_f16_f16_mk_nk_mn_instance.cpp
|
||||
device_gemm_xdl_splitk_fp8_f16_f16_km_kn_mn_instance.cpp
|
||||
device_gemm_xdl_splitk_fp8_f16_f16_km_nk_mn_instance.cpp
|
||||
device_gemm_xdl_splitk_f16_fp8_f16_mk_kn_mn_v1_instance.cpp
|
||||
device_gemm_xdl_splitk_f16_fp8_f16_mk_kn_mn_v1_interwave_instance.cpp
|
||||
device_gemm_xdl_splitk_f16_fp8_f16_mk_kn_mn_v2_instance.cpp
|
||||
device_gemm_xdl_splitk_f16_fp8_f16_mk_kn_mn_irregular_instance.cpp
|
||||
device_gemm_xdl_splitk_f16_fp8_f16_mk_nk_mn_v1_instance.cpp
|
||||
device_gemm_xdl_splitk_f16_fp8_f16_mk_nk_mn_v1_interwave_instance.cpp
|
||||
device_gemm_xdl_splitk_f16_fp8_f16_mk_nk_mn_v2_instance.cpp
|
||||
device_gemm_xdl_splitk_f16_fp8_f16_km_kn_mn_instance.cpp
|
||||
device_gemm_xdl_splitk_f16_fp8_f16_km_nk_mn_instance.cpp
|
||||
device_gemm_xdl_splitk_f16_f16_f16_comp_fp8_mk_kn_mn_instance.cpp
|
||||
device_gemm_xdl_splitk_f16_f16_f16_comp_fp8_mk_nk_mn_instance.cpp
|
||||
device_gemm_xdl_splitk_f16_f16_f16_comp_fp8_km_kn_mn_instance.cpp
|
||||
device_gemm_xdl_splitk_f16_f16_f16_comp_fp8_km_nk_mn_instance.cpp)
|
||||
list(APPEND GEMM_SPLITK_INSTANCES
|
||||
device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instance.cpp
|
||||
device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instance.cpp
|
||||
device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instance.cpp
|
||||
device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instance.cpp
|
||||
device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_v1_instance.cpp
|
||||
device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_v1_irregular_instance.cpp
|
||||
device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_v1_interwave_instance.cpp
|
||||
device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_v1_interwave_irregular_instance.cpp
|
||||
device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_v2_instance.cpp
|
||||
device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_v2_irregular_instance.cpp
|
||||
device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_v1_instance.cpp
|
||||
device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_v1_irregular_instance.cpp
|
||||
device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_v1_interwave_instance.cpp
|
||||
device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_v1_interwave_irregular_instance.cpp
|
||||
device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_v2_instance.cpp
|
||||
device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_v2_irregular_instance.cpp
|
||||
device_gemm_xdl_splitk_f16_f16_f16_km_kn_mn_instance.cpp
|
||||
device_gemm_xdl_splitk_f16_f16_f16_km_nk_mn_instance.cpp
|
||||
device_gemm_xdl_splitk_fp8_f16_f16_mk_kn_mn_v1_instance.cpp
|
||||
device_gemm_xdl_splitk_fp8_f16_f16_mk_kn_mn_v1_interwave_instance.cpp
|
||||
device_gemm_xdl_splitk_fp8_f16_f16_mk_kn_mn_v2_instance.cpp
|
||||
device_gemm_xdl_splitk_lds_direct_load_f16_f16_f16_mk_nk_mn_instance.cpp
|
||||
device_gemm_xdl_splitk_fp8_f16_f16_mk_nk_mn_instance.cpp
|
||||
device_gemm_xdl_splitk_fp8_f16_f16_km_kn_mn_instance.cpp
|
||||
device_gemm_xdl_splitk_fp8_f16_f16_km_nk_mn_instance.cpp
|
||||
device_gemm_xdl_splitk_f16_fp8_f16_mk_kn_mn_v1_instance.cpp
|
||||
device_gemm_xdl_splitk_f16_fp8_f16_mk_kn_mn_v1_interwave_instance.cpp
|
||||
device_gemm_xdl_splitk_f16_fp8_f16_mk_kn_mn_v2_instance.cpp
|
||||
device_gemm_xdl_splitk_f16_fp8_f16_mk_kn_mn_irregular_instance.cpp
|
||||
device_gemm_xdl_splitk_f16_fp8_f16_mk_nk_mn_kpb128_instance.cpp
|
||||
device_gemm_xdl_splitk_f16_fp8_f16_mk_nk_mn_v1_instance.cpp
|
||||
device_gemm_xdl_splitk_f16_fp8_f16_mk_nk_mn_v1_interwave_instance.cpp
|
||||
device_gemm_xdl_splitk_f16_fp8_f16_mk_nk_mn_v2_instance.cpp
|
||||
device_gemm_xdl_splitk_f16_fp8_f16_km_kn_mn_instance.cpp
|
||||
device_gemm_xdl_splitk_f16_fp8_f16_km_nk_mn_instance.cpp
|
||||
device_gemm_xdl_splitk_f16_f16_f16_comp_fp8_mk_kn_mn_instance.cpp
|
||||
device_gemm_xdl_splitk_f16_f16_f16_comp_fp8_mk_nk_mn_instance.cpp
|
||||
device_gemm_xdl_splitk_f16_f16_f16_comp_fp8_km_kn_mn_instance.cpp
|
||||
device_gemm_xdl_splitk_f16_f16_f16_comp_fp8_km_nk_mn_instance.cpp
|
||||
)
|
||||
|
||||
add_instance_library(device_gemm_splitk_instance ${GEMM_SPLITK_INSTANCES})
|
||||
|
||||
@@ -0,0 +1,147 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <cstdlib>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle.hpp"
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
using F8 = ck::f8_t;
|
||||
using F16 = ck::half_t;
|
||||
using F32 = float;
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
|
||||
static constexpr auto GemmMNPadding = ck::tensor_operation::device::GemmSpecialization::MNPadding;
|
||||
static constexpr auto GemmKPadding = ck::tensor_operation::device::GemmSpecialization::KPadding;
|
||||
static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
|
||||
|
||||
template <ck::tensor_operation::device::GemmSpecialization GemmSpec,
|
||||
ck::PipelineVersion PipVer,
|
||||
ck::LoopScheduler LoopSche>
|
||||
using device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_irregular_kpb128_instances = std::tuple<
|
||||
// clang-format off
|
||||
//#########################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
|
||||
//#########################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector|
|
||||
//#########################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl|
|
||||
//#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 32, 8, 16, 16, 16, 1, 1, S<1, 8, 8, 2>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 16, 16, true, 1, 1, S<1, 16, 1, 8>, 4, F16, PipVer, LoopSche, F16, F8>,
|
||||
DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 64, 8, 16, 16, 16, 1, 2, S<1, 8, 8, 2>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 16, 16, true, 1, 1, S<1, 16, 1, 8>, 4, F16, PipVer, LoopSche, F16, F8>,
|
||||
DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 128, 8, 16, 16, 16, 1, 4, S<1, 8, 8, 2>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 16, 16, true, 1, 1, S<1, 16, 1, 8>, 4, F16, PipVer, LoopSche, F16, F8>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
void add_device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_kpb128_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemmSplitK<Row, Col, Row, F16, F8, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances)
|
||||
{
|
||||
// default
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_irregular_kpb128_instances<
|
||||
GemmDefault,
|
||||
ck::PipelineVersion::v2,
|
||||
ck::LoopScheduler::Default>{});
|
||||
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_irregular_kpb128_instances<
|
||||
GemmDefault,
|
||||
ck::PipelineVersion::v1,
|
||||
ck::LoopScheduler::Interwave>{});
|
||||
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_irregular_kpb128_instances<
|
||||
GemmDefault,
|
||||
ck::PipelineVersion::v1,
|
||||
ck::LoopScheduler::Default>{});
|
||||
|
||||
// MNKPadding
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_irregular_kpb128_instances<
|
||||
GemmMNKPadding,
|
||||
ck::PipelineVersion::v2,
|
||||
ck::LoopScheduler::Default>{});
|
||||
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_irregular_kpb128_instances<
|
||||
GemmMNKPadding,
|
||||
ck::PipelineVersion::v1,
|
||||
ck::LoopScheduler::Interwave>{});
|
||||
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_irregular_kpb128_instances<
|
||||
GemmMNKPadding,
|
||||
ck::PipelineVersion::v1,
|
||||
ck::LoopScheduler::Default>{});
|
||||
|
||||
// KPadding
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_irregular_kpb128_instances<
|
||||
GemmKPadding,
|
||||
ck::PipelineVersion::v2,
|
||||
ck::LoopScheduler::Default>{});
|
||||
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_irregular_kpb128_instances<
|
||||
GemmKPadding,
|
||||
ck::PipelineVersion::v1,
|
||||
ck::LoopScheduler::Interwave>{});
|
||||
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_irregular_kpb128_instances<
|
||||
GemmKPadding,
|
||||
ck::PipelineVersion::v1,
|
||||
ck::LoopScheduler::Default>{});
|
||||
|
||||
// MNPadding
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_irregular_kpb128_instances<
|
||||
GemmMNPadding,
|
||||
ck::PipelineVersion::v2,
|
||||
ck::LoopScheduler::Default>{});
|
||||
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_irregular_kpb128_instances<
|
||||
GemmMNPadding,
|
||||
ck::PipelineVersion::v1,
|
||||
ck::LoopScheduler::Interwave>{});
|
||||
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_irregular_kpb128_instances<
|
||||
GemmMNPadding,
|
||||
ck::PipelineVersion::v1,
|
||||
ck::LoopScheduler::Default>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
Reference in New Issue
Block a user