Joye/revise wp pipeline (#3493)

* [CK_TILE] unify double and single lds implementation (#108)

Unify LDS buffer management API for single and double buffering modes

This change consolidates the Local Data Store (LDS) buffer management by:

Merging single and double LDS buffer APIs into a unified interface
Implementing ping-pong address calculation in pipeline when double LDS is enabled
Computing pong buffer addresses dynamically using base address offsets

---------

Co-authored-by: joye <joye@amd.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* update wp_pipeline

* fix a c++17 issue

* update for ci errors

* fix ci issues

* include a header to fix ci errors

* fix some rebase issues

* update with rebase

---------

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
joyeamd
2026-01-06 05:49:26 +08:00
committed by GitHub
parent 1224bc0a82
commit 2b563ad048
13 changed files with 766 additions and 929 deletions

View File

@@ -25,6 +25,7 @@
#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_default_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_problem.hpp"
#include "ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp"
#include "ck_tile/ops/gemm/block/block_wp_asmem_breg_creg.hpp"
#include "ck_tile/ops/gemm/block/block_wp_asmem_bsmem_creg_v1.hpp"
#include "ck_tile/ops/gemm/block/block_wp_asmem_bsmem_creg_v1_custom_policy.hpp"
#include "ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp"

View File

@@ -0,0 +1,212 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/block/block_wp_asmem_bsmem_creg_v1_custom_policy.hpp"
namespace ck_tile {
// A is block window on shared memory
// B is block window on register
// C is block distributed tensor
template <typename Problem_, typename BlockPolicy_>
struct BlockWeightPreshuffleASmemBRegCReg
{
using Problem = remove_cvref_t<Problem_>;
using BlockPolicy = remove_cvref_t<BlockPolicy_>;
using ADataType = remove_cvref_t<typename Problem::ADataType>;
using BDataType = remove_cvref_t<typename Problem::BDataType>;
using CDataType = remove_cvref_t<typename Problem::CDataType>;
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
static constexpr auto I0 = number<0>();
static constexpr auto I1 = number<1>();
static constexpr auto I2 = number<2>();
static constexpr auto idxM = I0;
static constexpr auto idxN = I1;
static constexpr auto idxK = I2;
using BlockTile = remove_cvref_t<typename BlockGemmShape::BlockTile>;
using BlockWarps = remove_cvref_t<typename BlockGemmShape::BlockWarps>;
using WarpTile = remove_cvref_t<typename BlockGemmShape::WarpTile>;
static constexpr index_t MPerBlock = BlockGemmShape::kM;
static constexpr index_t NPerBlock = BlockGemmShape::kN;
static constexpr index_t KPerBlock = BlockGemmShape::kK;
static constexpr index_t kBlockSize = Problem::kBlockSize;
static constexpr auto config = BlockPolicy::template GetWarpGemmMWarpNWarp<Problem>();
using WarpGemm = remove_cvref_t<decltype(config.template at<0>())>;
static constexpr index_t MWarp = config.template at<1>();
static constexpr index_t NWarp = config.template at<2>();
static constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WarpGemm::kM);
static constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WarpGemm::kN);
static constexpr index_t KIterPerWarp = KPerBlock / WarpGemm::kK;
static constexpr index_t MPerBlockPerIter = MWarp * WarpGemm::kM;
static constexpr index_t KPerBlockPerIter = WarpGemm::kK;
static constexpr index_t DsReadPreload = 2; // default 2, preload 2 ds read
static constexpr index_t m_preload = (MIterPerWarp * KIterPerWarp >= DsReadPreload)
? DsReadPreload
: MIterPerWarp * KIterPerWarp;
using AWarpTensor = typename WarpGemm::AWarpTensor;
statically_indexed_array<AWarpTensor, m_preload> preloaded_a_warp_tensor;
CK_TILE_DEVICE static constexpr auto MakeABlockDistributionEncode()
{
constexpr auto a_block_outer_dstr_encoding =
tile_distribution_encoding<sequence<NWarp>,
tuple<sequence<1, MWarp>, sequence<1>>,
tuple<sequence<1, 0>>,
tuple<sequence<1, 0>>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
a_block_outer_dstr_encoding, typename WarpGemm::AWarpDstrEncoding{});
return a_block_dstr_encode;
}
template <typename SmemBlockWindow>
CK_TILE_DEVICE auto MakeALoadWindows(SmemBlockWindow& a_block_window) const
{
constexpr auto a_load_dstr = make_static_tile_distribution(MakeABlockDistributionEncode());
// create MIterPerWarp × KIterPerWarp window
return generate_tuple(
[&](auto kIter) {
return generate_tuple(
[&](auto mIter) {
return make_tile_window(
get_slice_tile(
a_block_window,
sequence<mIter * MPerBlockPerIter, kIter * KPerBlockPerIter>{},
sequence<(mIter + 1) * MPerBlockPerIter,
(kIter + 1) * KPerBlockPerIter>{}),
a_load_dstr);
},
number<MIterPerWarp>{});
},
number<KIterPerWarp>{});
}
template <typename ALoadWindows>
CK_TILE_DEVICE void LocalPrefetch(const ALoadWindows& a_load_windows)
{
static_for<0, m_preload, 1>{}([&](auto loadIter) {
constexpr auto mIter = loadIter % MIterPerWarp;
constexpr auto kIter = loadIter / MIterPerWarp;
load_tile(preloaded_a_warp_tensor(loadIter),
a_load_windows[number<kIter>{}][number<mIter>{}]);
});
}
CK_TILE_DEVICE static constexpr auto MakeCBlockTile()
{
constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding<
sequence<>,
tuple<sequence<MIterPerWarp, MWarp>, sequence<NIterPerWarp, NWarp>>,
tuple<sequence<1, 2>>,
tuple<sequence<1, 1>>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
c_block_outer_dstr_encoding, typename WarpGemm::CWarpDstrEncoding{});
constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode);
auto c_block_tensor = make_static_distributed_tensor<CDataType>(c_block_dstr);
return c_block_tensor;
}
// C += A * B
template <typename CBlockTensor,
typename ALoadWindows,
typename BFlatBlockTensor,
typename BFlatDistribution>
CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor,
const ALoadWindows& a_load_windows,
BFlatBlockTensor& b_block_tensor,
const BFlatDistribution&)
{
constexpr auto MIter_2nd_last = (MIterPerWarp >= 2) ? MIterPerWarp - 2 : MIterPerWarp - 1;
using CWarpDstr = typename WarpGemm::CWarpDstr;
using CWarpTensor = typename WarpGemm::CWarpTensor;
using BWarpTensor = typename WarpGemm::BWarpTensor;
constexpr auto b_block_y_lengths =
to_sequence(BFlatDistribution{}.get_ys_to_d_descriptor().get_lengths());
constexpr auto c_warp_y_lengths =
to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
constexpr auto b_block_y_index_zeros =
uniform_sequence_gen_t<BFlatDistribution::NDimY, 0>{};
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload;
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
// read C warp tensor from C block tensor
BWarpTensor b_warp_tensor;
CWarpTensor c_warp_tensor;
b_warp_tensor.get_thread_buffer() = b_block_tensor.get_y_sliced_thread_data(
merge_sequences(sequence<nIter, kIter>{},
typename sequence_split<decltype(b_block_y_index_zeros),
2>::right_type{}),
merge_sequences(
sequence<1, 1>{},
typename sequence_split<decltype(b_block_y_lengths), 2>::right_type{}));
c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
// warp GEMM
WarpGemm{}(
c_warp_tensor, preloaded_a_warp_tensor(number<AwarpIter>{}), b_warp_tensor);
// write C warp tensor into C block tensor
c_block_tensor.set_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
c_warp_tensor.get_thread_buffer());
__builtin_amdgcn_sched_barrier(0x7F6);
});
// preload next A from lds
if constexpr((kIter * MIterPerWarp + mIter) <
(KIterPerWarp * MIterPerWarp - m_preload))
{
constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp;
constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp);
load_tile(preloaded_a_warp_tensor(number<AwarpIter>{}),
a_load_windows[number<AkIter>{}][number<AmIter>{}]);
}
// barrier
if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last))
{
block_sync_lds();
}
});
});
}
};
} // namespace ck_tile

View File

@@ -303,24 +303,15 @@ struct GroupedGemmKernel
CDataType* c_ptr = static_cast<CDataType*>(kargs.e_ptr);
// allocate LDS
__shared__ char smem_ptr_0[GetSmemSize()];
__shared__ char smem_ptr[GetSmemSize()];
// TO DO:
// Can we simplify this branching logic?
if constexpr(GemmPipeline::DoubleSmemBuffer == true)
{
__shared__ char smem_ptr_1[GemmPipeline::GetSmemSize()];
RunGemmWithPipelineSelection2LDS(a_ptr,
b_ptr,
c_ptr,
kargs.ds_ptr,
smem_ptr_0,
smem_ptr_1,
kargs,
splitk_batch_offset,
i_m,
i_n);
RunGemmWithPipelineSelection2LDS(
a_ptr, b_ptr, c_ptr, kargs.ds_ptr, smem_ptr, kargs, splitk_batch_offset, i_m, i_n);
}
else // SingleSmemBuffer
{
@@ -331,7 +322,7 @@ struct GroupedGemmKernel
b_ptr,
kargs.ds_ptr,
c_ptr,
smem_ptr_0,
smem_ptr,
kargs,
splitk_batch_offset,
i_m,
@@ -343,7 +334,7 @@ struct GroupedGemmKernel
{b_ptr},
kargs.ds_ptr,
c_ptr,
smem_ptr_0,
smem_ptr,
kargs,
splitk_batch_offset,
i_m,
@@ -425,9 +416,7 @@ struct GroupedGemmKernel
* @param a_ptr input A pointer
* @param b_ptr input B pointer
* @param c_ptr output C pointer
* @param ds_ptr input Ds pointer
* @param smem_ptr_0 The starting pointer of 1st shared memory block.
* @param smem_ptr_1 The starting pointer of 2nd shared memory block.
* @param smem_ptr The start memory pointer of the shared memory block.
* @param kargs GEMM kernel arguments
* @param splitk_batch_offset Utility structure used to calculate k batch.
* @param block_idx_m The GEMM's output M dimension tile index processed by this workgroup.
@@ -439,8 +428,7 @@ struct GroupedGemmKernel
const BDataType* b_ptr,
CDataType* c_ptr,
const std::array<const void*, NumDTensor_>& ds_ptr,
void* __restrict__ smem_ptr_0,
void* __restrict__ smem_ptr_1,
void* __restrict__ smem_ptr,
const UniversalGemmKernelArgs<1, 1, NumDTensor_>& kargs,
const typename Base::SplitKBatchOffset& splitk_batch_offset,
const index_t block_idx_m,
@@ -460,8 +448,8 @@ struct GroupedGemmKernel
amd_wave_read_first_lane(TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k));
// Run GEMM cooperatively by whole workgroup.
const auto& c_block_tile = GemmPipeline{}.template operator()(
a_block_window, b_block_window, num_loop, smem_ptr_0, smem_ptr_1);
const auto& c_block_tile =
GemmPipeline{}.template operator()(a_block_window, b_block_window, num_loop, smem_ptr);
// Run Epilogue Pipeline
if(kargs.k_batch == 1)
@@ -469,7 +457,7 @@ struct GroupedGemmKernel
auto c_block_window = Base::template MakeCBlockWindows<memory_operation_enum::set>(
c_ptr, kargs, block_idx_m, block_idx_n);
EpiloguePipeline{}(c_block_window, c_block_tile, d_block_window, smem_ptr_0);
EpiloguePipeline{}(c_block_window, c_block_tile, d_block_window, smem_ptr);
}
else
{
@@ -477,7 +465,7 @@ struct GroupedGemmKernel
Base::template MakeCBlockWindows<memory_operation_enum::atomic_add>(
c_ptr, kargs, block_idx_m, block_idx_n);
EpiloguePipeline{}(c_block_window, c_block_tile, d_block_window, smem_ptr_0);
EpiloguePipeline{}(c_block_window, c_block_tile, d_block_window, smem_ptr);
}
}

