mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-07-02 13:17:36 +00:00
Tile distribution changes to make single warp read tile_window and perform GEMM independently in the 2-warp-thread-block
This commit is contained in:
@@ -263,9 +263,10 @@ int run_gemm_example_with_layouts(int argc,
|
||||
// set 1 column in A and 1 Row in B to perform outer product.
|
||||
// and test the results.
|
||||
//const ck_tile::index_t K_len = a_m_k.get_length(1);
|
||||
const ck_tile::index_t M_len = a_m_k.get_length(0);
|
||||
const ck_tile::index_t N_len = b_k_n.get_length(1);
|
||||
//const ck_tile::index_t M_len = a_m_k.get_length(0);
|
||||
//const ck_tile::index_t N_len = b_k_n.get_length(1);
|
||||
|
||||
/*
|
||||
// Fill 0th column in A
|
||||
ck_tile::half_t dd = 1;
|
||||
for(int i = 0; i < M_len; i++)
|
||||
@@ -273,11 +274,7 @@ int run_gemm_example_with_layouts(int argc,
|
||||
int j = 0;
|
||||
{
|
||||
a_m_k(i, j) = dd;
|
||||
}
|
||||
int k = 8;
|
||||
{
|
||||
a_m_k(i, k) = dd++;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Fill 0th row in B
|
||||
@@ -289,15 +286,7 @@ int run_gemm_example_with_layouts(int argc,
|
||||
b_k_n(i, j) = dd;
|
||||
}
|
||||
}
|
||||
|
||||
i = 8;
|
||||
{
|
||||
for(int j=0; j < N_len; j++)
|
||||
{
|
||||
b_k_n(i, j) = dd;
|
||||
}
|
||||
}
|
||||
|
||||
*/
|
||||
|
||||
ck_tile::DeviceMem a_m_k_dev_buf(a_m_k.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size_in_bytes());
|
||||
@@ -441,5 +430,17 @@ int run_gemm_example_with_layouts(int argc,
|
||||
std::cout << "The GPU verification result is: " << (pass ? "correct" : "fail") << std::endl;
|
||||
}
|
||||
|
||||
const ck_tile::index_t M_len = c_m_n_dev_result.get_length(0);
|
||||
const ck_tile::index_t N_len = c_m_n_dev_result.get_length(1);
|
||||
|
||||
for (int i = 0; i < M_len; i++)
|
||||
{
|
||||
for (int j = 0; j < N_len; j++)
|
||||
{
|
||||
std::cout << std::setw(6) << ck_tile::type_convert<float>(c_m_n_dev_result(i, j));
|
||||
}
|
||||
std::cout<<std::endl;
|
||||
}
|
||||
|
||||
return pass;
|
||||
}
|
||||
|
||||
@@ -75,7 +75,7 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
|
||||
tail_number_v>;
|
||||
|
||||
using GemmPipeline = GEMM_PIPELINE<UniversalGemmProblem>;
|
||||
|
||||
/*
|
||||
using GemmEpilogue = ck_tile::DefaultGemm2DEpilogue<
|
||||
ck_tile::DefaultGemm2DEpilogueProblem<AccDataType,
|
||||
CDataType,
|
||||
@@ -86,9 +86,9 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
|
||||
GemmConfig::N_Warp_Tile,
|
||||
GemmConfig::K_Warp_Tile,
|
||||
UniversalGemmProblem::TransposeC>>;
|
||||
*/
|
||||
|
||||
|
||||
/*
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
@@ -104,7 +104,7 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
|
||||
GemmConfig::N_Warp_Tile,
|
||||
GemmConfig::K_Warp_Tile,
|
||||
UniversalGemmProblem::TransposeC>>;
|
||||
*/
|
||||
|
||||
using Kernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
|
||||
auto kargs = Kernel::MakeKernelArgs(args);
|
||||
|
||||
|
||||
@@ -56,17 +56,19 @@ template <index_t BlockSize,
|
||||
index_t YPerTile,
|
||||
index_t XPerTile,
|
||||
index_t VecSize,
|
||||
index_t NumWarpGroups,
|
||||
tile_distribution_pattern DistributionPattern>
|
||||
struct TileDistributionEncodingPattern2D : public TileDistributionEncodingPattern
|
||||
{
|
||||
};
|
||||
|
||||
// Thread raked
|
||||
template <index_t BlockSize, index_t YPerTile, index_t XPerTile, index_t VecSize>
|
||||
template <index_t BlockSize, index_t YPerTile, index_t XPerTile, index_t VecSize, index_t NumWarpGroups>
|
||||
struct TileDistributionEncodingPattern2D<BlockSize,
|
||||
YPerTile,
|
||||
XPerTile,
|
||||
VecSize,
|
||||
NumWarpGroups,
|
||||
tile_distribution_pattern::thread_raked>
|
||||
: public TileDistributionEncodingPattern
|
||||
{
|
||||
@@ -90,6 +92,14 @@ struct TileDistributionEncodingPattern2D<BlockSize,
|
||||
static_assert(X0 * Y1 * Y0 == BlockSize, "X0 * warp_ys * Y0 must cover whole workgroup!");
|
||||
static_assert(Y0 * Y1 * Y2 == YPerTile, "Y0, Y1, Y2 must cover whole YPerTile");
|
||||
|
||||
static constexpr index_t PX0 = X0;
|
||||
static constexpr index_t PX1 = X1;
|
||||
|
||||
static constexpr index_t PY0 = num_warps / NumWarpGroups;
|
||||
static constexpr index_t PY1 = Y1;
|
||||
static constexpr index_t PY2 = YPerTile / (PY1 * PY0); // only for the active warps
|
||||
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto Make2DStaticTileDistribution()
|
||||
{
|
||||
return make_static_tile_distribution(
|
||||
@@ -111,6 +121,50 @@ struct TileDistributionEncodingPattern2D<BlockSize,
|
||||
sequence<1, 2>,
|
||||
sequence<1, 2>>{});
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakePingPong2DStaticTileDistribution()
|
||||
{
|
||||
// X0, X1
|
||||
// X0 - One Row/Column needs X1 no. of instructions to read/write.
|
||||
// X1 - VecSize - The read instruction size.
|
||||
// X is always the fastest changing dimension of the input matrix.
|
||||
|
||||
// Y0, Y1, Y2
|
||||
// Y0 - Total number of warps in a thread group.
|
||||
// Y1 - WarpSize / no-of-threads-in-N-dimension.
|
||||
// - No. of threads needed in the M dimension
|
||||
// Y2 - YPerTile / (Y1 * Y0)
|
||||
// - Y size / (no. of threads on Y dimensions * no. of warps)
|
||||
// - Total no. of iterations needed by all the warps in the thread group to cover the
|
||||
// - entire tile window.
|
||||
|
||||
// (2, 0) = PY0 -- Number of warps in the threadblock
|
||||
// (2, 1) * (1, 0) = PY1 * PX0, (M threads) * (N Threads)
|
||||
|
||||
static_assert(NumWarpGroups == 2);
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<PX0, PX1>, sequence<PY0, PY1, PY2>>,
|
||||
tuple<sequence<2>, sequence<2, 1>>,
|
||||
tuple<sequence<0>, sequence<1, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<1, 2>>{}
|
||||
);
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakePingPongShuffled2DStaticTileDistribution()
|
||||
{
|
||||
static_assert(NumWarpGroups == 2);
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<PX0, PX1>, sequence<PY0, PY1, PY2>>,
|
||||
tuple<sequence<2>, sequence<2, 1>>,
|
||||
tuple<sequence<0>, sequence<1, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<1, 2>>{}
|
||||
);
|
||||
}
|
||||
};
|
||||
|
||||
// Warp raked
|
||||
@@ -119,6 +173,7 @@ struct TileDistributionEncodingPattern2D<BlockSize,
|
||||
YPerTile,
|
||||
XPerTile,
|
||||
VecSize,
|
||||
1,
|
||||
tile_distribution_pattern::warp_raked>
|
||||
: public TileDistributionEncodingPattern
|
||||
{
|
||||
@@ -167,6 +222,7 @@ struct TileDistributionEncodingPattern2D<BlockSize,
|
||||
YPerTile,
|
||||
XPerTile,
|
||||
VecSize,
|
||||
1,
|
||||
tile_distribution_pattern::block_raked>
|
||||
: public TileDistributionEncodingPattern
|
||||
{
|
||||
|
||||
@@ -151,6 +151,7 @@ struct CShuffleEpilogue
|
||||
kMPerIteration,
|
||||
kNPerIteration,
|
||||
GetVectorSizeC(),
|
||||
1,
|
||||
tile_distribution_pattern::thread_raked>;
|
||||
constexpr auto dram_tile_distribution = TileEncodingPattern::Make2DStaticTileDistribution();
|
||||
|
||||
|
||||
@@ -30,8 +30,8 @@
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v5.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4_default_policy.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v5.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp"
|
||||
|
||||
@@ -11,11 +11,11 @@ namespace ck_tile {
|
||||
// A is block distributed tensor
|
||||
// B is block distributed tensor
|
||||
// C is block distributed tensor
|
||||
template <typename Problem_, typename Policy_ = BlockGemmARegBRegCRegV1DefaultPolicy>
|
||||
template <typename Problem_, typename Policy_ = BlockGemmARegBRegCRegV1DefaultPolicy, index_t NumWarpGroups = 1>
|
||||
struct BlockGemmARegBRegCRegV1
|
||||
{
|
||||
private:
|
||||
template <typename PipelineProblem_, typename GemmPolicy_>
|
||||
template <typename PipelineProblem_, typename GemmPolicy_, index_t NumWarpGroups_>
|
||||
struct GemmTraits_
|
||||
{
|
||||
using Problem = remove_cvref_t<PipelineProblem_>;
|
||||
@@ -34,7 +34,7 @@ struct BlockGemmARegBRegCRegV1
|
||||
static constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>();
|
||||
using WarpGemm = remove_cvref_t<decltype(config.template at<0>())>;
|
||||
|
||||
static constexpr index_t MWarp = config.template at<1>();
|
||||
static constexpr index_t MWarp = (config.template at<1>()) / NumWarpGroups;
|
||||
static constexpr index_t NWarp = config.template at<2>();
|
||||
static constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WarpGemm::kM);
|
||||
static constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WarpGemm::kN);
|
||||
@@ -47,7 +47,7 @@ struct BlockGemmARegBRegCRegV1
|
||||
using Problem = remove_cvref_t<Problem_>;
|
||||
using Policy = remove_cvref_t<Policy_>;
|
||||
|
||||
using Traits = GemmTraits_<Problem, Policy>;
|
||||
using Traits = GemmTraits_<Problem, Policy, NumWarpGroups>;
|
||||
|
||||
using WarpGemm = typename Traits::WarpGemm;
|
||||
using BlockGemmShape = typename Traits::BlockGemmShape;
|
||||
@@ -65,6 +65,8 @@ struct BlockGemmARegBRegCRegV1
|
||||
|
||||
CK_TILE_DEVICE static constexpr auto MakeABlockDistributionEncode()
|
||||
{
|
||||
static_assert( MWarp == 1);
|
||||
static_assert(MIterPerWarp == 2);
|
||||
constexpr auto a_block_outer_dstr_encoding =
|
||||
tile_distribution_encoding<sequence<NWarp>,
|
||||
tuple<sequence<MIterPerWarp, MWarp>, sequence<KIterPerWarp>>,
|
||||
@@ -95,6 +97,7 @@ struct BlockGemmARegBRegCRegV1
|
||||
|
||||
CK_TILE_DEVICE static constexpr auto MakeCBlockDistributionEncode()
|
||||
{
|
||||
static_assert(MWarp * NWarp == 1);
|
||||
constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<MIterPerWarp, MWarp>, sequence<NIterPerWarp, NWarp>>,
|
||||
|
||||
@@ -67,7 +67,7 @@ struct GemmPipelineAgBgCrCompV4DefaultPolicy
|
||||
return b_lds_block_desc;
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
template <typename Problem, index_t NumWarpGroups=1>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm()
|
||||
{
|
||||
using AccDataType = float;
|
||||
@@ -86,7 +86,9 @@ struct GemmPipelineAgBgCrCompV4DefaultPolicy
|
||||
BlockWarps,
|
||||
WarpGemm>;
|
||||
|
||||
return BlockGemmARegBRegCRegV1<Problem, BlockGemmPolicy>{};
|
||||
static_assert(NumWarpGroups == 2);
|
||||
|
||||
return BlockGemmARegBRegCRegV1<Problem, BlockGemmPolicy, NumWarpGroups>{};
|
||||
}
|
||||
};
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -14,7 +14,7 @@ struct BaseGemmPipelineAgBgCrCompV5
|
||||
{
|
||||
static constexpr index_t PrefetchStages = 1;
|
||||
static constexpr index_t PrefillStages = 1;
|
||||
static constexpr index_t GlobalBufferNum = 2;
|
||||
static constexpr index_t GlobalBufferNum = 1;
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto TransposeC() { return Problem::TransposeC; }
|
||||
|
||||
@@ -25,14 +25,11 @@ struct BaseGemmPipelineAgBgCrCompV5
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr TailNumber GetBlockLoopTailNum(index_t num_loop)
|
||||
{
|
||||
if(num_loop > PrefetchStages)
|
||||
if(num_loop > 0)
|
||||
{
|
||||
return TailNumber::One;
|
||||
}
|
||||
else
|
||||
{
|
||||
return TailNumber::Two;
|
||||
}
|
||||
return TailNumber::One;
|
||||
}
|
||||
};
|
||||
|
||||
@@ -51,7 +48,9 @@ struct GemmPipelineAgBgCrCompV5 : public BaseGemmPipelineAgBgCrCompV5<Problem>
|
||||
using BLayout = remove_cvref_t<typename Problem::BLayout>;
|
||||
using CLayout = remove_cvref_t<typename Problem::CLayout>;
|
||||
|
||||
using BlockGemm = remove_cvref_t<decltype(Policy::template GetBlockGemm<Problem>())>;
|
||||
static constexpr index_t NumWarpGroups = 2;
|
||||
|
||||
using BlockGemm = remove_cvref_t<decltype(Policy::template GetBlockGemm<Problem, NumWarpGroups>())>;
|
||||
using I0 = number<0>;
|
||||
using I1 = number<1>;
|
||||
using I2 = number<2>;
|
||||
@@ -102,6 +101,255 @@ struct GemmPipelineAgBgCrCompV5 : public BaseGemmPipelineAgBgCrCompV5<Problem>
|
||||
{
|
||||
using Base = PipelineImplBase;
|
||||
|
||||
template <bool HasHotLoop,
|
||||
TailNumber TailNum,
|
||||
index_t NumWarpGroups,
|
||||
typename ADramBlockWindowTmp,
|
||||
typename AElementFunction,
|
||||
typename BDramBlockWindowTmp,
|
||||
typename BElementFunction>
|
||||
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const AElementFunction& a_element_func,
|
||||
const BDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
const BElementFunction& b_element_func,
|
||||
index_t num_loop,
|
||||
void* __restrict__ p_smem_0) const
|
||||
{
|
||||
static_assert(
|
||||
std::is_same_v<ADataType, remove_cvref_t<typename ADramBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<BDataType,
|
||||
remove_cvref_t<typename BDramBlockWindowTmp::DataType>>,
|
||||
"Data Type conflict on A and B matrix input data type.");
|
||||
|
||||
static_assert(
|
||||
KPerBlock % ((NumWarps / 2) * KTileSize) == 0,
|
||||
"Ping Pong Warps, TileSize and Block Size for K dimensions does not match.");
|
||||
|
||||
constexpr bool is_a_col_major =
|
||||
std::is_same_v<ALayout, tensor_layout::gemm::ColumnMajor>;
|
||||
constexpr bool is_b_row_major = std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>;
|
||||
|
||||
static_assert(is_a_col_major
|
||||
? (KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
|
||||
MPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1{}])
|
||||
: (MPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
|
||||
KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1{}]),
|
||||
"A block window has incorrect lengths for defined ALayout!");
|
||||
static_assert(is_b_row_major
|
||||
? (KPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
|
||||
NPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I1{}])
|
||||
: (NPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
|
||||
KPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I1{}]),
|
||||
"B block window has incorrect lengths for defined BLayout!");
|
||||
|
||||
constexpr index_t num_stages_ = 2; // mem-read and GEMM
|
||||
index_t group_id = __builtin_amdgcn_readfirstlane(get_warp_id() % NumWarpGroups); // warp-id (0, 1) for warp specific data in this pipeline
|
||||
index_t op_id = __builtin_amdgcn_readfirstlane(get_warp_id() % NumWarpGroups); // operation to perform (mem-read (0) or GEMM(1))
|
||||
|
||||
// global memory structures here.
|
||||
auto a_copy_dram_window =
|
||||
make_tile_window(a_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
make_tuple(number<MPerBlock>{}, number<KPerBlock>{}),
|
||||
a_dram_block_window_tmp.get_window_origin(),
|
||||
Policy::template MakePingPongADramTileDistribution<Problem, NumWarpGroups>());
|
||||
|
||||
// B DRAM tile window for load
|
||||
auto b_copy_dram_window =
|
||||
make_tile_window(b_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
make_tuple(number<NPerBlock>{}, number<KPerBlock>{}),
|
||||
b_dram_block_window_tmp.get_window_origin(),
|
||||
Policy::template MakePingPongBDramTileDistribution<Problem, NumWarpGroups>());
|
||||
|
||||
// DRAM window steps.
|
||||
using ADramTileWindowStep = typename ADramBlockWindowTmp::BottomTensorIndex;
|
||||
using BDramTileWindowStep = typename BDramBlockWindowTmp::BottomTensorIndex;
|
||||
constexpr ADramTileWindowStep a_dram_tile_window_step =
|
||||
is_a_col_major ? make_array(KPerBlock, 0) : make_array(0, KPerBlock);
|
||||
constexpr BDramTileWindowStep b_dram_tile_window_step =
|
||||
is_b_row_major ? make_array(KPerBlock, 0) : make_array(0, KPerBlock);
|
||||
|
||||
constexpr auto AGemmTileDistr = decltype(make_static_tile_distribution(
|
||||
BlockGemm::MakeABlockDistributionEncode())){};
|
||||
constexpr auto BGemmTileDistr = decltype(make_static_tile_distribution(
|
||||
BlockGemm::MakeBBlockDistributionEncode())){};
|
||||
|
||||
using AGemmTile = decltype(make_static_distributed_tensor<ADataType>(AGemmTileDistr));
|
||||
using BGemmTile = decltype(make_static_distributed_tensor<BDataType>(BGemmTileDistr));
|
||||
AGemmTile a_tile_0, a_tile_1; // Gemm Tiles in registers.
|
||||
BGemmTile b_tile_0, b_tile_1;
|
||||
|
||||
// Register tile for A and B.
|
||||
constexpr auto ABlockTileDistr = a_copy_dram_window.get_tile_distribution();
|
||||
constexpr auto BBlockTileDistr = b_copy_dram_window.get_tile_distribution();
|
||||
using ABlockTile = decltype(make_static_distributed_tensor<ADataType>(ABlockTileDistr));
|
||||
using BBlockTile = decltype(make_static_distributed_tensor<BDataType>(BBlockTileDistr));
|
||||
ABlockTile a_global_load_tile;
|
||||
BBlockTile b_global_load_tile;
|
||||
|
||||
// Block GEMM
|
||||
auto block_gemm = BlockGemm();
|
||||
auto c_block_tile_0 = block_gemm.MakeCBlockTile(); // Gemm distribution.
|
||||
auto c_block_tile_1 = block_gemm.MakeCBlockTile();
|
||||
|
||||
|
||||
// Not needed
|
||||
auto&& [a_lds_block, b_lds_block] = Base::GetABLdsTensorViews(p_smem_0);
|
||||
|
||||
auto a_copy_lds_window = make_tile_window(
|
||||
a_lds_block, make_tuple(number<MPerBlock>{}, number<KPerBlock>{}), {0, 0});
|
||||
auto b_copy_lds_window = make_tile_window(
|
||||
b_lds_block, make_tuple(number<NPerBlock>{}, number<KPerBlock>{}), {0, 0});
|
||||
|
||||
auto a_lds_window =
|
||||
make_tile_window(a_lds_block,
|
||||
make_tuple(number<MPerBlock>{}, number<KPerBlock>{}),
|
||||
{0, 0},
|
||||
AGemmTileDistr);
|
||||
auto b_lds_window =
|
||||
make_tile_window(b_lds_block,
|
||||
make_tuple(number<NPerBlock>{}, number<KPerBlock>{}),
|
||||
{0, 0},
|
||||
BGemmTileDistr);
|
||||
|
||||
// initialize C
|
||||
tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile_0);
|
||||
tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile_1);
|
||||
|
||||
// define ping, pong steps here as lambda functions.
|
||||
auto MemoryOpsStep = [&](auto idx) {
|
||||
|
||||
// Memory read half here.
|
||||
Base::GlobalPrefetch(
|
||||
a_global_load_tile, a_copy_dram_window, a_dram_tile_window_step);
|
||||
Base::GlobalPrefetch(
|
||||
b_global_load_tile, b_copy_dram_window, b_dram_tile_window_step);
|
||||
|
||||
if constexpr(is_a_col_major)
|
||||
{
|
||||
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
|
||||
Policy::template MakePingPongShuffledARegTileDistribution<Problem>());
|
||||
transpose_tile2d(a_shuffle_tmp, a_global_load_tile);
|
||||
Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func);
|
||||
}
|
||||
else
|
||||
{
|
||||
Base::LocalPrefill(a_copy_lds_window, a_global_load_tile, a_element_func);
|
||||
}
|
||||
|
||||
if constexpr(is_b_row_major)
|
||||
{
|
||||
auto b_shuffle_tmp = make_static_distributed_tensor<BDataType>(
|
||||
Policy::template MakePingPongShuffledBRegTileDistribution<Problem>());
|
||||
transpose_tile2d(b_shuffle_tmp, b_global_load_tile);
|
||||
Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func);
|
||||
}
|
||||
else
|
||||
{
|
||||
Base::LocalPrefill(b_copy_lds_window, b_global_load_tile, b_element_func);
|
||||
}
|
||||
|
||||
if (idx == 0)
|
||||
{
|
||||
Base::LocalPrefetch(a_tile_0, a_lds_window);
|
||||
Base::LocalPrefetch(b_tile_0, b_lds_window);
|
||||
}
|
||||
else
|
||||
{
|
||||
Base::LocalPrefetch(a_tile_1, a_lds_window);
|
||||
Base::LocalPrefetch(b_tile_1, b_lds_window);
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
auto ComputeStep = [&](auto idx) {
|
||||
if (idx == 0)
|
||||
{
|
||||
block_gemm(c_block_tile_0, a_tile_0, b_tile_0);
|
||||
}
|
||||
else
|
||||
{
|
||||
block_gemm(c_block_tile_1, a_tile_1, b_tile_1);
|
||||
}
|
||||
};
|
||||
|
||||
if (op_id == 0)
|
||||
{
|
||||
MemoryOpsStep(group_id);
|
||||
ComputeStep(group_id);
|
||||
}
|
||||
|
||||
// start the main loop.
|
||||
index_t num_compute_steps = __builtin_amdgcn_readfirstlane(num_loop);
|
||||
while(num_compute_steps > 10)
|
||||
{
|
||||
block_sync_lds();
|
||||
op_id = (op_id + 1) % num_stages_;
|
||||
|
||||
if(op_id == 0)
|
||||
{
|
||||
MemoryOpsStep(group_id);
|
||||
}
|
||||
else
|
||||
{
|
||||
ComputeStep(group_id);
|
||||
}
|
||||
num_compute_steps -= 1;
|
||||
}
|
||||
block_sync_lds();
|
||||
|
||||
return c_block_tile_0;
|
||||
|
||||
/*
|
||||
if(op_id == 0)
|
||||
{
|
||||
MemoryOpsStep(group_id);
|
||||
}
|
||||
|
||||
// start the main loop.
|
||||
index_t num_compute_steps = __builtin_amdgcn_readfirstlane(num_loop);
|
||||
while(num_compute_steps > 1)
|
||||
{
|
||||
block_sync_lds();
|
||||
op_id = (op_id + 1) % num_stages_;
|
||||
|
||||
if(op_id == 0)
|
||||
{
|
||||
MemoryOpsStep(group_id);
|
||||
}
|
||||
else
|
||||
{
|
||||
ComputeStep(group_id);
|
||||
}
|
||||
num_compute_steps -= 1;
|
||||
}
|
||||
block_sync_lds();
|
||||
|
||||
// Handle Tail Number here.
|
||||
if(op_id == 0)
|
||||
{
|
||||
ComputeStep(group_id);
|
||||
}
|
||||
block_sync_lds();
|
||||
|
||||
// Add both the tiles and return the result.
|
||||
|
||||
|
||||
if (group_id == 0)
|
||||
{
|
||||
constexpr auto s_spans = decltype(c_block_tile_0)::get_distributed_spans();
|
||||
sweep_tile_span(s_spans[number<0>{}], [&](auto idx0) {
|
||||
sweep_tile_span(s_spans[number<1>{}], [&](auto idx1) {
|
||||
auto idx2 = make_tuple(idx0, idx1);
|
||||
c_block_tile_0(idx2) += c_block_tile_1(idx2);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
return c_block_tile_0;
|
||||
*/
|
||||
}
|
||||
|
||||
/*
|
||||
template <bool HasHotLoop,
|
||||
TailNumber TailNum,
|
||||
typename ADramBlockWindowTmp,
|
||||
@@ -143,26 +391,19 @@ struct GemmPipelineAgBgCrCompV5 : public BaseGemmPipelineAgBgCrCompV5<Problem>
|
||||
"B block window has incorrect lengths for defined BLayout!");
|
||||
|
||||
static constexpr index_t num_stages_ = 2;
|
||||
// This is used to identify the register tile on which a warp always operates on.
|
||||
// For instance, warp-0 always uses a_block_tile_0 for reading in one cycle
|
||||
// and execution in the next cycle.
|
||||
index_t group_id = get_warp_id() % num_stages_;
|
||||
|
||||
// op_id indicated one of the steps (0 - Read, 1 - Gemm Execution)
|
||||
// Each warp performs read in one cycle and in the next cycles performs GEMM operation
|
||||
// on the same block_tile that it has read in the previous cycle.
|
||||
index_t op_id = get_warp_id() % num_stages_;
|
||||
index_t group_id = __builtin_amdgcn_readfirstlane(get_warp_id() % num_stages_);
|
||||
index_t op_id = __builtin_amdgcn_readfirstlane(get_warp_id() % num_stages_);
|
||||
|
||||
// global memory structures here.
|
||||
auto a_copy_dram_window =
|
||||
make_tile_window_linear(a_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
make_tile_window(a_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
make_tuple(number<MPerBlock>{}, number<KPerBlock>{}),
|
||||
a_dram_block_window_tmp.get_window_origin(),
|
||||
Policy::template MakeADramTileDistribution<Problem>());
|
||||
|
||||
// B DRAM tile window for load
|
||||
auto b_copy_dram_window =
|
||||
make_tile_window_linear(b_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
make_tile_window(b_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
make_tuple(number<NPerBlock>{}, number<KPerBlock>{}),
|
||||
b_dram_block_window_tmp.get_window_origin(),
|
||||
Policy::template MakeBDramTileDistribution<Problem>());
|
||||
@@ -184,8 +425,6 @@ struct GemmPipelineAgBgCrCompV5 : public BaseGemmPipelineAgBgCrCompV5<Problem>
|
||||
using ALdsTile = decltype(make_static_distributed_tensor<ADataType>(ALdsTileDistr));
|
||||
using BLdsTile = decltype(make_static_distributed_tensor<BDataType>(BLdsTileDistr));
|
||||
|
||||
// array<ALdsTile, num_stages_> a_tiles;
|
||||
// array<BLdsTile, num_stages_> b_tiles;
|
||||
ALdsTile a_tile_0, a_tile_1;
|
||||
BLdsTile b_tile_0, b_tile_1;
|
||||
|
||||
@@ -199,16 +438,15 @@ struct GemmPipelineAgBgCrCompV5 : public BaseGemmPipelineAgBgCrCompV5<Problem>
|
||||
b_lds_block, make_tuple(number<NPerBlock>{}, number<KPerBlock>{}), {0, 0});
|
||||
|
||||
auto a_lds_window =
|
||||
make_tile_window_linear(a_lds_block,
|
||||
make_tile_window(a_lds_block,
|
||||
make_tuple(number<MPerBlock>{}, number<KPerBlock>{}),
|
||||
{0, 0},
|
||||
ALdsTileDistr);
|
||||
auto b_lds_window =
|
||||
make_tile_window_linear(b_lds_block,
|
||||
make_tile_window(b_lds_block,
|
||||
make_tuple(number<NPerBlock>{}, number<KPerBlock>{}),
|
||||
{0, 0},
|
||||
BLdsTileDistr);
|
||||
|
||||
// Register tile for A and B.
|
||||
constexpr auto ABlockTileDistr = a_copy_dram_window.get_tile_distribution();
|
||||
constexpr auto BBlockTileDistr = b_copy_dram_window.get_tile_distribution();
|
||||
@@ -219,13 +457,14 @@ struct GemmPipelineAgBgCrCompV5 : public BaseGemmPipelineAgBgCrCompV5<Problem>
|
||||
|
||||
// Block GEMM
|
||||
auto block_gemm = BlockGemm();
|
||||
auto c_block_tile = block_gemm.MakeCBlockTile();
|
||||
auto c_block_tile_0 = block_gemm.MakeCBlockTile();
|
||||
|
||||
// initialize C
|
||||
tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile);
|
||||
|
||||
tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile_0);
|
||||
|
||||
// define ping, pong steps here as lambda functions.
|
||||
auto MemoryOpsStep = [&]() {
|
||||
auto MemoryOpsStep = [&](auto idx) {
|
||||
|
||||
// Memory read half here.
|
||||
Base::GlobalPrefetch(
|
||||
a_global_load_tile, a_copy_dram_window, a_dram_tile_window_step);
|
||||
@@ -237,49 +476,52 @@ struct GemmPipelineAgBgCrCompV5 : public BaseGemmPipelineAgBgCrCompV5<Problem>
|
||||
{
|
||||
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
|
||||
Policy::template MakeShuffledARegTileDistribution<Problem>());
|
||||
transpose_tile2d(a_shuffle_tmp, a_global_load_tile);
|
||||
Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func);
|
||||
transpose_tile2d(a_shuffle_tmp, a_global_load_tile);
|
||||
Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func);
|
||||
}
|
||||
else
|
||||
{
|
||||
Base::LocalPrefill(a_copy_lds_window, a_global_load_tile, a_element_func);
|
||||
Base::LocalPrefill(a_copy_lds_window, a_global_load_tile, a_element_func);
|
||||
}
|
||||
|
||||
if constexpr(is_b_row_major)
|
||||
{
|
||||
auto b_shuffle_tmp = make_static_distributed_tensor<BDataType>(
|
||||
Policy::template MakeShuffledBRegTileDistribution<Problem>());
|
||||
transpose_tile2d(b_shuffle_tmp, b_global_load_tile);
|
||||
Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func);
|
||||
Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, a_element_func);
|
||||
}
|
||||
else
|
||||
{
|
||||
Base::LocalPrefill(b_copy_lds_window, b_global_load_tile, b_element_func);
|
||||
Base::LocalPrefill(a_copy_lds_window, b_global_load_tile, a_element_func);
|
||||
}
|
||||
};
|
||||
|
||||
auto ComputeStep = [&](auto idx) {
|
||||
if(idx == 0)
|
||||
if (idx == 0)
|
||||
{
|
||||
//tile_elementwise_inout([&step](auto& c) { c = step; }, c_block_tile_0);
|
||||
Base::LocalPrefetch(a_tile_0, a_lds_window);
|
||||
Base::LocalPrefetch(b_tile_0, b_lds_window);
|
||||
block_gemm(c_block_tile, a_tile_0, b_tile_0);
|
||||
// tile_elementwise_inout([](auto& c) { c += 1; }, c_block_tile);
|
||||
block_gemm(c_block_tile_0, a_tile_0, b_tile_0);
|
||||
}
|
||||
else
|
||||
{
|
||||
//tile_elementwise_inout([&step](auto& c) { c = step; }, c_block_tile_0);
|
||||
Base::LocalPrefetch(a_tile_1, a_lds_window);
|
||||
Base::LocalPrefetch(b_tile_1, b_lds_window);
|
||||
block_gemm(c_block_tile, a_tile_1, b_tile_1);
|
||||
// tile_elementwise_inout([](auto& c) { c += 1; }, c_block_tile);
|
||||
Base::LocalPrefetch(b_tile_1, b_lds_window);
|
||||
block_gemm(c_block_tile_0, a_tile_1, b_tile_1);
|
||||
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
if(op_id == 0)
|
||||
{
|
||||
MemoryOpsStep();
|
||||
}
|
||||
MemoryOpsStep(group_id);
|
||||
}
|
||||
|
||||
// start the main loop.
|
||||
index_t num_compute_steps = __builtin_amdgcn_readfirstlane(num_loop) * 2 - 1;
|
||||
index_t num_compute_steps = __builtin_amdgcn_readfirstlane(num_loop)*2 - 1;
|
||||
|
||||
while(num_compute_steps > 0)
|
||||
{
|
||||
@@ -287,29 +529,40 @@ struct GemmPipelineAgBgCrCompV5 : public BaseGemmPipelineAgBgCrCompV5<Problem>
|
||||
block_sync_lds();
|
||||
op_id = (op_id + 1) % num_stages_;
|
||||
|
||||
if(op_id == 0)
|
||||
{
|
||||
MemoryOpsStep();
|
||||
if(op_id == 0)
|
||||
{
|
||||
MemoryOpsStep(group_id);
|
||||
}
|
||||
else
|
||||
{
|
||||
ComputeStep(group_id);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
ComputeStep(group_id);
|
||||
}
|
||||
|
||||
num_compute_steps -= 1;
|
||||
}
|
||||
|
||||
|
||||
// Handle Tail Number here.
|
||||
block_sync_lds();
|
||||
|
||||
// Handle Tail Number here.
|
||||
if(op_id == 0)
|
||||
{
|
||||
ComputeStep(group_id);
|
||||
}
|
||||
|
||||
block_sync_lds();
|
||||
return c_block_tile;
|
||||
block_sync_lds();
|
||||
|
||||
|
||||
//constexpr auto s_spans = decltype(c_block_tile_0)::get_distributed_spans();
|
||||
//sweep_tile_span(s_spans[number<0>{}], [&](auto idx0) {
|
||||
// sweep_tile_span(s_spans[number<1>{}], [&](auto idx1) {
|
||||
// auto idx2 = make_tuple(idx0, idx1);
|
||||
// c_block_tile_2(idx2) = c_block_tile_0(idx2) + c_block_tile_1(idx2);
|
||||
// });
|
||||
//});
|
||||
|
||||
return c_block_tile_0;
|
||||
}
|
||||
*/
|
||||
};
|
||||
|
||||
template <typename ADramBlockWindowTmp,
|
||||
@@ -321,15 +574,17 @@ struct GemmPipelineAgBgCrCompV5 : public BaseGemmPipelineAgBgCrCompV5<Problem>
|
||||
const BDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
const BElementFunction& b_element_func,
|
||||
index_t num_loop,
|
||||
void* p_smem_0) const
|
||||
void* p_smem_0
|
||||
) const
|
||||
{
|
||||
return PipelineImpl<Scheduler>{}.template operator()<HasHotLoop, TailNum>(
|
||||
return PipelineImpl<Scheduler>{}.template operator()<HasHotLoop, TailNum, NumWarpGroups>(
|
||||
a_dram_block_window_tmp,
|
||||
a_element_func,
|
||||
b_dram_block_window_tmp,
|
||||
b_element_func,
|
||||
num_loop,
|
||||
p_smem_0);
|
||||
p_smem_0
|
||||
);
|
||||
}
|
||||
|
||||
public:
|
||||
@@ -337,15 +592,17 @@ struct GemmPipelineAgBgCrCompV5 : public BaseGemmPipelineAgBgCrCompV5<Problem>
|
||||
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const BDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
const index_t num_loop,
|
||||
void* __restrict__ p_smem_0) const
|
||||
void* __restrict__ p_smem_0
|
||||
) const
|
||||
{
|
||||
return PipelineImpl<Scheduler>{}.template operator()<HasHotLoop, TailNum>(
|
||||
return PipelineImpl<Scheduler>{}.template operator()<HasHotLoop, TailNum, NumWarpGroups>(
|
||||
a_dram_block_window_tmp,
|
||||
[](const ADataType& a) { return a; },
|
||||
b_dram_block_window_tmp,
|
||||
[](const BDataType& b) { return b; },
|
||||
num_loop,
|
||||
p_smem_0);
|
||||
p_smem_0
|
||||
);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -123,7 +123,7 @@ struct UniversalGemmBasePolicy
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeC()
|
||||
{
|
||||
using BlockGemm = remove_cvref_t<decltype(Derived::template GetBlockGemm<Problem>())>;
|
||||
using BlockGemm = remove_cvref_t<decltype(Derived::template GetBlockGemm<Problem, 2>())>;
|
||||
using WG = typename BlockGemm::WarpGemm;
|
||||
|
||||
constexpr bool TransposeC = Problem::TransposeC;
|
||||
@@ -182,6 +182,75 @@ struct UniversalGemmBasePolicy
|
||||
return Problem::TransposeC;
|
||||
}
|
||||
|
||||
template <typename Problem, index_t NumWarpGroups = 1>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakePingPongADramTileDistribution()
|
||||
{
|
||||
using ALayout = remove_cvref_t<typename Problem::ALayout>;
|
||||
|
||||
constexpr index_t BlockSize = Problem::kBlockSize;
|
||||
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
constexpr index_t VecLoadSize = GetVectorSizeA<Problem>();
|
||||
|
||||
if constexpr(std::is_same_v<ALayout, ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
using TileEncodingPattern = TileDistributionEncodingPattern2D<BlockSize,
|
||||
MPerBlock,
|
||||
KPerBlock,
|
||||
VecLoadSize,
|
||||
NumWarpGroups,
|
||||
ATileAccessPattern>;
|
||||
return TileEncodingPattern::MakePingPong2DStaticTileDistribution();
|
||||
}
|
||||
else
|
||||
{
|
||||
using TileEncodingPattern = TileDistributionEncodingPattern2D<BlockSize,
|
||||
KPerBlock,
|
||||
MPerBlock,
|
||||
VecLoadSize,
|
||||
NumWarpGroups,
|
||||
ATileAccessPattern>;
|
||||
return TileEncodingPattern::MakePingPong2DStaticTileDistribution();
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Problem, index_t NumWarpGroups = 1>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakePingPongBDramTileDistribution()
|
||||
{
|
||||
using BLayout = remove_cvref_t<typename Problem::BLayout>;
|
||||
|
||||
constexpr index_t BlockSize = Problem::kBlockSize;
|
||||
constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
constexpr index_t VecLoadSize = GetVectorSizeB<Problem>();
|
||||
|
||||
// Tile: KPerBlock X NPerBlock
|
||||
if constexpr(std::is_same_v<BLayout, ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
using TileEncodingPattern = TileDistributionEncodingPattern2D<BlockSize,
|
||||
KPerBlock,
|
||||
NPerBlock,
|
||||
VecLoadSize,
|
||||
NumWarpGroups,
|
||||
BTileAccessPattern
|
||||
>;
|
||||
return TileEncodingPattern::MakePingpPong2DStaticTileDistribution();
|
||||
}
|
||||
// Tile: NPerBlock X KPerBlock
|
||||
else
|
||||
{
|
||||
using TileEncodingPattern = TileDistributionEncodingPattern2D<BlockSize,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
VecLoadSize,
|
||||
NumWarpGroups,
|
||||
BTileAccessPattern
|
||||
>;
|
||||
return TileEncodingPattern::MakePingPong2DStaticTileDistribution();
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeADramTileDistribution()
|
||||
{
|
||||
@@ -199,6 +268,7 @@ struct UniversalGemmBasePolicy
|
||||
MPerBlock,
|
||||
KPerBlock,
|
||||
VecLoadSize,
|
||||
1,
|
||||
ATileAccessPattern>;
|
||||
return TileEncodingPattern::Make2DStaticTileDistribution();
|
||||
}
|
||||
@@ -209,6 +279,7 @@ struct UniversalGemmBasePolicy
|
||||
KPerBlock,
|
||||
MPerBlock,
|
||||
VecLoadSize,
|
||||
1,
|
||||
ATileAccessPattern>;
|
||||
return TileEncodingPattern::Make2DStaticTileDistribution();
|
||||
}
|
||||
@@ -231,6 +302,7 @@ struct UniversalGemmBasePolicy
|
||||
KPerBlock,
|
||||
NPerBlock,
|
||||
VecLoadSize,
|
||||
1,
|
||||
BTileAccessPattern>;
|
||||
return TileEncodingPattern::Make2DStaticTileDistribution();
|
||||
}
|
||||
@@ -241,11 +313,50 @@ struct UniversalGemmBasePolicy
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
VecLoadSize,
|
||||
1,
|
||||
BTileAccessPattern>;
|
||||
return TileEncodingPattern::Make2DStaticTileDistribution();
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Problem, index_t NumWarpGroups = 1>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakePingPongShuffledARegTileDistribution()
|
||||
{
|
||||
using ALayout = remove_cvref_t<typename Problem::ALayout>;
|
||||
static_assert(std::is_same_v<ALayout, ck_tile::tensor_layout::gemm::ColumnMajor>);
|
||||
constexpr index_t BlockSize = Problem::kBlockSize;
|
||||
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
constexpr index_t VecLoadSize = GetVectorSizeA<Problem>();
|
||||
|
||||
using TileEncodingPattern = TileDistributionEncodingPattern2D<BlockSize,
|
||||
KPerBlock,
|
||||
MPerBlock,
|
||||
VecLoadSize,
|
||||
NumWarpGroups,
|
||||
ATileAccessPattern>;
|
||||
return TileEncodingPattern::MakePingPongShuffled2DStaticTileDistribution();
|
||||
}
|
||||
|
||||
template <typename Problem, index_t NumWarpGroups = 1>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakePingPongShuffledBRegTileDistribution()
|
||||
{
|
||||
using BLayout = remove_cvref_t<typename Problem::BLayout>;
|
||||
static_assert(std::is_same_v<BLayout, ck_tile::tensor_layout::gemm::RowMajor>);
|
||||
constexpr index_t BlockSize = Problem::kBlockSize;
|
||||
constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
constexpr index_t VecLoadSize = GetVectorSizeB<Problem>();
|
||||
|
||||
using TileEncodingPattern = TileDistributionEncodingPattern2D<BlockSize,
|
||||
KPerBlock,
|
||||
NPerBlock,
|
||||
VecLoadSize,
|
||||
NumWarpGroups,
|
||||
BTileAccessPattern>;
|
||||
return TileEncodingPattern::MakePingPongShuffled2DStaticTileDistribution();
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledARegTileDistribution()
|
||||
{
|
||||
@@ -260,6 +371,7 @@ struct UniversalGemmBasePolicy
|
||||
KPerBlock,
|
||||
MPerBlock,
|
||||
VecLoadSize,
|
||||
1,
|
||||
ATileAccessPattern>;
|
||||
return TileEncodingPattern::MakeShuffled2DStaticTileDistribution();
|
||||
}
|
||||
@@ -278,6 +390,7 @@ struct UniversalGemmBasePolicy
|
||||
KPerBlock,
|
||||
NPerBlock,
|
||||
VecLoadSize,
|
||||
1,
|
||||
BTileAccessPattern>;
|
||||
return TileEncodingPattern::MakeShuffled2DStaticTileDistribution();
|
||||
}
|
||||
@@ -285,7 +398,7 @@ struct UniversalGemmBasePolicy
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetSmemPackA()
|
||||
{
|
||||
using BlockGemm = remove_cvref_t<decltype(Derived::template GetBlockGemm<Problem>())>;
|
||||
using BlockGemm = remove_cvref_t<decltype(Derived::template GetBlockGemm<Problem, 2>())>;
|
||||
constexpr index_t KPack = BlockGemm::Traits::KPack;
|
||||
return KPack;
|
||||
}
|
||||
@@ -293,7 +406,7 @@ struct UniversalGemmBasePolicy
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetSmemPackB()
|
||||
{
|
||||
using BlockGemm = remove_cvref_t<decltype(Derived::template GetBlockGemm<Problem>())>;
|
||||
using BlockGemm = remove_cvref_t<decltype(Derived::template GetBlockGemm<Problem, 2>())>;
|
||||
constexpr index_t KPack = BlockGemm::Traits::KPack;
|
||||
return KPack;
|
||||
}
|
||||
@@ -362,7 +475,7 @@ struct UniversalGemmPipelineAgBgCrPolicy
|
||||
constexpr auto a_lds_block_desc_xk0_mnldslayer_mn_xk1 = transform_tensor_descriptor(
|
||||
a_lds_block_desc_permuted,
|
||||
make_tuple(make_unmerge_transform(
|
||||
make_tuple(number<KPerBlock / KPack>{}, number<MLdsLayer>{})),
|
||||
make_tuple(number<MLdsLayer>{}, number<KPerBlock / KPack>{})),
|
||||
make_pass_through_transform(number<MPerBlock / MLdsLayer>{}),
|
||||
make_pass_through_transform(number<KPack>{})),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}),
|
||||
@@ -374,7 +487,7 @@ struct UniversalGemmPipelineAgBgCrPolicy
|
||||
make_tuple(number<MPerBlock / MLdsLayer>{}, number<MLdsLayer>{})),
|
||||
make_merge_transform_v3_division_mod(
|
||||
make_tuple(number<KPerBlock / KPack>{}, number<KPack>{}))),
|
||||
make_tuple(sequence<1, 2>{}, sequence<0, 3>{}),
|
||||
make_tuple(sequence<1, 0>{}, sequence<2, 3>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
return a_lds_block_desc;
|
||||
@@ -421,7 +534,7 @@ struct UniversalGemmPipelineAgBgCrPolicy
|
||||
|
||||
constexpr auto b_lds_block_desc_bk0_nldslayer_n_bk1 = transform_tensor_descriptor(
|
||||
b_lds_block_desc_permuted,
|
||||
make_tuple(make_unmerge_transform(make_tuple(BK0, number<NLdsLayer>{})),
|
||||
make_tuple(make_unmerge_transform(make_tuple(number<NLdsLayer>{}, BK0)),
|
||||
make_pass_through_transform(number<NPerBlock / NLdsLayer>{}),
|
||||
make_pass_through_transform(number<KPack>{})),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}),
|
||||
@@ -432,7 +545,7 @@ struct UniversalGemmPipelineAgBgCrPolicy
|
||||
make_tuple(make_merge_transform_v3_division_mod(
|
||||
make_tuple(number<NPerBlock / NLdsLayer>{}, number<NLdsLayer>{})),
|
||||
make_merge_transform_v3_division_mod(make_tuple(BK0, number<KPack>{}))),
|
||||
make_tuple(sequence<1, 2>{}, sequence<0, 3>{}),
|
||||
make_tuple(sequence<1, 0>{}, sequence<2, 3>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
return b_lds_block_desc;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user