mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-11 17:00:18 +00:00
compile pass
This commit is contained in:
@@ -41,10 +41,10 @@ CK_TILE_DEVICE int32x4_t make_wave_buffer_resource(const void* ptr, uint32_t siz
|
||||
{
|
||||
buffer_resource res{ptr, size, CK_TILE_BUFFER_RESOURCE_3RD_DWORD};
|
||||
int32x4_t r = __builtin_bit_cast(int32x4_t, res);
|
||||
r.x = __builtin_amdgcn_readfirstlane(r.x);
|
||||
r.y = __builtin_amdgcn_readfirstlane(r.y);
|
||||
r.z = __builtin_amdgcn_readfirstlane(r.z);
|
||||
r.w = __builtin_amdgcn_readfirstlane(r.w);
|
||||
// r.x = __builtin_amdgcn_readfirstlane(r.x);
|
||||
// r.y = __builtin_amdgcn_readfirstlane(r.y);
|
||||
// r.z = __builtin_amdgcn_readfirstlane(r.z);
|
||||
// r.w = __builtin_amdgcn_readfirstlane(r.w);
|
||||
return r;
|
||||
}
|
||||
|
||||
|
||||
@@ -32,10 +32,10 @@ CK_TILE_DEVICE int32x4_t make_wave_buffer_resource(const void* ptr, uint32_t siz
|
||||
{
|
||||
buffer_resource res{ptr, size, CK_TILE_BUFFER_RESOURCE_3RD_DWORD};
|
||||
int32x4_t r = __builtin_bit_cast(int32x4_t, res);
|
||||
r.x = __builtin_amdgcn_readfirstlane(r.x);
|
||||
r.y = __builtin_amdgcn_readfirstlane(r.y);
|
||||
r.z = __builtin_amdgcn_readfirstlane(r.z);
|
||||
r.w = __builtin_amdgcn_readfirstlane(r.w);
|
||||
// r.x = __builtin_amdgcn_readfirstlane(r.x);
|
||||
// r.y = __builtin_amdgcn_readfirstlane(r.y);
|
||||
// r.z = __builtin_amdgcn_readfirstlane(r.z);
|
||||
// r.w = __builtin_amdgcn_readfirstlane(r.w);
|
||||
return r;
|
||||
}
|
||||
|
||||
|
||||
@@ -82,7 +82,8 @@ CK_TILE_DEVICE index_t get_lane_id() { return __lane_id(); }
|
||||
|
||||
CK_TILE_DEVICE index_t get_warp_id()
|
||||
{
|
||||
return __builtin_amdgcn_readfirstlane(threadIdx.x / get_warp_size());
|
||||
// return __builtin_amdgcn_readfirstlane(threadIdx.x / get_warp_size());
|
||||
return threadIdx.x / get_warp_size();
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE index_t get_thread_id() { return threadIdx.x; }
|
||||
@@ -126,7 +127,7 @@ template <index_t vmcnt>
|
||||
CK_TILE_DEVICE void block_sync_lds_direct_load()
|
||||
{
|
||||
// We don't sync the lds insts here.
|
||||
__builtin_amdgcn_s_waitcnt(CK_TILE_S_CNT_MAX & CK_TILE_VMCNT(vmcnt));
|
||||
__builtin_amdgcn_s_waitcnt(CK_TILE_VMCNT(vmcnt));
|
||||
__builtin_amdgcn_s_barrier();
|
||||
}
|
||||
|
||||
|
||||
@@ -14,6 +14,7 @@
|
||||
#include "ck_tile/core/container/statically_indexed_array.hpp"
|
||||
#include "ck_tile/core/numeric/math.hpp"
|
||||
#include "ck_tile/core/utility/type_traits.hpp"
|
||||
#include "ck_tile/core/utility/debug.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
@@ -116,9 +117,18 @@ struct DefaultTranspose
|
||||
template <typename InDstrEncode, bool ReverseDirection, index_t LaneGroupSize>
|
||||
struct ValidationTraitsImpl
|
||||
{
|
||||
using QuadEncoding = std::conditional_t<ReverseDirection,
|
||||
using QuadEncoding = std::conditional_t<ReverseDirection,
|
||||
QuadOutputEncoding<LaneGroupSize>,
|
||||
QuadInputEncoding<LaneGroupSize>>;
|
||||
/*
|
||||
using OutputEncoding =
|
||||
tile_distribution_encoding<sequence<>,
|
||||
tuple<sequence<LaneGroupSize>, sequence<4>>,
|
||||
tuple<sequence<1>>,
|
||||
tuple<sequence<0>>,
|
||||
sequence<2>,
|
||||
sequence<0>>;
|
||||
*/
|
||||
static constexpr auto I0 = number<0>{};
|
||||
static constexpr auto I1 = number<1>{};
|
||||
static constexpr auto input_hs = InDstrEncode::hs_lengthss_;
|
||||
@@ -130,6 +140,7 @@ struct DefaultTranspose
|
||||
util::is_sequence_suffix_v<decltype(quad_hs[I0]), decltype(input_hs[I0])>;
|
||||
static constexpr bool suffix_valid_dim1 =
|
||||
util::is_sequence_suffix_v<decltype(quad_hs[I1]), decltype(input_hs[I1])>;
|
||||
// using bbb = decltype(CK_PRINT<decltype(quad_hs[I1]), decltype(input_hs[I1])>());
|
||||
|
||||
// 3. PS→RHS mapping constraints
|
||||
static constexpr auto input_ps_major = InDstrEncode::ps_to_rhss_major_;
|
||||
@@ -169,6 +180,8 @@ struct DefaultTranspose
|
||||
static constexpr bool ys_mapping_valid =
|
||||
(input_ys_major.back() == 2) && (input_ys_minor.back() == input_hs[I1].size() - 1);
|
||||
|
||||
// using aaa = decltype(CK_PRINT<dims_valid, suffix_valid_dim0, suffix_valid_dim1,
|
||||
// ps_mapping_valid, ys_mapping_valid>());
|
||||
static constexpr bool value = dims_valid && suffix_valid_dim0 && suffix_valid_dim1 &&
|
||||
ps_mapping_valid && ys_mapping_valid;
|
||||
};
|
||||
|
||||
@@ -588,8 +588,8 @@ struct FmhaFwdDecodeKernel
|
||||
// divide problem
|
||||
const auto [i_tile_m, i_tile_n, i_split, i_nhead, i_batch] = GetTileIndex(kargs);
|
||||
|
||||
const index_t i_m0 = __builtin_amdgcn_readfirstlane(i_tile_m * FmhaPipeline::kM0);
|
||||
const index_t i_n1 = __builtin_amdgcn_readfirstlane(i_tile_n * FmhaPipeline::kN1);
|
||||
const index_t i_m0 = i_tile_m * FmhaPipeline::kM0;
|
||||
const index_t i_n1 = i_tile_n * FmhaPipeline::kN1;
|
||||
|
||||
long_index_t batch_offset_q = 0;
|
||||
long_index_t batch_offset_k = 0; // unused for paged-kvcache
|
||||
@@ -783,46 +783,21 @@ struct FmhaFwdDecodeKernel
|
||||
}();
|
||||
|
||||
const auto make_v_dram = [&](const VDataType* data, index_t length) {
|
||||
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
// We don't expect V data reuse among different blocks in decode case.
|
||||
const auto v_dram_naive =
|
||||
make_naive_tensor_view<address_space_enum::global,
|
||||
memory_operation_enum::set,
|
||||
amd_buffer_coherence_enum::SYSTEM_NT1>(
|
||||
data, // will update this pointer if using paged-kvcache
|
||||
make_tuple(length, kargs.hdim_v),
|
||||
make_tuple(kargs.stride_v, 1),
|
||||
number<FmhaPipeline::kAlignmentV>{},
|
||||
number<1>{});
|
||||
const auto v_dram_naive = make_naive_tensor_view<address_space_enum::global,
|
||||
memory_operation_enum::set,
|
||||
amd_buffer_coherence_enum::SYSTEM_NT1>(
|
||||
data, // will update this pointer if using paged-kvcache
|
||||
make_tuple(length, kargs.hdim_v),
|
||||
make_tuple(kargs.stride_v, 1),
|
||||
number<FmhaPipeline::kAlignmentV>{},
|
||||
number<1>{});
|
||||
|
||||
const auto v_dram_transposed =
|
||||
transform_tensor_view(v_dram_naive,
|
||||
make_tuple(make_pass_through_transform(kargs.hdim_v),
|
||||
make_pass_through_transform(length)),
|
||||
make_tuple(sequence<1>{}, sequence<0>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
return pad_tensor_view(
|
||||
v_dram_transposed,
|
||||
make_tuple(number<FmhaPipeline::kN1>{}, number<FmhaPipeline::kK1>{}),
|
||||
sequence<kPadHeadDimV, kPadSeqLenK>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto v_dram_naive = make_naive_tensor_view<address_space_enum::global>(
|
||||
data, // will update this pointer if using paged-kvcache
|
||||
make_tuple(kargs.hdim_v, length),
|
||||
make_tuple(kargs.stride_v, 1),
|
||||
number<FmhaPipeline::kAlignmentV>{},
|
||||
number<1>{});
|
||||
|
||||
return pad_tensor_view(
|
||||
v_dram_naive,
|
||||
make_tuple(number<FmhaPipeline::kN1>{}, number<FmhaPipeline::kK1>{}),
|
||||
sequence<false, kPadSeqLenK>{});
|
||||
}
|
||||
return pad_tensor_view(
|
||||
v_dram_naive,
|
||||
make_tuple(number<FmhaPipeline::kK1>{}, number<FmhaPipeline::kN1>{}),
|
||||
sequence<kPadSeqLenK, false>{});
|
||||
};
|
||||
|
||||
const auto v_dram = [&]() {
|
||||
if constexpr(kIsPagedKV)
|
||||
{
|
||||
|
||||
@@ -172,7 +172,7 @@ struct BlockFmhaFwdDecodePipelineQRKSVS
|
||||
|
||||
// Block GEMM
|
||||
constexpr auto gemm_0 = Policy::template GetQKBlockGemm<Problem>();
|
||||
constexpr auto gemm_1 = Policy::template GetKVBlockGemm<Problem>();
|
||||
constexpr auto gemm_1 = Policy::template GetPVBlockGemm<Problem>();
|
||||
|
||||
using SaccBlockTileType = decltype(gemm_0.MakeCBlockTile());
|
||||
auto s_acc = SaccBlockTileType{};
|
||||
@@ -298,8 +298,7 @@ struct BlockFmhaFwdDecodePipelineQRKSVS
|
||||
|
||||
// V tile in LDS
|
||||
auto [i_page_block_v, v_dram_block_window] = v_page_block_navigator.make_tile_window(
|
||||
v_dram_block_window_lengths,
|
||||
{0, aligned_physical_seqlen_k_start});
|
||||
v_dram_block_window_lengths, {0, aligned_physical_seqlen_k_start});
|
||||
|
||||
auto v_dram_window = make_tile_window(
|
||||
v_dram_block_window, Policy::template MakeVDramTileDistribution<Problem>());
|
||||
@@ -350,8 +349,8 @@ struct BlockFmhaFwdDecodePipelineQRKSVS
|
||||
k_dram_window = make_tile_window(k_dram_block_window,
|
||||
Policy::template MakeKDramTileDistribution<Problem>());
|
||||
|
||||
constexpr index_t k_vmem_insts = k_dram_window.get_num_of_access();
|
||||
constexpr index_t v_vmem_insts = v_dram_window.get_num_of_access();
|
||||
// constexpr index_t k_vmem_insts = k_dram_window.get_num_of_access();
|
||||
// constexpr index_t v_vmem_insts = v_dram_window.get_num_of_access();
|
||||
|
||||
do
|
||||
{
|
||||
@@ -373,10 +372,10 @@ struct BlockFmhaFwdDecodePipelineQRKSVS
|
||||
async_load_tile(v_lds_write_window, v_dram_window); // prefetch load v tile
|
||||
// move V tile windows
|
||||
i_page_block_v =
|
||||
v_page_block_navigator.move_tile_window(i_page_block_v, v_dram_window, {0, kK1});
|
||||
v_page_block_navigator.move_tile_window(i_page_block_v, v_dram_window, {kK1, 0});
|
||||
|
||||
// CK_PRINT<decltype(v_dram_window.get_num_of_access())>();
|
||||
block_sync_lds_direct_load<v_vmem_insts>();
|
||||
// block_sync_lds_direct_load<v_vmem_insts>();
|
||||
auto k_tile = load_tile(k_lds_read_window);
|
||||
|
||||
gemm_0(
|
||||
@@ -509,7 +508,7 @@ struct BlockFmhaFwdDecodePipelineQRKSVS
|
||||
// Otherwise shuffle through LDS so that the tile layout is consistent with required by
|
||||
// Gemm1
|
||||
auto s_new = [&]() {
|
||||
if constexpr(!((kNWarp == 1) && (kNXdl == 32)))
|
||||
if constexpr(kNWarp > 1)
|
||||
{
|
||||
auto s = cast_tile<SMPLComputeDataType>(s_acc); // S{j}
|
||||
|
||||
@@ -589,7 +588,9 @@ struct BlockFmhaFwdDecodePipelineQRKSVS
|
||||
|
||||
block_tile_reduce_sync(rowsum_p, f_sum, bool_constant<false>{});
|
||||
|
||||
const auto p = cast_tile<PDataType>(p_compute);
|
||||
auto p_tile = make_static_distributed_tensor<PDataType>(
|
||||
Policy::template MakePRegTileDistribution<Problem>());
|
||||
p_tile.get_thread_buffer() = cast_tile<PDataType>(p_compute).get_thread_buffer();
|
||||
|
||||
// l{j}, Oacc{j}
|
||||
constexpr auto o_spans = decltype(o_acc)::get_distributed_spans();
|
||||
@@ -628,15 +629,15 @@ struct BlockFmhaFwdDecodePipelineQRKSVS
|
||||
});
|
||||
});
|
||||
|
||||
block_sync_lds_direct_load<k_vmem_insts>();
|
||||
// block_sync_lds_direct_load<k_vmem_insts>();
|
||||
auto v_tile = load_tile_transpose(v_lds_read_window);
|
||||
|
||||
gemm_1(o_acc,
|
||||
get_slice_tile(
|
||||
p, sequence<0, (k1_loops - 1) * kK1>{}, sequence<kM0, k1_loops * kK1>{}),
|
||||
get_slice_tile(v_tile,
|
||||
sequence<0, (k1_loops - 1) * kK1>{},
|
||||
sequence<kN1, k1_loops * kK1>{}));
|
||||
gemm_1(
|
||||
o_acc,
|
||||
get_slice_tile(
|
||||
p_tile, sequence<0, (k1_loops - 1) * kK1>{}, sequence<kM0, k1_loops * kK1>{}),
|
||||
get_slice_tile(
|
||||
v_tile, sequence<0, (k1_loops - 1) * kK1>{}, sequence<kN1, k1_loops * kK1>{}));
|
||||
|
||||
} while(++i_total_loops < num_total_loop);
|
||||
|
||||
|
||||
@@ -14,18 +14,6 @@
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
// Use `CK_PRINT<T1, T2, ...>()` to inspect values of type T1, T2, ...
|
||||
// Use `CK_PRINT<v1, v2, ...>()` to inspect constexpr values of val1, val2, ... of the same type
|
||||
// In a non-evaluated context, you can use `using _dummy = decltype(CK_PRINT<...>());`
|
||||
// Set BUILD_DEV to OFF to avoid enabling Werror
|
||||
template <auto... val>
|
||||
[[deprecated("Help function to print value")]] inline constexpr void CK_PRINT()
|
||||
{
|
||||
}
|
||||
template <typename... type>
|
||||
[[deprecated("Help function to print value")]] inline constexpr void CK_PRINT()
|
||||
{
|
||||
}
|
||||
// This pipeline is qkv all located in LDS
|
||||
struct BlockFmhaFwdDecodePipelineQRKSVSDefaultPolicy
|
||||
: BlockFmhaPipelineQXKSVSCustomPolicy</* QLoadOnce = */ true,
|
||||
@@ -61,6 +49,34 @@ struct BlockFmhaFwdDecodePipelineQRKSVSDefaultPolicy
|
||||
return static_cast<index_t>(16 / sizeof(OaccDataType));
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentK()
|
||||
{
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0;
|
||||
|
||||
constexpr index_t MaxVectorSize = 16 / sizeof(typename Problem::KDataType);
|
||||
|
||||
constexpr index_t ElemPerThread = (kNPerBlock * kKPerBlock) / kBlockSize;
|
||||
static_assert(0 < ElemPerThread);
|
||||
return min(ElemPerThread, MaxVectorSize);
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentV()
|
||||
{
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1;
|
||||
|
||||
constexpr index_t MaxVectorSize = 16 / sizeof(typename Problem::VDataType);
|
||||
|
||||
constexpr index_t ElemPerThread = (kNPerBlock * kKPerBlock) / kBlockSize;
|
||||
static_assert(0 < ElemPerThread);
|
||||
return min(ElemPerThread, MaxVectorSize);
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeQDramTileDistribution()
|
||||
{
|
||||
@@ -93,7 +109,33 @@ struct BlockFmhaFwdDecodePipelineQRKSVSDefaultPolicy
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeQRegTileDistribution()
|
||||
{
|
||||
return BasePolicy::template MakeQRegTileDistribution<Problem>();
|
||||
using BlockGemm = remove_cvref_t<decltype(GetQKBlockGemm<Problem>())>;
|
||||
constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>();
|
||||
using WarpGemm = remove_cvref_t<decltype(config.template at<0>())>;
|
||||
|
||||
constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm0BlockWarps::at(number<0>{});
|
||||
constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm0BlockWarps::at(number<1>{});
|
||||
|
||||
constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0;
|
||||
|
||||
constexpr index_t MIterPerWarp = kMPerBlock / (MWarp * WarpGemm::kM);
|
||||
constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK;
|
||||
|
||||
constexpr auto q_block_outer_dstr_encoding =
|
||||
tile_distribution_encoding<sequence<NWarp>,
|
||||
tuple<sequence<MIterPerWarp, MWarp>, sequence<KIterPerWarp>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
|
||||
constexpr auto q_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
q_block_outer_dstr_encoding, typename WarpGemm::AWarpDstrEncoding{});
|
||||
|
||||
constexpr auto q_block_dstr = make_static_tile_distribution(q_block_dstr_encode);
|
||||
|
||||
return q_block_dstr;
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
@@ -107,31 +149,54 @@ struct BlockFmhaFwdDecodePipelineQRKSVSDefaultPolicy
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeQLdsBlockDescriptor()
|
||||
{
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kSubQKHeaddim;
|
||||
|
||||
constexpr index_t ElemPerThread = (kMPerBlock * kKPerBlock) / kBlockSize;
|
||||
static_assert(0 < ElemPerThread);
|
||||
constexpr index_t kKPack = min(ElemPerThread, GetSmemKPackQ<Problem>());
|
||||
constexpr index_t kKPack = GetSmemKPackQ<Problem>();
|
||||
|
||||
constexpr auto q_lds_block_desc_0 = make_naive_tensor_descriptor(
|
||||
make_tuple(number<kKPerBlock / kKPack>{}, number<kMPerBlock>{}, number<kKPack>{}),
|
||||
make_tuple(number<(kMPerBlock + 1) * kKPack>{}, number<kKPack>{}, number<1>{}),
|
||||
number<kKPack>{},
|
||||
number<1>{});
|
||||
|
||||
constexpr auto q_lds_block_desc = transform_tensor_descriptor(
|
||||
q_lds_block_desc_0,
|
||||
make_tuple(
|
||||
make_pass_through_transform(number<kMPerBlock>{}),
|
||||
make_merge_transform(make_tuple(number<kKPerBlock / kKPack>{}, number<kKPack>{}))),
|
||||
make_tuple(sequence<1>{}, sequence<0, 2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
constexpr auto q_lds_block_desc =
|
||||
make_naive_tensor_descriptor(make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}),
|
||||
make_tuple(number<kKPerBlock>{}, number<1>{}),
|
||||
number<kKPack>{},
|
||||
number<1>{});
|
||||
|
||||
return q_lds_block_desc;
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeKLdsBlockDescriptor()
|
||||
{
|
||||
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kSubQKHeaddim;
|
||||
|
||||
constexpr index_t kKPack = GetSmemKPackK<Problem>();
|
||||
|
||||
constexpr auto k_lds_block_desc =
|
||||
make_naive_tensor_descriptor(make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}),
|
||||
make_tuple(number<kKPerBlock>{}, number<1>{}),
|
||||
number<kKPack>{},
|
||||
number<1>{});
|
||||
|
||||
return k_lds_block_desc;
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeVLdsBlockDescriptor()
|
||||
{
|
||||
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1;
|
||||
|
||||
constexpr index_t kKPack = GetSmemKPackV<Problem>();
|
||||
|
||||
constexpr auto v_lds_block_desc =
|
||||
make_naive_tensor_descriptor(make_tuple(number<kKPerBlock>{}, number<kNPerBlock>{}),
|
||||
make_tuple(number<kNPerBlock>{}, number<1>{}),
|
||||
number<kKPack>{},
|
||||
number<1>{});
|
||||
|
||||
return v_lds_block_desc;
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetQKBlockGemm()
|
||||
{
|
||||
@@ -146,15 +211,14 @@ struct BlockFmhaFwdDecodePipelineQRKSVSDefaultPolicy
|
||||
typename Problem::BlockFmhaShape::Gemm0BlockWarps,
|
||||
typename Problem::BlockFmhaShape::Gemm0WarpTile>>;
|
||||
|
||||
using WarpGemm = WarpGemmMfmaDispatcher<
|
||||
typename Problem::QDataType,
|
||||
typename Problem::KDataType,
|
||||
typename Problem::SaccDataType,
|
||||
Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{}),
|
||||
Problem::BlockFmhaShape::Gemm0WarpTile::at(number<1>{}),
|
||||
Problem::BlockFmhaShape::Gemm0WarpTile::at(number<2>{}),
|
||||
true,
|
||||
Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{}) == 16 ? false : true>;
|
||||
using WarpGemm =
|
||||
WarpGemmMfmaDispatcher<typename Problem::QDataType,
|
||||
typename Problem::KDataType,
|
||||
typename Problem::SaccDataType,
|
||||
Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{}),
|
||||
Problem::BlockFmhaShape::Gemm0WarpTile::at(number<1>{}),
|
||||
Problem::BlockFmhaShape::Gemm0WarpTile::at(number<2>{}),
|
||||
true>;
|
||||
|
||||
using BlockGemmPolicy =
|
||||
BlockGemmARegBRegCRegV1CustomPolicy<typename Problem::QDataType,
|
||||
@@ -187,7 +251,10 @@ struct BlockFmhaFwdDecodePipelineQRKSVSDefaultPolicy
|
||||
Problem::BlockFmhaShape::Gemm1WarpTile::at(number<0>{}),
|
||||
Problem::BlockFmhaShape::Gemm1WarpTile::at(number<1>{}),
|
||||
Problem::BlockFmhaShape::Gemm1WarpTile::at(number<2>{}),
|
||||
true>;
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
WGAttrNumAccessEnum::Double>;
|
||||
|
||||
using BlockGemmPolicy =
|
||||
BlockGemmARegBRegCRegV1CustomPolicy<typename Problem::PDataType,
|
||||
@@ -231,6 +298,67 @@ struct BlockFmhaFwdDecodePipelineQRKSVSDefaultPolicy
|
||||
return k_block_dstr;
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_DEVICE static constexpr auto MakeVDramTileDistribution()
|
||||
{
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1;
|
||||
|
||||
constexpr index_t MaxVectorSize = 16 / sizeof(typename Problem::VDataType);
|
||||
|
||||
constexpr index_t ElemPerThread = (kNPerBlock * kKPerBlock) / kBlockSize;
|
||||
static_assert(0 < ElemPerThread);
|
||||
constexpr index_t kMaxVecLoad = min(ElemPerThread, MaxVectorSize);
|
||||
|
||||
constexpr index_t NPerThread = kMaxVecLoad;
|
||||
constexpr index_t NThreads = kNPerBlock / NPerThread;
|
||||
constexpr index_t KThreadPerWarp = get_warp_size() / NThreads;
|
||||
constexpr index_t NumWarps = kBlockSize / get_warp_size();
|
||||
constexpr index_t KPerThread = kKPerBlock / (KThreadPerWarp * NumWarps);
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<KPerThread, NumWarps, KThreadPerWarp>,
|
||||
sequence<NThreads, NPerThread>>,
|
||||
tuple<sequence<1>, sequence<1, 2>>,
|
||||
tuple<sequence<1>, sequence<2, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 1>>{});
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakePRegTileDistribution()
|
||||
{
|
||||
using BlockGemm = remove_cvref_t<decltype(GetPVBlockGemm<Problem>())>;
|
||||
constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>();
|
||||
using WarpGemm = remove_cvref_t<decltype(config.template at<0>())>;
|
||||
|
||||
constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm1BlockWarps::at(number<0>{});
|
||||
constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm1BlockWarps::at(number<1>{});
|
||||
|
||||
constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1;
|
||||
|
||||
constexpr index_t MIterPerWarp = kMPerBlock / (MWarp * WarpGemm::kM);
|
||||
constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK;
|
||||
|
||||
constexpr auto p_block_outer_dstr_encoding =
|
||||
tile_distribution_encoding<sequence<NWarp>,
|
||||
tuple<sequence<MIterPerWarp, MWarp>, sequence<KIterPerWarp>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
|
||||
constexpr auto p_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
p_block_outer_dstr_encoding, typename WarpGemm::AWarpDstrEncoding{});
|
||||
|
||||
constexpr auto p_block_dstr = make_static_tile_distribution(p_block_dstr_encode);
|
||||
|
||||
return p_block_dstr;
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeVRegTileDistribution()
|
||||
{
|
||||
@@ -258,7 +386,10 @@ struct BlockFmhaFwdDecodePipelineQRKSVSDefaultPolicy
|
||||
constexpr auto v_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
v_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{});
|
||||
|
||||
constexpr auto v_block_dstr = make_static_tile_distribution(v_block_dstr_encode);
|
||||
constexpr auto v_block_dstr =
|
||||
make_static_tile_distribution(typename InputTileDistributionTraits<
|
||||
decltype(v_block_dstr_encode),
|
||||
typename Problem::VDataType>::TransposedDstrEncode{});
|
||||
|
||||
return v_block_dstr;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user