View File

@@ -978,7 +978,7 @@ struct UniversalGemmKernel
* @param bs_ptr input Bs pointer
* @param ds_ptr input Ds pointer
* @param e_ptr output E pointer
* @param smem_ptr_0 The start memory pointer of the shared memory block.
* @param smem_ptr The start memory pointer of the shared memory block.
* @param kargs GEMM kernel arguments
* @param splitk_batch_offset splitk_batch_offset Utility structure used to calculate k batch.
* @param block_idx_m The GEMM's output M dimension tile index processed by this workgroup.
@@ -990,7 +990,7 @@ struct UniversalGemmKernel
const std::array<const BDataType*, NumBTensor>& bs_ptr,
const std::array<const void*, NumDTensor>& ds_ptr,
EDataType* e_ptr,
void* smem_ptr_0,
void* smem_ptr,
const KernelArgs& kargs,
const SplitKBatchOffset& splitk_batch_offset,
const index_t block_idx_m,
@@ -1008,7 +1008,7 @@ struct UniversalGemmKernel
// Run GEMM cooperatively by whole workgroup.
const auto& c_block_tile = GemmPipeline{}.template operator()(
as_block_window, AElementWise{}, bs_block_window, BElementWise{}, num_loop, smem_ptr_0);
as_block_window, AElementWise{}, bs_block_window, BElementWise{}, num_loop, smem_ptr);
const index_t k_batch = amd_wave_read_first_lane(kargs.k_batch);
// Run Epilogue Pipeline
@@ -1016,77 +1016,63 @@ struct UniversalGemmKernel
{
auto c_block_window = MakeCBlockWindows<memory_operation_enum::set>(
e_ptr, kargs, block_idx_m, block_idx_n);
EpiloguePipeline{}(c_block_window, c_block_tile, ds_block_window, smem_ptr_0);
EpiloguePipeline{}(c_block_window, c_block_tile, ds_block_window, smem_ptr);
}
else
{
auto c_block_window = MakeCBlockWindows<memory_operation_enum::atomic_add>(
e_ptr, kargs, block_idx_m, block_idx_n);
EpiloguePipeline{}(c_block_window, c_block_tile, ds_block_window, smem_ptr_0);
EpiloguePipeline{}(c_block_window, c_block_tile, ds_block_window, smem_ptr);
}
}
/**
* @brief Runs single GEMM problem cooperatively by whole workgroup.
*
* @note RunGEMM2LDS in with two shared memory buffers using the ping pong buffer mechanism.
*
* @param as_ptr input As pointer
* @param bs_ptr input Bs pointer
* @param ds_ptr input Ds pointer
* @param e_ptr output E pointer
* @param smem_ptr_0 The starting pointer of 1st shared memory block.
* @param smem_ptr_1 The starting pointer of 2nd shared memory block.
* @param kargs GEMM kernel arguments
* @param splitk_batch_offset Utility structure used to calculate k batch.
* @param block_idx_m The GEMM's output M dimension tile index processed by this workgroup.
* @param block_idx_n The GEMM's output N dimension tile index processed by this workgroup.
*
*/
CK_TILE_DEVICE static void RunGemm2LDS(const std::array<const ADataType*, NumATensor>& as_ptr,
const std::array<const BDataType*, NumBTensor>& bs_ptr,
const std::array<const void*, NumDTensor>& ds_ptr,
EDataType* e_ptr,
void* __restrict__ smem_ptr_0,
void* __restrict__ smem_ptr_1,
const KernelArgs& kargs,
const SplitKBatchOffset& splitk_batch_offset,
const index_t block_idx_m,
const index_t block_idx_n)
CK_TILE_DEVICE static auto
GetTileCoordinates(const KernelArgs& kargs) -> tuple<index_t, index_t>
{
// Create block windows using specialized methods
const auto& as_block_window =
MakeABlockWindows(as_ptr, kargs, splitk_batch_offset.splitted_k, block_idx_m);
const auto& bs_block_window =
MakeBBlockWindows(bs_ptr, kargs, splitk_batch_offset.splitted_k, block_idx_n);
const auto& ds_block_window = MakeDBlockWindows(ds_ptr, kargs, block_idx_m, block_idx_n);
index_t iM, iN;
const index_t num_loop =
amd_wave_read_first_lane(TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k));
// Regular launch: use 1D block indexing
const auto blockId = amd_wave_read_first_lane(blockIdx.x);
const auto [tile_m, tile_n] = TilePartitioner{kargs.M, kargs.N}.GetOutputTileIndex(blockId);
iM = tile_m;
iN = tile_n;
// Run GEMM cooperatively by whole workgroup.
const auto& c_block_tile = GemmPipeline{}.template operator()(as_block_window,
AElementWise{},
bs_block_window,
BElementWise{},
num_loop,
smem_ptr_0,
smem_ptr_1);
const index_t i_m = amd_wave_read_first_lane(iM * TilePartitioner::MPerBlock);
const index_t i_n = amd_wave_read_first_lane(iN * TilePartitioner::NPerBlock);
// Run Epilogue Pipeline
if(kargs.k_batch == 1)
return make_tuple(i_m, i_n);
}
// Helper functions
CK_TILE_DEVICE static auto GetBlockId() -> index_t
{
// For 1D regular launch
return amd_wave_read_first_lane(get_block_id());
}
CK_TILE_DEVICE static auto GetGridSize() -> index_t
{
// For 1D regular launch
return amd_wave_read_first_lane(get_grid_size());
}
// Helper to get total number of tiles, handling both dim3 and index_t return types
template <typename... Args>
CK_TILE_HOST_DEVICE static auto GetNumTiles(Args&&... args) -> index_t
{
auto grid_size = TilePartitioner::GridSize(std::forward<Args>(args)...);
using GridSizeType = decltype(grid_size);
if constexpr(std::is_same_v<GridSizeType, dim3>)
{
auto c_block_window = MakeCBlockWindows<memory_operation_enum::set>(
e_ptr, kargs, block_idx_m, block_idx_n);
EpiloguePipeline{}(c_block_window, c_block_tile, ds_block_window, smem_ptr_0);
// GridSize returns dim3: compute total tiles as x * y * z
return amd_wave_read_first_lane(grid_size.x * grid_size.y * grid_size.z);
}
else
{
auto c_block_window = MakeCBlockWindows<memory_operation_enum::atomic_add>(
e_ptr, kargs, block_idx_m, block_idx_n);
EpiloguePipeline{}(c_block_window, c_block_tile, ds_block_window, smem_ptr_0);
// GridSize returns scalar (index_t): use directly
return amd_wave_read_first_lane(grid_size);
}
}
@@ -1123,36 +1109,12 @@ struct UniversalGemmKernel
}
// allocate LDS
__shared__ char smem_ptr_0[GetSmemSize()];
__shared__ char smem_ptr[GetSmemSize()];
if constexpr(GemmPipeline::DoubleSmemBuffer == true)
{
__shared__ char smem_ptr_1[GemmPipeline::GetSmemSize()];
RunGemm2LDS(as_ptr,
bs_ptr,
kargs.ds_ptr,
e_ptr,
smem_ptr_0,
smem_ptr_1,
kargs,
splitk_batch_offset,
i_m,
i_n);
}
else
{
constexpr auto scheduler_type = (GemmPipeline::NumWaveGroups == 1);
RunGemm<scheduler_type>(as_ptr,
bs_ptr,
kargs.ds_ptr,
e_ptr,
smem_ptr_0,
kargs,
splitk_batch_offset,
i_m,
i_n);
}
constexpr auto scheduler_type =
GemmPipeline::DoubleSmemBuffer || (GemmPipeline::NumWaveGroups == 1);
RunGemm<scheduler_type>(
as_ptr, bs_ptr, kargs.ds_ptr, e_ptr, smem_ptr, kargs, splitk_batch_offset, i_m, i_n);
}
// Persistent kernel entry point
@@ -1199,34 +1161,19 @@ struct UniversalGemmKernel
}
// allocate LDS
__shared__ char smem_ptr_0[GetSmemSize()];
__shared__ char smem_ptr[GetSmemSize()];
// Run the GEMM
if constexpr(GemmPipeline::DoubleSmemBuffer == true)
{
__shared__ char smem_ptr_1[GemmPipeline::GetSmemSize()];
RunGemm2LDS(as_ptr,
bs_ptr,
kargs.ds_ptr,
e_ptr,
smem_ptr_0,
smem_ptr_1,
kargs,
splitk_batch_offset,
i_m,
i_n);
}
else
{
RunGemm(as_ptr,
bs_ptr,
kargs.ds_ptr,
e_ptr,
smem_ptr_0,
kargs,
splitk_batch_offset,
i_m,
i_n);
}
RunGemm(as_ptr,
bs_ptr,
kargs.ds_ptr,
e_ptr,
smem_ptr,
kargs,
splitk_batch_offset,
i_m,
i_n);
// Advance to the next work item
block_id += grid_size;
if(block_id >= num_work)

View File

