refactor blockgemm change, isolate to v2;

This commit is contained in:
aska-0096
2025-08-12 14:25:50 +00:00
parent fd1eb323af
commit 0810799e25
8 changed files with 514 additions and 194 deletions

View File

@@ -1038,7 +1038,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
tuple<sequence<NIterPerWarp, NWarp>, sequence<KIterPerWarp>>,
tuple<sequence<0, 1>>,
tuple<sequence<0, 1>>,
sequence<2, 1>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto k_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
@@ -1096,7 +1096,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
tuple<sequence<NIterPerWarp, NWarp>, sequence<KIterPerWarp>>,
tuple<sequence<0, 1>>,
tuple<sequence<0, 1>>,
sequence<2, 1>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto v_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
@@ -1190,7 +1190,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
tuple<sequence<NIterPerWarp, NWarp>, sequence<KIterPerWarp>>,
tuple<sequence<0, 1>>,
tuple<sequence<0, 1>>,
sequence<2, 1>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto kt_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
@@ -1249,7 +1249,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
tuple<sequence<MIterPerWarp, MWarp>, sequence<KIterPerWarp>>,
tuple<sequence<1, 0>>,
tuple<sequence<1, 0>>,
sequence<2, 1>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto q_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
@@ -1344,7 +1344,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
tuple<sequence<NIterPerWarp, NWarp>, sequence<KIterPerWarp>>,
tuple<sequence<0, 1>>,
tuple<sequence<0, 1>>,
sequence<2, 1>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto qt_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
@@ -1379,7 +1379,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
tuple<sequence<MIterPerWarp, MWarp>, sequence<KIterPerWarp>>,
tuple<sequence<1, 0>>,
tuple<sequence<1, 0>>,
sequence<2, 1>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto dst_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
@@ -1490,7 +1490,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
tuple<sequence<MIterPerWarp, MWarp>, sequence<KIterPerWarp>>,
tuple<sequence<1, 0>>,
tuple<sequence<1, 0>>,
sequence<2, 1>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto do_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
@@ -1589,7 +1589,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
tuple<sequence<NIterPerWarp, NWarp>, sequence<KIterPerWarp>>,
tuple<sequence<0, 1>>,
tuple<sequence<0, 1>>,
sequence<2, 1>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto dot_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
@@ -1623,7 +1623,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
tuple<sequence<MIterPerWarp, MWarp>, sequence<KIterPerWarp>>,
tuple<sequence<1, 0>>,
tuple<sequence<1, 0>>,
sequence<2, 1>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto pt_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
@@ -1667,7 +1667,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
tuple<sequence<MIterPerWarp, MWarp>, sequence<KIterPerWarp>>,
tuple<sequence<1, 0>>,
tuple<sequence<1, 0>>,
sequence<2, 1>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto ds_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
@@ -1718,7 +1718,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
pt_out.set_y_sliced_thread_data(
merge_sequences(sequence<kIter, mIter>{}, a_warp_y_index_zeros),
merge_sequences(sequence<mIter, kIter>{}, a_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths),
pt_warp_tensor.get_thread_buffer());
});
@@ -1768,7 +1768,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
dst_out.set_y_sliced_thread_data(
merge_sequences(sequence<kIter, mIter>{}, a_warp_y_index_zeros),
merge_sequences(sequence<mIter, kIter>{}, a_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths),
dst_warp_tensor.get_thread_buffer());
});

View File

