mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-11 17:00:18 +00:00
temp save
This commit is contained in:
@@ -287,7 +287,7 @@ template <bf16_rounding_mode rounding =
|
||||
static_cast<bf16_rounding_mode>(CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT)>
|
||||
CK_TILE_HOST_DEVICE constexpr bfloat16_t float_to_bf16(float f, constant<rounding> = {})
|
||||
{
|
||||
#if defined (__gfx950__)
|
||||
#if defined(__gfx950__)
|
||||
return static_cast<bfloat16_t>(f);
|
||||
#else
|
||||
return bit_cast<bfloat16_t>(float_to_bf16_raw(f, constant<rounding>{}));
|
||||
|
||||
@@ -21,7 +21,7 @@ namespace ck_tile {
|
||||
using fp32_t = float;
|
||||
using fp32x2_t = float __attribute__((ext_vector_type(2)));
|
||||
using fp16x2_t = _Float16 __attribute__((ext_vector_type(2)));
|
||||
using bf16x2_t = bf16_raw_t __attribute__((ext_vector_type(2)));
|
||||
using bf16x2_t = bfloat16_t __attribute__((ext_vector_type(2)));
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr uint8_t float_to_e2m1(float);
|
||||
|
||||
|
||||
@@ -757,7 +757,9 @@ struct FmhaFwdDecodeKernel
|
||||
|
||||
const auto make_k_dram = [&](const KDataType* data, index_t height) {
|
||||
// We don't expect K data reuse among different blocks in decode case.
|
||||
const auto k_dram_naive = make_naive_tensor_view<address_space_enum::global, memory_operation_enum::set, amd_buffer_coherence_enum::SYSTEM_NT1>(
|
||||
const auto k_dram_naive = make_naive_tensor_view<address_space_enum::global,
|
||||
memory_operation_enum::set,
|
||||
amd_buffer_coherence_enum::SYSTEM_NT1>(
|
||||
data, // will update this pointer if using paged-kvcache
|
||||
make_tuple(height, kargs.hdim_q),
|
||||
make_tuple(kargs.stride_k, 1),
|
||||
@@ -784,12 +786,15 @@ struct FmhaFwdDecodeKernel
|
||||
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
// We don't expect V data reuse among different blocks in decode case.
|
||||
const auto v_dram_naive = make_naive_tensor_view<address_space_enum::global, memory_operation_enum::set, amd_buffer_coherence_enum::SYSTEM_NT1>(
|
||||
data, // will update this pointer if using paged-kvcache
|
||||
make_tuple(length, kargs.hdim_v),
|
||||
make_tuple(kargs.stride_v, 1),
|
||||
number<FmhaPipeline::kAlignmentV>{},
|
||||
number<1>{});
|
||||
const auto v_dram_naive =
|
||||
make_naive_tensor_view<address_space_enum::global,
|
||||
memory_operation_enum::set,
|
||||
amd_buffer_coherence_enum::SYSTEM_NT1>(
|
||||
data, // will update this pointer if using paged-kvcache
|
||||
make_tuple(length, kargs.hdim_v),
|
||||
make_tuple(kargs.stride_v, 1),
|
||||
number<FmhaPipeline::kAlignmentV>{},
|
||||
number<1>{});
|
||||
|
||||
const auto v_dram_transposed =
|
||||
transform_tensor_view(v_dram_naive,
|
||||
@@ -1079,15 +1084,15 @@ struct FmhaFwdDecodeKernel
|
||||
v_page_block_navigator, // Remove it
|
||||
bias_dram_window,
|
||||
lse_acc_dram_window,
|
||||
kargs.num_splits, // Remove it
|
||||
i_split_, // Remove it
|
||||
kargs.num_splits, // Remove it
|
||||
i_split_, // Remove it
|
||||
mask,
|
||||
position_encoding,
|
||||
kargs.scale_s,
|
||||
variant, // Remove it
|
||||
variant_params, // Remove it
|
||||
block_indices, // Remove it
|
||||
kv_l2p_offset, // Remove it
|
||||
variant, // Remove it
|
||||
variant_params, // Remove it
|
||||
block_indices, // Remove it
|
||||
kv_l2p_offset, // Remove it
|
||||
smem_ptr);
|
||||
}
|
||||
}();
|
||||
|
||||
@@ -1071,20 +1071,20 @@ struct FmhaFwdSplitKVKernel
|
||||
{
|
||||
return FmhaPipeline{}(q_dram_window,
|
||||
k_dram_window_lengths,
|
||||
// k_page_block_navigator,
|
||||
// k_page_block_navigator,
|
||||
v_dram_window_lengths,
|
||||
// v_page_block_navigator,
|
||||
// v_page_block_navigator,
|
||||
bias_dram_window,
|
||||
lse_acc_dram_window,
|
||||
// kargs.num_splits,
|
||||
// i_split_,
|
||||
// kargs.num_splits,
|
||||
// i_split_,
|
||||
mask,
|
||||
position_encoding,
|
||||
kargs.scale_s,
|
||||
variant,
|
||||
variant_params,
|
||||
block_indices,
|
||||
// kv_l2p_offset,
|
||||
// kv_l2p_offset,
|
||||
smem_ptr);
|
||||
}
|
||||
}();
|
||||
|
||||
@@ -11,8 +11,7 @@
|
||||
namespace ck_tile {
|
||||
|
||||
// This pipeline is qkv all located in LDS
|
||||
template <typename Problem_,
|
||||
typename Policy_ = BlockFmhaFwdDecodePipelineQRKSVSDefaultPolicy>
|
||||
template <typename Problem_, typename Policy_ = BlockFmhaFwdDecodePipelineQRKSVSDefaultPolicy>
|
||||
struct BlockFmhaFwdDecodePipelineQRKSVS
|
||||
{
|
||||
using Problem = remove_cvref_t<Problem_>;
|
||||
@@ -52,11 +51,13 @@ struct BlockFmhaFwdDecodePipelineQRKSVS
|
||||
// static_assert(Problem::kPadSeqLenQ == true && Problem::kPadHeadDimQ == true &&
|
||||
// Problem::kPadHeadDimV == true);
|
||||
|
||||
static constexpr bool kIsGroupMode = Problem::kIsGroupMode;
|
||||
static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ;
|
||||
static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK;
|
||||
static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ; // support multiple of vector(like 8x)
|
||||
static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV; // support multiple of vector(like 8x)
|
||||
static constexpr bool kIsGroupMode = Problem::kIsGroupMode;
|
||||
static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ;
|
||||
static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK;
|
||||
static constexpr bool kPadHeadDimQ =
|
||||
Problem::kPadHeadDimQ; // support multiple of vector(like 8x)
|
||||
static constexpr bool kPadHeadDimV =
|
||||
Problem::kPadHeadDimV; // support multiple of vector(like 8x)
|
||||
|
||||
static constexpr bool kHasLogitsSoftCap = Problem::kHasLogitsSoftCap;
|
||||
static constexpr auto BiasEnum = Problem::BiasEnum;
|
||||
@@ -130,33 +131,17 @@ struct BlockFmhaFwdDecodePipelineQRKSVS
|
||||
typename VPageBlockNavigator,
|
||||
typename BiasDramBlockWindowTmp,
|
||||
typename LSEaccDramBlockWindowTmp,
|
||||
typename QElementFunction,
|
||||
typename KElementFunction,
|
||||
typename VElementFunction,
|
||||
typename BiasElementFunction,
|
||||
typename LSEaccElementFunction,
|
||||
typename SAccElementFunction,
|
||||
typename PComputeElementFunction,
|
||||
typename OAccElementFunction,
|
||||
typename PositionEncoding,
|
||||
typename AttentionVariantParams,
|
||||
typename BlockIndices>
|
||||
CK_TILE_HOST_DEVICE auto
|
||||
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
|
||||
const QElementFunction& q_element_func,
|
||||
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
|
||||
const KDramBlockWindowLengths& k_dram_block_window_lengths, // N0*K0 tile
|
||||
const KPageBlockNavigator& k_page_block_navigator,
|
||||
const KElementFunction& k_element_func,
|
||||
const VDramBlockWindowLengths& v_dram_block_window_lengths, // N1*K1 tile
|
||||
const VPageBlockNavigator& v_page_block_navigator,
|
||||
const VElementFunction& v_element_func,
|
||||
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
|
||||
const BiasElementFunction& bias_element_func,
|
||||
LSEaccDramBlockWindowTmp& lse_acc_dram_window_tmp, // M0*1 tile
|
||||
const LSEaccElementFunction& lse_acc_element_func,
|
||||
const SAccElementFunction& s_acc_element_func,
|
||||
const PComputeElementFunction& p_compute_element_func,
|
||||
const OAccElementFunction& o_acc_element_func,
|
||||
LSEaccDramBlockWindowTmp& lse_acc_dram_window_tmp, // M0*1 tile
|
||||
index_t num_splits,
|
||||
index_t i_split,
|
||||
FmhaMask mask,
|
||||
@@ -184,56 +169,11 @@ struct BlockFmhaFwdDecodePipelineQRKSVS
|
||||
kM0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kN0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<1>{}],
|
||||
"wrong!");
|
||||
// Q tile in LDS
|
||||
QDataType* q_lds_ptr =
|
||||
static_cast<QDataType*>(static_cast<void*>(static_cast<char*>(smem_ptr)));
|
||||
auto q_lds = make_tensor_view<address_space_enum::lds>(
|
||||
q_lds_ptr, Policy::template MakeQLdsBlockDescriptor<Problem>());
|
||||
|
||||
// K tile in LDS
|
||||
KDataType* k_lds_ptr =
|
||||
static_cast<KDataType*>(static_cast<void*>(static_cast<char*>(smem_ptr)));
|
||||
auto k_lds = make_tensor_view<address_space_enum::lds>(
|
||||
k_lds_ptr, Policy::template MakeKLdsBlockDescriptor<Problem>());
|
||||
auto k_lds_window =
|
||||
make_tile_window(k_lds, make_tuple(number<kN0>{}, number<kK0>{}), {0, 0});
|
||||
|
||||
// V tile in LDS
|
||||
auto v_lds = make_tensor_view<address_space_enum::lds>(
|
||||
reinterpret_cast<VDataType*>(static_cast<char*>(smem_ptr) +
|
||||
max(Policy::template GetSmemSizeQ<Problem>(),
|
||||
Policy::template GetSmemSizeK<Problem>())),
|
||||
Policy::template MakeVLdsBlockDescriptor<Problem>());
|
||||
auto v_lds_window = make_tile_window(
|
||||
v_lds, Policy::template MakeVLdsBlockDescriptor<Problem>().get_lengths(), {0, 0});
|
||||
|
||||
// S tile in LDS
|
||||
auto s_lds = make_tensor_view<address_space_enum::lds>(
|
||||
reinterpret_cast<SaccDataType*>(reinterpret_cast<char*>(smem_ptr) +
|
||||
max(Policy::template GetSmemSizeQ<Problem>(),
|
||||
Policy::template GetSmemSizeK<Problem>())),
|
||||
Policy::template MakeSLdsBlockDescriptor<Problem>());
|
||||
auto s_write_lds_window = make_tile_window(
|
||||
s_lds, Policy::template MakeSLdsBlockDescriptor<Problem>().get_lengths(), {0, 0});
|
||||
auto s_read_lds_window =
|
||||
make_tile_window(s_lds,
|
||||
Policy::template MakeSLdsBlockDescriptor<Problem>().get_lengths(),
|
||||
{0, 0},
|
||||
Policy::template MakeSRegTileDistribution<Problem>());
|
||||
|
||||
// Block GEMM
|
||||
constexpr auto gemm_0 = Policy::template GetQKBlockGemm<Problem>();
|
||||
constexpr auto gemm_1 = Policy::template GetKVBlockGemm<Problem>();
|
||||
|
||||
auto q_dram_window =
|
||||
make_tile_window(q_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
q_dram_block_window_tmp.get_window_lengths(),
|
||||
q_dram_block_window_tmp.get_window_origin(),
|
||||
Policy::template MakeQDramTileDistribution<Problem>());
|
||||
|
||||
// load Q here, will store Q into LDS to maximize throughput
|
||||
auto origin_q = load_tile(q_dram_window);
|
||||
|
||||
using SaccBlockTileType = decltype(gemm_0.MakeCBlockTile());
|
||||
auto s_acc = SaccBlockTileType{};
|
||||
|
||||
@@ -259,7 +199,7 @@ struct BlockFmhaFwdDecodePipelineQRKSVS
|
||||
set_tile(m, -numeric<SMPLComputeDataType>::infinity());
|
||||
clear_tile(l);
|
||||
|
||||
const auto q_origin = q_dram_window.get_window_origin();
|
||||
const auto q_origin = q_dram_block_window_tmp.get_window_origin();
|
||||
const auto [logical_seqlen_k_start, logical_seqlen_k_end] = mask.GetTileRangeAlongX(
|
||||
q_origin.at(number<0>{}), number<kM0>{}, number<kN0>{}, num_splits, i_split);
|
||||
|
||||
@@ -279,8 +219,7 @@ struct BlockFmhaFwdDecodePipelineQRKSVS
|
||||
|
||||
if(get_thread_local_1d_id() < kM0)
|
||||
{
|
||||
store_tile(lse_acc_dram_window_tmp,
|
||||
tile_elementwise_in(lse_acc_element_func, lse_acc));
|
||||
store_tile(lse_acc_dram_window_tmp, lse_acc);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -290,6 +229,25 @@ struct BlockFmhaFwdDecodePipelineQRKSVS
|
||||
}
|
||||
}
|
||||
|
||||
// Q tile in LDS
|
||||
auto q_dram_window = make_tile_window(
|
||||
q_dram_block_window_tmp, Policy::template MakeQDramTileDistribution<Problem>());
|
||||
|
||||
auto q_lds = make_tensor_view<address_space_enum::lds>(
|
||||
static_cast<QDataType*>(smem_ptr), Policy::template MakeQLdsBlockDescriptor<Problem>());
|
||||
|
||||
auto q_lds_store_window = make_tile_window(
|
||||
q_lds, Policy::template MakeQLdsBlockDescriptor<Problem>().get_lengths(), {0, 0});
|
||||
|
||||
auto q_lds_read_window =
|
||||
make_tile_window(q_lds,
|
||||
Policy::template MakeQLdsBlockDescriptor<Problem>().get_lengths(),
|
||||
{0, 0},
|
||||
Policy::template MakeQRegTileDistribution<Problem>());
|
||||
|
||||
async_load_tile(q_lds_store_window, q_dram_window);
|
||||
|
||||
// K tile in LDS
|
||||
const index_t physical_seqlen_k_start = logical_seqlen_k_start + kv_l2p_offset;
|
||||
const index_t physical_seqlen_k_end = logical_seqlen_k_end + kv_l2p_offset;
|
||||
// make sure the first tile is completely located in page-block (page-block size should be
|
||||
@@ -307,12 +265,59 @@ struct BlockFmhaFwdDecodePipelineQRKSVS
|
||||
return physical_seqlen_k_start_;
|
||||
}
|
||||
}();
|
||||
const index_t num_total_loop =
|
||||
integer_divide_ceil(physical_seqlen_k_end - aligned_physical_seqlen_k_start, kN0);
|
||||
|
||||
auto [i_page_block_k, k_dram_block_window] = k_page_block_navigator.make_tile_window(
|
||||
k_dram_block_window_lengths, {aligned_physical_seqlen_k_start, 0});
|
||||
|
||||
auto k_dram_window = make_tile_window(
|
||||
k_dram_block_window, Policy::template MakeKDramTileDistribution<Problem>());
|
||||
|
||||
auto k_lds = make_tensor_view<address_space_enum::lds>(
|
||||
static_cast<KDataType*>(smem_ptr), Policy::template MakeKLdsBlockDescriptor<Problem>());
|
||||
auto k_lds_write_window =
|
||||
make_tile_window(k_lds, make_tuple(number<kN0>{}, number<kK0>{}), {0, 0});
|
||||
auto k_lds_read_window =
|
||||
make_tile_window(k_lds,
|
||||
make_tuple(number<kN0>{}, number<kK0>{}),
|
||||
{0, 0},
|
||||
Policy::template MakeKRegTileDistribution<Problem>());
|
||||
|
||||
// S tile in LDS
|
||||
auto s_lds = make_tensor_view<address_space_enum::lds>(
|
||||
reinterpret_cast<SaccDataType*>(reinterpret_cast<char*>(smem_ptr) +
|
||||
max(Policy::template GetSmemSizeQ<Problem>(),
|
||||
Policy::template GetSmemSizeK<Problem>())),
|
||||
Policy::template MakeSLdsBlockDescriptor<Problem>());
|
||||
auto s_write_lds_window = make_tile_window(
|
||||
s_lds, Policy::template MakeSLdsBlockDescriptor<Problem>().get_lengths(), {0, 0});
|
||||
auto s_read_lds_window =
|
||||
make_tile_window(s_lds,
|
||||
Policy::template MakeSLdsBlockDescriptor<Problem>().get_lengths(),
|
||||
{0, 0},
|
||||
Policy::template MakeSRegTileDistribution<Problem>());
|
||||
|
||||
// V tile in LDS
|
||||
auto [i_page_block_v, v_dram_window] = v_page_block_navigator.make_tile_window(
|
||||
v_dram_block_window_lengths,
|
||||
{0, aligned_physical_seqlen_k_start}, // TODO: hdim split?
|
||||
Policy::template MakeVDramTileDistribution<Problem>());
|
||||
|
||||
auto v_lds = make_tensor_view<address_space_enum::lds>(
|
||||
reinterpret_cast<VDataType*>(static_cast<char*>(smem_ptr) +
|
||||
max(Policy::template GetSmemSizeQ<Problem>(),
|
||||
Policy::template GetSmemSizeK<Problem>()) +
|
||||
Policy::template GetSmemSizeS<Problem>()),
|
||||
Policy::template MakeVLdsBlockDescriptor<Problem>());
|
||||
auto v_lds_write_window = make_tile_window(
|
||||
v_lds, Policy::template MakeVLdsBlockDescriptor<Problem>().get_lengths(), {0, 0});
|
||||
|
||||
auto v_lds_read_window =
|
||||
make_tile_window(v_lds,
|
||||
Policy::template MakeVLdsBlockDescriptor<Problem>().get_lengths(),
|
||||
{0, 0},
|
||||
Policy::template MakeVRegTileDistribution<Problem>());
|
||||
|
||||
// Bias tile in LDS
|
||||
const auto bias_origin = bias_dram_block_window_tmp.get_window_origin();
|
||||
auto bias_dram_window =
|
||||
make_tile_window(bias_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
@@ -322,51 +327,26 @@ struct BlockFmhaFwdDecodePipelineQRKSVS
|
||||
aligned_physical_seqlen_k_start)}, // M/N
|
||||
Policy::template MakeBiasDramTileDistribution<decltype(gemm_0)>());
|
||||
|
||||
auto [i_page_block_v, v_dram_window] = v_page_block_navigator.make_tile_window(
|
||||
v_dram_block_window_lengths,
|
||||
{0, aligned_physical_seqlen_k_start}, // TODO: hdim split?
|
||||
Policy::template MakeVDramTileDistribution<Problem>());
|
||||
block_sync_lds_direct_load<0>();
|
||||
auto q_tile = load_tile(q_lds_read_window);
|
||||
|
||||
// store Q into LDS
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
auto q_lds_window_for_store = make_tile_window(
|
||||
q_lds, Policy::template MakeQLdsBlockDescriptor<Problem>().get_lengths(), {0, 0});
|
||||
const index_t num_total_loop =
|
||||
integer_divide_ceil(physical_seqlen_k_end - aligned_physical_seqlen_k_start, kN0);
|
||||
|
||||
store_tile(q_lds_window_for_store, origin_q);
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
// load Q from LDS
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
auto q_lds_window_for_load =
|
||||
make_tile_window(q_lds,
|
||||
Policy::template MakeQLdsBlockDescriptor<Problem>().get_lengths(),
|
||||
{0, 0},
|
||||
Policy::template MakeQRegTileDistribution<Problem>());
|
||||
block_sync_lds();
|
||||
auto q = load_tile(q_lds_window_for_load);
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
auto q_tile = tile_elementwise_in(q_element_func, q);
|
||||
|
||||
// prefetch K tile
|
||||
index_t i_total_loops = 0;
|
||||
constexpr index_t k0_loops = kQKHeaddim / kK0;
|
||||
constexpr index_t k1_loops = kN0 / kK1;
|
||||
|
||||
static_assert(1 <= k0_loops);
|
||||
static_assert(1 <= k1_loops);
|
||||
static_assert(1 == k0_loops);
|
||||
static_assert(1 == k1_loops);
|
||||
|
||||
auto k_dram_window = make_tile_window(
|
||||
k_dram_block_window,
|
||||
Policy::template MakeKDramTileDistribution<Problem>()); // K DRAM tile window for
|
||||
async_load_tile(k_lds_write_window, k_dram_window);
|
||||
// move K tile windows
|
||||
i_page_block_k =
|
||||
k_page_block_navigator.move_tile_window(i_page_block_k, k_dram_block_window, {kN0, 0});
|
||||
|
||||
// load the first tile of the first iteration and store to LDS
|
||||
auto k_block_tile = load_tile(k_dram_window);
|
||||
// moving k_dram_window is an in-page-block operation, so there is
|
||||
// no need to invoke k_page_block_navigator.move_tile_window() here.
|
||||
move_tile_window(k_dram_window, {0, kK0});
|
||||
// ensure LDS access by Q is done before the over-writting by K
|
||||
block_sync_lds();
|
||||
store_tile(k_lds_window, tile_elementwise_in(k_element_func, k_block_tile));
|
||||
k_dram_window = make_tile_window(k_dram_block_window,
|
||||
Policy::template MakeKDramTileDistribution<Problem>());
|
||||
|
||||
do
|
||||
{
|
||||
@@ -385,40 +365,24 @@ struct BlockFmhaFwdDecodePipelineQRKSVS
|
||||
0); // prevent from messing up the order of global loads
|
||||
}
|
||||
|
||||
if constexpr(k0_loops > 1)
|
||||
{
|
||||
static_for<0, k0_loops - 1, 1>{}([&](auto i_k0) {
|
||||
k_block_tile = load_tile(k_dram_window); // global read i + 1
|
||||
block_sync_lds();
|
||||
gemm_0(s_acc,
|
||||
get_slice_tile(q_tile,
|
||||
sequence<0, i_k0 * kK0>{},
|
||||
sequence<kM0, (i_k0 + 1) * kK0>{}),
|
||||
k_lds_window);
|
||||
block_sync_lds();
|
||||
move_tile_window(k_dram_window, {0, kK0});
|
||||
async_load_tile(v_lds_write_window, v_dram_window); // prefetch load v tile
|
||||
// move V tile windows
|
||||
i_page_block_v =
|
||||
v_page_block_navigator.move_tile_window(i_page_block_v, v_dram_window, {0, kK1});
|
||||
|
||||
store_tile(
|
||||
k_lds_window,
|
||||
tile_elementwise_in(k_element_func, k_block_tile)); // LDS write i + 1
|
||||
});
|
||||
}
|
||||
block_sync_lds_direct_load<v_dram_window.get_num_of_access()>();
|
||||
auto k_tile = load_tile(k_lds_read_window);
|
||||
|
||||
const auto v_prefetch = load_tile(v_dram_window); // prefetch load v tile
|
||||
{ // tail
|
||||
block_sync_lds();
|
||||
|
||||
gemm_0(s_acc,
|
||||
get_slice_tile(q_tile,
|
||||
sequence<0, (k0_loops - 1) * kK0>{},
|
||||
sequence<kM0, k0_loops * kK0>{}),
|
||||
k_lds_window);
|
||||
}
|
||||
gemm_0(
|
||||
s_acc,
|
||||
get_slice_tile(
|
||||
q_tile, sequence<0, (k0_loops - 1) * kK0>{}, sequence<kM0, k0_loops * kK0>{}),
|
||||
get_slice_tile(
|
||||
k_tile, sequence<0, (k0_loops - 1) * kK0>{}, sequence<kN0, k0_loops * kK0>{}));
|
||||
|
||||
// STAGE 2, scale_s, add bias, mask, softmax
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
|
||||
{
|
||||
s_acc = tile_elementwise_in(s_acc_element_func, s_acc);
|
||||
tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_acc);
|
||||
tile_elementwise_inout(
|
||||
[&](auto& x, const auto& y) {
|
||||
@@ -437,7 +401,6 @@ struct BlockFmhaFwdDecodePipelineQRKSVS
|
||||
const auto k_origin = k_page_block_navigator.to_global_window_origin(
|
||||
i_page_block_k, k_dram_block_window.get_window_origin());
|
||||
constexpr auto s_spans = decltype(s_acc)::get_distributed_spans();
|
||||
s_acc = tile_elementwise_in(s_acc_element_func, s_acc);
|
||||
sweep_tile_span(s_spans[number<0>{}], [&](auto idx0) {
|
||||
sweep_tile_span(s_spans[number<1>{}], [&](auto idx1) {
|
||||
const auto tile_idx = get_x_indices_from_distributed_indices(
|
||||
@@ -455,7 +418,6 @@ struct BlockFmhaFwdDecodePipelineQRKSVS
|
||||
}
|
||||
else
|
||||
{
|
||||
s_acc = tile_elementwise_in(s_acc_element_func, s_acc);
|
||||
if constexpr(kHasLogitsSoftCap)
|
||||
{
|
||||
auto apply_logits_transform =
|
||||
@@ -530,39 +492,30 @@ struct BlockFmhaFwdDecodePipelineQRKSVS
|
||||
}
|
||||
}
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
async_load_tile(k_lds_write_window, k_dram_window);
|
||||
i_page_block_k = k_page_block_navigator.move_tile_window(
|
||||
i_page_block_k, k_dram_block_window, {kN0, 0});
|
||||
|
||||
// load the first tile for next iteration
|
||||
if(i_total_loops < num_total_loop - 1)
|
||||
{
|
||||
// move K tile windows
|
||||
i_page_block_k = k_page_block_navigator.move_tile_window(
|
||||
i_page_block_k, k_dram_block_window, {kN0, 0});
|
||||
k_dram_window = make_tile_window(k_dram_block_window,
|
||||
Policy::template MakeKDramTileDistribution<Problem>());
|
||||
|
||||
k_dram_window = make_tile_window(
|
||||
k_dram_block_window,
|
||||
Policy::template MakeKDramTileDistribution<Problem>()); // K DRAM tile window
|
||||
|
||||
// laod the first tile of the first iteration and store to LDS
|
||||
k_block_tile = load_tile(k_dram_window);
|
||||
}
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
// In Nwarp=1 and NXdl=32, GEMM0 output naturally fit the input of GEMM1
|
||||
// Otherwise shuffle through LDS so that the tile layout is consistent with required by Gemm1
|
||||
auto s_new = [&](){
|
||||
if constexpr ( !((kNWarp==1) && (kNXdl == 32)) ){
|
||||
// Otherwise shuffle through LDS so that the tile layout is consistent with required by
|
||||
// Gemm1
|
||||
auto s_new = [&]() {
|
||||
if constexpr(!((kNWarp == 1) && (kNXdl == 32)))
|
||||
{
|
||||
auto s = cast_tile<SMPLComputeDataType>(s_acc); // S{j}
|
||||
|
||||
store_tile(s_write_lds_window, s);
|
||||
block_sync_lds();
|
||||
return load_tile(s_read_lds_window);
|
||||
}
|
||||
else{
|
||||
else
|
||||
{
|
||||
return cast_tile<SMPLComputeDataType>(s_acc); // S{j}
|
||||
}
|
||||
}();
|
||||
}();
|
||||
|
||||
auto m_local = block_tile_reduce<SMPLComputeDataType>(
|
||||
s_new,
|
||||
@@ -630,8 +583,7 @@ struct BlockFmhaFwdDecodePipelineQRKSVS
|
||||
|
||||
block_tile_reduce_sync(rowsum_p, f_sum, bool_constant<false>{});
|
||||
|
||||
const auto p =
|
||||
cast_tile<PDataType>(tile_elementwise_in(p_compute_element_func, p_compute));
|
||||
const auto p = cast_tile<PDataType>(p_compute);
|
||||
|
||||
// l{j}, Oacc{j}
|
||||
constexpr auto o_spans = decltype(o_acc)::get_distributed_spans();
|
||||
@@ -670,79 +622,16 @@ struct BlockFmhaFwdDecodePipelineQRKSVS
|
||||
});
|
||||
});
|
||||
|
||||
block_sync_lds();
|
||||
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
auto v_shuffle_tmp = make_static_distributed_tensor<VDataType>(
|
||||
Policy::template MakeShuffledVRegBlockDescriptor<Problem>());
|
||||
shuffle_tile(v_shuffle_tmp, v_prefetch);
|
||||
store_tile(
|
||||
v_lds_window,
|
||||
tile_elementwise_in(v_element_func, v_shuffle_tmp)); // store the prefetch
|
||||
}
|
||||
else
|
||||
{
|
||||
store_tile(v_lds_window,
|
||||
tile_elementwise_in(v_element_func, v_prefetch)); // store the prefetch
|
||||
}
|
||||
i_page_block_v =
|
||||
v_page_block_navigator.move_tile_window(i_page_block_v, v_dram_window, {0, kK1});
|
||||
block_sync_lds_direct_load<k_dram_window.get_num_of_access()>();
|
||||
auto v_tile = load_tile_transpose(v_lds_read_window);
|
||||
|
||||
// STAGE 3, KV gemm
|
||||
if constexpr(k1_loops > 1)
|
||||
{
|
||||
static_for<0, k1_loops - 1, 1>{}([&,
|
||||
&i_page_block_v_ = i_page_block_v,
|
||||
&v_dram_window_ = v_dram_window](auto i_k1) {
|
||||
const auto v = load_tile(v_dram_window_); // load next v
|
||||
block_sync_lds();
|
||||
gemm_1(o_acc,
|
||||
get_slice_tile(
|
||||
p, sequence<0, (k1_loops - 1) * kK1>{}, sequence<kM0, k1_loops * kK1>{}),
|
||||
get_slice_tile(v_tile,
|
||||
sequence<0, (k1_loops - 1) * kK1>{},
|
||||
sequence<kN1, k1_loops * kK1>{}));
|
||||
|
||||
gemm_1(o_acc,
|
||||
get_slice_tile(
|
||||
p, sequence<0, i_k1 * kK1>{}, sequence<kM0, (i_k1 + 1) * kK1>{}),
|
||||
v_lds_window);
|
||||
block_sync_lds();
|
||||
|
||||
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
auto v_shuffle_tmp = make_static_distributed_tensor<VDataType>(
|
||||
Policy::template MakeShuffledVRegBlockDescriptor<Problem>());
|
||||
shuffle_tile(v_shuffle_tmp, v);
|
||||
store_tile(v_lds_window,
|
||||
tile_elementwise_in(v_element_func,
|
||||
v_shuffle_tmp)); // store the prefetch
|
||||
}
|
||||
else
|
||||
{
|
||||
store_tile(v_lds_window,
|
||||
tile_elementwise_in(v_element_func, v)); // store next v
|
||||
}
|
||||
i_page_block_v_ = v_page_block_navigator.move_tile_window(
|
||||
i_page_block_v_, v_dram_window_, {0, kK1});
|
||||
});
|
||||
}
|
||||
|
||||
// tail
|
||||
{
|
||||
block_sync_lds();
|
||||
gemm_1(o_acc,
|
||||
get_slice_tile(
|
||||
p, sequence<0, (k1_loops - 1) * kK1>{}, sequence<kM0, k1_loops * kK1>{}),
|
||||
v_lds_window);
|
||||
block_sync_lds();
|
||||
}
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
// load the first tile for next iteration
|
||||
if(i_total_loops < num_total_loop - 1)
|
||||
{
|
||||
// store the first tile for next iteration to LDS
|
||||
// moving k_dram_window is an in-page-block operation, so there is
|
||||
// no need to invoke k_page_block_navigator.move_tile_window() here.
|
||||
move_tile_window(k_dram_window, {0, kK0});
|
||||
store_tile(k_lds_window, tile_elementwise_in(k_element_func, k_block_tile));
|
||||
}
|
||||
} while(++i_total_loops < num_total_loop);
|
||||
|
||||
if constexpr(kStoreLSE)
|
||||
@@ -777,8 +666,7 @@ struct BlockFmhaFwdDecodePipelineQRKSVS
|
||||
|
||||
if(get_thread_local_1d_id() < kM0)
|
||||
{
|
||||
store_tile(lse_acc_dram_window_tmp,
|
||||
tile_elementwise_in(lse_acc_element_func, lse_acc));
|
||||
store_tile(lse_acc_dram_window_tmp, lse_acc);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -802,66 +690,8 @@ struct BlockFmhaFwdDecodePipelineQRKSVS
|
||||
});
|
||||
});
|
||||
|
||||
o_acc = tile_elementwise_in(o_acc_element_func, o_acc);
|
||||
|
||||
return o_acc;
|
||||
}
|
||||
|
||||
template <typename QDramBlockWindowTmp,
|
||||
typename KDramBlockWindowLengths,
|
||||
typename KPageBlockNavigator,
|
||||
typename VDramBlockWindowLengths,
|
||||
typename VPageBlockNavigator,
|
||||
typename BiasDramBlockWindowTmp,
|
||||
typename LSEaccDramBlockWindowTmp,
|
||||
typename PositionEncoding,
|
||||
typename AttentionVariantParams,
|
||||
typename BlockIndices>
|
||||
CK_TILE_HOST_DEVICE auto
|
||||
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
|
||||
const KDramBlockWindowLengths& k_dram_block_window_lengths, // N0*K0 tile
|
||||
const KPageBlockNavigator& k_page_block_navigator,
|
||||
const VDramBlockWindowLengths& v_dram_block_window_lengths, // N1*K1 tile
|
||||
const VPageBlockNavigator& v_page_block_navigator,
|
||||
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
|
||||
LSEaccDramBlockWindowTmp& lse_acc_dram_block_window_tmp, // M0*1 tile
|
||||
index_t num_splits,
|
||||
index_t i_split,
|
||||
FmhaMask mask,
|
||||
PositionEncoding position_encoding,
|
||||
float scale_s,
|
||||
const AttentionVariant& variant,
|
||||
const AttentionVariantParams& variant_params,
|
||||
const BlockIndices& block_indices,
|
||||
index_t kv_l2p_offset, // logical-to-physical offset of seqlen_k coordinate
|
||||
void* smem_ptr) const
|
||||
{
|
||||
return operator()(q_dram_block_window_tmp,
|
||||
identity{},
|
||||
k_dram_block_window_lengths,
|
||||
k_page_block_navigator,
|
||||
identity{},
|
||||
v_dram_block_window_lengths,
|
||||
v_page_block_navigator,
|
||||
identity{},
|
||||
bias_dram_block_window_tmp,
|
||||
identity{},
|
||||
lse_acc_dram_block_window_tmp,
|
||||
identity{},
|
||||
identity{},
|
||||
identity{},
|
||||
identity{},
|
||||
num_splits,
|
||||
i_split,
|
||||
mask,
|
||||
position_encoding,
|
||||
scale_s,
|
||||
variant,
|
||||
variant_params,
|
||||
block_indices,
|
||||
kv_l2p_offset,
|
||||
smem_ptr);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -7,6 +7,11 @@
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_custom_policy.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1_custom_policy.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
@@ -116,6 +121,137 @@ struct BlockFmhaFwdDecodePipelineQRKSVSDefaultPolicy
|
||||
return q_lds_block_desc;
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetQKBlockGemm()
|
||||
{
|
||||
using GemmProblem =
|
||||
BlockGemmProblem<typename Problem::QDataType,
|
||||
typename Problem::KDataType,
|
||||
typename Problem::SaccDataType,
|
||||
Problem::kBlockSize,
|
||||
TileGemmShape<sequence<Problem::BlockFmhaShape::kM0,
|
||||
Problem::BlockFmhaShape::kN0,
|
||||
Problem::BlockFmhaShape::kK0>,
|
||||
typename Problem::BlockFmhaShape::Gemm0BlockWarps,
|
||||
typename Problem::BlockFmhaShape::Gemm0WarpTile>>;
|
||||
|
||||
using WarpGemm = WarpGemmMfmaDispatcher<
|
||||
typename Problem::QDataType,
|
||||
typename Problem::KDataType,
|
||||
typename Problem::SaccDataType,
|
||||
Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{}),
|
||||
Problem::BlockFmhaShape::Gemm0WarpTile::at(number<1>{}),
|
||||
Problem::BlockFmhaShape::Gemm0WarpTile::at(number<2>{}),
|
||||
true,
|
||||
Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{}) == 16 ? false : true>;
|
||||
|
||||
using BlockGemmPolicy =
|
||||
BlockGemmARegBRegCRegV1CustomPolicy<typename Problem::QDataType,
|
||||
typename Problem::KDataType,
|
||||
typename Problem::SaccDataType,
|
||||
typename Problem::BlockFmhaShape::Gemm0BlockWarps,
|
||||
WarpGemm>;
|
||||
|
||||
return BlockGemmARegBRegCRegV1<GemmProblem, BlockGemmPolicy>{};
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetPVBlockGemm()
|
||||
{
|
||||
using GemmProblem =
|
||||
BlockGemmProblem<typename Problem::PDataType,
|
||||
typename Problem::VDataType,
|
||||
typename Problem::OaccDataType,
|
||||
Problem::kBlockSize,
|
||||
TileGemmShape<sequence<Problem::BlockFmhaShape::kM0,
|
||||
Problem::BlockFmhaShape::kN1,
|
||||
Problem::BlockFmhaShape::kK1>,
|
||||
typename Problem::BlockFmhaShape::Gemm1BlockWarps,
|
||||
typename Problem::BlockFmhaShape::Gemm1WarpTile>>;
|
||||
|
||||
using WarpGemm =
|
||||
WarpGemmMfmaDispatcher<typename Problem::PDataType,
|
||||
typename Problem::VDataType,
|
||||
typename Problem::OaccDataType,
|
||||
Problem::BlockFmhaShape::Gemm1WarpTile::at(number<0>{}),
|
||||
Problem::BlockFmhaShape::Gemm1WarpTile::at(number<1>{}),
|
||||
Problem::BlockFmhaShape::Gemm1WarpTile::at(number<2>{}),
|
||||
true>;
|
||||
|
||||
using BlockGemmPolicy =
|
||||
BlockGemmARegBRegCRegV1CustomPolicy<typename Problem::PDataType,
|
||||
typename Problem::VDataType,
|
||||
typename Problem::OaccDataType,
|
||||
typename Problem::BlockFmhaShape::Gemm1BlockWarps,
|
||||
WarpGemm>;
|
||||
|
||||
return BlockGemmARegBRegCRegV1<GemmProblem, BlockGemmPolicy>{};
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeKRegTileDistribution()
|
||||
{
|
||||
using BlockGemm = remove_cvref_t<decltype(GetQKBlockGemm<Problem>())>;
|
||||
constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>();
|
||||
using WarpGemm = remove_cvref_t<decltype(config.template at<0>())>;
|
||||
|
||||
constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm0BlockWarps::at(number<0>{});
|
||||
constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm0BlockWarps::at(number<1>{});
|
||||
|
||||
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0;
|
||||
|
||||
constexpr index_t NIterPerWarp = kNPerBlock / (NWarp * WarpGemm::kN);
|
||||
constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK;
|
||||
|
||||
constexpr auto k_block_outer_dstr_encoding =
|
||||
tile_distribution_encoding<sequence<MWarp>,
|
||||
tuple<sequence<NIterPerWarp, NWarp>, sequence<KIterPerWarp>>,
|
||||
tuple<sequence<0, 1>>,
|
||||
tuple<sequence<0, 1>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
|
||||
constexpr auto k_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
k_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{});
|
||||
|
||||
constexpr auto k_block_dstr = make_static_tile_distribution(k_block_dstr_encode);
|
||||
|
||||
return k_block_dstr;
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeVRegTileDistribution()
|
||||
{
|
||||
using BlockGemm = remove_cvref_t<decltype(GetPVBlockGemm<Problem>())>;
|
||||
constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>();
|
||||
using WarpGemm = remove_cvref_t<decltype(config.template at<0>())>;
|
||||
|
||||
constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm1BlockWarps::at(number<0>{});
|
||||
constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm1BlockWarps::at(number<1>{});
|
||||
|
||||
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1;
|
||||
|
||||
constexpr index_t NIterPerWarp = kNPerBlock / (NWarp * WarpGemm::kN);
|
||||
constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK;
|
||||
|
||||
constexpr auto v_block_outer_dstr_encoding =
|
||||
tile_distribution_encoding<sequence<MWarp>,
|
||||
tuple<sequence<NIterPerWarp, NWarp>, sequence<KIterPerWarp>>,
|
||||
tuple<sequence<0, 1>>,
|
||||
tuple<sequence<0, 1>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
|
||||
constexpr auto v_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
v_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{});
|
||||
|
||||
constexpr auto v_block_dstr = make_static_tile_distribution(v_block_dstr_encode);
|
||||
|
||||
return v_block_dstr;
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetSmemNPackS()
|
||||
{
|
||||
|
||||
@@ -183,7 +183,7 @@ using WarpGemmMfmaBf16Bf16F32M16N16K32TransposedCDistribution =
|
||||
#if defined(__gfx950__)
|
||||
using WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleBTransposedCDistribution =
|
||||
WarpGemmImpl<WarpGemmAtrributeMfmaTransposedCDistribution_SwizzleB<
|
||||
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8<WGAttrCtlEnum::Default_>>>;
|
||||
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K16<WGAttrCtlEnum::Default_>>>;
|
||||
#else
|
||||
using WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleBTransposedCDistribution =
|
||||
WarpGemmImpl<WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB<
|
||||
|
||||
@@ -36,6 +36,8 @@ template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float
|
||||
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 32, 32, 8, false, true> { using Type = WarpGemmMfmaF16F16F32M32N32K8SwizzleA; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 32, 32, 16, false, true> { using Type = WarpGemmMfmaF16F16F32M32N32K16SwizzleA; };
|
||||
// template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 32, 32, 8, true, true> { using Type = WarpGemmMfmaF16F16F32M32N32K8SwizzleBTransposedCDistribution; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 32, 32, 16, true, true> { using Type = WarpGemmMfmaF16F16F32M32N32K16SwizzleBTransposedCDistribution; };
|
||||
|
||||
// fp16 2:4 structural sparsity
|
||||
// ADataType, BDataType, AccDataType, MPerWave, NPerWave, KPerWave, TransposeC, SwizzleA, UseStructuredSparsity
|
||||
@@ -57,6 +59,8 @@ template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float
|
||||
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 32, 32, 8, false, true> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K8SwizzleA; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 32, 32, 16, false, true> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleA; };
|
||||
// template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 32, 32, 8, true, true> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K8SwizzleBTransposedCDistribution; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 32, 32, 16, true, true> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleBTransposedCDistribution; };
|
||||
|
||||
// fp8
|
||||
// ADataType, BDataType, AccDataType, MPerWave, NPerWave, KPerWave, TransposeC, SwizzleA, UseStructuredSparsity
|
||||
|
||||
Reference in New Issue
Block a user