mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 10:09:41 +00:00
Multiple fixes to GroupedGemm+SplitK (#707)
* Add license header.
* Reduce number of logged output. Add constant initialization.
* Add functional tests for grouped_gemm with different kbatch value.
* Add debug log informations + remove unused code.
* Don't pass kbatch to CalculateKPadded.
* Turn on logging in grouped gemm and gemm splitk profiler
* Debug: limit number of test cases to run;
* Log more information and initialize with constant value.
* Turn on DEBUG_LOG
* Add more debug log informations.
* Limit the number of instances to compile.
* Use GridwiseGemmPipeline
* Use KBatch to calculate K0
* Multiple DebugLog messages.
* Unit tests for multiple KBatch values.
* Refactoring
* Disable logging
* extract out of if statement KBatch update.
* Uncomment instances.
* Disable DebugLog.
* Use Kbatch when calculate KPadded.
* Fix CGridDesc padding.
* Use available helper functions.
* Uncomment code commented for debuggin.
* Remove unnecessary debug log messages.
* Uncomment previously commented code for debug purposes.
* Add KBatch info to profiler output summary log.
* Add gtests for gemm splitk using ckProfiler API.
* Add more test-cases for different data layout.
* Add more test cases for gemm splitk
* Remove old test.
* Unit tests for MKNK ggemm interface.
* Fix and add more unit-tests.
* Constepxr everything!
* Increase error threshold for fp16 and splitk.
Since we're using fp16 atomic add for splitk there's a
known precision loss.
---------
Co-authored-by: Adam Osewski <aosewski@amd.com>
Co-authored-by: zjing14 <zhangjing14@gmail.com>
[ROCm/composable_kernel commit: 70e4eb567f]
This commit is contained in:
@@ -1,4 +1,8 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <vector>
|
||||
|
||||
|
||||
@@ -73,6 +73,11 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
static constexpr auto I3 = Number<3>{};
|
||||
|
||||
// TODO: should be exposed as Tparams.
|
||||
static constexpr index_t NumGemmKPrefetchStage = 1;
|
||||
static constexpr LoopScheduler LoopSched = make_default_loop_scheduler();
|
||||
static constexpr PipelineVersion PipelineVer = PipelineVersion::v2;
|
||||
|
||||
using GridwiseGemm = GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2<
|
||||
BlockSize,
|
||||
ADataType, // TODO: distinguish A/B datatype
|
||||
@@ -85,6 +90,7 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation,
|
||||
GemmSpec,
|
||||
NumGemmKPrefetchStage,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
K0PerBlock,
|
||||
@@ -112,7 +118,9 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
|
||||
CShuffleMRepeatPerShuffle,
|
||||
CShuffleNRepeatPerShuffle,
|
||||
CBlockTransferScalarPerVector_NWaveNPerXDL,
|
||||
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock>;
|
||||
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
LoopSched,
|
||||
PipelineVer>;
|
||||
|
||||
using Argument = typename GridwiseGemm::Argument;
|
||||
using DefaultBlock2CTileMap = typename GridwiseGemm::DefaultBlock2CTileMap;
|
||||
@@ -257,7 +265,7 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
|
||||
StrideC,
|
||||
GridwiseGemm::CalculateMPadded(M),
|
||||
GridwiseGemm::CalculateNPadded(N),
|
||||
GridwiseGemm::CalculateKPadded(K),
|
||||
GridwiseGemm::CalculateKPadded(K, KBatch),
|
||||
GridwiseGemm::CalculateK0(K, KBatch),
|
||||
KBatch};
|
||||
}
|
||||
@@ -290,7 +298,7 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
|
||||
StrideC,
|
||||
GridwiseGemm::CalculateMPadded(M),
|
||||
GridwiseGemm::CalculateNPadded(N),
|
||||
GridwiseGemm::CalculateKPadded(K),
|
||||
GridwiseGemm::CalculateKPadded(K, KBatch),
|
||||
GridwiseGemm::CalculateK0(K, KBatch),
|
||||
KBatch);
|
||||
}
|
||||
|
||||
@@ -85,7 +85,7 @@ template <typename ALayout,
|
||||
typename BElementwiseOperation,
|
||||
typename CDEElementwiseOperation,
|
||||
GemmSpecialization GemmSpec,
|
||||
ck::index_t NumPrefetch,
|
||||
ck::index_t NumGemmKPrefetchStage,
|
||||
ck::index_t BlockSize,
|
||||
ck::index_t MPerBlock,
|
||||
ck::index_t NPerBlock,
|
||||
@@ -152,6 +152,7 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
|
||||
BElementwiseOperation,
|
||||
CDEElementwiseOperation,
|
||||
GemmSpec,
|
||||
NumGemmKPrefetchStage,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
K0PerBlock,
|
||||
@@ -179,7 +180,9 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
|
||||
CShuffleMXdlPerWavePerShuffle,
|
||||
CShuffleNXdlPerWavePerShuffle,
|
||||
CDEBlockTransferScalarPerVector_NPerBlock,
|
||||
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock>;
|
||||
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
LoopSched,
|
||||
PipelineVersion::v2>;
|
||||
|
||||
using CGridDesc_M_N = typename GridwiseGemm::CGridDesc_M_N;
|
||||
using Block2ETileMapKSplit =
|
||||
@@ -265,8 +268,7 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
|
||||
const index_t k_padded = GridwiseGemm::CalculateKPadded(K, K_BATCH);
|
||||
const index_t k0 = GridwiseGemm::CalculateK0(K, K_BATCH);
|
||||
|
||||
const auto c_grid_desc_m_n =
|
||||
GridwiseGemm::MakeCGridDescriptor_M_N(M, N, m_padded, n_padded, stride_c);
|
||||
const auto c_grid_desc_m_n = GridwiseGemm::MakeCGridDescriptor_M_N(M, N, stride_c);
|
||||
|
||||
const auto local_b2c_tile_map =
|
||||
Block2ETileMapKSplit{c_grid_desc_m_n, B2E_M01, K_BATCH};
|
||||
@@ -319,8 +321,8 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
|
||||
const index_t k_padded = GridwiseGemm::CalculateKPadded(karg.K, K_BATCH);
|
||||
const index_t k0 = GridwiseGemm::CalculateK0(karg.K, K_BATCH);
|
||||
|
||||
const auto c_grid_desc_m_n = GridwiseGemm::MakeCGridDescriptor_M_N(
|
||||
karg.M, karg.N, karg.MPadded, karg.NPadded, karg.StrideC);
|
||||
const auto c_grid_desc_m_n =
|
||||
GridwiseGemm::MakeCGridDescriptor_M_N(karg.M, karg.N, karg.StrideC);
|
||||
|
||||
const auto local_b2c_tile_map =
|
||||
Block2ETileMapKSplit{c_grid_desc_m_n, B2E_M01, K_BATCH};
|
||||
@@ -501,6 +503,11 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
|
||||
if((ck::type_convert<ck::index_t>(arg.gemm_kernel_args_.size()) +
|
||||
arg.skipped_group_count_) != arg.group_count_)
|
||||
{
|
||||
#if DEBUG_LOG
|
||||
std::cout << "The group count is not equal to sum of skipped groups "
|
||||
"and kernel args size!"
|
||||
<< std::endl;
|
||||
#endif // DEBUG_LOG
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -509,14 +516,15 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
|
||||
{
|
||||
const auto& a = arg.gemm_kernel_args_[i].karg_;
|
||||
bool group_arg_valid = GridwiseGemm::CheckValidity(a);
|
||||
#if DEBUG_LOG
|
||||
if(not group_arg_valid)
|
||||
{
|
||||
std::cout << "[" << __func__ << "] group id: " << i << " is not supported!\n";
|
||||
#if DEBUG_LOG
|
||||
std::cout << "[" << __func__ << "] group id: " << i
|
||||
<< " has invalid GridwiseGemm settings!" << std::endl;
|
||||
a.Print();
|
||||
}
|
||||
#endif // DEBUG_LOG
|
||||
supported &= group_arg_valid;
|
||||
}
|
||||
supported = supported && group_arg_valid;
|
||||
}
|
||||
return supported;
|
||||
}
|
||||
|
||||
@@ -8,14 +8,14 @@
|
||||
#include "ck/tensor_description/tensor_descriptor.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp"
|
||||
#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp"
|
||||
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp"
|
||||
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp"
|
||||
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
@@ -55,6 +55,7 @@ template <index_t BlockSize,
|
||||
typename BElementwiseOperation,
|
||||
typename CElementwiseOperation,
|
||||
tensor_operation::device::GemmSpecialization GemmSpec,
|
||||
index_t NumGemmKPrefetchStage,
|
||||
index_t MPerBlock,
|
||||
index_t NPerBlock,
|
||||
index_t K0PerBlock,
|
||||
@@ -82,7 +83,9 @@ template <index_t BlockSize,
|
||||
index_t CShuffleMRepeatPerShuffle,
|
||||
index_t CShuffleNRepeatPerShuffle,
|
||||
index_t CBlockTransferScalarPerVector_NWaveNPerXDL,
|
||||
typename CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock>
|
||||
typename CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
LoopScheduler LoopSched = make_default_loop_scheduler(),
|
||||
PipelineVersion PipelineVer = PipelineVersion::v1>
|
||||
struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
|
||||
{
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
@@ -99,8 +102,15 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
|
||||
static constexpr auto M01 = 1;
|
||||
static constexpr auto N01 = 1;
|
||||
|
||||
static constexpr auto gemm_padder =
|
||||
tensor_operation::device::GemmPadder<GemmSpec, index_t, index_t, index_t>{
|
||||
MPerBlock, NPerBlock, K1* K0PerBlock};
|
||||
|
||||
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
|
||||
|
||||
using GridwiseGemmPipe = remove_cvref_t<decltype(
|
||||
GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage, LoopSched>())>;
|
||||
|
||||
struct Argument : public ck::tensor_operation::device::BaseArgument
|
||||
{
|
||||
const FloatAB* p_a_grid;
|
||||
@@ -176,12 +186,12 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
|
||||
// prefer this to be called on host
|
||||
__host__ __device__ static auto CalculateMPadded(index_t M)
|
||||
{
|
||||
return (M + MPerBlock - 1) / MPerBlock * MPerBlock;
|
||||
return math::integer_least_multiple(M, MPerBlock);
|
||||
}
|
||||
|
||||
__host__ __device__ static auto CalculateNPadded(index_t N)
|
||||
{
|
||||
return (N + NPerBlock - 1) / NPerBlock * NPerBlock;
|
||||
return math::integer_least_multiple(N, NPerBlock);
|
||||
}
|
||||
|
||||
__host__ __device__ static auto CalculateK0(index_t K, index_t K_Batch = 1)
|
||||
@@ -295,8 +305,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
|
||||
}
|
||||
}
|
||||
|
||||
__host__ __device__ static auto
|
||||
MakeCGridDescriptor_M_N(index_t M, index_t N, index_t MPad, index_t NPad, index_t StrideC)
|
||||
__host__ __device__ static auto MakeCGridDescriptor_M_N(index_t M, index_t N, index_t StrideC)
|
||||
{
|
||||
const auto c_grid_desc_m_n = [&]() {
|
||||
if constexpr(is_same<tensor_layout::gemm::RowMajor, CLayout>::value)
|
||||
@@ -309,22 +318,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
|
||||
}
|
||||
}();
|
||||
|
||||
if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding)
|
||||
{
|
||||
return transform_tensor_descriptor(c_grid_desc_m_n,
|
||||
make_tuple(make_right_pad_transform(M, MPad - M),
|
||||
make_right_pad_transform(N, NPad - N)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
}
|
||||
else
|
||||
{
|
||||
return transform_tensor_descriptor(
|
||||
c_grid_desc_m_n,
|
||||
make_tuple(make_pass_through_transform(M), make_pass_through_transform(N)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
}
|
||||
return gemm_padder.PadCDescriptor_M_N(c_grid_desc_m_n);
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
|
||||
@@ -383,7 +377,15 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
|
||||
GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding))
|
||||
{
|
||||
if(!(karg.M % MPerBlock == 0))
|
||||
{
|
||||
#if DEBUG_LOG
|
||||
std::cout << "Arg M value is not a multiple of MPerBlock! M: " << karg.M << " "
|
||||
<< __FILE__ << ":" << __LINE__ << ", in function: " << __func__
|
||||
<< std::endl;
|
||||
|
||||
#endif // DEBUG_LOG
|
||||
return false;
|
||||
}
|
||||
}
|
||||
if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding ||
|
||||
GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding ||
|
||||
@@ -391,40 +393,116 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
|
||||
GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding))
|
||||
{
|
||||
if(!(karg.N % NPerBlock == 0))
|
||||
{
|
||||
#if DEBUG_LOG
|
||||
std::cout << "Arg N value is not a multiple of NPerBlock! N: " << karg.N << " "
|
||||
<< __FILE__ << ":" << __LINE__ << ", in function: " << __func__
|
||||
<< std::endl;
|
||||
|
||||
#endif // DEBUG_LOG
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr(is_same<tensor_layout::gemm::RowMajor, ALayout>::value)
|
||||
{
|
||||
if(karg.K % ABlockTransferSrcScalarPerVector != 0)
|
||||
{
|
||||
#if DEBUG_LOG
|
||||
std::cout << "Arg K (" << karg.K
|
||||
<< ") value is not a multiple of ABlockTransferSrcScalarPerVector ("
|
||||
<< ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
|
||||
<< __LINE__ << ", in function: " << __func__ << std::endl;
|
||||
|
||||
#endif // DEBUG_LOG
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if(karg.M % ABlockTransferSrcScalarPerVector != 0)
|
||||
{
|
||||
#if DEBUG_LOG
|
||||
std::cout << "Arg M (" << karg.M
|
||||
<< ") value is not a multiple of ABlockTransferSrcScalarPerVector ("
|
||||
<< ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
|
||||
<< __LINE__ << ", in function: " << __func__ << std::endl;
|
||||
|
||||
#endif // DEBUG_LOG
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value)
|
||||
{
|
||||
if(karg.N % BBlockTransferSrcScalarPerVector != 0)
|
||||
{
|
||||
#if DEBUG_LOG
|
||||
std::cout << "Arg N (" << karg.N
|
||||
<< ") value is not a multiple of BBlockTransferSrcScalarPerVector ("
|
||||
<< BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
|
||||
<< __LINE__ << ", in function: " << __func__ << std::endl;
|
||||
|
||||
#endif // DEBUG_LOG
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if(karg.K % BBlockTransferSrcScalarPerVector != 0)
|
||||
{
|
||||
#if DEBUG_LOG
|
||||
std::cout << "Arg K (" << karg.K
|
||||
<< ") value is not a multiple of BBlockTransferSrcScalarPerVector ("
|
||||
<< BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
|
||||
<< __LINE__ << ", in function: " << __func__ << std::endl;
|
||||
|
||||
#endif // DEBUG_LOG
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr(is_same<tensor_layout::gemm::RowMajor, CLayout>::value)
|
||||
{
|
||||
if(karg.N % CBlockTransferScalarPerVector_NWaveNPerXDL != 0)
|
||||
{
|
||||
#if DEBUG_LOG
|
||||
std::cout
|
||||
<< "Arg N (" << karg.N
|
||||
<< ") value is not a multiple of CBlockTransferScalarPerVector_NWaveNPerXDL ("
|
||||
<< CBlockTransferScalarPerVector_NWaveNPerXDL << " )! " << __FILE__ << ":"
|
||||
<< __LINE__ << ", in function: " << __func__ << std::endl;
|
||||
|
||||
#endif // DEBUG_LOG
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if(karg.M % CBlockTransferScalarPerVector_NWaveNPerXDL != 0)
|
||||
{
|
||||
#if DEBUG_LOG
|
||||
std::cout
|
||||
<< "Arg M (" << karg.M
|
||||
<< ") value is not a multiple of CBlockTransferScalarPerVector_NWaveNPerXDL ("
|
||||
<< CBlockTransferScalarPerVector_NWaveNPerXDL << " )! " << __FILE__ << ":"
|
||||
<< __LINE__ << ", in function: " << __func__ << std::endl;
|
||||
|
||||
#endif // DEBUG_LOG
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
const auto num_k_loop = karg.K0 / K0PerBlock;
|
||||
if(!GridwiseGemmPipe::IsSupported(num_k_loop))
|
||||
{
|
||||
#if DEBUG_LOG
|
||||
std::cout << "The number of k loops (" << num_k_loop
|
||||
<< ") value is not supported by GridwiseGemm Pipeline."
|
||||
<< " K0: " << karg.K0 << ", K0PerBlock: " << K0PerBlock << " " << __FILE__
|
||||
<< ":" << __LINE__ << ", in function: " << __func__ << std::endl;
|
||||
#endif // DEBUG_LOG
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
@@ -439,9 +517,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
|
||||
|
||||
__host__ __device__ static constexpr bool CalculateHasMainK0BlockLoop(index_t K0)
|
||||
{
|
||||
const bool has_main_k0_block_loop = K0 > K0PerBlock;
|
||||
|
||||
return has_main_k0_block_loop;
|
||||
const index_t num_loop = K0 / K0PerBlock;
|
||||
return GridwiseGemmPipe::CalculateHasMainLoop(num_loop);
|
||||
}
|
||||
|
||||
template <typename CGridDesc>
|
||||
@@ -490,7 +567,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
|
||||
return BlockToCTileMap_3DGrid_KSplit<MPerBlock, NPerBlock>();
|
||||
}
|
||||
|
||||
using CGridDesc_M_N = remove_cvref_t<decltype(MakeCGridDescriptor_M_N(1, 1, 1, 1, 1))>;
|
||||
using CGridDesc_M_N = remove_cvref_t<decltype(MakeCGridDescriptor_M_N(1, 1, 1))>;
|
||||
using DefaultBlock2CTileMap = remove_cvref_t<decltype(MakeDefaultBlock2CTileMap())>;
|
||||
|
||||
template <bool HasMainKBlockLoop,
|
||||
@@ -507,8 +584,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
|
||||
karg.M, karg.MPadded, karg.K, karg.StrideA, karg.k_batch, karg.K0, karg.KPadded);
|
||||
const auto b_b_k0_n_k1_grid_desc = MakeBGridDescriptor_KBatch_K0_N_K1(
|
||||
karg.K, karg.NPadded, karg.N, karg.StrideB, karg.k_batch, karg.K0, karg.KPadded);
|
||||
const auto c_grid_desc_m_n =
|
||||
MakeCGridDescriptor_M_N(karg.M, karg.N, karg.MPadded, karg.NPadded, karg.StrideC);
|
||||
const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N(karg.M, karg.N, karg.StrideC);
|
||||
|
||||
const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
|
||||
MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(c_grid_desc_m_n);
|
||||
@@ -680,20 +756,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
|
||||
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
|
||||
// register
|
||||
// sanity check
|
||||
#if 1
|
||||
auto blockwise_gemm =
|
||||
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize,
|
||||
FloatAB,
|
||||
FloatAcc,
|
||||
decltype(a_k0_m_k1_block_desc),
|
||||
decltype(b_k0_n_k1_block_desc),
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
K1>{};
|
||||
#else
|
||||
auto blockwise_gemm = BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<
|
||||
|
||||
auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector<
|
||||
BlockSize,
|
||||
FloatAB,
|
||||
FloatAcc,
|
||||
@@ -703,9 +767,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
|
||||
NPerXDL,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
K1>{};
|
||||
|
||||
#endif
|
||||
K1,
|
||||
LoopSched>();
|
||||
|
||||
auto c_thread_buf = blockwise_gemm.GetCThreadBuffer();
|
||||
|
||||
@@ -761,7 +824,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
|
||||
b_blockwise_copy.RunWrite(b_b_k0_n_k1_block_desc, b_block_buf);
|
||||
|
||||
k0_block_data_begin += K0PerBlock;
|
||||
} while(k0_block_data_begin < (K0 - K0PerBlock));
|
||||
} while(k0_block_data_begin < (karg.K0 - K0PerBlock));
|
||||
}
|
||||
|
||||
// tail
|
||||
@@ -772,13 +835,12 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
|
||||
}
|
||||
#else
|
||||
// gridwise GEMM pipeline
|
||||
const auto gridwise_gemm_pipeline =
|
||||
GridwiseGemmPipeline_Selector<PipelineVersion::v2, 1, LoopScheduler::Default>();
|
||||
|
||||
const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
|
||||
(a_b_k0_m_k1_grid_desc.GetLength(I1) * a_b_k0_m_k1_grid_desc.GetLength(I3)) /
|
||||
(K0PerBlock * K1));
|
||||
|
||||
const auto gridwise_gemm_pipeline = GridwiseGemmPipe{};
|
||||
|
||||
gridwise_gemm_pipeline.template Run<HasMainKBlockLoop>(a_b_k0_m_k1_grid_desc,
|
||||
a_b_k0_m_k1_block_desc,
|
||||
a_blockwise_copy,
|
||||
@@ -993,24 +1055,6 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Layout>
|
||||
struct LStr
|
||||
{
|
||||
static std::string Get() { return ""; }
|
||||
};
|
||||
|
||||
template <>
|
||||
struct LStr<ck::tensor_layout::gemm::RowMajor>
|
||||
{
|
||||
static std::string Get() { return "R"; }
|
||||
};
|
||||
|
||||
template <>
|
||||
struct LStr<ck::tensor_layout::gemm::ColumnMajor>
|
||||
{
|
||||
static std::string Get() { return "C"; }
|
||||
};
|
||||
|
||||
static std::string GetTypeString()
|
||||
{
|
||||
auto str = std::stringstream();
|
||||
|
||||
@@ -64,6 +64,7 @@ using device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_irregular_tile_instances = st
|
||||
//###################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
|
||||
//###################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
|
||||
//###################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
DeviceGroupedGemm_Xdl< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
|
||||
DeviceGroupedGemm_Xdl< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
|
||||
DeviceGroupedGemm_Xdl< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 2, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>,
|
||||
DeviceGroupedGemm_Xdl< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
|
||||
|
||||
@@ -44,14 +44,14 @@ using device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_tile_instanc
|
||||
DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
|
||||
DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
|
||||
DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 192, 64, 32, 8, 8, 32, 32, 3, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
|
||||
// DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 64, 192, 32, 8, 8, 32, 32, 1, 3, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 48, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
|
||||
DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 64, 192, 32, 8, 8, 32, 32, 1, 3, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 48, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
|
||||
DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
|
||||
// DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>,
|
||||
DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
|
||||
DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
|
||||
DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>,
|
||||
DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>,
|
||||
DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>,
|
||||
// DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 32, 192, 32, 8, 8, 32, 32, 1, 3, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 24, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>,
|
||||
DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 32, 192, 32, 8, 8, 32, 32, 1, 3, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 24, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>,
|
||||
DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 192, 32, 32, 8, 8, 32, 32, 3, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>,
|
||||
DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 32, 64, 32, 8, 8, 32, 32, 1, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>,
|
||||
DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 64, 32, 32, 8, 8, 32, 32, 1, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>,
|
||||
|
||||
@@ -37,7 +37,7 @@ using device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_irregular_tile_instanc
|
||||
//################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
DeviceGroupedGemmXdlSplitKCShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
|
||||
DeviceGroupedGemmXdlSplitKCShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 192, 64, 32, 8, 8, 32, 32, 3, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
|
||||
// DeviceGroupedGemmXdlSplitKCShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 64, 192, 32, 8, 8, 32, 32, 1, 3, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 48, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 4>,
|
||||
DeviceGroupedGemmXdlSplitKCShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 64, 192, 32, 8, 8, 32, 32, 1, 3, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 48, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 4>,
|
||||
DeviceGroupedGemmXdlSplitKCShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
|
||||
DeviceGroupedGemmXdlSplitKCShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
|
||||
DeviceGroupedGemmXdlSplitKCShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
|
||||
@@ -45,7 +45,7 @@ using device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_irregular_tile_instanc
|
||||
DeviceGroupedGemmXdlSplitKCShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>,
|
||||
DeviceGroupedGemmXdlSplitKCShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>,
|
||||
DeviceGroupedGemmXdlSplitKCShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 192, 32, 32, 8, 8, 32, 32, 3, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>,
|
||||
// DeviceGroupedGemmXdlSplitKCShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 32, 192, 32, 8, 8, 32, 32, 1, 3, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>,
|
||||
DeviceGroupedGemmXdlSplitKCShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 32, 192, 32, 8, 8, 32, 32, 1, 3, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>,
|
||||
DeviceGroupedGemmXdlSplitKCShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>,
|
||||
DeviceGroupedGemmXdlSplitKCShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>,
|
||||
DeviceGroupedGemmXdlSplitKCShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 32, 256, 32, 8, 8, 32, 32, 1, 4, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>,
|
||||
|
||||
@@ -246,9 +246,9 @@ bool profile_gemm_splitk_impl(int do_verification,
|
||||
}
|
||||
|
||||
std::cout << " M = " << M << " N = " << N << " K = " << K << " StrideA = " << StrideA
|
||||
<< " StrideB = " << StrideB << " StrideC = " << StrideC << " : " << best_ave_time
|
||||
<< " ms, " << best_tflops << " TFlops, " << best_gb_per_sec << " GB/s, "
|
||||
<< best_op_name << std::endl;
|
||||
<< " StrideB = " << StrideB << " StrideC = " << StrideC << " KBatch = " << KBatch
|
||||
<< " : " << best_ave_time << " ms, " << best_tflops << " TFlops, " << best_gb_per_sec
|
||||
<< " GB/s, " << best_op_name << std::endl;
|
||||
|
||||
return pass;
|
||||
}
|
||||
|
||||
@@ -19,6 +19,7 @@
|
||||
#include "ck/library/utility/host_tensor.hpp"
|
||||
#include "ck/library/utility/host_tensor_generator.hpp"
|
||||
#include "ck/library/utility/literals.hpp"
|
||||
#include "ck/library/utility/fill.hpp"
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
|
||||
|
||||
namespace ck {
|
||||
@@ -43,7 +44,6 @@ bool profile_grouped_gemm_impl(int do_verification,
|
||||
const std::vector<int>& StrideCs,
|
||||
int kbatch = 1)
|
||||
{
|
||||
|
||||
bool pass = true;
|
||||
|
||||
auto f_host_tensor_descriptor =
|
||||
@@ -81,11 +81,11 @@ bool profile_grouped_gemm_impl(int do_verification,
|
||||
|
||||
c_m_n_device_results.push_back(
|
||||
Tensor<CDataType>(f_host_tensor_descriptor(Ms[i], Ns[i], StrideCs[i], CLayout{})));
|
||||
|
||||
#if DEBUG_LOG
|
||||
std::cout << "group: " << i << " a_m_k[" << i << "]:" << a_m_k[i].mDesc << ", b_k_n[" << i
|
||||
<< "]:" << b_k_n[i].mDesc << ", c_m_n_device_results[" << i
|
||||
<< "]:" << c_m_n_device_results[i].mDesc << std::endl;
|
||||
|
||||
#endif // DEBUG_LOG
|
||||
std::size_t num_thread = 1;
|
||||
switch(init_method)
|
||||
{
|
||||
@@ -191,65 +191,71 @@ bool profile_grouped_gemm_impl(int do_verification,
|
||||
DeviceMem gemm_desc_workspace(gemm_ptr->GetWorkSpaceSize(argument_ptr.get()));
|
||||
|
||||
gemm_ptr->SetWorkSpacePointer(argument_ptr.get(), gemm_desc_workspace.GetDeviceBuffer());
|
||||
std::string gemm_name = gemm_ptr->GetTypeString();
|
||||
|
||||
if(kbatch > 1)
|
||||
{
|
||||
using DeviceOpSplitK =
|
||||
ck::tensor_operation::device::DeviceGroupedGemmSplitK<ALayout,
|
||||
BLayout,
|
||||
ck::Tuple<>,
|
||||
CLayout,
|
||||
ADataType,
|
||||
BDataType,
|
||||
ck::Tuple<>,
|
||||
CDataType,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
CElementOp>;
|
||||
|
||||
if(dynamic_cast<DeviceOpSplitK*>(gemm_ptr.get()) != nullptr)
|
||||
{
|
||||
dynamic_cast<DeviceOpSplitK*>(gemm_ptr.get())
|
||||
->SetKBatchSize(argument_ptr.get(), kbatch);
|
||||
}
|
||||
}
|
||||
|
||||
if(gemm_ptr->IsSupportedArgument(argument_ptr.get()))
|
||||
{
|
||||
std::string gemm_name = gemm_ptr->GetTypeString();
|
||||
|
||||
if(kbatch > 1)
|
||||
{
|
||||
using DeviceOpSplitK =
|
||||
ck::tensor_operation::device::DeviceGroupedGemmSplitK<ALayout,
|
||||
BLayout,
|
||||
ck::Tuple<>,
|
||||
CLayout,
|
||||
ADataType,
|
||||
BDataType,
|
||||
ck::Tuple<>,
|
||||
CDataType,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
CElementOp>;
|
||||
|
||||
if(dynamic_cast<DeviceOpSplitK*>(gemm_ptr.get()) != nullptr)
|
||||
{
|
||||
dynamic_cast<DeviceOpSplitK*>(gemm_ptr.get())
|
||||
->SetKBatchSize(argument_ptr.get(), kbatch);
|
||||
}
|
||||
}
|
||||
|
||||
float ave_time =
|
||||
invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel});
|
||||
|
||||
std::size_t flop = 0, num_btype = 0;
|
||||
for(std::size_t i = 0; i < gemm_descs.size(); i++)
|
||||
if(time_kernel)
|
||||
{
|
||||
flop += std::size_t(2) * Ms[i] * Ns[i] * Ks[i];
|
||||
std::size_t flop = 0, num_btype = 0;
|
||||
for(std::size_t i = 0; i < gemm_descs.size(); i++)
|
||||
{
|
||||
flop += std::size_t(2) * Ms[i] * Ns[i] * Ks[i];
|
||||
|
||||
num_btype += sizeof(ADataType) * Ms[i] * Ks[i] + sizeof(BDataType) * Ks[i] * Ns[i] +
|
||||
sizeof(CDataType) * Ms[i] * Ns[i];
|
||||
}
|
||||
num_btype += sizeof(ADataType) * Ms[i] * Ks[i] +
|
||||
sizeof(BDataType) * Ks[i] * Ns[i] +
|
||||
sizeof(CDataType) * Ms[i] * Ns[i];
|
||||
}
|
||||
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
|
||||
float gb_per_sec = num_btype / 1.E6 / ave_time;
|
||||
std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops << " TFlops, "
|
||||
<< gb_per_sec << " GB/s, " << gemm_name << std::endl;
|
||||
float gb_per_sec = num_btype / 1.E6 / ave_time;
|
||||
std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops
|
||||
<< " TFlops, " << gb_per_sec << " GB/s, " << gemm_name << std::endl;
|
||||
|
||||
if(tflops > best_tflops)
|
||||
{
|
||||
best_gemm_name = gemm_name;
|
||||
best_tflops = tflops;
|
||||
best_ave_time = ave_time;
|
||||
best_gb_per_sec = gb_per_sec;
|
||||
if(tflops > best_tflops)
|
||||
{
|
||||
best_gemm_name = gemm_name;
|
||||
best_tflops = tflops;
|
||||
best_ave_time = ave_time;
|
||||
best_gb_per_sec = gb_per_sec;
|
||||
}
|
||||
}
|
||||
|
||||
if(do_verification)
|
||||
{
|
||||
bool instance_pass = true;
|
||||
for(std::size_t i = 0; i < gemm_descs.size(); i++)
|
||||
{
|
||||
|
||||
c_device_buf[i]->FromDevice(c_m_n_device_results[i].mData.data());
|
||||
c_device_buf[i]->SetZero();
|
||||
|
||||
Tensor<CDataType> c_m_n_host_result(
|
||||
f_host_tensor_descriptor(Ms[i], Ns[i], StrideCs[i], CLayout{}));
|
||||
@@ -274,7 +280,20 @@ bool profile_grouped_gemm_impl(int do_verification,
|
||||
c_element_op);
|
||||
|
||||
ref_invoker.Run(ref_argument);
|
||||
pass = pass && ck::utils::check_err(c_m_n_device_results[i], c_m_n_host_result);
|
||||
if(std::is_same_v<CDataType, ck::half_t> && kbatch > 1)
|
||||
{
|
||||
instance_pass =
|
||||
instance_pass && ck::utils::check_err(c_m_n_device_results[i],
|
||||
c_m_n_host_result,
|
||||
"Error: Incorrect results!",
|
||||
0.06);
|
||||
}
|
||||
else
|
||||
{
|
||||
instance_pass =
|
||||
instance_pass &&
|
||||
ck::utils::check_err(c_m_n_device_results[i], c_m_n_host_result);
|
||||
}
|
||||
|
||||
if(do_log)
|
||||
{
|
||||
@@ -289,16 +308,25 @@ bool profile_grouped_gemm_impl(int do_verification,
|
||||
<< std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
std::cout << "Instance: " << gemm_name << " verification "
|
||||
<< (instance_pass ? "SUCCEED" : "FAILED") << std::endl;
|
||||
|
||||
pass = pass && instance_pass;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cout << "does not support this GEMM problem" << std::endl;
|
||||
std::cout << "Instance: " << gemm_name << ", does not support this GEMM problem"
|
||||
<< std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, "
|
||||
<< best_gb_per_sec << " GB/s, " << best_gemm_name << std::endl;
|
||||
if(time_kernel)
|
||||
{
|
||||
std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, "
|
||||
<< best_gb_per_sec << " GB/s, " << best_gemm_name << std::endl;
|
||||
}
|
||||
|
||||
return pass;
|
||||
}
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
if(GPU_TARGETS MATCHES "gfx908" OR GPU_TARGETS MATCHES "gfx90a" OR GPU_TARGETS MATCHES "gfx940")
|
||||
add_test_executable(test_gemm_split_k gemm_split_k.cpp)
|
||||
target_link_libraries(test_gemm_split_k PRIVATE utility)
|
||||
target_link_libraries(test_gemm_split_k PRIVATE device_gemm_splitk_instance)
|
||||
add_gtest_executable(test_gemm_splitk test_gemm_splitk.cpp)
|
||||
target_link_libraries(test_gemm_splitk PRIVATE utility device_gemm_splitk_instance)
|
||||
endif()
|
||||
|
||||
@@ -1,261 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <iostream>
|
||||
#include <initializer_list>
|
||||
#include <cstdlib>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
#include "ck/library/tensor_operation_instance/gpu/gemm_splitk.hpp"
|
||||
|
||||
#include "ck/library/utility/check_err.hpp"
|
||||
#include "ck/library/utility/device_memory.hpp"
|
||||
#include "ck/library/utility/host_tensor.hpp"
|
||||
#include "ck/library/utility/host_tensor_generator.hpp"
|
||||
#include "ck/library/utility/literals.hpp"
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
|
||||
|
||||
#include "ck/library/utility/host_gemm.hpp"
|
||||
|
||||
enum struct GemmMatrixLayout
|
||||
{
|
||||
MK_KN_MN, // 0
|
||||
MK_NK_MN, // 1
|
||||
KM_KN_MN, // 2
|
||||
KM_NK_MN, // 3
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
static bool check_out(const Tensor<T>& ref, const Tensor<T>& result)
|
||||
{
|
||||
float max_diff = 1e-6;
|
||||
|
||||
for(std::size_t i = 0; i < ref.mData.size(); ++i)
|
||||
{
|
||||
float diff = std::abs(double(ref.mData[i]) - double(result.mData[i]));
|
||||
if(max_diff < diff)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
struct gemmArgs
|
||||
{
|
||||
GemmMatrixLayout layout;
|
||||
int M;
|
||||
int N;
|
||||
int K;
|
||||
int StrideA;
|
||||
int StrideB;
|
||||
int StrideC;
|
||||
int KBatch;
|
||||
};
|
||||
|
||||
int test_gemm(const gemmArgs& args)
|
||||
{
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
bool a_row_major, b_row_major, c_row_major;
|
||||
|
||||
switch(args.layout)
|
||||
{
|
||||
case GemmMatrixLayout::MK_KN_MN:
|
||||
a_row_major = true;
|
||||
b_row_major = true;
|
||||
c_row_major = true;
|
||||
break;
|
||||
case GemmMatrixLayout::MK_NK_MN:
|
||||
a_row_major = true;
|
||||
b_row_major = false;
|
||||
c_row_major = true;
|
||||
break;
|
||||
case GemmMatrixLayout::KM_KN_MN:
|
||||
a_row_major = false;
|
||||
b_row_major = true;
|
||||
c_row_major = true;
|
||||
break;
|
||||
case GemmMatrixLayout::KM_NK_MN:
|
||||
a_row_major = false;
|
||||
b_row_major = false;
|
||||
c_row_major = true;
|
||||
break;
|
||||
default: printf("not supported layout"); return 1;
|
||||
}
|
||||
|
||||
auto f_host_tensor_descriptor =
|
||||
[](std::size_t row, std::size_t col, std::size_t stride, bool row_major) {
|
||||
using namespace ck::literals;
|
||||
|
||||
if(row_major)
|
||||
{
|
||||
return HostTensorDescriptor({row, col}, {stride, 1_uz});
|
||||
}
|
||||
else
|
||||
{
|
||||
return HostTensorDescriptor({row, col}, {1_uz, stride});
|
||||
}
|
||||
};
|
||||
|
||||
Tensor<float> a_m_k(f_host_tensor_descriptor(args.M, args.K, args.StrideA, a_row_major));
|
||||
Tensor<float> b_k_n(f_host_tensor_descriptor(args.K, args.N, args.StrideB, b_row_major));
|
||||
Tensor<float> c_m_n_host_result(
|
||||
f_host_tensor_descriptor(args.M, args.N, args.StrideC, c_row_major));
|
||||
Tensor<float> c_m_n_device_result(
|
||||
f_host_tensor_descriptor(args.M, args.N, args.StrideC, c_row_major));
|
||||
|
||||
// init data
|
||||
std::size_t num_thread = 1;
|
||||
a_m_k.GenerateTensorValue(GeneratorTensor_2<float>{-5, 5}, num_thread);
|
||||
b_k_n.GenerateTensorValue(GeneratorTensor_2<float>{-5, 5}, num_thread);
|
||||
// set zero to c_device_buf
|
||||
c_m_n_device_result.GenerateTensorValue(GeneratorTensor_0<float>{}, num_thread);
|
||||
|
||||
host_gemm_mk_kn_mn(a_m_k,
|
||||
b_k_n,
|
||||
c_m_n_host_result,
|
||||
ck::tensor_operation::element_wise::PassThrough{},
|
||||
ck::tensor_operation::element_wise::PassThrough{},
|
||||
ck::tensor_operation::element_wise::PassThrough{});
|
||||
|
||||
DeviceMem a_device_buf(sizeof(float) * a_m_k.mDesc.GetElementSpaceSize());
|
||||
DeviceMem b_device_buf(sizeof(float) * b_k_n.mDesc.GetElementSpaceSize());
|
||||
DeviceMem c_device_buf(sizeof(float) * c_m_n_device_result.mDesc.GetElementSpaceSize());
|
||||
|
||||
a_device_buf.ToDevice(a_m_k.mData.data());
|
||||
b_device_buf.ToDevice(b_k_n.mData.data());
|
||||
c_device_buf.ToDevice(c_m_n_device_result.mData.data());
|
||||
|
||||
auto test = [&](auto a_layout, auto b_layout, auto c_layout) {
|
||||
bool success = false;
|
||||
|
||||
using DeviceOp = ck::tensor_operation::device::DeviceGemmSplitK<decltype(a_layout),
|
||||
decltype(b_layout),
|
||||
decltype(c_layout),
|
||||
float,
|
||||
float,
|
||||
float,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>;
|
||||
|
||||
const auto gemm_ptrs =
|
||||
ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
|
||||
DeviceOp>::GetInstances();
|
||||
|
||||
for(auto& gemm_ptr : gemm_ptrs)
|
||||
{
|
||||
auto argument_ptr =
|
||||
gemm_ptr->MakeArgumentPointer(static_cast<float*>(a_device_buf.GetDeviceBuffer()),
|
||||
static_cast<float*>(b_device_buf.GetDeviceBuffer()),
|
||||
static_cast<float*>(c_device_buf.GetDeviceBuffer()),
|
||||
args.M,
|
||||
args.N,
|
||||
args.K,
|
||||
args.StrideA,
|
||||
args.StrideB,
|
||||
args.StrideC,
|
||||
ck::tensor_operation::element_wise::PassThrough{},
|
||||
ck::tensor_operation::element_wise::PassThrough{},
|
||||
ck::tensor_operation::element_wise::PassThrough{},
|
||||
args.KBatch);
|
||||
|
||||
auto invoker_ptr = gemm_ptr->MakeInvokerPointer();
|
||||
|
||||
if(gemm_ptr->IsSupportedArgument(argument_ptr.get()))
|
||||
{
|
||||
invoker_ptr->Run(argument_ptr.get());
|
||||
|
||||
c_device_buf.FromDevice(c_m_n_device_result.mData.data());
|
||||
|
||||
if(!check_out(c_m_n_host_result, c_m_n_device_result))
|
||||
{
|
||||
success = false;
|
||||
break;
|
||||
}
|
||||
success = true;
|
||||
}
|
||||
}
|
||||
|
||||
return success;
|
||||
};
|
||||
|
||||
bool success = false;
|
||||
|
||||
if(args.layout == GemmMatrixLayout::MK_KN_MN)
|
||||
{
|
||||
success = test(Row{}, Row{}, Row{});
|
||||
}
|
||||
else if(args.layout == GemmMatrixLayout::MK_NK_MN)
|
||||
{
|
||||
success = test(Row{}, Col{}, Row{});
|
||||
}
|
||||
else if(args.layout == GemmMatrixLayout::KM_KN_MN)
|
||||
{
|
||||
success = test(Col{}, Row{}, Row{});
|
||||
}
|
||||
else
|
||||
{
|
||||
success = test(Col{}, Col{}, Row{});
|
||||
}
|
||||
|
||||
auto error_code = 0;
|
||||
if(success)
|
||||
{
|
||||
std::cout << "test split k : Pass" << std::endl;
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cout << "test split k: Fail " << std::endl;
|
||||
error_code = -1; // test needs to report failure
|
||||
}
|
||||
return error_code;
|
||||
}
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
std::vector<gemmArgs> test_cases;
|
||||
if(argc == 1)
|
||||
{
|
||||
test_cases = {{GemmMatrixLayout::MK_KN_MN, 1024, 1024, 1024, 1024, 1024, 1024, 2},
|
||||
{GemmMatrixLayout::MK_KN_MN, 1024, 1024, 1024, 1024, 1024, 1024, 8}};
|
||||
}
|
||||
else if(argc == 9)
|
||||
{
|
||||
const auto layout = static_cast<GemmMatrixLayout>(std::stoi(argv[1]));
|
||||
|
||||
const int M = std::stoi(argv[2]);
|
||||
const int N = std::stoi(argv[3]);
|
||||
const int K = std::stoi(argv[4]);
|
||||
|
||||
const int StrideA = std::stoi(argv[5]);
|
||||
const int StrideB = std::stoi(argv[6]);
|
||||
const int StrideC = std::stoi(argv[7]);
|
||||
const int KBatch = std::stoi(argv[8]);
|
||||
test_cases = {{layout, M, N, K, StrideA, StrideB, StrideC, KBatch}};
|
||||
}
|
||||
else
|
||||
{
|
||||
printf("arg1: matrix layout (0: A[m, k] * B[k, n] = C[m, n];\n");
|
||||
printf(" 1: A[m, k] * B[n, k] = C[m, n];\n");
|
||||
printf(" 2: A[k, m] * B[k, n] = C[m, n];\n");
|
||||
printf(" 3: A[k, m] * B[n, k] = C[m, n])\n");
|
||||
printf("arg2 to 7: M, N, K, StrideA, StrideB, StrideC KBatch\n");
|
||||
return -1;
|
||||
}
|
||||
bool error = false;
|
||||
for(const auto& kinder : test_cases)
|
||||
{
|
||||
error |= test_gemm(kinder);
|
||||
}
|
||||
return error ? 1 : 0;
|
||||
}
|
||||
66
test/gemm_split_k/test_gemm_splitk.cpp
Normal file
66
test/gemm_split_k/test_gemm_splitk.cpp
Normal file
@@ -0,0 +1,66 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <tuple>
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "test_gemm_splitk_util.hpp"
|
||||
|
||||
using F16 = ck::half_t;
|
||||
using F32 = float;
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename X, typename Y>
|
||||
struct tuple_concat;
|
||||
|
||||
template <typename... Xs, typename... Ys>
|
||||
struct tuple_concat<std::tuple<Xs...>, std::tuple<Ys...>>
|
||||
{
|
||||
using type = std::tuple<Xs..., Ys...>;
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
template <typename Tuple>
|
||||
class TestGemmSplitK_MK_KN
|
||||
: public ck::test::TestGemmSplitK<typename tuple_concat<std::tuple<Row, Row>, Tuple>::type>
|
||||
{
|
||||
};
|
||||
|
||||
template <typename Tuple>
|
||||
class TestGemmSplitK_MK_NK
|
||||
: public ck::test::TestGemmSplitK<typename tuple_concat<std::tuple<Row, Col>, Tuple>::type>
|
||||
{
|
||||
};
|
||||
|
||||
template <typename Tuple>
|
||||
class TestGemmSplitK_KM_KN
|
||||
: public ck::test::TestGemmSplitK<typename tuple_concat<std::tuple<Col, Row>, Tuple>::type>
|
||||
{
|
||||
};
|
||||
|
||||
template <typename Tuple>
|
||||
class TestGemmSplitK_KM_NK
|
||||
: public ck::test::TestGemmSplitK<typename tuple_concat<std::tuple<Col, Col>, Tuple>::type>
|
||||
{
|
||||
};
|
||||
|
||||
// clang-format off
|
||||
using KernelTypes = ::testing::Types<
|
||||
// ADataType, BDataType, CDataType
|
||||
std::tuple< F16, F16, F16>,
|
||||
std::tuple< F32, F32, F32>
|
||||
>;
|
||||
// clang-format on
|
||||
|
||||
TYPED_TEST_SUITE(TestGemmSplitK_MK_KN, KernelTypes);
|
||||
TYPED_TEST_SUITE(TestGemmSplitK_MK_NK, KernelTypes);
|
||||
TYPED_TEST_SUITE(TestGemmSplitK_KM_KN, KernelTypes);
|
||||
TYPED_TEST_SUITE(TestGemmSplitK_KM_NK, KernelTypes);
|
||||
|
||||
#include "test_gemm_splitk_ut_cases.inc"
|
||||
217
test/gemm_split_k/test_gemm_splitk_ut_cases.inc
Normal file
217
test/gemm_split_k/test_gemm_splitk_ut_cases.inc
Normal file
@@ -0,0 +1,217 @@
|
||||
#pragma once
|
||||
|
||||
TYPED_TEST(TestGemmSplitK_MK_KN, SmallM)
|
||||
{
|
||||
std::vector<int> Ms{0, 1, 2, 3, 4, 5, 6};
|
||||
constexpr int N = 512;
|
||||
constexpr int K = 320;
|
||||
|
||||
constexpr int StrideA = K;
|
||||
constexpr int StrideB = N;
|
||||
constexpr int StrideC = N;
|
||||
|
||||
for(int M : Ms)
|
||||
this->Run(M, N, K, StrideA, StrideB, StrideC);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestGemmSplitK_MK_NK, SmallM)
|
||||
{
|
||||
std::vector<int> Ms{0, 1, 2, 3, 4, 5, 6};
|
||||
constexpr int N = 512;
|
||||
constexpr int K = 320;
|
||||
|
||||
constexpr int StrideA = K;
|
||||
constexpr int StrideB = K;
|
||||
constexpr int StrideC = N;
|
||||
|
||||
for(int M : Ms)
|
||||
this->Run(M, N, K, StrideA, StrideB, StrideC);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestGemmSplitK_KM_KN, SmallM)
|
||||
{
|
||||
std::vector<int> Ms{0, 1, 2, 3, 4, 5, 6};
|
||||
constexpr int N = 512;
|
||||
constexpr int K = 320;
|
||||
|
||||
constexpr int StrideB = N;
|
||||
constexpr int StrideC = N;
|
||||
|
||||
for(int M : Ms)
|
||||
this->Run(M, N, K, M, StrideB, StrideC);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestGemmSplitK_KM_NK, SmallM)
|
||||
{
|
||||
std::vector<int> Ms{0, 1, 2, 3, 4, 5, 6};
|
||||
constexpr int N = 512;
|
||||
constexpr int K = 320;
|
||||
|
||||
constexpr int StrideB = K;
|
||||
constexpr int StrideC = N;
|
||||
|
||||
for(int M : Ms)
|
||||
this->Run(M, N, K, M, StrideB, StrideC);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestGemmSplitK_MK_KN, MidLargeM)
|
||||
{
|
||||
std::vector<int> Ms{127, 255, 312, 799, 1573};
|
||||
constexpr int N = 512;
|
||||
constexpr int K = 320;
|
||||
|
||||
constexpr int StrideA = K;
|
||||
constexpr int StrideB = N;
|
||||
constexpr int StrideC = N;
|
||||
|
||||
for(int M : Ms)
|
||||
this->Run(M, N, K, StrideA, StrideB, StrideC);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestGemmSplitK_MK_NK, MidLargeM)
|
||||
{
|
||||
std::vector<int> Ms{127, 255, 312, 799, 1573};
|
||||
constexpr int N = 512;
|
||||
constexpr int K = 320;
|
||||
|
||||
constexpr int StrideA = K;
|
||||
constexpr int StrideB = K;
|
||||
constexpr int StrideC = N;
|
||||
|
||||
for(int M : Ms)
|
||||
this->Run(M, N, K, StrideA, StrideB, StrideC);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestGemmSplitK_KM_KN, MidLargeM)
|
||||
{
|
||||
std::vector<int> Ms{127, 255, 312, 799, 1573};
|
||||
constexpr int N = 512;
|
||||
constexpr int K = 320;
|
||||
|
||||
constexpr int StrideB = N;
|
||||
constexpr int StrideC = N;
|
||||
|
||||
for(int M : Ms)
|
||||
this->Run(M, N, K, M, StrideB, StrideC);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestGemmSplitK_KM_NK, MidLargeM)
|
||||
{
|
||||
std::vector<int> Ms{127, 255, 312, 799, 1573};
|
||||
constexpr int N = 512;
|
||||
constexpr int K = 320;
|
||||
|
||||
constexpr int StrideB = K;
|
||||
constexpr int StrideC = N;
|
||||
|
||||
for(int M : Ms)
|
||||
this->Run(M, N, K, M, StrideB, StrideC);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestGemmSplitK_MK_KN, PaddK)
|
||||
{
|
||||
std::vector<int> Ms{127};
|
||||
constexpr int N = 512;
|
||||
constexpr int K = 437;
|
||||
|
||||
constexpr int StrideA = K;
|
||||
constexpr int StrideB = N;
|
||||
constexpr int StrideC = N;
|
||||
|
||||
for(int M : Ms)
|
||||
this->Run(M, N, K, StrideA, StrideB, StrideC);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestGemmSplitK_MK_NK, PaddK)
|
||||
{
|
||||
std::vector<int> Ms{127};
|
||||
constexpr int N = 512;
|
||||
constexpr int K = 437;
|
||||
|
||||
constexpr int StrideA = K;
|
||||
constexpr int StrideB = K;
|
||||
constexpr int StrideC = N;
|
||||
|
||||
for(int M : Ms)
|
||||
this->Run(M, N, K, StrideA, StrideB, StrideC);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestGemmSplitK_KM_KN, PaddK)
|
||||
{
|
||||
std::vector<int> Ms{127};
|
||||
constexpr int N = 512;
|
||||
constexpr int K = 437;
|
||||
|
||||
constexpr int StrideB = N;
|
||||
constexpr int StrideC = N;
|
||||
|
||||
for(int M : Ms)
|
||||
this->Run(M, N, K, M, StrideB, StrideC);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestGemmSplitK_KM_NK, PaddK)
|
||||
{
|
||||
std::vector<int> Ms{127};
|
||||
constexpr int N = 512;
|
||||
constexpr int K = 437;
|
||||
|
||||
constexpr int StrideB = K;
|
||||
constexpr int StrideC = N;
|
||||
|
||||
for(int M : Ms)
|
||||
this->Run(M, N, K, M, StrideB, StrideC);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestGemmSplitK_MK_KN, Regular)
|
||||
{
|
||||
std::vector<int> Ms{512};
|
||||
constexpr int N = 512;
|
||||
constexpr int K = 512;
|
||||
|
||||
constexpr int StrideA = K;
|
||||
constexpr int StrideB = N;
|
||||
constexpr int StrideC = N;
|
||||
|
||||
for(int M : Ms)
|
||||
this->Run(M, N, K, StrideA, StrideB, StrideC);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestGemmSplitK_MK_NK, Regular)
|
||||
{
|
||||
std::vector<int> Ms{512};
|
||||
constexpr int N = 512;
|
||||
constexpr int K = 512;
|
||||
|
||||
constexpr int StrideA = K;
|
||||
constexpr int StrideB = K;
|
||||
constexpr int StrideC = N;
|
||||
|
||||
for(int M : Ms)
|
||||
this->Run(M, N, K, StrideA, StrideB, StrideC);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestGemmSplitK_KM_KN, Regular)
|
||||
{
|
||||
std::vector<int> Ms{512};
|
||||
constexpr int N = 512;
|
||||
constexpr int K = 512;
|
||||
|
||||
constexpr int StrideB = N;
|
||||
constexpr int StrideC = N;
|
||||
|
||||
for(int M : Ms)
|
||||
this->Run(M, N, K, M, StrideB, StrideC);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestGemmSplitK_KM_NK, Regular)
|
||||
{
|
||||
std::vector<int> Ms{512};
|
||||
constexpr int N = 512;
|
||||
constexpr int K = 512;
|
||||
|
||||
constexpr int StrideB = K;
|
||||
constexpr int StrideC = N;
|
||||
|
||||
for(int M : Ms)
|
||||
this->Run(M, N, K, M, StrideB, StrideC);
|
||||
}
|
||||
78
test/gemm_split_k/test_gemm_splitk_util.hpp
Normal file
78
test/gemm_split_k/test_gemm_splitk_util.hpp
Normal file
@@ -0,0 +1,78 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
#include <sstream>
|
||||
#include <tuple>
|
||||
#include <vector>
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "include/ck/utility/data_type.hpp"
|
||||
#include "profiler/profile_gemm_splitk_impl.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace test {
|
||||
|
||||
template <typename Tuple>
|
||||
class TestGemmSplitK : public testing::Test
|
||||
{
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using F32 = float;
|
||||
|
||||
protected:
|
||||
using ALayout = std::tuple_element_t<0, Tuple>;
|
||||
using BLayout = std::tuple_element_t<1, Tuple>;
|
||||
using CLayout = Row;
|
||||
using ADataType = std::tuple_element_t<2, Tuple>;
|
||||
using BDataType = std::tuple_element_t<3, Tuple>;
|
||||
using CDataType = std::tuple_element_t<4, Tuple>;
|
||||
|
||||
public:
|
||||
static constexpr bool verify_ = true;
|
||||
static constexpr int init_method_ = 1; // decimal value initialization
|
||||
static constexpr bool log_ = false;
|
||||
static constexpr bool bench_ = false; // measure kernel performance
|
||||
std::vector<int> k_batches_;
|
||||
|
||||
void SetUp() override { k_batches_ = {1, 2, 3, 5, 8}; }
|
||||
|
||||
void Run(const int M,
|
||||
const int N,
|
||||
const int K,
|
||||
const int StrideA,
|
||||
const int StrideB,
|
||||
const int StrideC)
|
||||
{
|
||||
for(auto kb : k_batches_)
|
||||
{
|
||||
RunSingle(M, N, K, StrideA, StrideB, StrideC, kb);
|
||||
}
|
||||
}
|
||||
|
||||
void RunSingle(const int M,
|
||||
const int N,
|
||||
const int K,
|
||||
const int StrideA,
|
||||
const int StrideB,
|
||||
const int StrideC,
|
||||
int kbatch = 1)
|
||||
{
|
||||
bool pass = ck::profiler::profile_gemm_splitk_impl<ADataType,
|
||||
BDataType,
|
||||
F32,
|
||||
CDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout>(
|
||||
verify_, init_method_, log_, bench_, M, N, K, StrideA, StrideB, StrideC, kbatch);
|
||||
EXPECT_TRUE(pass);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace test
|
||||
} // namespace ck
|
||||
@@ -1,5 +1,9 @@
|
||||
if(GPU_TARGETS MATCHES "gfx908" OR GPU_TARGETS MATCHES "gfx90a" OR GPU_TARGETS MATCHES "gfx940")
|
||||
add_test_executable(test_grouped_gemm_fp16 grouped_gemm_fp16.cpp)
|
||||
target_link_libraries(test_grouped_gemm_fp16 PRIVATE utility)
|
||||
target_link_libraries(test_grouped_gemm_fp16 PRIVATE device_grouped_gemm_instance)
|
||||
add_custom_target(test_grouped_gemm)
|
||||
add_gtest_executable(test_grouped_gemm_splitk test_grouped_gemm_splitk.cpp)
|
||||
add_gtest_executable(test_grouped_gemm_interface test_grouped_gemm_interface.cpp)
|
||||
target_link_libraries(test_grouped_gemm_splitk PRIVATE utility device_grouped_gemm_instance)
|
||||
target_link_libraries(test_grouped_gemm_interface PRIVATE utility device_grouped_gemm_instance)
|
||||
|
||||
add_dependencies(test_grouped_gemm test_grouped_gemm_splitk test_grouped_gemm_interface)
|
||||
endif()
|
||||
|
||||
@@ -1,69 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <iostream>
|
||||
#include <random>
|
||||
|
||||
#include "profiler/profile_grouped_gemm_impl.hpp"
|
||||
|
||||
namespace {
|
||||
|
||||
using ADataType = ck::half_t;
|
||||
using BDataType = ck::half_t;
|
||||
using CDataType = ck::half_t;
|
||||
using AccDataType = float;
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
template <typename ALayout, typename BLayout, typename CLayout>
|
||||
bool TestGroupedGemm()
|
||||
{
|
||||
|
||||
std::mt19937 gen(19391);
|
||||
std::uniform_int_distribution<> distrib(1, 10);
|
||||
int group_count = distrib(gen);
|
||||
|
||||
// GEMM shape
|
||||
std::vector<ck::tensor_operation::device::GemmDesc> gemm_descs;
|
||||
std::vector<const void*> p_a, p_b;
|
||||
std::vector<void*> p_c;
|
||||
|
||||
std::vector<int> Ms, Ns, Ks, StrideAs, StrideBs, StrideCs;
|
||||
|
||||
for(int i = 0; i < group_count; i++)
|
||||
{
|
||||
Ms.push_back(256 + 256 * distrib(gen));
|
||||
Ns.push_back(256 + 256 * distrib(gen));
|
||||
Ks.push_back(128 + 128 * distrib(gen));
|
||||
|
||||
StrideAs.push_back(std::is_same<Row, ALayout>::value ? Ks[i] : Ms[i]);
|
||||
StrideBs.push_back(std::is_same<Row, BLayout>::value ? Ns[i] : Ks[i]);
|
||||
StrideCs.push_back(std::is_same<Row, CLayout>::value ? Ns[i] : Ms[i]);
|
||||
}
|
||||
|
||||
return ck::profiler::profile_grouped_gemm_impl<ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
AccDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout>(
|
||||
true, 1, false, 1, Ms, Ns, Ks, StrideAs, StrideBs, StrideCs);
|
||||
}
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
int main()
|
||||
{
|
||||
bool res = true;
|
||||
|
||||
res = res && TestGroupedGemm<Row, Row, Row>();
|
||||
res = res && TestGroupedGemm<Row, Col, Row>();
|
||||
res = res && TestGroupedGemm<Col, Row, Row>();
|
||||
res = res && TestGroupedGemm<Col, Col, Row>();
|
||||
|
||||
std::cout << "TestGroupedGemm ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl;
|
||||
|
||||
return res ? 0 : 1;
|
||||
}
|
||||
202
test/grouped_gemm/test_grouped_gemm_interface.cpp
Normal file
202
test/grouped_gemm/test_grouped_gemm_interface.cpp
Normal file
@@ -0,0 +1,202 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <stdexcept>
|
||||
#include <vector>
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "test_grouped_gemm_util.hpp"
|
||||
|
||||
class TestGGemmSplitKInterface_MKNKMN : public ::testing::Test
|
||||
{
|
||||
protected:
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
using ALayout = Row;
|
||||
using BLayout = Col;
|
||||
using ELayout = Row;
|
||||
|
||||
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
|
||||
|
||||
template <ck::tensor_operation::device::GemmSpecialization GemmSpec,
|
||||
ck::index_t KPerBlock,
|
||||
ck::index_t K1,
|
||||
ck::index_t ABlockTransferSrcScalarPerVector,
|
||||
ck::index_t BBlockTransferSrcScalarPerVector,
|
||||
ck::index_t CDEBlockTransferScalarPerVector_NPerBlock>
|
||||
using GGemmInstance =
|
||||
ck::test::DeviceGroupedGemmSplitkInstanceWrapper<ALayout,
|
||||
BLayout,
|
||||
ELayout,
|
||||
GemmSpec,
|
||||
KPerBlock,
|
||||
K1,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
CDEBlockTransferScalarPerVector_NPerBlock>;
|
||||
|
||||
using DefaultGGemmInstance = GGemmInstance<GemmDefault, 32, 8, 4, 8, 8>;
|
||||
};
|
||||
|
||||
TEST_F(TestGGemmSplitKInterface_MKNKMN, TileSize)
|
||||
{
|
||||
std::vector<int> Ms{128, 256, 188, 512};
|
||||
constexpr int N = 256;
|
||||
constexpr int K = 128;
|
||||
|
||||
std::vector<int> Ns(Ms.size(), N);
|
||||
std::vector<int> Ks(Ms.size(), K);
|
||||
std::vector<int> StrideAs(Ms.size(), K);
|
||||
std::vector<int> StrideBs(Ms.size(), K);
|
||||
std::vector<int> StrideCs(Ms.size(), N);
|
||||
|
||||
// M % MPerBlock
|
||||
EXPECT_FALSE(DefaultGGemmInstance{}.IsSupported(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs));
|
||||
|
||||
Ms = std::vector<int>{256, 128, 128, 512};
|
||||
Ns = std::vector<int>{256, 177, 128, 512};
|
||||
// N % NPerBlock
|
||||
EXPECT_FALSE(DefaultGGemmInstance{}.IsSupported(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs));
|
||||
}
|
||||
|
||||
TEST_F(TestGGemmSplitKInterface_MKNKMN, VectorLoadWidth)
|
||||
{
|
||||
static constexpr auto GemmMNKPadding =
|
||||
ck::tensor_operation::device::GemmSpecialization::MNKPadding;
|
||||
using PaddedGGemmInstance = GGemmInstance<GemmMNKPadding, 32, 8, 4, 8, 8>;
|
||||
|
||||
std::vector<int> Ms{128, 256, 256, 512};
|
||||
constexpr int N = 256;
|
||||
constexpr int K = 512;
|
||||
|
||||
std::vector<int> Ns(Ms.size(), N);
|
||||
std::vector<int> Ks(Ms.size(), K);
|
||||
std::vector<int> StrideAs(Ms.size(), K);
|
||||
std::vector<int> StrideBs(Ms.size(), K);
|
||||
std::vector<int> StrideCs(Ms.size(), N);
|
||||
|
||||
// K % ABlockTransferSrcScalarPerVector
|
||||
Ks = std::vector<int>{256, 177, 128, 512};
|
||||
EXPECT_FALSE(PaddedGGemmInstance{}.IsSupported(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs));
|
||||
|
||||
Ks = std::vector<int>{256, 164, 128, 512};
|
||||
// K % BBlockTransferSrcScalarPerVector
|
||||
EXPECT_FALSE(PaddedGGemmInstance{}.IsSupported(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs));
|
||||
|
||||
Ks = std::vector<int>(4, 128);
|
||||
Ns = std::vector<int>{256, 127, 128, 512};
|
||||
// N % CBlockTransferScalarPerVector_NWaveNPerXDL
|
||||
EXPECT_FALSE(PaddedGGemmInstance{}.IsSupported(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs));
|
||||
}
|
||||
|
||||
TEST_F(TestGGemmSplitKInterface_MKNKMN, KLoops)
|
||||
{
|
||||
std::vector<int> Ms{128, 256, 256, 512};
|
||||
constexpr int N = 256;
|
||||
constexpr int K = 128;
|
||||
constexpr int kbatch = 4;
|
||||
|
||||
std::vector<int> Ns(Ms.size(), N);
|
||||
std::vector<int> Ks(Ms.size(), K);
|
||||
std::vector<int> StrideAs(Ms.size(), K);
|
||||
std::vector<int> StrideBs(Ms.size(), K);
|
||||
std::vector<int> StrideCs(Ms.size(), N);
|
||||
|
||||
// kloops % 2
|
||||
Ks = std::vector<int>{256, 512, 320, 768};
|
||||
EXPECT_FALSE(
|
||||
DefaultGGemmInstance{}.IsSupported(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, kbatch));
|
||||
|
||||
// Not all gemms have same value for main_k0_block_loop!
|
||||
Ks = std::vector<int>{256, 512, 512, 512};
|
||||
EXPECT_THROW(DefaultGGemmInstance{}.Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, kbatch),
|
||||
std::runtime_error);
|
||||
}
|
||||
|
||||
class TestGGemmSplitKInterface_KMKNNM : public ::testing::Test
|
||||
{
|
||||
protected:
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
using ALayout = Col;
|
||||
using BLayout = Row;
|
||||
using ELayout = Col;
|
||||
|
||||
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
|
||||
|
||||
template <ck::tensor_operation::device::GemmSpecialization GemmSpec,
|
||||
ck::index_t KPerBlock,
|
||||
ck::index_t K1,
|
||||
ck::index_t ABlockTransferSrcScalarPerVector,
|
||||
ck::index_t BBlockTransferSrcScalarPerVector,
|
||||
ck::index_t CDEBlockTransferScalarPerVector_NPerBlock>
|
||||
using GGemmInstance =
|
||||
ck::test::DeviceGroupedGemmSplitkInstanceWrapper<ALayout,
|
||||
BLayout,
|
||||
ELayout,
|
||||
GemmSpec,
|
||||
KPerBlock,
|
||||
K1,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
CDEBlockTransferScalarPerVector_NPerBlock>;
|
||||
|
||||
using DefaultGGemmInstance = GGemmInstance<GemmDefault, 32, 8, 4, 8, 4>;
|
||||
};
|
||||
|
||||
TEST_F(TestGGemmSplitKInterface_KMKNNM, TileSize)
|
||||
{
|
||||
std::vector<int> Ms{128, 256, 188, 512};
|
||||
constexpr int N = 256;
|
||||
constexpr int K = 128;
|
||||
|
||||
std::vector<int> Ns(Ms.size(), N);
|
||||
std::vector<int> Ks(Ms.size(), K);
|
||||
std::vector<int> StrideAs(Ms.size(), K);
|
||||
std::vector<int> StrideBs(Ms.size(), K);
|
||||
std::vector<int> StrideCs(Ms.size(), N);
|
||||
|
||||
// M % MPerBlock
|
||||
EXPECT_FALSE(DefaultGGemmInstance{}.IsSupported(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs));
|
||||
|
||||
Ms = std::vector<int>{128, 256, 256, 512};
|
||||
Ns = std::vector<int>{256, 177, 128, 512};
|
||||
// N % NPerBlock
|
||||
EXPECT_FALSE(DefaultGGemmInstance{}.IsSupported(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs));
|
||||
}
|
||||
|
||||
TEST_F(TestGGemmSplitKInterface_KMKNNM, VectorLoadWidth)
|
||||
{
|
||||
static constexpr auto GemmMNKPadding =
|
||||
ck::tensor_operation::device::GemmSpecialization::MNKPadding;
|
||||
using PaddedGGemmInstance = GGemmInstance<GemmMNKPadding, 32, 8, 2, 8, 4>;
|
||||
|
||||
std::vector<int> Ms{128, 256, 256, 512};
|
||||
constexpr int N = 256;
|
||||
constexpr int K = 512;
|
||||
|
||||
std::vector<int> Ns(Ms.size(), N);
|
||||
std::vector<int> Ks(Ms.size(), K);
|
||||
std::vector<int> StrideAs(Ms.size(), K);
|
||||
std::vector<int> StrideBs(Ms.size(), K);
|
||||
std::vector<int> StrideCs(Ms.size(), N);
|
||||
|
||||
// M % ABlockTransferSrcScalarPerVector
|
||||
Ms = std::vector<int>{256, 177, 128, 512};
|
||||
EXPECT_FALSE(PaddedGGemmInstance{}.IsSupported(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs));
|
||||
|
||||
Ms = std::vector<int>{128, 256, 256, 512};
|
||||
Ns = std::vector<int>{256, 164, 128, 512};
|
||||
// N % BBlockTransferSrcScalarPerVector
|
||||
EXPECT_FALSE(PaddedGGemmInstance{}.IsSupported(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs));
|
||||
|
||||
Ns = std::vector<int>{128, 256, 256, 512};
|
||||
Ms = std::vector<int>{256, 130, 128, 512};
|
||||
// M % CBlockTransferScalarPerVector_NWaveNPerXDL
|
||||
EXPECT_FALSE(PaddedGGemmInstance{}.IsSupported(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs));
|
||||
}
|
||||
34
test/grouped_gemm/test_grouped_gemm_splitk.cpp
Normal file
34
test/grouped_gemm/test_grouped_gemm_splitk.cpp
Normal file
@@ -0,0 +1,34 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <tuple>
|
||||
#include <vector>
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/utility/data_type.hpp"
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "test_grouped_gemm_util.hpp"
|
||||
|
||||
using F16 = ck::half_t;
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
using RRR_F16_F16_F16 = ck::test::TestGroupedGemm<std::tuple<Row, Row, Row, F16, F16, F16>>;
|
||||
using RCR_F16_F16_F16 = ck::test::TestGroupedGemm<std::tuple<Row, Col, Row, F16, F16, F16>>;
|
||||
|
||||
using RRR_F16_F16_F16_LargeK = ck::test::TestGroupedGemm<std::tuple<Row, Row, Row, F16, F16, F16>>;
|
||||
using RCR_F16_F16_F16_LargeK = ck::test::TestGroupedGemm<std::tuple<Row, Col, Row, F16, F16, F16>>;
|
||||
|
||||
const std::vector<int> KBATCH{1, 2, 3, 5, 8};
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(TestGroupedGemm_splitk_MK_KN, RRR_F16_F16_F16, testing::ValuesIn(KBATCH));
|
||||
INSTANTIATE_TEST_SUITE_P(TestGroupedGemm_splitk_MK_NK, RCR_F16_F16_F16, testing::ValuesIn(KBATCH));
|
||||
INSTANTIATE_TEST_SUITE_P(TestGroupedGemm_splitk_LargeK_MK_KN,
|
||||
RRR_F16_F16_F16_LargeK,
|
||||
testing::Values(32, 64));
|
||||
INSTANTIATE_TEST_SUITE_P(TestGroupedGemm_splitk_LargeK_MK_NK,
|
||||
RCR_F16_F16_F16_LargeK,
|
||||
testing::Values(32, 64));
|
||||
|
||||
#include "test_grouped_gemm_ut_cases.inc"
|
||||
180
test/grouped_gemm/test_grouped_gemm_ut_cases.inc
Normal file
180
test/grouped_gemm/test_grouped_gemm_ut_cases.inc
Normal file
@@ -0,0 +1,180 @@
|
||||
#pragma once
|
||||
|
||||
TEST_P(RRR_F16_F16_F16, TinyCases)
|
||||
{
|
||||
const std::vector<int> Ms{0, 1};
|
||||
constexpr int N = 768;
|
||||
constexpr int K = 544;
|
||||
|
||||
const std::vector<int> Ns(Ms.size(), N);
|
||||
const std::vector<int> Ks(Ms.size(), K);
|
||||
const std::vector<int> StrideAs(Ms.size(), K);
|
||||
const std::vector<int> StrideBs(Ms.size(), N);
|
||||
const std::vector<int> StrideCs(Ms.size(), N);
|
||||
|
||||
this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, this->GetParam());
|
||||
}
|
||||
|
||||
TEST_P(RRR_F16_F16_F16, SmallCases)
|
||||
{
|
||||
const std::vector<int> Ms{2, 1, 3, 4, 5, 0};
|
||||
constexpr int N = 768;
|
||||
constexpr int K = 544;
|
||||
|
||||
const std::vector<int> Ns(Ms.size(), N);
|
||||
const std::vector<int> Ks(Ms.size(), K);
|
||||
const std::vector<int> StrideAs(Ms.size(), K);
|
||||
const std::vector<int> StrideBs(Ms.size(), N);
|
||||
const std::vector<int> StrideCs(Ms.size(), N);
|
||||
|
||||
this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, this->GetParam());
|
||||
}
|
||||
|
||||
TEST_P(RRR_F16_F16_F16, MidCases)
|
||||
{
|
||||
const std::vector<int> Ms{167, 183, 177, 153, 139, 204};
|
||||
constexpr int N = 768;
|
||||
constexpr int K = 544;
|
||||
|
||||
const std::vector<int> Ns(Ms.size(), N);
|
||||
const std::vector<int> Ks(Ms.size(), K);
|
||||
const std::vector<int> StrideAs(Ms.size(), K);
|
||||
const std::vector<int> StrideBs(Ms.size(), N);
|
||||
const std::vector<int> StrideCs(Ms.size(), N);
|
||||
|
||||
this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, this->GetParam());
|
||||
}
|
||||
|
||||
TEST_P(RRR_F16_F16_F16, Regular)
|
||||
{
|
||||
const std::vector<int> Ms{64, 128, 256};
|
||||
constexpr int N = 768;
|
||||
constexpr int K = 320;
|
||||
|
||||
const std::vector<int> Ns(Ms.size(), N);
|
||||
const std::vector<int> Ks(Ms.size(), K);
|
||||
const std::vector<int> StrideAs(Ms.size(), K);
|
||||
const std::vector<int> StrideBs(Ms.size(), N);
|
||||
const std::vector<int> StrideCs(Ms.size(), N);
|
||||
|
||||
this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, this->GetParam());
|
||||
}
|
||||
|
||||
TEST_P(RRR_F16_F16_F16, MNKPadded)
|
||||
{
|
||||
const std::vector<int> Ms{127, 150, 188, 210};
|
||||
constexpr int N = 136;
|
||||
constexpr int K = 280;
|
||||
|
||||
const std::vector<int> Ns(Ms.size(), N);
|
||||
const std::vector<int> Ks(Ms.size(), K);
|
||||
const std::vector<int> StrideAs(Ms.size(), K);
|
||||
const std::vector<int> StrideBs(Ms.size(), N);
|
||||
const std::vector<int> StrideCs(Ms.size(), N);
|
||||
|
||||
this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, this->GetParam());
|
||||
}
|
||||
|
||||
TEST_P(RCR_F16_F16_F16, TinyCases)
|
||||
{
|
||||
const std::vector<int> Ms{0, 1};
|
||||
constexpr int N = 768;
|
||||
constexpr int K = 544;
|
||||
|
||||
const std::vector<int> Ns(Ms.size(), N);
|
||||
const std::vector<int> Ks(Ms.size(), K);
|
||||
const std::vector<int> StrideAs(Ms.size(), K);
|
||||
const std::vector<int> StrideBs(Ms.size(), K);
|
||||
const std::vector<int> StrideCs(Ms.size(), N);
|
||||
this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, this->GetParam());
|
||||
}
|
||||
|
||||
TEST_P(RCR_F16_F16_F16, SmallCases)
|
||||
{
|
||||
const std::vector<int> Ms{2, 1, 3, 4, 5, 0};
|
||||
constexpr int N = 768;
|
||||
constexpr int K = 544;
|
||||
|
||||
const std::vector<int> Ns(Ms.size(), N);
|
||||
const std::vector<int> Ks(Ms.size(), K);
|
||||
const std::vector<int> StrideAs(Ms.size(), K);
|
||||
const std::vector<int> StrideBs(Ms.size(), K);
|
||||
const std::vector<int> StrideCs(Ms.size(), N);
|
||||
|
||||
this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, this->GetParam());
|
||||
}
|
||||
|
||||
TEST_P(RCR_F16_F16_F16, MidCases)
|
||||
{
|
||||
const std::vector<int> Ms{167, 183, 177, 153, 139, 204};
|
||||
constexpr int N = 768;
|
||||
constexpr int K = 544;
|
||||
|
||||
const std::vector<int> Ns(Ms.size(), N);
|
||||
const std::vector<int> Ks(Ms.size(), K);
|
||||
const std::vector<int> StrideAs(Ms.size(), K);
|
||||
const std::vector<int> StrideBs(Ms.size(), K);
|
||||
const std::vector<int> StrideCs(Ms.size(), N);
|
||||
|
||||
this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, this->GetParam());
|
||||
}
|
||||
|
||||
TEST_P(RCR_F16_F16_F16, Regular)
|
||||
{
|
||||
const std::vector<int> Ms{32, 64, 128, 256};
|
||||
constexpr int N = 768;
|
||||
constexpr int K = 320;
|
||||
|
||||
const std::vector<int> Ns(Ms.size(), N);
|
||||
const std::vector<int> Ks(Ms.size(), K);
|
||||
const std::vector<int> StrideAs(Ms.size(), K);
|
||||
const std::vector<int> StrideBs(Ms.size(), K);
|
||||
const std::vector<int> StrideCs(Ms.size(), N);
|
||||
|
||||
this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, this->GetParam());
|
||||
}
|
||||
|
||||
TEST_P(RCR_F16_F16_F16, MNKPadded)
|
||||
{
|
||||
const std::vector<int> Ms{127, 150, 188, 210};
|
||||
constexpr int N = 136;
|
||||
constexpr int K = 280;
|
||||
|
||||
const std::vector<int> Ns(Ms.size(), N);
|
||||
const std::vector<int> Ks(Ms.size(), K);
|
||||
const std::vector<int> StrideAs(Ms.size(), K);
|
||||
const std::vector<int> StrideBs(Ms.size(), K);
|
||||
const std::vector<int> StrideCs(Ms.size(), N);
|
||||
|
||||
this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, this->GetParam());
|
||||
}
|
||||
|
||||
TEST_P(RRR_F16_F16_F16_LargeK, TestLargeKBatch)
|
||||
{
|
||||
const std::vector<int> Ms{188, 210};
|
||||
constexpr int N = 768;
|
||||
constexpr int K = 4096;
|
||||
|
||||
const std::vector<int> Ns(Ms.size(), N);
|
||||
const std::vector<int> Ks(Ms.size(), K);
|
||||
const std::vector<int> StrideAs(Ms.size(), K);
|
||||
const std::vector<int> StrideBs(Ms.size(), N);
|
||||
const std::vector<int> StrideCs(Ms.size(), N);
|
||||
|
||||
this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, this->GetParam());
|
||||
}
|
||||
|
||||
TEST_P(RCR_F16_F16_F16_LargeK, TestLargeKBatch)
|
||||
{
|
||||
const std::vector<int> Ms{188, 210};
|
||||
constexpr int N = 768;
|
||||
constexpr int K = 4096;
|
||||
|
||||
const std::vector<int> Ns(Ms.size(), N);
|
||||
const std::vector<int> Ks(Ms.size(), K);
|
||||
const std::vector<int> StrideAs(Ms.size(), K);
|
||||
const std::vector<int> StrideBs(Ms.size(), K);
|
||||
const std::vector<int> StrideCs(Ms.size(), N);
|
||||
|
||||
this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, this->GetParam());
|
||||
}
|
||||
249
test/grouped_gemm/test_grouped_gemm_util.hpp
Normal file
249
test/grouped_gemm/test_grouped_gemm_util.hpp
Normal file
@@ -0,0 +1,249 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <array>
|
||||
#include <string>
|
||||
#include <sstream>
|
||||
#include <tuple>
|
||||
#include <vector>
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/stream_config.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/library/utility/device_memory.hpp"
|
||||
#include "ck/utility/data_type.hpp"
|
||||
#include "ck/utility/sequence.hpp"
|
||||
#include "ck/utility/tuple.hpp"
|
||||
#include "ck/utility/number.hpp"
|
||||
#include "profiler/profile_grouped_gemm_impl.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace test {
|
||||
|
||||
template <typename Range>
|
||||
std::string serialize_range(const Range& range)
|
||||
{
|
||||
std::stringstream ss;
|
||||
for(auto& r : range)
|
||||
{
|
||||
ss << r << ", ";
|
||||
}
|
||||
std::string str = ss.str();
|
||||
return std::string(str.begin(), str.end() - 2);
|
||||
}
|
||||
|
||||
template <typename Tuple>
|
||||
class TestGroupedGemm : public testing::TestWithParam<int>
|
||||
{
|
||||
protected:
|
||||
using ALayout = std::tuple_element_t<0, Tuple>;
|
||||
using BLayout = std::tuple_element_t<1, Tuple>;
|
||||
using ELayout = std::tuple_element_t<2, Tuple>;
|
||||
using ADataType = std::tuple_element_t<3, Tuple>;
|
||||
using BDataType = std::tuple_element_t<4, Tuple>;
|
||||
using EDataType = std::tuple_element_t<5, Tuple>;
|
||||
|
||||
public:
|
||||
static constexpr bool verify_ = true;
|
||||
static constexpr int init_method_ = 1; // decimal value initialization
|
||||
static constexpr bool log_ = false;
|
||||
static constexpr bool bench_ = false; // measure kernel performance
|
||||
|
||||
void SetUp() override {}
|
||||
|
||||
void Run(const std::vector<int>& Ms,
|
||||
const std::vector<int>& Ns,
|
||||
const std::vector<int>& Ks,
|
||||
const std::vector<int>& StrideAs,
|
||||
const std::vector<int>& StrideBs,
|
||||
const std::vector<int>& StrideCs,
|
||||
int kbatch = 1)
|
||||
{
|
||||
bool pass = ck::profiler::profile_grouped_gemm_impl<ADataType,
|
||||
BDataType,
|
||||
EDataType,
|
||||
float,
|
||||
ALayout,
|
||||
BLayout,
|
||||
ELayout>(
|
||||
verify_, init_method_, log_, bench_, Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, kbatch);
|
||||
EXPECT_TRUE(pass);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename ALayout,
|
||||
typename BLayout,
|
||||
typename ELayout,
|
||||
tensor_operation::device::GemmSpecialization GemmSpec,
|
||||
ck::index_t KPerBlock,
|
||||
ck::index_t K1,
|
||||
ck::index_t ABlockTransferSrcScalarPerVector,
|
||||
ck::index_t BBlockTransferSrcScalarPerVector,
|
||||
index_t CDEBlockTransferScalarPerVector_NPerBlock>
|
||||
struct DeviceGroupedGemmSplitkInstanceWrapper
|
||||
{
|
||||
using F16 = half_t;
|
||||
using F32 = float;
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
using PassThrough = tensor_operation::element_wise::PassThrough;
|
||||
|
||||
using EmptyTuple = ck::Tuple<>;
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
template <ck::index_t N>
|
||||
using I = ck::Number<N>;
|
||||
|
||||
using ABlockTransferThreadClusterArrageOrder =
|
||||
std::conditional_t<std::is_same_v<ALayout, Row>, S<0, 2, 1, 3>, S<0, 1, 3, 2>>;
|
||||
using ABlockTransferSrcAccessOrder =
|
||||
std::conditional_t<std::is_same_v<ALayout, Row>, S<0, 2, 1, 3>, S<0, 1, 3, 2>>;
|
||||
using ABlockTransferSrcVectorDim = std::conditional_t<std::is_same_v<ALayout, Row>, I<3>, I<2>>;
|
||||
using ABlockTransferDstScalarPerVector_K1 =
|
||||
std::conditional_t<std::is_same_v<ALayout, Row>, I<8>, I<2>>;
|
||||
using ABlockLdsAddExtraM = std::conditional_t<std::is_same_v<ALayout, Row>, I<1>, I<0>>;
|
||||
|
||||
using BBlockTransferThreadClusterArrageOrder =
|
||||
std::conditional_t<std::is_same_v<BLayout, Row>, S<0, 1, 3, 2>, S<0, 2, 1, 3>>;
|
||||
using BBlockTransferSrcAccessOrder =
|
||||
std::conditional_t<std::is_same_v<BLayout, Row>, S<0, 1, 3, 2>, S<0, 2, 1, 3>>;
|
||||
using BBlockTransferSrcVectorDim = std::conditional_t<std::is_same_v<BLayout, Row>, I<2>, I<3>>;
|
||||
using BBlockTransferDstScalarPerVector_K1 =
|
||||
std::conditional_t<std::is_same_v<ALayout, Row>, I<2>, I<8>>;
|
||||
using BBlockLdsAddExtraM = std::conditional_t<std::is_same_v<ALayout, Row>, I<0>, I<1>>;
|
||||
|
||||
using DeviceGroupedGemmSplitKInstance =
|
||||
tensor_operation::device::DeviceGroupedGemmXdlSplitKCShuffle<
|
||||
ALayout,
|
||||
BLayout,
|
||||
EmptyTuple,
|
||||
ELayout,
|
||||
F16,
|
||||
F16,
|
||||
F32,
|
||||
F16,
|
||||
EmptyTuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
GemmSpec,
|
||||
1,
|
||||
128,
|
||||
128,
|
||||
128,
|
||||
KPerBlock,
|
||||
K1,
|
||||
K1,
|
||||
32,
|
||||
32,
|
||||
4,
|
||||
2,
|
||||
S<1, 4, 32, 1>,
|
||||
ABlockTransferThreadClusterArrageOrder,
|
||||
ABlockTransferSrcAccessOrder,
|
||||
ABlockTransferSrcVectorDim::value,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_K1::value,
|
||||
ABlockLdsAddExtraM::value,
|
||||
S<1, 4, 32, 1>,
|
||||
BBlockTransferThreadClusterArrageOrder,
|
||||
BBlockTransferSrcAccessOrder,
|
||||
BBlockTransferSrcVectorDim::value,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BBlockTransferDstScalarPerVector_K1::value,
|
||||
BBlockLdsAddExtraM::value,
|
||||
1,
|
||||
1,
|
||||
S<1, 16, 1, 8>,
|
||||
CDEBlockTransferScalarPerVector_NPerBlock>;
|
||||
|
||||
bool IsSupported(const std::vector<int>& Ms,
|
||||
const std::vector<int>& Ns,
|
||||
const std::vector<int>& Ks,
|
||||
const std::vector<int>& StrideAs,
|
||||
const std::vector<int>& StrideBs,
|
||||
const std::vector<int>& StrideCs,
|
||||
int kbatch = 1) const
|
||||
{
|
||||
std::size_t n_groups = Ms.size();
|
||||
EXPECT_TRUE(Ns.size() == n_groups && Ks.size() == n_groups && StrideAs.size() == n_groups &&
|
||||
StrideBs.size() == n_groups && StrideCs.size() == n_groups)
|
||||
<< "The number of groups is not consistent!";
|
||||
|
||||
std::vector<tensor_operation::device::GemmDesc> gemm_descs;
|
||||
|
||||
for(std::size_t i = 0; i < n_groups; ++i)
|
||||
{
|
||||
gemm_descs.push_back(tensor_operation::device::GemmDesc{
|
||||
Ms[i], Ns[i], Ks[i], StrideAs[i], StrideBs[i], StrideCs[i], {}});
|
||||
}
|
||||
|
||||
std::vector<const void*> p_As(n_groups, nullptr);
|
||||
std::vector<const void*> p_Bs(n_groups, nullptr);
|
||||
std::vector<void*> p_Cs(n_groups, nullptr);
|
||||
auto p_Ds = std::vector<std::array<const void*, 0>>{};
|
||||
|
||||
auto ggemm_instance = DeviceGroupedGemmSplitKInstance{};
|
||||
auto argument = ggemm_instance.MakeArgument(
|
||||
p_As, p_Bs, p_Ds, p_Cs, gemm_descs, PassThrough{}, PassThrough{}, PassThrough{});
|
||||
if(kbatch > 1)
|
||||
{
|
||||
ggemm_instance.SetKBatchSize(argument, kbatch);
|
||||
}
|
||||
|
||||
return ggemm_instance.IsSupportedArgument(argument);
|
||||
}
|
||||
|
||||
float Run(const std::vector<int>& Ms,
|
||||
const std::vector<int>& Ns,
|
||||
const std::vector<int>& Ks,
|
||||
const std::vector<int>& StrideAs,
|
||||
const std::vector<int>& StrideBs,
|
||||
const std::vector<int>& StrideCs,
|
||||
int kbatch = 1) const
|
||||
{
|
||||
std::size_t n_groups = Ms.size();
|
||||
EXPECT_TRUE(Ns.size() == n_groups && Ks.size() == n_groups && StrideAs.size() == n_groups &&
|
||||
StrideBs.size() == n_groups && StrideCs.size() == n_groups)
|
||||
<< "The number of groups is not consistent!";
|
||||
|
||||
std::vector<tensor_operation::device::GemmDesc> gemm_descs;
|
||||
|
||||
for(std::size_t i = 0; i < n_groups; ++i)
|
||||
{
|
||||
gemm_descs.push_back(tensor_operation::device::GemmDesc{
|
||||
Ms[i], Ns[i], Ks[i], StrideAs[i], StrideBs[i], StrideCs[i], {}});
|
||||
}
|
||||
|
||||
std::vector<const void*> p_As(n_groups, nullptr);
|
||||
std::vector<const void*> p_Bs(n_groups, nullptr);
|
||||
std::vector<void*> p_Cs(n_groups, nullptr);
|
||||
auto p_Ds = std::vector<std::array<const void*, 0>>{};
|
||||
|
||||
auto ggemm_instance = DeviceGroupedGemmSplitKInstance{};
|
||||
auto argument = ggemm_instance.MakeArgument(
|
||||
p_As, p_Bs, p_Ds, p_Cs, gemm_descs, PassThrough{}, PassThrough{}, PassThrough{});
|
||||
if(kbatch > 1)
|
||||
{
|
||||
ggemm_instance.SetKBatchSize(argument, kbatch);
|
||||
}
|
||||
|
||||
EXPECT_TRUE(ggemm_instance.IsSupportedArgument(argument));
|
||||
auto invoker = ggemm_instance.MakeInvoker();
|
||||
DeviceMem gemm_desc_workspace(ggemm_instance.GetWorkSpaceSize(&argument));
|
||||
ggemm_instance.SetWorkSpacePointer(&argument, gemm_desc_workspace.GetDeviceBuffer());
|
||||
return invoker.Run(argument, StreamConfig{nullptr, false});
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace test
|
||||
} // namespace ck
|
||||
Reference in New Issue
Block a user