@@ -5,13 +5,11 @@
#include "ck_tile/core.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_custom_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1.hpp"
#include "ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1_custom_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v2_custom_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v2.hpp"
// can remove all bank conflicts, but drop the performance for some cases
// Probably it is limited by compiler optimization.
@@ -512,8 +510,7 @@ struct BlockFmhaPipelineQRKSVSAsyncTrloadDefaultPolicy
Problem::BlockFmhaShape::kN0,
Problem::BlockFmhaShape::kK0>,
typename Problem::BlockFmhaShape::Gemm0BlockWarps,
typename Problem::BlockFmhaShape::Gemm0WarpTile>,
GemmLoopOrder::MNK>;
typename Problem::BlockFmhaShape::Gemm0WarpTile>>;
using WarpGemm =
WarpGemmMfmaDispatcher<typename Problem::QDataType,
@@ -525,13 +522,14 @@ struct BlockFmhaPipelineQRKSVSAsyncTrloadDefaultPolicy
true>;
using BlockGemmPolicy =
BlockGemmARegBRegCRegV1CustomPolicy<typename Problem::QDataType,
BlockGemmARegBRegCRegV2CustomPolicy<typename Problem::QDataType,
typename Problem::KDataType,
typename Problem::SaccDataType,
typename Problem::BlockFmhaShape::Gemm0BlockWarps,
WarpGemm>;
WarpGemm,
GemmLoopOrder::MNK>;
return BlockGemmARegBRegCRegV1<GemmProblem, BlockGemmPolicy>{};
return BlockGemmARegBRegCRegV2<GemmProblem, BlockGemmPolicy>{};
}
template <typename Problem>
@@ -546,8 +544,7 @@ struct BlockFmhaPipelineQRKSVSAsyncTrloadDefaultPolicy
Problem::BlockFmhaShape::kN1,
Problem::BlockFmhaShape::kK1>,
typename Problem::BlockFmhaShape::Gemm1BlockWarps,
typename Problem::BlockFmhaShape::Gemm1WarpTile>,
GemmLoopOrder::KMN>;
typename Problem::BlockFmhaShape::Gemm1WarpTile>>;
using WarpGemm = WarpGemmMfmaDispatcher<
typename Problem::PDataType,
@@ -567,13 +564,14 @@ struct BlockFmhaPipelineQRKSVSAsyncTrloadDefaultPolicy
: WGAttrNumAccessEnum::Single>;
using BlockGemmPolicy =
BlockGemmARegBRegCRegV1CustomPolicy<typename Problem::PDataType,
BlockGemmARegBRegCRegV2CustomPolicy<typename Problem::PDataType,
typename Problem::VDataType,
typename Problem::OaccDataType,
typename Problem::BlockFmhaShape::Gemm1BlockWarps,
WarpGemm>;
WarpGemm,
GemmLoopOrder::KMN>;
return BlockGemmARegBRegCRegV1<GemmProblem, BlockGemmPolicy>{};
return BlockGemmARegBRegCRegV2<GemmProblem, BlockGemmPolicy>{};
}
template <typename Problem>

View File