@@ -64,12 +64,17 @@ struct GemmPipelineAgBgCrImplBase
CK_TILE_HOST_DEVICE static constexpr auto TransposeC() { return Problem::TransposeC; }
template <typename DstBlockTile, typename SrcTileWindow, typename DramTileWindowStep>
template <typename SrcDataType = void,
typename DstDataType = void,
index_t UnaryOpSize = 8,
typename DstBlockTile,
typename SrcTileWindow,
typename DramTileWindowStep>
CK_TILE_DEVICE void GlobalPrefetch(DstBlockTile& dst_block_tile,
SrcTileWindow& dram_tile_window,
const DramTileWindowStep& dram_tile_window_step) const
{
load_tile(dst_block_tile, dram_tile_window);
load_int4_tile<SrcDataType, DstDataType, UnaryOpSize>(dst_block_tile, dram_tile_window);
move_tile_window(dram_tile_window, dram_tile_window_step);
}
@@ -217,22 +222,17 @@ struct GemmPipelineAgBgCrImplBase
return std::move(a_copy_dram_window);
}
template <typename ADramBlockWindowTmp, typename ALdsTensorView, typename ALdsLoadTileDistr>
CK_TILE_DEVICE constexpr auto GetAWindows(const ADramBlockWindowTmp& a_dram_block_window_tmp,
const ALdsTensorView& a_lds_block_view,
const ALdsLoadTileDistr&,
const array<index_t, 2>& offset = {0, 0}) const
template <typename ALdsTensorView, typename ALdsLoadTileDistr>
CK_TILE_DEVICE constexpr auto MakeALdsWindows(const ALdsTensorView& a_lds_block_view,
const ALdsLoadTileDistr&) const
{
// A DRAM tile window for load
auto a_copy_dram_window = CopyADramWindow(a_dram_block_window_tmp, offset);
// A LDS tile window for store
auto a_lds_shape = []() {
if constexpr(is_a_load_tr)
return make_tuple(number<KPerBlock>{}, number<MPerBlock>{});
else
return make_tuple(number<MPerBlock>{}, number<KPerBlock>{});
}();
auto a_copy_lds_window = make_tile_window(a_lds_block_view, a_lds_shape, {0, 0});
auto a_lds_load_tile_distr = []() {
@@ -244,32 +244,73 @@ struct GemmPipelineAgBgCrImplBase
else
return ALdsLoadTileDistr{};
}();
auto a_lds_gemm_window =
make_tile_window(a_lds_block_view, a_lds_shape, {0, 0}, a_lds_load_tile_distr);
return make_tuple(std::move(a_copy_lds_window), std::move(a_lds_gemm_window));
}
template <
typename ADramBlockWindowTmp,
typename ALdsTensorView,
typename ALdsLoadTileDistr,
typename std::enable_if_t<!is_detected<is_tuple, ALdsTensorView>::value, bool>* = nullptr>
CK_TILE_DEVICE constexpr auto GetAWindows(const ADramBlockWindowTmp& a_dram_block_window_tmp,
const ALdsTensorView& a_lds_block_view,
const ALdsLoadTileDistr& a_lds_load_tile_distr,
const array<index_t, 2>& offset = {0, 0}) const
{
// A DRAM tile window for load
auto a_copy_dram_window = CopyADramWindow(a_dram_block_window_tmp, offset);
// Create LDS windows
auto [a_copy_lds_window, a_lds_gemm_window] =
MakeALdsWindows(a_lds_block_view, a_lds_load_tile_distr);
return make_tuple(std::move(a_copy_dram_window),
std::move(a_copy_lds_window),
std::move(a_lds_gemm_window));
}
template <typename BDramBlockWindowTmp, typename BLdsTensorView, typename BLdsLoadTileDistr>
CK_TILE_DEVICE constexpr auto GetBWindows(const BDramBlockWindowTmp& b_dram_block_window_tmp,
const BLdsTensorView& b_lds_block_view,
const BLdsLoadTileDistr&,
// Unified GetAWindows that supports 1, 2, or 3 LDS buffers
template <typename ADramBlockWindowTmp,
typename ALdsTensorViewsTuple,
typename ALdsLoadTileDistr,
typename std::enable_if_t<is_detected<is_tuple, ALdsTensorViewsTuple>::value, bool>* =
nullptr>
CK_TILE_DEVICE constexpr auto GetAWindows(const ADramBlockWindowTmp& a_dram_block_window_tmp,
const ALdsTensorViewsTuple& a_lds_block_views_tuple,
const ALdsLoadTileDistr& a_lds_load_tile_distr,
const array<index_t, 2>& offset = {0, 0}) const
{
// A DRAM tile window for load
auto b_copy_dram_window = CopyBDramWindow(b_dram_block_window_tmp, offset);
auto a_copy_dram_window = CopyADramWindow(a_dram_block_window_tmp, offset);
// TODO: Do we really need those two tile windows???
// They're exactly same...
// B LDS tile window for store
// Create LDS windows for each buffer
constexpr index_t num_buffers = ALdsTensorViewsTuple::size();
auto a_lds_windows = generate_tuple(
[&](auto i) {
return MakeALdsWindows(a_lds_block_views_tuple[i], a_lds_load_tile_distr);
},
number<num_buffers>{});
// Return: (dram_window, lds_windows_tuple)
// lds_windows_tuple[i] = (copy_lds_window_i, lds_gemm_window_i)
return make_tuple(std::move(a_copy_dram_window), std::move(a_lds_windows));
}
template <typename BLdsTensorView, typename BLdsLoadTileDistr>
CK_TILE_DEVICE constexpr auto MakeBLdsWindows(const BLdsTensorView& b_lds_block_view,
const BLdsLoadTileDistr&) const
{
auto b_lds_shape = []() {
if constexpr(is_b_load_tr)
return make_tuple(number<KPerBlock>{}, number<NPerBlock>{});
else
return make_tuple(number<NPerBlock>{}, number<KPerBlock>{});
}();
auto b_copy_lds_window = make_tile_window(b_lds_block_view, b_lds_shape, {0, 0});
using BLdsDataType =
@@ -286,13 +327,61 @@ struct GemmPipelineAgBgCrImplBase
else
return BLdsLoadTileDistr{};
}();
auto b_lds_gemm_window =
make_tile_window(b_lds_block_view, b_lds_shape, {0, 0}, b_lds_load_tile_distr);
return make_tuple(std::move(b_copy_lds_window), std::move(b_lds_gemm_window));
}
template <
typename BDramBlockWindowTmp,
typename BLdsTensorView,
typename BLdsLoadTileDistr,
typename std::enable_if_t<!is_detected<is_tuple, BLdsTensorView>::value, bool>* = nullptr>
CK_TILE_DEVICE constexpr auto GetBWindows(const BDramBlockWindowTmp& b_dram_block_window_tmp,
const BLdsTensorView& b_lds_block_view,
const BLdsLoadTileDistr& b_lds_load_tile_distr,
const array<index_t, 2>& offset = {0, 0}) const
{
// A DRAM tile window for load
auto b_copy_dram_window = CopyBDramWindow(b_dram_block_window_tmp, offset);
// Create LDS windows
auto [b_copy_lds_window, b_lds_gemm_window] =
MakeBLdsWindows(b_lds_block_view, b_lds_load_tile_distr);
return make_tuple(std::move(b_copy_dram_window),
std::move(b_copy_lds_window),
std::move(b_lds_gemm_window));
}
// Unified GetBWindows that supports 1, 2, or 3 LDS buffers
template <typename BDramBlockWindowTmp,
typename BLdsTensorViewsTuple,
typename BLdsLoadTileDistr,
typename std::enable_if_t<is_detected<is_tuple, BLdsTensorViewsTuple>::value, bool>* =
nullptr>
CK_TILE_DEVICE constexpr auto GetBWindows(const BDramBlockWindowTmp& b_dram_block_window_tmp,
const BLdsTensorViewsTuple& b_lds_block_views_tuple,
const BLdsLoadTileDistr& b_lds_load_tile_distr,
const array<index_t, 2>& offset = {0, 0}) const
{
// B DRAM tile window for load
auto b_copy_dram_window = CopyBDramWindow(b_dram_block_window_tmp, offset);
// Create LDS windows for each buffer
constexpr index_t num_buffers = BLdsTensorViewsTuple::size();
auto b_lds_windows = generate_tuple(
[&](auto i) {
return MakeBLdsWindows(b_lds_block_views_tuple[i], b_lds_load_tile_distr);
},
number<num_buffers>{});
// Return: (dram_window, lds_windows_tuple)
// lds_windows_tuple[i] = (copy_lds_window_i, lds_gemm_window_i)
return make_tuple(std::move(b_copy_dram_window), std::move(b_lds_windows));
}
};
} // namespace ck_tile

View File

@@ -158,6 +158,8 @@ struct GemmPipelineAgBgCrCompAsync : public BaseGemmPipelineAgBgCrCompAsync<Prob
static constexpr bool DoubleSmemBuffer = Problem::DoubleSmemBuffer;
static_assert(DoubleSmemBuffer == true, "pipeline requires double smem buffer");
static constexpr auto Scheduler = Problem::Scheduler;
static constexpr auto is_a_load_tr_v = bool_constant<PipelineImplBase::is_a_load_tr>{};
@@ -172,7 +174,8 @@ struct GemmPipelineAgBgCrCompAsync : public BaseGemmPipelineAgBgCrCompAsync<Prob
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
{
return Policy::template GetSmemSize<Problem>();
constexpr index_t smem_size = Policy::template GetSmemSize<Problem>();
return 2 * smem_size;
}
CK_TILE_HOST_DEVICE static constexpr auto IsTransposeC()
@@ -240,8 +243,7 @@ struct GemmPipelineAgBgCrCompAsync : public BaseGemmPipelineAgBgCrCompAsync<Prob
const BsDramBlockWindowTmp& b_dram_block_window_tmp,
const BElementFunction& b_element_func,
index_t num_loop,
void* __restrict__ p_smem_0,
void* __restrict__ p_smem_1) const
void* __restrict__ p_smem) const
{
// TODO support multi-ABD
static_assert(1 == std::tuple_size_v<AsDramBlockWindowTmp>);
@@ -303,8 +305,10 @@ struct GemmPipelineAgBgCrCompAsync : public BaseGemmPipelineAgBgCrCompAsync<Prob
number<BsLayout::size()>{});
// this pipeline has a pair of LDS buffers per logical tile
auto&& [a_lds_block0, b_lds_block0] = Base::GetABLdsTensorViews(p_smem_0);
auto&& [a_lds_block1, b_lds_block1] = Base::GetABLdsTensorViews(p_smem_1);
constexpr index_t smem_size = Policy::template GetSmemSize<Problem>();
auto&& [a_lds_block0, b_lds_block0] = Base::GetABLdsTensorViews(p_smem);
auto&& [a_lds_block1, b_lds_block1] =
Base::GetABLdsTensorViews(static_cast<char*>(p_smem) + smem_size);
// set up LDS tile shapes
constexpr auto a_lds_shape = []() {
@@ -534,21 +538,18 @@ struct GemmPipelineAgBgCrCompAsync : public BaseGemmPipelineAgBgCrCompAsync<Prob
const BDramBlockWindowTmp& b_dram_block_window_tmp,
const BElementFunction& b_element_func,
index_t num_loop,
void* p_smem_0,
void* p_smem_1) const
void* p_smem) const
{
const bool has_hot_loop = Base::BlockHasHotloop(num_loop);
const auto tail_number = Base::GetBlockLoopTailNum(num_loop);
const auto RunPipeline = [&](auto hot_loop_, auto tail_num_) {
const auto RunPipeline = [&](auto hot_loop_, auto tail_num_) {
return PipelineImpl<Scheduler>{}.template operator()<hot_loop_.value, tail_num_.value>(
a_dram_block_window_tmp,
a_element_func,
b_dram_block_window_tmp,
b_element_func,
num_loop,
p_smem_0,
p_smem_1);
p_smem);
};
return Base::TailHandler(RunPipeline, has_hot_loop, tail_number);
@@ -559,8 +560,7 @@ struct GemmPipelineAgBgCrCompAsync : public BaseGemmPipelineAgBgCrCompAsync<Prob
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
const BDramBlockWindowTmp& b_dram_block_window_tmp,
const index_t num_loop,
void* __restrict__ p_smem_0,
void* __restrict__ p_smem_1) const
void* __restrict__ p_smem) const
{
const bool has_hot_loop = Base::BlockHasHotloop(num_loop);
const auto tail_number = Base::GetBlockLoopTailNum(num_loop);
@@ -572,8 +572,7 @@ struct GemmPipelineAgBgCrCompAsync : public BaseGemmPipelineAgBgCrCompAsync<Prob
b_dram_block_window_tmp,
[](const BDataType& b) { return b; },
num_loop,
p_smem_0,
p_smem_1);
p_smem);
};
return Base::TailHandler(RunPipeline, has_hot_loop, tail_number);

View File

