Initial re-implementation of pipeline qr_ks_vs_whole_k_prefetch in looping Gemm0 along n0 dimension

This commit is contained in:
Qianfeng Zhang
2025-12-04 09:09:47 +00:00
parent 30727c48fc
commit 5fada1ce99
5 changed files with 1391 additions and 816 deletions

View File

@@ -23,6 +23,44 @@
namespace ck_tile {
namespace detail {
// A helper struct for detecting n0loop
template <typename T, typename = void>
struct has_n0loop_flag : std::false_type
{
};
template <typename T>
struct has_n0loop_flag<
T,
std::enable_if_t<std::is_convertible_v<decltype(T::kUseN0Loop), bool> && T::kUseN0Loop>>
: std::true_type
{
};
template <typename T>
static inline constexpr bool is_n0loop_pipeline_v = has_n0loop_flag<T>::value;
// A helper struct for detecting ignore_fast_exp2 flag
template <typename T, typename = void>
struct has_ignore_fast_exp2_flag : std::false_type
{
};
template <typename T>
struct has_ignore_fast_exp2_flag<
T,
std::enable_if_t<std::is_convertible_v<decltype(T::kIgnoreFastExp2), bool> &&
T::kIgnoreFastExp2>> : std::true_type
{
};
template <typename T>
static inline constexpr bool ignore_fast_exp2_v = has_ignore_fast_exp2_flag<T>::value;
}; // namespace detail
template <typename FmhaPipeline_, typename EpiloguePipeline_>
struct FmhaFwdKernel
{
@@ -402,7 +440,9 @@ struct FmhaFwdKernel
num_head_q,
nhead_ratio_qk,
#if CK_TILE_FMHA_FWD_FAST_EXP2
static_cast<float>(scale_s * ck_tile::log2e_v<>),
detail::ignore_fast_exp2_v<FmhaPipeline>
? scale_s
: static_cast<float>(scale_s * ck_tile::log2e_v<>),
#else
scale_s,
#endif
@@ -741,7 +781,9 @@ struct FmhaFwdKernel
num_head_q,
nhead_ratio_qk,
#if CK_TILE_FMHA_FWD_FAST_EXP2
static_cast<float>(scale_s * ck_tile::log2e_v<>),
detail::ignore_fast_exp2_v<FmhaPipeline>
? scale_s
: static_cast<float>(scale_s * ck_tile::log2e_v<>),
#else
scale_s,
#endif
@@ -1303,10 +1345,21 @@ struct FmhaFwdKernel
number<1>{});
constexpr bool kPadSeqLenK_ = kUseAsyncCopy ? kPadSeqLenK : false;
return pad_tensor_view(
k_dram_naive,
make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kK0>{}),
sequence<kPadSeqLenK_, kPadHeadDimQ>{});
if constexpr(detail::is_n0loop_pipeline_v<FmhaPipeline>)
{
return pad_tensor_view(k_dram_naive,
make_tuple(number<FmhaPipeline::kK1>{},
number<FmhaPipeline::kSubQKHeaddim>{}),
sequence<kPadSeqLenK_, kPadHeadDimQ>{});
}
else
{
return pad_tensor_view(
k_dram_naive,
make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kK0>{}),
sequence<kPadSeqLenK_, kPadHeadDimQ>{});
}
}();
const auto v_dram = [&]() {
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
@@ -1359,10 +1412,22 @@ struct FmhaFwdKernel
}(),
{i_m0, 0});
auto k_dram_window = make_tile_window(
k_dram,
make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kK0>{}),
{0, 0});
auto k_dram_window = [&]() {
if constexpr(detail::is_n0loop_pipeline_v<FmhaPipeline>)
{
return make_tile_window(k_dram,
make_tuple(number<FmhaPipeline::kK1>{},
number<FmhaPipeline::kSubQKHeaddim>{}),
{0, 0});
}
else
{
return make_tile_window(
k_dram,
make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kK0>{}),
{0, 0});
}
}();
auto v_dram_window = make_tile_window(
v_dram,
@@ -1508,7 +1573,10 @@ struct FmhaFwdKernel
*(reinterpret_cast<const SaccDataType*>(kargs.alibi_slope_ptr) +
i_batch_ * kargs.alibi_slope_stride + i_nhead_);
#if CK_TILE_FMHA_FWD_FAST_EXP2
slope *= ck_tile::log2e_v<>;
if constexpr(!detail::ignore_fast_exp2_v<FmhaPipeline>)
{
slope *= ck_tile::log2e_v<>;
}
#endif
if constexpr(kHasMask)
{
@@ -2247,7 +2315,10 @@ struct FmhaFwdKernel
*(reinterpret_cast<const SaccDataType*>(kargs.alibi_slope_ptr) +
i_batch_ * kargs.alibi_slope_stride + i_nhead_);
#if CK_TILE_FMHA_FWD_FAST_EXP2
slope *= ck_tile::log2e_v<>;
if constexpr(!detail::ignore_fast_exp2_v<FmhaPipeline>)
{
slope *= ck_tile::log2e_v<>;
}
#endif
if constexpr(kHasMask)
{

View File

@@ -4,17 +4,18 @@
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2_custom_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_one_warp_v1.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2_prefetch_k.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2_prefetch_n.hpp"
namespace ck_tile {
struct BlockFmhaPipelineQRKSVSWholeKPrefetchDefaultPolicy
: BlockFmhaPipelineQXKSVSCustomPolicy</* QLoadOnce = */ true,
/* AsyncCopy = */ false,
/* NumPrefetchK = */ -1,
/* NumPrefetchV = */ 2>
{
static constexpr index_t NumPrefetchV = 2;
static constexpr bool QLoadOnce = true; // needed by the kernel
static constexpr bool AsyncCopy = false; // needed by the kernel
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t IsPreloadWholeNextIterationK()
@@ -23,30 +24,11 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetchDefaultPolicy
};
template <typename Problem>
CK_TILE_DEVICE static constexpr auto GetNumKLdsBuffers()
CK_TILE_DEVICE static constexpr auto GetNumKVLdsBuffers()
{
return 2;
return 4;
}
template <typename Problem>
CK_TILE_DEVICE static constexpr auto GetNumPrefetchV()
{
using BlockFmhaShape = remove_cvref_t<typename Problem::BlockFmhaShape>;
constexpr index_t kN0 = BlockFmhaShape::kN0;
constexpr index_t kK1 = BlockFmhaShape::kK1;
constexpr index_t k1_loops = kN0 / kK1;
return min(NumPrefetchV, k1_loops);
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetNumVLdsBuffers()
{
return 2;
};
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeQRegTileDistribution()
{
@@ -57,49 +39,268 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetchDefaultPolicy
Problem::BlockFmhaShape::kQKHeaddim>();
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetQKWarpGemmKPerThreadSize()
{
using BlockGemm = remove_cvref_t<decltype(GetQKBlockGemm<Problem>())>;
constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>();
using WG = remove_cvref_t<decltype(config.template at<0>())>;
return WG::WarpGemmAttribute::kKPerThread;
};
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetKVWarpGemmKPerThreadSize()
{
using BlockGemm = remove_cvref_t<decltype(GetKVBlockGemm<Problem>())>;
constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>();
using WG = remove_cvref_t<decltype(config.template at<0>())>;
return WG::WarpGemmAttribute::kKPerThread;
};
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeBiasDramTileDistribution()
{
using BlockGemm = remove_cvref_t<decltype(GetQKBlockGemm<Problem>())>;
return BlockGemm::template MakeCBlockTile<Problem::BlockFmhaShape::kM0,
Problem::BlockFmhaShape::kN0>()
.get_tile_distribution();
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentBias()
{
using BlockGemm = remove_cvref_t<decltype(GetQKBlockGemm<Problem>())>;
constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>();
using WG = remove_cvref_t<decltype(config.template at<0>())>;
return WG::WarpGemmAttribute::Impl::kCM1PerLane;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentQ()
{
constexpr index_t MaxVectorSize = 16 / sizeof(typename Problem::QDataType);
using BlockGemm = remove_cvref_t<decltype(GetQKBlockGemm<Problem>())>;
constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>();
using WG = remove_cvref_t<decltype(config.template at<0>())>;
return min(MaxVectorSize, WG::kK / WG::WarpGemmAttribute::Impl::kABKLane);
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPackK()
{
using KDataType = remove_cvref_t<typename Problem::KDataType>;
return 8 / sizeof(KDataType);
if constexpr(GetQKWarpGemmKPerThreadSize<Problem>() >= 8)
return 8;
else
return 4;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentK()
{
using KDataType = remove_cvref_t<typename Problem::KDataType>;
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kK1;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kSubQKHeaddim;
constexpr index_t MaxVectorSize = 16 / sizeof(KDataType);
constexpr index_t ElemPerThread = (kNPerBlock * kKPerBlock) / kBlockSize;
return min(MaxVectorSize, ElemPerThread);
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPackV()
{
if constexpr(GetKVWarpGemmKPerThreadSize<Problem>() >= 8)
return 8;
else
return 4;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentV()
{
using VDataType = remove_cvref_t<typename Problem::VDataType>;
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1;
constexpr index_t ElemPerThread = kNPerBlock * kKPerBlock / kBlockSize;
constexpr index_t MaxVectorSize = 16 / sizeof(VDataType);
constexpr index_t kMaxVecLoad = min(ElemPerThread, MaxVectorSize);
constexpr index_t kMinVecLoad = 4 / sizeof(VDataType);
constexpr index_t kVecLoad = ((ElemPerThread / kMaxVecLoad) >= kMinVecLoad)
? kMaxVecLoad
: (ElemPerThread / kMinVecLoad);
return kVecLoad;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentO()
{
using BlockGemm = remove_cvref_t<decltype(GetKVBlockGemm<Problem>())>;
constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>();
using WG = remove_cvref_t<decltype(config.template at<0>())>;
return WG::WarpGemmAttribute::Impl::kCM1PerLane;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetKSingleSmemElementSpaceSize()
{
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kK1;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kSubQKHeaddim;
constexpr index_t kKPack = GetSmemKPackK<Problem>();
constexpr index_t kKVector = GetAlignmentK<Problem>();
if constexpr(GetQKWarpGemmKPerThreadSize<Problem>() >= 8)
{
static_assert(kKVector == kKPack);
return kKPerBlock * kNPerBlock;
}
else
{
static_assert(kKVector % kKPack == 0);
return kKPerBlock * kNPerBlock + kKPerBlock * kKPack / kKVector;
};
};
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetVSingleSmemElementSpaceSize()
{
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1;
constexpr index_t N1 = GetAlignmentV<Problem>();
constexpr index_t N0 = kNPerBlock / N1;
constexpr index_t kKPack = GetKVWarpGemmKPerThreadSize<Problem>();
return N0 * (N1 * kKPerBlock + kKPack);
};
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetSingleSmemElementSpaceSize()
{
return max(GetKSingleSmemElementSpaceSize<Problem>(),
GetVSingleSmemElementSpaceSize<Problem>());
};
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeKLdsBlockDescriptor()
{
constexpr index_t NumKLdsBuffers = GetNumKLdsBuffers<Problem>();
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0;
constexpr index_t NumKLdsBuffers = GetNumKVLdsBuffers<Problem>();
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kK1;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kSubQKHeaddim;
constexpr index_t kKPack = GetSmemKPackK<Problem>();
constexpr index_t kKVector = GetAlignmentK<Problem>();
static_assert(kKVector % kKPack == 0);
if constexpr(GetQKWarpGemmKPerThreadSize<Problem>() >= 8)
{
static_assert(kKVector == kKPack);
constexpr auto k_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<NumKLdsBuffers>{},
number<kKPerBlock / kKVector>{},
number<kKVector / kKPack>{},
number<kNPerBlock>{},
number<kKPack>{}),
make_tuple(number<kKPerBlock * kNPerBlock + kKPerBlock * kKPack / kKVector>{},
number<kNPerBlock * kKVector + kKPack>{},
number<kNPerBlock * kKPack>{},
number<kKPack>{},
number<1>{}),
number<kKPack>{},
number<1>{});
using KDataType = remove_cvref_t<typename Problem::KDataType>;
constexpr auto k_lds_block_desc = transform_tensor_descriptor(
k_lds_block_desc_0,
make_tuple(
make_merge_transform(make_tuple(number<NumKLdsBuffers>{}, number<kNPerBlock>{})),
make_merge_transform(make_tuple(number<kKPerBlock / kKVector>{},
number<kKVector / kKPack>{},
number<kKPack>{}))),
make_tuple(sequence<0, 3>{}, sequence<1, 2, 4>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
constexpr index_t DataTypeSize = sizeof(KDataType);
return k_lds_block_desc;
// 128 contiguous bytes mapped to 32 banks with each bank 4 contiguous bytes
constexpr auto NLdsLayer =
(32 * 4 / kKPerBlock / DataTypeSize) < 1 ? 1 : (32 * 4 / kKPerBlock / DataTypeSize);
constexpr auto k_lds_block_desc_0 =
make_naive_tensor_descriptor(make_tuple(number<NumKLdsBuffers>{},
number<kNPerBlock / NLdsLayer>{},
number<kKPerBlock / kKPack * NLdsLayer>{},
number<kKPack>{}),
make_tuple(number<kKPerBlock * kNPerBlock>{},
number<kKPerBlock * NLdsLayer>{},
number<kKPack>{},
number<1>{}),
number<kKPack>{},
number<1>{});
constexpr auto k_lds_block_desc_permuted = transform_tensor_descriptor(
k_lds_block_desc_0,
make_tuple(
make_pass_through_transform(number<NumKLdsBuffers>{}),
make_xor_transform(make_tuple(number<kNPerBlock / NLdsLayer>{},
number<kKPerBlock / kKPack * NLdsLayer>{})),
make_pass_through_transform(number<kKPack>{})),
make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3>{}),
make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3>{}));
constexpr auto k_lds_block_desc_k0_nldslayer_n_k1 = transform_tensor_descriptor(
k_lds_block_desc_permuted,
make_tuple(make_pass_through_transform(number<NumKLdsBuffers>{}),
make_pass_through_transform(number<kNPerBlock / NLdsLayer>{}),
make_unmerge_transform(
make_tuple(number<kKPerBlock / kKPack>{}, number<NLdsLayer>{})),
make_pass_through_transform(number<kKPack>{})),
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}),
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2, 3>{}, sequence<4>{}));
constexpr auto k_lds_block_desc = transform_tensor_descriptor(
k_lds_block_desc_k0_nldslayer_n_k1,
make_tuple(
make_merge_transform_v3_division_mod(
make_tuple(number<kNPerBlock / NLdsLayer>{}, number<NLdsLayer>{})),
make_merge_transform_v3_division_mod(make_tuple(number<NumKLdsBuffers>{},
number<kKPerBlock / kKPack>{},
number<kKPack>{}))),
make_tuple(sequence<1, 3>{}, sequence<0, 2, 4>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
return k_lds_block_desc;
}
else
{
static_assert(kKVector % kKPack == 0);
constexpr index_t KSingleSmemElementSpaceSize =
kKPerBlock * kNPerBlock + kKPerBlock * kKPack / kKVector;
static_assert(KSingleSmemElementSpaceSize == GetKSingleSmemElementSpaceSize<Problem>());
constexpr index_t SingleSmemElementSpaceSize = GetSingleSmemElementSpaceSize<Problem>();
constexpr auto k_lds_block_desc_0 =
make_naive_tensor_descriptor(make_tuple(number<NumKLdsBuffers>{},
number<kKPerBlock / kKVector>{},
number<kKVector / kKPack>{},
number<kNPerBlock>{},
number<kKPack>{}),
make_tuple(number<SingleSmemElementSpaceSize>{},
number<kNPerBlock * kKVector + kKPack>{},
number<kNPerBlock * kKPack>{},
number<kKPack>{},
number<1>{}),
number<kKPack>{},
number<1>{});
constexpr auto k_lds_block_desc = transform_tensor_descriptor(
k_lds_block_desc_0,
make_tuple(make_merge_transform(
make_tuple(number<NumKLdsBuffers>{}, number<kNPerBlock>{})),
make_merge_transform(make_tuple(number<kKPerBlock / kKVector>{},
number<kKVector / kKPack>{},
number<kKPack>{}))),
make_tuple(sequence<0, 3>{}, sequence<1, 2, 4>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
return k_lds_block_desc;
};
}
template <typename Problem>
@@ -108,8 +309,8 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetchDefaultPolicy
using KDataType = remove_cvref_t<typename Problem::KDataType>;
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0;
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kK1;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kSubQKHeaddim;
constexpr index_t MaxVectorSize = 16 / sizeof(KDataType);
@@ -136,44 +337,45 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetchDefaultPolicy
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeVLdsBlockDescriptor()
{
using VDataType = remove_cvref_t<typename Problem::VDataType>;
constexpr index_t NumVLdsBuffers = GetNumKVLdsBuffers<Problem>();
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1;
constexpr index_t NumVLdsBuffers = GetNumVLdsBuffers<Problem>();
constexpr index_t N1 = GetAlignmentV<Problem>();
constexpr index_t N0 = kNPerBlock / N1;
constexpr index_t Banks = get_n_lds_banks();
constexpr index_t PixelsPerRow = Banks * 4 / sizeof(VDataType);
constexpr index_t kKPack = GetSmemKPackV<Problem>();
static_assert(PixelsPerRow % kKPack == 0);
constexpr index_t NPerRow = PixelsPerRow / kKPack;
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1;
static_assert(kNPerBlock % NPerRow == 0);
static_assert(kKPerBlock % kKPack == 0);
constexpr index_t ElemPerThread = kNPerBlock * kKPerBlock / kBlockSize;
constexpr index_t VSingleSmemElementSpaceSize =
(kKPerBlock / kKPack) * (kNPerBlock / NPerRow) * (PixelsPerRow + kKPack);
// K2 is the vector size for storing shuffled tile to LDS
constexpr index_t K2 = ElemPerThread / N1;
// GetSmemKPackV() is the vector size for loading from LDS by BlockGemm
constexpr index_t kKPack = GetSmemKPackV<Problem>();
static_assert(kKPack >= K2, "Check failed!");
constexpr index_t VSingleSmemElementSpaceSize = N0 * (N1 * kKPerBlock + kKPack);
static_assert(VSingleSmemElementSpaceSize == GetVSingleSmemElementSpaceSize<Problem>());
constexpr index_t SingleSmemElementSpaceSize = GetSingleSmemElementSpaceSize<Problem>();
constexpr auto v_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<NumVLdsBuffers>{},
number<kKPerBlock / kKPack>{},
number<kNPerBlock / NPerRow>{},
number<NPerRow>{},
number<kKPack>{}),
make_tuple(number<VSingleSmemElementSpaceSize>{},
number<(kNPerBlock / NPerRow) * (PixelsPerRow + kKPack)>{},
number<PixelsPerRow + kKPack>{},
number<kKPack>{},
make_tuple(number<NumVLdsBuffers>{}, number<N0>{}, number<N1>{}, number<kKPerBlock>{}),
make_tuple(number<SingleSmemElementSpaceSize>{},
number<N1 * kKPerBlock + kKPack>{},
number<kKPerBlock>{},
number<1>{}),
number<kKPack>{},
number<1>{});
constexpr auto v_lds_block_desc = transform_tensor_descriptor(
v_lds_block_desc_0,
make_tuple(
make_merge_transform(make_tuple(
number<NumVLdsBuffers>{}, number<kNPerBlock / NPerRow>{}, number<NPerRow>{})),
make_merge_transform(make_tuple(number<kKPerBlock / kKPack>{}, number<kKPack>{}))),
make_tuple(sequence<0, 2, 3>{}, sequence<1, 4>{}),
make_tuple(make_merge_transform(
make_tuple(number<NumVLdsBuffers>{}, number<N0>{}, number<N1>{})),
make_pass_through_transform(number<kKPerBlock>{})),
make_tuple(sequence<0, 1, 2>{}, sequence<3>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
return v_lds_block_desc;
@@ -182,70 +384,55 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetchDefaultPolicy
template <typename Problem>
CK_TILE_DEVICE static constexpr auto MakeVDramTileDistribution()
{
using VLayout = remove_cvref_t<typename Problem::BlockFmhaShape::VLayout>;
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1;
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
{
constexpr index_t N1 = GetAlignmentV<Problem>();
constexpr index_t N0 = kNPerBlock / N1; // P
constexpr index_t N1 = GetAlignmentV<Problem>();
constexpr index_t N0 = kNPerBlock / N1;
constexpr index_t ElemPerThread = kNPerBlock * kKPerBlock / kBlockSize;
static_assert(ElemPerThread % N1 == 0);
constexpr index_t K3 = ElemPerThread / N1;
constexpr index_t kKPack = GetSmemKPackV<Problem>();
static_assert(kKPack % K3 == 0);
constexpr index_t K2 = kKPack / K3;
if constexpr(get_warp_size() % (K2 * N0) == 0)
{
constexpr index_t K1 = get_warp_size() / (K2 * N0);
constexpr index_t K0 = kBlockSize / get_warp_size();
static_assert(kKPerBlock == K0 * K1 * K2 * K3);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<N0, N1>, sequence<K0, K1, K2, K3>>,
tuple<sequence<2>, sequence<2, 1, 2>>,
tuple<sequence<0>, sequence<1, 0, 2>>,
sequence<2, 1>,
sequence<3, 1>>{});
}
else
{
constexpr index_t K1 = (K2 * N0) / get_warp_size();
constexpr index_t K2_m = K2 / K1;
constexpr index_t K0 = kBlockSize / get_warp_size() / K1;
static_assert(kKPerBlock == K0 * K1 * K2_m * K3);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<N0, N1>, sequence<K0, K1, K2_m, K3>>,
tuple<sequence<2, 2>, sequence<1, 2>>,
tuple<sequence<0, 1>, sequence<0, 2>>,
sequence<2, 1>,
sequence<3, 1>>{});
}
}
else
{
constexpr index_t K1 = GetAlignmentV<Problem>();
constexpr index_t K0 = kKPerBlock / K1;
constexpr index_t N2 = get_warp_size() / K0;
constexpr index_t N1 = kBlockSize / get_warp_size();
static_assert(N2 != 0, "N2 is zero, which will lead to a division by zero error.");
static_assert(N1 != 0, "N1 is zero, which will lead to a division by zero error.");
constexpr index_t N0 = kNPerBlock / (N2 * N1);
static_assert(N0 != 0);
constexpr index_t ElemPerThread = kNPerBlock * kKPerBlock / kBlockSize;
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<N0, N1, N2>, sequence<K0, K1>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<1>, sequence<2, 0>>,
sequence<1, 2>,
sequence<0, 1>>{});
}
static_assert(ElemPerThread % N1 == 0);
constexpr index_t K2 = ElemPerThread / N1;
constexpr index_t K1 = get_warp_size() / N0;
constexpr index_t K0 = kBlockSize / get_warp_size();
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<N0, N1>, sequence<K0, K1, K2>>,
tuple<sequence<2>, sequence<2, 1>>,
tuple<sequence<0>, sequence<1, 0>>,
sequence<2, 1>,
sequence<2, 1>>{});
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledVRegTileDistribution()
{
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1;
constexpr index_t N1 = GetAlignmentV<Problem>();
constexpr index_t N0 = kNPerBlock / N1;
constexpr index_t ElemPerThread = kNPerBlock * kKPerBlock / kBlockSize;
static_assert(ElemPerThread % N1 == 0);
constexpr index_t K2 = ElemPerThread / N1;
constexpr index_t K1 = get_warp_size() / N0;
constexpr index_t K0 = kBlockSize / get_warp_size();
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<N0, N1>, sequence<K0, K1, K2>>,
tuple<sequence<2>, sequence<2, 1>>,
tuple<sequence<0>, sequence<1, 0>>,
sequence<1, 2>,
sequence<1, 2>>{});
}
template <typename Problem>
@@ -257,113 +444,163 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetchDefaultPolicy
typename Problem::SaccDataType,
Problem::kNumGemm0Warps * get_warp_size(),
TileGemmShape<sequence<Problem::BlockFmhaShape::kM0,
Problem::BlockFmhaShape::kN0,
Problem::BlockFmhaShape::kK0>,
Problem::BlockFmhaShape::kK1,
Problem::BlockFmhaShape::kQKHeaddim>,
typename Problem::BlockFmhaShape::Gemm0BlockWarps,
typename Problem::BlockFmhaShape::Gemm0WarpTile>>;
constexpr auto warp_gemm = []() {
if constexpr(get_warp_size() == 64 &&
std::is_same_v<typename Problem::QDataType, fp8_t> &&
std::is_same_v<typename Problem::KDataType, fp8_t> &&
std::is_same_v<typename Problem::SaccDataType, float>)
auto warp_gemm = [&]() {
if constexpr((std::is_same_v<typename Problem::QDataType, half_t> ||
std::is_same_v<typename Problem::QDataType, bf16_t>)&&std::
is_same_v<typename Problem::SaccDataType, float>)
{
static_assert(Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{}) == 32);
static_assert(Problem::BlockFmhaShape::Gemm0WarpTile::at(number<1>{}) == 32);
static_assert(Problem::BlockFmhaShape::Gemm0WarpTile::at(number<2>{}) == 32);
constexpr index_t WarpGemmM =
Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{});
constexpr index_t WarpGemmK =
Problem::BlockFmhaShape::Gemm0WarpTile::at(number<2>{});
#ifdef __gfx950__
static_assert((WarpGemmM == 16 && WarpGemmK == 32) ||
(WarpGemmM == 32 && WarpGemmK == 16),
"Not supported WarpGemm sizes!");
#else
static_assert((WarpGemmM == 16 && (WarpGemmK == 16 || WarpGemmK == 32)) ||
(WarpGemmM == 32 && (WarpGemmK == 8 || WarpGemmK == 16)),
"Not supported WarpGemm sizes!");
#endif
// TODO: hard coded here. Otherwise, it produces incorrect results
constexpr index_t swizzle_factor = 4;
return WarpGemmMfmaFp8Fp8F32M32N32K32SwizzleBTransposedCDistribution<
swizzle_factor>{};
}
else
{
constexpr bool SwizzleA =
Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{}) == 32;
return WarpGemmDispatcher<typename Problem::QDataType,
typename Problem::KDataType,
typename Problem::SaccDataType,
Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{}),
Problem::BlockFmhaShape::Gemm0WarpTile::at(number<1>{}),
Problem::BlockFmhaShape::Gemm0WarpTile::at(number<2>{}),
true, // TransposeC
SwizzleA>{};
true,
false,
false,
WGAttrNumAccessEnum::Single>{};
}
else
{
static_assert(false, "Not supported data types!");
}
}();
using WarpGemm = remove_cvref_t<decltype(warp_gemm)>;
using BlockGemmPolicy =
BlockGemmARegBSmemCRegV2CustomPolicy<typename Problem::QDataType,
typename Problem::KDataType,
typename Problem::SaccDataType,
typename Problem::BlockFmhaShape::Gemm0BlockWarps,
decltype(warp_gemm)>;
WarpGemm>;
if constexpr(1 < Problem::kNumGemm0Warps)
return BlockGemmARegBSmemCRegV2<GemmProblem, BlockGemmPolicy>{};
return BlockGemmARegBSmemCRegV2PrefetchK<GemmProblem, BlockGemmPolicy>{};
else
return BlockGemmARegBSmemCRegOneWarpV1<GemmProblem, BlockGemmPolicy>{};
}
// leave some exclusive space so that the second v_lds buffer will nenver overlap with the first
// k_lds bufffer
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetExclusiveKLdsBytes()
CK_TILE_HOST_DEVICE static constexpr auto GetKVBlockGemm()
{
constexpr index_t single_k_lds_buffer_size =
GetSmemSizeK<Problem>() / GetNumKLdsBuffers<Problem>();
constexpr index_t single_v_lds_buffer_size =
GetSmemSizeV<Problem>() / GetNumVLdsBuffers<Problem>();
using GemmProblem =
BlockGemmProblem<typename Problem::PDataType,
typename Problem::VDataType,
typename Problem::OaccDataType,
Problem::kNumGemm1Warps * get_warp_size(),
TileGemmShape<sequence<Problem::BlockFmhaShape::kM0,
Problem::BlockFmhaShape::kN1,
Problem::BlockFmhaShape::kK1>,
typename Problem::BlockFmhaShape::Gemm1BlockWarps,
typename Problem::BlockFmhaShape::Gemm1WarpTile>>;
if constexpr(single_k_lds_buffer_size <= single_v_lds_buffer_size)
return 0;
auto warp_gemm = [&]() {
if constexpr((std::is_same_v<typename Problem::VDataType, half_t> ||
std::is_same_v<typename Problem::VDataType, bf16_t>)&&std::
is_same_v<typename Problem::OaccDataType, float>)
{
constexpr index_t WarpGemmM =
Problem::BlockFmhaShape::Gemm1WarpTile::at(number<0>{});
constexpr index_t WarpGemmK =
Problem::BlockFmhaShape::Gemm1WarpTile::at(number<2>{});
#ifdef __gfx950__
static_assert((WarpGemmM == 16 && WarpGemmK == 32) ||
(WarpGemmM == 32 && WarpGemmK == 16),
"Not supported WarpGemm sizes!");
#else
static_assert((WarpGemmM == 16 && (WarpGemmK == 16 || WarpGemmK == 32)) ||
(WarpGemmM == 32 && (WarpGemmK == 8 || WarpGemmK == 16)),
"Not supported WarpGemm sizes!");
#endif
if constexpr((WarpGemmM == 16 && WarpGemmK == 32) ||
(WarpGemmM == 32 && WarpGemmK == 16))
return WarpGemmDispatcher<
typename Problem::PDataType,
typename Problem::VDataType,
typename Problem::OaccDataType,
Problem::BlockFmhaShape::Gemm1WarpTile::at(number<0>{}),
Problem::BlockFmhaShape::Gemm1WarpTile::at(number<1>{}),
Problem::BlockFmhaShape::Gemm1WarpTile::at(number<2>{}),
true,
false,
false,
WGAttrNumAccessEnum::Double>{};
else
return WarpGemmDispatcher<
typename Problem::PDataType,
typename Problem::VDataType,
typename Problem::OaccDataType,
Problem::BlockFmhaShape::Gemm1WarpTile::at(number<0>{}),
Problem::BlockFmhaShape::Gemm1WarpTile::at(number<1>{}),
Problem::BlockFmhaShape::Gemm1WarpTile::at(number<2>{}),
true,
false,
false,
WGAttrNumAccessEnum::Single>{};
}
else
{
static_assert(false, "Not supported data types!");
}
}();
using WarpGemm = remove_cvref_t<decltype(warp_gemm)>;
using BlockGemmPolicy =
BlockGemmARegBSmemCRegV2CustomPolicy<typename Problem::QDataType,
typename Problem::KDataType,
typename Problem::OaccDataType,
typename Problem::BlockFmhaShape::Gemm1BlockWarps,
WarpGemm>;
if constexpr(1 < Problem::kNumGemm1Warps)
return BlockGemmARegBSmemCRegV2PrefetchN<GemmProblem, BlockGemmPolicy>{};
else
return integer_least_multiple(single_k_lds_buffer_size - single_v_lds_buffer_size, 64);
};
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t IsFirstKLdsBufferOverlapLastVLdsBuffer()
{
using BlockFmhaShape = remove_cvref_t<typename Problem::BlockFmhaShape>;
constexpr index_t k1_loops = BlockFmhaShape::kN0 / BlockFmhaShape::kK1;
constexpr index_t num_k_lds_buffers = GetNumKLdsBuffers<Problem>();
constexpr index_t num_v_lds_buffers = GetNumVLdsBuffers<Problem>();
constexpr index_t last_v_lds_buffer_offset =
MakeVLdsBlockDescriptor<Problem>().get_element_space_size() / num_v_lds_buffers *
((k1_loops - 1) % num_v_lds_buffers) * sizeof(typename Problem::VDataType);
constexpr index_t first_k_lds_buffer_size =
MakeKLdsBlockDescriptor<Problem>().get_element_space_size() / num_k_lds_buffers *
sizeof(typename Problem::KDataType);
return GetExclusiveKLdsBytes<Problem>() + last_v_lds_buffer_offset <
first_k_lds_buffer_size;
};
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeK()
{
return MakeKLdsBlockDescriptor<Problem>().get_element_space_size() *
sizeof(typename Problem::KDataType);
return BlockGemmARegBSmemCRegOneWarpV1<GemmProblem, BlockGemmPolicy>{};
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeV()
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeKV()
{
return MakeVLdsBlockDescriptor<Problem>().get_element_space_size() *
sizeof(typename Problem::VDataType);
}
constexpr index_t num_kv_lds_buffers = GetNumKVLdsBuffers<Problem>();
return num_kv_lds_buffers * GetSingleSmemElementSpaceSize<Problem>() *
max(sizeof(typename Problem::KDataType), sizeof(typename Problem::VDataType));
};
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeDropout()
{
return 0;
};
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
{
// assume V can reuse the other shared memory by K except the first
// assume Dropout can reuse the shared memory by V
return GetExclusiveKLdsBytes<Problem>() +
max(GetSmemSizeK<Problem>() - GetExclusiveKLdsBytes<Problem>(),
max(GetSmemSizeV<Problem>(), GetSmemSizeDropout<Problem>(0)));
return GetSmemSizeKV<Problem>() + GetSmemSizeDropout<Problem>();
}
};

View File

@@ -0,0 +1,299 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2_default_policy.hpp"
namespace ck_tile {
// A is block distributed tensor
// B is block window on shared memory
// C is block distributed tensor
template <typename Problem_, typename Policy_ = BlockGemmARegBSmemCRegV2DefaultPolicy>
struct BlockGemmARegBSmemCRegV2PrefetchK
{
using Problem = remove_cvref_t<Problem_>;
using Policy = remove_cvref_t<Policy_>;
using ADataType = remove_cvref_t<typename Problem::ADataType>;
using BDataType = remove_cvref_t<typename Problem::BDataType>;
using CDataType = remove_cvref_t<typename Problem::CDataType>;
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
static constexpr index_t kBlockSize = Problem::kBlockSize;
// C += A * B
template <typename CBlockTensor, typename ABlockTensorTmp, typename BBlockWindowTmp>
CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor,
const ABlockTensorTmp& a_block_tensor_tmp,
const BBlockWindowTmp& b_block_window_tmp) const
{
static_assert(
std::is_same_v<ADataType, remove_cv_t<typename ABlockTensorTmp::DataType>> &&
std::is_same_v<BDataType, remove_cv_t<typename BBlockWindowTmp::DataType>> &&
std::is_same_v<CDataType, remove_cv_t<typename CBlockTensor::DataType>>,
"wrong!");
constexpr index_t MPerBlock = ABlockTensorTmp{}.get_lengths()[number<0>{}];
constexpr index_t NPerBlock = BBlockWindowTmp{}.get_window_lengths()[number<0>{}];
constexpr index_t KPerBlock = ABlockTensorTmp{}.get_lengths()[number<1>{}];
static_assert(MPerBlock == BlockGemmShape::kM && NPerBlock == BlockGemmShape::kN &&
KPerBlock == BlockGemmShape::kK,
"wrong!");
constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>();
using WG = remove_cvref_t<decltype(config.template at<0>())>;
constexpr index_t MWarp = config.template at<1>();
constexpr index_t NWarp = config.template at<2>();
constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM);
constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN);
constexpr index_t KIterPerWarp = KPerBlock / WG::kK;
constexpr index_t NPerBlockPerIter = NPerBlock / NIterPerWarp;
constexpr index_t KPerBlockPerIter = KPerBlock / KIterPerWarp;
const index_t iNWarp = get_warp_id<false>() % NWarp;
static_assert(NWarp == 1, "Check failed!");
constexpr auto c_block_outer_dstr_encoding =
tile_distribution_encoding<sequence<NWarp>,
tuple<sequence<MIterPerWarp, MWarp>, sequence<NIterPerWarp>>,
tuple<sequence<1, 0>>,
tuple<sequence<1, 0>>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{});
// constrcut from A-block-tensor from A-Block-tensor-tmp
// FIXME: need method to check a_block_tensor and a_block_tensor_tmp have equivalent
// distribution
auto a_block_tensor = make_static_distributed_tensor<typename ABlockTensorTmp::DataType>(
MakeABlockTileDistribution());
a_block_tensor.get_thread_buffer() = a_block_tensor_tmp.get_thread_buffer();
// construct B-warp-window
auto b_warp_window_tmp = make_tile_window(
b_block_window_tmp.get_bottom_tensor_view(),
make_tuple(number<WG::kN>{}, number<WG::kK>{}),
b_block_window_tmp.get_window_origin() + multi_index<2>{iNWarp * WG::kN, 0},
make_static_tile_distribution(typename WG::BWarpDstrEncoding{}));
#if 0 // FIXME: using array will cause register spill
array<array<decltype(b_warp_window_tmp), KIterPerWarp>, NIterPerWarp> b_warp_windows{
{b_warp_window_tmp}};
for(index_t nIter = 0; nIter < NIterPerWarp; nIter++)
{
for(index_t kIter = 0; kIter < KIterPerWarp; kIter++)
{
move_tile_window(b_warp_windows(nIter)(kIter),
{nIter * NPerBlockPerIter, kIter * KPerBlockPerIter});
}
}
#else
statically_indexed_array<
statically_indexed_array<decltype(b_warp_window_tmp), KIterPerWarp>,
NIterPerWarp>
b_warp_windows;
#endif
// check C-block-distribution
static_assert(
std::is_same_v<remove_cvref_t<decltype(c_block_dstr_encode)>,
remove_cvref_t<decltype(CBlockTensor::get_tile_distribution()
.get_static_tile_distribution_encoding())>>,
"wrong!");
using AWarpDstr = typename WG::AWarpDstr;
using CWarpDstr = typename WG::CWarpDstr;
using AWarpTensor = typename WG::AWarpTensor;
using CWarpTensor = typename WG::CWarpTensor;
constexpr auto a_warp_y_lengths =
to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
constexpr auto c_warp_y_lengths =
to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
constexpr auto a_warp_y_index_zeros = uniform_sequence_gen_t<AWarpDstr::NDimY, 0>{};
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
constexpr auto I0 = number<0>{};
constexpr auto I1 = number<1>{};
// hot loop:
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
using b_warp_tensor_type = decltype(load_tile(b_warp_windows(I0)(I0)));
statically_indexed_array<b_warp_tensor_type, KIterPerWarp> b_warp_tensors;
b_warp_windows(nIter)(I0) = b_warp_window_tmp;
move_tile_window(b_warp_windows(nIter)(I0),
{nIter * NPerBlockPerIter, 0 * KPerBlockPerIter});
b_warp_tensors[I0] = load_tile(b_warp_windows(nIter)(I0));
__builtin_amdgcn_sched_barrier(0);
b_warp_windows(nIter)(I1) = b_warp_window_tmp;
move_tile_window(b_warp_windows(nIter)(I1),
{nIter * NPerBlockPerIter, 1 * KPerBlockPerIter});
b_warp_tensors[I1] = load_tile(b_warp_windows(nIter)(I1));
__builtin_amdgcn_sched_barrier(0);
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
// read A warp tensor from A block tensor
AWarpTensor a_warp_tensor;
a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, 0>{}, a_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
// warp GEMM
auto c_warp_tensor = WG{}(a_warp_tensor, b_warp_tensors[I0]);
// WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor_array[nIter]);
// write C warp tensor into C block tensor
c_block_tensor.set_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
c_warp_tensor.get_thread_buffer());
});
static_for<1, KIterPerWarp, 1>{}([&](auto kIter) {
// read B warp tensor from B Block window
if constexpr(kIter < KIterPerWarp - 1)
{
b_warp_windows(nIter)(number<kIter + 1>{}) = b_warp_window_tmp;
move_tile_window(b_warp_windows(nIter)(number<kIter + 1>{}),
{nIter * NPerBlockPerIter, (kIter + 1) * KPerBlockPerIter});
b_warp_tensors[number<kIter + 1>{}] =
load_tile(b_warp_windows(nIter)(number<kIter + 1>{}));
};
__builtin_amdgcn_sched_barrier(0);
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
// read A warp tensor from A block tensor
AWarpTensor a_warp_tensor;
a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, kIter>{}, a_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
// read C warp tensor from C block tensor
CWarpTensor c_warp_tensor;
c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
// warp GEMM
WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensors[kIter]);
// WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor_array[nIter]);
// write C warp tensor into C block tensor
c_block_tensor.set_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
c_warp_tensor.get_thread_buffer());
});
});
});
}
template <index_t MPerBlock = BlockGemmShape::kM, index_t KPerBlock = BlockGemmShape::kK>
CK_TILE_DEVICE static constexpr auto MakeABlockDistributionEncode()
{
constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>();
using WG = remove_cvref_t<decltype(config.template at<0>())>;
constexpr index_t MWarp = config.template at<1>();
constexpr index_t NWarp = config.template at<2>();
constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM);
constexpr index_t KIterPerWarp = KPerBlock / WG::kK;
constexpr auto a_block_outer_dstr_encoding =
tile_distribution_encoding<sequence<NWarp>,
tuple<sequence<MIterPerWarp, MWarp>, sequence<KIterPerWarp>>,
tuple<sequence<1, 0>>,
tuple<sequence<1, 0>>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
a_block_outer_dstr_encoding, typename WG::AWarpDstrEncoding{});
return a_block_dstr_encode;
}
template <index_t MPerBlock = BlockGemmShape::kM, index_t KPerBlock = BlockGemmShape::kK>
CK_TILE_DEVICE static constexpr auto MakeABlockTileDistribution()
{
constexpr auto a_block_dstr_encode = MakeABlockDistributionEncode<MPerBlock, KPerBlock>();
return make_static_tile_distribution(a_block_dstr_encode);
}
template <index_t MPerBlock = BlockGemmShape::kM, index_t NPerBlock = BlockGemmShape::kN>
CK_TILE_DEVICE static constexpr auto MakeCBlockDistributionEncode()
{
constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>();
using WG = remove_cvref_t<decltype(config.template at<0>())>;
constexpr index_t MWarp = config.template at<1>();
constexpr index_t NWarp = config.template at<2>();
constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM);
constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN);
static_assert(NWarp == 1, "Check failed!");
constexpr auto c_block_outer_dstr_encoding =
tile_distribution_encoding<sequence<NWarp>,
tuple<sequence<MIterPerWarp, MWarp>, sequence<NIterPerWarp>>,
tuple<sequence<1, 0>>,
tuple<sequence<1, 0>>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{});
return c_block_dstr_encode;
}
template <index_t MPerBlock = BlockGemmShape::kM, index_t NPerBlock = BlockGemmShape::kN>
CK_TILE_DEVICE static constexpr auto MakeCBlockTile()
{
constexpr auto c_block_dstr_encode = MakeCBlockDistributionEncode<MPerBlock, NPerBlock>();
constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode);
auto c_block_tensor = make_static_distributed_tensor<CDataType>(c_block_dstr);
return c_block_tensor;
}
// C = A * B
template <typename ABlockTensorTmp, typename BBlockWindowTmp>
CK_TILE_DEVICE auto operator()(const ABlockTensorTmp& a_block_tensor_tmp,
const BBlockWindowTmp& b_block_window_tmp) const
{
auto c_block_tensor = MakeCBlockTile();
operator()(c_block_tensor, a_block_tensor_tmp, b_block_window_tmp);
return c_block_tensor;
}
};
} // namespace ck_tile

