mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 13:11:25 +00:00
Switch to universal gemm in grouped gemm tile loop (#1335)
* switch to universal gemm in grouped gemm tile loop * minor fixes * add reviewers comments --------- Co-authored-by: Adam Osewski <19374865+aosewski@users.noreply.github.com>
This commit is contained in:
@@ -19,6 +19,7 @@
|
||||
#include "ck/tensor_operation/gpu/device/device_grouped_gemm_tile_loop.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include <ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp>
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp" // stare wywalic
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp"
|
||||
|
||||
@@ -42,16 +43,22 @@ namespace device {
|
||||
template <typename GridwiseGemm,
|
||||
typename GemmDesc,
|
||||
GemmSpecialization GemmSpec,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename DsDataType,
|
||||
typename EDataType,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename ELayout,
|
||||
index_t KPerBlock,
|
||||
typename OffsettedBlockToCTileMap,
|
||||
typename LocalBlock2ETileMap,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CDEElementwiseOperation>
|
||||
typename CDEElementwiseOperation,
|
||||
BlockGemmPipelineScheduler BlkGemmPipeSched,
|
||||
BlockGemmPipelineVersion BlkGemmPipelineVer>
|
||||
__global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
@@ -67,6 +74,7 @@ __global__ void
|
||||
|
||||
constexpr index_t shared_size = GridwiseGemm::GetSharedMemoryNumberOfByte();
|
||||
__shared__ uint8_t p_shared[shared_size];
|
||||
__shared__ uint8_t p_shared1[shared_size];
|
||||
|
||||
const auto gemm_desc_ptr =
|
||||
reinterpret_cast<const GemmDesc*>(cast_pointer_to_generic_address_space(gemm_descs_const));
|
||||
@@ -81,27 +89,8 @@ __global__ void
|
||||
index_t gemm_tile_id_start = 0;
|
||||
index_t gemm_tile_id_end = 0;
|
||||
|
||||
using AGridDescMK =
|
||||
remove_cvref_t<decltype(GridwiseGemm::template MakeAGridDescriptor_M_K<ALayout, GemmSpec>(
|
||||
1, 1, 1))>;
|
||||
using BGridDescNK =
|
||||
remove_cvref_t<decltype(GridwiseGemm::template MakeBGridDescriptor_N_K<BLayout, GemmSpec>(
|
||||
1, 1, 1))>;
|
||||
using EGridDescMN =
|
||||
remove_cvref_t<decltype(GridwiseGemm::template MakeEGridDescriptor_M_N<ELayout, GemmSpec>(
|
||||
1, 1, 1))>;
|
||||
using DsGridDescMN =
|
||||
remove_cvref_t<decltype(GridwiseGemm::template MakeDsGridDescriptor_M_N<DsLayout, GemmSpec>(
|
||||
{}, {}, {}))>;
|
||||
|
||||
index_t M = 0, N = 0, K = 0;
|
||||
index_t StrideA, StrideB, StrideE;
|
||||
std::array<index_t, NumDTensor> StrideDs;
|
||||
|
||||
AGridDescMK a_grid_desc_mk;
|
||||
BGridDescNK b_grid_desc_nk;
|
||||
EGridDescMN e_grid_desc_mn;
|
||||
DsGridDescMN ds_grid_desc_mn;
|
||||
auto b2c_tile_map = OffsettedBlockToCTileMap(LocalBlock2ETileMap(1, 1), 1, 1);
|
||||
|
||||
do
|
||||
@@ -127,31 +116,13 @@ __global__ void
|
||||
}
|
||||
|
||||
b2c_tile_map =
|
||||
OffsettedBlockToCTileMap(LocalBlock2ETileMap(M, N), group_offset, tile_offset);
|
||||
OffsettedBlockToCTileMap(LocalBlock2ETileMap(M, N, 4), group_offset, tile_offset);
|
||||
grid_size_grp = b2c_tile_map.CalculateGridSize(M, N);
|
||||
|
||||
gemm_tile_id_start = group_offset;
|
||||
gemm_tile_id_end = group_offset + grid_size_grp;
|
||||
}
|
||||
|
||||
StrideA = gemm_desc_ptr[group_id].StrideA;
|
||||
StrideB = gemm_desc_ptr[group_id].StrideB;
|
||||
StrideDs = gemm_desc_ptr[group_id].StrideDs;
|
||||
StrideE = gemm_desc_ptr[group_id].StrideE;
|
||||
|
||||
a_grid_desc_mk =
|
||||
GridwiseGemm::template MakeAGridDescriptor_M_K<ALayout, GemmSpec>(M, K, StrideA);
|
||||
b_grid_desc_nk =
|
||||
GridwiseGemm::template MakeBGridDescriptor_N_K<BLayout, GemmSpec>(K, N, StrideB);
|
||||
e_grid_desc_mn =
|
||||
GridwiseGemm::template MakeEGridDescriptor_M_N<ELayout, GemmSpec>(M, N, StrideE);
|
||||
|
||||
static_for<0, NumDTensor, 1>{}([&](auto j) {
|
||||
using DLayout = remove_cvref_t<tuple_element_t<j.value, DsLayout>>;
|
||||
ds_grid_desc_mn(j) = GridwiseGemm::template MakeEGridDescriptor_M_N<DLayout, GemmSpec>(
|
||||
M, N, StrideDs[j]);
|
||||
});
|
||||
|
||||
using DsGridPointer = decltype(GridwiseGemm::MakeDsGridPointer());
|
||||
DsGridPointer p_ds_grid;
|
||||
|
||||
@@ -160,42 +131,268 @@ __global__ void
|
||||
p_ds_grid(i) = static_cast<const DDataType*>(gemm_desc_ptr[group_id].p_ds_grid[i]);
|
||||
});
|
||||
|
||||
bool has_main_kblock_loop =
|
||||
GridwiseGemm::CalculateHasMainKBlockLoop(a_grid_desc_mk.GetLength(Number<1>{}));
|
||||
static constexpr index_t kbatch = 1;
|
||||
static constexpr index_t k_grain = kbatch * KPerBlock;
|
||||
index_t K_split = (K + k_grain - 1) / k_grain * KPerBlock;
|
||||
|
||||
const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split);
|
||||
|
||||
// Update tile offset if we have moved within group
|
||||
b2c_tile_map.UpdateTileOffset(tile_offset);
|
||||
|
||||
if(has_main_kblock_loop)
|
||||
using Problem = typename GridwiseGemm::Problem;
|
||||
auto problem = Problem(gemm_desc_ptr[group_id].M,
|
||||
gemm_desc_ptr[group_id].N,
|
||||
gemm_desc_ptr[group_id].K,
|
||||
gemm_desc_ptr[group_id].StrideA,
|
||||
gemm_desc_ptr[group_id].StrideB,
|
||||
gemm_desc_ptr[group_id].StrideDs,
|
||||
gemm_desc_ptr[group_id].StrideE,
|
||||
kbatch);
|
||||
|
||||
if(has_main_k_block_loop)
|
||||
{
|
||||
GridwiseGemm::template Run<true>(gemm_desc_ptr[group_id].p_a_grid,
|
||||
gemm_desc_ptr[group_id].p_b_grid,
|
||||
p_ds_grid,
|
||||
gemm_desc_ptr[group_id].p_e_grid,
|
||||
static_cast<void*>(p_shared),
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op,
|
||||
a_grid_desc_mk,
|
||||
b_grid_desc_nk,
|
||||
ds_grid_desc_mn,
|
||||
e_grid_desc_mn,
|
||||
b2c_tile_map);
|
||||
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 ||
|
||||
BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
|
||||
{
|
||||
GridwiseGemm::template Run<OffsettedBlockToCTileMap,
|
||||
true,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
TailNumber::Full>(
|
||||
static_cast<const ADataType*>(gemm_desc_ptr[group_id].p_a_grid),
|
||||
static_cast<const BDataType*>(gemm_desc_ptr[group_id].p_b_grid),
|
||||
p_ds_grid,
|
||||
static_cast<EDataType*>(gemm_desc_ptr[group_id].p_e_grid),
|
||||
static_cast<void*>(p_shared),
|
||||
problem,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op,
|
||||
b2c_tile_map);
|
||||
}
|
||||
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2)
|
||||
{
|
||||
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::One)
|
||||
{
|
||||
GridwiseGemm::template Run<OffsettedBlockToCTileMap,
|
||||
true,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
TailNumber::One>(
|
||||
static_cast<const ADataType*>(gemm_desc_ptr[group_id].p_a_grid),
|
||||
static_cast<const BDataType*>(gemm_desc_ptr[group_id].p_b_grid),
|
||||
p_ds_grid,
|
||||
static_cast<EDataType*>(gemm_desc_ptr[group_id].p_e_grid),
|
||||
static_cast<void*>(p_shared),
|
||||
problem,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op,
|
||||
b2c_tile_map);
|
||||
}
|
||||
else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Full)
|
||||
{
|
||||
GridwiseGemm::template Run<OffsettedBlockToCTileMap,
|
||||
true,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
TailNumber::Full>(
|
||||
static_cast<const ADataType*>(gemm_desc_ptr[group_id].p_a_grid),
|
||||
static_cast<const BDataType*>(gemm_desc_ptr[group_id].p_b_grid),
|
||||
p_ds_grid,
|
||||
static_cast<EDataType*>(gemm_desc_ptr[group_id].p_e_grid),
|
||||
static_cast<void*>(p_shared),
|
||||
problem,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op,
|
||||
b2c_tile_map);
|
||||
}
|
||||
|
||||
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 2)
|
||||
{
|
||||
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Two)
|
||||
{
|
||||
GridwiseGemm::template Run<OffsettedBlockToCTileMap,
|
||||
true,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
TailNumber::Two>(
|
||||
static_cast<const ADataType*>(gemm_desc_ptr[group_id].p_a_grid),
|
||||
static_cast<const BDataType*>(gemm_desc_ptr[group_id].p_b_grid),
|
||||
p_ds_grid,
|
||||
static_cast<EDataType*>(gemm_desc_ptr[group_id].p_e_grid),
|
||||
static_cast<void*>(p_shared),
|
||||
problem,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op,
|
||||
b2c_tile_map);
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 3)
|
||||
{
|
||||
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Three)
|
||||
{
|
||||
GridwiseGemm::template Run<OffsettedBlockToCTileMap,
|
||||
true,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
TailNumber::Three>(
|
||||
static_cast<const ADataType*>(gemm_desc_ptr[group_id].p_a_grid),
|
||||
static_cast<const BDataType*>(gemm_desc_ptr[group_id].p_b_grid),
|
||||
p_ds_grid,
|
||||
static_cast<EDataType*>(gemm_desc_ptr[group_id].p_e_grid),
|
||||
static_cast<void*>(p_shared),
|
||||
problem,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op,
|
||||
b2c_tile_map);
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 4)
|
||||
{
|
||||
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Four)
|
||||
{
|
||||
GridwiseGemm::template Run<OffsettedBlockToCTileMap,
|
||||
true,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
TailNumber::Four>(
|
||||
static_cast<const ADataType*>(gemm_desc_ptr[group_id].p_a_grid),
|
||||
static_cast<const BDataType*>(gemm_desc_ptr[group_id].p_b_grid),
|
||||
p_ds_grid,
|
||||
static_cast<EDataType*>(gemm_desc_ptr[group_id].p_e_grid),
|
||||
static_cast<void*>(p_shared),
|
||||
problem,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op,
|
||||
b2c_tile_map);
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 5)
|
||||
{
|
||||
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Five)
|
||||
{
|
||||
GridwiseGemm::template Run<OffsettedBlockToCTileMap,
|
||||
true,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
TailNumber::Five>(
|
||||
static_cast<const ADataType*>(gemm_desc_ptr[group_id].p_a_grid),
|
||||
static_cast<const BDataType*>(gemm_desc_ptr[group_id].p_b_grid),
|
||||
p_ds_grid,
|
||||
static_cast<EDataType*>(gemm_desc_ptr[group_id].p_e_grid),
|
||||
static_cast<void*>(p_shared),
|
||||
problem,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op,
|
||||
b2c_tile_map);
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 6)
|
||||
{
|
||||
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Six)
|
||||
{
|
||||
GridwiseGemm::template Run<OffsettedBlockToCTileMap,
|
||||
true,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
TailNumber::Six>(
|
||||
static_cast<const ADataType*>(gemm_desc_ptr[group_id].p_a_grid),
|
||||
static_cast<const BDataType*>(gemm_desc_ptr[group_id].p_b_grid),
|
||||
p_ds_grid,
|
||||
static_cast<EDataType*>(gemm_desc_ptr[group_id].p_e_grid),
|
||||
static_cast<void*>(p_shared),
|
||||
problem,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op,
|
||||
b2c_tile_map);
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 7)
|
||||
{
|
||||
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Seven)
|
||||
{
|
||||
GridwiseGemm::template Run<OffsettedBlockToCTileMap,
|
||||
true,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
TailNumber::Seven>(
|
||||
static_cast<const ADataType*>(gemm_desc_ptr[group_id].p_a_grid),
|
||||
static_cast<const BDataType*>(gemm_desc_ptr[group_id].p_b_grid),
|
||||
p_ds_grid,
|
||||
static_cast<EDataType*>(gemm_desc_ptr[group_id].p_e_grid),
|
||||
static_cast<void*>(p_shared),
|
||||
problem,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op,
|
||||
b2c_tile_map);
|
||||
}
|
||||
}
|
||||
}
|
||||
// Tail number could be Odd or Even
|
||||
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v4)
|
||||
{
|
||||
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
|
||||
{
|
||||
GridwiseGemm::template Run_2Lds<OffsettedBlockToCTileMap,
|
||||
true,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
TailNumber::Odd>(
|
||||
static_cast<const ADataType*>(gemm_desc_ptr[group_id].p_a_grid),
|
||||
static_cast<const BDataType*>(gemm_desc_ptr[group_id].p_b_grid),
|
||||
p_ds_grid,
|
||||
static_cast<EDataType*>(gemm_desc_ptr[group_id].p_e_grid),
|
||||
static_cast<void*>(p_shared),
|
||||
static_cast<void*>(p_shared1),
|
||||
problem,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op,
|
||||
b2c_tile_map);
|
||||
}
|
||||
else
|
||||
{
|
||||
GridwiseGemm::template Run_2Lds<OffsettedBlockToCTileMap,
|
||||
true,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
TailNumber::Even>(
|
||||
static_cast<const ADataType*>(gemm_desc_ptr[group_id].p_a_grid),
|
||||
static_cast<const BDataType*>(gemm_desc_ptr[group_id].p_b_grid),
|
||||
p_ds_grid,
|
||||
static_cast<EDataType*>(gemm_desc_ptr[group_id].p_e_grid),
|
||||
static_cast<void*>(p_shared),
|
||||
static_cast<void*>(p_shared1),
|
||||
problem,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op,
|
||||
b2c_tile_map);
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
GridwiseGemm::template Run<false>(gemm_desc_ptr[group_id].p_a_grid,
|
||||
gemm_desc_ptr[group_id].p_b_grid,
|
||||
p_ds_grid,
|
||||
gemm_desc_ptr[group_id].p_e_grid,
|
||||
static_cast<void*>(p_shared),
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op,
|
||||
a_grid_desc_mk,
|
||||
b_grid_desc_nk,
|
||||
ds_grid_desc_mn,
|
||||
e_grid_desc_mn,
|
||||
b2c_tile_map);
|
||||
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
|
||||
{
|
||||
GridwiseGemm::template Run<OffsettedBlockToCTileMap,
|
||||
false,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
TailNumber::Full>(
|
||||
static_cast<const ADataType*>(gemm_desc_ptr[group_id].p_a_grid),
|
||||
static_cast<const BDataType*>(gemm_desc_ptr[group_id].p_b_grid),
|
||||
p_ds_grid,
|
||||
static_cast<EDataType*>(gemm_desc_ptr[group_id].p_e_grid),
|
||||
static_cast<void*>(p_shared),
|
||||
problem,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op,
|
||||
b2c_tile_map);
|
||||
}
|
||||
}
|
||||
|
||||
tile_id += get_grid_size();
|
||||
@@ -253,10 +450,12 @@ template <typename ALayout,
|
||||
index_t CShuffleMXdlPerWavePerShuffle,
|
||||
index_t CShuffleNXdlPerWavePerShuffle,
|
||||
typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
index_t CDEShuffleBlockTransferScalarPerVector_NPerBlock,
|
||||
PipelineVersion PipelineVer = PipelineVersion::v1,
|
||||
LoopScheduler LoopSched = make_default_loop_scheduler(),
|
||||
typename ComputeDataType = EDataType>
|
||||
typename CDEShuffleBlockTransferScalarPerVectors,
|
||||
BlockGemmPipelineScheduler BlkGemmPipeSched = BlockGemmPipelineScheduler::Intrawave,
|
||||
BlockGemmPipelineVersion BlkGemmPipelineVer = BlockGemmPipelineVersion::v1,
|
||||
typename ComputeTypeA = EDataType,
|
||||
typename ComputeTypeB = ComputeTypeA>
|
||||
|
||||
struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop
|
||||
: public DeviceGroupedGemmTileLoop<ALayout,
|
||||
BLayout,
|
||||
@@ -273,10 +472,13 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop
|
||||
using DeviceOp = DeviceGroupedGemmMultipleDXdlCShuffleTileLoop;
|
||||
static constexpr index_t NumDTensor = DsDataType::Size();
|
||||
|
||||
using GridwiseGemm = GridwiseGemmMultipleD_xdl_cshuffle<
|
||||
using GridwiseGemm = GridwiseGemmMultiD_xdl_cshuffle_v3<
|
||||
ALayout,
|
||||
BLayout,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
ADataType,
|
||||
BDataType,
|
||||
ComputeDataType,
|
||||
AccDataType,
|
||||
CShuffleDataType,
|
||||
DsDataType,
|
||||
@@ -284,8 +486,7 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CDEElementwiseOperation,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
NumGemmKPrefetchStage,
|
||||
GemmSpec,
|
||||
BlockSize,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
@@ -315,58 +516,15 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop
|
||||
CShuffleMXdlPerWavePerShuffle,
|
||||
CShuffleNXdlPerWavePerShuffle,
|
||||
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
CDEShuffleBlockTransferScalarPerVector_NPerBlock,
|
||||
LoopSched,
|
||||
PipelineVer>;
|
||||
CDEShuffleBlockTransferScalarPerVectors,
|
||||
BlkGemmPipeSched,
|
||||
BlkGemmPipelineVer,
|
||||
ComputeTypeA,
|
||||
ComputeTypeB>;
|
||||
|
||||
template <typename UnderlyingBlockToCTileMap>
|
||||
struct OffsettedBlockToCTileMap
|
||||
{
|
||||
using underlying_type = UnderlyingBlockToCTileMap;
|
||||
|
||||
__host__ __device__ OffsettedBlockToCTileMap(UnderlyingBlockToCTileMap block_to_ctile_map,
|
||||
index_t group_offset,
|
||||
index_t tile_offset)
|
||||
: block_to_ctile_map_{block_to_ctile_map},
|
||||
group_offset_{group_offset},
|
||||
tile_offset_{tile_offset}
|
||||
{
|
||||
}
|
||||
|
||||
template <typename TopIdx>
|
||||
__host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const
|
||||
{
|
||||
return block_to_ctile_map_.CalculateBottomIndex(
|
||||
make_multi_index(idx_top[Number<0>{}] + tile_offset_ - group_offset_));
|
||||
}
|
||||
|
||||
template <typename CTileIdx, typename CTileDim>
|
||||
__host__ __device__ bool ValidCTileIndex(const CTileIdx& c_tile_idx,
|
||||
const CTileDim& c_tile_dim) const
|
||||
{
|
||||
return block_to_ctile_map_.ValidCTileIndex(c_tile_idx, c_tile_dim);
|
||||
}
|
||||
|
||||
template <typename CGridDesc_M_N>
|
||||
__host__ constexpr bool CheckValidity(const CGridDesc_M_N& c_grid_desc_m_n) const
|
||||
{
|
||||
return block_to_ctile_map_.CheckValidity(c_grid_desc_m_n);
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr index_t CalculateGridSize(index_t M, index_t N) const
|
||||
{
|
||||
return block_to_ctile_map_.CalculateGridSize(M, N);
|
||||
}
|
||||
|
||||
__device__ void UpdateTileOffset(index_t offset) { tile_offset_ = offset; }
|
||||
UnderlyingBlockToCTileMap block_to_ctile_map_;
|
||||
index_t group_offset_;
|
||||
index_t tile_offset_;
|
||||
};
|
||||
|
||||
using KernelArguments = GroupedGemmTileLoopKernelArguments<NumDTensor>;
|
||||
using Block2ETileMap = BlockToCTileMap_N00_M0_N01Adapt<MPerBlock, NPerBlock>;
|
||||
using OffsetedLocalBlock2ETileMap = OffsettedBlockToCTileMap<Block2ETileMap>;
|
||||
using KernelArguments = GroupedGemmTileLoopKernelArguments<NumDTensor>;
|
||||
using Block2ETileMap = BlockToCTileMap_Grouped_M00_N0_M01Adapt<8, MPerBlock, NPerBlock>;
|
||||
using OffsettedLocalBlock2ETileMap = OffsettedBlockToCTileMap2<Block2ETileMap>;
|
||||
|
||||
// Argument
|
||||
struct Argument : public BaseArgument
|
||||
@@ -403,7 +561,6 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop
|
||||
const void* p_dev_gemm_args_;
|
||||
int occupancy_num_blocks_;
|
||||
int gpu_cu_count_;
|
||||
|
||||
const std::vector<GemmDesc>& gemm_descs_;
|
||||
AElementwiseOperation a_element_op_;
|
||||
BElementwiseOperation b_element_op_;
|
||||
@@ -496,16 +653,22 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop
|
||||
const auto kernel = kernel_grouped_gemm_multiple_d_xdl<GridwiseGemm,
|
||||
KernelArguments,
|
||||
GemmSpec,
|
||||
ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
EDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
OffsetedLocalBlock2ETileMap,
|
||||
KPerBlock,
|
||||
OffsettedLocalBlock2ETileMap,
|
||||
Block2ETileMap,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CDEElementwiseOperation>;
|
||||
CDEElementwiseOperation,
|
||||
BlkGemmPipeSched,
|
||||
BlkGemmPipelineVer>;
|
||||
return LaunchKernel(kernel, arg, dev_gemm_args, stream_config);
|
||||
}
|
||||
|
||||
@@ -546,6 +709,8 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop
|
||||
<< std::endl;
|
||||
}
|
||||
|
||||
// run multiple kernels
|
||||
|
||||
return launch_and_time_kernel(stream_config,
|
||||
kernel,
|
||||
dim3(grid_size),
|
||||
@@ -572,63 +737,41 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop
|
||||
return false;
|
||||
}
|
||||
|
||||
using DsGridDescMN = remove_cvref_t<
|
||||
decltype(GridwiseGemm::template MakeDsGridDescriptor_M_N<DsLayout, GemmSpec>(
|
||||
{}, {}, {}))>;
|
||||
|
||||
bool supported = true;
|
||||
|
||||
for(const auto& gdesc : arg.gemm_descs_)
|
||||
constexpr index_t k_batch = 1;
|
||||
for(index_t i = 0; i < arg.group_count_; ++i)
|
||||
{
|
||||
const auto M = gdesc.M_;
|
||||
const auto N = gdesc.N_;
|
||||
const auto K = gdesc.K_;
|
||||
std::array<const void*, NumDTensor> placeholder_p_ds_grid{};
|
||||
std::array<index_t, NumDTensor> stride_Ds;
|
||||
std::copy_n(arg.gemm_descs_[i].stride_Ds_.begin(), NumDTensor, stride_Ds.begin());
|
||||
using GridArg = typename GridwiseGemm::Argument;
|
||||
GridArg gridwise_arg(nullptr, // p_a_grid,
|
||||
nullptr, // p_b_grid,
|
||||
placeholder_p_ds_grid, // p_ds_grid,
|
||||
nullptr, // p_e_grid ,
|
||||
arg.gemm_descs_[i].M_,
|
||||
arg.gemm_descs_[i].N_,
|
||||
arg.gemm_descs_[i].K_,
|
||||
arg.gemm_descs_[i].stride_A_,
|
||||
arg.gemm_descs_[i].stride_B_,
|
||||
stride_Ds,
|
||||
arg.gemm_descs_[i].stride_C_,
|
||||
k_batch,
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.cde_element_op_);
|
||||
|
||||
const auto StrideA = gdesc.stride_A_;
|
||||
const auto StrideB = gdesc.stride_B_;
|
||||
const auto StrideE = gdesc.stride_C_;
|
||||
const auto& StrideDs = gdesc.stride_Ds_;
|
||||
|
||||
// If M dimension is unknown at launch time then validate just NK.
|
||||
// If N or K dim is zero (or unknown) then the vector loads responsibility lies on
|
||||
// the user.
|
||||
if(N * K == 0)
|
||||
continue;
|
||||
|
||||
const auto a_grid_desc_mk =
|
||||
GridwiseGemm::template MakeAGridDescriptor_M_K<ALayout, GemmSpec>(M, K, StrideA);
|
||||
const auto b_grid_desc_nk =
|
||||
GridwiseGemm::template MakeBGridDescriptor_N_K<BLayout, GemmSpec>(K, N, StrideB);
|
||||
const auto e_grid_desc_mn =
|
||||
GridwiseGemm::template MakeEGridDescriptor_M_N<ELayout, GemmSpec>(M, N, StrideE);
|
||||
|
||||
DsGridDescMN ds_grid_desc_mn;
|
||||
static_for<0, NumDTensor, 1>{}([&](auto j) {
|
||||
using DLayout = remove_cvref_t<tuple_element_t<j.value, DsLayout>>;
|
||||
ds_grid_desc_mn(j) =
|
||||
GridwiseGemm::template MakeEGridDescriptor_M_N<DLayout, GemmSpec>(
|
||||
M, N, StrideDs[j]);
|
||||
});
|
||||
|
||||
const auto b2c_tile_map = Block2ETileMap(M, N);
|
||||
|
||||
if(!(GridwiseGemm::template CheckValidity(a_grid_desc_mk,
|
||||
b_grid_desc_nk,
|
||||
ds_grid_desc_mn,
|
||||
e_grid_desc_mn,
|
||||
b2c_tile_map) &&
|
||||
GridwiseGemm::template CheckTensorTransfersValidity<ALayout, BLayout, ELayout>(
|
||||
M, N, K)))
|
||||
if((arg.gemm_descs_[i].K_ % AK1 != 0 || arg.gemm_descs_[i].K_ % BK1 != 0) &&
|
||||
!(GemmSpec == GemmSpecialization::MKPadding ||
|
||||
GemmSpec == GemmSpecialization::NKPadding ||
|
||||
GemmSpec == GemmSpecialization::MNKPadding ||
|
||||
GemmSpec == GemmSpecialization::KPadding))
|
||||
{
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "The provided GEMM problem size (M,N,K) [" << M << "," << N << ","
|
||||
<< K << "] are not supported by current template parameters!"
|
||||
<< " In " << __FILE__ << ":" << __LINE__
|
||||
<< ", in function: " << __func__;
|
||||
}
|
||||
supported = false;
|
||||
return false;
|
||||
}
|
||||
|
||||
supported = supported && GridwiseGemm::CheckValidity(gridwise_arg);
|
||||
}
|
||||
|
||||
return supported;
|
||||
@@ -651,16 +794,22 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop
|
||||
const auto kernel = kernel_grouped_gemm_multiple_d_xdl<GridwiseGemm,
|
||||
KernelArguments,
|
||||
GemmSpec,
|
||||
ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
EDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
OffsetedLocalBlock2ETileMap,
|
||||
KPerBlock,
|
||||
OffsettedLocalBlock2ETileMap,
|
||||
Block2ETileMap,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CDEElementwiseOperation>;
|
||||
CDEElementwiseOperation,
|
||||
BlkGemmPipeSched,
|
||||
BlkGemmPipelineVer>;
|
||||
int occupancy, num_cu;
|
||||
hip_check_error(
|
||||
hipOccupancyMaxActiveBlocksPerMultiprocessor(&occupancy, kernel, BlockSize, 0));
|
||||
@@ -696,16 +845,22 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop
|
||||
const auto kernel = kernel_grouped_gemm_multiple_d_xdl<GridwiseGemm,
|
||||
KernelArguments,
|
||||
GemmSpec,
|
||||
ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
EDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
OffsetedLocalBlock2ETileMap,
|
||||
KPerBlock,
|
||||
OffsettedLocalBlock2ETileMap,
|
||||
Block2ETileMap,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CDEElementwiseOperation>;
|
||||
CDEElementwiseOperation,
|
||||
BlkGemmPipeSched,
|
||||
BlkGemmPipelineVer>;
|
||||
int occupancy, num_cu;
|
||||
hip_check_error(
|
||||
hipOccupancyMaxActiveBlocksPerMultiprocessor(&occupancy, kernel, BlockSize, 0));
|
||||
@@ -739,6 +894,17 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop
|
||||
{
|
||||
auto str = std::ostringstream();
|
||||
|
||||
std::map<BlockGemmPipelineScheduler, std::string> BlkGemmPipelineSchedulerToString{
|
||||
{BlockGemmPipelineScheduler::Intrawave, "Intrawave"},
|
||||
{BlockGemmPipelineScheduler::Interwave, "Interwave"}};
|
||||
|
||||
std::map<BlockGemmPipelineVersion, std::string> BlkGemmPipelineVersionToString{
|
||||
{BlockGemmPipelineVersion::v1, "v1"},
|
||||
{BlockGemmPipelineVersion::v2, "v2"},
|
||||
{BlockGemmPipelineVersion::v3, "v3"},
|
||||
{BlockGemmPipelineVersion::v4, "v4"},
|
||||
{BlockGemmPipelineVersion::v5, "v5"}};
|
||||
|
||||
// clang-format off
|
||||
str << "DeviceGroupedGemmMultipleDXdlCShuffleTileLoop"
|
||||
<< "<"
|
||||
@@ -760,8 +926,10 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop
|
||||
<< CShuffleMXdlPerWavePerShuffle << ", "
|
||||
<< CShuffleNXdlPerWavePerShuffle << ", "
|
||||
<< getGemmSpecializationString(GemmSpec) << ", "
|
||||
<< PipelineVer << ", "
|
||||
<< LoopSched
|
||||
<< "BlkGemmPipelineScheduler: "
|
||||
<< BlkGemmPipelineSchedulerToString[BlkGemmPipeSched] << ", "
|
||||
<< "BlkGemmPipelineVersion: "
|
||||
<< BlkGemmPipelineVersionToString[BlkGemmPipelineVer]
|
||||
<< ">";
|
||||
// clang-format on
|
||||
|
||||
|
||||
Reference in New Issue
Block a user