@@ -172,6 +172,8 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4<Problem>
static constexpr auto is_a_load_tr_v = bool_constant<PipelineImplBase::is_a_load_tr>{};
static constexpr auto is_b_load_tr_v = bool_constant<PipelineImplBase::is_b_load_tr>{};
static_assert(DoubleSmemBuffer == true, "pipeline requires double smem buffer");
[[nodiscard]] CK_TILE_HOST static const std::string GetPipelineName()
{
// clang-format off
@@ -191,7 +193,8 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4<Problem>
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
{
return Policy::template GetSmemSize<Problem>();
constexpr index_t smem_size = Policy::template GetSmemSize<Problem>();
return 2 * smem_size;
}
CK_TILE_HOST_DEVICE static constexpr auto IsTransposeC()
@@ -281,8 +284,7 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4<Problem>
const BsDramBlockWindowTmp& b_dram_block_window_tmp,
const BElementFunction& b_element_func,
index_t num_loop,
void* __restrict__ p_smem_0,
void* __restrict__ p_smem_1) const
void* __restrict__ p_smem) const
{
using ADramBlockWindowTmp =
remove_cvref_t<std::tuple_element_t<number<0>{}, AsDramBlockWindowTmp>>;
@@ -324,8 +326,10 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4<Problem>
// global read 0
////////////// LDS desc, window & register /////////////////
auto&& [a_lds_block0, b_lds_block0] = Base::GetABLdsTensorViews(p_smem_0);
auto&& [a_lds_block1, b_lds_block1] = Base::GetABLdsTensorViews(p_smem_1);
constexpr index_t smem_size = Policy::template GetSmemSize<Problem>();
auto&& [a_lds_block0, b_lds_block0] = Base::GetABLdsTensorViews(p_smem);
auto&& [a_lds_block1, b_lds_block1] =
Base::GetABLdsTensorViews(static_cast<char*>(p_smem) + smem_size);
constexpr auto a_lds_shape = []() {
if constexpr(is_a_load_tr_v())
@@ -680,8 +684,7 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4<Problem>
const BsDramBlockWindowTmp& b_dram_block_window_tmp,
const BElementFunction& b_element_func,
index_t num_loop,
void* p_smem_0,
void* p_smem_1) const
void* p_smem) const
{
const bool has_hot_loop = Base::BlockHasHotloop(num_loop);
const auto tail_number = Base::GetBlockLoopTailNum(num_loop);
@@ -693,8 +696,7 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4<Problem>
b_dram_block_window_tmp,
b_element_func,
num_loop,
p_smem_0,
p_smem_1);
p_smem);
};
return Base::TailHandler(RunPipeline, has_hot_loop, tail_number);
@@ -708,8 +710,7 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4<Problem>
CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp,
const BsDramBlockWindowTmp& b_dram_block_window_tmp,
const index_t num_loop,
void* __restrict__ p_smem_0,
void* __restrict__ p_smem_1) const
void* __restrict__ p_smem) const
{
const bool has_hot_loop = Base::BlockHasHotloop(num_loop);
const auto tail_number = Base::GetBlockLoopTailNum(num_loop);
@@ -721,8 +722,7 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4<Problem>
b_dram_block_window_tmp,
[](auto& e, const BDataType& b) { e = b; },
num_loop,
p_smem_0,
p_smem_1);
p_smem);
};
return Base::TailHandler(RunPipeline, has_hot_loop, tail_number);
@@ -738,8 +738,7 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4<Problem>
index_t num_loop,
bool has_hot_loop,
TailNumber tail_number,
void* __restrict__ p_smem_0,
void* __restrict__ p_smem_1) const
void* __restrict__ p_smem) const
{
const auto RunPipeline = [&](auto hot_loop_, auto tail_num_) {
constexpr bool hot_loop = hot_loop_.value;
@@ -751,8 +750,7 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4<Problem>
b_dram_block_window_tmp,
PassThrough,
num_loop,
p_smem_0,
p_smem_1);
p_smem);
};
return Base::TailHandler(RunPipeline, has_hot_loop, tail_number);
}
@@ -769,16 +767,14 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4<Problem>
const BDramBlockWindowTmp& b_dram_block_window_tmp,
const BElementFunction& b_element_func,
index_t num_loop,
void* p_smem_0,
void* p_smem_1) const
void* p_smem) const
{
return operator()(ck_tile::make_tuple(a_dram_block_window_tmp),
a_element_func,
ck_tile::make_tuple(b_dram_block_window_tmp),
b_element_func,
num_loop,
p_smem_0,
p_smem_1);
p_smem);
}
template <typename ADramBlockWindowTmp,
@@ -789,14 +785,12 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4<Problem>
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
const BDramBlockWindowTmp& b_dram_block_window_tmp,
const index_t num_loop,
void* __restrict__ p_smem_0,
void* __restrict__ p_smem_1) const
void* __restrict__ p_smem) const
{
return operator()(ck_tile::make_tuple(a_dram_block_window_tmp),
ck_tile::make_tuple(b_dram_block_window_tmp),
num_loop,
p_smem_0,
p_smem_1);
p_smem);
}
template <typename ADramBlockWindowTmp,
@@ -809,16 +803,14 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4<Problem>
index_t num_loop,
bool has_hot_loop,
TailNumber tail_number,
void* __restrict__ p_smem_0,
void* __restrict__ p_smem_1) const
void* __restrict__ p_smem) const
{
return operator()(ck_tile::make_tuple(a_dram_block_window_tmp),
ck_tile::make_tuple(b_dram_block_window_tmp),
num_loop,
has_hot_loop,
tail_number,
p_smem_0,
p_smem_1);
p_smem);
}
};
} // namespace ck_tile

View File