@@ -42,8 +42,6 @@ struct BlockGemmARegBRegCRegV1
static constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WarpGemm::kN);
static constexpr index_t KIterPerWarp = KPerBlock / WarpGemm::kK;
static constexpr auto BlockGemmLoopOrder = Problem::BlockGemmLoopOrder;
static constexpr index_t KPack = WarpGemm::kKPerThread;
};
@@ -54,9 +52,8 @@ struct BlockGemmARegBRegCRegV1
using Traits = GemmTraits_<Problem, Policy>;
using WarpGemm = typename Traits::WarpGemm;
using BlockGemmShape = typename Traits::BlockGemmShape;
static constexpr auto BlockGemmLoopOrder = Traits::BlockGemmLoopOrder;
using WarpGemm = typename Traits::WarpGemm;
using BlockGemmShape = typename Traits::BlockGemmShape;
using ADataType = remove_cvref_t<typename Traits::ADataType>;
using BDataType = remove_cvref_t<typename Traits::BDataType>;
@@ -89,36 +86,17 @@ struct BlockGemmARegBRegCRegV1
}
else
{
if constexpr(BlockGemmLoopOrder == GemmLoopOrder::KMN)
{
constexpr auto a_block_outer_dstr_encoding = tile_distribution_encoding<
sequence<NWarp>,
tuple<sequence<MIterPerWarp, MWarp>, sequence<KIterPerWarp>>,
tuple<sequence<1, 0>>,
tuple<sequence<1, 0>>,
sequence<2, 1>,
sequence<0, 0>>{};
constexpr auto a_block_outer_dstr_encoding = tile_distribution_encoding<
sequence<NWarp>,
tuple<sequence<MIterPerWarp, MWarp>, sequence<KIterPerWarp>>,
tuple<sequence<1, 0>>,
tuple<sequence<1, 0>>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
a_block_outer_dstr_encoding, typename WarpGemm::AWarpDstrEncoding{});
constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
a_block_outer_dstr_encoding, typename WarpGemm::AWarpDstrEncoding{});
return a_block_dstr_encode;
}
else if constexpr(BlockGemmLoopOrder == GemmLoopOrder::MNK)
{
constexpr auto a_block_outer_dstr_encoding = tile_distribution_encoding<
sequence<NWarp>,
tuple<sequence<MIterPerWarp, MWarp>, sequence<KIterPerWarp>>,
tuple<sequence<1, 0>>,
tuple<sequence<1, 0>>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
a_block_outer_dstr_encoding, typename WarpGemm::AWarpDstrEncoding{});
return a_block_dstr_encode;
}
return a_block_dstr_encode;
}
}
@@ -140,33 +118,17 @@ struct BlockGemmARegBRegCRegV1
}
else
{
if constexpr(BlockGemmLoopOrder == GemmLoopOrder::KMN)
{
constexpr auto b_block_outer_dstr_encoding = tile_distribution_encoding<
sequence<MWarp>,
tuple<sequence<NIterPerWarp, NWarp>, sequence<KIterPerWarp>>,
tuple<sequence<0, 1>>,
tuple<sequence<0, 1>>,
sequence<2, 1>,
sequence<0, 0>>{};
constexpr auto b_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
b_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{});
constexpr auto b_block_outer_dstr_encoding = tile_distribution_encoding<
sequence<MWarp>,
tuple<sequence<NIterPerWarp, NWarp>, sequence<KIterPerWarp>>,
tuple<sequence<0, 1>>,
tuple<sequence<0, 1>>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto b_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
b_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{});
return b_block_dstr_encode;
}
else if constexpr(BlockGemmLoopOrder == GemmLoopOrder::MNK)
{
constexpr auto b_block_outer_dstr_encoding = tile_distribution_encoding<
sequence<MWarp>,
tuple<sequence<NIterPerWarp, NWarp>, sequence<KIterPerWarp>>,
tuple<sequence<0, 1>>,
tuple<sequence<0, 1>>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto b_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
b_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{});
return b_block_dstr_encode;
}
return b_block_dstr_encode;
}
}
@@ -251,82 +213,40 @@ struct BlockGemmARegBRegCRegV1
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
// hot loop:
if constexpr(BlockGemmLoopOrder == GemmLoopOrder::KMN)
{
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
// read A warp tensor from A Block window
AWarpTensor a_warp_tensor;
a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data(
merge_sequences(sequence<kIter, mIter>{}, a_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
// read B warp tensor from B block tensor
BWarpTensor b_warp_tensor;
b_warp_tensor.get_thread_buffer() = b_block_tensor.get_y_sliced_thread_data(
merge_sequences(sequence<kIter, nIter>{}, b_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, b_warp_y_lengths));
// read C warp tensor from C block tensor
using c_iter_idx = std::conditional_t<TransposeC,
sequence<nIter, mIter>,
sequence<mIter, nIter>>;
CWarpTensor c_warp_tensor;
c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data(
merge_sequences(c_iter_idx{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
// warp GEMM
WarpGemm{}(c_warp_tensor, a_warp_tensor, b_warp_tensor);
// write C warp tensor into C block tensor
c_block_tensor.set_y_sliced_thread_data(
merge_sequences(c_iter_idx{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
c_warp_tensor.get_thread_buffer());
});
});
});
}
else if constexpr(BlockGemmLoopOrder == GemmLoopOrder::MNK)
{
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
// read A warp tensor from A Block window
AWarpTensor a_warp_tensor;
a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, kIter>{}, a_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
// read A warp tensor from A Block window
AWarpTensor a_warp_tensor;
// read B warp tensor from B block tensor
BWarpTensor b_warp_tensor;
b_warp_tensor.get_thread_buffer() = b_block_tensor.get_y_sliced_thread_data(
merge_sequences(sequence<nIter, kIter>{}, b_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, b_warp_y_lengths));
a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, kIter>{}, a_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
// read C warp tensor from C block tensor
using c_iter_idx = std::
conditional_t<TransposeC, sequence<nIter, mIter>, sequence<mIter, nIter>>;
CWarpTensor c_warp_tensor;
c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data(
merge_sequences(c_iter_idx{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
// read B warp tensor from B block tensor
BWarpTensor b_warp_tensor;
// warp GEMM
WarpGemm{}(c_warp_tensor, a_warp_tensor, b_warp_tensor);
b_warp_tensor.get_thread_buffer() = b_block_tensor.get_y_sliced_thread_data(
merge_sequences(sequence<nIter, kIter>{}, b_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, b_warp_y_lengths));
// read C warp tensor from C block tensor
CWarpTensor c_warp_tensor;
c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
// warp GEMM
WarpGemm{}(c_warp_tensor, a_warp_tensor, b_warp_tensor);
// write C warp tensor into C block tensor
c_block_tensor.set_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
c_warp_tensor.get_thread_buffer());
});
// write C warp tensor into C block tensor
c_block_tensor.set_y_sliced_thread_data(
merge_sequences(c_iter_idx{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
c_warp_tensor.get_thread_buffer());
});
});
}
});
}
CK_TILE_DEVICE static constexpr auto MakeCBlockTile()