View File

@@ -0,0 +1,242 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2_default_policy.hpp"
namespace ck_tile {
// A is block distributed tensor
// B is block window on shared memory
// C is block distributed tensor
template <typename Problem_, typename Policy_ = BlockGemmARegBSmemCRegV2DefaultPolicy>
struct BlockGemmARegBSmemCRegV2PrefetchN
{
using Problem = remove_cvref_t<Problem_>;
using Policy = remove_cvref_t<Policy_>;
using ADataType = remove_cvref_t<typename Problem::ADataType>;
using BDataType = remove_cvref_t<typename Problem::BDataType>;
using CDataType = remove_cvref_t<typename Problem::CDataType>;
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
static constexpr index_t kBlockSize = Problem::kBlockSize;
// C += A * B
template <typename CBlockTensor, typename ABlockTensorTmp, typename BBlockWindowTmp>
CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor,
const ABlockTensorTmp& a_block_tensor_tmp,
const BBlockWindowTmp& b_block_window_tmp) const
{
static_assert(
std::is_same_v<ADataType, remove_cv_t<typename ABlockTensorTmp::DataType>> &&
std::is_same_v<BDataType, remove_cv_t<typename BBlockWindowTmp::DataType>> &&
std::is_same_v<CDataType, remove_cv_t<typename CBlockTensor::DataType>>,
"wrong!");
constexpr index_t MPerBlock = ABlockTensorTmp{}.get_lengths()[number<0>{}];
constexpr index_t NPerBlock = BBlockWindowTmp{}.get_window_lengths()[number<0>{}];
constexpr index_t KPerBlock = ABlockTensorTmp{}.get_lengths()[number<1>{}];
static_assert(MPerBlock == BlockGemmShape::kM && NPerBlock == BlockGemmShape::kN &&
KPerBlock == BlockGemmShape::kK,
"wrong!");
constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>();
using WG = remove_cvref_t<decltype(config.template at<0>())>;
constexpr index_t MWarp = config.template at<1>();
constexpr index_t NWarp = config.template at<2>();
constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM);
constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN);
constexpr index_t KIterPerWarp = KPerBlock / WG::kK;
constexpr index_t NPerBlockPerIter = NPerBlock / NIterPerWarp;
constexpr index_t KPerBlockPerIter = KPerBlock / KIterPerWarp;
const index_t iNWarp = get_warp_id<false>() % NWarp;
constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding<
sequence<>,
tuple<sequence<MIterPerWarp, MWarp>, sequence<NIterPerWarp, NWarp>>,
tuple<sequence<1, 2>>,
tuple<sequence<1, 1>>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{});
// constrcut from A-block-tensor from A-Block-tensor-tmp
// FIXME: need method to check a_block_tensor and a_block_tensor_tmp have equivalent
// distribution
auto a_block_tensor = make_static_distributed_tensor<typename ABlockTensorTmp::DataType>(
MakeABlockTileDistribution());
a_block_tensor.get_thread_buffer() = a_block_tensor_tmp.get_thread_buffer();
// construct B-warp-window
auto b_warp_window_tmp = make_tile_window(
b_block_window_tmp.get_bottom_tensor_view(),
make_tuple(number<WG::kN>{}, number<WG::kK>{}),
b_block_window_tmp.get_window_origin() + multi_index<2>{iNWarp * WG::kN, 0},
make_static_tile_distribution(typename WG::BWarpDstrEncoding{}));
statically_indexed_array<
statically_indexed_array<decltype(b_warp_window_tmp), KIterPerWarp>,
NIterPerWarp>
b_warp_windows;
// check C-block-distribution
static_assert(
std::is_same_v<remove_cvref_t<decltype(c_block_dstr_encode)>,
remove_cvref_t<decltype(CBlockTensor::get_tile_distribution()
.get_static_tile_distribution_encoding())>>,
"wrong!");
using AWarpDstr = typename WG::AWarpDstr;
using CWarpDstr = typename WG::CWarpDstr;
using AWarpTensor = typename WG::AWarpTensor;
using CWarpTensor = typename WG::CWarpTensor;
constexpr auto a_warp_y_lengths =
to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
constexpr auto c_warp_y_lengths =
to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
constexpr auto a_warp_y_index_zeros = uniform_sequence_gen_t<AWarpDstr::NDimY, 0>{};
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
constexpr auto I0 = number<0>{};
using b_warp_tensor_type = decltype(load_tile(b_warp_windows(I0)(I0)));
statically_indexed_array<statically_indexed_array<b_warp_tensor_type, KIterPerWarp>,
NIterPerWarp>
b_warp_tensors;
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
b_warp_windows(I0)(kIter) = b_warp_window_tmp;
move_tile_window(b_warp_windows(I0)(kIter),
{0 * NPerBlockPerIter, kIter * KPerBlockPerIter});
b_warp_tensors(I0)(kIter) = load_tile(b_warp_windows(I0)(kIter));
});
__builtin_amdgcn_sched_barrier(0);
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
if constexpr(nIter < NIterPerWarp - 1)
{
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
b_warp_windows(number<nIter + 1>{})(kIter) = b_warp_window_tmp;
move_tile_window(b_warp_windows(number<nIter + 1>{})(kIter),
{(nIter + 1) * NPerBlockPerIter, kIter * KPerBlockPerIter});
b_warp_tensors(number<nIter + 1>{})(kIter) =
load_tile(b_warp_windows(number<nIter + 1>{})(kIter));
});
};
__builtin_amdgcn_sched_barrier(0);
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
// read C warp tensor from C block tensor
CWarpTensor c_warp_tensor;
c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
// read A warp tensor from A block tensor
AWarpTensor a_warp_tensor;
a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, kIter>{}, a_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
// warp GEMM
WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensors[nIter][kIter]);
});
// write C warp tensor into C block tensor
c_block_tensor.set_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
c_warp_tensor.get_thread_buffer());
});
});
}
template <index_t MPerBlock = BlockGemmShape::kM, index_t KPerBlock = BlockGemmShape::kK>
CK_TILE_DEVICE static constexpr auto MakeABlockTileDistribution()
{
constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>();
using WG = remove_cvref_t<decltype(config.template at<0>())>;
constexpr index_t MWarp = config.template at<1>();
constexpr index_t NWarp = config.template at<2>();
constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM);
constexpr index_t KIterPerWarp = KPerBlock / WG::kK;
constexpr auto a_block_outer_dstr_encoding =
tile_distribution_encoding<sequence<NWarp>,
tuple<sequence<MIterPerWarp, MWarp>, sequence<KIterPerWarp>>,
tuple<sequence<1, 0>>,
tuple<sequence<1, 0>>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
a_block_outer_dstr_encoding, typename WG::AWarpDstrEncoding{});
return make_static_tile_distribution(a_block_dstr_encode);
}
CK_TILE_DEVICE static constexpr auto MakeCBlockTile()
{
constexpr index_t MPerBlock = BlockGemmShape::kM;
constexpr index_t NPerBlock = BlockGemmShape::kN;
constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>();
using WG = remove_cvref_t<decltype(config.template at<0>())>;
constexpr index_t MWarp = config.template at<1>();
constexpr index_t NWarp = config.template at<2>();
constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM);
constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN);
// constexpr index_t KIterPerWarp = KPerBlock / WG::kK;
constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding<
sequence<>,
tuple<sequence<MIterPerWarp, MWarp>, sequence<NIterPerWarp, NWarp>>,
tuple<sequence<1, 2>>,
tuple<sequence<1, 1>>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{});
constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode);
auto c_block_tensor = make_static_distributed_tensor<CDataType>(c_block_dstr);
return c_block_tensor;
}
// C = A * B
template <typename ABlockTensorTmp, typename BBlockWindowTmp>
CK_TILE_DEVICE auto operator()(const ABlockTensorTmp& a_block_tensor_tmp,
const BBlockWindowTmp& b_block_window_tmp) const
{
auto c_block_tensor = MakeCBlockTile();
operator()(c_block_tensor, a_block_tensor_tmp, b_block_window_tmp);
return c_block_tensor;
}
};
} // namespace ck_tile