@@ -4,6 +4,7 @@
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/block/block_wp_asmem_breg_creg.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
namespace ck_tile {
@@ -201,6 +202,12 @@ struct UniversalWeightPreshufflePipelineAgBgCrPolicy
{
using TileShape = typename Problem::BlockGemmShape;
constexpr index_t kNPerBlock = TileShape::kN;
constexpr index_t kKPerBlock = TileShape::kK;
constexpr index_t NIterPerWarp =
kNPerBlock / TileShape::BlockWarps::at(I1) / TileShape::WarpTile::at(I1);
constexpr index_t KIterPerWarp = kKPerBlock / TileShape::WarpTile::at(I2);
constexpr index_t BlockSize = Problem::kBlockSize;
constexpr index_t WaveSize = get_warp_size();
constexpr index_t WaveNum = BlockSize / WaveSize;
@@ -213,13 +220,13 @@ struct UniversalWeightPreshufflePipelineAgBgCrPolicy
#endif
constexpr index_t KThdPerWave = WaveSize / KRepeatInWave; // threads cnt in K dim
constexpr index_t KWavePerBlk = 1;
constexpr index_t KRepeat = 1;
constexpr index_t KRepeat = KIterPerWarp;
static_assert(TileShape::flatKPerWarp == KThdPerWave * KBPerLoad, "wrong");
constexpr index_t NBPerLoad = 1;
constexpr index_t NThdPerWave = 1;
constexpr index_t NWavePerBlk = TileShape::BlockWarps::at(number<1>{}); // N_Warp
constexpr index_t NRepeat = 1;
constexpr index_t NRepeat = NIterPerWarp;
constexpr index_t WaveRepeat = WaveNum / TileShape::flatNPerWarp;
return make_static_tile_distribution(
@@ -232,8 +239,8 @@ struct UniversalWeightPreshufflePipelineAgBgCrPolicy
tuple<sequence<0, 1, 2>, sequence<0, 1, 2>>, // which direction
tuple<sequence<0, 1, 1>, sequence<1, 2, 2>>, // which index
// <repeat, vec_load>
sequence<1, 1, 2, 2>,
sequence<0, 3, 0, 3>>{});
sequence<1, 2, 1, 2>,
sequence<0, 0, 3, 3>>{});
}
template <typename Problem>
@@ -307,7 +314,7 @@ struct UniversalWeightPreshufflePipelineAgBgCrPolicy
typename Problem::CDataType,
BlockWarps,
WarpGemm>;
return BlockWeightPreshuffleASmemBSmemCRegV1<Problem, BlockWeightPreshufflePolicy>{};
return BlockWeightPreshuffleASmemBRegCReg<Problem, BlockWeightPreshufflePolicy>{};
}
/**
* @brief Get the vector store size for C tensor.
@@ -325,7 +332,7 @@ struct UniversalWeightPreshufflePipelineAgBgCrPolicy
CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeC()
{
using BlockGemm = remove_cvref_t<decltype(GetBlockWeightPreshuffle<Problem>())>;
using WG_ = typename BlockGemm::WG;
using WG_ = typename BlockGemm::WarpGemm;
constexpr bool TransposeC = Problem::TransposeC;
using CLayout = typename Problem::CLayout;

View File

@@ -32,19 +32,34 @@ struct BaseWeightPreshufflePipelineAGmemBGmemCRegV2
template <typename RunFunction>
CK_TILE_HOST_DEVICE static auto
TailHandler(const RunFunction& run_func, bool, TailNumber tail_number)
TailHandler(const RunFunction& run_func, bool has_hot_loop, TailNumber tail_number)
{
if(tail_number == TailNumber::Odd)
if(has_hot_loop)
{
return run_func(bool_constant<true>{},
integral_constant<TailNumber, TailNumber::Odd>{});
if(tail_number == TailNumber::Odd)
{
return run_func(bool_constant<true>{},
integral_constant<TailNumber, TailNumber::Odd>{});
}
else // Even tail number
{
return run_func(bool_constant<true>{},
integral_constant<TailNumber, TailNumber::Even>{});
}
}
else // Even tail number
else
{
return run_func(bool_constant<true>{},
integral_constant<TailNumber, TailNumber::Even>{});
if(tail_number == TailNumber::Odd)
{
return run_func(bool_constant<false>{},
integral_constant<TailNumber, TailNumber::Odd>{});
}
else // Even tail number
{
return run_func(bool_constant<false>{},
integral_constant<TailNumber, TailNumber::Even>{});
}
}
return run_func(bool_constant<true>{}, integral_constant<TailNumber, TailNumber::Empty>{});
}
};
@@ -52,7 +67,8 @@ template <typename Problem, typename PipelinePolicy = UniversalWeightPreshuffleP
struct WeightPreshufflePipelineAGmemBGmemCRegV2
: public BaseWeightPreshufflePipelineAGmemBGmemCRegV2<Problem>
{
using Base = BaseWeightPreshufflePipelineAGmemBGmemCRegV2<Problem>;
using Base = BaseWeightPreshufflePipelineAGmemBGmemCRegV2<Problem>;
using PipelineImplBase = GemmPipelineAgBgCrImplBase<Problem, PipelinePolicy>;
using AsDataType = remove_cvref_t<typename Problem::AsDataTypeTuple>;
using BsDataType = remove_cvref_t<typename Problem::BsDataTypeTuple>;
@@ -75,11 +91,6 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2
using BlockWeightPreshuffle =
remove_cvref_t<decltype(PipelinePolicy::template GetBlockWeightPreshuffle<Problem>())>;
static constexpr auto config =
BlockWeightPreshuffle::BlockPolicy::template GetWarpGemmMWarpNWarp<Problem>();
using WG = remove_cvref_t<decltype(config.template at<0>())>;
static constexpr index_t DsWritePreIssue = 3; // default 2, ds write at MIter - 2
static constexpr index_t DsReadPreload = 2; // default 2, preload 2 ds read
@@ -95,6 +106,8 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2
static constexpr index_t NPerBlock = BlockGemmShape::kN;
static constexpr index_t KPerBlock = BlockGemmShape::kK;
static constexpr index_t kflatKPerBlock = BlockGemmShape::flatKPerBlock;
static constexpr index_t flatKPerWarp = BlockGemmShape::flatKPerWarp;
static constexpr index_t flatNPerWarp = BlockGemmShape::flatNPerWarp;
@@ -131,12 +144,16 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2
using BlockWarps = remove_cvref_t<typename BlockGemmShape::BlockWarps>;
using WarpTile = remove_cvref_t<typename BlockGemmShape::WarpTile>;
static constexpr index_t MWarp = config.template at<1>();
static constexpr index_t NWarp = config.template at<2>();
static constexpr index_t MWarp = BlockWarps::at(I0);
static constexpr index_t NWarp = BlockWarps::at(I1);
static constexpr index_t MIterPerWarp = kMPerBlock / (MWarp * WG::kM);
static constexpr index_t NIterPerWarp = kNPerBlock / (NWarp * WG::kN);
static constexpr index_t KIterPerWarp = kKPerBlock / WG::kK;
static constexpr index_t WarpTileM = WarpTile::at(I0);
static constexpr index_t WarpTileN = WarpTile::at(I1);
static constexpr index_t WarpTileK = WarpTile::at(I2);
static constexpr index_t MIterPerWarp = kMPerBlock / (MWarp * WarpTileM);
static constexpr index_t NIterPerWarp = kNPerBlock / (NWarp * WarpTileN);
static constexpr index_t KIterPerWarp = kKPerBlock / WarpTileK;
static constexpr index_t KFlatPerBlockPerIter = flatKPerWarp;
static constexpr index_t NFlatPerBlockPerIter = flatNPerWarp;
@@ -154,20 +171,20 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2
#else
static constexpr index_t mfma_per_wg = 1;
#endif
static constexpr index_t dsread_per_wg =
max(index_t(WG::kM * WG::kK * sizeof(ADataType) / WaveSize / Problem::VectorLoadSize), 1);
static constexpr index_t dsread_per_wg = max(
index_t(WarpTileM * WarpTileK * sizeof(ADataType) / WaveSize / Problem::VectorLoadSize), 1);
#if defined(__HIP_DEVICE_COMPILE__)
static_assert((WG::kM * WG::kK * sizeof(ADataType) * MIterPerWarp / WaveSize) %
static_assert((WarpTileM * WarpTileK * sizeof(ADataType) * MIterPerWarp / WaveSize) %
Problem::VectorLoadSize ==
0);
#endif
static constexpr index_t dsread_num_perK =
WG::kM * WG::kK * sizeof(ADataType) * MIterPerWarp / WaveSize / Problem::VectorLoadSize;
static constexpr index_t dsread_num_perK = WarpTileM * WarpTileK * sizeof(ADataType) *
MIterPerWarp / WaveSize / Problem::VectorLoadSize;
static constexpr index_t dswrite_num_perK = dsread_num_perK / (MWarp * NWarp);
static constexpr index_t dswrite_rep = (dswrite_num_perK + MIterPerWarp - 1) / MIterPerWarp;
static constexpr index_t Aload_num_perK = dswrite_num_perK;
static constexpr index_t Aload_rep = dswrite_rep;
static constexpr index_t Bload_num_perK = kNPerBlock * WG::kK / NWarp / K1 / WaveSize;
static constexpr index_t Bload_num_perK = kNPerBlock * WarpTileK / NWarp / K1 / WaveSize;
static constexpr index_t HalfMIter = (MIterPerWarp + 1) / 2;
static constexpr index_t Bload_rep = (Bload_num_perK + HalfMIter - 1) / HalfMIter;
@@ -187,7 +204,7 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2
// clang-format off
return concat('_', "pipeline_AGmemBGmemCRegV2",
concat('x', kMPerBlock, kNPerBlock, kKPerBlock, BlockSize),
concat('x', WG::kM, WG::kN, WG::kK),
concat('x', WarpTileM, WarpTileN, WarpTileK),
concat('x', GetVectorSizeA(), GetVectorSizeB()),
concat('x', kPadM, kPadN, kPadK));
@@ -195,14 +212,16 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2
}
static constexpr bool DoubleSmemBuffer = Problem::DoubleSmemBuffer;
static constexpr index_t Preshuffle = Problem::Preshuffle;
static constexpr index_t Preshuffle = Problem::Preshuffle;
using Base::UsePersistentKernel;
CK_TILE_HOST_DEVICE static constexpr auto TransposeC() { return Problem::TransposeC; }
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
{
return PipelinePolicy::template GetSmemSize<Problem>();
constexpr index_t smem_size = PipelinePolicy::template GetSmemSize<Problem>();
return DoubleSmemBuffer ? 2 * smem_size : smem_size;
}
// dsread_perM: how many LDS reads want to issue in this M-iter
@@ -515,515 +534,184 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2
// __builtin_amdgcn_sched_barrier(0);
}
template <TailNumber TailNum,
typename ADramBlockWindowTmp,
typename BFlatBlockWindowTmp,
typename AElementFunction,
typename std::enable_if_t<!is_detected<is_tuple, ADramBlockWindowTmp>::value &&
!is_detected<is_tuple, BFlatBlockWindowTmp>::value,
bool>* = nullptr,
index_t UnaryOpSize_ = 8>
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
const AElementFunction& a_element_func,
const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp,
index_t num_loop,
void* p_smem_ping,
void* p_smem_pong) const
struct PipelineImpl : public PipelineImplBase
{
static_assert(
std::is_same_v<ADataType, remove_cvref_t<typename ADramBlockWindowTmp::DataType>>,
"wrong!");
using Base = PipelineImplBase;
static_assert(kMPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<0>{}],
"wrong!");
static_assert(kKPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<1>{}],
"wrong!");
constexpr auto MIter_2nd_last = (MIterPerWarp >= 2) ? MIterPerWarp - 2 : MIterPerWarp - 1;
const index_t iMWarp = get_warp_id() / NWarp;
using CWarpDstr = typename WG::CWarpDstr;
using CWarpTensor = typename WG::CWarpTensor;
constexpr auto c_warp_y_lengths =
to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
__builtin_amdgcn_sched_barrier(0);
// A tile in LDS
ADataType* p_a_lds_ping = static_cast<ADataType*>(p_smem_ping);
ADataType* p_a_lds_pong = static_cast<ADataType*>(p_smem_pong);
constexpr auto a_lds_block_desc =
PipelinePolicy::template MakeALdsBlockDescriptor<Problem>();
auto a_lds_block_ping =
make_tensor_view<address_space_enum::lds>(p_a_lds_ping, a_lds_block_desc);
auto a_lds_block_pong =
make_tensor_view<address_space_enum::lds>(p_a_lds_pong, a_lds_block_desc);
// A DRAM tile window for load
auto a_copy_dram_window =
make_tile_window(a_dram_block_window_tmp.get_bottom_tensor_view(),
make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}),
a_dram_block_window_tmp.get_window_origin(),
PipelinePolicy::template MakeADramTileDistribution<Problem>());
auto a_copy_lds_window_ping =
make_tile_window(a_lds_block_ping,
make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}),
{0, 0},
PipelinePolicy::template MakeADramTileDistribution<Problem>());
auto a_copy_lds_window_pong =
make_tile_window(a_lds_block_pong,
make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}),
{0, 0},
PipelinePolicy::template MakeADramTileDistribution<Problem>());
// ping-pong window for A LDS
auto a_warp_window_ping_tmp =
make_tile_window(a_lds_block_ping,
make_tuple(number<WG::kM>{}, number<WG::kK>{}),
{iMWarp * WG::kM, 0},
make_static_tile_distribution(typename WG::AWarpDstrEncoding{}));
auto a_warp_window_pong_tmp =
make_tile_window(a_lds_block_pong,
make_tuple(number<WG::kM>{}, number<WG::kK>{}),
{iMWarp * WG::kM, 0},
make_static_tile_distribution(typename WG::AWarpDstrEncoding{}));
statically_indexed_array<
statically_indexed_array<decltype(a_warp_window_ping_tmp), KIterPerWarp>,
MIterPerWarp>
a_warp_windows_ping;
statically_indexed_array<
statically_indexed_array<decltype(a_warp_window_pong_tmp), KIterPerWarp>,
MIterPerWarp>
a_warp_windows_pong;
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
a_warp_windows_ping(mIter)(kIter) = a_warp_window_ping_tmp;
move_tile_window(a_warp_windows_ping(mIter)(kIter),
{mIter * MPerBlockPerIter, kIter * KPerBlockPerIter});
});
});
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
a_warp_windows_pong(mIter)(kIter) = a_warp_window_pong_tmp;
move_tile_window(a_warp_windows_pong(mIter)(kIter),
{mIter * MPerBlockPerIter, kIter * KPerBlockPerIter});
});
});
// Block GEMM
auto block_weight_preshuffle = BlockWeightPreshuffle();
// Acc register tile
auto c_block_tile = block_weight_preshuffle.MakeCBlockTile();
// B flat DRAM window for load
auto b_flat_distribution =
PipelinePolicy::template MakeBFlatDramTileDistribution<Problem>();
auto b_flat_dram_window = // tile_window_with_static_distribution
make_tile_window(
b_flat_dram_block_window_tmp.get_bottom_tensor_view(), // from kernel gemm_pad_views
make_tuple(number<flatNPerWarp>{}, number<flatKPerWarp>{}),
b_flat_dram_block_window_tmp.get_window_origin(),
b_flat_distribution);
// pingpong buffer for B
using BTypeToUse =
std::conditional_t<std::is_same_v<BDataType, pk_int4_t>, ADataType, BDataType>;
using BTileType = decltype(make_static_distributed_tensor<BTypeToUse>(b_flat_distribution));
statically_indexed_array<
statically_indexed_array<decltype(b_flat_dram_window), KIterPerWarp>,
NIterPerWarp>
b_flat_dram_windows;
statically_indexed_array<statically_indexed_array<BTileType, KIterPerWarp>, NIterPerWarp>
b_warp_tensor_ping;
statically_indexed_array<statically_indexed_array<BTileType, KIterPerWarp>, NIterPerWarp>
b_warp_tensor_pong;
// Prefetch A0
auto a_block_tile = load_tile(a_copy_dram_window);
// move A window to next k
move_tile_window(a_copy_dram_window, {0, kKPerBlock});
// prefetch B
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
move_tile_window(b_flat_dram_windows(nIter)(kIter),
{nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
load_int4_tile<BDataType, ADataType, UnaryOpSize_>(
b_warp_tensor_ping(nIter)(kIter), b_flat_dram_windows(nIter)(kIter));
});
});
// move B window to next flat K
move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock});
// Prefill A0
auto a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile);
store_tile(a_copy_lds_window_ping, a_block_tile_tmp);
__builtin_amdgcn_sched_barrier(0);
// Prefetch A1
a_block_tile = load_tile(a_copy_dram_window);
// move A window to next k
move_tile_window(a_copy_dram_window, {0, kKPerBlock});
// initialize C
tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile);
block_sync_lds();
// preload A00,A10 from lds
statically_indexed_array<decltype(load_tile(a_warp_windows_ping(number<0>{})(number<0>{}))),
m_preload>
a_warp_tensor;
static_for<0, m_preload, 1>{}([&](auto loadIter) {
constexpr auto mIter = loadIter % MIterPerWarp;
constexpr auto kIter = loadIter / MIterPerWarp;
a_warp_tensor(loadIter) =
load_tile(a_warp_windows_ping(number<mIter>{})(number<kIter>{}));
});
__builtin_amdgcn_sched_barrier(0);
// MAIN LOOP
index_t iCounter = (num_loop - 1) / 2;
while(iCounter > 0)
template <bool HasHotLoop,
TailNumber TailNum,
typename ADramBlockWindowTmp,
typename BFlatBlockWindowTmp,
typename AElementFunction,
typename std::enable_if_t<!is_detected<is_tuple, ADramBlockWindowTmp>::value &&
!is_detected<is_tuple, BFlatBlockWindowTmp>::value,
bool>* = nullptr,
index_t UnaryOpSize_ = 8>
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
[[maybe_unused]] const AElementFunction& a_element_func,
const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp,
index_t num_loop,
void* p_smem) const
{
// prefetch B(2i+1)
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
static_assert(
std::is_same_v<ADataType, remove_cvref_t<typename ADramBlockWindowTmp::DataType>>,
"wrong!");
move_tile_window(b_flat_dram_windows(nIter)(kIter),
{nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
static_assert(kMPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<0>{}],
"wrong!");
static_assert(kKPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<1>{}],
"wrong!");
load_int4_tile<BDataType, ADataType, UnaryOpSize_>(
b_warp_tensor_pong(nIter)(kIter), b_flat_dram_windows(nIter)(kIter));
});
});
// A tile in LDS
constexpr index_t smem_size = PipelinePolicy::template GetSmemSize<Problem>();
// Prefill A(2i+1)
a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile);
store_tile(a_copy_lds_window_pong, a_block_tile_tmp);
constexpr auto a_lds_block_desc =
PipelinePolicy::template MakeALdsBlockDescriptor<Problem>();
// Prefetch A(2i+2)
a_block_tile = load_tile(a_copy_dram_window);
// move A window to next k
move_tile_window(a_copy_dram_window, {0, kKPerBlock});
auto a_lds_blocks = generate_tuple(
[&](auto i) {
ADataType* p_a_lds = static_cast<ADataType*>(
static_cast<void*>(static_cast<char*>(p_smem) + smem_size * i.value));
return make_tensor_view<address_space_enum::lds>(p_a_lds, a_lds_block_desc);
},
number<2>{});
// GEMM 2i
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload;
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
// read C warp tensor from C block tensor
CWarpTensor c_warp_tensor;
constexpr auto a_lds_load_tile_distr = make_static_tile_distribution(
BlockWeightPreshuffle::MakeABlockDistributionEncode());
auto&& windows_result =
Base::GetAWindows(a_dram_block_window_tmp, a_lds_blocks, a_lds_load_tile_distr);
auto&& a_copy_dram_window = windows_result.template get<0>();
auto&& a_lds_windows = windows_result.template get<1>();
auto a_copy_lds_windows = generate_tuple(
[&](auto i) -> decltype(auto) { return a_lds_windows[i].template at<0>(); },
number<2>{});
// Block GEMM
auto block_weight_preshuffle = BlockWeightPreshuffle();
// Acc register tile
auto c_block_tile = block_weight_preshuffle.MakeCBlockTile();
c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
auto a_load_windows = generate_tuple(
[&](auto i) -> decltype(auto) {
return block_weight_preshuffle.MakeALoadWindows(a_copy_lds_windows[i]);
},
number<2>{});
// warp GEMM
WG{}(c_warp_tensor,
a_warp_tensor(number<AwarpIter>{}),
b_warp_tensor_ping(nIter)(kIter));
// B flat DRAM window for load
auto b_flat_distribution =
PipelinePolicy::template MakeBFlatDramTileDistribution<Problem>();
auto b_flat_dram_window = // tile_window_with_static_distribution
make_tile_window(b_flat_dram_block_window_tmp
.get_bottom_tensor_view(), // from kernel gemm_pad_views
make_tuple(number<flatNPerWarp * NIterPerWarp>{},
number<flatKPerWarp * KIterPerWarp>{}),
b_flat_dram_block_window_tmp.get_window_origin(),
b_flat_distribution);
// write C warp tensor into C block tensor
c_block_tile.set_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
c_warp_tensor.get_thread_buffer());
using ADramTileWindowStep = typename ADramBlockWindowTmp::BottomTensorIndex;
using BDramTileWindowStep = typename BFlatBlockWindowTmp::BottomTensorIndex;
constexpr ADramTileWindowStep a_dram_tile_window_step = make_array(0, kKPerBlock);
constexpr BDramTileWindowStep b_dram_tile_window_step = make_array(0, kflatKPerBlock);
__builtin_amdgcn_sched_barrier(0x7F6);
});
// preload next A from lds
if constexpr((kIter * MIterPerWarp + mIter) <
(KIterPerWarp * MIterPerWarp - m_preload))
using ABlockTileDistr = decltype(a_copy_dram_window.get_tile_distribution());
using ABlockTile =
decltype(make_static_distributed_tensor<ADataType>(ABlockTileDistr{}));
using BTypeToUse =
std::conditional_t<std::is_same_v<BDataType, pk_int4_t>, ADataType, BDataType>;
using BBlockTile =
decltype(make_static_distributed_tensor<BTypeToUse>(b_flat_distribution));
ABlockTile a_global_tile;
BBlockTile b_global_tile[2];
// // Prefetch A0
Base::GlobalPrefetch(a_global_tile, a_copy_dram_window, a_dram_tile_window_step);
Base::template GlobalPrefetch<BDataType, BTypeToUse, UnaryOpSize_>(
b_global_tile[0], b_flat_dram_window, b_dram_tile_window_step);
// Prefill A0
Base::LocalPrefill(a_copy_lds_windows[I0], a_global_tile);
// Prefetch A1
Base::GlobalPrefetch(a_global_tile, a_copy_dram_window, a_dram_tile_window_step);
// initialize C
tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile);
block_sync_lds();
// preload A00,A10 from lds
block_weight_preshuffle.LocalPrefetch(a_load_windows[I0]);
__builtin_amdgcn_sched_barrier(0);
// MAIN LOOP
if constexpr(HasHotLoop)
{
index_t i_global_read = amd_wave_read_first_lane(2);
do
{
{
constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp;
constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp);
a_warp_tensor(number<AwarpIter>{}) =
load_tile(a_warp_windows_ping(number<AmIter>{})(number<AkIter>{}));
}
Base::template GlobalPrefetch<BDataType, BTypeToUse, UnaryOpSize_>(
b_global_tile[1], b_flat_dram_window, b_dram_tile_window_step);
Base::LocalPrefill(a_copy_lds_windows[I1], a_global_tile);
Base::GlobalPrefetch(
a_global_tile, a_copy_dram_window, a_dram_tile_window_step);
block_weight_preshuffle(c_block_tile,
a_load_windows[I0],
b_global_tile[0],
b_flat_distribution);
// barrier
if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last))
block_weight_preshuffle.LocalPrefetch(a_load_windows[I1]);
HotLoopScheduler();
}
{
block_sync_lds();
Base::template GlobalPrefetch<BDataType, BTypeToUse, UnaryOpSize_>(
b_global_tile[0], b_flat_dram_window, b_dram_tile_window_step);
Base::LocalPrefill(a_copy_lds_windows[I0], a_global_tile);
Base::GlobalPrefetch(
a_global_tile, a_copy_dram_window, a_dram_tile_window_step);
block_weight_preshuffle(c_block_tile,
a_load_windows[I1],
b_global_tile[1],
b_flat_distribution);
block_weight_preshuffle.LocalPrefetch(a_load_windows[I0]);
HotLoopScheduler();
}
});
});
// move B window to next flat K
move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock});
i_global_read += 2;
} while(i_global_read < num_loop);
}
static_for<0, m_preload, 1>{}([&](auto loadIter) {
constexpr auto mIter = loadIter % MIterPerWarp;
constexpr auto kIter = loadIter / MIterPerWarp;
a_warp_tensor(loadIter) =
load_tile(a_warp_windows_pong(number<mIter>{})(number<kIter>{}));
});
HotLoopScheduler();
// tail
if constexpr(TailNum == TailNumber::Even)
{
{
Base::template GlobalPrefetch<BDataType, BTypeToUse, UnaryOpSize_>(
b_global_tile[1], b_flat_dram_window, b_dram_tile_window_step);
Base::LocalPrefill(a_copy_lds_windows[I1], a_global_tile);
block_weight_preshuffle(
c_block_tile, a_load_windows[I0], b_global_tile[0], b_flat_distribution);
block_sync_lds();
block_weight_preshuffle.LocalPrefetch(a_load_windows[I1]);
Last2ndHotLoopScheduler();
}
{
block_weight_preshuffle(
c_block_tile, a_load_windows[I1], b_global_tile[1], b_flat_distribution);
LastHotLoopScheduler();
}
}
else if constexpr(TailNum == TailNumber::Odd)
{
block_weight_preshuffle(
c_block_tile, a_load_windows[I0], b_global_tile[0], b_flat_distribution);
LastHotLoopScheduler();
}
// Next K
// prefetch B(2i+2)
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
move_tile_window(b_flat_dram_windows(nIter)(kIter),
{nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
load_int4_tile<BDataType, ADataType, UnaryOpSize_>(
b_warp_tensor_ping(nIter)(kIter), b_flat_dram_windows(nIter)(kIter));
});
});
// Prefill A(2i+2)
a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile);
store_tile(a_copy_lds_window_ping, a_block_tile_tmp);
// Prefetch A(2i+3)
a_block_tile = load_tile(a_copy_dram_window);
// move A window to next k
move_tile_window(a_copy_dram_window, {0, kKPerBlock});
// GEMM 2i+1
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload;
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
// read C warp tensor from C block tensor
CWarpTensor c_warp_tensor;
c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
// warp GEMM
WG{}(c_warp_tensor,
a_warp_tensor(number<AwarpIter>{}),
b_warp_tensor_pong(nIter)(kIter));
// write C warp tensor into C block tensor
c_block_tile.set_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
c_warp_tensor.get_thread_buffer());
__builtin_amdgcn_sched_barrier(0x7F6);
});
// preload next A from lds
if constexpr((kIter * MIterPerWarp + mIter) <
(KIterPerWarp * MIterPerWarp - m_preload))
{
constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp;
constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp);
a_warp_tensor(number<AwarpIter>{}) =
load_tile(a_warp_windows_pong(number<AmIter>{})(number<AkIter>{}));
}
// barrier
if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last))
{
block_sync_lds();
}
});
});
// move B window to next flat K
move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock});
static_for<0, m_preload, 1>{}([&](auto loadIter) {
constexpr auto mIter = loadIter % MIterPerWarp;
constexpr auto kIter = loadIter / MIterPerWarp;
a_warp_tensor(loadIter) =
load_tile(a_warp_windows_ping(number<mIter>{})(number<kIter>{}));
});
HotLoopScheduler();
iCounter--;
return c_block_tile;
}
// tail
if constexpr(TailNum == TailNumber::Even)
{
// __builtin_amdgcn_sched_barrier(0);
// prefetch B(loopK)
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
move_tile_window(b_flat_dram_windows(nIter)(kIter),
{nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
load_int4_tile<BDataType, ADataType, UnaryOpSize_>(
b_warp_tensor_pong(nIter)(kIter), b_flat_dram_windows(nIter)(kIter));
});
});
// Prefill A(loopK)
a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile);
store_tile(a_copy_lds_window_pong, a_block_tile_tmp);
// GEMM loopK-1
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload;
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
// read C warp tensor from C block tensor
CWarpTensor c_warp_tensor;
c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
// warp GEMM
WG{}(c_warp_tensor,
a_warp_tensor(number<AwarpIter>{}),
b_warp_tensor_ping(nIter)(kIter));
// write C warp tensor into C block tensor
c_block_tile.set_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
c_warp_tensor.get_thread_buffer());
__builtin_amdgcn_sched_barrier(0x7F6);
});
// preload next A from lds
if constexpr((kIter * MIterPerWarp + mIter) <
(KIterPerWarp * MIterPerWarp - m_preload))
{
constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp;
constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp);
a_warp_tensor(number<AwarpIter>{}) =
load_tile(a_warp_windows_ping(number<AmIter>{})(number<AkIter>{}));
}
// barrier
if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last))
{
block_sync_lds();
}
});
});
// TailHotLoopScheduler();
static_for<0, m_preload, 1>{}([&](auto loadIter) {
constexpr auto mIter = loadIter % MIterPerWarp;
constexpr auto kIter = loadIter / MIterPerWarp;
a_warp_tensor(loadIter) =
load_tile(a_warp_windows_pong(number<mIter>{})(number<kIter>{}));
});
Last2ndHotLoopScheduler();
// GEMM loopK
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload;
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
// read C warp tensor from C block tensor
CWarpTensor c_warp_tensor;
c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
// warp GEMM
WG{}(c_warp_tensor,
a_warp_tensor(number<AwarpIter>{}),
b_warp_tensor_pong(nIter)(kIter));
// write C warp tensor into C block tensor
c_block_tile.set_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
c_warp_tensor.get_thread_buffer());
});
if constexpr((kIter * MIterPerWarp + mIter) <
(KIterPerWarp * MIterPerWarp - m_preload))
{
constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp;
constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp);
a_warp_tensor(number<AwarpIter>{}) =
load_tile(a_warp_windows_pong(number<AmIter>{})(number<AkIter>{}));
}
// barrier
if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last))
{
block_sync_lds();
}
});
});
LastHotLoopScheduler();
}
else if constexpr(TailNum == TailNumber::Odd)
{
// GEMM loopK
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload;
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
// read C warp tensor from C block tensor
CWarpTensor c_warp_tensor;
c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
// warp GEMM
WG{}(c_warp_tensor,
a_warp_tensor(number<AwarpIter>{}),
b_warp_tensor_ping(nIter)(kIter));
// write C warp tensor into C block tensor
c_block_tile.set_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
c_warp_tensor.get_thread_buffer());
__builtin_amdgcn_sched_barrier(0x7F6);
});
// preload next A from lds
if constexpr((kIter * MIterPerWarp + mIter) <
(KIterPerWarp * MIterPerWarp - m_preload))
{
constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp;
constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp);
a_warp_tensor(number<AwarpIter>{}) =
load_tile(a_warp_windows_ping(number<AmIter>{})(number<AkIter>{}));
}
// barrier
if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last))
{
block_sync_lds();
}
});
});
LastHotLoopScheduler();
}
return c_block_tile;
}
};
// called from universal gemm kernel
template <typename ADramBlockWindowTmp,
@@ -1038,23 +726,20 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2
const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp,
[[maybe_unused]] const BElementFunction& b_element_func,
index_t num_loop,
void* p_smem_ping,
void* p_smem_pong) const
void* p_smem) const
{
const auto tail_number = Base::GetBlockLoopTailNum(num_loop);
const auto has_hot_loop = Base::BlockHasHotloop(num_loop);
const auto tail_number = Base::GetBlockLoopTailNum(num_loop);
const auto RunPipeline = [&](auto bool_val, auto tail_num_) {
(void)bool_val; // Suppress unused parameter warning
constexpr auto tail_num = tail_num_.value;
constexpr auto PassThrough = [](const ADataType& a) { return a; };
return operator()<tail_num>(a_dram_block_window_tmp[number<0>{}],
PassThrough,
b_flat_dram_block_window_tmp[number<0>{}],
num_loop,
p_smem_ping,
p_smem_pong);
const auto RunPipeline = [&](auto hot_loop_, auto tail_num_) {
return PipelineImpl{}.template operator()<hot_loop_.value, tail_num_.value>(
a_dram_block_window_tmp[number<0>{}],
a_element_func,
b_flat_dram_block_window_tmp[number<0>{}],
num_loop,
p_smem);
};
return Base::TailHandler(RunPipeline, true, tail_number);
return Base::TailHandler(RunPipeline, has_hot_loop, tail_number);
}
// called from general gemm kernel
@@ -1066,23 +751,21 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp,
index_t num_loop,
void* p_smem_ping,
void* p_smem_pong) const
void* p_smem) const
{
const auto tail_number = Base::GetBlockLoopTailNum(num_loop);
const auto has_hot_loop = Base::BlockHasHotloop(num_loop);
const auto tail_number = Base::GetBlockLoopTailNum(num_loop);
const auto RunPipeline = [&](auto bool_val, auto tail_num_) {
(void)bool_val; // Suppress unused parameter warning
constexpr auto tail_num = tail_num_.value;
const auto RunPipeline = [&](auto hot_loop_, auto tail_num_) {
constexpr auto PassThrough = [](const ADataType& a) { return a; };
return operator()<tail_num>(a_dram_block_window_tmp,
PassThrough,
b_flat_dram_block_window_tmp,
num_loop,
p_smem_ping,
p_smem_pong);
return PipelineImpl{}.template operator()<hot_loop_.value, tail_num_.value>(
a_dram_block_window_tmp,
PassThrough,
b_flat_dram_block_window_tmp,
num_loop,
p_smem);
};
return Base::TailHandler(RunPipeline, true, tail_number);
return Base::TailHandler(RunPipeline, has_hot_loop, tail_number);
}
// called from grouped gemm kernel
@@ -1095,21 +778,19 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2
const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp,
index_t num_loop,
TailNumber tail_number,
void* __restrict__ p_smem_0,
void* __restrict__ p_smem_1) const
void* __restrict__ p_smem) const
{
const auto RunPipeline = [&](auto bool_val, auto tail_num_) {
(void)bool_val; // Suppress unused parameter warning
constexpr auto tail_num = tail_num_.value;
const auto has_hot_loop = Base::BlockHasHotloop(num_loop);
const auto RunPipeline = [&](auto hot_loop_, auto tail_num_) {
constexpr auto PassThrough = [](const auto& x) { return x; };
return operator()<tail_num>(a_dram_block_window_tmp,
PassThrough,
b_flat_dram_block_window_tmp,
num_loop,
p_smem_0,
p_smem_1);
return PipelineImpl{}.template operator()<hot_loop_.value, tail_num_.value>(
a_dram_block_window_tmp,
PassThrough,
b_flat_dram_block_window_tmp,
num_loop,
p_smem);
};
return Base::TailHandler(RunPipeline, true, tail_number);
return Base::TailHandler(RunPipeline, has_hot_loop, tail_number);
}
};

View File

@@ -1723,7 +1723,7 @@ struct QuantGemmKernel
* @param aq_ptr input AQ pointer
* @param bq_ptr input BQ pointer
* @param c_ptr output C pointer
* @param smem_ptr_0 The start memory pointer of the shared memory block.
* @param smem_ptr The start memory pointer of the shared memory block.
* @param kargs GEMM kernel arguments
* @param splitk_batch_offset splitk_batch_offset Utility structure used to calculate k batch.
* @param block_idx_m The GEMM's output M dimension tile index processed by this workgroup.
@@ -1735,7 +1735,7 @@ struct QuantGemmKernel
const AQDataType* aq_ptr,
const BQDataType* bq_ptr,
CDataType* c_ptr,
void* smem_ptr_0,
void* smem_ptr,
const QuantGemmKernelArgs& kargs,
const SplitKBatchOffset& splitk_batch_offset,
const index_t block_idx_m,
@@ -1762,7 +1762,7 @@ struct QuantGemmKernel
m = kargs.M;
}
return GemmPipeline{}.template operator()(
a_block_window, b_block_window, aq_block_window, num_loop, smem_ptr_0, m);
a_block_window, b_block_window, aq_block_window, num_loop, smem_ptr, m);
}
else if constexpr(kQuantType == QuantType::BQuantGrouped)
{
@@ -1772,7 +1772,7 @@ struct QuantGemmKernel
n = kargs.N;
}
return GemmPipeline{}.template operator()(
a_block_window, b_block_window, bq_block_window, num_loop, smem_ptr_0, n);
a_block_window, b_block_window, bq_block_window, num_loop, smem_ptr, n);
}
else if constexpr(kQuantType == QuantType::ABQuantGrouped)
{
@@ -1788,7 +1788,7 @@ struct QuantGemmKernel
aq_block_window,
bq_block_window,
num_loop,
smem_ptr_0,
smem_ptr,
m,
n);
}
@@ -1796,7 +1796,7 @@ struct QuantGemmKernel
kQuantType == QuantType::TensorQuant)
{
return GemmPipeline{}.template operator()(
a_block_window, b_block_window, num_loop, smem_ptr_0);
a_block_window, b_block_window, num_loop, smem_ptr);
}
}();
@@ -1812,14 +1812,14 @@ struct QuantGemmKernel
kQuantType == QuantType::AQuantGrouped ||
kQuantType == QuantType::BQuantGrouped)
{
EpiloguePipeline{}(c_block_window, c_block_tile, c_block_window, smem_ptr_0);
EpiloguePipeline{}(c_block_window, c_block_tile, c_block_window, smem_ptr);
}
else if constexpr(kQuantType == QuantType::RowColQuant)
{
EpiloguePipeline{}(c_block_window,
c_block_tile,
c_block_window,
smem_ptr_0,
smem_ptr,
aq_block_window,
bq_block_window);
}
@@ -1828,7 +1828,7 @@ struct QuantGemmKernel
const AccDataType aq_scale = type_convert<AccDataType>(*aq_ptr);
const AccDataType bq_scale = type_convert<AccDataType>(*bq_ptr);
EpiloguePipeline{}(
c_block_window, c_block_tile, c_block_window, smem_ptr_0, aq_scale, bq_scale);
c_block_window, c_block_tile, c_block_window, smem_ptr, aq_scale, bq_scale);
}
}
else
@@ -1840,14 +1840,14 @@ struct QuantGemmKernel
kQuantType == QuantType::AQuantGrouped ||
kQuantType == QuantType::BQuantGrouped)
{
EpiloguePipeline{}(c_block_window, c_block_tile, c_block_window, smem_ptr_0);
EpiloguePipeline{}(c_block_window, c_block_tile, c_block_window, smem_ptr);
}
else if constexpr(kQuantType == QuantType::RowColQuant)
{
EpiloguePipeline{}(c_block_window,
c_block_tile,
c_block_window,
smem_ptr_0,
smem_ptr,
aq_block_window,
bq_block_window);
}
@@ -1856,89 +1856,7 @@ struct QuantGemmKernel
const AccDataType aq_scale = type_convert<AccDataType>(*aq_ptr);
const AccDataType bq_scale = type_convert<AccDataType>(*bq_ptr);
EpiloguePipeline{}(
c_block_window, c_block_tile, c_block_window, smem_ptr_0, aq_scale, bq_scale);
}
}
}
/**
* @brief Runs single GEMM problem cooperatively by whole workgroup.
*
* @note RunGemm2LDS in with two shared memory buffers using the ping pong buffer mechanism.
*
* @param a_ptr input A pointer
* @param b_ptr input B pointer
* @param aq_ptr input AQ pointer
* @param bq_ptr input BQ pointer
* @param c_ptr output C pointer
* @param smem_ptr_0 The starting pointer of 1st shared memory block.
* @param smem_ptr_1 The starting pointer of 2nd shared memory block.
* @param kargs GEMM kernel arguments
* @param splitk_batch_offset Utility structure used to calculate k batch.
* @param block_idx_m The GEMM's output M dimension tile index processed by this workgroup.
* @param block_idx_n The GEMM's output N dimension tile index processed by this workgroup.
*
*/
CK_TILE_DEVICE static void RunGemm2LDS(const ADataType* a_ptr,
const BDataType* b_ptr,
[[maybe_unused]] const AQDataType* aq_ptr,
const BQDataType* bq_ptr,
CDataType* c_ptr,
void* __restrict__ smem_ptr_0,
void* __restrict__ smem_ptr_1,
const QuantGemmKernelArgs& kargs,
const SplitKBatchOffset& splitk_batch_offset,
const index_t block_idx_m,
const index_t block_idx_n)
{
// Create block windows using specialized methods
const auto& a_block_window =
MakeABlockWindow(a_ptr, kargs, splitk_batch_offset.splitted_k, block_idx_m);
const auto& b_block_window =
MakeBBlockWindow(b_ptr, kargs, splitk_batch_offset.splitted_k, block_idx_n);
const auto& bq_block_window = MakeBQBlockWindow(bq_ptr, kargs, block_idx_m, block_idx_n);
const index_t num_loop =
amd_wave_read_first_lane(TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k));
// Run GEMM cooperatively by whole workgroup.
const auto& c_block_tile = [&]() {
if constexpr(kQuantType == QuantType::BQuantGrouped)
{
index_t n = 0;
if constexpr(PreshuffleQuant)
{
n = kargs.N;
}
return GemmPipeline{}.template operator()(a_block_window,
b_block_window,
bq_block_window,
num_loop,
smem_ptr_0,
smem_ptr_1,
n);
}
else
{
return nullptr;
}
}();
const index_t k_batch = amd_wave_read_first_lane(kargs.k_batch);
// Run Epilogue Pipeline with k_batch dispatch
if constexpr(kQuantType == QuantType::BQuantGrouped)
{
if(k_batch == 1)
{
auto c_block_window = MakeCBlockWindow<memory_operation_enum::set>(
c_ptr, kargs, block_idx_m, block_idx_n);
EpiloguePipeline{}(c_block_window, c_block_tile, c_block_window, smem_ptr_0);
}
else
{
auto c_block_window = MakeCBlockWindow<memory_operation_enum::atomic_add>(
c_ptr, kargs, block_idx_m, block_idx_n);
EpiloguePipeline{}(c_block_window, c_block_tile, c_block_window, smem_ptr_0);
c_block_window, c_block_tile, c_block_window, smem_ptr, aq_scale, bq_scale);
}
}
}
@@ -1961,37 +1879,10 @@ struct QuantGemmKernel
CDataType* c_ptr = static_cast<CDataType*>(kargs.c_ptr);
// allocate LDS
__shared__ char smem_ptr_0[GetSmemSize()];
__shared__ char smem_ptr[GetSmemSize()];
if constexpr(GemmPipeline::DoubleSmemBuffer == true)
{
__shared__ char smem_ptr_1[GemmPipeline::GetSmemSize()];
RunGemm2LDS(a_ptr,
b_ptr,
aq_ptr,
bq_ptr,
c_ptr,
smem_ptr_0,
smem_ptr_1,
kargs,
splitk_batch_offset,
i_m,
i_n);
}
else
{
RunGemm(a_ptr,
b_ptr,
aq_ptr,
bq_ptr,
c_ptr,
smem_ptr_0,
kargs,
splitk_batch_offset,
i_m,
i_n);
}
RunGemm(
a_ptr, b_ptr, aq_ptr, bq_ptr, c_ptr, smem_ptr, kargs, splitk_batch_offset, i_m, i_n);
}
};