View File

@@ -0,0 +1,372 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v2_custom_policy.hpp"
namespace ck_tile {
// This BlockGemm enhanced the control over inst issue order
// A is block distributed tensor
// B is block distributed tensor
// C is block distributed tensor
template <typename Problem_, typename Policy_>
struct BlockGemmARegBRegCRegV2
{
private:
template <typename PipelineProblem_, typename GemmPolicy_>
struct GemmTraits_
{
using Problem = remove_cvref_t<PipelineProblem_>;
using Policy = remove_cvref_t<GemmPolicy_>;
using ADataType = remove_cvref_t<typename Problem::ADataType>;
using BDataType = remove_cvref_t<typename Problem::BDataType>;
using CDataType = remove_cvref_t<typename Problem::CDataType>;
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
static constexpr index_t kBlockSize = Problem::kBlockSize;
static constexpr index_t MPerBlock = BlockGemmShape::kM;
static constexpr index_t NPerBlock = BlockGemmShape::kN;
static constexpr index_t KPerBlock = BlockGemmShape::kK;
static constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>();
using WarpGemm = remove_cvref_t<decltype(config.template at<0>())>;
static constexpr index_t MWarp = config.template at<1>();
static constexpr index_t NWarp = config.template at<2>();
static constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WarpGemm::kM);
static constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WarpGemm::kN);
static constexpr index_t KIterPerWarp = KPerBlock / WarpGemm::kK;
static constexpr auto BlockGemmLoopOrder = Policy::BlockGemmLoopOrder;
static constexpr index_t KPack = WarpGemm::kKPerThread;
};
public:
using Problem = remove_cvref_t<Problem_>;
using Policy = remove_cvref_t<Policy_>;
using Traits = GemmTraits_<Problem, Policy>;
using WarpGemm = typename Traits::WarpGemm;
using BlockGemmShape = typename Traits::BlockGemmShape;
static constexpr auto BlockGemmLoopOrder = Traits::BlockGemmLoopOrder;
using ADataType = remove_cvref_t<typename Traits::ADataType>;
using BDataType = remove_cvref_t<typename Traits::BDataType>;
using CDataType = remove_cvref_t<typename Traits::CDataType>;
static constexpr index_t KIterPerWarp = Traits::KIterPerWarp;
static constexpr index_t MIterPerWarp = Traits::MIterPerWarp;
static constexpr index_t NIterPerWarp = Traits::NIterPerWarp;
static constexpr index_t MWarp = Traits::MWarp;
static constexpr index_t NWarp = Traits::NWarp;
static constexpr bool UseDefaultScheduler = (Problem::NumWaveGroups != 1);
CK_TILE_DEVICE static constexpr auto MakeABlockDistributionEncode()
{
if constexpr(UseDefaultScheduler)
{
constexpr auto a_block_outer_dstr_encoding =
tile_distribution_encoding<sequence<NWarp>,
tuple<sequence<MIterPerWarp>, sequence<KIterPerWarp>>,
tuple<>,
tuple<>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
a_block_outer_dstr_encoding, typename WarpGemm::AWarpDstrEncoding{});
return a_block_dstr_encode;
}
else
{
if constexpr(BlockGemmLoopOrder == GemmLoopOrder::KMN)
{
constexpr auto a_block_outer_dstr_encoding = tile_distribution_encoding<
sequence<NWarp>,
tuple<sequence<MIterPerWarp, MWarp>, sequence<KIterPerWarp>>,
tuple<sequence<1, 0>>,
tuple<sequence<1, 0>>,
sequence<2, 1>,
sequence<0, 0>>{};
constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
a_block_outer_dstr_encoding, typename WarpGemm::AWarpDstrEncoding{});
return a_block_dstr_encode;
}
else if constexpr(BlockGemmLoopOrder == GemmLoopOrder::MNK)
{
constexpr auto a_block_outer_dstr_encoding = tile_distribution_encoding<
sequence<NWarp>,
tuple<sequence<MIterPerWarp, MWarp>, sequence<KIterPerWarp>>,
tuple<sequence<1, 0>>,
tuple<sequence<1, 0>>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
a_block_outer_dstr_encoding, typename WarpGemm::AWarpDstrEncoding{});
return a_block_dstr_encode;
}
}
}
CK_TILE_DEVICE static constexpr auto MakeBBlockDistributionEncode()
{
if constexpr(UseDefaultScheduler)
{
constexpr auto b_block_outer_dstr_encoding =
tile_distribution_encoding<sequence<MWarp>,
tuple<sequence<NIterPerWarp>, sequence<KIterPerWarp>>,
tuple<>,
tuple<>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto b_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
b_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{});
return b_block_dstr_encode;
}
else
{
if constexpr(BlockGemmLoopOrder == GemmLoopOrder::KMN)
{
constexpr auto b_block_outer_dstr_encoding = tile_distribution_encoding<
sequence<MWarp>,
tuple<sequence<NIterPerWarp, NWarp>, sequence<KIterPerWarp>>,
tuple<sequence<0, 1>>,
tuple<sequence<0, 1>>,
sequence<2, 1>,
sequence<0, 0>>{};
constexpr auto b_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
b_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{});
return b_block_dstr_encode;
}
else if constexpr(BlockGemmLoopOrder == GemmLoopOrder::MNK)
{
constexpr auto b_block_outer_dstr_encoding = tile_distribution_encoding<
sequence<MWarp>,
tuple<sequence<NIterPerWarp, NWarp>, sequence<KIterPerWarp>>,
tuple<sequence<0, 1>>,
tuple<sequence<0, 1>>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto b_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
b_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{});
return b_block_dstr_encode;
}
}
}
CK_TILE_DEVICE static constexpr auto MakeCBlockDistributionEncode()
{
if constexpr(UseDefaultScheduler)
{
constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding<
sequence<MWarp>,
tuple<sequence<MIterPerWarp>, sequence<NIterPerWarp, NWarp>>,
tuple<>,
tuple<>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
c_block_outer_dstr_encoding, typename WarpGemm::CWarpDstrEncoding{});
return c_block_dstr_encode;
}
else
{
constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding<
sequence<>,
tuple<sequence<MIterPerWarp, MWarp>, sequence<NIterPerWarp, NWarp>>,
tuple<sequence<1, 2>>,
tuple<sequence<1, 1>>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
c_block_outer_dstr_encoding, typename WarpGemm::CWarpDstrEncoding{});
return c_block_dstr_encode;
}
}
// C += A * B
template <typename CBlockTensor, typename ABlockTensor, typename BBlockTensor>
CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor,
const ABlockTensor& a_block_tensor,
const BBlockTensor& b_block_tensor) const
{
static_assert(std::is_same_v<ADataType, remove_cv_t<typename ABlockTensor::DataType>> &&
std::is_same_v<BDataType, remove_cv_t<typename BBlockTensor::DataType>> &&
std::is_same_v<CDataType, remove_cv_t<typename CBlockTensor::DataType>>,
"wrong!");
// check ABC-block-distribution
static_assert(
std::is_same_v<remove_cvref_t<decltype(MakeABlockDistributionEncode())>,
remove_cvref_t<decltype(ABlockTensor::get_tile_distribution()
.get_static_tile_distribution_encoding())>>,
"A distribution is wrong!");
static_assert(
std::is_same_v<remove_cvref_t<decltype(MakeBBlockDistributionEncode())>,
remove_cvref_t<decltype(BBlockTensor::get_tile_distribution()
.get_static_tile_distribution_encoding())>>,
"B distribution is wrong!");
static_assert(
std::is_same_v<remove_cvref_t<decltype(MakeCBlockDistributionEncode())>,
remove_cvref_t<decltype(CBlockTensor::get_tile_distribution()
.get_static_tile_distribution_encoding())>>,
"C distribution is wrong!");
using AWarpDstr = typename WarpGemm::AWarpDstr;
using BWarpDstr = typename WarpGemm::BWarpDstr;
using CWarpDstr = typename WarpGemm::CWarpDstr;
using AWarpTensor = typename WarpGemm::AWarpTensor;
using BWarpTensor = typename WarpGemm::BWarpTensor;
using CWarpTensor = typename WarpGemm::CWarpTensor;
constexpr auto a_warp_y_lengths =
to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
constexpr auto b_warp_y_lengths =
to_sequence(BWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
constexpr auto c_warp_y_lengths =
to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
constexpr auto a_warp_y_index_zeros = uniform_sequence_gen_t<AWarpDstr::NDimY, 0>{};
constexpr auto b_warp_y_index_zeros = uniform_sequence_gen_t<BWarpDstr::NDimY, 0>{};
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
// hot loop:
if constexpr(BlockGemmLoopOrder == GemmLoopOrder::KMN)
{
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
// read A warp tensor from A Block window
AWarpTensor a_warp_tensor;
a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data(
merge_sequences(sequence<kIter, mIter>{}, a_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
// read B warp tensor from B block tensor
BWarpTensor b_warp_tensor;
b_warp_tensor.get_thread_buffer() = b_block_tensor.get_y_sliced_thread_data(
merge_sequences(sequence<kIter, nIter>{}, b_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, b_warp_y_lengths));
CWarpTensor c_warp_tensor;
c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
// warp GEMM
WarpGemm{}(c_warp_tensor, a_warp_tensor, b_warp_tensor);
// write C warp tensor into C block tensor
c_block_tensor.set_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
c_warp_tensor.get_thread_buffer());
});
});
});
}
else if constexpr(BlockGemmLoopOrder == GemmLoopOrder::MNK)
{
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
// read A warp tensor from A Block window
AWarpTensor a_warp_tensor;
a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, kIter>{}, a_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
// read B warp tensor from B block tensor
BWarpTensor b_warp_tensor;
b_warp_tensor.get_thread_buffer() = b_block_tensor.get_y_sliced_thread_data(
merge_sequences(sequence<nIter, kIter>{}, b_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, b_warp_y_lengths));
// read C warp tensor from C block tensor
CWarpTensor c_warp_tensor;
c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
// warp GEMM
WarpGemm{}(c_warp_tensor, a_warp_tensor, b_warp_tensor);
// write C warp tensor into C block tensor
c_block_tensor.set_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
c_warp_tensor.get_thread_buffer());
});
});
});
}
}
CK_TILE_DEVICE static constexpr auto MakeCBlockTile()
{
if constexpr(UseDefaultScheduler)
{
constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding<
sequence<MWarp>,
tuple<sequence<MIterPerWarp>, sequence<NIterPerWarp, NWarp>>,
tuple<>,
tuple<>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
c_block_outer_dstr_encoding, typename WarpGemm::CWarpDstrEncoding{});
constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode);
auto c_block_tensor = make_static_distributed_tensor<CDataType>(c_block_dstr);
return c_block_tensor;
}
else
{
constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding<
sequence<>,
tuple<sequence<MIterPerWarp, MWarp>, sequence<NIterPerWarp, NWarp>>,
tuple<sequence<1, 2>>,
tuple<sequence<1, 1>>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
c_block_outer_dstr_encoding, typename WarpGemm::CWarpDstrEncoding{});
constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode);
auto c_block_tensor = make_static_distributed_tensor<CDataType>(c_block_dstr);
return c_block_tensor;
}
}
// C = A * B
template <typename ABlockTensor, typename BBlockTensor>
CK_TILE_DEVICE auto operator()(const ABlockTensor& a_block_tensor,
const BBlockTensor& b_block_tensor) const
{
auto c_block_tensor = MakeCBlockTile();
operator()(c_block_tensor, a_block_tensor, b_block_tensor);
return c_block_tensor;
}
};
} // namespace ck_tile

