mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 14:29:05 +00:00
Grouped Gemm + SplitK + simplified Kernel Args (#669)
* simplify karg in device/grid split-k op * fix mk_kn_mn instances * add more instances * B2C with 3D grid for KSplit * Remove unused code. * Use default B2C (3D grid) in grid gemm v2r4r2. * Device gemm splitk use B2C map. * Device GroupedGemmXdlSplitKCShuffle * Example for GroupedGemm Xdl SplitK * Introduce Device GroupedGemmSplitK * Fix updating kbatch size. * Add instance mk-nk-mn * Enable set kbatch in profiler. * Add GGemmSplitK mk-kn-mn instances * Add more instances & split into multiple files. * minor fix * tuning * clean * disabled failed instances * use pipe v2 * Ignore arg on not supported arch. * fix warning --------- Co-authored-by: carlushuang <carlus.huang@amd.com> Co-authored-by: Adam Osewski <aosewski@amd.com> Co-authored-by: zjing14 <zhangjing14@gmail.com> Co-authored-by: Jing Zhang <jizhan@amd.com> Co-authored-by: root <root@ctr-ubbsmc15.amd.com>
This commit is contained in:
@@ -5,6 +5,7 @@ add_example_executable(example_grouped_gemm_xdl_fp16 grouped_gemm_xdl_fp16.cpp)
|
||||
add_example_executable(example_grouped_gemm_xdl_bfp16 grouped_gemm_xdl_bfp16.cpp)
|
||||
add_example_executable(example_grouped_gemm_xdl_int8 grouped_gemm_xdl_int8.cpp)
|
||||
add_example_executable(example_grouped_gemm_multiple_d_dl_fp16 grouped_gemm_multiple_d_dl_fp16.cpp)
|
||||
add_example_executable(example_grouped_gemm_xdl_splitk_fp16 grouped_gemm_xdl_splitk_fp16.cpp)
|
||||
|
||||
|
||||
add_dependencies(example_grouped_gemm_xdl
|
||||
@@ -12,7 +13,8 @@ add_dependencies(example_grouped_gemm_xdl
|
||||
example_grouped_gemm_xdl_fp16
|
||||
example_grouped_gemm_xdl_bfp16
|
||||
example_grouped_gemm_xdl_int8
|
||||
example_grouped_gemm_multiple_d_dl_fp16)
|
||||
example_grouped_gemm_multiple_d_dl_fp16
|
||||
example_grouped_gemm_xdl_splitk_fp16)
|
||||
|
||||
if(USE_BITINT_EXTENSION_INT4)
|
||||
add_example_executable(example_grouped_gemm_xdl_int4 grouped_gemm_xdl_int4.cpp)
|
||||
|
||||
97
example/15_grouped_gemm/grouped_gemm_xdl_splitk_fp16.cpp
Normal file
97
example/15_grouped_gemm/grouped_gemm_xdl_splitk_fp16.cpp
Normal file
@@ -0,0 +1,97 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
#include <initializer_list>
|
||||
#include <cstdlib>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
#include "ck/library/utility/check_err.hpp"
|
||||
#include "ck/library/utility/device_memory.hpp"
|
||||
#include "ck/library/utility/host_tensor.hpp"
|
||||
#include "ck/library/utility/host_tensor_generator.hpp"
|
||||
#include "ck/library/utility/literals.hpp"
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
using F16 = ck::half_t;
|
||||
using F32 = float;
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
using ADataType = F16;
|
||||
using BDataType = F16;
|
||||
using AccDataType = F32;
|
||||
using CShuffleDataType = F16;
|
||||
using DsDataType = ck::Tuple<>;
|
||||
using EDataType = F16;
|
||||
|
||||
using ALayout = Row;
|
||||
using BLayout = Col;
|
||||
using DsLayout = ck::Tuple<>;
|
||||
using ELayout = Row;
|
||||
|
||||
using AElementOp = PassThrough;
|
||||
using BElementOp = PassThrough;
|
||||
using CDEElementOp = PassThrough;
|
||||
|
||||
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
|
||||
|
||||
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedGemmXdlSplitKCShuffle
|
||||
// clang-format off
|
||||
//######| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
|
||||
//######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
|
||||
//######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
|
||||
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
< ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>;
|
||||
// clang-format on
|
||||
|
||||
#include "run_grouped_gemm_example.inc"
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
ProblemSize problem_size;
|
||||
ExecutionConfig config;
|
||||
|
||||
problem_size.group_count = 16;
|
||||
|
||||
problem_size.Ms = {
|
||||
167, 183, 177, 181, 153, 139, 156, 173, 163, 150, 204, 184, 168, 156, 168, 148};
|
||||
|
||||
for(int i = 0; i < problem_size.group_count; i++)
|
||||
{
|
||||
problem_size.Ns.push_back(768);
|
||||
problem_size.Ks.push_back(4608);
|
||||
|
||||
problem_size.stride_As.push_back(problem_size.Ks[i]);
|
||||
problem_size.stride_Bs.push_back(problem_size.Ks[i]);
|
||||
problem_size.stride_Cs.push_back(problem_size.Ns[i]);
|
||||
}
|
||||
|
||||
if(argc == 4)
|
||||
{
|
||||
config.do_verification = std::stoi(argv[1]);
|
||||
config.init_method = std::stoi(argv[2]);
|
||||
config.time_kernel = std::stoi(argv[3]);
|
||||
}
|
||||
else
|
||||
{
|
||||
printf("arg1: verification (0=no, 1=yes)\n");
|
||||
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
|
||||
printf("arg3: time kernel (0=n0, 1=yes)\n");
|
||||
exit(0);
|
||||
}
|
||||
|
||||
return !run_grouped_gemm(problem_size, config);
|
||||
}
|
||||
@@ -147,6 +147,7 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
|
||||
#else
|
||||
a_tensors_device[i]->ToDevice(a_tensors[i].mData.data());
|
||||
b_tensors_device[i]->ToDevice(b_tensors[i].mData.data());
|
||||
c_tensors_device[i]->SetZero();
|
||||
#endif
|
||||
|
||||
p_a.push_back(a_tensors_device[i]->GetDeviceBuffer());
|
||||
|
||||
@@ -31,7 +31,7 @@ struct DeviceGroupedGemm : public BaseOperator
|
||||
{
|
||||
static constexpr index_t NumDTensor = DsDataType::Size();
|
||||
|
||||
static_assert(DsLayout::Size() == DsDataType::Size(), "wrong! inconsisiten NumDTensor");
|
||||
static_assert(DsLayout::Size() == DsDataType::Size(), "wrong! inconsistent NumDTensor");
|
||||
|
||||
virtual std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(std::vector<const void*>& p_a,
|
||||
|
||||
@@ -0,0 +1,39 @@
|
||||
#pragma once
|
||||
#include <iostream>
|
||||
#include <vector>
|
||||
|
||||
#include "device_grouped_gemm.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
template <typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename ELayout,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename DsDataType,
|
||||
typename EDataType,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CElementwiseOperation>
|
||||
struct DeviceGroupedGemmSplitK : public DeviceGroupedGemm<ALayout,
|
||||
BLayout,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
EDataType,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation>
|
||||
{
|
||||
virtual void SetKBatchSize(BaseArgument* p_arg, index_t kbatch) const = 0;
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -114,7 +114,8 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
|
||||
CBlockTransferScalarPerVector_NWaveNPerXDL,
|
||||
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock>;
|
||||
|
||||
using Argument = typename GridwiseGemm::Argument;
|
||||
using Argument = typename GridwiseGemm::Argument;
|
||||
using DefaultBlock2CTileMap = typename GridwiseGemm::DefaultBlock2CTileMap;
|
||||
|
||||
// Invoker
|
||||
struct Invoker : public BaseInvoker
|
||||
@@ -138,8 +139,9 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
|
||||
"setting");
|
||||
}
|
||||
|
||||
const auto b2c_map = DefaultBlock2CTileMap{};
|
||||
index_t gdx, gdy, gdz;
|
||||
std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(karg);
|
||||
std::tie(gdx, gdy, gdz) = b2c_map.CalculateGridSize(karg.M, karg.N, karg.k_batch);
|
||||
const auto K0 = karg.K0;
|
||||
|
||||
const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0);
|
||||
@@ -152,7 +154,7 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
|
||||
hipMemset(karg.p_c_grid, 0, karg.M * karg.N * sizeof(CDataType)));
|
||||
|
||||
ave_time = launch_and_time_kernel(
|
||||
stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, karg);
|
||||
stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, karg, b2c_map);
|
||||
};
|
||||
|
||||
if(has_main_k0_block_loop)
|
||||
@@ -162,7 +164,8 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
|
||||
const auto kernel =
|
||||
kernel_gemm_xdlops_v2r4r2_simplified<GridwiseGemm,
|
||||
true,
|
||||
InMemoryDataOperationEnum::Set>;
|
||||
InMemoryDataOperationEnum::Set,
|
||||
DefaultBlock2CTileMap>;
|
||||
|
||||
Run(kernel);
|
||||
}
|
||||
@@ -171,7 +174,8 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
|
||||
const auto kernel =
|
||||
kernel_gemm_xdlops_v2r4r2_simplified<GridwiseGemm,
|
||||
true,
|
||||
InMemoryDataOperationEnum::AtomicAdd>;
|
||||
InMemoryDataOperationEnum::AtomicAdd,
|
||||
DefaultBlock2CTileMap>;
|
||||
|
||||
Run(kernel);
|
||||
}
|
||||
@@ -183,7 +187,8 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
|
||||
const auto kernel =
|
||||
kernel_gemm_xdlops_v2r4r2_simplified<GridwiseGemm,
|
||||
false,
|
||||
InMemoryDataOperationEnum::Set>;
|
||||
InMemoryDataOperationEnum::Set,
|
||||
DefaultBlock2CTileMap>;
|
||||
|
||||
Run(kernel);
|
||||
}
|
||||
@@ -192,7 +197,8 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
|
||||
const auto kernel =
|
||||
kernel_gemm_xdlops_v2r4r2_simplified<GridwiseGemm,
|
||||
false,
|
||||
InMemoryDataOperationEnum::AtomicAdd>;
|
||||
InMemoryDataOperationEnum::AtomicAdd,
|
||||
DefaultBlock2CTileMap>;
|
||||
|
||||
Run(kernel);
|
||||
}
|
||||
|
||||
@@ -0,0 +1,612 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
#include "ck/host_utility/kernel_launch.hpp"
|
||||
#include "ck/host_utility/hip_check_error.hpp"
|
||||
#include "ck/utility/common_header.hpp"
|
||||
#include "ck/utility/tuple.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_grouped_gemm_splitk.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
template <typename GridwiseGemm,
|
||||
typename GemmDesc,
|
||||
bool HasMainKBlockLoop,
|
||||
InMemoryDataOperationEnum CGlobalMemoryDataOperation>
|
||||
__global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
#endif
|
||||
kernel_grouped_gemm_xdl_splitk(const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const,
|
||||
const index_t group_count)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
|
||||
constexpr index_t shared_size = GridwiseGemm::GetSharedMemoryNumberOfByte();
|
||||
__shared__ uint8_t p_shared[shared_size];
|
||||
|
||||
const index_t block_id = get_block_1d_id();
|
||||
const auto gemm_desc_ptr =
|
||||
reinterpret_cast<const GemmDesc*>(cast_pointer_to_generic_address_space(gemm_descs_const));
|
||||
|
||||
index_t left = 0;
|
||||
index_t right = group_count;
|
||||
index_t group_id = index_t((left + right) / 2);
|
||||
while((!(block_id >= gemm_desc_ptr[group_id].block_start_ &&
|
||||
block_id < gemm_desc_ptr[group_id].block_end_)) &&
|
||||
left <= right)
|
||||
{
|
||||
if(block_id < gemm_desc_ptr[group_id].block_start_)
|
||||
{
|
||||
right = group_id;
|
||||
}
|
||||
else
|
||||
{
|
||||
left = group_id;
|
||||
}
|
||||
group_id = index_t((left + right) / 2);
|
||||
}
|
||||
|
||||
GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation>(
|
||||
gemm_desc_ptr[group_id].karg_,
|
||||
static_cast<void*>(p_shared),
|
||||
gemm_desc_ptr[group_id].block_2_ctile_map_);
|
||||
#else
|
||||
ignore = gemm_descs_const;
|
||||
ignore = group_count;
|
||||
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
|
||||
}
|
||||
|
||||
template <typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename ELayout,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename AccDataType,
|
||||
typename CShuffleDataType,
|
||||
typename DsDataType,
|
||||
typename EDataType,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CDEElementwiseOperation,
|
||||
GemmSpecialization GemmSpec,
|
||||
ck::index_t NumPrefetch,
|
||||
ck::index_t BlockSize,
|
||||
ck::index_t MPerBlock,
|
||||
ck::index_t NPerBlock,
|
||||
ck::index_t KPerBlock,
|
||||
ck::index_t AK1,
|
||||
ck::index_t BK1,
|
||||
ck::index_t MPerXDL,
|
||||
ck::index_t NPerXDL,
|
||||
ck::index_t MXdlPerWave,
|
||||
ck::index_t NXdlPerWave,
|
||||
typename ABlockTransferThreadClusterLengths_K0_M_K1,
|
||||
typename ABlockTransferThreadClusterArrangeOrder,
|
||||
typename ABlockTransferSrcAccessOrder,
|
||||
ck::index_t ABlockTransferSrcVectorDim,
|
||||
ck::index_t ABlockTransferSrcScalarPerVector,
|
||||
ck::index_t ABlockTransferDstScalarPerVector_K1,
|
||||
bool ABlockLdsExtraM,
|
||||
typename BBlockTransferThreadClusterLengths_K0_N_K1,
|
||||
typename BBlockTransferThreadClusterArrangeOrder,
|
||||
typename BBlockTransferSrcAccessOrder,
|
||||
ck::index_t BBlockTransferSrcVectorDim,
|
||||
ck::index_t BBlockTransferSrcScalarPerVector,
|
||||
ck::index_t BBlockTransferDstScalarPerVector_K1,
|
||||
bool BBlockLdsExtraN,
|
||||
index_t CShuffleMXdlPerWavePerShuffle,
|
||||
index_t CShuffleNXdlPerWavePerShuffle,
|
||||
typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
index_t CDEBlockTransferScalarPerVector_NPerBlock,
|
||||
LoopScheduler LoopSched = make_default_loop_scheduler(),
|
||||
// Current implementation does not support multiple D fusions.
|
||||
enable_if_t<AK1 == BK1 && is_same_v<DsLayout, ck::Tuple<>> &&
|
||||
is_same_v<DsDataType, ck::Tuple<>>,
|
||||
bool> = false>
|
||||
struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayout,
|
||||
BLayout,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
EDataType,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CDEElementwiseOperation>
|
||||
{
|
||||
static constexpr index_t NumDTensor = DsDataType::Size();
|
||||
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
static constexpr auto I3 = Number<3>{};
|
||||
static_assert(KPerBlock % AK1 == 0);
|
||||
static constexpr index_t K0PerBlock = KPerBlock / AK1;
|
||||
|
||||
using GridwiseGemm = GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2<
|
||||
BlockSize,
|
||||
ADataType, // TODO: distinguish A/B datatype
|
||||
AccDataType,
|
||||
EDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
ELayout,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CDEElementwiseOperation,
|
||||
GemmSpec,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
K0PerBlock,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
AK1,
|
||||
MXdlPerWave,
|
||||
NXdlPerWave,
|
||||
ABlockTransferThreadClusterLengths_K0_M_K1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
ABlockTransferSrcAccessOrder,
|
||||
ABlockTransferSrcVectorDim,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_K1,
|
||||
false, // AThreadTransferSrcResetCoordinateAfterRun,
|
||||
ABlockLdsExtraM,
|
||||
BBlockTransferThreadClusterLengths_K0_N_K1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
BBlockTransferSrcAccessOrder,
|
||||
BBlockTransferSrcVectorDim,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BBlockTransferDstScalarPerVector_K1,
|
||||
false, // BThreadTransferSrcResetCoordinateAfterRun,
|
||||
BBlockLdsExtraN,
|
||||
CShuffleMXdlPerWavePerShuffle,
|
||||
CShuffleNXdlPerWavePerShuffle,
|
||||
CDEBlockTransferScalarPerVector_NPerBlock,
|
||||
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock>;
|
||||
|
||||
using CGridDesc_M_N = typename GridwiseGemm::CGridDesc_M_N;
|
||||
using Block2ETileMapKSplit =
|
||||
BlockToCTileMap_KSplit_M00_N0_M01Adapt<MPerBlock, NPerBlock, CGridDesc_M_N>;
|
||||
// Block2CTileMap configuration parameter.
|
||||
static constexpr index_t B2E_M01 = 8;
|
||||
using GroupedGemmBlock2ETileMap = OffsettedBlockToCTileMap<Block2ETileMapKSplit>;
|
||||
using KernelArgument = typename GridwiseGemm::Argument;
|
||||
|
||||
struct GemmTransKernelArg
|
||||
{
|
||||
KernelArgument karg_;
|
||||
GroupedGemmBlock2ETileMap block_2_ctile_map_;
|
||||
index_t block_start_, block_end_;
|
||||
|
||||
GemmTransKernelArg() = default;
|
||||
GemmTransKernelArg(KernelArgument&& karg,
|
||||
GroupedGemmBlock2ETileMap&& b2c_map,
|
||||
index_t block_start,
|
||||
index_t block_end)
|
||||
: karg_{karg},
|
||||
block_2_ctile_map_{b2c_map},
|
||||
block_start_{block_start},
|
||||
block_end_{block_end}
|
||||
{
|
||||
}
|
||||
};
|
||||
|
||||
static constexpr index_t DefaultKBatch = 1;
|
||||
|
||||
// Argument
|
||||
struct Argument : public BaseArgument
|
||||
{
|
||||
|
||||
Argument(std::vector<const void*>& p_As,
|
||||
std::vector<const void*>& p_Bs,
|
||||
std::vector<void*>& p_Es,
|
||||
std::vector<GemmDesc>& gemm_descs)
|
||||
: Argument(p_As, p_Bs, p_Es, gemm_descs, DefaultKBatch)
|
||||
{
|
||||
// TODO: use occupancy api to calculate appropriate batch size.
|
||||
}
|
||||
|
||||
Argument(std::vector<const void*>& p_As,
|
||||
std::vector<const void*>& p_Bs,
|
||||
std::vector<void*>& p_Es,
|
||||
std::vector<GemmDesc>& gemm_descs,
|
||||
index_t kbatch)
|
||||
: K_BATCH{kbatch}
|
||||
{
|
||||
grid_size_ = 0;
|
||||
group_count_ = ck::type_convert<ck::index_t>(gemm_descs.size());
|
||||
|
||||
if(!(group_count_ == ck::type_convert<ck::index_t>(p_As.size()) &&
|
||||
group_count_ == ck::type_convert<ck::index_t>(p_Bs.size()) &&
|
||||
group_count_ == ck::type_convert<ck::index_t>(p_Es.size())))
|
||||
{
|
||||
throw std::runtime_error("wrong! group_count_ != p_As/b/c.size");
|
||||
}
|
||||
|
||||
gemm_kernel_args_.reserve(group_count_);
|
||||
|
||||
skipped_group_count_ = 0;
|
||||
|
||||
for(std::size_t i = 0; i < gemm_descs.size(); ++i)
|
||||
{
|
||||
const index_t M = gemm_descs[i].M_;
|
||||
const index_t N = gemm_descs[i].N_;
|
||||
const index_t K = gemm_descs[i].K_;
|
||||
|
||||
if(M == 0)
|
||||
{
|
||||
skipped_group_count_++;
|
||||
continue;
|
||||
}
|
||||
|
||||
const index_t stride_a = gemm_descs[i].stride_A_;
|
||||
const index_t stride_b = gemm_descs[i].stride_B_;
|
||||
const index_t stride_c = gemm_descs[i].stride_C_;
|
||||
|
||||
const index_t m_padded = GridwiseGemm::CalculateMPadded(M);
|
||||
const index_t n_padded = GridwiseGemm::CalculateNPadded(N);
|
||||
const index_t k_padded = GridwiseGemm::CalculateKPadded(K, K_BATCH);
|
||||
const index_t k0 = GridwiseGemm::CalculateK0(K, K_BATCH);
|
||||
|
||||
const auto c_grid_desc_m_n =
|
||||
GridwiseGemm::MakeCGridDescriptor_M_N(M, N, m_padded, n_padded, stride_c);
|
||||
|
||||
const auto local_b2c_tile_map =
|
||||
Block2ETileMapKSplit{c_grid_desc_m_n, B2E_M01, K_BATCH};
|
||||
const index_t grid_size_grp = local_b2c_tile_map.CalculateGridSize(c_grid_desc_m_n);
|
||||
|
||||
const index_t block_start = grid_size_;
|
||||
const index_t block_end = grid_size_ + grid_size_grp;
|
||||
|
||||
grid_size_ += grid_size_grp;
|
||||
|
||||
// block-to-e-tile map
|
||||
auto grouped_block_2_ctile_map =
|
||||
GroupedGemmBlock2ETileMap(local_b2c_tile_map, block_start);
|
||||
|
||||
auto karg = KernelArgument{type_convert<const ADataType*>(p_As[i]),
|
||||
type_convert<const BDataType*>(p_Bs[i]),
|
||||
type_convert<EDataType*>(p_Es[i]),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
stride_a,
|
||||
stride_b,
|
||||
stride_c,
|
||||
m_padded,
|
||||
n_padded,
|
||||
k_padded,
|
||||
k0,
|
||||
K_BATCH};
|
||||
|
||||
gemm_kernel_args_.emplace_back(
|
||||
std::move(karg), std::move(grouped_block_2_ctile_map), block_start, block_end);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Recalculate group grid size for all gemms and update B2C maps.
|
||||
*
|
||||
* @param[in] kbatch The new splitK parameter value.
|
||||
*/
|
||||
void UpdateKBatch(index_t kbatch)
|
||||
{
|
||||
K_BATCH = kbatch;
|
||||
grid_size_ = 0;
|
||||
|
||||
for(std::size_t i = 0; i < gemm_kernel_args_.size(); ++i)
|
||||
{
|
||||
|
||||
auto& karg = gemm_kernel_args_[i].karg_;
|
||||
|
||||
const index_t k_padded = GridwiseGemm::CalculateKPadded(karg.K, K_BATCH);
|
||||
const index_t k0 = GridwiseGemm::CalculateK0(karg.K, K_BATCH);
|
||||
|
||||
const auto c_grid_desc_m_n = GridwiseGemm::MakeCGridDescriptor_M_N(
|
||||
karg.M, karg.N, karg.MPadded, karg.NPadded, karg.StrideC);
|
||||
|
||||
const auto local_b2c_tile_map =
|
||||
Block2ETileMapKSplit{c_grid_desc_m_n, B2E_M01, K_BATCH};
|
||||
const index_t grid_size_grp = local_b2c_tile_map.CalculateGridSize(c_grid_desc_m_n);
|
||||
|
||||
const index_t block_start = grid_size_;
|
||||
const index_t block_end = grid_size_ + grid_size_grp;
|
||||
|
||||
grid_size_ += grid_size_grp;
|
||||
|
||||
// block-to-e-tile map
|
||||
auto grouped_block_2_ctile_map =
|
||||
GroupedGemmBlock2ETileMap(local_b2c_tile_map, block_start);
|
||||
|
||||
karg.KPadded = k_padded;
|
||||
karg.K0 = k0;
|
||||
karg.k_batch = K_BATCH;
|
||||
gemm_kernel_args_[i].block_2_ctile_map_ = grouped_block_2_ctile_map;
|
||||
gemm_kernel_args_[i].block_start_ = block_start;
|
||||
gemm_kernel_args_[i].block_end_ = block_end;
|
||||
}
|
||||
}
|
||||
|
||||
// private:
|
||||
index_t K_BATCH;
|
||||
index_t group_count_;
|
||||
index_t skipped_group_count_;
|
||||
|
||||
std::vector<GemmTransKernelArg> gemm_kernel_args_;
|
||||
index_t grid_size_;
|
||||
};
|
||||
|
||||
// Invoker
|
||||
struct Invoker : public BaseInvoker
|
||||
{
|
||||
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
index_t K0 = arg.gemm_kernel_args_[0].karg_.K0;
|
||||
bool all_have_kbatch_gt_one = arg.gemm_kernel_args_[0].karg_.k_batch > 1;
|
||||
bool all_have_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0);
|
||||
|
||||
for(std::size_t i = 0; i < arg.gemm_kernel_args_.size(); ++i)
|
||||
{
|
||||
const auto& karg = arg.gemm_kernel_args_[i].karg_;
|
||||
if(stream_config.log_level_ > 0)
|
||||
{
|
||||
karg.Print();
|
||||
}
|
||||
|
||||
auto kbatch = karg.k_batch;
|
||||
|
||||
if(!GridwiseGemm::CheckValidity(karg))
|
||||
{
|
||||
std::ostringstream err;
|
||||
err << "Group id: " << i << " has invalid GridwiseGemm settings!" << __FILE__
|
||||
<< ":" << __LINE__ << ", in function: " << __func__;
|
||||
throw std::runtime_error(err.str());
|
||||
}
|
||||
|
||||
K0 = karg.K0;
|
||||
bool not_all_have_main_k0_block_loop_same =
|
||||
all_have_main_k0_block_loop xor GridwiseGemm::CalculateHasMainK0BlockLoop(K0);
|
||||
bool not_all_have_kbatch_value_same = all_have_kbatch_gt_one xor (kbatch > 1);
|
||||
|
||||
if(not_all_have_main_k0_block_loop_same)
|
||||
{
|
||||
std::ostringstream err;
|
||||
err << "Not all gemms have same value for main_k0_block_loop! in " << __FILE__
|
||||
<< ":" << __LINE__ << ", in function: " << __func__;
|
||||
throw std::runtime_error(err.str());
|
||||
}
|
||||
|
||||
if(not_all_have_kbatch_value_same)
|
||||
{
|
||||
std::ostringstream err;
|
||||
err << "Not all gemms have same kbatch value (=1 or >1)! "
|
||||
<< "group [" << i << "], kbatch: " << kbatch
|
||||
<< ", group [0], kbatch: " << arg.gemm_kernel_args_[0].karg_.k_batch
|
||||
<< " in " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__;
|
||||
throw std::runtime_error(err.str());
|
||||
}
|
||||
}
|
||||
|
||||
hip_check_error(hipMemcpy(arg.p_workspace_,
|
||||
arg.gemm_kernel_args_.data(),
|
||||
arg.gemm_kernel_args_.size() * sizeof(GemmTransKernelArg),
|
||||
hipMemcpyHostToDevice));
|
||||
|
||||
float ave_time = 0;
|
||||
|
||||
const auto Run = [&](const auto& kernel) {
|
||||
if(all_have_kbatch_gt_one)
|
||||
{
|
||||
for(const auto& trans_arg : arg.gemm_kernel_args_)
|
||||
{
|
||||
const auto& karg = trans_arg.karg_;
|
||||
hip_check_error(
|
||||
hipMemset(karg.p_c_grid, 0, karg.M * karg.N * sizeof(EDataType)));
|
||||
}
|
||||
}
|
||||
|
||||
ave_time =
|
||||
launch_and_time_kernel(stream_config,
|
||||
kernel,
|
||||
dim3(arg.grid_size_),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
cast_pointer_to_constant_address_space(arg.p_workspace_),
|
||||
arg.gemm_kernel_args_.size());
|
||||
};
|
||||
|
||||
if(all_have_main_k0_block_loop)
|
||||
{
|
||||
if(all_have_kbatch_gt_one)
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_grouped_gemm_xdl_splitk<GridwiseGemm,
|
||||
GemmTransKernelArg,
|
||||
true,
|
||||
InMemoryDataOperationEnum::AtomicAdd>;
|
||||
|
||||
Run(kernel);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_grouped_gemm_xdl_splitk<GridwiseGemm,
|
||||
GemmTransKernelArg,
|
||||
true,
|
||||
InMemoryDataOperationEnum::Set>;
|
||||
|
||||
Run(kernel);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if(all_have_kbatch_gt_one)
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_grouped_gemm_xdl_splitk<GridwiseGemm,
|
||||
GemmTransKernelArg,
|
||||
false,
|
||||
InMemoryDataOperationEnum::AtomicAdd>;
|
||||
|
||||
Run(kernel);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_grouped_gemm_xdl_splitk<GridwiseGemm,
|
||||
GemmTransKernelArg,
|
||||
false,
|
||||
InMemoryDataOperationEnum::Set>;
|
||||
|
||||
Run(kernel);
|
||||
}
|
||||
}
|
||||
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
float Run(const BaseArgument* p_arg,
|
||||
const StreamConfig& stream_config = StreamConfig{}) override
|
||||
{
|
||||
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
|
||||
}
|
||||
};
|
||||
|
||||
static constexpr bool IsValidCompilationParameter()
|
||||
{
|
||||
// TODO: properly implement this check
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
if((ck::type_convert<ck::index_t>(arg.gemm_kernel_args_.size()) +
|
||||
arg.skipped_group_count_) != arg.group_count_)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
bool supported = true;
|
||||
for(std::size_t i = 0; i < arg.gemm_kernel_args_.size(); ++i)
|
||||
{
|
||||
const auto& a = arg.gemm_kernel_args_[i].karg_;
|
||||
bool group_arg_valid = GridwiseGemm::CheckValidity(a);
|
||||
#if DEBUG_LOG
|
||||
if(not group_arg_valid)
|
||||
{
|
||||
std::cout << "[" << __func__ << "] group id: " << i << " is not supported!\n";
|
||||
a.Print();
|
||||
}
|
||||
#endif // DEBUG_LOG
|
||||
supported &= group_arg_valid;
|
||||
}
|
||||
return supported;
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
bool IsSupportedArgument(const BaseArgument* p_arg) override
|
||||
{
|
||||
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
|
||||
}
|
||||
|
||||
static auto MakeArgument(std::vector<const void*>& p_As,
|
||||
std::vector<const void*>& p_Bs,
|
||||
std::vector<std::array<const void*, NumDTensor>>&,
|
||||
std::vector<void*>& p_Es,
|
||||
std::vector<GemmDesc> gemm_descs,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CDEElementwiseOperation)
|
||||
{
|
||||
return Argument{p_As, p_Bs, p_Es, gemm_descs};
|
||||
}
|
||||
|
||||
static auto MakeInvoker() { return Invoker{}; }
|
||||
|
||||
// polymorphic
|
||||
std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(std::vector<const void*>& p_As,
|
||||
std::vector<const void*>& p_Bs,
|
||||
std::vector<std::array<const void*, NumDTensor>>&,
|
||||
std::vector<void*>& p_Es,
|
||||
std::vector<GemmDesc>& gemm_descs,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CDEElementwiseOperation) override
|
||||
{
|
||||
return std::make_unique<Argument>(p_As, p_Bs, p_Es, gemm_descs);
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
|
||||
{
|
||||
return std::make_unique<Invoker>(Invoker{});
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
std::string GetTypeString() const override
|
||||
{
|
||||
auto str = std::stringstream();
|
||||
|
||||
// clang-format off
|
||||
str << "DeviceGroupedGemm_XdlSplitK"
|
||||
<< "<"
|
||||
<< std::string(ALayout::name)[0] << ","
|
||||
<< std::string(BLayout::name)[0] << ","
|
||||
<< std::string(ELayout::name)[0] << ","
|
||||
<< BlockSize << ", "
|
||||
<< MPerBlock << ", "
|
||||
<< NPerBlock << ", "
|
||||
<< KPerBlock << ", "
|
||||
<< AK1 << ", "
|
||||
<< BK1 << ", "
|
||||
<< MPerXDL << ", "
|
||||
<< NPerXDL << ", "
|
||||
<< MXdlPerWave << ", "
|
||||
<< NXdlPerWave << ", "
|
||||
<< ABlockTransferSrcScalarPerVector << ", "
|
||||
<< BBlockTransferSrcScalarPerVector << ", "
|
||||
<< CShuffleMXdlPerWavePerShuffle << ", "
|
||||
<< CShuffleNXdlPerWavePerShuffle << ", "
|
||||
<< getGemmSpecializationString(GemmSpec)
|
||||
<< ">";
|
||||
// clang-format on
|
||||
|
||||
return str.str();
|
||||
}
|
||||
|
||||
size_t GetWorkSpaceSize(const BaseArgument* p_arg) const override
|
||||
{
|
||||
return dynamic_cast<const Argument*>(p_arg)->gemm_kernel_args_.size() *
|
||||
sizeof(GemmTransKernelArg);
|
||||
}
|
||||
|
||||
static void SetKBatchSize(Argument& arg, index_t kbatch) { arg.UpdateKBatch(kbatch); }
|
||||
|
||||
// polymorphic
|
||||
void SetKBatchSize(BaseArgument* p_arg, index_t kbatch) const override
|
||||
{
|
||||
return SetKBatchSize(*dynamic_cast<Argument*>(p_arg), kbatch);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -587,4 +587,52 @@ struct OffsettedBlockToCTileMap
|
||||
index_t block_start_;
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief Simple tile mapping which creates 3D grid of block of threads.
|
||||
*
|
||||
* @paragraph Description
|
||||
* This Block-to-C-tile-map creates a 3D grid (n_blocks, m_blocks, z_blocks) of thread
|
||||
* blocks. The first 2D are regular 2D tiles created by division of output GEMM
|
||||
* dimenions by corresponding tile size. The third dimension (Z) is a k-split dimension,
|
||||
* which denotes the number of blocks we use to divide work on GEMM K dimension onto.
|
||||
*
|
||||
* @tparam MPerBlock Output block tile size in M dimension.
|
||||
* @tparam NPerBlock Output block tile size in N dimension.
|
||||
*/
|
||||
template <index_t MPerBlock, index_t NPerBlock>
|
||||
struct BlockToCTileMap_3DGrid_KSplit
|
||||
{
|
||||
|
||||
__host__ __device__ BlockToCTileMap_3DGrid_KSplit() = default;
|
||||
|
||||
__host__ __device__ constexpr auto
|
||||
CalculateGridSize(index_t M, index_t N, index_t k_split) const
|
||||
{
|
||||
// Create 3D grid
|
||||
const auto M0 = math::integer_divide_ceil(M, MPerBlock);
|
||||
const auto N0 = math::integer_divide_ceil(N, NPerBlock);
|
||||
|
||||
return std::make_tuple(N0, M0, k_split);
|
||||
}
|
||||
|
||||
template <typename TopIdx>
|
||||
__device__ constexpr auto CalculateBottomIndex(const TopIdx&) const
|
||||
{
|
||||
return make_tuple(blockIdx.z, blockIdx.y, blockIdx.x);
|
||||
}
|
||||
|
||||
template <typename CTileIdx, typename CTileDim>
|
||||
__host__ __device__ bool ValidCTileIndex(const CTileIdx& /* c_tile_idx */,
|
||||
const CTileDim& /* c_tile_dim */) const
|
||||
{
|
||||
return true; // always valid provided that user gets grid size from CalculateGridSize()
|
||||
}
|
||||
|
||||
template <typename CGridDesc_M_N>
|
||||
__host__ bool CheckValidity(const CGridDesc_M_N& /* c_grid_desc_m_n */) const
|
||||
{
|
||||
return true;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
|
||||
@@ -15,16 +15,20 @@
|
||||
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <typename GridwiseGemm,
|
||||
bool HasMainKBlockLoop,
|
||||
InMemoryDataOperationEnum CGlobalMemoryDataOperation>
|
||||
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
|
||||
typename Block2CTileMap>
|
||||
__global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
#endif
|
||||
kernel_gemm_xdlops_v2r4r2_simplified(typename GridwiseGemm::Argument karg)
|
||||
kernel_gemm_xdlops_v2r4r2_simplified(typename GridwiseGemm::Argument karg,
|
||||
const Block2CTileMap& b2c_map)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
|
||||
constexpr index_t shared_size = GridwiseGemm::GetSharedMemoryNumberOfByte();
|
||||
@@ -32,9 +36,10 @@ __global__ void
|
||||
__shared__ uint8_t p_shared[shared_size];
|
||||
|
||||
GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation>(
|
||||
karg, static_cast<void*>(p_shared));
|
||||
karg, static_cast<void*>(p_shared), b2c_map);
|
||||
#else
|
||||
ignore = karg;
|
||||
ignore = b2c_map;
|
||||
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
|
||||
}
|
||||
|
||||
@@ -478,8 +483,21 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
|
||||
Number<CShuffleNRepeatPerShuffle * NWave * NPerXDL>{}));
|
||||
}
|
||||
|
||||
template <bool HasMainKBlockLoop, InMemoryDataOperationEnum CGlobalMemoryDataOperation>
|
||||
__device__ static void Run(const Argument& karg, void* __restrict__ p_shared_block)
|
||||
// return block_id to C matrix tile idx (m0, n0, k_split) mapping
|
||||
__host__ __device__ static constexpr auto MakeDefaultBlock2CTileMap()
|
||||
{
|
||||
return BlockToCTileMap_3DGrid_KSplit<MPerBlock, NPerBlock>();
|
||||
}
|
||||
|
||||
using CGridDesc_M_N = remove_cvref_t<decltype(MakeCGridDescriptor_M_N(1, 1, 1, 1, 1))>;
|
||||
using DefaultBlock2CTileMap = remove_cvref_t<decltype(MakeDefaultBlock2CTileMap())>;
|
||||
|
||||
template <bool HasMainKBlockLoop,
|
||||
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
|
||||
typename Block2CTileMap>
|
||||
__device__ static void Run(const Argument& karg,
|
||||
void* __restrict__ p_shared_block,
|
||||
const Block2CTileMap& block_2_ctile_map)
|
||||
{
|
||||
const FloatAB* p_a_grid = karg.p_a_grid;
|
||||
const FloatAB* p_b_grid = karg.p_b_grid;
|
||||
@@ -504,11 +522,21 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
|
||||
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
|
||||
|
||||
const auto K0 = a_b_k0_m_k1_grid_desc.GetLength(I1);
|
||||
// divide block work by [KBatch, M, N]
|
||||
const auto block_work_idx =
|
||||
block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
|
||||
|
||||
const index_t block_m_id = __builtin_amdgcn_readfirstlane(blockIdx.y);
|
||||
const index_t block_n_id = __builtin_amdgcn_readfirstlane(blockIdx.x);
|
||||
const index_t k_batch_id = __builtin_amdgcn_readfirstlane(blockIdx.z);
|
||||
if(!block_2_ctile_map.ValidCTileIndex(
|
||||
block_work_idx,
|
||||
make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0),
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2))))
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[I1]);
|
||||
const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[I2]);
|
||||
const index_t k_batch_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]);
|
||||
|
||||
// HACK: this force m/n_block_data_idx_on_grid into SGPR
|
||||
const index_t m_block_data_idx_on_grid =
|
||||
@@ -651,6 +679,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
|
||||
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
|
||||
// register
|
||||
// sanity check
|
||||
#if 1
|
||||
auto blockwise_gemm =
|
||||
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize,
|
||||
FloatAB,
|
||||
@@ -662,6 +691,20 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
K1>{};
|
||||
#else
|
||||
auto blockwise_gemm = BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<
|
||||
BlockSize,
|
||||
FloatAB,
|
||||
FloatAcc,
|
||||
decltype(a_k0_m_k1_block_desc),
|
||||
decltype(b_k0_n_k1_block_desc),
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
K1>{};
|
||||
|
||||
#endif
|
||||
|
||||
auto c_thread_buf = blockwise_gemm.GetCThreadBuffer();
|
||||
|
||||
@@ -680,6 +723,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
|
||||
auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
|
||||
p_b_block, b_k0_n_k1_block_desc.GetElementSpaceSize());
|
||||
|
||||
#if 0
|
||||
// preload data into LDS
|
||||
{
|
||||
a_blockwise_copy.RunRead(a_b_k0_m_k1_grid_desc, a_grid_buf);
|
||||
@@ -725,6 +769,31 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
|
||||
|
||||
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
|
||||
}
|
||||
#else
|
||||
// gridwise GEMM pipeline
|
||||
const auto gridwise_gemm_pipeline =
|
||||
GridwiseGemmPipeline_Selector<PipelineVersion::v2, 1, LoopScheduler::Default>();
|
||||
|
||||
const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
|
||||
(a_b_k0_m_k1_grid_desc.GetLength(I1) * a_b_k0_m_k1_grid_desc.GetLength(I3)) /
|
||||
(K0PerBlock * K1));
|
||||
|
||||
gridwise_gemm_pipeline.template Run<HasMainKBlockLoop>(a_b_k0_m_k1_grid_desc,
|
||||
a_b_k0_m_k1_block_desc,
|
||||
a_blockwise_copy,
|
||||
a_grid_buf,
|
||||
a_block_buf,
|
||||
a_block_slice_copy_step,
|
||||
b_b_k0_n_k1_grid_desc,
|
||||
b_b_k0_n_k1_block_desc,
|
||||
b_blockwise_copy,
|
||||
b_grid_buf,
|
||||
b_block_buf,
|
||||
b_block_slice_copy_step,
|
||||
blockwise_gemm,
|
||||
c_thread_buf,
|
||||
num_k_block_main_loop);
|
||||
#endif
|
||||
|
||||
// output: register to global memory
|
||||
{
|
||||
|
||||
@@ -68,6 +68,58 @@ void add_device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_instances(
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedGemm<Row,
|
||||
Col,
|
||||
Empty_Tuple,
|
||||
Row,
|
||||
F16,
|
||||
F16,
|
||||
Empty_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedGemm<Row,
|
||||
Row,
|
||||
Empty_Tuple,
|
||||
Row,
|
||||
F16,
|
||||
F16,
|
||||
Empty_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_irregular_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedGemm<Row,
|
||||
Col,
|
||||
Empty_Tuple,
|
||||
Row,
|
||||
F16,
|
||||
F16,
|
||||
Empty_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedGemm<Row,
|
||||
Row,
|
||||
Empty_Tuple,
|
||||
Row,
|
||||
F16,
|
||||
F16,
|
||||
Empty_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
template <typename ALayout,
|
||||
typename BLayout,
|
||||
typename ELayout,
|
||||
@@ -109,11 +161,17 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
|
||||
is_same_v<ELayout, Row>)
|
||||
{
|
||||
add_device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instances(op_ptrs);
|
||||
add_device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instances(op_ptrs);
|
||||
add_device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
|
||||
is_same_v<ELayout, Row>)
|
||||
{
|
||||
add_device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instances(op_ptrs);
|
||||
add_device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instances(op_ptrs);
|
||||
add_device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_irregular_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Row> &&
|
||||
is_same_v<ELayout, Row>)
|
||||
|
||||
@@ -3,4 +3,8 @@ add_instance_library(device_grouped_gemm_instance
|
||||
device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instance.cpp
|
||||
device_grouped_gemm_xdl_f16_f16_f16_km_kn_mn_instance.cpp
|
||||
device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_instance.cpp
|
||||
device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instance.cpp
|
||||
device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instance.cpp
|
||||
device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_instance.cpp
|
||||
device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_irregular_instance.cpp
|
||||
)
|
||||
|
||||
@@ -0,0 +1,80 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <cstdlib>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp"
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
using F16 = ck::half_t;
|
||||
using F32 = float;
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
using Empty_Tuple = ck::Tuple<>;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
|
||||
|
||||
// a[m, k] * b[k, n] = e[m, n]
|
||||
using device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instances = std::tuple<
|
||||
// clang-format off
|
||||
//################################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
|
||||
//################################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
|
||||
//################################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
|
||||
//################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
// Currently AK1 must equal BK1 !
|
||||
// DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 32, 8, 2, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>,
|
||||
// DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 256, 32, 8, 2, 32, 32, 2, 4, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>,
|
||||
// DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 128, 32, 8, 2, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8>,
|
||||
// DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 32, 8, 2, 32, 32, 2, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>,
|
||||
// DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 64, 32, 8, 2, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 2, 0, 1, 1, S<1, 32, 1, 4>, 8>,
|
||||
// DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 64, 128, 32, 8, 2, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8>,
|
||||
// DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 64, 32, 8, 2, 32, 32, 2, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 16,16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>,
|
||||
// DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 128, 32, 8, 2, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>,
|
||||
DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
|
||||
DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
|
||||
DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>,
|
||||
DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
|
||||
DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>,
|
||||
DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>,
|
||||
DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
|
||||
DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
void add_device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedGemm<Row,
|
||||
Row,
|
||||
Empty_Tuple,
|
||||
Row,
|
||||
F16,
|
||||
F16,
|
||||
Empty_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instances{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,87 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <cstdlib>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp"
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
using F16 = ck::half_t;
|
||||
using F32 = float;
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
using Empty_Tuple = ck::Tuple<>;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
|
||||
|
||||
using device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_tile_instances = std::tuple<
|
||||
// clang-format off
|
||||
//################################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
|
||||
//################################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
|
||||
//################################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
|
||||
//################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
// Currently AK1 must equal BK1 !
|
||||
// DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 2, 32, 32, 2, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 16,16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>,
|
||||
// DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 2, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>,
|
||||
// DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 2, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 2, 0, 1, 1, S<1, 32, 1, 4>, 8>,
|
||||
// DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 2, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8>,
|
||||
|
||||
DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
|
||||
DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
|
||||
DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 192, 64, 32, 8, 8, 32, 32, 3, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
|
||||
// DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 64, 192, 32, 8, 8, 32, 32, 1, 3, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 48, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
|
||||
DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
|
||||
// DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>,
|
||||
DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
|
||||
DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>,
|
||||
DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>,
|
||||
DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>,
|
||||
// DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 32, 192, 32, 8, 8, 32, 32, 1, 3, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 24, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>,
|
||||
DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 192, 32, 32, 8, 8, 32, 32, 3, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>,
|
||||
DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 32, 64, 32, 8, 8, 32, 32, 1, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>,
|
||||
DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 64, 32, 32, 8, 8, 32, 32, 1, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>,
|
||||
DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>,
|
||||
DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>,
|
||||
DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 64, 64, 32, 8, 8, 32, 32, 2, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>,
|
||||
DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>,
|
||||
DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>,
|
||||
DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
void add_device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedGemm<Row,
|
||||
Row,
|
||||
Empty_Tuple,
|
||||
Row,
|
||||
F16,
|
||||
F16,
|
||||
Empty_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances, device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_tile_instances{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,75 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <cstdlib>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp"
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
using F16 = ck::half_t;
|
||||
using F32 = float;
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
using Empty_Tuple = ck::Tuple<>;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
|
||||
|
||||
// a[m, k] * b[n, k] = e[m, n]
|
||||
using device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instances = std::tuple<
|
||||
// clang-format off
|
||||
//################################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
|
||||
//################################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
|
||||
//################################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
|
||||
//################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
DeviceGroupedGemmXdlSplitKCShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
|
||||
DeviceGroupedGemmXdlSplitKCShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
|
||||
DeviceGroupedGemmXdlSplitKCShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>,
|
||||
DeviceGroupedGemmXdlSplitKCShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
|
||||
DeviceGroupedGemmXdlSplitKCShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>,
|
||||
DeviceGroupedGemmXdlSplitKCShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>,
|
||||
DeviceGroupedGemmXdlSplitKCShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>,
|
||||
DeviceGroupedGemmXdlSplitKCShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
|
||||
DeviceGroupedGemmXdlSplitKCShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
|
||||
DeviceGroupedGemmXdlSplitKCShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>,
|
||||
DeviceGroupedGemmXdlSplitKCShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>,
|
||||
DeviceGroupedGemmXdlSplitKCShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>,
|
||||
DeviceGroupedGemmXdlSplitKCShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
void add_device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedGemm<Row,
|
||||
Col,
|
||||
Empty_Tuple,
|
||||
Row,
|
||||
F16,
|
||||
F16,
|
||||
Empty_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instances{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,81 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <cstdlib>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp"
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
using F16 = ck::half_t;
|
||||
using F32 = float;
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
using Empty_Tuple = ck::Tuple<>;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
|
||||
|
||||
using device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_irregular_tile_instances = std::tuple<
|
||||
// clang-format off
|
||||
//################################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
|
||||
//################################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
|
||||
//################################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
|
||||
//################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
DeviceGroupedGemmXdlSplitKCShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
|
||||
DeviceGroupedGemmXdlSplitKCShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 192, 64, 32, 8, 8, 32, 32, 3, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
|
||||
// DeviceGroupedGemmXdlSplitKCShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 64, 192, 32, 8, 8, 32, 32, 1, 3, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 48, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 4>,
|
||||
DeviceGroupedGemmXdlSplitKCShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
|
||||
DeviceGroupedGemmXdlSplitKCShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
|
||||
DeviceGroupedGemmXdlSplitKCShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
|
||||
DeviceGroupedGemmXdlSplitKCShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>,
|
||||
DeviceGroupedGemmXdlSplitKCShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>,
|
||||
DeviceGroupedGemmXdlSplitKCShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>,
|
||||
DeviceGroupedGemmXdlSplitKCShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 192, 32, 32, 8, 8, 32, 32, 3, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>,
|
||||
// DeviceGroupedGemmXdlSplitKCShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 32, 192, 32, 8, 8, 32, 32, 1, 3, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>,
|
||||
DeviceGroupedGemmXdlSplitKCShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>,
|
||||
DeviceGroupedGemmXdlSplitKCShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>,
|
||||
DeviceGroupedGemmXdlSplitKCShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 32, 256, 32, 8, 8, 32, 32, 1, 4, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>,
|
||||
DeviceGroupedGemmXdlSplitKCShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 32, 64, 32, 8, 8, 32, 32, 1, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>,
|
||||
DeviceGroupedGemmXdlSplitKCShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 64, 32, 32, 8, 8, 32, 32, 1, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>,
|
||||
DeviceGroupedGemmXdlSplitKCShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 64, 64, 32, 8, 8, 32, 32, 2, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>,
|
||||
DeviceGroupedGemmXdlSplitKCShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>,
|
||||
DeviceGroupedGemmXdlSplitKCShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>,
|
||||
DeviceGroupedGemmXdlSplitKCShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
void add_device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_irregular_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedGemm<Row,
|
||||
Col,
|
||||
Empty_Tuple,
|
||||
Row,
|
||||
F16,
|
||||
F16,
|
||||
Empty_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances, device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_irregular_tile_instances{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -8,6 +8,7 @@
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_grouped_gemm.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_grouped_gemm_splitk.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_gemm.hpp"
|
||||
@@ -39,7 +40,8 @@ bool profile_grouped_gemm_impl(int do_verification,
|
||||
const std::vector<int>& Ks,
|
||||
const std::vector<int>& StrideAs,
|
||||
const std::vector<int>& StrideBs,
|
||||
const std::vector<int>& StrideCs)
|
||||
const std::vector<int>& StrideCs,
|
||||
int kbatch = 1)
|
||||
{
|
||||
|
||||
bool pass = true;
|
||||
@@ -96,8 +98,6 @@ bool profile_grouped_gemm_impl(int do_verification,
|
||||
a_m_k[i].GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0}, num_thread);
|
||||
b_k_n[i].GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5}, num_thread);
|
||||
}
|
||||
|
||||
c_m_n_device_results[i].GenerateTensorValue(GeneratorTensor_0<CDataType>{}, num_thread);
|
||||
}
|
||||
|
||||
using AElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
@@ -132,13 +132,12 @@ bool profile_grouped_gemm_impl(int do_verification,
|
||||
std::make_unique<DeviceMem>(sizeof(ADataType) * a_m_k[i].mDesc.GetElementSpaceSize()));
|
||||
b_device_buf.emplace_back(
|
||||
std::make_unique<DeviceMem>(sizeof(BDataType) * b_k_n[i].mDesc.GetElementSpaceSize()));
|
||||
|
||||
c_device_buf.emplace_back(std::make_unique<DeviceMem>(
|
||||
sizeof(CDataType) * c_m_n_device_results[i].mDesc.GetElementSpaceSize()));
|
||||
|
||||
a_device_buf[i]->ToDevice(a_m_k[i].mData.data());
|
||||
b_device_buf[i]->ToDevice(b_k_n[i].mData.data());
|
||||
c_device_buf[i]->ToDevice(c_m_n_device_results[i].mData.data());
|
||||
c_device_buf[i]->SetZero();
|
||||
|
||||
gemm_descs.push_back({Ms[i], Ns[i], Ks[i], StrideAs[i], StrideBs[i], StrideCs[i], {}});
|
||||
|
||||
@@ -197,6 +196,28 @@ bool profile_grouped_gemm_impl(int do_verification,
|
||||
{
|
||||
std::string gemm_name = gemm_ptr->GetTypeString();
|
||||
|
||||
if(kbatch > 1)
|
||||
{
|
||||
using DeviceOpSplitK =
|
||||
ck::tensor_operation::device::DeviceGroupedGemmSplitK<ALayout,
|
||||
BLayout,
|
||||
ck::Tuple<>,
|
||||
CLayout,
|
||||
ADataType,
|
||||
BDataType,
|
||||
ck::Tuple<>,
|
||||
CDataType,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
CElementOp>;
|
||||
|
||||
if(dynamic_cast<DeviceOpSplitK*>(gemm_ptr.get()) != nullptr)
|
||||
{
|
||||
dynamic_cast<DeviceOpSplitK*>(gemm_ptr.get())
|
||||
->SetKBatchSize(argument_ptr.get(), kbatch);
|
||||
}
|
||||
}
|
||||
|
||||
float ave_time =
|
||||
invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel});
|
||||
|
||||
|
||||
@@ -52,20 +52,24 @@ std::vector<int> argToIntArray(char* input)
|
||||
|
||||
int profile_grouped_gemm(int argc, char* argv[])
|
||||
{
|
||||
if(!(argc == 14))
|
||||
if(argc < 14)
|
||||
{
|
||||
printf("arg1: tensor operation (" OP_NAME ": " OP_DESC ")\n");
|
||||
printf("arg2: data type (0: fp32; 1: fp16; 2: bf16; 3: int8)\n");
|
||||
printf("arg3: matrix layout (0: A[m, k] * B[k, n] = C[m, n];\n");
|
||||
printf(" 1: A[m, k] * B[n, k] = C[m, n];\n");
|
||||
printf(" 2: A[k, m] * B[k, n] = C[m, n];\n");
|
||||
printf(" 3: A[k, m] * B[n, k] = C[m, n])\n");
|
||||
printf("arg4: verification (0: no; 1: yes)\n");
|
||||
printf("arg5: initialization (0: no init; 1: integer value; 2: decimal value)\n");
|
||||
printf("arg6: print tensor value (0: no; 1: yes)\n");
|
||||
printf("arg7: time kernel (0=n0, 1=yes)\n");
|
||||
printf("arg8 to 13: Ms, Ns, Ks, StrideAs, StrideBs, StrideCs (e.g., 256,256 128,128 64,64 "
|
||||
"64,64 64,64 128,128)\n");
|
||||
std::cout
|
||||
<< "arg1: tensor operation (" OP_NAME ": " OP_DESC ")\n"
|
||||
<< "arg2: data type (0: fp32; 1: fp16; 2: bf16; 3: int8)\n"
|
||||
<< "arg3: matrix layout (0: A[m, k] * B[k, n] = C[m, n];\n"
|
||||
<< " 1: A[m, k] * B[n, k] = C[m, n];\n"
|
||||
<< " 2: A[k, m] * B[k, n] = C[m, n];\n"
|
||||
<< " 3: A[k, m] * B[n, k] = C[m, n])\n"
|
||||
<< "arg4: verification (0: no; 1: yes)\n"
|
||||
<< "arg5: initialization (0: no init; 1: integer value; 2: decimal value)\n"
|
||||
<< "arg6: print tensor value (0: no; 1: yes)\n"
|
||||
<< "arg7: time kernel (0=n0, 1=yes)\n"
|
||||
<< "arg8 to 13: Ms, Ns, Ks, StrideAs, StrideBs, StrideCs (e.g., 256,256 128,128 64,64 "
|
||||
"64,64 64,64 128,128)\n"
|
||||
<< "arg15: kbatch value (default 4)\n"
|
||||
<< std::endl;
|
||||
|
||||
exit(1);
|
||||
}
|
||||
|
||||
@@ -83,6 +87,7 @@ int profile_grouped_gemm(int argc, char* argv[])
|
||||
const auto StrideAs = argToIntArray(argv[11]);
|
||||
const auto StrideBs = argToIntArray(argv[12]);
|
||||
const auto StrideCs = argToIntArray(argv[13]);
|
||||
const int kbatch = argc == 15 ? std::stoi(argv[14]) : 1;
|
||||
|
||||
if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_KN_MN)
|
||||
{
|
||||
@@ -101,7 +106,8 @@ int profile_grouped_gemm(int argc, char* argv[])
|
||||
Ks,
|
||||
StrideAs,
|
||||
StrideBs,
|
||||
StrideCs);
|
||||
StrideCs,
|
||||
kbatch);
|
||||
}
|
||||
else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_NK_MN)
|
||||
{
|
||||
@@ -120,7 +126,8 @@ int profile_grouped_gemm(int argc, char* argv[])
|
||||
Ks,
|
||||
StrideAs,
|
||||
StrideBs,
|
||||
StrideCs);
|
||||
StrideCs,
|
||||
kbatch);
|
||||
}
|
||||
else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_KN_MN)
|
||||
{
|
||||
@@ -139,7 +146,8 @@ int profile_grouped_gemm(int argc, char* argv[])
|
||||
Ks,
|
||||
StrideAs,
|
||||
StrideBs,
|
||||
StrideCs);
|
||||
StrideCs,
|
||||
kbatch);
|
||||
}
|
||||
else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_NK_MN)
|
||||
{
|
||||
@@ -158,7 +166,8 @@ int profile_grouped_gemm(int argc, char* argv[])
|
||||
Ks,
|
||||
StrideAs,
|
||||
StrideBs,
|
||||
StrideCs);
|
||||
StrideCs,
|
||||
kbatch);
|
||||
}
|
||||
else
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user