View File

@@ -318,21 +318,18 @@ struct QuantGroupedGemmKernel
CDataType* c_ptr = static_cast<CDataType*>(kargs.c_ptr);
// allocate LDS
__shared__ char smem_ptr_0[GetSmemSize()];
__shared__ char smem_ptr[GetSmemSize()];
// Only for BQuantGrouped DoubleSmemBuffer is supported
if constexpr(GemmPipeline::DoubleSmemBuffer == true &&
kQuantType == QuantType::BQuantGrouped)
{
__shared__ char smem_ptr_1[GemmPipeline::GetSmemSize()];
RunGemmWithPipelineSelection2LDS(a_ptr,
b_ptr,
aq_ptr,
bq_ptr,
c_ptr,
smem_ptr_0,
smem_ptr_1,
smem_ptr,
kargs,
splitk_batch_offset,
i_m,
@@ -348,7 +345,7 @@ struct QuantGroupedGemmKernel
aq_ptr,
bq_ptr,
c_ptr,
smem_ptr_0,
smem_ptr,
kargs,
splitk_batch_offset,
i_m,
@@ -361,7 +358,7 @@ struct QuantGroupedGemmKernel
aq_ptr,
bq_ptr,
c_ptr,
smem_ptr_0,
smem_ptr,
kargs,
splitk_batch_offset,
i_m,
@@ -377,8 +374,7 @@ struct QuantGroupedGemmKernel
[[maybe_unused]] const AQDataType* aq_ptr,
const BQDataType* bq_ptr,
CDataType* c_ptr,
void* smem_ptr_0,
void* smem_ptr_1,
void* smem_ptr,
const QuantGroupedGemmKernelArgs& kargs,
const typename Base::SplitKBatchOffset& splitk_batch_offset,
const index_t block_idx_m,
@@ -399,27 +395,22 @@ struct QuantGroupedGemmKernel
const TailNumber tail_num = GemmPipeline::GetBlockLoopTailNum(num_loop);
// Run GEMM cooperatively by whole workgroup
const auto& c_block_tile = GemmPipeline{}.template operator()(a_block_window,
b_block_window,
bq_block_window,
num_loop,
tail_num,
smem_ptr_0,
smem_ptr_1);
const auto& c_block_tile = GemmPipeline{}.template operator()(
a_block_window, b_block_window, bq_block_window, num_loop, tail_num, smem_ptr);
// Run Epilogue Pipeline with split_k dispatch
if(kargs.k_batch == 1)
{
auto c_block_window = Base::template MakeCBlockWindow<memory_operation_enum::set>(
c_ptr, kargs, block_idx_m, block_idx_n);
EpiloguePipeline{}(c_block_window, c_block_tile, c_block_window, smem_ptr_0);
EpiloguePipeline{}(c_block_window, c_block_tile, c_block_window, smem_ptr);
}
else
{
auto c_block_window =
Base::template MakeCBlockWindow<memory_operation_enum::atomic_add>(
c_ptr, kargs, block_idx_m, block_idx_n);
EpiloguePipeline{}(c_block_window, c_block_tile, c_block_window, smem_ptr_0);
EpiloguePipeline{}(c_block_window, c_block_tile, c_block_window, smem_ptr);
}
}
@@ -435,7 +426,7 @@ struct QuantGroupedGemmKernel
* @param aq_ptr input AQ pointer
* @param bq_ptr input BQ pointer
* @param c_ptr output C pointer
* @param smem_ptr_0 The start memory pointer of the shared memory block.
* @param smem_ptr The start memory pointer of the shared memory block.
* @param kargs GEMM kernel arguments
* @param splitk_batch_offset splitk_batch_offset Utility structure used to calculate k
* batch.
@@ -449,7 +440,7 @@ struct QuantGroupedGemmKernel
const AQDataType* aq_ptr,
const BQDataType* bq_ptr,
CDataType* c_ptr,
void* smem_ptr_0,
void* smem_ptr,
const QuantGroupedGemmKernelArgs& kargs,
const typename Base::SplitKBatchOffset& splitk_batch_offset,
const index_t block_idx_m,
@@ -481,7 +472,7 @@ struct QuantGroupedGemmKernel
num_loop,
has_hot_loop,
tail_num,
smem_ptr_0);
smem_ptr);
}
else if constexpr(kQuantType == QuantType::BQuantGrouped)
{
@@ -491,13 +482,13 @@ struct QuantGroupedGemmKernel
num_loop,
has_hot_loop,
tail_num,
smem_ptr_0);
smem_ptr);
}
else if constexpr(kQuantType == QuantType::RowColQuant ||
kQuantType == QuantType::TensorQuant)
{
return GemmPipeline{}.template operator()(
a_block_window, b_block_window, num_loop, has_hot_loop, tail_num, smem_ptr_0);
a_block_window, b_block_window, num_loop, has_hot_loop, tail_num, smem_ptr);
}
}();
@@ -510,14 +501,14 @@ struct QuantGroupedGemmKernel
if constexpr(kQuantType == QuantType::AQuantGrouped ||
kQuantType == QuantType::BQuantGrouped)
{
EpiloguePipeline{}(c_block_window, c_block_tile, c_block_window, smem_ptr_0);
EpiloguePipeline{}(c_block_window, c_block_tile, c_block_window, smem_ptr);
}
else if constexpr(kQuantType == QuantType::RowColQuant)
{
EpiloguePipeline{}(c_block_window,
c_block_tile,
c_block_window,
smem_ptr_0,
smem_ptr,
aq_block_window,
bq_block_window);
}
@@ -526,7 +517,7 @@ struct QuantGroupedGemmKernel
const AccDataType aq_scale = type_convert<AccDataType>(*aq_ptr);
const AccDataType bq_scale = type_convert<AccDataType>(*bq_ptr);
EpiloguePipeline{}(
c_block_window, c_block_tile, c_block_window, smem_ptr_0, aq_scale, bq_scale);
c_block_window, c_block_tile, c_block_window, smem_ptr, aq_scale, bq_scale);
}
}
else
@@ -538,14 +529,14 @@ struct QuantGroupedGemmKernel
if constexpr(kQuantType == QuantType::AQuantGrouped ||
kQuantType == QuantType::BQuantGrouped)
{
EpiloguePipeline{}(c_block_window, c_block_tile, c_block_window, smem_ptr_0);
EpiloguePipeline{}(c_block_window, c_block_tile, c_block_window, smem_ptr);
}
else if constexpr(kQuantType == QuantType::RowColQuant)
{
EpiloguePipeline{}(c_block_window,
c_block_tile,
c_block_window,
smem_ptr_0,
smem_ptr,
aq_block_window,
bq_block_window);
}
@@ -554,7 +545,7 @@ struct QuantGroupedGemmKernel
const AccDataType aq_scale = type_convert<AccDataType>(*aq_ptr);
const AccDataType bq_scale = type_convert<AccDataType>(*bq_ptr);
EpiloguePipeline{}(
c_block_window, c_block_tile, c_block_window, smem_ptr_0, aq_scale, bq_scale);
c_block_window, c_block_tile, c_block_window, smem_ptr, aq_scale, bq_scale);
}
}
}