View File

@@ -0,0 +1,45 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace ck_tile {
enum struct GemmLoopOrder
{
KMN,
MNK,
};
template <typename AType_,
typename BType_,
typename CType_,
typename BlockWarps_,
typename WarpGemm_,
GemmLoopOrder BlockGemmLoopOrder_ = GemmLoopOrder::KMN>
struct BlockGemmARegBRegCRegV2CustomPolicy
{
using AType = remove_cvref_t<AType_>;
using BType = remove_cvref_t<BType_>;
using CType = remove_cvref_t<CType_>;
using BlockWarps = remove_cvref_t<BlockWarps_>;
static constexpr index_t kMWarps = BlockWarps::at(number<0>{});
static constexpr index_t kNWarps = BlockWarps::at(number<1>{});
static constexpr index_t kKWarps = BlockWarps::at(number<2>{});
using WarpGemm = remove_cvref_t<WarpGemm_>;
static constexpr auto BlockGemmLoopOrder = BlockGemmLoopOrder_;
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetWarpGemmMWarpNWarp()
{
return make_tuple(WarpGemm{}, kMWarps, kNWarps);
}
};
} // namespace ck_tile

View File

@@ -4,7 +4,6 @@
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
namespace ck_tile {
@@ -14,8 +13,7 @@ template <typename ADataType_,
typename CDataType_,
index_t kBlockSize_,
typename BlockGemmShape_,
GemmLoopOrder BlockGemmLoopOrder_ = GemmLoopOrder::KMN,
index_t NumWaveGroups_ = 1>
index_t NumWaveGroups_ = 1>
struct BlockGemmProblem
{
using ADataType = remove_cvref_t<ADataType_>;
@@ -23,9 +21,8 @@ struct BlockGemmProblem
using CDataType = remove_cvref_t<CDataType_>;
using BlockGemmShape = remove_cvref_t<BlockGemmShape_>;
static constexpr index_t kBlockSize = kBlockSize_;
static constexpr index_t NumWaveGroups = NumWaveGroups_;
static constexpr GemmLoopOrder BlockGemmLoopOrder = BlockGemmLoopOrder_;
static constexpr index_t kBlockSize = kBlockSize_;
static constexpr index_t NumWaveGroups = NumWaveGroups_;
};
} // namespace ck_tile

