Switch to universal gemm in grouped gemm tile loop (#1335)

* switch to universal gemm in grouped gemm tile loop

* minor fixes

* add reviewers comments

---------

Co-authored-by: Adam Osewski <19374865+aosewski@users.noreply.github.com>
This commit is contained in:
jakpiase
2024-06-18 16:01:49 +02:00
committed by GitHub
parent 933951ed48
commit e2d139201b
34 changed files with 1953 additions and 321 deletions

View File

@@ -13,7 +13,7 @@
#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_gemm_tile_loop_multply.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_gemm_tile_loop_multiply.hpp"
#include "ck/host_utility/hip_check_error.hpp"

View File

@@ -13,7 +13,7 @@
#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_gemm_tile_loop_multply.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_gemm_tile_loop_multiply.hpp"
#include "ck/host_utility/hip_check_error.hpp"

View File

@@ -63,7 +63,7 @@ using DeviceGemmInstance =
//######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 4>;
< ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<4,4,4>>;
// clang-format on
struct ProblemSize final

View File

@@ -144,12 +144,12 @@ struct BlockwiseGemmXdlops_pipeline_v1<BlockGemmPipelineScheduler::Intrawave,
static constexpr index_t PrefillStages = 1;
static constexpr index_t GlobalBufferNum = 1;
__host__ static constexpr bool BlockHasHotloop(index_t num_loop)
__host__ __device__ static constexpr bool BlockHasHotloop(index_t num_loop)
{
return num_loop > PrefetchStages;
}
__host__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop)
__host__ __device__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop)
{
ignore = num_loop;
return TailNumber::Full;
@@ -446,12 +446,12 @@ struct BlockwiseGemmXdlops_pipeline_v1<BlockGemmPipelineScheduler::Interwave,
static constexpr index_t PrefetchStages = 1;
static constexpr index_t PrefillStages = 1;
static constexpr index_t GlobalBufferNum = 1;
__host__ static constexpr bool BlockHasHotloop(index_t num_loop)
__host__ __device__ static constexpr bool BlockHasHotloop(index_t num_loop)
{
return num_loop > PrefetchStages;
}
__host__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop)
__host__ __device__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop)
{
ignore = num_loop;
return TailNumber::Full;

View File

@@ -153,12 +153,12 @@ struct BlockwiseGemmXdlops_pipeline_v2<BlockGemmPipelineScheduler::Intrawave,
static constexpr index_t PrefillStages = 1;
static constexpr index_t GlobalBufferNum = PrefetchStages;
__host__ static constexpr bool BlockHasHotloop(index_t num_loop)
__host__ __device__ static constexpr bool BlockHasHotloop(index_t num_loop)
{
return num_loop > PrefetchStages;
}
__host__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop)
__host__ __device__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop)
{
if(num_loop % PrefetchStages == 1)
{
@@ -646,12 +646,12 @@ struct BlockwiseGemmXdlops_pipeline_v2<BlockGemmPipelineScheduler::Interwave,
static constexpr index_t PrefillStages = 1;
static constexpr index_t GlobalBufferNum = PrefetchStages;
__host__ static constexpr bool BlockHasHotloop(index_t num_loop)
__host__ __device__ static constexpr bool BlockHasHotloop(index_t num_loop)
{
return num_loop > PrefetchStages;
}
__host__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop)
__host__ __device__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop)
{
if(num_loop % PrefetchStages == 1)
{

View File

@@ -146,12 +146,12 @@ struct BlockwiseGemmXdlops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
static constexpr index_t PrefillStages = 1;
static constexpr index_t GlobalBufferNum = 1;
__host__ static constexpr bool BlockHasHotloop(index_t num_loop)
__host__ __device__ static constexpr bool BlockHasHotloop(index_t num_loop)
{
return num_loop > PrefetchStages;
}
__host__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop)
__host__ __device__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop)
{
ignore = num_loop;
return TailNumber::Full;

View File

@@ -147,12 +147,12 @@ struct BlockwiseGemmXdlops_pipeline_v4<BlockGemmPipelineScheduler::Intrawave,
static constexpr index_t GlobalBufferNum = 2;
static constexpr index_t HotloopUnroll = 2;
__host__ static constexpr bool BlockHasHotloop(index_t num_loop)
__host__ __device__ static constexpr bool BlockHasHotloop(index_t num_loop)
{
return num_loop > PrefetchStages;
}
__host__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop)
__host__ __device__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop)
{
if(num_loop % HotloopUnroll == 1)
{

View File

@@ -19,6 +19,7 @@
#include "ck/tensor_operation/gpu/device/device_grouped_gemm_tile_loop.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include <ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp>
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp" // stare wywalic
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp"
@@ -42,16 +43,22 @@ namespace device {
template <typename GridwiseGemm,
typename GemmDesc,
GemmSpecialization GemmSpec,
typename ADataType,
typename BDataType,
typename DsDataType,
typename EDataType,
typename ALayout,
typename BLayout,
typename DsLayout,
typename ELayout,
index_t KPerBlock,
typename OffsettedBlockToCTileMap,
typename LocalBlock2ETileMap,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation>
typename CDEElementwiseOperation,
BlockGemmPipelineScheduler BlkGemmPipeSched,
BlockGemmPipelineVersion BlkGemmPipelineVer>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
@@ -67,6 +74,7 @@ __global__ void
constexpr index_t shared_size = GridwiseGemm::GetSharedMemoryNumberOfByte();
__shared__ uint8_t p_shared[shared_size];
__shared__ uint8_t p_shared1[shared_size];
const auto gemm_desc_ptr =
reinterpret_cast<const GemmDesc*>(cast_pointer_to_generic_address_space(gemm_descs_const));
@@ -81,27 +89,8 @@ __global__ void
index_t gemm_tile_id_start = 0;
index_t gemm_tile_id_end = 0;
using AGridDescMK =
remove_cvref_t<decltype(GridwiseGemm::template MakeAGridDescriptor_M_K<ALayout, GemmSpec>(
1, 1, 1))>;
using BGridDescNK =
remove_cvref_t<decltype(GridwiseGemm::template MakeBGridDescriptor_N_K<BLayout, GemmSpec>(
1, 1, 1))>;
using EGridDescMN =
remove_cvref_t<decltype(GridwiseGemm::template MakeEGridDescriptor_M_N<ELayout, GemmSpec>(
1, 1, 1))>;
using DsGridDescMN =
remove_cvref_t<decltype(GridwiseGemm::template MakeDsGridDescriptor_M_N<DsLayout, GemmSpec>(
{}, {}, {}))>;
index_t M = 0, N = 0, K = 0;
index_t StrideA, StrideB, StrideE;
std::array<index_t, NumDTensor> StrideDs;
AGridDescMK a_grid_desc_mk;
BGridDescNK b_grid_desc_nk;
EGridDescMN e_grid_desc_mn;
DsGridDescMN ds_grid_desc_mn;
auto b2c_tile_map = OffsettedBlockToCTileMap(LocalBlock2ETileMap(1, 1), 1, 1);
do
@@ -127,31 +116,13 @@ __global__ void
}
b2c_tile_map =
OffsettedBlockToCTileMap(LocalBlock2ETileMap(M, N), group_offset, tile_offset);
OffsettedBlockToCTileMap(LocalBlock2ETileMap(M, N, 4), group_offset, tile_offset);
grid_size_grp = b2c_tile_map.CalculateGridSize(M, N);
gemm_tile_id_start = group_offset;
gemm_tile_id_end = group_offset + grid_size_grp;
}
StrideA = gemm_desc_ptr[group_id].StrideA;
StrideB = gemm_desc_ptr[group_id].StrideB;
StrideDs = gemm_desc_ptr[group_id].StrideDs;
StrideE = gemm_desc_ptr[group_id].StrideE;
a_grid_desc_mk =
GridwiseGemm::template MakeAGridDescriptor_M_K<ALayout, GemmSpec>(M, K, StrideA);
b_grid_desc_nk =
GridwiseGemm::template MakeBGridDescriptor_N_K<BLayout, GemmSpec>(K, N, StrideB);
e_grid_desc_mn =
GridwiseGemm::template MakeEGridDescriptor_M_N<ELayout, GemmSpec>(M, N, StrideE);
static_for<0, NumDTensor, 1>{}([&](auto j) {
using DLayout = remove_cvref_t<tuple_element_t<j.value, DsLayout>>;
ds_grid_desc_mn(j) = GridwiseGemm::template MakeEGridDescriptor_M_N<DLayout, GemmSpec>(
M, N, StrideDs[j]);
});
using DsGridPointer = decltype(GridwiseGemm::MakeDsGridPointer());
DsGridPointer p_ds_grid;
@@ -160,42 +131,268 @@ __global__ void
p_ds_grid(i) = static_cast<const DDataType*>(gemm_desc_ptr[group_id].p_ds_grid[i]);
});
bool has_main_kblock_loop =
GridwiseGemm::CalculateHasMainKBlockLoop(a_grid_desc_mk.GetLength(Number<1>{}));
static constexpr index_t kbatch = 1;
static constexpr index_t k_grain = kbatch * KPerBlock;
index_t K_split = (K + k_grain - 1) / k_grain * KPerBlock;
const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split);
// Update tile offset if we have moved within group
b2c_tile_map.UpdateTileOffset(tile_offset);
if(has_main_kblock_loop)
using Problem = typename GridwiseGemm::Problem;
auto problem = Problem(gemm_desc_ptr[group_id].M,
gemm_desc_ptr[group_id].N,
gemm_desc_ptr[group_id].K,
gemm_desc_ptr[group_id].StrideA,
gemm_desc_ptr[group_id].StrideB,
gemm_desc_ptr[group_id].StrideDs,
gemm_desc_ptr[group_id].StrideE,
kbatch);
if(has_main_k_block_loop)
{
GridwiseGemm::template Run<true>(gemm_desc_ptr[group_id].p_a_grid,
gemm_desc_ptr[group_id].p_b_grid,
p_ds_grid,
gemm_desc_ptr[group_id].p_e_grid,
static_cast<void*>(p_shared),
a_element_op,
b_element_op,
cde_element_op,
a_grid_desc_mk,
b_grid_desc_nk,
ds_grid_desc_mn,
e_grid_desc_mn,
b2c_tile_map);
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 ||
BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
{
GridwiseGemm::template Run<OffsettedBlockToCTileMap,
true,
InMemoryDataOperationEnum::Set,
TailNumber::Full>(
static_cast<const ADataType*>(gemm_desc_ptr[group_id].p_a_grid),
static_cast<const BDataType*>(gemm_desc_ptr[group_id].p_b_grid),
p_ds_grid,
static_cast<EDataType*>(gemm_desc_ptr[group_id].p_e_grid),
static_cast<void*>(p_shared),
problem,
a_element_op,
b_element_op,
cde_element_op,
b2c_tile_map);
}
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::One)
{
GridwiseGemm::template Run<OffsettedBlockToCTileMap,
true,
InMemoryDataOperationEnum::Set,
TailNumber::One>(
static_cast<const ADataType*>(gemm_desc_ptr[group_id].p_a_grid),
static_cast<const BDataType*>(gemm_desc_ptr[group_id].p_b_grid),
p_ds_grid,
static_cast<EDataType*>(gemm_desc_ptr[group_id].p_e_grid),
static_cast<void*>(p_shared),
problem,
a_element_op,
b_element_op,
cde_element_op,
b2c_tile_map);
}
else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Full)
{
GridwiseGemm::template Run<OffsettedBlockToCTileMap,
true,
InMemoryDataOperationEnum::Set,
TailNumber::Full>(
static_cast<const ADataType*>(gemm_desc_ptr[group_id].p_a_grid),
static_cast<const BDataType*>(gemm_desc_ptr[group_id].p_b_grid),
p_ds_grid,
static_cast<EDataType*>(gemm_desc_ptr[group_id].p_e_grid),
static_cast<void*>(p_shared),
problem,
a_element_op,
b_element_op,
cde_element_op,
b2c_tile_map);
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 2)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Two)
{
GridwiseGemm::template Run<OffsettedBlockToCTileMap,
true,
InMemoryDataOperationEnum::Set,
TailNumber::Two>(
static_cast<const ADataType*>(gemm_desc_ptr[group_id].p_a_grid),
static_cast<const BDataType*>(gemm_desc_ptr[group_id].p_b_grid),
p_ds_grid,
static_cast<EDataType*>(gemm_desc_ptr[group_id].p_e_grid),
static_cast<void*>(p_shared),
problem,
a_element_op,
b_element_op,
cde_element_op,
b2c_tile_map);
}
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 3)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Three)
{
GridwiseGemm::template Run<OffsettedBlockToCTileMap,
true,
InMemoryDataOperationEnum::Set,
TailNumber::Three>(
static_cast<const ADataType*>(gemm_desc_ptr[group_id].p_a_grid),
static_cast<const BDataType*>(gemm_desc_ptr[group_id].p_b_grid),
p_ds_grid,
static_cast<EDataType*>(gemm_desc_ptr[group_id].p_e_grid),
static_cast<void*>(p_shared),
problem,
a_element_op,
b_element_op,
cde_element_op,
b2c_tile_map);
}
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 4)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Four)
{
GridwiseGemm::template Run<OffsettedBlockToCTileMap,
true,
InMemoryDataOperationEnum::Set,
TailNumber::Four>(
static_cast<const ADataType*>(gemm_desc_ptr[group_id].p_a_grid),
static_cast<const BDataType*>(gemm_desc_ptr[group_id].p_b_grid),
p_ds_grid,
static_cast<EDataType*>(gemm_desc_ptr[group_id].p_e_grid),
static_cast<void*>(p_shared),
problem,
a_element_op,
b_element_op,
cde_element_op,
b2c_tile_map);
}
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 5)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Five)
{
GridwiseGemm::template Run<OffsettedBlockToCTileMap,
true,
InMemoryDataOperationEnum::Set,
TailNumber::Five>(
static_cast<const ADataType*>(gemm_desc_ptr[group_id].p_a_grid),
static_cast<const BDataType*>(gemm_desc_ptr[group_id].p_b_grid),
p_ds_grid,
static_cast<EDataType*>(gemm_desc_ptr[group_id].p_e_grid),
static_cast<void*>(p_shared),
problem,
a_element_op,
b_element_op,
cde_element_op,
b2c_tile_map);
}
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 6)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Six)
{
GridwiseGemm::template Run<OffsettedBlockToCTileMap,
true,
InMemoryDataOperationEnum::Set,
TailNumber::Six>(
static_cast<const ADataType*>(gemm_desc_ptr[group_id].p_a_grid),
static_cast<const BDataType*>(gemm_desc_ptr[group_id].p_b_grid),
p_ds_grid,
static_cast<EDataType*>(gemm_desc_ptr[group_id].p_e_grid),
static_cast<void*>(p_shared),
problem,
a_element_op,
b_element_op,
cde_element_op,
b2c_tile_map);
}
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 7)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Seven)
{
GridwiseGemm::template Run<OffsettedBlockToCTileMap,
true,
InMemoryDataOperationEnum::Set,
TailNumber::Seven>(
static_cast<const ADataType*>(gemm_desc_ptr[group_id].p_a_grid),
static_cast<const BDataType*>(gemm_desc_ptr[group_id].p_b_grid),
p_ds_grid,
static_cast<EDataType*>(gemm_desc_ptr[group_id].p_e_grid),
static_cast<void*>(p_shared),
problem,
a_element_op,
b_element_op,
cde_element_op,
b2c_tile_map);
}
}
}
// Tail number could be Odd or Even
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v4)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
{
GridwiseGemm::template Run_2Lds<OffsettedBlockToCTileMap,
true,
InMemoryDataOperationEnum::Set,
TailNumber::Odd>(
static_cast<const ADataType*>(gemm_desc_ptr[group_id].p_a_grid),
static_cast<const BDataType*>(gemm_desc_ptr[group_id].p_b_grid),
p_ds_grid,
static_cast<EDataType*>(gemm_desc_ptr[group_id].p_e_grid),
static_cast<void*>(p_shared),
static_cast<void*>(p_shared1),
problem,
a_element_op,
b_element_op,
cde_element_op,
b2c_tile_map);
}
else
{
GridwiseGemm::template Run_2Lds<OffsettedBlockToCTileMap,
true,
InMemoryDataOperationEnum::Set,
TailNumber::Even>(
static_cast<const ADataType*>(gemm_desc_ptr[group_id].p_a_grid),
static_cast<const BDataType*>(gemm_desc_ptr[group_id].p_b_grid),
p_ds_grid,
static_cast<EDataType*>(gemm_desc_ptr[group_id].p_e_grid),
static_cast<void*>(p_shared),
static_cast<void*>(p_shared1),
problem,
a_element_op,
b_element_op,
cde_element_op,
b2c_tile_map);
}
}
}
else
{
GridwiseGemm::template Run<false>(gemm_desc_ptr[group_id].p_a_grid,
gemm_desc_ptr[group_id].p_b_grid,
p_ds_grid,
gemm_desc_ptr[group_id].p_e_grid,
static_cast<void*>(p_shared),
a_element_op,
b_element_op,
cde_element_op,
a_grid_desc_mk,
b_grid_desc_nk,
ds_grid_desc_mn,
e_grid_desc_mn,
b2c_tile_map);
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
{
GridwiseGemm::template Run<OffsettedBlockToCTileMap,
false,
InMemoryDataOperationEnum::Set,
TailNumber::Full>(
static_cast<const ADataType*>(gemm_desc_ptr[group_id].p_a_grid),
static_cast<const BDataType*>(gemm_desc_ptr[group_id].p_b_grid),
p_ds_grid,
static_cast<EDataType*>(gemm_desc_ptr[group_id].p_e_grid),
static_cast<void*>(p_shared),
problem,
a_element_op,
b_element_op,
cde_element_op,
b2c_tile_map);
}
}
tile_id += get_grid_size();
@@ -253,10 +450,12 @@ template <typename ALayout,
index_t CShuffleMXdlPerWavePerShuffle,
index_t CShuffleNXdlPerWavePerShuffle,
typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CDEShuffleBlockTransferScalarPerVector_NPerBlock,
PipelineVersion PipelineVer = PipelineVersion::v1,
LoopScheduler LoopSched = make_default_loop_scheduler(),
typename ComputeDataType = EDataType>
typename CDEShuffleBlockTransferScalarPerVectors,
BlockGemmPipelineScheduler BlkGemmPipeSched = BlockGemmPipelineScheduler::Intrawave,
BlockGemmPipelineVersion BlkGemmPipelineVer = BlockGemmPipelineVersion::v1,
typename ComputeTypeA = EDataType,
typename ComputeTypeB = ComputeTypeA>
struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop
: public DeviceGroupedGemmTileLoop<ALayout,
BLayout,
@@ -273,10 +472,13 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop
using DeviceOp = DeviceGroupedGemmMultipleDXdlCShuffleTileLoop;
static constexpr index_t NumDTensor = DsDataType::Size();
using GridwiseGemm = GridwiseGemmMultipleD_xdl_cshuffle<
using GridwiseGemm = GridwiseGemmMultiD_xdl_cshuffle_v3<
ALayout,
BLayout,
DsLayout,
ELayout,
ADataType,
BDataType,
ComputeDataType,
AccDataType,
CShuffleDataType,
DsDataType,
@@ -284,8 +486,7 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation,
InMemoryDataOperationEnum::Set,
NumGemmKPrefetchStage,
GemmSpec,
BlockSize,
MPerBlock,
NPerBlock,
@@ -315,58 +516,15 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop
CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CDEShuffleBlockTransferScalarPerVector_NPerBlock,
LoopSched,
PipelineVer>;
CDEShuffleBlockTransferScalarPerVectors,
BlkGemmPipeSched,
BlkGemmPipelineVer,
ComputeTypeA,
ComputeTypeB>;
template <typename UnderlyingBlockToCTileMap>
struct OffsettedBlockToCTileMap
{
using underlying_type = UnderlyingBlockToCTileMap;
__host__ __device__ OffsettedBlockToCTileMap(UnderlyingBlockToCTileMap block_to_ctile_map,
index_t group_offset,
index_t tile_offset)
: block_to_ctile_map_{block_to_ctile_map},
group_offset_{group_offset},
tile_offset_{tile_offset}
{
}
template <typename TopIdx>
__host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const
{
return block_to_ctile_map_.CalculateBottomIndex(
make_multi_index(idx_top[Number<0>{}] + tile_offset_ - group_offset_));
}
template <typename CTileIdx, typename CTileDim>
__host__ __device__ bool ValidCTileIndex(const CTileIdx& c_tile_idx,
const CTileDim& c_tile_dim) const
{
return block_to_ctile_map_.ValidCTileIndex(c_tile_idx, c_tile_dim);
}
template <typename CGridDesc_M_N>
__host__ constexpr bool CheckValidity(const CGridDesc_M_N& c_grid_desc_m_n) const
{
return block_to_ctile_map_.CheckValidity(c_grid_desc_m_n);
}
__host__ __device__ constexpr index_t CalculateGridSize(index_t M, index_t N) const
{
return block_to_ctile_map_.CalculateGridSize(M, N);
}
__device__ void UpdateTileOffset(index_t offset) { tile_offset_ = offset; }
UnderlyingBlockToCTileMap block_to_ctile_map_;
index_t group_offset_;
index_t tile_offset_;
};
using KernelArguments = GroupedGemmTileLoopKernelArguments<NumDTensor>;
using Block2ETileMap = BlockToCTileMap_N00_M0_N01Adapt<MPerBlock, NPerBlock>;
using OffsetedLocalBlock2ETileMap = OffsettedBlockToCTileMap<Block2ETileMap>;
using KernelArguments = GroupedGemmTileLoopKernelArguments<NumDTensor>;
using Block2ETileMap = BlockToCTileMap_Grouped_M00_N0_M01Adapt<8, MPerBlock, NPerBlock>;
using OffsettedLocalBlock2ETileMap = OffsettedBlockToCTileMap2<Block2ETileMap>;
// Argument
struct Argument : public BaseArgument
@@ -403,7 +561,6 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop
const void* p_dev_gemm_args_;
int occupancy_num_blocks_;
int gpu_cu_count_;
const std::vector<GemmDesc>& gemm_descs_;
AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_;
@@ -496,16 +653,22 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop
const auto kernel = kernel_grouped_gemm_multiple_d_xdl<GridwiseGemm,
KernelArguments,
GemmSpec,
ADataType,
BDataType,
DsDataType,
EDataType,
ALayout,
BLayout,
DsLayout,
ELayout,
OffsetedLocalBlock2ETileMap,
KPerBlock,
OffsettedLocalBlock2ETileMap,
Block2ETileMap,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation>;
CDEElementwiseOperation,
BlkGemmPipeSched,
BlkGemmPipelineVer>;
return LaunchKernel(kernel, arg, dev_gemm_args, stream_config);
}
@@ -546,6 +709,8 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop
<< std::endl;
}
// run multiple kernels
return launch_and_time_kernel(stream_config,
kernel,
dim3(grid_size),
@@ -572,63 +737,41 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop
return false;
}
using DsGridDescMN = remove_cvref_t<
decltype(GridwiseGemm::template MakeDsGridDescriptor_M_N<DsLayout, GemmSpec>(
{}, {}, {}))>;
bool supported = true;
for(const auto& gdesc : arg.gemm_descs_)
constexpr index_t k_batch = 1;
for(index_t i = 0; i < arg.group_count_; ++i)
{
const auto M = gdesc.M_;
const auto N = gdesc.N_;
const auto K = gdesc.K_;
std::array<const void*, NumDTensor> placeholder_p_ds_grid{};
std::array<index_t, NumDTensor> stride_Ds;
std::copy_n(arg.gemm_descs_[i].stride_Ds_.begin(), NumDTensor, stride_Ds.begin());
using GridArg = typename GridwiseGemm::Argument;
GridArg gridwise_arg(nullptr, // p_a_grid,
nullptr, // p_b_grid,
placeholder_p_ds_grid, // p_ds_grid,
nullptr, // p_e_grid ,
arg.gemm_descs_[i].M_,
arg.gemm_descs_[i].N_,
arg.gemm_descs_[i].K_,
arg.gemm_descs_[i].stride_A_,
arg.gemm_descs_[i].stride_B_,
stride_Ds,
arg.gemm_descs_[i].stride_C_,
k_batch,
arg.a_element_op_,
arg.b_element_op_,
arg.cde_element_op_);
const auto StrideA = gdesc.stride_A_;
const auto StrideB = gdesc.stride_B_;
const auto StrideE = gdesc.stride_C_;
const auto& StrideDs = gdesc.stride_Ds_;
// If M dimension is unknown at launch time then validate just NK.
// If N or K dim is zero (or unknown) then the vector loads responsibility lies on
// the user.
if(N * K == 0)
continue;
const auto a_grid_desc_mk =
GridwiseGemm::template MakeAGridDescriptor_M_K<ALayout, GemmSpec>(M, K, StrideA);
const auto b_grid_desc_nk =
GridwiseGemm::template MakeBGridDescriptor_N_K<BLayout, GemmSpec>(K, N, StrideB);
const auto e_grid_desc_mn =
GridwiseGemm::template MakeEGridDescriptor_M_N<ELayout, GemmSpec>(M, N, StrideE);
DsGridDescMN ds_grid_desc_mn;
static_for<0, NumDTensor, 1>{}([&](auto j) {
using DLayout = remove_cvref_t<tuple_element_t<j.value, DsLayout>>;
ds_grid_desc_mn(j) =
GridwiseGemm::template MakeEGridDescriptor_M_N<DLayout, GemmSpec>(
M, N, StrideDs[j]);
});
const auto b2c_tile_map = Block2ETileMap(M, N);
if(!(GridwiseGemm::template CheckValidity(a_grid_desc_mk,
b_grid_desc_nk,
ds_grid_desc_mn,
e_grid_desc_mn,
b2c_tile_map) &&
GridwiseGemm::template CheckTensorTransfersValidity<ALayout, BLayout, ELayout>(
M, N, K)))
if((arg.gemm_descs_[i].K_ % AK1 != 0 || arg.gemm_descs_[i].K_ % BK1 != 0) &&
!(GemmSpec == GemmSpecialization::MKPadding ||
GemmSpec == GemmSpecialization::NKPadding ||
GemmSpec == GemmSpecialization::MNKPadding ||
GemmSpec == GemmSpecialization::KPadding))
{
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
std::cout << "The provided GEMM problem size (M,N,K) [" << M << "," << N << ","
<< K << "] are not supported by current template parameters!"
<< " In " << __FILE__ << ":" << __LINE__
<< ", in function: " << __func__;
}
supported = false;
return false;
}
supported = supported && GridwiseGemm::CheckValidity(gridwise_arg);
}
return supported;
@@ -651,16 +794,22 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop
const auto kernel = kernel_grouped_gemm_multiple_d_xdl<GridwiseGemm,
KernelArguments,
GemmSpec,
ADataType,
BDataType,
DsDataType,
EDataType,
ALayout,
BLayout,
DsLayout,
ELayout,
OffsetedLocalBlock2ETileMap,
KPerBlock,
OffsettedLocalBlock2ETileMap,
Block2ETileMap,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation>;
CDEElementwiseOperation,
BlkGemmPipeSched,
BlkGemmPipelineVer>;
int occupancy, num_cu;
hip_check_error(
hipOccupancyMaxActiveBlocksPerMultiprocessor(&occupancy, kernel, BlockSize, 0));
@@ -696,16 +845,22 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop
const auto kernel = kernel_grouped_gemm_multiple_d_xdl<GridwiseGemm,
KernelArguments,
GemmSpec,
ADataType,
BDataType,
DsDataType,
EDataType,
ALayout,
BLayout,
DsLayout,
ELayout,
OffsetedLocalBlock2ETileMap,
KPerBlock,
OffsettedLocalBlock2ETileMap,
Block2ETileMap,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation>;
CDEElementwiseOperation,
BlkGemmPipeSched,
BlkGemmPipelineVer>;
int occupancy, num_cu;
hip_check_error(
hipOccupancyMaxActiveBlocksPerMultiprocessor(&occupancy, kernel, BlockSize, 0));
@@ -739,6 +894,17 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop
{
auto str = std::ostringstream();
std::map<BlockGemmPipelineScheduler, std::string> BlkGemmPipelineSchedulerToString{
{BlockGemmPipelineScheduler::Intrawave, "Intrawave"},
{BlockGemmPipelineScheduler::Interwave, "Interwave"}};
std::map<BlockGemmPipelineVersion, std::string> BlkGemmPipelineVersionToString{
{BlockGemmPipelineVersion::v1, "v1"},
{BlockGemmPipelineVersion::v2, "v2"},
{BlockGemmPipelineVersion::v3, "v3"},
{BlockGemmPipelineVersion::v4, "v4"},
{BlockGemmPipelineVersion::v5, "v5"}};
// clang-format off
str << "DeviceGroupedGemmMultipleDXdlCShuffleTileLoop"
<< "<"
@@ -760,8 +926,10 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop
<< CShuffleMXdlPerWavePerShuffle << ", "
<< CShuffleNXdlPerWavePerShuffle << ", "
<< getGemmSpecializationString(GemmSpec) << ", "
<< PipelineVer << ", "
<< LoopSched
<< "BlkGemmPipelineScheduler: "
<< BlkGemmPipelineSchedulerToString[BlkGemmPipeSched] << ", "
<< "BlkGemmPipelineVersion: "
<< BlkGemmPipelineVersionToString[BlkGemmPipelineVer]
<< ">";
// clang-format on

View File

@@ -908,6 +908,51 @@ struct OffsettedBlockToCTileMap
UnderlyingBlockToCTileMap block_to_ctile_map_;
index_t block_start_;
};
// second version with 2 offsets
template <typename UnderlyingBlockToCTileMap>
struct OffsettedBlockToCTileMap2
{
using underlying_type = UnderlyingBlockToCTileMap;
__host__ __device__ OffsettedBlockToCTileMap2(UnderlyingBlockToCTileMap block_to_ctile_map,
index_t group_offset,
index_t tile_offset)
: block_to_ctile_map_{block_to_ctile_map},
group_offset_{group_offset},
tile_offset_{tile_offset}
{
}
template <typename TopIdx>
__host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const
{
return block_to_ctile_map_.CalculateBottomIndex(
make_multi_index(idx_top[Number<0>{}] + tile_offset_ - group_offset_));
}
template <typename CTileIdx, typename CTileDim>
__host__ __device__ bool ValidCTileIndex(const CTileIdx& c_tile_idx,
const CTileDim& c_tile_dim) const
{
return block_to_ctile_map_.ValidCTileIndex(c_tile_idx, c_tile_dim);
}
template <typename CGridDesc_M_N>
__host__ constexpr bool CheckValidity(const CGridDesc_M_N& c_grid_desc_m_n) const
{
return block_to_ctile_map_.CheckValidity(c_grid_desc_m_n);
}
__host__ __device__ constexpr index_t CalculateGridSize(index_t M, index_t N) const
{
return block_to_ctile_map_.CalculateGridSize(M, N);
}
__device__ void UpdateTileOffset(index_t offset) { tile_offset_ = offset; }
UnderlyingBlockToCTileMap block_to_ctile_map_;
index_t group_offset_;
index_t tile_offset_;
};
/**
* @brief Simple tile mapping which creates 3D grid of block of threads.

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -189,55 +189,55 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
__host__ static auto CalculateGridSize(index_t M, index_t N, index_t KBatch)
{
return std::make_tuple(Block2CTileMap::CalculateGridSize(M, N), 1, KBatch);
return std::make_tuple(Block2CTileMapDefault::CalculateGridSize(M, N), 1, KBatch);
}
__host__ static auto CalculateMPadded(index_t M)
__host__ __device__ static auto CalculateMPadded(index_t M)
{
return math::integer_least_multiple(M, MPerBlock);
}
__host__ static auto CalculateNPadded(index_t N)
__host__ __device__ static auto CalculateNPadded(index_t N)
{
return math::integer_least_multiple(N, NPerBlock);
}
__host__ static auto CalculateKPadded(index_t K)
__host__ __device__ static auto CalculateKPadded(index_t K)
{
return math::integer_divide_ceil(K, KPerBlock) * KPerBlock;
}
__host__ static auto CalculateAK0Padded(index_t K, index_t K_Batch = 1)
__host__ __device__ static auto CalculateAK0Padded(index_t K, index_t K_Batch = 1)
{
auto K_t = K_Batch * KPerBlock;
return (K + K_t - 1) / K_t * (KPerBlock / AK1Value);
}
__host__ static auto CalculateBK0Padded(index_t K, index_t K_Batch = 1)
__host__ __device__ static auto CalculateBK0Padded(index_t K, index_t K_Batch = 1)
{
auto K_t = K_Batch * KPerBlock;
return (K + K_t - 1) / K_t * (KPerBlock / BK1Value);
}
__host__ static auto CalculateKPadded(index_t K, index_t K_Batch = 1)
__host__ __device__ static auto CalculateKPadded(index_t K, index_t K_Batch = 1)
{
auto K_t = K_Batch * KPerBlock;
return (K + K_t - 1) / K_t * KPerBlock;
}
__host__ static auto CalculateKRead(index_t K, index_t K_Batch = 1)
__host__ __device__ static auto CalculateKRead(index_t K, index_t K_Batch = 1)
{
constexpr auto KReadVec = math::lcm(AK1Number, BK1Number);
auto K_t = K_Batch * KReadVec;
return (K + K_t - 1) / K_t * KReadVec;
}
__host__ static auto CalculateMBlock(index_t M)
__host__ __device__ static auto CalculateMBlock(index_t M)
{
return math::integer_divide_ceil(M, MPerBlock);
}
__host__ static auto CalculateNBlock(index_t N)
__host__ __device__ static auto CalculateNBlock(index_t N)
{
return math::integer_divide_ceil(N, NPerBlock);
}
@@ -520,14 +520,14 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
struct Problem
{
__host__ Problem(index_t M_,
index_t N_,
index_t K_,
index_t StrideA_,
index_t StrideB_,
std::array<index_t, NumDTensor> StrideDs_,
index_t StrideC_,
index_t KBatch_)
__host__ __device__ Problem(index_t M_,
index_t N_,
index_t K_,
index_t StrideA_,
index_t StrideB_,
std::array<index_t, NumDTensor> StrideDs_,
index_t StrideC_,
index_t KBatch_)
: M{M_},
N{N_},
K{K_},
@@ -1180,14 +1180,14 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
return true;
}
__host__ static constexpr bool CalculateHasMainKBlockLoop(index_t K)
__host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K)
{
const index_t num_loop = K / KPerBlock;
return BlockwiseGemmPipe::BlockHasHotloop(num_loop);
}
__host__ static constexpr TailNumber CalculateKBlockLoopTailNum(index_t K)
__host__ __device__ static constexpr TailNumber CalculateKBlockLoopTailNum(index_t K)
{
const index_t num_loop = K / KPerBlock;
@@ -1210,8 +1210,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
// return block_id to C matrix tile idx (m0, n0) mapping
// if arch = gfx942
using Block2CTileMap = BlockToCTileMap_Grouped_M00_N0_M01Adapt<8, MPerBlock, NPerBlock>;
// using Block2CTileMap = BlockToCTileMap_3DGrid_KSplit<MPerBlock, NPerBlock>;
using Block2CTileMapDefault = BlockToCTileMap_Grouped_M00_N0_M01Adapt<8, MPerBlock, NPerBlock>;
template <bool HasMainKBlockLoop,
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
@@ -1225,6 +1224,35 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op)
{
const auto block_2_ctile_map = Block2CTileMapDefault{problem.M, problem.N, 4};
Run<Block2CTileMapDefault, HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
p_a_grid,
p_b_grid,
p_ds_grid,
p_c_grid,
p_shared,
problem,
a_element_op,
b_element_op,
c_element_op,
block_2_ctile_map);
}
template <typename Block2CTileMap,
bool HasMainKBlockLoop,
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
TailNumber TailNum = TailNumber::Odd>
__device__ static void Run(const ADataType* p_a_grid,
const BDataType* p_b_grid,
DsGridPointer& p_ds_grid,
CDataType* p_c_grid,
void* p_shared,
const Problem& problem,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op,
const Block2CTileMap& block_2_ctile_map)
{
const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(
problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0);
@@ -1244,9 +1272,6 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
// divide block work by [M, N]
const auto block_2_ctile_map = Block2CTileMap{problem.M, problem.N, 4};
const auto block_work_idx =
block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
@@ -1653,6 +1678,38 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op)
{
// divide block work by [M, N]
const auto block_2_ctile_map = Block2CTileMapDefault{problem.M, problem.N, 4};
Run_2Lds<Block2CTileMapDefault, HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
p_a_grid,
p_b_grid,
p_ds_grid,
p_c_grid,
p_shared_0,
p_shared_1,
problem,
a_element_op,
b_element_op,
c_element_op,
block_2_ctile_map);
}
template <typename Block2CTileMap,
bool HasMainKBlockLoop,
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
TailNumber TailNum = TailNumber::Odd>
__device__ static void Run_2Lds(const ADataType* p_a_grid,
const BDataType* p_b_grid,
DsGridPointer& p_ds_grid,
CDataType* p_c_grid,
void* p_shared_0,
void* p_shared_1,
const Problem& problem,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op,
const Block2CTileMap& block_2_ctile_map)
{
const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(
problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0);
@@ -1672,9 +1729,6 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
// divide block work by [M, N]
const auto block_2_ctile_map = Block2CTileMap{problem.M, problem.N, 4};
const auto block_work_idx =
block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));

View File

@@ -17,7 +17,150 @@ namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_instances(
void add_device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_comp_default_instances(
std::vector<std::unique_ptr<DeviceGroupedGemmTileLoop<Row,
Row,
Row_Tuple,
Row,
BF16,
I8,
BF16_Tuple,
BF16,
PassThrough,
PassThrough,
Multiply>>>& instances);
void add_device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_comp_mnkpadding_instances(
std::vector<std::unique_ptr<DeviceGroupedGemmTileLoop<Row,
Row,
Row_Tuple,
Row,
BF16,
I8,
BF16_Tuple,
BF16,
PassThrough,
PassThrough,
Multiply>>>& instances);
void add_device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_comp_mnpadding_instances(
std::vector<std::unique_ptr<DeviceGroupedGemmTileLoop<Row,
Row,
Row_Tuple,
Row,
BF16,
I8,
BF16_Tuple,
BF16,
PassThrough,
PassThrough,
Multiply>>>& instances);
void add_device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_comp_kpadding_instances(
std::vector<std::unique_ptr<DeviceGroupedGemmTileLoop<Row,
Row,
Row_Tuple,
Row,
BF16,
I8,
BF16_Tuple,
BF16,
PassThrough,
PassThrough,
Multiply>>>& instances);
void add_device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_mem_v1_default_instances(
std::vector<std::unique_ptr<DeviceGroupedGemmTileLoop<Row,
Row,
Row_Tuple,
Row,
BF16,
I8,
BF16_Tuple,
BF16,
PassThrough,
PassThrough,
Multiply>>>& instances);
void add_device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_mem_v1_mnkpadding_instances(
std::vector<std::unique_ptr<DeviceGroupedGemmTileLoop<Row,
Row,
Row_Tuple,
Row,
BF16,
I8,
BF16_Tuple,
BF16,
PassThrough,
PassThrough,
Multiply>>>& instances);
void add_device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_mem_v1_mnpadding_instances(
std::vector<std::unique_ptr<DeviceGroupedGemmTileLoop<Row,
Row,
Row_Tuple,
Row,
BF16,
I8,
BF16_Tuple,
BF16,
PassThrough,
PassThrough,
Multiply>>>& instances);
void add_device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_mem_v1_kpadding_instances(
std::vector<std::unique_ptr<DeviceGroupedGemmTileLoop<Row,
Row,
Row_Tuple,
Row,
BF16,
I8,
BF16_Tuple,
BF16,
PassThrough,
PassThrough,
Multiply>>>& instances);
void add_device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_mem_v2_default_instances(
std::vector<std::unique_ptr<DeviceGroupedGemmTileLoop<Row,
Row,
Row_Tuple,
Row,
BF16,
I8,
BF16_Tuple,
BF16,
PassThrough,
PassThrough,
Multiply>>>& instances);
void add_device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_mem_v2_mnkpadding_instances(
std::vector<std::unique_ptr<DeviceGroupedGemmTileLoop<Row,
Row,
Row_Tuple,
Row,
BF16,
I8,
BF16_Tuple,
BF16,
PassThrough,
PassThrough,
Multiply>>>& instances);
void add_device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_mem_v2_mnpadding_instances(
std::vector<std::unique_ptr<DeviceGroupedGemmTileLoop<Row,
Row,
Row_Tuple,
Row,
BF16,
I8,
BF16_Tuple,
BF16,
PassThrough,
PassThrough,
Multiply>>>& instances);
void add_device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_mem_v2_kpadding_instances(
std::vector<std::unique_ptr<DeviceGroupedGemmTileLoop<Row,
Row,
Row_Tuple,
@@ -67,14 +210,35 @@ struct DeviceOperationInstanceFactory<
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
// fp16_output
if constexpr(is_same_v<ADataType, bhalf_t> && is_same_v<BDataType, int8_t> &&
is_same_v<EDataType, bhalf_t>)
{
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
is_same_v<ELayout, Row>)
{
add_device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_instances(
add_device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_comp_default_instances(
op_ptrs);
add_device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_comp_mnkpadding_instances(
op_ptrs);
add_device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_comp_mnpadding_instances(
op_ptrs);
add_device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_comp_kpadding_instances(
op_ptrs);
add_device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_mem_v1_default_instances(
op_ptrs);
add_device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_mem_v1_mnkpadding_instances(
op_ptrs);
add_device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_mem_v1_mnpadding_instances(
op_ptrs);
add_device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_mem_v1_kpadding_instances(
op_ptrs);
add_device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_mem_v2_default_instances(
op_ptrs);
add_device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_mem_v2_mnkpadding_instances(
op_ptrs);
add_device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_mem_v2_mnpadding_instances(
op_ptrs);
add_device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_mem_v2_kpadding_instances(
op_ptrs);
}
}
@@ -132,7 +296,6 @@ struct DeviceOperationInstanceFactory<
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
// fp16_output
if constexpr(is_same_v<ADataType, bhalf_t> && is_same_v<BDataType, int8_t> &&
is_same_v<EDataType, bhalf_t>)
{
@@ -199,7 +362,6 @@ struct DeviceOperationInstanceFactory<
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
// fp16_output
if constexpr(is_same_v<ADataType, bhalf_t> && is_same_v<BDataType, int8_t> &&
is_same_v<EDataType, bhalf_t>)
{
@@ -266,7 +428,6 @@ struct DeviceOperationInstanceFactory<
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
// fp16_output
if constexpr(is_same_v<ADataType, bhalf_t> && is_same_v<BDataType, int8_t> &&
is_same_v<EDataType, bhalf_t>)
{

View File

@@ -5,8 +5,22 @@ set(GROUPED_GEMM_TILE_LOOP_INSTANCES)
list(APPEND GROUPED_GEMM_TILE_LOOP_INSTANCES
device_grouped_gemm_xdl_tile_loop_f16_f16_f16_mk_kn_mn_instance.cpp
device_grouped_gemm_xdl_tile_loop_f16_f16_f16_mk_nk_mn_instance.cpp
device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_instance.cpp
device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn.hpp
device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_comp_default_instance.cpp
device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_comp_kpadding_instance.cpp
device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_comp_mnkpadding_instance.cpp
device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_comp_mnpadding_instance.cpp
device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_mem_v1_default_instance.cpp
device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_mem_v1_kpadding_instance.cpp
device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_mem_v1_mnkpadding_instance.cpp
device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_mem_v1_mnpadding_instance.cpp
device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_mem_v2_default_instance.cpp
device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_mem_v2_kpadding_instance.cpp
device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_mem_v2_mnkpadding_instance.cpp
device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_mem_v2_mnpadding_instance.cpp
device_grouped_gemm_xdl_tile_loop_multiply_bias_bf16_i8_bf16_mk_kn_mn_instance.cpp
device_grouped_gemm_xdl_tile_loop_multiply_bias_fastgelu_bf16_i8_bf16_mk_kn_mn_instance.cpp
device_grouped_gemm_xdl_tile_loop_multiply_fastgelu_bf16_i8_bf16_mk_kn_mn_instance.cpp
)
add_instance_library(device_grouped_gemm_tile_loop_instance ${GROUPED_GEMM_TILE_LOOP_INSTANCES})

View File

@@ -38,16 +38,16 @@ using device_grouped_gemm_xdl_tile_loop_f16_f16_f16_mk_kn_mn_irregular_tile_inst
//###########################################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//###########################################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//###########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 2, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 2, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 4>, 8>,
DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>,
DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8>,
DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>
DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8>>,
DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8>>,
DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 2, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, S<8>>,
DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8>>,
DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 2, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, S<8>>,
DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8>>,
DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 4>, S<8>>,
DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, S<8>>,
DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, S<8>>,
DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, S<8>>
// clang-format on
>;

View File

@@ -37,19 +37,19 @@ using device_grouped_gemm_xdl_tile_loop_f16_f16_f16_mk_nk_mn_irregular_tile_inst
//###########################################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//###########################################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//###########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Col, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 256, 64, 8, 8, 32, 32, 2, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Col, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 128, 64, 8, 8, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Col, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 64, 64, 8, 8, 32, 32, 2, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Col, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 64, 128, 64, 8, 8, 32, 32, 1, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Col, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 128, 64, 8, 8, 32, 32, 4, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>,
DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Col, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 64, 64, 8, 8, 32, 32, 2, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>,
DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Col, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 64, 128, 64, 8, 8, 32, 32, 2, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>,
DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Col, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 32, 64, 8, 8, 32, 32, 2, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>,
DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Col, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 32, 128, 64, 8, 8, 32, 32, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>,
DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Col, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 32, 256, 64, 8, 8, 32, 32, 1, 4, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>,
DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Col, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 64, 64, 64, 8, 8, 32, 32, 2, 2, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>,
DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Col, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 64, 32, 64, 8, 8, 32, 32, 2, 1, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>,
DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Col, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 32, 64, 64, 8, 8, 32, 32, 1, 2, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>
DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Col, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 256, 64, 8, 8, 32, 32, 2, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8>>,
DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Col, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 128, 64, 8, 8, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8>>,
DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Col, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 64, 64, 8, 8, 32, 32, 2, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8>>,
DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Col, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 64, 128, 64, 8, 8, 32, 32, 1, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8>>,
DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Col, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 128, 64, 8, 8, 32, 32, 4, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, S<8>>,
DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Col, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 64, 64, 8, 8, 32, 32, 2, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, S<8>>,
DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Col, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 64, 128, 64, 8, 8, 32, 32, 2, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, S<8>>,
DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Col, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 32, 64, 8, 8, 32, 32, 2, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, S<8>>,
DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Col, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 32, 128, 64, 8, 8, 32, 32, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, S<8>>,
DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Col, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 32, 256, 64, 8, 8, 32, 32, 1, 4, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, S<8>>,
DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Col, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 64, 64, 64, 8, 8, 32, 32, 2, 2, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, S<8>>,
DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Col, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 64, 32, 64, 8, 8, 32, 32, 2, 1, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, S<8>>,
DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Col, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 32, 64, 64, 8, 8, 32, 32, 1, 2, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, S<8>>
// clang-format on
>;

View File

@@ -0,0 +1,93 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
using BF16 = ck::bhalf_t;
using I8 = int8_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using Multiply = ck::tensor_operation::element_wise::Multiply;
using MultiplyAddFastGelu = ck::tensor_operation::element_wise::MultiplyAddFastGelu;
using MultiplyFastGelu = ck::tensor_operation::element_wise::MultiplyFastGelu;
using MultiplyAdd = ck::tensor_operation::element_wise::MultiplyAdd;
static constexpr auto GemmDefault = GemmSpecialization::Default;
static constexpr auto GemmKPadding = GemmSpecialization::KPadding;
static constexpr auto GemmMNPadding = GemmSpecialization::MNPadding;
static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding;
static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave;
static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave;
template <typename DsLayout,
typename DsDataType,
typename CDEElementwiseOp,
GemmSpecialization GemmSpec = GemmMNKPadding>
using device_grouped_gemm_xdl_tile_loop_bf16_i8_bf16_mk_kn_mn_comp_instances = std::tuple<
// clang-format off
//###########################################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//###########################################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//###########################################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//###########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | S<C,D0...,D_N|
// DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmSpec, 1, 256, 256, 256, 32, 8, 4, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, S<8,8,1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>,
// DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmSpec, 1, 256, 128, 128, 64, 8, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, S<8,8,1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>,
// DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmSpec, 1, 256, 256, 256, 32, 8, 4, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, S<8,8,1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>,
// DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmSpec, 1, 256, 256, 256, 32, 8, 4, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, S<8,8,1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>,
// DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmSpec, 1, 256, 224, 256, 64, 8, 4, 16, 16, 7, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 2, S<1, 32, 1, 8>, S<8,8,1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>,
// DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmSpec, 1, 256, 128, 128, 64, 8, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, S<8,8,1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>,
// DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmSpec, 1, 256, 128, 256, 32, 8, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, S<8,8,1>, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1>,
DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmSpec, 1, 256, 128, 128, 64, 8, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, S<8,8,8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>
// clang-format on
>;
template <typename DsLayout,
typename DsDataType,
typename CDEElementwiseOp,
GemmSpecialization GemmSpec = GemmMNKPadding,
BlockGemmPipelineScheduler BlkGemmPipeSched = BlockGemmPipelineScheduler::Intrawave>
using device_grouped_gemm_xdl_tile_loop_bf16_i8_bf16_mk_kn_mn_mem_instances =
std::tuple<
// clang-format off
//###########################################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//###########################################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//###########################################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//###########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | S<C,D0...,D_N|
// Latency friendly
// DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmSpec, 1, 64, 16, 16, 256, 8, 4, 16, 16, 1, 1, S<32, 2, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<64, 1, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 16, 1, 4>, S<4,4,1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
// DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmSpec, 1, 128, 16, 32, 256, 8, 4, 16, 16, 1, 1, S<32, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<64, 2, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 16, 1, 8>, S<4,4,1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
// Memory friendly
// DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmSpec, 1, 64, 16, 16, 256, 8, 4, 16, 16, 1, 1, S<32, 2, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<64, 1, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 16, 1, 4>, S<4,4,1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
// DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmSpec, 1, 128, 16, 32, 256, 8, 4, 16, 16, 1, 1, S<32, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<64, 2, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 16, 1, 8>, S<4,4,1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmSpec, 1, 128, 16, 64, 128, 8, 4, 16, 16, 1, 2, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<32, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 16, 1, 8>, S<4,4,1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>
// DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmSpec, 1, 128, 32, 64, 128, 8, 4, 32, 32, 1, 1, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<32, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 16, 1, 8>, S<8,8,1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
// DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmSpec, 1, 128, 16, 128, 64, 8, 4, 16, 16, 1, 4, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 16, 1, 8>, S<4,4,1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
// DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmSpec, 1, 128, 32, 128, 64, 8, 4, 32, 32, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 16, 1, 8>, S<8,8,1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
// DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmSpec, 1, 256, 16, 256, 64, 8, 4, 16, 16, 1, 4, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 16, 1, 16>, S<4,4,1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
// DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmSpec, 1, 256, 32, 256, 64, 8, 4, 32, 32, 1, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 16, 1, 16>, S<8,8,1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>
// clang-format on
>;
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,35 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_comp_default_instances(
std::vector<std::unique_ptr<DeviceGroupedGemmTileLoop<Row,
Row,
ck::Tuple<Row>,
Row,
BF16,
I8,
ck::Tuple<BF16>,
BF16,
PassThrough,
PassThrough,
Multiply>>>& instances)
{
add_device_operation_instances(
instances,
device_grouped_gemm_xdl_tile_loop_bf16_i8_bf16_mk_kn_mn_comp_instances<ck::Tuple<Row>,
ck::Tuple<BF16>,
Multiply,
GemmDefault>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,35 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_comp_kpadding_instances(
std::vector<std::unique_ptr<DeviceGroupedGemmTileLoop<Row,
Row,
ck::Tuple<Row>,
Row,
BF16,
I8,
ck::Tuple<BF16>,
BF16,
PassThrough,
PassThrough,
Multiply>>>& instances)
{
add_device_operation_instances(
instances,
device_grouped_gemm_xdl_tile_loop_bf16_i8_bf16_mk_kn_mn_comp_instances<ck::Tuple<Row>,
ck::Tuple<BF16>,
Multiply,
GemmKPadding>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,35 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_comp_mnkpadding_instances(
std::vector<std::unique_ptr<DeviceGroupedGemmTileLoop<Row,
Row,
ck::Tuple<Row>,
Row,
BF16,
I8,
ck::Tuple<BF16>,
BF16,
PassThrough,
PassThrough,
Multiply>>>& instances)
{
add_device_operation_instances(
instances,
device_grouped_gemm_xdl_tile_loop_bf16_i8_bf16_mk_kn_mn_comp_instances<ck::Tuple<Row>,
ck::Tuple<BF16>,
Multiply,
GemmMNKPadding>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,35 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_comp_mnpadding_instances(
std::vector<std::unique_ptr<DeviceGroupedGemmTileLoop<Row,
Row,
ck::Tuple<Row>,
Row,
BF16,
I8,
ck::Tuple<BF16>,
BF16,
PassThrough,
PassThrough,
Multiply>>>& instances)
{
add_device_operation_instances(
instances,
device_grouped_gemm_xdl_tile_loop_bf16_i8_bf16_mk_kn_mn_comp_instances<ck::Tuple<Row>,
ck::Tuple<BF16>,
Multiply,
GemmMNPadding>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -31,51 +31,63 @@ using MultiplyAddFastGelu = ck::tensor_operation::element_wise::MultiplyAddFastG
using MultiplyFastGelu = ck::tensor_operation::element_wise::MultiplyFastGelu;
using MultiplyAdd = ck::tensor_operation::element_wise::MultiplyAdd;
static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
static constexpr auto GemmDefault = GemmSpecialization::Default;
static constexpr auto GemmKPadding = GemmSpecialization::KPadding;
static constexpr auto GemmMNPadding = GemmSpecialization::MNPadding;
static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding;
template <typename DsLayout, typename DsDataType, typename CDEElementwiseOp>
using device_grouped_gemm_xdl_tile_loop_bf16_i8_bf16_mk_kn_mn_irregular_tile_instances = std::tuple<
// clang-format off
static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave;
static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave;
template <typename DsLayout,
typename DsDataType,
typename CDEElementwiseOp,
GemmSpecialization GemmSpec = GemmMNKPadding>
using device_grouped_gemm_xdl_tile_loop_bf16_i8_bf16_mk_kn_mn_comp_instances = std::tuple<
// clang-format off
//###########################################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//###########################################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//###########################################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//###########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
#if 1
DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 2, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 2, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 4>, 8>,
DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>,
DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8>,
DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>
#endif
#if 0
//comp
DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmMNKPadding, 1, 256, 128, 128, 64, 8, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmMNKPadding, 1, 256, 128, 128, 64, 8, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmMNKPadding, 1, 256, 128, 256, 32, 8, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmMNKPadding, 1, 256, 128, 128, 64, 8, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8>,
//###########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | S<C,D0...,D_N|
// DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmSpec, 1, 256, 256, 256, 32, 8, 4, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, S<8,8,1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>,
// DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmSpec, 1, 256, 128, 128, 64, 8, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, S<8,8,1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>,
// DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmSpec, 1, 256, 256, 256, 32, 8, 4, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, S<8,8,1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>,
// DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmSpec, 1, 256, 256, 256, 32, 8, 4, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, S<8,8,1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>,
// DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmSpec, 1, 256, 224, 256, 64, 8, 4, 16, 16, 7, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 2, S<1, 32, 1, 8>, S<8,8,1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>,
// DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmSpec, 1, 256, 128, 128, 64, 8, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, S<8,8,1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>,
// DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmSpec, 1, 256, 128, 256, 32, 8, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, S<8,8,1>, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1>,
DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmSpec, 1, 256, 128, 128, 64, 8, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, S<8,8,1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>
//latency
DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmMNKPadding, 1, 64, 16, 16, 256, 8, 4, 16, 16, 1, 1, S<32, 2, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<64, 1, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 16, 1, 4>, 4>,
DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmMNKPadding, 1, 128, 16, 32, 256, 8, 4, 16, 16, 1, 1, S<32, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<64, 2, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 16, 1, 8>, 4>,
//mem
DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmMNKPadding, 1, 64, 16, 16, 256, 8, 4, 16, 16, 1, 1, S<32, 2, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<64, 1, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 16, 1, 4>, 4>,
DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmMNKPadding, 1, 128, 16, 32, 256, 8, 4, 16, 16, 1, 1, S<32, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<64, 2, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 16, 1, 8>, 4>,
DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmMNKPadding, 1, 128, 16, 64, 128, 8, 4, 16, 16, 1, 2, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<32, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 16, 1, 8>, 4>,
DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmMNKPadding, 1, 128, 32, 64, 128, 8, 4, 32, 32, 1, 1, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<32, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 16, 1, 8>, 8>,
DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmMNKPadding, 1, 128, 16, 128, 64, 8, 4, 16, 16, 1, 4, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 16, 1, 8>, 4>,
DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmMNKPadding, 1, 128, 32, 128, 64, 8, 4, 32, 32, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 16, 1, 8>, 8>,
DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmMNKPadding, 1, 256, 16, 256, 64, 8, 4, 16, 16, 1, 4, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 16, 1, 16>, 4>,
DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmMNKPadding, 1, 256, 32, 256, 64, 8, 4, 32, 32, 1, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 16, 1, 16>, 8>
#endif
// clang-format on
>;
template <typename DsLayout,
typename DsDataType,
typename CDEElementwiseOp,
GemmSpecialization GemmSpec = GemmMNKPadding,
BlockGemmPipelineScheduler BlkGemmPipeSched = BlockGemmPipelineScheduler::Intrawave>
using device_grouped_gemm_xdl_tile_loop_bf16_i8_bf16_mk_kn_mn_mem_instances =
std::tuple<
// clang-format off
//###########################################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//###########################################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//###########################################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//###########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | S<C,D0...,D_N|
// Latency friendly
// DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmSpec, 1, 64, 16, 16, 256, 8, 4, 16, 16, 1, 1, S<32, 2, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<64, 1, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 16, 1, 4>, S<4,4,1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
// DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmSpec, 1, 128, 16, 32, 256, 8, 4, 16, 16, 1, 1, S<32, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<64, 2, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 16, 1, 8>, S<4,4,1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
// Memory friendly
// DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmSpec, 1, 64, 16, 16, 256, 8, 4, 16, 16, 1, 1, S<32, 2, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<64, 1, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 16, 1, 4>, S<4,4,1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
// DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmSpec, 1, 128, 16, 32, 256, 8, 4, 16, 16, 1, 1, S<32, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<64, 2, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 16, 1, 8>, S<4,4,1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmSpec, 1, 128, 16, 64, 128, 8, 4, 16, 16, 1, 2, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<32, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 16, 1, 8>, S<4,4,1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>
// DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmSpec, 1, 128, 32, 64, 128, 8, 4, 32, 32, 1, 1, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<32, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 16, 1, 8>, S<8,8,1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
// DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmSpec, 1, 128, 16, 128, 64, 8, 4, 16, 16, 1, 4, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 16, 1, 8>, S<4,4,1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
// DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmSpec, 1, 128, 32, 128, 64, 8, 4, 32, 32, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 16, 1, 8>, S<8,8,1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
// DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmSpec, 1, 256, 16, 256, 64, 8, 4, 16, 16, 1, 4, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 16, 1, 16>, S<4,4,1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
// DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmSpec, 1, 256, 32, 256, 64, 8, 4, 32, 32, 1, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 16, 1, 16>, S<8,8,1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>
// clang-format on
>;
void add_device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_instances(
std::vector<std::unique_ptr<DeviceGroupedGemmTileLoop<Row,
Row,
@@ -89,33 +101,89 @@ void add_device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_instan
PassThrough,
Multiply>>>& instances)
{
// comp
add_device_operation_instances(
instances,
device_grouped_gemm_xdl_tile_loop_bf16_i8_bf16_mk_kn_mn_irregular_tile_instances<
ck::Tuple<Row>,
ck::Tuple<BF16>,
Multiply>{});
}
device_grouped_gemm_xdl_tile_loop_bf16_i8_bf16_mk_kn_mn_comp_instances<ck::Tuple<Row>,
ck::Tuple<BF16>,
Multiply,
GemmDefault>{});
add_device_operation_instances(
instances,
device_grouped_gemm_xdl_tile_loop_bf16_i8_bf16_mk_kn_mn_comp_instances<ck::Tuple<Row>,
ck::Tuple<BF16>,
Multiply,
GemmMNKPadding>{});
add_device_operation_instances(
instances,
device_grouped_gemm_xdl_tile_loop_bf16_i8_bf16_mk_kn_mn_comp_instances<ck::Tuple<Row>,
ck::Tuple<BF16>,
Multiply,
GemmMNPadding>{});
add_device_operation_instances(
instances,
device_grouped_gemm_xdl_tile_loop_bf16_i8_bf16_mk_kn_mn_comp_instances<ck::Tuple<Row>,
ck::Tuple<BF16>,
Multiply,
GemmKPadding>{});
// mem
add_device_operation_instances(
instances,
device_grouped_gemm_xdl_tile_loop_bf16_i8_bf16_mk_kn_mn_mem_instances<ck::Tuple<Row>,
ck::Tuple<BF16>,
Multiply,
GemmDefault,
Intrawave>{});
add_device_operation_instances(
instances,
device_grouped_gemm_xdl_tile_loop_bf16_i8_bf16_mk_kn_mn_mem_instances<ck::Tuple<Row>,
ck::Tuple<BF16>,
Multiply,
GemmMNKPadding,
Intrawave>{});
add_device_operation_instances(
instances,
device_grouped_gemm_xdl_tile_loop_bf16_i8_bf16_mk_kn_mn_mem_instances<ck::Tuple<Row>,
ck::Tuple<BF16>,
Multiply,
GemmMNPadding,
Intrawave>{});
add_device_operation_instances(
instances,
device_grouped_gemm_xdl_tile_loop_bf16_i8_bf16_mk_kn_mn_mem_instances<ck::Tuple<Row>,
ck::Tuple<BF16>,
Multiply,
GemmKPadding,
Intrawave>{});
void add_device_grouped_gemm_xdl_tile_loop_multiply_bias_bf16_i8_bf16_mk_kn_mn_instances(
std::vector<std::unique_ptr<DeviceGroupedGemmTileLoop<Row,
Row,
ck::Tuple<Row, Row>,
Row,
BF16,
I8,
ck::Tuple<BF16, BF16>,
BF16,
PassThrough,
PassThrough,
MultiplyAdd>>>& instances)
{
add_device_operation_instances(
instances,
device_grouped_gemm_xdl_tile_loop_bf16_i8_bf16_mk_kn_mn_irregular_tile_instances<
ck::Tuple<Row, Row>,
ck::Tuple<BF16, BF16>,
MultiplyAdd>{});
device_grouped_gemm_xdl_tile_loop_bf16_i8_bf16_mk_kn_mn_mem_instances<ck::Tuple<Row>,
ck::Tuple<BF16>,
Multiply,
GemmDefault,
Interwave>{});
add_device_operation_instances(
instances,
device_grouped_gemm_xdl_tile_loop_bf16_i8_bf16_mk_kn_mn_mem_instances<ck::Tuple<Row>,
ck::Tuple<BF16>,
Multiply,
GemmMNKPadding,
Interwave>{});
add_device_operation_instances(
instances,
device_grouped_gemm_xdl_tile_loop_bf16_i8_bf16_mk_kn_mn_mem_instances<ck::Tuple<Row>,
ck::Tuple<BF16>,
Multiply,
GemmMNPadding,
Interwave>{});
add_device_operation_instances(
instances,
device_grouped_gemm_xdl_tile_loop_bf16_i8_bf16_mk_kn_mn_mem_instances<ck::Tuple<Row>,
ck::Tuple<BF16>,
Multiply,
GemmKPadding,
Interwave>{});
}
void add_device_grouped_gemm_xdl_tile_loop_multiply_bias_fastgelu_bf16_i8_bf16_mk_kn_mn_instances(

View File

@@ -0,0 +1,36 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_mem_v1_default_instances(
std::vector<std::unique_ptr<DeviceGroupedGemmTileLoop<Row,
Row,
ck::Tuple<Row>,
Row,
BF16,
I8,
ck::Tuple<BF16>,
BF16,
PassThrough,
PassThrough,
Multiply>>>& instances)
{
add_device_operation_instances(
instances,
device_grouped_gemm_xdl_tile_loop_bf16_i8_bf16_mk_kn_mn_mem_instances<ck::Tuple<Row>,
ck::Tuple<BF16>,
Multiply,
GemmDefault,
Intrawave>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,36 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_mem_v1_kpadding_instances(
std::vector<std::unique_ptr<DeviceGroupedGemmTileLoop<Row,
Row,
ck::Tuple<Row>,
Row,
BF16,
I8,
ck::Tuple<BF16>,
BF16,
PassThrough,
PassThrough,
Multiply>>>& instances)
{
add_device_operation_instances(
instances,
device_grouped_gemm_xdl_tile_loop_bf16_i8_bf16_mk_kn_mn_mem_instances<ck::Tuple<Row>,
ck::Tuple<BF16>,
Multiply,
GemmKPadding,
Intrawave>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,36 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_mem_v1_mnkpadding_instances(
std::vector<std::unique_ptr<DeviceGroupedGemmTileLoop<Row,
Row,
ck::Tuple<Row>,
Row,
BF16,
I8,
ck::Tuple<BF16>,
BF16,
PassThrough,
PassThrough,
Multiply>>>& instances)
{
add_device_operation_instances(
instances,
device_grouped_gemm_xdl_tile_loop_bf16_i8_bf16_mk_kn_mn_mem_instances<ck::Tuple<Row>,
ck::Tuple<BF16>,
Multiply,
GemmMNKPadding,
Intrawave>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,36 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_mem_v1_mnpadding_instances(
std::vector<std::unique_ptr<DeviceGroupedGemmTileLoop<Row,
Row,
ck::Tuple<Row>,
Row,
BF16,
I8,
ck::Tuple<BF16>,
BF16,
PassThrough,
PassThrough,
Multiply>>>& instances)
{
add_device_operation_instances(
instances,
device_grouped_gemm_xdl_tile_loop_bf16_i8_bf16_mk_kn_mn_mem_instances<ck::Tuple<Row>,
ck::Tuple<BF16>,
Multiply,
GemmMNPadding,
Intrawave>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,36 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_mem_v2_default_instances(
std::vector<std::unique_ptr<DeviceGroupedGemmTileLoop<Row,
Row,
ck::Tuple<Row>,
Row,
BF16,
I8,
ck::Tuple<BF16>,
BF16,
PassThrough,
PassThrough,
Multiply>>>& instances)
{
add_device_operation_instances(
instances,
device_grouped_gemm_xdl_tile_loop_bf16_i8_bf16_mk_kn_mn_mem_instances<ck::Tuple<Row>,
ck::Tuple<BF16>,
Multiply,
GemmDefault,
Interwave>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,36 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_mem_v2_kpadding_instances(
std::vector<std::unique_ptr<DeviceGroupedGemmTileLoop<Row,
Row,
ck::Tuple<Row>,
Row,
BF16,
I8,
ck::Tuple<BF16>,
BF16,
PassThrough,
PassThrough,
Multiply>>>& instances)
{
add_device_operation_instances(
instances,
device_grouped_gemm_xdl_tile_loop_bf16_i8_bf16_mk_kn_mn_mem_instances<ck::Tuple<Row>,
ck::Tuple<BF16>,
Multiply,
GemmKPadding,
Interwave>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,36 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_mem_v2_mnkpadding_instances(
std::vector<std::unique_ptr<DeviceGroupedGemmTileLoop<Row,
Row,
ck::Tuple<Row>,
Row,
BF16,
I8,
ck::Tuple<BF16>,
BF16,
PassThrough,
PassThrough,
Multiply>>>& instances)
{
add_device_operation_instances(
instances,
device_grouped_gemm_xdl_tile_loop_bf16_i8_bf16_mk_kn_mn_mem_instances<ck::Tuple<Row>,
ck::Tuple<BF16>,
Multiply,
GemmMNKPadding,
Interwave>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,36 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_mem_v2_mnpadding_instances(
std::vector<std::unique_ptr<DeviceGroupedGemmTileLoop<Row,
Row,
ck::Tuple<Row>,
Row,
BF16,
I8,
ck::Tuple<BF16>,
BF16,
PassThrough,
PassThrough,
Multiply>>>& instances)
{
add_device_operation_instances(
instances,
device_grouped_gemm_xdl_tile_loop_bf16_i8_bf16_mk_kn_mn_mem_instances<ck::Tuple<Row>,
ck::Tuple<BF16>,
Multiply,
GemmMNPadding,
Interwave>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,40 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_gemm_xdl_tile_loop_multiply_bias_bf16_i8_bf16_mk_kn_mn_instances(
std::vector<std::unique_ptr<DeviceGroupedGemmTileLoop<Row,
Row,
ck::Tuple<Row, Row>,
Row,
BF16,
I8,
ck::Tuple<BF16, BF16>,
BF16,
PassThrough,
PassThrough,
MultiplyAdd>>>& instances)
{
add_device_operation_instances(
instances,
device_grouped_gemm_xdl_tile_loop_bf16_i8_bf16_mk_kn_mn_comp_instances<
ck::Tuple<Row, Row>,
ck::Tuple<BF16, BF16>,
MultiplyAdd>{});
add_device_operation_instances(
instances,
device_grouped_gemm_xdl_tile_loop_bf16_i8_bf16_mk_kn_mn_mem_instances<ck::Tuple<Row, Row>,
ck::Tuple<BF16, BF16>,
MultiplyAdd>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,41 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_gemm_xdl_tile_loop_multiply_bias_fastgelu_bf16_i8_bf16_mk_kn_mn_instances(
std::vector<std::unique_ptr<DeviceGroupedGemmTileLoop<Row,
Row,
ck::Tuple<Row, Row>,
Row,
BF16,
I8,
ck::Tuple<BF16, BF16>,
BF16,
PassThrough,
PassThrough,
MultiplyAddFastGelu>>>& instances)
{
add_device_operation_instances(
instances,
device_grouped_gemm_xdl_tile_loop_bf16_i8_bf16_mk_kn_mn_comp_instances<
ck::Tuple<Row, Row>,
ck::Tuple<BF16, BF16>,
MultiplyAddFastGelu>{});
add_device_operation_instances(
instances,
device_grouped_gemm_xdl_tile_loop_bf16_i8_bf16_mk_kn_mn_mem_instances<
ck::Tuple<Row, Row>,
ck::Tuple<BF16, BF16>,
MultiplyAddFastGelu>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,39 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_gemm_xdl_tile_loop_multiply_fastgelu_bf16_i8_bf16_mk_kn_mn_instances(
std::vector<std::unique_ptr<DeviceGroupedGemmTileLoop<Row,
Row,
ck::Tuple<Row>,
Row,
BF16,
I8,
ck::Tuple<BF16>,
BF16,
PassThrough,
PassThrough,
MultiplyFastGelu>>>& instances)
{
add_device_operation_instances(
instances,
device_grouped_gemm_xdl_tile_loop_bf16_i8_bf16_mk_kn_mn_comp_instances<ck::Tuple<Row>,
ck::Tuple<BF16>,
MultiplyFastGelu>{});
add_device_operation_instances(
instances,
device_grouped_gemm_xdl_tile_loop_bf16_i8_bf16_mk_kn_mn_mem_instances<ck::Tuple<Row>,
ck::Tuple<BF16>,
MultiplyFastGelu>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,347 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iomanip>
#include "ck/ck.hpp"
#include "ck/host_utility/hip_check_error.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_gemm_tile_loop.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_gemm_tile_loop_multiply.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/literals.hpp"
#include "ck/library/utility/fill.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
namespace ck {
namespace profiler {
template <typename ADataType,
typename BDataType,
typename DDataType,
typename EDataType,
typename AccDataType,
typename ALayout,
typename BLayout,
typename DLayout,
typename ELayout>
bool profile_grouped_gemm_multiply_tile_loop_impl(int do_verification,
int init_method,
bool do_log,
bool time_kernel,
const std::vector<int>& Ms,
const std::vector<int>& Ns,
const std::vector<int>& Ks,
const std::vector<int>& StrideAs,
const std::vector<int>& StrideBs,
const std::vector<int>& StrideDs,
const std::vector<int>& StrideEs,
int n_warmup = 10,
int n_iter = 50)
{
using CDataType = EDataType;
bool pass = true;
auto f_host_tensor_descriptor =
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
using namespace ck::literals;
if(is_same<decltype(layout), tensor_layout::gemm::RowMajor>::value)
{
return HostTensorDescriptor({row, col}, {stride, 1_uz});
}
else
{
return HostTensorDescriptor({row, col}, {1_uz, stride});
}
};
std::size_t group_count = Ms.size();
if(!(group_count == Ns.size() && group_count == Ks.size() && group_count == StrideAs.size() &&
group_count == StrideBs.size() && group_count == StrideEs.size()))
{
throw std::runtime_error("wrong! inconsistent M/N/Ks, StrideA/B/Cs size\n");
}
std::vector<Tensor<ADataType>> a_m_k;
std::vector<Tensor<BDataType>> b_k_n;
std::vector<Tensor<DDataType>> d_m_n;
std::vector<Tensor<CDataType>> e_m_n_host_results;
std::vector<Tensor<CDataType>> e_m_n_device_results;
for(std::size_t i = 0; i < group_count; i++)
{
a_m_k.push_back(
Tensor<ADataType>(f_host_tensor_descriptor(Ms[i], Ks[i], StrideAs[i], ALayout{})));
b_k_n.push_back(
Tensor<BDataType>(f_host_tensor_descriptor(Ks[i], Ns[i], StrideBs[i], BLayout{})));
d_m_n.push_back(
Tensor<DDataType>(f_host_tensor_descriptor(Ms[i], Ns[i], StrideDs[i], DLayout{})));
e_m_n_device_results.push_back(
Tensor<CDataType>(f_host_tensor_descriptor(Ms[i], Ns[i], StrideEs[i], ELayout{})));
e_m_n_host_results.push_back(
Tensor<CDataType>(f_host_tensor_descriptor(Ms[i], Ns[i], StrideEs[i], ELayout{})));
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
std::cout << "group: " << i << " a_m_k[" << i << "]:" << a_m_k[i].mDesc << ", b_k_n["
<< i << "]:" << b_k_n[i].mDesc << ", e_m_n_device_results[" << i
<< "]:" << e_m_n_device_results[i].mDesc << std::endl;
}
switch(init_method)
{
case 0: break;
case 1:
ck::utils::FillUniformDistributionIntegerValue<ADataType>{-5, 5}(a_m_k[i]);
ck::utils::FillUniformDistributionIntegerValue<BDataType>{-5, 5}(b_k_n[i]);
ck::utils::FillUniformDistributionIntegerValue<DDataType>{-5, 5}(d_m_n[i]);
break;
case 2:
ck::utils::FillUniformDistribution<ADataType>{.0, 1.}(a_m_k[i]);
ck::utils::FillUniformDistribution<BDataType>{-0.5, 0.5}(b_k_n[i]);
ck::utils::FillUniformDistribution<DDataType>{-0.5, 0.5}(d_m_n[i]);
break;
default:
ck::utils::FillConstant<ADataType>{1}(a_m_k[i]);
ck::utils::FillConstant<BDataType>{1}(b_k_n[i]);
ck::utils::FillConstant<DDataType>{1}(d_m_n[i]);
}
}
using AElementOp = ck::tensor_operation::element_wise::PassThrough;
using BElementOp = ck::tensor_operation::element_wise::PassThrough;
using CElementOp = ck::tensor_operation::element_wise::PassThrough;
using CDEElementOp = ck::tensor_operation::element_wise::Multiply;
const auto a_element_op = AElementOp{};
const auto b_element_op = BElementOp{};
const auto c_element_op = CElementOp{};
const auto cde_element_op = CDEElementOp{};
using DeviceMemPtr = std::unique_ptr<DeviceMem>;
std::vector<DeviceMemPtr> a_device_buf, b_device_buf, d_device_buf, e_device_buf;
a_device_buf.reserve(group_count);
b_device_buf.reserve(group_count);
d_device_buf.reserve(group_count);
e_device_buf.reserve(group_count);
std::vector<const void*> p_a, p_b, p_d;
constexpr ck::index_t NumDTensor = 1;
auto p_ds = std::vector<std::array<const void*, NumDTensor>>{};
std::vector<void*> p_e;
p_a.reserve(group_count);
p_b.reserve(group_count);
p_ds.reserve(group_count);
p_e.reserve(group_count);
using KernelArguments =
ck::tensor_operation::device::GroupedGemmTileLoopKernelArguments<NumDTensor>;
std::vector<ck::tensor_operation::device::GemmDesc> gemm_descs;
std::vector<KernelArguments> gemm_kargs;
gemm_descs.reserve(group_count);
gemm_kargs.reserve(group_count);
for(std::size_t i = 0; i < group_count; i++)
{
a_device_buf.emplace_back(
std::make_unique<DeviceMem>(sizeof(ADataType) * a_m_k[i].mDesc.GetElementSpaceSize()));
b_device_buf.emplace_back(
std::make_unique<DeviceMem>(sizeof(BDataType) * b_k_n[i].mDesc.GetElementSpaceSize()));
d_device_buf.emplace_back(
std::make_unique<DeviceMem>(sizeof(DDataType) * d_m_n[i].mDesc.GetElementSpaceSize()));
e_device_buf.emplace_back(std::make_unique<DeviceMem>(
sizeof(CDataType) * e_m_n_device_results[i].mDesc.GetElementSpaceSize()));
a_device_buf[i]->ToDevice(a_m_k[i].mData.data());
b_device_buf[i]->ToDevice(b_k_n[i].mData.data());
d_device_buf[i]->ToDevice(d_m_n[i].mData.data());
e_device_buf[i]->SetZero();
p_a.push_back(a_device_buf[i]->GetDeviceBuffer());
p_b.push_back(b_device_buf[i]->GetDeviceBuffer());
p_ds.push_back({d_device_buf[i]->GetDeviceBuffer()});
p_e.push_back(e_device_buf[i]->GetDeviceBuffer());
gemm_descs.push_back(
{0, Ns[i], Ks[i], StrideAs[i], StrideBs[i], StrideEs[i], {StrideDs[i]}});
gemm_kargs.push_back({a_device_buf[i]->GetDeviceBuffer(),
b_device_buf[i]->GetDeviceBuffer(),
{d_device_buf[i]->GetDeviceBuffer()},
e_device_buf[i]->GetDeviceBuffer(),
Ms[i],
Ns[i],
Ks[i],
StrideAs[i],
StrideBs[i],
{StrideDs[i]},
StrideEs[i]});
}
using DeviceOp = ck::tensor_operation::device::DeviceGroupedGemmTileLoop<ALayout,
BLayout,
ck::Tuple<DLayout>,
ELayout,
ADataType,
BDataType,
ck::Tuple<DDataType>,
EDataType,
AElementOp,
BElementOp,
CDEElementOp>;
const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
DeviceOp>::GetInstances();
if(op_ptrs.size() <= 0)
{
throw std::runtime_error("wrong! no device GEMM instance found");
}
std::string best_gemm_name;
float best_ave_time = 0;
float best_tflops = 0;
float best_gb_per_sec = 0;
if(do_verification)
{
for(std::size_t i = 0; i < gemm_descs.size(); i++)
{
Tensor<CDataType> c_m_n({Ms[i], Ns[i]});
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
BDataType,
CDataType,
AccDataType,
AElementOp,
BElementOp,
CElementOp>;
auto ref_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_gemm.MakeInvoker();
auto ref_argument = ref_gemm.MakeArgument(
a_m_k[i], b_k_n[i], c_m_n, a_element_op, b_element_op, c_element_op);
ref_invoker.Run(ref_argument);
for(int m = 0; m < Ms[i]; ++m)
{
for(int n = 0; n < Ns[i]; ++n)
{
cde_element_op(e_m_n_host_results[i](m, n), c_m_n(m, n), d_m_n[i](m, n));
}
}
}
}
// profile device GEMM instances
for(auto& gemm_ptr : op_ptrs)
{
auto argument_ptr =
gemm_ptr->MakeArgumentPointer(p_a,
p_b,
p_ds,
p_e,
gemm_descs,
ck::tensor_operation::element_wise::PassThrough{},
ck::tensor_operation::element_wise::PassThrough{},
cde_element_op);
auto invoker_ptr = gemm_ptr->MakeInvokerPointer();
std::string gemm_name = gemm_ptr->GetTypeString();
DeviceMem gemm_arg_dev_mem(gemm_ptr->GetDeviceKernelArgSize(argument_ptr.get()));
hip_check_error(hipMemcpy(gemm_arg_dev_mem.GetDeviceBuffer(),
gemm_kargs.data(),
gemm_ptr->GetDeviceKernelArgSize(argument_ptr.get()),
hipMemcpyHostToDevice));
gemm_ptr->SetDeviceKernelArgs(argument_ptr.get(), gemm_arg_dev_mem.GetDeviceBuffer());
if(gemm_ptr->IsSupportedArgument(argument_ptr.get()))
{
invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false, 0, n_warmup, n_iter});
if(do_verification)
{
bool instance_pass = true;
for(std::size_t i = 0; i < gemm_descs.size(); i++)
{
e_device_buf[i]->FromDevice(e_m_n_device_results[i].mData.data());
instance_pass = instance_pass && ck::utils::check_err(e_m_n_device_results[i],
e_m_n_host_results[i]);
if(do_log)
{
LogRangeAsType<float>(std::cout << "a : ", a_m_k[i].mData, ",")
<< std::endl;
LogRangeAsType<float>(std::cout << "b: ", b_k_n[i].mData, ",") << std::endl;
LogRangeAsType<float>(
std::cout << "e_device: ", e_m_n_device_results[i].mData, ",")
<< std::endl;
LogRangeAsType<float>(
std::cout << "e_host : ", e_m_n_host_results[i].mData, ",")
<< std::endl;
}
}
std::cout << "Instance: " << gemm_name << " verification "
<< (instance_pass ? "SUCCEED" : "FAILED") << std::endl;
pass = pass && instance_pass;
}
if(time_kernel)
{
float ave_time = invoker_ptr->Run(
argument_ptr.get(), StreamConfig{nullptr, time_kernel, 0, n_warmup, n_iter});
std::size_t flop = 0, num_btype = 0;
for(std::size_t i = 0; i < gemm_descs.size(); i++)
{
flop += std::size_t(2) * Ms[i] * Ns[i] * Ks[i];
num_btype += sizeof(ADataType) * Ms[i] * Ks[i] +
sizeof(BDataType) * Ks[i] * Ns[i] +
sizeof(EDataType) * Ms[i] * Ns[i] + // D matrix
sizeof(EDataType) * Ms[i] * Ns[i];
}
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops
<< " TFlops, " << gb_per_sec << " GB/s, " << gemm_name << std::endl;
if(tflops > best_tflops)
{
best_gemm_name = gemm_name;
best_tflops = tflops;
best_ave_time = ave_time;
best_gb_per_sec = gb_per_sec;
}
}
}
else
{
std::cout << "Instance: " << gemm_name << ", does not support this GEMM problem"
<< std::endl;
}
}
if(time_kernel)
{
std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, "
<< best_gb_per_sec << " GB/s, " << best_gemm_name << std::endl;
}
return pass;
}
} // namespace profiler
} // namespace ck

View File

@@ -43,6 +43,7 @@ if(GPU_TARGETS MATCHES "gfx9")
list(APPEND PROFILER_SOURCES profile_grouped_gemm_two_stage.cpp)
list(APPEND PROFILER_SOURCES profile_grouped_gemm_fastgelu.cpp)
list(APPEND PROFILER_SOURCES profile_grouped_gemm_tile_loop.cpp)
list(APPEND PROFILER_SOURCES profile_grouped_gemm_multiply_tile_loop.cpp)
endif()
list(APPEND PROFILER_SOURCES profile_gemm_multiply_add.cpp)
list(APPEND PROFILER_SOURCES profile_batched_gemm.cpp)

View File

@@ -0,0 +1,133 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include <iostream>
#include <sstream>
#include <string>
#include <vector>
#include "profiler/profile_grouped_gemm_multiply_tile_loop_impl.hpp"
#include "profiler_operation_registry.hpp"
enum struct GemmMatrixLayout
{
MK_KN_MN, // 0
};
enum struct GemmDataType
{
BF16_INT8_BF16_BF16, // 0
};
#define OP_NAME "grouped_gemm_multiply_tile_loop"
#define OP_DESC "Grouped GEMM Multiply Multiple D Tile Loop"
namespace {
std::vector<int> argToIntArray(char* input)
{
std::vector<int> out;
std::istringstream in(input);
std::string item;
while(std::getline(in, item, ','))
{
out.push_back(std::stoi(item));
}
return out;
}
int profile_grouped_gemm_tile_loop(int argc, char* argv[])
{
if(argc < 14)
{
std::cout
<< "arg1: tensor operation (" OP_NAME ": " OP_DESC ")\n"
<< "arg2: data type (0: bf16@int8)\n"
<< "arg3: matrix layout (0: A[m, k] * B[k, n] = C[m, n]);\n"
<< "arg4: verification (0: no; 1: yes)\n"
<< "arg5: initialization (0: no init; 1: integer value; 2: decimal value)\n"
<< "arg6: print tensor value (0: no; 1: yes)\n"
<< "arg7: time kernel (0=n0, 1=yes)\n"
<< "arg8 to 13: Ms, Ns, Ks, StrideAs, StrideBs, StrideCs (e.g., 256,256 128,128 64,64 "
"64,64 64,64 128,128)\n"
<< "optional:\n"
<< "arg14: number of warm-up cycles (default 1)\n"
<< "arg15: number of iterations (default 10)\n"
<< std::endl;
exit(1);
}
const auto data_type = static_cast<GemmDataType>(std::stoi(argv[2]));
const auto layout = static_cast<GemmMatrixLayout>(std::stoi(argv[3]));
const bool do_verification = std::stoi(argv[4]);
const int init_method = std::stoi(argv[5]);
const bool do_log = std::stoi(argv[6]);
const bool time_kernel = std::stoi(argv[7]);
const auto Ms = argToIntArray(argv[8]);
const auto Ns = argToIntArray(argv[9]);
const auto Ks = argToIntArray(argv[10]);
auto StrideAs = argToIntArray(argv[11]);
auto StrideBs = argToIntArray(argv[12]);
auto StrideCs = argToIntArray(argv[13]);
const int DefaultStrideA = Ks[0];
const int DefaultStrideB = Ns[0];
const int DefaultStrideC = Ns[0];
for(size_t i = 0; i < Ms.size(); ++i)
{
StrideAs[i] = StrideAs[i] == -1 ? DefaultStrideA : StrideAs[i];
StrideBs[i] = StrideBs[i] == -1 ? DefaultStrideB : StrideBs[i];
StrideCs[i] = StrideCs[i] == -1 ? DefaultStrideC : StrideCs[i];
}
std::vector<int> StrideDs(StrideCs);
int n_warmup = 10;
int n_iter = 50;
if(argc == 16)
{
n_warmup = std::stoi(argv[14]);
n_iter = std::stoi(argv[15]);
}
if(data_type == GemmDataType::BF16_INT8_BF16_BF16 && layout == GemmMatrixLayout::MK_KN_MN)
{
ck::profiler::profile_grouped_gemm_multiply_tile_loop_impl<
ck::bhalf_t,
int8_t,
ck::bhalf_t,
ck::bhalf_t,
float,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor>(do_verification,
init_method,
do_log,
time_kernel,
Ms,
Ns,
Ks,
StrideAs,
StrideBs,
StrideDs,
StrideCs,
n_warmup,
n_iter);
}
else
{
throw std::runtime_error("wrong! this GEMM data_type & layout is not implemented");
}
return 0;
}
} // anonymous namespace
REGISTER_PROFILER_OPERATION(OP_NAME, OP_DESC, profile_grouped_gemm_tile_loop);