mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
Switch to universal gemm in grouped gemm tile loop (#1335)
* switch to universal gemm in grouped gemm tile loop * minor fixes * add reviewers comments --------- Co-authored-by: Adam Osewski <19374865+aosewski@users.noreply.github.com>
This commit is contained in:
@@ -13,7 +13,7 @@
|
||||
#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_gemm_tile_loop_multply.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_gemm_tile_loop_multiply.hpp"
|
||||
|
||||
#include "ck/host_utility/hip_check_error.hpp"
|
||||
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_gemm_tile_loop_multply.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_gemm_tile_loop_multiply.hpp"
|
||||
|
||||
#include "ck/host_utility/hip_check_error.hpp"
|
||||
|
||||
|
||||
@@ -63,7 +63,7 @@ using DeviceGemmInstance =
|
||||
//######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
|
||||
//######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
|
||||
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
< ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 4>;
|
||||
< ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<4,4,4>>;
|
||||
// clang-format on
|
||||
|
||||
struct ProblemSize final
|
||||
|
||||
@@ -144,12 +144,12 @@ struct BlockwiseGemmXdlops_pipeline_v1<BlockGemmPipelineScheduler::Intrawave,
|
||||
static constexpr index_t PrefillStages = 1;
|
||||
static constexpr index_t GlobalBufferNum = 1;
|
||||
|
||||
__host__ static constexpr bool BlockHasHotloop(index_t num_loop)
|
||||
__host__ __device__ static constexpr bool BlockHasHotloop(index_t num_loop)
|
||||
{
|
||||
return num_loop > PrefetchStages;
|
||||
}
|
||||
|
||||
__host__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop)
|
||||
__host__ __device__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop)
|
||||
{
|
||||
ignore = num_loop;
|
||||
return TailNumber::Full;
|
||||
@@ -446,12 +446,12 @@ struct BlockwiseGemmXdlops_pipeline_v1<BlockGemmPipelineScheduler::Interwave,
|
||||
static constexpr index_t PrefetchStages = 1;
|
||||
static constexpr index_t PrefillStages = 1;
|
||||
static constexpr index_t GlobalBufferNum = 1;
|
||||
__host__ static constexpr bool BlockHasHotloop(index_t num_loop)
|
||||
__host__ __device__ static constexpr bool BlockHasHotloop(index_t num_loop)
|
||||
{
|
||||
return num_loop > PrefetchStages;
|
||||
}
|
||||
|
||||
__host__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop)
|
||||
__host__ __device__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop)
|
||||
{
|
||||
ignore = num_loop;
|
||||
return TailNumber::Full;
|
||||
|
||||
@@ -153,12 +153,12 @@ struct BlockwiseGemmXdlops_pipeline_v2<BlockGemmPipelineScheduler::Intrawave,
|
||||
static constexpr index_t PrefillStages = 1;
|
||||
static constexpr index_t GlobalBufferNum = PrefetchStages;
|
||||
|
||||
__host__ static constexpr bool BlockHasHotloop(index_t num_loop)
|
||||
__host__ __device__ static constexpr bool BlockHasHotloop(index_t num_loop)
|
||||
{
|
||||
return num_loop > PrefetchStages;
|
||||
}
|
||||
|
||||
__host__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop)
|
||||
__host__ __device__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop)
|
||||
{
|
||||
if(num_loop % PrefetchStages == 1)
|
||||
{
|
||||
@@ -646,12 +646,12 @@ struct BlockwiseGemmXdlops_pipeline_v2<BlockGemmPipelineScheduler::Interwave,
|
||||
static constexpr index_t PrefillStages = 1;
|
||||
static constexpr index_t GlobalBufferNum = PrefetchStages;
|
||||
|
||||
__host__ static constexpr bool BlockHasHotloop(index_t num_loop)
|
||||
__host__ __device__ static constexpr bool BlockHasHotloop(index_t num_loop)
|
||||
{
|
||||
return num_loop > PrefetchStages;
|
||||
}
|
||||
|
||||
__host__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop)
|
||||
__host__ __device__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop)
|
||||
{
|
||||
if(num_loop % PrefetchStages == 1)
|
||||
{
|
||||
|
||||
@@ -146,12 +146,12 @@ struct BlockwiseGemmXdlops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
|
||||
static constexpr index_t PrefillStages = 1;
|
||||
static constexpr index_t GlobalBufferNum = 1;
|
||||
|
||||
__host__ static constexpr bool BlockHasHotloop(index_t num_loop)
|
||||
__host__ __device__ static constexpr bool BlockHasHotloop(index_t num_loop)
|
||||
{
|
||||
return num_loop > PrefetchStages;
|
||||
}
|
||||
|
||||
__host__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop)
|
||||
__host__ __device__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop)
|
||||
{
|
||||
ignore = num_loop;
|
||||
return TailNumber::Full;
|
||||
|
||||
@@ -147,12 +147,12 @@ struct BlockwiseGemmXdlops_pipeline_v4<BlockGemmPipelineScheduler::Intrawave,
|
||||
static constexpr index_t GlobalBufferNum = 2;
|
||||
static constexpr index_t HotloopUnroll = 2;
|
||||
|
||||
__host__ static constexpr bool BlockHasHotloop(index_t num_loop)
|
||||
__host__ __device__ static constexpr bool BlockHasHotloop(index_t num_loop)
|
||||
{
|
||||
return num_loop > PrefetchStages;
|
||||
}
|
||||
|
||||
__host__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop)
|
||||
__host__ __device__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop)
|
||||
{
|
||||
if(num_loop % HotloopUnroll == 1)
|
||||
{
|
||||
|
||||
@@ -19,6 +19,7 @@
|
||||
#include "ck/tensor_operation/gpu/device/device_grouped_gemm_tile_loop.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include <ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp>
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp" // stare wywalic
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp"
|
||||
|
||||
@@ -42,16 +43,22 @@ namespace device {
|
||||
template <typename GridwiseGemm,
|
||||
typename GemmDesc,
|
||||
GemmSpecialization GemmSpec,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename DsDataType,
|
||||
typename EDataType,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename ELayout,
|
||||
index_t KPerBlock,
|
||||
typename OffsettedBlockToCTileMap,
|
||||
typename LocalBlock2ETileMap,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CDEElementwiseOperation>
|
||||
typename CDEElementwiseOperation,
|
||||
BlockGemmPipelineScheduler BlkGemmPipeSched,
|
||||
BlockGemmPipelineVersion BlkGemmPipelineVer>
|
||||
__global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
@@ -67,6 +74,7 @@ __global__ void
|
||||
|
||||
constexpr index_t shared_size = GridwiseGemm::GetSharedMemoryNumberOfByte();
|
||||
__shared__ uint8_t p_shared[shared_size];
|
||||
__shared__ uint8_t p_shared1[shared_size];
|
||||
|
||||
const auto gemm_desc_ptr =
|
||||
reinterpret_cast<const GemmDesc*>(cast_pointer_to_generic_address_space(gemm_descs_const));
|
||||
@@ -81,27 +89,8 @@ __global__ void
|
||||
index_t gemm_tile_id_start = 0;
|
||||
index_t gemm_tile_id_end = 0;
|
||||
|
||||
using AGridDescMK =
|
||||
remove_cvref_t<decltype(GridwiseGemm::template MakeAGridDescriptor_M_K<ALayout, GemmSpec>(
|
||||
1, 1, 1))>;
|
||||
using BGridDescNK =
|
||||
remove_cvref_t<decltype(GridwiseGemm::template MakeBGridDescriptor_N_K<BLayout, GemmSpec>(
|
||||
1, 1, 1))>;
|
||||
using EGridDescMN =
|
||||
remove_cvref_t<decltype(GridwiseGemm::template MakeEGridDescriptor_M_N<ELayout, GemmSpec>(
|
||||
1, 1, 1))>;
|
||||
using DsGridDescMN =
|
||||
remove_cvref_t<decltype(GridwiseGemm::template MakeDsGridDescriptor_M_N<DsLayout, GemmSpec>(
|
||||
{}, {}, {}))>;
|
||||
|
||||
index_t M = 0, N = 0, K = 0;
|
||||
index_t StrideA, StrideB, StrideE;
|
||||
std::array<index_t, NumDTensor> StrideDs;
|
||||
|
||||
AGridDescMK a_grid_desc_mk;
|
||||
BGridDescNK b_grid_desc_nk;
|
||||
EGridDescMN e_grid_desc_mn;
|
||||
DsGridDescMN ds_grid_desc_mn;
|
||||
auto b2c_tile_map = OffsettedBlockToCTileMap(LocalBlock2ETileMap(1, 1), 1, 1);
|
||||
|
||||
do
|
||||
@@ -127,31 +116,13 @@ __global__ void
|
||||
}
|
||||
|
||||
b2c_tile_map =
|
||||
OffsettedBlockToCTileMap(LocalBlock2ETileMap(M, N), group_offset, tile_offset);
|
||||
OffsettedBlockToCTileMap(LocalBlock2ETileMap(M, N, 4), group_offset, tile_offset);
|
||||
grid_size_grp = b2c_tile_map.CalculateGridSize(M, N);
|
||||
|
||||
gemm_tile_id_start = group_offset;
|
||||
gemm_tile_id_end = group_offset + grid_size_grp;
|
||||
}
|
||||
|
||||
StrideA = gemm_desc_ptr[group_id].StrideA;
|
||||
StrideB = gemm_desc_ptr[group_id].StrideB;
|
||||
StrideDs = gemm_desc_ptr[group_id].StrideDs;
|
||||
StrideE = gemm_desc_ptr[group_id].StrideE;
|
||||
|
||||
a_grid_desc_mk =
|
||||
GridwiseGemm::template MakeAGridDescriptor_M_K<ALayout, GemmSpec>(M, K, StrideA);
|
||||
b_grid_desc_nk =
|
||||
GridwiseGemm::template MakeBGridDescriptor_N_K<BLayout, GemmSpec>(K, N, StrideB);
|
||||
e_grid_desc_mn =
|
||||
GridwiseGemm::template MakeEGridDescriptor_M_N<ELayout, GemmSpec>(M, N, StrideE);
|
||||
|
||||
static_for<0, NumDTensor, 1>{}([&](auto j) {
|
||||
using DLayout = remove_cvref_t<tuple_element_t<j.value, DsLayout>>;
|
||||
ds_grid_desc_mn(j) = GridwiseGemm::template MakeEGridDescriptor_M_N<DLayout, GemmSpec>(
|
||||
M, N, StrideDs[j]);
|
||||
});
|
||||
|
||||
using DsGridPointer = decltype(GridwiseGemm::MakeDsGridPointer());
|
||||
DsGridPointer p_ds_grid;
|
||||
|
||||
@@ -160,42 +131,268 @@ __global__ void
|
||||
p_ds_grid(i) = static_cast<const DDataType*>(gemm_desc_ptr[group_id].p_ds_grid[i]);
|
||||
});
|
||||
|
||||
bool has_main_kblock_loop =
|
||||
GridwiseGemm::CalculateHasMainKBlockLoop(a_grid_desc_mk.GetLength(Number<1>{}));
|
||||
static constexpr index_t kbatch = 1;
|
||||
static constexpr index_t k_grain = kbatch * KPerBlock;
|
||||
index_t K_split = (K + k_grain - 1) / k_grain * KPerBlock;
|
||||
|
||||
const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split);
|
||||
|
||||
// Update tile offset if we have moved within group
|
||||
b2c_tile_map.UpdateTileOffset(tile_offset);
|
||||
|
||||
if(has_main_kblock_loop)
|
||||
using Problem = typename GridwiseGemm::Problem;
|
||||
auto problem = Problem(gemm_desc_ptr[group_id].M,
|
||||
gemm_desc_ptr[group_id].N,
|
||||
gemm_desc_ptr[group_id].K,
|
||||
gemm_desc_ptr[group_id].StrideA,
|
||||
gemm_desc_ptr[group_id].StrideB,
|
||||
gemm_desc_ptr[group_id].StrideDs,
|
||||
gemm_desc_ptr[group_id].StrideE,
|
||||
kbatch);
|
||||
|
||||
if(has_main_k_block_loop)
|
||||
{
|
||||
GridwiseGemm::template Run<true>(gemm_desc_ptr[group_id].p_a_grid,
|
||||
gemm_desc_ptr[group_id].p_b_grid,
|
||||
p_ds_grid,
|
||||
gemm_desc_ptr[group_id].p_e_grid,
|
||||
static_cast<void*>(p_shared),
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op,
|
||||
a_grid_desc_mk,
|
||||
b_grid_desc_nk,
|
||||
ds_grid_desc_mn,
|
||||
e_grid_desc_mn,
|
||||
b2c_tile_map);
|
||||
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 ||
|
||||
BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
|
||||
{
|
||||
GridwiseGemm::template Run<OffsettedBlockToCTileMap,
|
||||
true,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
TailNumber::Full>(
|
||||
static_cast<const ADataType*>(gemm_desc_ptr[group_id].p_a_grid),
|
||||
static_cast<const BDataType*>(gemm_desc_ptr[group_id].p_b_grid),
|
||||
p_ds_grid,
|
||||
static_cast<EDataType*>(gemm_desc_ptr[group_id].p_e_grid),
|
||||
static_cast<void*>(p_shared),
|
||||
problem,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op,
|
||||
b2c_tile_map);
|
||||
}
|
||||
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2)
|
||||
{
|
||||
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::One)
|
||||
{
|
||||
GridwiseGemm::template Run<OffsettedBlockToCTileMap,
|
||||
true,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
TailNumber::One>(
|
||||
static_cast<const ADataType*>(gemm_desc_ptr[group_id].p_a_grid),
|
||||
static_cast<const BDataType*>(gemm_desc_ptr[group_id].p_b_grid),
|
||||
p_ds_grid,
|
||||
static_cast<EDataType*>(gemm_desc_ptr[group_id].p_e_grid),
|
||||
static_cast<void*>(p_shared),
|
||||
problem,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op,
|
||||
b2c_tile_map);
|
||||
}
|
||||
else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Full)
|
||||
{
|
||||
GridwiseGemm::template Run<OffsettedBlockToCTileMap,
|
||||
true,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
TailNumber::Full>(
|
||||
static_cast<const ADataType*>(gemm_desc_ptr[group_id].p_a_grid),
|
||||
static_cast<const BDataType*>(gemm_desc_ptr[group_id].p_b_grid),
|
||||
p_ds_grid,
|
||||
static_cast<EDataType*>(gemm_desc_ptr[group_id].p_e_grid),
|
||||
static_cast<void*>(p_shared),
|
||||
problem,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op,
|
||||
b2c_tile_map);
|
||||
}
|
||||
|
||||
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 2)
|
||||
{
|
||||
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Two)
|
||||
{
|
||||
GridwiseGemm::template Run<OffsettedBlockToCTileMap,
|
||||
true,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
TailNumber::Two>(
|
||||
static_cast<const ADataType*>(gemm_desc_ptr[group_id].p_a_grid),
|
||||
static_cast<const BDataType*>(gemm_desc_ptr[group_id].p_b_grid),
|
||||
p_ds_grid,
|
||||
static_cast<EDataType*>(gemm_desc_ptr[group_id].p_e_grid),
|
||||
static_cast<void*>(p_shared),
|
||||
problem,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op,
|
||||
b2c_tile_map);
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 3)
|
||||
{
|
||||
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Three)
|
||||
{
|
||||
GridwiseGemm::template Run<OffsettedBlockToCTileMap,
|
||||
true,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
TailNumber::Three>(
|
||||
static_cast<const ADataType*>(gemm_desc_ptr[group_id].p_a_grid),
|
||||
static_cast<const BDataType*>(gemm_desc_ptr[group_id].p_b_grid),
|
||||
p_ds_grid,
|
||||
static_cast<EDataType*>(gemm_desc_ptr[group_id].p_e_grid),
|
||||
static_cast<void*>(p_shared),
|
||||
problem,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op,
|
||||
b2c_tile_map);
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 4)
|
||||
{
|
||||
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Four)
|
||||
{
|
||||
GridwiseGemm::template Run<OffsettedBlockToCTileMap,
|
||||
true,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
TailNumber::Four>(
|
||||
static_cast<const ADataType*>(gemm_desc_ptr[group_id].p_a_grid),
|
||||
static_cast<const BDataType*>(gemm_desc_ptr[group_id].p_b_grid),
|
||||
p_ds_grid,
|
||||
static_cast<EDataType*>(gemm_desc_ptr[group_id].p_e_grid),
|
||||
static_cast<void*>(p_shared),
|
||||
problem,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op,
|
||||
b2c_tile_map);
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 5)
|
||||
{
|
||||
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Five)
|
||||
{
|
||||
GridwiseGemm::template Run<OffsettedBlockToCTileMap,
|
||||
true,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
TailNumber::Five>(
|
||||
static_cast<const ADataType*>(gemm_desc_ptr[group_id].p_a_grid),
|
||||
static_cast<const BDataType*>(gemm_desc_ptr[group_id].p_b_grid),
|
||||
p_ds_grid,
|
||||
static_cast<EDataType*>(gemm_desc_ptr[group_id].p_e_grid),
|
||||
static_cast<void*>(p_shared),
|
||||
problem,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op,
|
||||
b2c_tile_map);
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 6)
|
||||
{
|
||||
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Six)
|
||||
{
|
||||
GridwiseGemm::template Run<OffsettedBlockToCTileMap,
|
||||
true,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
TailNumber::Six>(
|
||||
static_cast<const ADataType*>(gemm_desc_ptr[group_id].p_a_grid),
|
||||
static_cast<const BDataType*>(gemm_desc_ptr[group_id].p_b_grid),
|
||||
p_ds_grid,
|
||||
static_cast<EDataType*>(gemm_desc_ptr[group_id].p_e_grid),
|
||||
static_cast<void*>(p_shared),
|
||||
problem,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op,
|
||||
b2c_tile_map);
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 7)
|
||||
{
|
||||
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Seven)
|
||||
{
|
||||
GridwiseGemm::template Run<OffsettedBlockToCTileMap,
|
||||
true,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
TailNumber::Seven>(
|
||||
static_cast<const ADataType*>(gemm_desc_ptr[group_id].p_a_grid),
|
||||
static_cast<const BDataType*>(gemm_desc_ptr[group_id].p_b_grid),
|
||||
p_ds_grid,
|
||||
static_cast<EDataType*>(gemm_desc_ptr[group_id].p_e_grid),
|
||||
static_cast<void*>(p_shared),
|
||||
problem,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op,
|
||||
b2c_tile_map);
|
||||
}
|
||||
}
|
||||
}
|
||||
// Tail number could be Odd or Even
|
||||
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v4)
|
||||
{
|
||||
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
|
||||
{
|
||||
GridwiseGemm::template Run_2Lds<OffsettedBlockToCTileMap,
|
||||
true,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
TailNumber::Odd>(
|
||||
static_cast<const ADataType*>(gemm_desc_ptr[group_id].p_a_grid),
|
||||
static_cast<const BDataType*>(gemm_desc_ptr[group_id].p_b_grid),
|
||||
p_ds_grid,
|
||||
static_cast<EDataType*>(gemm_desc_ptr[group_id].p_e_grid),
|
||||
static_cast<void*>(p_shared),
|
||||
static_cast<void*>(p_shared1),
|
||||
problem,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op,
|
||||
b2c_tile_map);
|
||||
}
|
||||
else
|
||||
{
|
||||
GridwiseGemm::template Run_2Lds<OffsettedBlockToCTileMap,
|
||||
true,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
TailNumber::Even>(
|
||||
static_cast<const ADataType*>(gemm_desc_ptr[group_id].p_a_grid),
|
||||
static_cast<const BDataType*>(gemm_desc_ptr[group_id].p_b_grid),
|
||||
p_ds_grid,
|
||||
static_cast<EDataType*>(gemm_desc_ptr[group_id].p_e_grid),
|
||||
static_cast<void*>(p_shared),
|
||||
static_cast<void*>(p_shared1),
|
||||
problem,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op,
|
||||
b2c_tile_map);
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
GridwiseGemm::template Run<false>(gemm_desc_ptr[group_id].p_a_grid,
|
||||
gemm_desc_ptr[group_id].p_b_grid,
|
||||
p_ds_grid,
|
||||
gemm_desc_ptr[group_id].p_e_grid,
|
||||
static_cast<void*>(p_shared),
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op,
|
||||
a_grid_desc_mk,
|
||||
b_grid_desc_nk,
|
||||
ds_grid_desc_mn,
|
||||
e_grid_desc_mn,
|
||||
b2c_tile_map);
|
||||
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
|
||||
{
|
||||
GridwiseGemm::template Run<OffsettedBlockToCTileMap,
|
||||
false,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
TailNumber::Full>(
|
||||
static_cast<const ADataType*>(gemm_desc_ptr[group_id].p_a_grid),
|
||||
static_cast<const BDataType*>(gemm_desc_ptr[group_id].p_b_grid),
|
||||
p_ds_grid,
|
||||
static_cast<EDataType*>(gemm_desc_ptr[group_id].p_e_grid),
|
||||
static_cast<void*>(p_shared),
|
||||
problem,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op,
|
||||
b2c_tile_map);
|
||||
}
|
||||
}
|
||||
|
||||
tile_id += get_grid_size();
|
||||
@@ -253,10 +450,12 @@ template <typename ALayout,
|
||||
index_t CShuffleMXdlPerWavePerShuffle,
|
||||
index_t CShuffleNXdlPerWavePerShuffle,
|
||||
typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
index_t CDEShuffleBlockTransferScalarPerVector_NPerBlock,
|
||||
PipelineVersion PipelineVer = PipelineVersion::v1,
|
||||
LoopScheduler LoopSched = make_default_loop_scheduler(),
|
||||
typename ComputeDataType = EDataType>
|
||||
typename CDEShuffleBlockTransferScalarPerVectors,
|
||||
BlockGemmPipelineScheduler BlkGemmPipeSched = BlockGemmPipelineScheduler::Intrawave,
|
||||
BlockGemmPipelineVersion BlkGemmPipelineVer = BlockGemmPipelineVersion::v1,
|
||||
typename ComputeTypeA = EDataType,
|
||||
typename ComputeTypeB = ComputeTypeA>
|
||||
|
||||
struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop
|
||||
: public DeviceGroupedGemmTileLoop<ALayout,
|
||||
BLayout,
|
||||
@@ -273,10 +472,13 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop
|
||||
using DeviceOp = DeviceGroupedGemmMultipleDXdlCShuffleTileLoop;
|
||||
static constexpr index_t NumDTensor = DsDataType::Size();
|
||||
|
||||
using GridwiseGemm = GridwiseGemmMultipleD_xdl_cshuffle<
|
||||
using GridwiseGemm = GridwiseGemmMultiD_xdl_cshuffle_v3<
|
||||
ALayout,
|
||||
BLayout,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
ADataType,
|
||||
BDataType,
|
||||
ComputeDataType,
|
||||
AccDataType,
|
||||
CShuffleDataType,
|
||||
DsDataType,
|
||||
@@ -284,8 +486,7 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CDEElementwiseOperation,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
NumGemmKPrefetchStage,
|
||||
GemmSpec,
|
||||
BlockSize,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
@@ -315,58 +516,15 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop
|
||||
CShuffleMXdlPerWavePerShuffle,
|
||||
CShuffleNXdlPerWavePerShuffle,
|
||||
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
CDEShuffleBlockTransferScalarPerVector_NPerBlock,
|
||||
LoopSched,
|
||||
PipelineVer>;
|
||||
CDEShuffleBlockTransferScalarPerVectors,
|
||||
BlkGemmPipeSched,
|
||||
BlkGemmPipelineVer,
|
||||
ComputeTypeA,
|
||||
ComputeTypeB>;
|
||||
|
||||
template <typename UnderlyingBlockToCTileMap>
|
||||
struct OffsettedBlockToCTileMap
|
||||
{
|
||||
using underlying_type = UnderlyingBlockToCTileMap;
|
||||
|
||||
__host__ __device__ OffsettedBlockToCTileMap(UnderlyingBlockToCTileMap block_to_ctile_map,
|
||||
index_t group_offset,
|
||||
index_t tile_offset)
|
||||
: block_to_ctile_map_{block_to_ctile_map},
|
||||
group_offset_{group_offset},
|
||||
tile_offset_{tile_offset}
|
||||
{
|
||||
}
|
||||
|
||||
template <typename TopIdx>
|
||||
__host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const
|
||||
{
|
||||
return block_to_ctile_map_.CalculateBottomIndex(
|
||||
make_multi_index(idx_top[Number<0>{}] + tile_offset_ - group_offset_));
|
||||
}
|
||||
|
||||
template <typename CTileIdx, typename CTileDim>
|
||||
__host__ __device__ bool ValidCTileIndex(const CTileIdx& c_tile_idx,
|
||||
const CTileDim& c_tile_dim) const
|
||||
{
|
||||
return block_to_ctile_map_.ValidCTileIndex(c_tile_idx, c_tile_dim);
|
||||
}
|
||||
|
||||
template <typename CGridDesc_M_N>
|
||||
__host__ constexpr bool CheckValidity(const CGridDesc_M_N& c_grid_desc_m_n) const
|
||||
{
|
||||
return block_to_ctile_map_.CheckValidity(c_grid_desc_m_n);
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr index_t CalculateGridSize(index_t M, index_t N) const
|
||||
{
|
||||
return block_to_ctile_map_.CalculateGridSize(M, N);
|
||||
}
|
||||
|
||||
__device__ void UpdateTileOffset(index_t offset) { tile_offset_ = offset; }
|
||||
UnderlyingBlockToCTileMap block_to_ctile_map_;
|
||||
index_t group_offset_;
|
||||
index_t tile_offset_;
|
||||
};
|
||||
|
||||
using KernelArguments = GroupedGemmTileLoopKernelArguments<NumDTensor>;
|
||||
using Block2ETileMap = BlockToCTileMap_N00_M0_N01Adapt<MPerBlock, NPerBlock>;
|
||||
using OffsetedLocalBlock2ETileMap = OffsettedBlockToCTileMap<Block2ETileMap>;
|
||||
using KernelArguments = GroupedGemmTileLoopKernelArguments<NumDTensor>;
|
||||
using Block2ETileMap = BlockToCTileMap_Grouped_M00_N0_M01Adapt<8, MPerBlock, NPerBlock>;
|
||||
using OffsettedLocalBlock2ETileMap = OffsettedBlockToCTileMap2<Block2ETileMap>;
|
||||
|
||||
// Argument
|
||||
struct Argument : public BaseArgument
|
||||
@@ -403,7 +561,6 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop
|
||||
const void* p_dev_gemm_args_;
|
||||
int occupancy_num_blocks_;
|
||||
int gpu_cu_count_;
|
||||
|
||||
const std::vector<GemmDesc>& gemm_descs_;
|
||||
AElementwiseOperation a_element_op_;
|
||||
BElementwiseOperation b_element_op_;
|
||||
@@ -496,16 +653,22 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop
|
||||
const auto kernel = kernel_grouped_gemm_multiple_d_xdl<GridwiseGemm,
|
||||
KernelArguments,
|
||||
GemmSpec,
|
||||
ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
EDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
OffsetedLocalBlock2ETileMap,
|
||||
KPerBlock,
|
||||
OffsettedLocalBlock2ETileMap,
|
||||
Block2ETileMap,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CDEElementwiseOperation>;
|
||||
CDEElementwiseOperation,
|
||||
BlkGemmPipeSched,
|
||||
BlkGemmPipelineVer>;
|
||||
return LaunchKernel(kernel, arg, dev_gemm_args, stream_config);
|
||||
}
|
||||
|
||||
@@ -546,6 +709,8 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop
|
||||
<< std::endl;
|
||||
}
|
||||
|
||||
// run multiple kernels
|
||||
|
||||
return launch_and_time_kernel(stream_config,
|
||||
kernel,
|
||||
dim3(grid_size),
|
||||
@@ -572,63 +737,41 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop
|
||||
return false;
|
||||
}
|
||||
|
||||
using DsGridDescMN = remove_cvref_t<
|
||||
decltype(GridwiseGemm::template MakeDsGridDescriptor_M_N<DsLayout, GemmSpec>(
|
||||
{}, {}, {}))>;
|
||||
|
||||
bool supported = true;
|
||||
|
||||
for(const auto& gdesc : arg.gemm_descs_)
|
||||
constexpr index_t k_batch = 1;
|
||||
for(index_t i = 0; i < arg.group_count_; ++i)
|
||||
{
|
||||
const auto M = gdesc.M_;
|
||||
const auto N = gdesc.N_;
|
||||
const auto K = gdesc.K_;
|
||||
std::array<const void*, NumDTensor> placeholder_p_ds_grid{};
|
||||
std::array<index_t, NumDTensor> stride_Ds;
|
||||
std::copy_n(arg.gemm_descs_[i].stride_Ds_.begin(), NumDTensor, stride_Ds.begin());
|
||||
using GridArg = typename GridwiseGemm::Argument;
|
||||
GridArg gridwise_arg(nullptr, // p_a_grid,
|
||||
nullptr, // p_b_grid,
|
||||
placeholder_p_ds_grid, // p_ds_grid,
|
||||
nullptr, // p_e_grid ,
|
||||
arg.gemm_descs_[i].M_,
|
||||
arg.gemm_descs_[i].N_,
|
||||
arg.gemm_descs_[i].K_,
|
||||
arg.gemm_descs_[i].stride_A_,
|
||||
arg.gemm_descs_[i].stride_B_,
|
||||
stride_Ds,
|
||||
arg.gemm_descs_[i].stride_C_,
|
||||
k_batch,
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.cde_element_op_);
|
||||
|
||||
const auto StrideA = gdesc.stride_A_;
|
||||
const auto StrideB = gdesc.stride_B_;
|
||||
const auto StrideE = gdesc.stride_C_;
|
||||
const auto& StrideDs = gdesc.stride_Ds_;
|
||||
|
||||
// If M dimension is unknown at launch time then validate just NK.
|
||||
// If N or K dim is zero (or unknown) then the vector loads responsibility lies on
|
||||
// the user.
|
||||
if(N * K == 0)
|
||||
continue;
|
||||
|
||||
const auto a_grid_desc_mk =
|
||||
GridwiseGemm::template MakeAGridDescriptor_M_K<ALayout, GemmSpec>(M, K, StrideA);
|
||||
const auto b_grid_desc_nk =
|
||||
GridwiseGemm::template MakeBGridDescriptor_N_K<BLayout, GemmSpec>(K, N, StrideB);
|
||||
const auto e_grid_desc_mn =
|
||||
GridwiseGemm::template MakeEGridDescriptor_M_N<ELayout, GemmSpec>(M, N, StrideE);
|
||||
|
||||
DsGridDescMN ds_grid_desc_mn;
|
||||
static_for<0, NumDTensor, 1>{}([&](auto j) {
|
||||
using DLayout = remove_cvref_t<tuple_element_t<j.value, DsLayout>>;
|
||||
ds_grid_desc_mn(j) =
|
||||
GridwiseGemm::template MakeEGridDescriptor_M_N<DLayout, GemmSpec>(
|
||||
M, N, StrideDs[j]);
|
||||
});
|
||||
|
||||
const auto b2c_tile_map = Block2ETileMap(M, N);
|
||||
|
||||
if(!(GridwiseGemm::template CheckValidity(a_grid_desc_mk,
|
||||
b_grid_desc_nk,
|
||||
ds_grid_desc_mn,
|
||||
e_grid_desc_mn,
|
||||
b2c_tile_map) &&
|
||||
GridwiseGemm::template CheckTensorTransfersValidity<ALayout, BLayout, ELayout>(
|
||||
M, N, K)))
|
||||
if((arg.gemm_descs_[i].K_ % AK1 != 0 || arg.gemm_descs_[i].K_ % BK1 != 0) &&
|
||||
!(GemmSpec == GemmSpecialization::MKPadding ||
|
||||
GemmSpec == GemmSpecialization::NKPadding ||
|
||||
GemmSpec == GemmSpecialization::MNKPadding ||
|
||||
GemmSpec == GemmSpecialization::KPadding))
|
||||
{
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "The provided GEMM problem size (M,N,K) [" << M << "," << N << ","
|
||||
<< K << "] are not supported by current template parameters!"
|
||||
<< " In " << __FILE__ << ":" << __LINE__
|
||||
<< ", in function: " << __func__;
|
||||
}
|
||||
supported = false;
|
||||
return false;
|
||||
}
|
||||
|
||||
supported = supported && GridwiseGemm::CheckValidity(gridwise_arg);
|
||||
}
|
||||
|
||||
return supported;
|
||||
@@ -651,16 +794,22 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop
|
||||
const auto kernel = kernel_grouped_gemm_multiple_d_xdl<GridwiseGemm,
|
||||
KernelArguments,
|
||||
GemmSpec,
|
||||
ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
EDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
OffsetedLocalBlock2ETileMap,
|
||||
KPerBlock,
|
||||
OffsettedLocalBlock2ETileMap,
|
||||
Block2ETileMap,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CDEElementwiseOperation>;
|
||||
CDEElementwiseOperation,
|
||||
BlkGemmPipeSched,
|
||||
BlkGemmPipelineVer>;
|
||||
int occupancy, num_cu;
|
||||
hip_check_error(
|
||||
hipOccupancyMaxActiveBlocksPerMultiprocessor(&occupancy, kernel, BlockSize, 0));
|
||||
@@ -696,16 +845,22 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop
|
||||
const auto kernel = kernel_grouped_gemm_multiple_d_xdl<GridwiseGemm,
|
||||
KernelArguments,
|
||||
GemmSpec,
|
||||
ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
EDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
OffsetedLocalBlock2ETileMap,
|
||||
KPerBlock,
|
||||
OffsettedLocalBlock2ETileMap,
|
||||
Block2ETileMap,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CDEElementwiseOperation>;
|
||||
CDEElementwiseOperation,
|
||||
BlkGemmPipeSched,
|
||||
BlkGemmPipelineVer>;
|
||||
int occupancy, num_cu;
|
||||
hip_check_error(
|
||||
hipOccupancyMaxActiveBlocksPerMultiprocessor(&occupancy, kernel, BlockSize, 0));
|
||||
@@ -739,6 +894,17 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop
|
||||
{
|
||||
auto str = std::ostringstream();
|
||||
|
||||
std::map<BlockGemmPipelineScheduler, std::string> BlkGemmPipelineSchedulerToString{
|
||||
{BlockGemmPipelineScheduler::Intrawave, "Intrawave"},
|
||||
{BlockGemmPipelineScheduler::Interwave, "Interwave"}};
|
||||
|
||||
std::map<BlockGemmPipelineVersion, std::string> BlkGemmPipelineVersionToString{
|
||||
{BlockGemmPipelineVersion::v1, "v1"},
|
||||
{BlockGemmPipelineVersion::v2, "v2"},
|
||||
{BlockGemmPipelineVersion::v3, "v3"},
|
||||
{BlockGemmPipelineVersion::v4, "v4"},
|
||||
{BlockGemmPipelineVersion::v5, "v5"}};
|
||||
|
||||
// clang-format off
|
||||
str << "DeviceGroupedGemmMultipleDXdlCShuffleTileLoop"
|
||||
<< "<"
|
||||
@@ -760,8 +926,10 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop
|
||||
<< CShuffleMXdlPerWavePerShuffle << ", "
|
||||
<< CShuffleNXdlPerWavePerShuffle << ", "
|
||||
<< getGemmSpecializationString(GemmSpec) << ", "
|
||||
<< PipelineVer << ", "
|
||||
<< LoopSched
|
||||
<< "BlkGemmPipelineScheduler: "
|
||||
<< BlkGemmPipelineSchedulerToString[BlkGemmPipeSched] << ", "
|
||||
<< "BlkGemmPipelineVersion: "
|
||||
<< BlkGemmPipelineVersionToString[BlkGemmPipelineVer]
|
||||
<< ">";
|
||||
// clang-format on
|
||||
|
||||
|
||||
@@ -908,6 +908,51 @@ struct OffsettedBlockToCTileMap
|
||||
UnderlyingBlockToCTileMap block_to_ctile_map_;
|
||||
index_t block_start_;
|
||||
};
|
||||
// second version with 2 offsets
|
||||
template <typename UnderlyingBlockToCTileMap>
|
||||
struct OffsettedBlockToCTileMap2
|
||||
{
|
||||
using underlying_type = UnderlyingBlockToCTileMap;
|
||||
|
||||
__host__ __device__ OffsettedBlockToCTileMap2(UnderlyingBlockToCTileMap block_to_ctile_map,
|
||||
index_t group_offset,
|
||||
index_t tile_offset)
|
||||
: block_to_ctile_map_{block_to_ctile_map},
|
||||
group_offset_{group_offset},
|
||||
tile_offset_{tile_offset}
|
||||
{
|
||||
}
|
||||
|
||||
template <typename TopIdx>
|
||||
__host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const
|
||||
{
|
||||
return block_to_ctile_map_.CalculateBottomIndex(
|
||||
make_multi_index(idx_top[Number<0>{}] + tile_offset_ - group_offset_));
|
||||
}
|
||||
|
||||
template <typename CTileIdx, typename CTileDim>
|
||||
__host__ __device__ bool ValidCTileIndex(const CTileIdx& c_tile_idx,
|
||||
const CTileDim& c_tile_dim) const
|
||||
{
|
||||
return block_to_ctile_map_.ValidCTileIndex(c_tile_idx, c_tile_dim);
|
||||
}
|
||||
|
||||
template <typename CGridDesc_M_N>
|
||||
__host__ constexpr bool CheckValidity(const CGridDesc_M_N& c_grid_desc_m_n) const
|
||||
{
|
||||
return block_to_ctile_map_.CheckValidity(c_grid_desc_m_n);
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr index_t CalculateGridSize(index_t M, index_t N) const
|
||||
{
|
||||
return block_to_ctile_map_.CalculateGridSize(M, N);
|
||||
}
|
||||
|
||||
__device__ void UpdateTileOffset(index_t offset) { tile_offset_ = offset; }
|
||||
UnderlyingBlockToCTileMap block_to_ctile_map_;
|
||||
index_t group_offset_;
|
||||
index_t tile_offset_;
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief Simple tile mapping which creates 3D grid of block of threads.
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -189,55 +189,55 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
|
||||
|
||||
__host__ static auto CalculateGridSize(index_t M, index_t N, index_t KBatch)
|
||||
{
|
||||
return std::make_tuple(Block2CTileMap::CalculateGridSize(M, N), 1, KBatch);
|
||||
return std::make_tuple(Block2CTileMapDefault::CalculateGridSize(M, N), 1, KBatch);
|
||||
}
|
||||
|
||||
__host__ static auto CalculateMPadded(index_t M)
|
||||
__host__ __device__ static auto CalculateMPadded(index_t M)
|
||||
{
|
||||
return math::integer_least_multiple(M, MPerBlock);
|
||||
}
|
||||
|
||||
__host__ static auto CalculateNPadded(index_t N)
|
||||
__host__ __device__ static auto CalculateNPadded(index_t N)
|
||||
{
|
||||
return math::integer_least_multiple(N, NPerBlock);
|
||||
}
|
||||
|
||||
__host__ static auto CalculateKPadded(index_t K)
|
||||
__host__ __device__ static auto CalculateKPadded(index_t K)
|
||||
{
|
||||
return math::integer_divide_ceil(K, KPerBlock) * KPerBlock;
|
||||
}
|
||||
|
||||
__host__ static auto CalculateAK0Padded(index_t K, index_t K_Batch = 1)
|
||||
__host__ __device__ static auto CalculateAK0Padded(index_t K, index_t K_Batch = 1)
|
||||
{
|
||||
auto K_t = K_Batch * KPerBlock;
|
||||
return (K + K_t - 1) / K_t * (KPerBlock / AK1Value);
|
||||
}
|
||||
|
||||
__host__ static auto CalculateBK0Padded(index_t K, index_t K_Batch = 1)
|
||||
__host__ __device__ static auto CalculateBK0Padded(index_t K, index_t K_Batch = 1)
|
||||
{
|
||||
auto K_t = K_Batch * KPerBlock;
|
||||
return (K + K_t - 1) / K_t * (KPerBlock / BK1Value);
|
||||
}
|
||||
|
||||
__host__ static auto CalculateKPadded(index_t K, index_t K_Batch = 1)
|
||||
__host__ __device__ static auto CalculateKPadded(index_t K, index_t K_Batch = 1)
|
||||
{
|
||||
auto K_t = K_Batch * KPerBlock;
|
||||
return (K + K_t - 1) / K_t * KPerBlock;
|
||||
}
|
||||
|
||||
__host__ static auto CalculateKRead(index_t K, index_t K_Batch = 1)
|
||||
__host__ __device__ static auto CalculateKRead(index_t K, index_t K_Batch = 1)
|
||||
{
|
||||
constexpr auto KReadVec = math::lcm(AK1Number, BK1Number);
|
||||
auto K_t = K_Batch * KReadVec;
|
||||
return (K + K_t - 1) / K_t * KReadVec;
|
||||
}
|
||||
|
||||
__host__ static auto CalculateMBlock(index_t M)
|
||||
__host__ __device__ static auto CalculateMBlock(index_t M)
|
||||
{
|
||||
return math::integer_divide_ceil(M, MPerBlock);
|
||||
}
|
||||
|
||||
__host__ static auto CalculateNBlock(index_t N)
|
||||
__host__ __device__ static auto CalculateNBlock(index_t N)
|
||||
{
|
||||
return math::integer_divide_ceil(N, NPerBlock);
|
||||
}
|
||||
@@ -520,14 +520,14 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
|
||||
|
||||
struct Problem
|
||||
{
|
||||
__host__ Problem(index_t M_,
|
||||
index_t N_,
|
||||
index_t K_,
|
||||
index_t StrideA_,
|
||||
index_t StrideB_,
|
||||
std::array<index_t, NumDTensor> StrideDs_,
|
||||
index_t StrideC_,
|
||||
index_t KBatch_)
|
||||
__host__ __device__ Problem(index_t M_,
|
||||
index_t N_,
|
||||
index_t K_,
|
||||
index_t StrideA_,
|
||||
index_t StrideB_,
|
||||
std::array<index_t, NumDTensor> StrideDs_,
|
||||
index_t StrideC_,
|
||||
index_t KBatch_)
|
||||
: M{M_},
|
||||
N{N_},
|
||||
K{K_},
|
||||
@@ -1180,14 +1180,14 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
|
||||
return true;
|
||||
}
|
||||
|
||||
__host__ static constexpr bool CalculateHasMainKBlockLoop(index_t K)
|
||||
__host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K)
|
||||
{
|
||||
const index_t num_loop = K / KPerBlock;
|
||||
|
||||
return BlockwiseGemmPipe::BlockHasHotloop(num_loop);
|
||||
}
|
||||
|
||||
__host__ static constexpr TailNumber CalculateKBlockLoopTailNum(index_t K)
|
||||
__host__ __device__ static constexpr TailNumber CalculateKBlockLoopTailNum(index_t K)
|
||||
{
|
||||
const index_t num_loop = K / KPerBlock;
|
||||
|
||||
@@ -1210,8 +1210,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
|
||||
|
||||
// return block_id to C matrix tile idx (m0, n0) mapping
|
||||
// if arch = gfx942
|
||||
using Block2CTileMap = BlockToCTileMap_Grouped_M00_N0_M01Adapt<8, MPerBlock, NPerBlock>;
|
||||
// using Block2CTileMap = BlockToCTileMap_3DGrid_KSplit<MPerBlock, NPerBlock>;
|
||||
using Block2CTileMapDefault = BlockToCTileMap_Grouped_M00_N0_M01Adapt<8, MPerBlock, NPerBlock>;
|
||||
|
||||
template <bool HasMainKBlockLoop,
|
||||
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
|
||||
@@ -1225,6 +1224,35 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation c_element_op)
|
||||
{
|
||||
const auto block_2_ctile_map = Block2CTileMapDefault{problem.M, problem.N, 4};
|
||||
Run<Block2CTileMapDefault, HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_ds_grid,
|
||||
p_c_grid,
|
||||
p_shared,
|
||||
problem,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op,
|
||||
block_2_ctile_map);
|
||||
}
|
||||
|
||||
template <typename Block2CTileMap,
|
||||
bool HasMainKBlockLoop,
|
||||
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
|
||||
TailNumber TailNum = TailNumber::Odd>
|
||||
__device__ static void Run(const ADataType* p_a_grid,
|
||||
const BDataType* p_b_grid,
|
||||
DsGridPointer& p_ds_grid,
|
||||
CDataType* p_c_grid,
|
||||
void* p_shared,
|
||||
const Problem& problem,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation c_element_op,
|
||||
const Block2CTileMap& block_2_ctile_map)
|
||||
{
|
||||
const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(
|
||||
problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0);
|
||||
@@ -1244,9 +1272,6 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
|
||||
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
|
||||
|
||||
// divide block work by [M, N]
|
||||
const auto block_2_ctile_map = Block2CTileMap{problem.M, problem.N, 4};
|
||||
|
||||
const auto block_work_idx =
|
||||
block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
|
||||
|
||||
@@ -1653,6 +1678,38 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation c_element_op)
|
||||
{
|
||||
// divide block work by [M, N]
|
||||
const auto block_2_ctile_map = Block2CTileMapDefault{problem.M, problem.N, 4};
|
||||
Run_2Lds<Block2CTileMapDefault, HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_ds_grid,
|
||||
p_c_grid,
|
||||
p_shared_0,
|
||||
p_shared_1,
|
||||
problem,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op,
|
||||
block_2_ctile_map);
|
||||
}
|
||||
|
||||
template <typename Block2CTileMap,
|
||||
bool HasMainKBlockLoop,
|
||||
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
|
||||
TailNumber TailNum = TailNumber::Odd>
|
||||
__device__ static void Run_2Lds(const ADataType* p_a_grid,
|
||||
const BDataType* p_b_grid,
|
||||
DsGridPointer& p_ds_grid,
|
||||
CDataType* p_c_grid,
|
||||
void* p_shared_0,
|
||||
void* p_shared_1,
|
||||
const Problem& problem,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation c_element_op,
|
||||
const Block2CTileMap& block_2_ctile_map)
|
||||
{
|
||||
const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(
|
||||
problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0);
|
||||
@@ -1672,9 +1729,6 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
|
||||
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
|
||||
|
||||
// divide block work by [M, N]
|
||||
const auto block_2_ctile_map = Block2CTileMap{problem.M, problem.N, 4};
|
||||
|
||||
const auto block_work_idx =
|
||||
block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
|
||||
|
||||
|
||||
@@ -17,7 +17,150 @@ namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_instances(
|
||||
void add_device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_comp_default_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedGemmTileLoop<Row,
|
||||
Row,
|
||||
Row_Tuple,
|
||||
Row,
|
||||
BF16,
|
||||
I8,
|
||||
BF16_Tuple,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
Multiply>>>& instances);
|
||||
|
||||
void add_device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_comp_mnkpadding_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedGemmTileLoop<Row,
|
||||
Row,
|
||||
Row_Tuple,
|
||||
Row,
|
||||
BF16,
|
||||
I8,
|
||||
BF16_Tuple,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
Multiply>>>& instances);
|
||||
|
||||
void add_device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_comp_mnpadding_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedGemmTileLoop<Row,
|
||||
Row,
|
||||
Row_Tuple,
|
||||
Row,
|
||||
BF16,
|
||||
I8,
|
||||
BF16_Tuple,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
Multiply>>>& instances);
|
||||
|
||||
void add_device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_comp_kpadding_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedGemmTileLoop<Row,
|
||||
Row,
|
||||
Row_Tuple,
|
||||
Row,
|
||||
BF16,
|
||||
I8,
|
||||
BF16_Tuple,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
Multiply>>>& instances);
|
||||
|
||||
void add_device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_mem_v1_default_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedGemmTileLoop<Row,
|
||||
Row,
|
||||
Row_Tuple,
|
||||
Row,
|
||||
BF16,
|
||||
I8,
|
||||
BF16_Tuple,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
Multiply>>>& instances);
|
||||
|
||||
void add_device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_mem_v1_mnkpadding_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedGemmTileLoop<Row,
|
||||
Row,
|
||||
Row_Tuple,
|
||||
Row,
|
||||
BF16,
|
||||
I8,
|
||||
BF16_Tuple,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
Multiply>>>& instances);
|
||||
|
||||
void add_device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_mem_v1_mnpadding_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedGemmTileLoop<Row,
|
||||
Row,
|
||||
Row_Tuple,
|
||||
Row,
|
||||
BF16,
|
||||
I8,
|
||||
BF16_Tuple,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
Multiply>>>& instances);
|
||||
|
||||
void add_device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_mem_v1_kpadding_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedGemmTileLoop<Row,
|
||||
Row,
|
||||
Row_Tuple,
|
||||
Row,
|
||||
BF16,
|
||||
I8,
|
||||
BF16_Tuple,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
Multiply>>>& instances);
|
||||
|
||||
void add_device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_mem_v2_default_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedGemmTileLoop<Row,
|
||||
Row,
|
||||
Row_Tuple,
|
||||
Row,
|
||||
BF16,
|
||||
I8,
|
||||
BF16_Tuple,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
Multiply>>>& instances);
|
||||
|
||||
void add_device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_mem_v2_mnkpadding_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedGemmTileLoop<Row,
|
||||
Row,
|
||||
Row_Tuple,
|
||||
Row,
|
||||
BF16,
|
||||
I8,
|
||||
BF16_Tuple,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
Multiply>>>& instances);
|
||||
|
||||
void add_device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_mem_v2_mnpadding_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedGemmTileLoop<Row,
|
||||
Row,
|
||||
Row_Tuple,
|
||||
Row,
|
||||
BF16,
|
||||
I8,
|
||||
BF16_Tuple,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
Multiply>>>& instances);
|
||||
|
||||
void add_device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_mem_v2_kpadding_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedGemmTileLoop<Row,
|
||||
Row,
|
||||
Row_Tuple,
|
||||
@@ -67,14 +210,35 @@ struct DeviceOperationInstanceFactory<
|
||||
{
|
||||
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
|
||||
|
||||
// fp16_output
|
||||
if constexpr(is_same_v<ADataType, bhalf_t> && is_same_v<BDataType, int8_t> &&
|
||||
is_same_v<EDataType, bhalf_t>)
|
||||
{
|
||||
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
|
||||
is_same_v<ELayout, Row>)
|
||||
{
|
||||
add_device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_instances(
|
||||
add_device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_comp_default_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_comp_mnkpadding_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_comp_mnpadding_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_comp_kpadding_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_mem_v1_default_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_mem_v1_mnkpadding_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_mem_v1_mnpadding_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_mem_v1_kpadding_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_mem_v2_default_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_mem_v2_mnkpadding_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_mem_v2_mnpadding_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_mem_v2_kpadding_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
}
|
||||
@@ -132,7 +296,6 @@ struct DeviceOperationInstanceFactory<
|
||||
{
|
||||
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
|
||||
|
||||
// fp16_output
|
||||
if constexpr(is_same_v<ADataType, bhalf_t> && is_same_v<BDataType, int8_t> &&
|
||||
is_same_v<EDataType, bhalf_t>)
|
||||
{
|
||||
@@ -199,7 +362,6 @@ struct DeviceOperationInstanceFactory<
|
||||
{
|
||||
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
|
||||
|
||||
// fp16_output
|
||||
if constexpr(is_same_v<ADataType, bhalf_t> && is_same_v<BDataType, int8_t> &&
|
||||
is_same_v<EDataType, bhalf_t>)
|
||||
{
|
||||
@@ -266,7 +428,6 @@ struct DeviceOperationInstanceFactory<
|
||||
{
|
||||
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
|
||||
|
||||
// fp16_output
|
||||
if constexpr(is_same_v<ADataType, bhalf_t> && is_same_v<BDataType, int8_t> &&
|
||||
is_same_v<EDataType, bhalf_t>)
|
||||
{
|
||||
@@ -5,8 +5,22 @@ set(GROUPED_GEMM_TILE_LOOP_INSTANCES)
|
||||
list(APPEND GROUPED_GEMM_TILE_LOOP_INSTANCES
|
||||
device_grouped_gemm_xdl_tile_loop_f16_f16_f16_mk_kn_mn_instance.cpp
|
||||
device_grouped_gemm_xdl_tile_loop_f16_f16_f16_mk_nk_mn_instance.cpp
|
||||
|
||||
device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_instance.cpp
|
||||
device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn.hpp
|
||||
device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_comp_default_instance.cpp
|
||||
device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_comp_kpadding_instance.cpp
|
||||
device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_comp_mnkpadding_instance.cpp
|
||||
device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_comp_mnpadding_instance.cpp
|
||||
device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_mem_v1_default_instance.cpp
|
||||
device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_mem_v1_kpadding_instance.cpp
|
||||
device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_mem_v1_mnkpadding_instance.cpp
|
||||
device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_mem_v1_mnpadding_instance.cpp
|
||||
device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_mem_v2_default_instance.cpp
|
||||
device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_mem_v2_kpadding_instance.cpp
|
||||
device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_mem_v2_mnkpadding_instance.cpp
|
||||
device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_mem_v2_mnpadding_instance.cpp
|
||||
device_grouped_gemm_xdl_tile_loop_multiply_bias_bf16_i8_bf16_mk_kn_mn_instance.cpp
|
||||
device_grouped_gemm_xdl_tile_loop_multiply_bias_fastgelu_bf16_i8_bf16_mk_kn_mn_instance.cpp
|
||||
device_grouped_gemm_xdl_tile_loop_multiply_fastgelu_bf16_i8_bf16_mk_kn_mn_instance.cpp
|
||||
)
|
||||
|
||||
add_instance_library(device_grouped_gemm_tile_loop_instance ${GROUPED_GEMM_TILE_LOOP_INSTANCES})
|
||||
|
||||
@@ -38,16 +38,16 @@ using device_grouped_gemm_xdl_tile_loop_f16_f16_f16_mk_kn_mn_irregular_tile_inst
|
||||
//###########################################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
|
||||
//###########################################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
|
||||
//###########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
|
||||
DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
|
||||
DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 2, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>,
|
||||
DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
|
||||
DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 2, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>,
|
||||
DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
|
||||
DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 4>, 8>,
|
||||
DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>,
|
||||
DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8>,
|
||||
DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>
|
||||
DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8>>,
|
||||
DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8>>,
|
||||
DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 2, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, S<8>>,
|
||||
DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8>>,
|
||||
DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 2, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, S<8>>,
|
||||
DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8>>,
|
||||
DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 4>, S<8>>,
|
||||
DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, S<8>>,
|
||||
DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, S<8>>,
|
||||
DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, S<8>>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
|
||||
@@ -37,19 +37,19 @@ using device_grouped_gemm_xdl_tile_loop_f16_f16_f16_mk_nk_mn_irregular_tile_inst
|
||||
//###########################################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
|
||||
//###########################################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
|
||||
//###########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Col, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 256, 64, 8, 8, 32, 32, 2, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
|
||||
DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Col, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 128, 64, 8, 8, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
|
||||
DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Col, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 64, 64, 8, 8, 32, 32, 2, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
|
||||
DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Col, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 64, 128, 64, 8, 8, 32, 32, 1, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
|
||||
DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Col, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 128, 64, 8, 8, 32, 32, 4, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>,
|
||||
DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Col, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 64, 64, 8, 8, 32, 32, 2, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>,
|
||||
DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Col, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 64, 128, 64, 8, 8, 32, 32, 2, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>,
|
||||
DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Col, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 32, 64, 8, 8, 32, 32, 2, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>,
|
||||
DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Col, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 32, 128, 64, 8, 8, 32, 32, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>,
|
||||
DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Col, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 32, 256, 64, 8, 8, 32, 32, 1, 4, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>,
|
||||
DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Col, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 64, 64, 64, 8, 8, 32, 32, 2, 2, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>,
|
||||
DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Col, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 64, 32, 64, 8, 8, 32, 32, 2, 1, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>,
|
||||
DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Col, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 32, 64, 64, 8, 8, 32, 32, 1, 2, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>
|
||||
DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Col, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 256, 64, 8, 8, 32, 32, 2, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8>>,
|
||||
DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Col, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 128, 64, 8, 8, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8>>,
|
||||
DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Col, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 64, 64, 8, 8, 32, 32, 2, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8>>,
|
||||
DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Col, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 64, 128, 64, 8, 8, 32, 32, 1, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8>>,
|
||||
DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Col, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 128, 64, 8, 8, 32, 32, 4, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, S<8>>,
|
||||
DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Col, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 64, 64, 8, 8, 32, 32, 2, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, S<8>>,
|
||||
DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Col, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 64, 128, 64, 8, 8, 32, 32, 2, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, S<8>>,
|
||||
DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Col, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 32, 64, 8, 8, 32, 32, 2, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, S<8>>,
|
||||
DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Col, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 32, 128, 64, 8, 8, 32, 32, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, S<8>>,
|
||||
DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Col, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 32, 256, 64, 8, 8, 32, 32, 1, 4, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, S<8>>,
|
||||
DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Col, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 64, 64, 64, 8, 8, 32, 32, 2, 2, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, S<8>>,
|
||||
DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Col, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 64, 32, 64, 8, 8, 32, 32, 2, 1, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, S<8>>,
|
||||
DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Col, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 32, 64, 64, 8, 8, 32, 32, 1, 2, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, S<8>>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
|
||||
@@ -0,0 +1,93 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <cstdlib>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp"
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
using BF16 = ck::bhalf_t;
|
||||
using I8 = int8_t;
|
||||
using F32 = float;
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
using Multiply = ck::tensor_operation::element_wise::Multiply;
|
||||
using MultiplyAddFastGelu = ck::tensor_operation::element_wise::MultiplyAddFastGelu;
|
||||
using MultiplyFastGelu = ck::tensor_operation::element_wise::MultiplyFastGelu;
|
||||
using MultiplyAdd = ck::tensor_operation::element_wise::MultiplyAdd;
|
||||
|
||||
static constexpr auto GemmDefault = GemmSpecialization::Default;
|
||||
static constexpr auto GemmKPadding = GemmSpecialization::KPadding;
|
||||
static constexpr auto GemmMNPadding = GemmSpecialization::MNPadding;
|
||||
static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding;
|
||||
|
||||
static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave;
|
||||
static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave;
|
||||
|
||||
template <typename DsLayout,
|
||||
typename DsDataType,
|
||||
typename CDEElementwiseOp,
|
||||
GemmSpecialization GemmSpec = GemmMNKPadding>
|
||||
using device_grouped_gemm_xdl_tile_loop_bf16_i8_bf16_mk_kn_mn_comp_instances = std::tuple<
|
||||
// clang-format off
|
||||
//###########################################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
|
||||
//###########################################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
|
||||
//###########################################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
|
||||
//###########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | S<C,D0...,D_N|
|
||||
// DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmSpec, 1, 256, 256, 256, 32, 8, 4, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, S<8,8,1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>,
|
||||
// DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmSpec, 1, 256, 128, 128, 64, 8, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, S<8,8,1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>,
|
||||
// DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmSpec, 1, 256, 256, 256, 32, 8, 4, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, S<8,8,1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>,
|
||||
// DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmSpec, 1, 256, 256, 256, 32, 8, 4, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, S<8,8,1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>,
|
||||
// DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmSpec, 1, 256, 224, 256, 64, 8, 4, 16, 16, 7, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 2, S<1, 32, 1, 8>, S<8,8,1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>,
|
||||
// DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmSpec, 1, 256, 128, 128, 64, 8, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, S<8,8,1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>,
|
||||
// DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmSpec, 1, 256, 128, 256, 32, 8, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, S<8,8,1>, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmSpec, 1, 256, 128, 128, 64, 8, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, S<8,8,8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
template <typename DsLayout,
|
||||
typename DsDataType,
|
||||
typename CDEElementwiseOp,
|
||||
GemmSpecialization GemmSpec = GemmMNKPadding,
|
||||
BlockGemmPipelineScheduler BlkGemmPipeSched = BlockGemmPipelineScheduler::Intrawave>
|
||||
using device_grouped_gemm_xdl_tile_loop_bf16_i8_bf16_mk_kn_mn_mem_instances =
|
||||
std::tuple<
|
||||
// clang-format off
|
||||
//###########################################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
|
||||
//###########################################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
|
||||
//###########################################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
|
||||
//###########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | S<C,D0...,D_N|
|
||||
// Latency friendly
|
||||
// DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmSpec, 1, 64, 16, 16, 256, 8, 4, 16, 16, 1, 1, S<32, 2, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<64, 1, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 16, 1, 4>, S<4,4,1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
|
||||
// DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmSpec, 1, 128, 16, 32, 256, 8, 4, 16, 16, 1, 1, S<32, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<64, 2, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 16, 1, 8>, S<4,4,1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
|
||||
// Memory friendly
|
||||
// DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmSpec, 1, 64, 16, 16, 256, 8, 4, 16, 16, 1, 1, S<32, 2, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<64, 1, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 16, 1, 4>, S<4,4,1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
|
||||
// DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmSpec, 1, 128, 16, 32, 256, 8, 4, 16, 16, 1, 1, S<32, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<64, 2, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 16, 1, 8>, S<4,4,1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
|
||||
DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmSpec, 1, 128, 16, 64, 128, 8, 4, 16, 16, 1, 2, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<32, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 16, 1, 8>, S<4,4,1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>
|
||||
// DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmSpec, 1, 128, 32, 64, 128, 8, 4, 32, 32, 1, 1, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<32, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 16, 1, 8>, S<8,8,1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
|
||||
// DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmSpec, 1, 128, 16, 128, 64, 8, 4, 16, 16, 1, 4, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 16, 1, 8>, S<4,4,1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
|
||||
// DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmSpec, 1, 128, 32, 128, 64, 8, 4, 32, 32, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 16, 1, 8>, S<8,8,1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
|
||||
// DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmSpec, 1, 256, 16, 256, 64, 8, 4, 16, 16, 1, 4, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 16, 1, 16>, S<4,4,1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
|
||||
// DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmSpec, 1, 256, 32, 256, 64, 8, 4, 32, 32, 1, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 16, 1, 16>, S<8,8,1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,35 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_comp_default_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedGemmTileLoop<Row,
|
||||
Row,
|
||||
ck::Tuple<Row>,
|
||||
Row,
|
||||
BF16,
|
||||
I8,
|
||||
ck::Tuple<BF16>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
Multiply>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_gemm_xdl_tile_loop_bf16_i8_bf16_mk_kn_mn_comp_instances<ck::Tuple<Row>,
|
||||
ck::Tuple<BF16>,
|
||||
Multiply,
|
||||
GemmDefault>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,35 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_comp_kpadding_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedGemmTileLoop<Row,
|
||||
Row,
|
||||
ck::Tuple<Row>,
|
||||
Row,
|
||||
BF16,
|
||||
I8,
|
||||
ck::Tuple<BF16>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
Multiply>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_gemm_xdl_tile_loop_bf16_i8_bf16_mk_kn_mn_comp_instances<ck::Tuple<Row>,
|
||||
ck::Tuple<BF16>,
|
||||
Multiply,
|
||||
GemmKPadding>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,35 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_comp_mnkpadding_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedGemmTileLoop<Row,
|
||||
Row,
|
||||
ck::Tuple<Row>,
|
||||
Row,
|
||||
BF16,
|
||||
I8,
|
||||
ck::Tuple<BF16>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
Multiply>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_gemm_xdl_tile_loop_bf16_i8_bf16_mk_kn_mn_comp_instances<ck::Tuple<Row>,
|
||||
ck::Tuple<BF16>,
|
||||
Multiply,
|
||||
GemmMNKPadding>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,35 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_comp_mnpadding_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedGemmTileLoop<Row,
|
||||
Row,
|
||||
ck::Tuple<Row>,
|
||||
Row,
|
||||
BF16,
|
||||
I8,
|
||||
ck::Tuple<BF16>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
Multiply>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_gemm_xdl_tile_loop_bf16_i8_bf16_mk_kn_mn_comp_instances<ck::Tuple<Row>,
|
||||
ck::Tuple<BF16>,
|
||||
Multiply,
|
||||
GemmMNPadding>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -31,51 +31,63 @@ using MultiplyAddFastGelu = ck::tensor_operation::element_wise::MultiplyAddFastG
|
||||
using MultiplyFastGelu = ck::tensor_operation::element_wise::MultiplyFastGelu;
|
||||
using MultiplyAdd = ck::tensor_operation::element_wise::MultiplyAdd;
|
||||
|
||||
static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
|
||||
static constexpr auto GemmDefault = GemmSpecialization::Default;
|
||||
static constexpr auto GemmKPadding = GemmSpecialization::KPadding;
|
||||
static constexpr auto GemmMNPadding = GemmSpecialization::MNPadding;
|
||||
static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding;
|
||||
|
||||
template <typename DsLayout, typename DsDataType, typename CDEElementwiseOp>
|
||||
using device_grouped_gemm_xdl_tile_loop_bf16_i8_bf16_mk_kn_mn_irregular_tile_instances = std::tuple<
|
||||
// clang-format off
|
||||
static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave;
|
||||
static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave;
|
||||
|
||||
template <typename DsLayout,
|
||||
typename DsDataType,
|
||||
typename CDEElementwiseOp,
|
||||
GemmSpecialization GemmSpec = GemmMNKPadding>
|
||||
using device_grouped_gemm_xdl_tile_loop_bf16_i8_bf16_mk_kn_mn_comp_instances = std::tuple<
|
||||
// clang-format off
|
||||
//###########################################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
|
||||
//###########################################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
|
||||
//###########################################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
|
||||
//###########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
#if 1
|
||||
DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
|
||||
DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
|
||||
DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 2, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>,
|
||||
DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
|
||||
DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 2, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>,
|
||||
DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
|
||||
DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 4>, 8>,
|
||||
DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>,
|
||||
DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8>,
|
||||
DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>
|
||||
#endif
|
||||
#if 0
|
||||
//comp
|
||||
DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmMNKPadding, 1, 256, 128, 128, 64, 8, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8>,
|
||||
DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmMNKPadding, 1, 256, 128, 128, 64, 8, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8>,
|
||||
DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmMNKPadding, 1, 256, 128, 256, 32, 8, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8>,
|
||||
DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmMNKPadding, 1, 256, 128, 128, 64, 8, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8>,
|
||||
//###########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | S<C,D0...,D_N|
|
||||
// DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmSpec, 1, 256, 256, 256, 32, 8, 4, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, S<8,8,1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>,
|
||||
// DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmSpec, 1, 256, 128, 128, 64, 8, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, S<8,8,1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>,
|
||||
// DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmSpec, 1, 256, 256, 256, 32, 8, 4, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, S<8,8,1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>,
|
||||
// DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmSpec, 1, 256, 256, 256, 32, 8, 4, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, S<8,8,1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>,
|
||||
// DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmSpec, 1, 256, 224, 256, 64, 8, 4, 16, 16, 7, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 2, S<1, 32, 1, 8>, S<8,8,1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>,
|
||||
// DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmSpec, 1, 256, 128, 128, 64, 8, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, S<8,8,1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>,
|
||||
// DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmSpec, 1, 256, 128, 256, 32, 8, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, S<8,8,1>, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmSpec, 1, 256, 128, 128, 64, 8, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, S<8,8,1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>
|
||||
|
||||
//latency
|
||||
DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmMNKPadding, 1, 64, 16, 16, 256, 8, 4, 16, 16, 1, 1, S<32, 2, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<64, 1, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 16, 1, 4>, 4>,
|
||||
DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmMNKPadding, 1, 128, 16, 32, 256, 8, 4, 16, 16, 1, 1, S<32, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<64, 2, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 16, 1, 8>, 4>,
|
||||
|
||||
//mem
|
||||
DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmMNKPadding, 1, 64, 16, 16, 256, 8, 4, 16, 16, 1, 1, S<32, 2, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<64, 1, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 16, 1, 4>, 4>,
|
||||
DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmMNKPadding, 1, 128, 16, 32, 256, 8, 4, 16, 16, 1, 1, S<32, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<64, 2, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 16, 1, 8>, 4>,
|
||||
DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmMNKPadding, 1, 128, 16, 64, 128, 8, 4, 16, 16, 1, 2, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<32, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 16, 1, 8>, 4>,
|
||||
DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmMNKPadding, 1, 128, 32, 64, 128, 8, 4, 32, 32, 1, 1, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<32, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 16, 1, 8>, 8>,
|
||||
DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmMNKPadding, 1, 128, 16, 128, 64, 8, 4, 16, 16, 1, 4, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 16, 1, 8>, 4>,
|
||||
DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmMNKPadding, 1, 128, 32, 128, 64, 8, 4, 32, 32, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 16, 1, 8>, 8>,
|
||||
DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmMNKPadding, 1, 256, 16, 256, 64, 8, 4, 16, 16, 1, 4, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 16, 1, 16>, 4>,
|
||||
DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmMNKPadding, 1, 256, 32, 256, 64, 8, 4, 32, 32, 1, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 16, 1, 16>, 8>
|
||||
#endif
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
template <typename DsLayout,
|
||||
typename DsDataType,
|
||||
typename CDEElementwiseOp,
|
||||
GemmSpecialization GemmSpec = GemmMNKPadding,
|
||||
BlockGemmPipelineScheduler BlkGemmPipeSched = BlockGemmPipelineScheduler::Intrawave>
|
||||
using device_grouped_gemm_xdl_tile_loop_bf16_i8_bf16_mk_kn_mn_mem_instances =
|
||||
std::tuple<
|
||||
// clang-format off
|
||||
//###########################################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
|
||||
//###########################################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
|
||||
//###########################################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
|
||||
//###########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | S<C,D0...,D_N|
|
||||
// Latency friendly
|
||||
// DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmSpec, 1, 64, 16, 16, 256, 8, 4, 16, 16, 1, 1, S<32, 2, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<64, 1, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 16, 1, 4>, S<4,4,1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
|
||||
// DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmSpec, 1, 128, 16, 32, 256, 8, 4, 16, 16, 1, 1, S<32, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<64, 2, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 16, 1, 8>, S<4,4,1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
|
||||
// Memory friendly
|
||||
// DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmSpec, 1, 64, 16, 16, 256, 8, 4, 16, 16, 1, 1, S<32, 2, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<64, 1, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 16, 1, 4>, S<4,4,1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
|
||||
// DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmSpec, 1, 128, 16, 32, 256, 8, 4, 16, 16, 1, 1, S<32, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<64, 2, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 16, 1, 8>, S<4,4,1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
|
||||
DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmSpec, 1, 128, 16, 64, 128, 8, 4, 16, 16, 1, 2, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<32, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 16, 1, 8>, S<4,4,1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>
|
||||
// DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmSpec, 1, 128, 32, 64, 128, 8, 4, 32, 32, 1, 1, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<32, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 16, 1, 8>, S<8,8,1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
|
||||
// DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmSpec, 1, 128, 16, 128, 64, 8, 4, 16, 16, 1, 4, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 16, 1, 8>, S<4,4,1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
|
||||
// DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmSpec, 1, 128, 32, 128, 64, 8, 4, 32, 32, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 16, 1, 8>, S<8,8,1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
|
||||
// DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmSpec, 1, 256, 16, 256, 64, 8, 4, 16, 16, 1, 4, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 16, 1, 16>, S<4,4,1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
|
||||
// DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmSpec, 1, 256, 32, 256, 64, 8, 4, 32, 32, 1, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 16, 1, 16>, S<8,8,1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
void add_device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedGemmTileLoop<Row,
|
||||
Row,
|
||||
@@ -89,33 +101,89 @@ void add_device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_instan
|
||||
PassThrough,
|
||||
Multiply>>>& instances)
|
||||
{
|
||||
// comp
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_gemm_xdl_tile_loop_bf16_i8_bf16_mk_kn_mn_irregular_tile_instances<
|
||||
ck::Tuple<Row>,
|
||||
ck::Tuple<BF16>,
|
||||
Multiply>{});
|
||||
}
|
||||
device_grouped_gemm_xdl_tile_loop_bf16_i8_bf16_mk_kn_mn_comp_instances<ck::Tuple<Row>,
|
||||
ck::Tuple<BF16>,
|
||||
Multiply,
|
||||
GemmDefault>{});
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_gemm_xdl_tile_loop_bf16_i8_bf16_mk_kn_mn_comp_instances<ck::Tuple<Row>,
|
||||
ck::Tuple<BF16>,
|
||||
Multiply,
|
||||
GemmMNKPadding>{});
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_gemm_xdl_tile_loop_bf16_i8_bf16_mk_kn_mn_comp_instances<ck::Tuple<Row>,
|
||||
ck::Tuple<BF16>,
|
||||
Multiply,
|
||||
GemmMNPadding>{});
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_gemm_xdl_tile_loop_bf16_i8_bf16_mk_kn_mn_comp_instances<ck::Tuple<Row>,
|
||||
ck::Tuple<BF16>,
|
||||
Multiply,
|
||||
GemmKPadding>{});
|
||||
// mem
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_gemm_xdl_tile_loop_bf16_i8_bf16_mk_kn_mn_mem_instances<ck::Tuple<Row>,
|
||||
ck::Tuple<BF16>,
|
||||
Multiply,
|
||||
GemmDefault,
|
||||
Intrawave>{});
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_gemm_xdl_tile_loop_bf16_i8_bf16_mk_kn_mn_mem_instances<ck::Tuple<Row>,
|
||||
ck::Tuple<BF16>,
|
||||
Multiply,
|
||||
GemmMNKPadding,
|
||||
Intrawave>{});
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_gemm_xdl_tile_loop_bf16_i8_bf16_mk_kn_mn_mem_instances<ck::Tuple<Row>,
|
||||
ck::Tuple<BF16>,
|
||||
Multiply,
|
||||
GemmMNPadding,
|
||||
Intrawave>{});
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_gemm_xdl_tile_loop_bf16_i8_bf16_mk_kn_mn_mem_instances<ck::Tuple<Row>,
|
||||
ck::Tuple<BF16>,
|
||||
Multiply,
|
||||
GemmKPadding,
|
||||
Intrawave>{});
|
||||
|
||||
void add_device_grouped_gemm_xdl_tile_loop_multiply_bias_bf16_i8_bf16_mk_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedGemmTileLoop<Row,
|
||||
Row,
|
||||
ck::Tuple<Row, Row>,
|
||||
Row,
|
||||
BF16,
|
||||
I8,
|
||||
ck::Tuple<BF16, BF16>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
MultiplyAdd>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_gemm_xdl_tile_loop_bf16_i8_bf16_mk_kn_mn_irregular_tile_instances<
|
||||
ck::Tuple<Row, Row>,
|
||||
ck::Tuple<BF16, BF16>,
|
||||
MultiplyAdd>{});
|
||||
device_grouped_gemm_xdl_tile_loop_bf16_i8_bf16_mk_kn_mn_mem_instances<ck::Tuple<Row>,
|
||||
ck::Tuple<BF16>,
|
||||
Multiply,
|
||||
GemmDefault,
|
||||
Interwave>{});
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_gemm_xdl_tile_loop_bf16_i8_bf16_mk_kn_mn_mem_instances<ck::Tuple<Row>,
|
||||
ck::Tuple<BF16>,
|
||||
Multiply,
|
||||
GemmMNKPadding,
|
||||
Interwave>{});
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_gemm_xdl_tile_loop_bf16_i8_bf16_mk_kn_mn_mem_instances<ck::Tuple<Row>,
|
||||
ck::Tuple<BF16>,
|
||||
Multiply,
|
||||
GemmMNPadding,
|
||||
Interwave>{});
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_gemm_xdl_tile_loop_bf16_i8_bf16_mk_kn_mn_mem_instances<ck::Tuple<Row>,
|
||||
ck::Tuple<BF16>,
|
||||
Multiply,
|
||||
GemmKPadding,
|
||||
Interwave>{});
|
||||
}
|
||||
|
||||
void add_device_grouped_gemm_xdl_tile_loop_multiply_bias_fastgelu_bf16_i8_bf16_mk_kn_mn_instances(
|
||||
|
||||
@@ -0,0 +1,36 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_mem_v1_default_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedGemmTileLoop<Row,
|
||||
Row,
|
||||
ck::Tuple<Row>,
|
||||
Row,
|
||||
BF16,
|
||||
I8,
|
||||
ck::Tuple<BF16>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
Multiply>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_gemm_xdl_tile_loop_bf16_i8_bf16_mk_kn_mn_mem_instances<ck::Tuple<Row>,
|
||||
ck::Tuple<BF16>,
|
||||
Multiply,
|
||||
GemmDefault,
|
||||
Intrawave>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,36 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_mem_v1_kpadding_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedGemmTileLoop<Row,
|
||||
Row,
|
||||
ck::Tuple<Row>,
|
||||
Row,
|
||||
BF16,
|
||||
I8,
|
||||
ck::Tuple<BF16>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
Multiply>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_gemm_xdl_tile_loop_bf16_i8_bf16_mk_kn_mn_mem_instances<ck::Tuple<Row>,
|
||||
ck::Tuple<BF16>,
|
||||
Multiply,
|
||||
GemmKPadding,
|
||||
Intrawave>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,36 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_mem_v1_mnkpadding_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedGemmTileLoop<Row,
|
||||
Row,
|
||||
ck::Tuple<Row>,
|
||||
Row,
|
||||
BF16,
|
||||
I8,
|
||||
ck::Tuple<BF16>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
Multiply>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_gemm_xdl_tile_loop_bf16_i8_bf16_mk_kn_mn_mem_instances<ck::Tuple<Row>,
|
||||
ck::Tuple<BF16>,
|
||||
Multiply,
|
||||
GemmMNKPadding,
|
||||
Intrawave>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,36 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_mem_v1_mnpadding_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedGemmTileLoop<Row,
|
||||
Row,
|
||||
ck::Tuple<Row>,
|
||||
Row,
|
||||
BF16,
|
||||
I8,
|
||||
ck::Tuple<BF16>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
Multiply>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_gemm_xdl_tile_loop_bf16_i8_bf16_mk_kn_mn_mem_instances<ck::Tuple<Row>,
|
||||
ck::Tuple<BF16>,
|
||||
Multiply,
|
||||
GemmMNPadding,
|
||||
Intrawave>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,36 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_mem_v2_default_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedGemmTileLoop<Row,
|
||||
Row,
|
||||
ck::Tuple<Row>,
|
||||
Row,
|
||||
BF16,
|
||||
I8,
|
||||
ck::Tuple<BF16>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
Multiply>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_gemm_xdl_tile_loop_bf16_i8_bf16_mk_kn_mn_mem_instances<ck::Tuple<Row>,
|
||||
ck::Tuple<BF16>,
|
||||
Multiply,
|
||||
GemmDefault,
|
||||
Interwave>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,36 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_mem_v2_kpadding_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedGemmTileLoop<Row,
|
||||
Row,
|
||||
ck::Tuple<Row>,
|
||||
Row,
|
||||
BF16,
|
||||
I8,
|
||||
ck::Tuple<BF16>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
Multiply>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_gemm_xdl_tile_loop_bf16_i8_bf16_mk_kn_mn_mem_instances<ck::Tuple<Row>,
|
||||
ck::Tuple<BF16>,
|
||||
Multiply,
|
||||
GemmKPadding,
|
||||
Interwave>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,36 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_mem_v2_mnkpadding_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedGemmTileLoop<Row,
|
||||
Row,
|
||||
ck::Tuple<Row>,
|
||||
Row,
|
||||
BF16,
|
||||
I8,
|
||||
ck::Tuple<BF16>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
Multiply>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_gemm_xdl_tile_loop_bf16_i8_bf16_mk_kn_mn_mem_instances<ck::Tuple<Row>,
|
||||
ck::Tuple<BF16>,
|
||||
Multiply,
|
||||
GemmMNKPadding,
|
||||
Interwave>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,36 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_mem_v2_mnpadding_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedGemmTileLoop<Row,
|
||||
Row,
|
||||
ck::Tuple<Row>,
|
||||
Row,
|
||||
BF16,
|
||||
I8,
|
||||
ck::Tuple<BF16>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
Multiply>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_gemm_xdl_tile_loop_bf16_i8_bf16_mk_kn_mn_mem_instances<ck::Tuple<Row>,
|
||||
ck::Tuple<BF16>,
|
||||
Multiply,
|
||||
GemmMNPadding,
|
||||
Interwave>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,40 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_grouped_gemm_xdl_tile_loop_multiply_bias_bf16_i8_bf16_mk_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedGemmTileLoop<Row,
|
||||
Row,
|
||||
ck::Tuple<Row, Row>,
|
||||
Row,
|
||||
BF16,
|
||||
I8,
|
||||
ck::Tuple<BF16, BF16>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
MultiplyAdd>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_gemm_xdl_tile_loop_bf16_i8_bf16_mk_kn_mn_comp_instances<
|
||||
ck::Tuple<Row, Row>,
|
||||
ck::Tuple<BF16, BF16>,
|
||||
MultiplyAdd>{});
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_gemm_xdl_tile_loop_bf16_i8_bf16_mk_kn_mn_mem_instances<ck::Tuple<Row, Row>,
|
||||
ck::Tuple<BF16, BF16>,
|
||||
MultiplyAdd>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,41 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_grouped_gemm_xdl_tile_loop_multiply_bias_fastgelu_bf16_i8_bf16_mk_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedGemmTileLoop<Row,
|
||||
Row,
|
||||
ck::Tuple<Row, Row>,
|
||||
Row,
|
||||
BF16,
|
||||
I8,
|
||||
ck::Tuple<BF16, BF16>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
MultiplyAddFastGelu>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_gemm_xdl_tile_loop_bf16_i8_bf16_mk_kn_mn_comp_instances<
|
||||
ck::Tuple<Row, Row>,
|
||||
ck::Tuple<BF16, BF16>,
|
||||
MultiplyAddFastGelu>{});
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_gemm_xdl_tile_loop_bf16_i8_bf16_mk_kn_mn_mem_instances<
|
||||
ck::Tuple<Row, Row>,
|
||||
ck::Tuple<BF16, BF16>,
|
||||
MultiplyAddFastGelu>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,39 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_grouped_gemm_xdl_tile_loop_multiply_fastgelu_bf16_i8_bf16_mk_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedGemmTileLoop<Row,
|
||||
Row,
|
||||
ck::Tuple<Row>,
|
||||
Row,
|
||||
BF16,
|
||||
I8,
|
||||
ck::Tuple<BF16>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
MultiplyFastGelu>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_gemm_xdl_tile_loop_bf16_i8_bf16_mk_kn_mn_comp_instances<ck::Tuple<Row>,
|
||||
ck::Tuple<BF16>,
|
||||
MultiplyFastGelu>{});
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_gemm_xdl_tile_loop_bf16_i8_bf16_mk_kn_mn_mem_instances<ck::Tuple<Row>,
|
||||
ck::Tuple<BF16>,
|
||||
MultiplyFastGelu>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,347 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iomanip>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/host_utility/hip_check_error.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_grouped_gemm_tile_loop.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_gemm_tile_loop_multiply.hpp"
|
||||
|
||||
#include "ck/library/utility/check_err.hpp"
|
||||
#include "ck/library/utility/device_memory.hpp"
|
||||
#include "ck/library/utility/host_tensor.hpp"
|
||||
#include "ck/library/utility/literals.hpp"
|
||||
#include "ck/library/utility/fill.hpp"
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace profiler {
|
||||
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename DDataType,
|
||||
typename EDataType,
|
||||
typename AccDataType,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename DLayout,
|
||||
typename ELayout>
|
||||
bool profile_grouped_gemm_multiply_tile_loop_impl(int do_verification,
|
||||
int init_method,
|
||||
bool do_log,
|
||||
bool time_kernel,
|
||||
const std::vector<int>& Ms,
|
||||
const std::vector<int>& Ns,
|
||||
const std::vector<int>& Ks,
|
||||
const std::vector<int>& StrideAs,
|
||||
const std::vector<int>& StrideBs,
|
||||
const std::vector<int>& StrideDs,
|
||||
const std::vector<int>& StrideEs,
|
||||
int n_warmup = 10,
|
||||
int n_iter = 50)
|
||||
{
|
||||
using CDataType = EDataType;
|
||||
bool pass = true;
|
||||
|
||||
auto f_host_tensor_descriptor =
|
||||
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
|
||||
using namespace ck::literals;
|
||||
|
||||
if(is_same<decltype(layout), tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
return HostTensorDescriptor({row, col}, {stride, 1_uz});
|
||||
}
|
||||
else
|
||||
{
|
||||
return HostTensorDescriptor({row, col}, {1_uz, stride});
|
||||
}
|
||||
};
|
||||
|
||||
std::size_t group_count = Ms.size();
|
||||
|
||||
if(!(group_count == Ns.size() && group_count == Ks.size() && group_count == StrideAs.size() &&
|
||||
group_count == StrideBs.size() && group_count == StrideEs.size()))
|
||||
{
|
||||
throw std::runtime_error("wrong! inconsistent M/N/Ks, StrideA/B/Cs size\n");
|
||||
}
|
||||
|
||||
std::vector<Tensor<ADataType>> a_m_k;
|
||||
std::vector<Tensor<BDataType>> b_k_n;
|
||||
std::vector<Tensor<DDataType>> d_m_n;
|
||||
std::vector<Tensor<CDataType>> e_m_n_host_results;
|
||||
std::vector<Tensor<CDataType>> e_m_n_device_results;
|
||||
|
||||
for(std::size_t i = 0; i < group_count; i++)
|
||||
{
|
||||
a_m_k.push_back(
|
||||
Tensor<ADataType>(f_host_tensor_descriptor(Ms[i], Ks[i], StrideAs[i], ALayout{})));
|
||||
b_k_n.push_back(
|
||||
Tensor<BDataType>(f_host_tensor_descriptor(Ks[i], Ns[i], StrideBs[i], BLayout{})));
|
||||
d_m_n.push_back(
|
||||
Tensor<DDataType>(f_host_tensor_descriptor(Ms[i], Ns[i], StrideDs[i], DLayout{})));
|
||||
e_m_n_device_results.push_back(
|
||||
Tensor<CDataType>(f_host_tensor_descriptor(Ms[i], Ns[i], StrideEs[i], ELayout{})));
|
||||
e_m_n_host_results.push_back(
|
||||
Tensor<CDataType>(f_host_tensor_descriptor(Ms[i], Ns[i], StrideEs[i], ELayout{})));
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "group: " << i << " a_m_k[" << i << "]:" << a_m_k[i].mDesc << ", b_k_n["
|
||||
<< i << "]:" << b_k_n[i].mDesc << ", e_m_n_device_results[" << i
|
||||
<< "]:" << e_m_n_device_results[i].mDesc << std::endl;
|
||||
}
|
||||
switch(init_method)
|
||||
{
|
||||
case 0: break;
|
||||
case 1:
|
||||
ck::utils::FillUniformDistributionIntegerValue<ADataType>{-5, 5}(a_m_k[i]);
|
||||
ck::utils::FillUniformDistributionIntegerValue<BDataType>{-5, 5}(b_k_n[i]);
|
||||
ck::utils::FillUniformDistributionIntegerValue<DDataType>{-5, 5}(d_m_n[i]);
|
||||
break;
|
||||
case 2:
|
||||
ck::utils::FillUniformDistribution<ADataType>{.0, 1.}(a_m_k[i]);
|
||||
ck::utils::FillUniformDistribution<BDataType>{-0.5, 0.5}(b_k_n[i]);
|
||||
ck::utils::FillUniformDistribution<DDataType>{-0.5, 0.5}(d_m_n[i]);
|
||||
break;
|
||||
default:
|
||||
ck::utils::FillConstant<ADataType>{1}(a_m_k[i]);
|
||||
ck::utils::FillConstant<BDataType>{1}(b_k_n[i]);
|
||||
ck::utils::FillConstant<DDataType>{1}(d_m_n[i]);
|
||||
}
|
||||
}
|
||||
|
||||
using AElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using BElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using CElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using CDEElementOp = ck::tensor_operation::element_wise::Multiply;
|
||||
|
||||
const auto a_element_op = AElementOp{};
|
||||
const auto b_element_op = BElementOp{};
|
||||
const auto c_element_op = CElementOp{};
|
||||
const auto cde_element_op = CDEElementOp{};
|
||||
|
||||
using DeviceMemPtr = std::unique_ptr<DeviceMem>;
|
||||
std::vector<DeviceMemPtr> a_device_buf, b_device_buf, d_device_buf, e_device_buf;
|
||||
|
||||
a_device_buf.reserve(group_count);
|
||||
b_device_buf.reserve(group_count);
|
||||
d_device_buf.reserve(group_count);
|
||||
e_device_buf.reserve(group_count);
|
||||
|
||||
std::vector<const void*> p_a, p_b, p_d;
|
||||
constexpr ck::index_t NumDTensor = 1;
|
||||
auto p_ds = std::vector<std::array<const void*, NumDTensor>>{};
|
||||
std::vector<void*> p_e;
|
||||
|
||||
p_a.reserve(group_count);
|
||||
p_b.reserve(group_count);
|
||||
p_ds.reserve(group_count);
|
||||
p_e.reserve(group_count);
|
||||
|
||||
using KernelArguments =
|
||||
ck::tensor_operation::device::GroupedGemmTileLoopKernelArguments<NumDTensor>;
|
||||
|
||||
std::vector<ck::tensor_operation::device::GemmDesc> gemm_descs;
|
||||
std::vector<KernelArguments> gemm_kargs;
|
||||
|
||||
gemm_descs.reserve(group_count);
|
||||
gemm_kargs.reserve(group_count);
|
||||
|
||||
for(std::size_t i = 0; i < group_count; i++)
|
||||
{
|
||||
a_device_buf.emplace_back(
|
||||
std::make_unique<DeviceMem>(sizeof(ADataType) * a_m_k[i].mDesc.GetElementSpaceSize()));
|
||||
b_device_buf.emplace_back(
|
||||
std::make_unique<DeviceMem>(sizeof(BDataType) * b_k_n[i].mDesc.GetElementSpaceSize()));
|
||||
d_device_buf.emplace_back(
|
||||
std::make_unique<DeviceMem>(sizeof(DDataType) * d_m_n[i].mDesc.GetElementSpaceSize()));
|
||||
e_device_buf.emplace_back(std::make_unique<DeviceMem>(
|
||||
sizeof(CDataType) * e_m_n_device_results[i].mDesc.GetElementSpaceSize()));
|
||||
|
||||
a_device_buf[i]->ToDevice(a_m_k[i].mData.data());
|
||||
b_device_buf[i]->ToDevice(b_k_n[i].mData.data());
|
||||
d_device_buf[i]->ToDevice(d_m_n[i].mData.data());
|
||||
e_device_buf[i]->SetZero();
|
||||
|
||||
p_a.push_back(a_device_buf[i]->GetDeviceBuffer());
|
||||
p_b.push_back(b_device_buf[i]->GetDeviceBuffer());
|
||||
p_ds.push_back({d_device_buf[i]->GetDeviceBuffer()});
|
||||
p_e.push_back(e_device_buf[i]->GetDeviceBuffer());
|
||||
|
||||
gemm_descs.push_back(
|
||||
{0, Ns[i], Ks[i], StrideAs[i], StrideBs[i], StrideEs[i], {StrideDs[i]}});
|
||||
gemm_kargs.push_back({a_device_buf[i]->GetDeviceBuffer(),
|
||||
b_device_buf[i]->GetDeviceBuffer(),
|
||||
{d_device_buf[i]->GetDeviceBuffer()},
|
||||
e_device_buf[i]->GetDeviceBuffer(),
|
||||
Ms[i],
|
||||
Ns[i],
|
||||
Ks[i],
|
||||
StrideAs[i],
|
||||
StrideBs[i],
|
||||
{StrideDs[i]},
|
||||
StrideEs[i]});
|
||||
}
|
||||
|
||||
using DeviceOp = ck::tensor_operation::device::DeviceGroupedGemmTileLoop<ALayout,
|
||||
BLayout,
|
||||
ck::Tuple<DLayout>,
|
||||
ELayout,
|
||||
ADataType,
|
||||
BDataType,
|
||||
ck::Tuple<DDataType>,
|
||||
EDataType,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
CDEElementOp>;
|
||||
|
||||
const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
|
||||
DeviceOp>::GetInstances();
|
||||
|
||||
if(op_ptrs.size() <= 0)
|
||||
{
|
||||
throw std::runtime_error("wrong! no device GEMM instance found");
|
||||
}
|
||||
|
||||
std::string best_gemm_name;
|
||||
float best_ave_time = 0;
|
||||
float best_tflops = 0;
|
||||
float best_gb_per_sec = 0;
|
||||
|
||||
if(do_verification)
|
||||
{
|
||||
for(std::size_t i = 0; i < gemm_descs.size(); i++)
|
||||
{
|
||||
Tensor<CDataType> c_m_n({Ms[i], Ns[i]});
|
||||
|
||||
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
AccDataType,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
CElementOp>;
|
||||
|
||||
auto ref_gemm = ReferenceGemmInstance{};
|
||||
auto ref_invoker = ref_gemm.MakeInvoker();
|
||||
auto ref_argument = ref_gemm.MakeArgument(
|
||||
a_m_k[i], b_k_n[i], c_m_n, a_element_op, b_element_op, c_element_op);
|
||||
ref_invoker.Run(ref_argument);
|
||||
|
||||
for(int m = 0; m < Ms[i]; ++m)
|
||||
{
|
||||
for(int n = 0; n < Ns[i]; ++n)
|
||||
{
|
||||
cde_element_op(e_m_n_host_results[i](m, n), c_m_n(m, n), d_m_n[i](m, n));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// profile device GEMM instances
|
||||
for(auto& gemm_ptr : op_ptrs)
|
||||
{
|
||||
auto argument_ptr =
|
||||
gemm_ptr->MakeArgumentPointer(p_a,
|
||||
p_b,
|
||||
p_ds,
|
||||
p_e,
|
||||
gemm_descs,
|
||||
ck::tensor_operation::element_wise::PassThrough{},
|
||||
ck::tensor_operation::element_wise::PassThrough{},
|
||||
cde_element_op);
|
||||
auto invoker_ptr = gemm_ptr->MakeInvokerPointer();
|
||||
std::string gemm_name = gemm_ptr->GetTypeString();
|
||||
|
||||
DeviceMem gemm_arg_dev_mem(gemm_ptr->GetDeviceKernelArgSize(argument_ptr.get()));
|
||||
hip_check_error(hipMemcpy(gemm_arg_dev_mem.GetDeviceBuffer(),
|
||||
gemm_kargs.data(),
|
||||
gemm_ptr->GetDeviceKernelArgSize(argument_ptr.get()),
|
||||
hipMemcpyHostToDevice));
|
||||
gemm_ptr->SetDeviceKernelArgs(argument_ptr.get(), gemm_arg_dev_mem.GetDeviceBuffer());
|
||||
|
||||
if(gemm_ptr->IsSupportedArgument(argument_ptr.get()))
|
||||
{
|
||||
invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false, 0, n_warmup, n_iter});
|
||||
if(do_verification)
|
||||
{
|
||||
bool instance_pass = true;
|
||||
for(std::size_t i = 0; i < gemm_descs.size(); i++)
|
||||
{
|
||||
e_device_buf[i]->FromDevice(e_m_n_device_results[i].mData.data());
|
||||
instance_pass = instance_pass && ck::utils::check_err(e_m_n_device_results[i],
|
||||
e_m_n_host_results[i]);
|
||||
|
||||
if(do_log)
|
||||
{
|
||||
LogRangeAsType<float>(std::cout << "a : ", a_m_k[i].mData, ",")
|
||||
<< std::endl;
|
||||
LogRangeAsType<float>(std::cout << "b: ", b_k_n[i].mData, ",") << std::endl;
|
||||
LogRangeAsType<float>(
|
||||
std::cout << "e_device: ", e_m_n_device_results[i].mData, ",")
|
||||
<< std::endl;
|
||||
LogRangeAsType<float>(
|
||||
std::cout << "e_host : ", e_m_n_host_results[i].mData, ",")
|
||||
<< std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
std::cout << "Instance: " << gemm_name << " verification "
|
||||
<< (instance_pass ? "SUCCEED" : "FAILED") << std::endl;
|
||||
|
||||
pass = pass && instance_pass;
|
||||
}
|
||||
|
||||
if(time_kernel)
|
||||
{
|
||||
float ave_time = invoker_ptr->Run(
|
||||
argument_ptr.get(), StreamConfig{nullptr, time_kernel, 0, n_warmup, n_iter});
|
||||
|
||||
std::size_t flop = 0, num_btype = 0;
|
||||
for(std::size_t i = 0; i < gemm_descs.size(); i++)
|
||||
{
|
||||
flop += std::size_t(2) * Ms[i] * Ns[i] * Ks[i];
|
||||
|
||||
num_btype += sizeof(ADataType) * Ms[i] * Ks[i] +
|
||||
sizeof(BDataType) * Ks[i] * Ns[i] +
|
||||
sizeof(EDataType) * Ms[i] * Ns[i] + // D matrix
|
||||
sizeof(EDataType) * Ms[i] * Ns[i];
|
||||
}
|
||||
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
float gb_per_sec = num_btype / 1.E6 / ave_time;
|
||||
std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops
|
||||
<< " TFlops, " << gb_per_sec << " GB/s, " << gemm_name << std::endl;
|
||||
|
||||
if(tflops > best_tflops)
|
||||
{
|
||||
best_gemm_name = gemm_name;
|
||||
best_tflops = tflops;
|
||||
best_ave_time = ave_time;
|
||||
best_gb_per_sec = gb_per_sec;
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cout << "Instance: " << gemm_name << ", does not support this GEMM problem"
|
||||
<< std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
if(time_kernel)
|
||||
{
|
||||
std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, "
|
||||
<< best_gb_per_sec << " GB/s, " << best_gemm_name << std::endl;
|
||||
}
|
||||
|
||||
return pass;
|
||||
}
|
||||
|
||||
} // namespace profiler
|
||||
} // namespace ck
|
||||
@@ -43,6 +43,7 @@ if(GPU_TARGETS MATCHES "gfx9")
|
||||
list(APPEND PROFILER_SOURCES profile_grouped_gemm_two_stage.cpp)
|
||||
list(APPEND PROFILER_SOURCES profile_grouped_gemm_fastgelu.cpp)
|
||||
list(APPEND PROFILER_SOURCES profile_grouped_gemm_tile_loop.cpp)
|
||||
list(APPEND PROFILER_SOURCES profile_grouped_gemm_multiply_tile_loop.cpp)
|
||||
endif()
|
||||
list(APPEND PROFILER_SOURCES profile_gemm_multiply_add.cpp)
|
||||
list(APPEND PROFILER_SOURCES profile_batched_gemm.cpp)
|
||||
|
||||
133
profiler/src/profile_grouped_gemm_multiply_tile_loop.cpp
Normal file
133
profiler/src/profile_grouped_gemm_multiply_tile_loop.cpp
Normal file
@@ -0,0 +1,133 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <cstdlib>
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "profiler/profile_grouped_gemm_multiply_tile_loop_impl.hpp"
|
||||
#include "profiler_operation_registry.hpp"
|
||||
|
||||
enum struct GemmMatrixLayout
|
||||
{
|
||||
MK_KN_MN, // 0
|
||||
};
|
||||
|
||||
enum struct GemmDataType
|
||||
{
|
||||
BF16_INT8_BF16_BF16, // 0
|
||||
};
|
||||
|
||||
#define OP_NAME "grouped_gemm_multiply_tile_loop"
|
||||
#define OP_DESC "Grouped GEMM Multiply Multiple D Tile Loop"
|
||||
|
||||
namespace {
|
||||
|
||||
std::vector<int> argToIntArray(char* input)
|
||||
{
|
||||
std::vector<int> out;
|
||||
std::istringstream in(input);
|
||||
std::string item;
|
||||
|
||||
while(std::getline(in, item, ','))
|
||||
{
|
||||
out.push_back(std::stoi(item));
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
int profile_grouped_gemm_tile_loop(int argc, char* argv[])
|
||||
{
|
||||
if(argc < 14)
|
||||
{
|
||||
std::cout
|
||||
<< "arg1: tensor operation (" OP_NAME ": " OP_DESC ")\n"
|
||||
<< "arg2: data type (0: bf16@int8)\n"
|
||||
<< "arg3: matrix layout (0: A[m, k] * B[k, n] = C[m, n]);\n"
|
||||
<< "arg4: verification (0: no; 1: yes)\n"
|
||||
<< "arg5: initialization (0: no init; 1: integer value; 2: decimal value)\n"
|
||||
<< "arg6: print tensor value (0: no; 1: yes)\n"
|
||||
<< "arg7: time kernel (0=n0, 1=yes)\n"
|
||||
<< "arg8 to 13: Ms, Ns, Ks, StrideAs, StrideBs, StrideCs (e.g., 256,256 128,128 64,64 "
|
||||
"64,64 64,64 128,128)\n"
|
||||
<< "optional:\n"
|
||||
<< "arg14: number of warm-up cycles (default 1)\n"
|
||||
<< "arg15: number of iterations (default 10)\n"
|
||||
<< std::endl;
|
||||
|
||||
exit(1);
|
||||
}
|
||||
|
||||
const auto data_type = static_cast<GemmDataType>(std::stoi(argv[2]));
|
||||
const auto layout = static_cast<GemmMatrixLayout>(std::stoi(argv[3]));
|
||||
const bool do_verification = std::stoi(argv[4]);
|
||||
const int init_method = std::stoi(argv[5]);
|
||||
const bool do_log = std::stoi(argv[6]);
|
||||
const bool time_kernel = std::stoi(argv[7]);
|
||||
|
||||
const auto Ms = argToIntArray(argv[8]);
|
||||
const auto Ns = argToIntArray(argv[9]);
|
||||
const auto Ks = argToIntArray(argv[10]);
|
||||
|
||||
auto StrideAs = argToIntArray(argv[11]);
|
||||
auto StrideBs = argToIntArray(argv[12]);
|
||||
auto StrideCs = argToIntArray(argv[13]);
|
||||
|
||||
const int DefaultStrideA = Ks[0];
|
||||
const int DefaultStrideB = Ns[0];
|
||||
const int DefaultStrideC = Ns[0];
|
||||
|
||||
for(size_t i = 0; i < Ms.size(); ++i)
|
||||
{
|
||||
StrideAs[i] = StrideAs[i] == -1 ? DefaultStrideA : StrideAs[i];
|
||||
StrideBs[i] = StrideBs[i] == -1 ? DefaultStrideB : StrideBs[i];
|
||||
StrideCs[i] = StrideCs[i] == -1 ? DefaultStrideC : StrideCs[i];
|
||||
}
|
||||
|
||||
std::vector<int> StrideDs(StrideCs);
|
||||
|
||||
int n_warmup = 10;
|
||||
int n_iter = 50;
|
||||
if(argc == 16)
|
||||
{
|
||||
n_warmup = std::stoi(argv[14]);
|
||||
n_iter = std::stoi(argv[15]);
|
||||
}
|
||||
|
||||
if(data_type == GemmDataType::BF16_INT8_BF16_BF16 && layout == GemmMatrixLayout::MK_KN_MN)
|
||||
{
|
||||
ck::profiler::profile_grouped_gemm_multiply_tile_loop_impl<
|
||||
ck::bhalf_t,
|
||||
int8_t,
|
||||
ck::bhalf_t,
|
||||
ck::bhalf_t,
|
||||
float,
|
||||
ck::tensor_layout::gemm::RowMajor,
|
||||
ck::tensor_layout::gemm::RowMajor,
|
||||
ck::tensor_layout::gemm::RowMajor,
|
||||
ck::tensor_layout::gemm::RowMajor>(do_verification,
|
||||
init_method,
|
||||
do_log,
|
||||
time_kernel,
|
||||
Ms,
|
||||
Ns,
|
||||
Ks,
|
||||
StrideAs,
|
||||
StrideBs,
|
||||
StrideDs,
|
||||
StrideCs,
|
||||
n_warmup,
|
||||
n_iter);
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("wrong! this GEMM data_type & layout is not implemented");
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
REGISTER_PROFILER_OPERATION(OP_NAME, OP_DESC, profile_grouped_gemm_tile_loop);
|
||||
Reference in New Issue
Block a user