View File

@@ -29,6 +29,48 @@ struct GemmWPQuantPipelineAgBgCrPolicy : public UniversalWeightPreshufflePipelin
return GemmBQuantPipelineAgBgCrDefaultPolicy::MakeBQDramTileDistribution<Problem>();
}
// as UniversalWeightPreshufflePipelineAgBgCrPolicy's MakeBFlatDramTileDistribution is changed;
// move original UniversalWeightPreshufflePipelineAgBgCrPolicy's implementation to here
// temporarily
template <typename Problem>
CK_TILE_DEVICE static constexpr auto MakeBFlatDramTileDistribution()
{
using TileShape = typename Problem::BlockGemmShape;
constexpr index_t BlockSize = Problem::kBlockSize;
constexpr index_t WaveSize = get_warp_size();
constexpr index_t WaveNum = BlockSize / WaveSize;
constexpr index_t KBPerLoad = GetKBPerLoad<Problem>();
#if defined(__gfx11__)
constexpr index_t KRepeatInWave = 2;
#else
constexpr index_t KRepeatInWave = 1;
#endif
constexpr index_t KThdPerWave = WaveSize / KRepeatInWave; // threads cnt in K dim
constexpr index_t KWavePerBlk = 1;
constexpr index_t KRepeat = 1;
static_assert(TileShape::flatKPerWarp == KThdPerWave * KBPerLoad, "wrong");
constexpr index_t NBPerLoad = 1;
constexpr index_t NThdPerWave = 1;
constexpr index_t NWavePerBlk = TileShape::BlockWarps::at(number<1>{}); // N_Warp
constexpr index_t NRepeat = 1;
constexpr index_t WaveRepeat = WaveNum / TileShape::flatNPerWarp;
return make_static_tile_distribution(
tile_distribution_encoding<
sequence<WaveRepeat, KRepeatInWave>, // ?
tuple<sequence<NRepeat, NWavePerBlk, NThdPerWave, NBPerLoad>, // second direction
sequence<KRepeat, KWavePerBlk, KThdPerWave, KBPerLoad>>, // first direction
// wave in blk, // thd in wave
// <M, K> // <M, K>
tuple<sequence<0, 1, 2>, sequence<0, 1, 2>>, // which direction
tuple<sequence<0, 1, 1>, sequence<1, 2, 2>>, // which index
// <repeat, vec_load>
sequence<1, 1, 2, 2>,
sequence<0, 3, 0, 3>>{});
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetBlockWeightPreshuffleBQuant()
{

View File

@@ -184,8 +184,7 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV
const BQDramBlockWindowTmp& bq_dram_block_window_tmp,
index_t n,
index_t num_loop,
void* p_smem_ping,
void* p_smem_pong) const
void* p_smem) const
{
static_assert(
std::is_same_v<ADataType, remove_cvref_t<typename ADramBlockWindowTmp::DataType>> &&
@@ -210,8 +209,10 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV
__builtin_amdgcn_sched_barrier(0);
// A tile in LDS
ADataType* p_a_lds_ping = static_cast<ADataType*>(p_smem_ping);
ADataType* p_a_lds_pong = static_cast<ADataType*>(p_smem_pong);
constexpr index_t smem_size = PipelinePolicy::template GetSmemSize<Problem>();
ADataType* p_a_lds_ping = static_cast<ADataType*>(p_smem);
ADataType* p_a_lds_pong =
reinterpret_cast<ADataType*>(static_cast<char*>(p_smem) + smem_size);
constexpr auto a_lds_block_desc =
PipelinePolicy::template MakeALdsBlockDescriptor<Problem>();
@@ -561,9 +562,8 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV
const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp,
const BQDramBlockWindowTmp& bq_dram_block_window_tmp,
index_t num_loop,
void* p_smem_ping,
void* p_smem_pong,
index_t n = 0) const // Default value for non-preshuffle case
void* p_smem,
index_t n = 0) const
{
return operator()<TailNum>(
a_dram_block_window_tmp,
@@ -572,8 +572,7 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV
bq_dram_block_window_tmp,
n,
num_loop,
p_smem_ping,
p_smem_pong);
p_smem);
}
template <typename ADramBlockWindowTmp,
@@ -584,8 +583,7 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV
const BQDramBlockWindowTmp& bq_dram_block_window_tmp,
index_t num_loop,
TailNumber tail_number,
void* p_smem_ping,
void* p_smem_pong,
void* p_smem,
index_t n = 0) const
{
const auto RunPipeline = [&](auto bool_val, auto tail_num_) {
@@ -598,8 +596,7 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV
bq_dram_block_window_tmp,
n, // dummy value, won't be used
num_loop,
p_smem_ping,
p_smem_pong);
p_smem);
};
return Base::TailHandler(RunPipeline, true, tail_number);
}