View File

@@ -39,12 +39,6 @@ enum struct TailNumber
Full,
};
enum struct GemmLoopOrder
{
KMN,
MNK,
};
} // namespace ck_tile
inline std::ostream& operator<<(std::ostream& os, const ck_tile::GemmPipelineScheduler& s)

View File

@@ -14,11 +14,10 @@ template <typename ADataType_,
typename CDataType_,
typename BlockGemmShape_,
typename Traits_,
typename ComputeDataType_ = ADataType_,
bool FixedVectorSize_ = false,
index_t VectorSizeA_ = 1,
index_t VectorSizeB_ = 1,
GemmLoopOrder BlockGemmLoopOrder_ = GemmLoopOrder::KMN>
typename ComputeDataType_ = ADataType_,
bool FixedVectorSize_ = false,
index_t VectorSizeA_ = 1,
index_t VectorSizeB_ = 1>
struct GemmPipelineProblemBase
{
using Traits = remove_cvref_t<Traits_>;
@@ -46,10 +45,9 @@ struct GemmPipelineProblemBase
static constexpr bool kPadN = Traits::kPadN;
static constexpr bool kPadK = Traits::kPadK;
static constexpr bool DoubleSmemBuffer = Traits::DoubleSmemBuffer;
static constexpr auto Scheduler = GemmPipelineScheduler::Default;
static constexpr index_t VectorLoadSize = Traits::_VectorSize;
static constexpr GemmLoopOrder BlockGemmLoopOrder = BlockGemmLoopOrder_;
static constexpr bool DoubleSmemBuffer = Traits::DoubleSmemBuffer;
static constexpr auto Scheduler = GemmPipelineScheduler::Default;
static constexpr index_t VectorLoadSize = Traits::_VectorSize;
// In the base situation, the Preshuffle setting should be false.
static constexpr bool Preshuffle = false;
@@ -169,11 +167,10 @@ template <typename ADataType_,
typename CDataType_,
typename BlockGemmShape_,
typename Traits_,
typename ComputeDataType_ = ADataType_,
bool FixedVectorSize_ = false,
index_t VectorSizeA_ = 1,
index_t VectorSizeB_ = 1,
GemmLoopOrder BlockGemmLoopOrder_ = GemmLoopOrder::KMN>
typename ComputeDataType_ = ADataType_,
bool FixedVectorSize_ = false,
index_t VectorSizeA_ = 1,
index_t VectorSizeB_ = 1>
using GemmPipelineProblem = GemmPipelineProblemBase<ADataType_,
BDataType_,
CDataType_,
@@ -182,22 +179,20 @@ using GemmPipelineProblem = GemmPipelineProblemBase<ADataType_,
ComputeDataType_,
FixedVectorSize_,
VectorSizeA_,
VectorSizeB_,
BlockGemmLoopOrder_>;
VectorSizeB_>;
template <typename ADataType_,
typename BDataType_,
typename CDataType_,
typename BlockGemmShape_,
typename Traits_,
GemmPipelineScheduler Scheduler_ = GemmPipelineScheduler::Intrawave,
bool HasHotLoop_ = true,
TailNumber TailNum_ = TailNumber::Full,
typename ComputeDataType_ = ADataType_,
bool FixedVectorSize_ = false,
index_t VectorSizeA_ = 1,
index_t VectorSizeB_ = 1,
GemmLoopOrder BlockGemmLoopOrder_ = GemmLoopOrder::KMN>
GemmPipelineScheduler Scheduler_ = GemmPipelineScheduler::Intrawave,
bool HasHotLoop_ = true,
TailNumber TailNum_ = TailNumber::Full,
typename ComputeDataType_ = ADataType_,
bool FixedVectorSize_ = false,
index_t VectorSizeA_ = 1,
index_t VectorSizeB_ = 1>
struct UniversalGemmPipelineProblem
{
using Traits = remove_cvref_t<Traits_>;
@@ -229,9 +224,8 @@ struct UniversalGemmPipelineProblem
static constexpr auto Scheduler = Scheduler_;
static constexpr bool Preshuffle = Traits::Preshuffle;
static constexpr index_t VectorSizeA = VectorSizeA_;
static constexpr index_t VectorSizeB = VectorSizeB_;
static constexpr GemmLoopOrder BlockGemmLoopOrder = BlockGemmLoopOrder_;
static constexpr index_t VectorSizeA = VectorSizeA_;
static constexpr index_t VectorSizeB = VectorSizeB_;
static constexpr auto HasHotLoop = HasHotLoop_;
static constexpr auto TailNum = TailNum_;