mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-11 17:00:18 +00:00
[rocm-libraries] ROCm/rocm-libraries#6209 (commit 89c9f3e)
Improve the performance of qr_ks_vs_whole_k_prefetch pipeline (#6209) ## About qr_ks_vs_whole_k_prefetch pipeline This PR updates and enhances the qr_ks_vs_whole_k_prefetch pipeline to improve performance on both MI350 GPUs through better MFMA instruction usage, transposed V-loading support, and N0-loop implementation. The pipeline targets scenarios where the number of workgroups is low, enabling better CU occupancy by using smaller MTile sizes (kM0=64 vs 128) while prefetching entire K tiles. ## Changes: - Adds transposed V-loading support (qr_ks_vs_whole_k_prefetch_trload) to avoid using shuffle instructions on MI350 - Implements N0-loop based Gemm0 to reduce tile window movement overhead and eliminate `clear_tile` calls - Adds full support for hdim96/hdim160 without padding requirements - Updates MFMA instruction selection to ensure optimal choices for MI350 ## Performance results 1. For attention shapes which leads to kM0=64, `qr_ks_vs_async_whole_k_prefetch_trload` shows much better performance than `qr_ks_vs_async_trload` on the same case (execution time `41.02ms` by whole_k_prefetch_trload & `58.50ms` by async_load), and `qr_ks_vs_async_whole_k_prefetch_trload` also shows obviously better performance than the recently tuned `qr_ks_vs_async` on the same case (execution time `41.02ms` by whole_k_prefetch_trload 7 `47.60ms` by qr_ks_vs_async) 2. Also on MI300, for attention shapes which leads to kM0=64, `qr_ks_vs_async_whole_k_prefetch` shows much better performance than the `qr_ks_vs_async` (which is supposed to be very high-efficient) on the same case (execution time `64.50ms` by whole_k_prefetch & `80.20ms` by qr_ks_vs_async) 3. For attention shapes which leads to kM0=128, `qr_ks_vs_async_whole_k_prefetch_trload` show a little bit better performance than `qr_ks_vs_async` on mi350 (execution time `104.50ms` by whole_k_prefetch_trload & `106.50ms` by qr_ks_vs_async). And they shows completely on-par performance on MI300 ## Test/Verify 1. Use the ROCM xformers branch `test_whole_k_prefetch_n0loop` to test/verify qr_ks_vs_whole_k_prefetch pipeline since this pipeline can not be used by ck_tile fmha example so far 2. Use the following command-line for building/testing xformers >```bash > #> git clone -b test_whole_k_prefetch_n0loop https://github.com/ROCm/xformers > #> git submodule update --init --recursive > #> pip install --no-build-isolation -e ./ > #> pytest tests/test_mem_eff_attention.py::test_forward >``` 4. Any scripts which can run on xformers can be used to evaluate qr_ks_vs_whole_k_prefetch pipeline. Using the two environ variable to switch from using different pipelines > ```bash > #> export FMHA_DISABLE_SPECIAL_TREATMENT=1 #> to disable using FAV3 and qr_ks_vs_async_trload pipeline > #> export FMHA_ENABLE_ASYNC_PIPELINE=1 #> to disable using qr_ks_vs_async pipeline for comparing > ``` ## Discussion
This commit is contained in:
committed by
assistant-librarian[bot]
parent
b2ea5fd315
commit
865ab2b8ed
@@ -6,6 +6,7 @@
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/host/hip_check_error.hpp"
|
||||
#include <hip/hip_runtime.h>
|
||||
#include <iostream>
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
|
||||
@@ -56,6 +56,7 @@
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_default_policy.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_fp8.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch_trload.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch_default_policy.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs_default_policy.hpp"
|
||||
|
||||
@@ -32,6 +32,83 @@
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
namespace detail {
|
||||
|
||||
// A helper struct for detecting n0loop
|
||||
template <typename T, typename = void>
|
||||
struct has_n0loop_flag : std::false_type
|
||||
{
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct has_n0loop_flag<
|
||||
T,
|
||||
std::enable_if_t<std::is_convertible_v<decltype(T::kUseN0Loop), bool> && T::kUseN0Loop>>
|
||||
: std::true_type
|
||||
{
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
static inline constexpr bool is_n0loop_pipeline_v = has_n0loop_flag<T>::value;
|
||||
|
||||
// A helper struct for detecting ignore_fast_exp2 flag
|
||||
template <typename T, typename = void>
|
||||
struct has_ignore_fast_exp2_flag : std::false_type
|
||||
{
|
||||
};
|
||||
|
||||
// IgnoreFastExp2 is used by some pipeline which explicitly chooses not to use FAST_EXP2;
|
||||
// By detecting the kIgnoreFastExp2 from the pipeline, the kernel's MakeKargsImpl() interface
|
||||
// is able to avoid passing an in-correct scale_s parameter to the kernel layer
|
||||
template <typename T>
|
||||
struct has_ignore_fast_exp2_flag<
|
||||
T,
|
||||
std::enable_if_t<std::is_convertible_v<decltype(T::kIgnoreFastExp2), bool> &&
|
||||
T::kIgnoreFastExp2>> : std::true_type
|
||||
{
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
static inline constexpr bool ignore_fast_exp2_v = has_ignore_fast_exp2_flag<T>::value;
|
||||
|
||||
// A helper struct for detecting naive_hdim_load, naive_hdim_load means load tiles of
|
||||
// hdim96/hdim160/hdim192 without padding the tensor_view/tile_window to hdim128/hdim256
|
||||
// naive_hdim_load is current supported by the qr_ks_vs_whole_k_prefetch_pipeline
|
||||
template <typename T, typename = void>
|
||||
struct has_naive_hdim_load_flag : std::false_type
|
||||
{
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct has_naive_hdim_load_flag<
|
||||
T,
|
||||
std::enable_if_t<std::is_convertible_v<decltype(T::kIsNaiveHDimLoad), bool> &&
|
||||
T::kIsNaiveHDimLoad>> : std::true_type
|
||||
{
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
static inline constexpr bool is_naive_hdim_load_v = has_naive_hdim_load_flag<T>::value;
|
||||
|
||||
// A helper struct for detecting kUseTrLoad
|
||||
template <typename T, typename = void>
|
||||
struct has_use_trload_flag : std::false_type
|
||||
{
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct has_use_trload_flag<
|
||||
T,
|
||||
std::enable_if_t<std::is_convertible_v<decltype(T::kUseTrLoad), bool> && T::kUseTrLoad>>
|
||||
: std::true_type
|
||||
{
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
static inline constexpr bool is_using_trload_v = has_use_trload_flag<T>::value;
|
||||
|
||||
} // namespace detail
|
||||
|
||||
template <typename FmhaPipeline_, typename EpiloguePipeline_>
|
||||
struct FmhaFwdKernel
|
||||
{
|
||||
@@ -74,13 +151,14 @@ struct FmhaFwdKernel
|
||||
static constexpr bool kHasMask = FmhaMask::IsMasking;
|
||||
|
||||
static constexpr bool kUseAsyncCopy = FmhaPipeline::Policy::AsyncCopy;
|
||||
static constexpr bool kUseTrLoad = detail::is_using_trload_v<FmhaPipeline>;
|
||||
|
||||
static constexpr bool kUseTrLoad = FmhaPipeline::Problem::kUseTrLoad;
|
||||
#if defined(__gfx950__)
|
||||
static constexpr bool kIsAvailable = true;
|
||||
#else
|
||||
static constexpr bool kIsAvailable = !kUseTrLoad;
|
||||
#endif
|
||||
|
||||
static constexpr std::string_view kPipelineName = FmhaPipeline::name;
|
||||
|
||||
template <ck_tile::index_t I> // to avoid duplicated base class prblem, introduce an template
|
||||
@@ -441,7 +519,9 @@ struct FmhaFwdKernel
|
||||
num_head_q,
|
||||
nhead_ratio_qk,
|
||||
#if CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
static_cast<float>(scale_s * ck_tile::log2e_v<>),
|
||||
detail::ignore_fast_exp2_v<FmhaPipeline>
|
||||
? scale_s
|
||||
: static_cast<float>(scale_s * ck_tile::log2e_v<>),
|
||||
#else
|
||||
scale_s,
|
||||
#endif
|
||||
@@ -894,7 +974,9 @@ struct FmhaFwdKernel
|
||||
num_head_q,
|
||||
nhead_ratio_qk,
|
||||
#if CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
static_cast<float>(scale_s * ck_tile::log2e_v<>),
|
||||
detail::ignore_fast_exp2_v<FmhaPipeline>
|
||||
? scale_s
|
||||
: static_cast<float>(scale_s * ck_tile::log2e_v<>),
|
||||
#else
|
||||
scale_s,
|
||||
#endif
|
||||
@@ -1036,6 +1118,7 @@ struct FmhaFwdKernel
|
||||
const void* seqlen_k_ptr,
|
||||
const void* block_scale_seqstart_q_ptr,
|
||||
const void* block_scale_seqstart_k_ptr,
|
||||
const void* seqstart_v_scale_ptr,
|
||||
ck_tile::index_t hdim_q,
|
||||
ck_tile::index_t hdim_v,
|
||||
ck_tile::index_t num_head_q,
|
||||
@@ -1094,6 +1177,7 @@ struct FmhaFwdKernel
|
||||
seqlen_k_ptr,
|
||||
block_scale_seqstart_q_ptr,
|
||||
block_scale_seqstart_k_ptr,
|
||||
seqstart_v_scale_ptr,
|
||||
hdim_q,
|
||||
hdim_v,
|
||||
num_head_q,
|
||||
@@ -1155,6 +1239,7 @@ struct FmhaFwdKernel
|
||||
const void* seqlen_k_ptr,
|
||||
const void* block_scale_seqstart_q_ptr,
|
||||
const void* block_scale_seqstart_k_ptr,
|
||||
const void* seqstart_v_scale_ptr,
|
||||
ck_tile::index_t hdim_q,
|
||||
ck_tile::index_t hdim_v,
|
||||
ck_tile::index_t num_head_q,
|
||||
@@ -1213,6 +1298,7 @@ struct FmhaFwdKernel
|
||||
seqlen_k_ptr,
|
||||
block_scale_seqstart_q_ptr,
|
||||
block_scale_seqstart_k_ptr,
|
||||
seqstart_v_scale_ptr,
|
||||
hdim_q,
|
||||
hdim_v,
|
||||
num_head_q,
|
||||
@@ -1599,6 +1685,10 @@ struct FmhaFwdKernel
|
||||
static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_o +
|
||||
batch_offset_o;
|
||||
|
||||
constexpr index_t kQKHeaddimToUse = detail::is_naive_hdim_load_v<FmhaPipeline>
|
||||
? FmhaPipeline::kQKHeaddim
|
||||
: FmhaPipeline::kSubQKHeaddim;
|
||||
|
||||
// Q/K/V DRAM and DRAM window
|
||||
const auto q_dram = [&]() {
|
||||
const auto q_dram_naive = make_naive_tensor_view<address_space_enum::global>(
|
||||
@@ -1609,10 +1699,10 @@ struct FmhaFwdKernel
|
||||
number<1>{});
|
||||
if constexpr(FmhaPipeline::kQLoadOnce)
|
||||
{
|
||||
return pad_tensor_view(q_dram_naive,
|
||||
make_tuple(number<FmhaPipeline::kM0>{},
|
||||
number<FmhaPipeline::kSubQKHeaddim>{}),
|
||||
sequence<kPadSeqLenQ, kPadHeadDimQ>{});
|
||||
return pad_tensor_view(
|
||||
q_dram_naive,
|
||||
make_tuple(number<FmhaPipeline::kM0>{}, number<kQKHeaddimToUse>{}),
|
||||
sequence<kPadSeqLenQ, kPadHeadDimQ>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -1631,10 +1721,21 @@ struct FmhaFwdKernel
|
||||
number<1>{});
|
||||
|
||||
constexpr bool kPadSeqLenK_ = kUseAsyncCopy ? kPadSeqLenK : false;
|
||||
return pad_tensor_view(
|
||||
k_dram_naive,
|
||||
make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kK0>{}),
|
||||
sequence<kPadSeqLenK_, kPadHeadDimQ>{});
|
||||
|
||||
if constexpr(detail::is_n0loop_pipeline_v<FmhaPipeline>)
|
||||
{
|
||||
return pad_tensor_view(
|
||||
k_dram_naive,
|
||||
make_tuple(number<FmhaPipeline::kN0Sub>{}, number<kQKHeaddimToUse>{}),
|
||||
sequence<kPadSeqLenK_, kPadHeadDimQ>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return pad_tensor_view(
|
||||
k_dram_naive,
|
||||
make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kK0>{}),
|
||||
sequence<kPadSeqLenK_, kPadHeadDimQ>{});
|
||||
}
|
||||
}();
|
||||
const auto v_dram = [&]() {
|
||||
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
@@ -1646,18 +1747,29 @@ struct FmhaFwdKernel
|
||||
number<FmhaPipeline::kAlignmentV>{},
|
||||
number<1>{});
|
||||
|
||||
const auto v_dram_transposed = transform_tensor_view(
|
||||
v_dram_naive,
|
||||
make_tuple(make_pass_through_transform(kargs.hdim_v),
|
||||
make_pass_through_transform(kargs.seqlen_k)),
|
||||
make_tuple(sequence<1>{}, sequence<0>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
if constexpr(!kUseTrLoad)
|
||||
{
|
||||
const auto v_dram_transposed = transform_tensor_view(
|
||||
v_dram_naive,
|
||||
make_tuple(make_pass_through_transform(kargs.hdim_v),
|
||||
make_pass_through_transform(kargs.seqlen_k)),
|
||||
make_tuple(sequence<1>{}, sequence<0>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
constexpr bool kPadSeqLenK_ = kUseAsyncCopy ? kPadSeqLenK : false;
|
||||
return pad_tensor_view(
|
||||
v_dram_transposed,
|
||||
make_tuple(number<FmhaPipeline::kN1>{}, number<FmhaPipeline::kK1>{}),
|
||||
sequence<kPadHeadDimV, kPadSeqLenK_>{});
|
||||
constexpr bool kPadSeqLenK_ = kUseAsyncCopy ? kPadSeqLenK : false;
|
||||
|
||||
return pad_tensor_view(
|
||||
v_dram_transposed,
|
||||
make_tuple(number<FmhaPipeline::kN1>{}, number<FmhaPipeline::kK1>{}),
|
||||
sequence<kPadHeadDimV, kPadSeqLenK_>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return pad_tensor_view(
|
||||
v_dram_naive,
|
||||
make_tuple(number<FmhaPipeline::kK1>{}, number<FmhaPipeline::kN1>{}),
|
||||
sequence<false, kPadHeadDimV>{});
|
||||
};
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -1680,17 +1792,28 @@ struct FmhaFwdKernel
|
||||
q_dram,
|
||||
[&]() {
|
||||
if constexpr(FmhaPipeline::kQLoadOnce)
|
||||
return make_tuple(number<FmhaPipeline::kM0>{},
|
||||
number<FmhaPipeline::kSubQKHeaddim>{});
|
||||
return make_tuple(number<FmhaPipeline::kM0>{}, number<kQKHeaddimToUse>{});
|
||||
else
|
||||
return make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kK0>{});
|
||||
}(),
|
||||
{i_m0, 0});
|
||||
|
||||
auto k_dram_window = make_tile_window(
|
||||
k_dram,
|
||||
make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kK0>{}),
|
||||
{0, 0});
|
||||
auto k_dram_window = [&]() {
|
||||
if constexpr(detail::is_n0loop_pipeline_v<FmhaPipeline>)
|
||||
{
|
||||
return make_tile_window(
|
||||
k_dram,
|
||||
make_tuple(number<FmhaPipeline::kN0Sub>{}, number<kQKHeaddimToUse>{}),
|
||||
{0, 0});
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_tile_window(
|
||||
k_dram,
|
||||
make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kK0>{}),
|
||||
{0, 0});
|
||||
}
|
||||
}();
|
||||
|
||||
auto v_dram_window = make_tile_window(
|
||||
v_dram,
|
||||
@@ -1840,7 +1963,10 @@ struct FmhaFwdKernel
|
||||
*(reinterpret_cast<const SaccDataType*>(kargs.alibi_slope_ptr) +
|
||||
i_batch_ * kargs.alibi_slope_stride + i_nhead_);
|
||||
#if CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
slope *= ck_tile::log2e_v<>;
|
||||
if constexpr(!detail::ignore_fast_exp2_v<FmhaPipeline>)
|
||||
{
|
||||
slope *= ck_tile::log2e_v<>;
|
||||
}
|
||||
#endif
|
||||
if constexpr(kHasMask)
|
||||
{
|
||||
@@ -2798,7 +2924,10 @@ struct FmhaFwdKernel
|
||||
*(reinterpret_cast<const SaccDataType*>(kargs.alibi_slope_ptr) +
|
||||
i_batch_ * kargs.alibi_slope_stride + i_nhead_);
|
||||
#if CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
slope *= ck_tile::log2e_v<>;
|
||||
if constexpr(!detail::ignore_fast_exp2_v<FmhaPipeline>)
|
||||
{
|
||||
slope *= ck_tile::log2e_v<>;
|
||||
}
|
||||
#endif
|
||||
if constexpr(kHasMask)
|
||||
{
|
||||
|
||||
@@ -9,6 +9,52 @@
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
namespace detail {
|
||||
|
||||
template <typename DataType, index_t ElemPerThread>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetMaxVectorSize()
|
||||
{
|
||||
if constexpr(std::is_same_v<DataType, half_t> || std::is_same_v<DataType, bf16_t>)
|
||||
{
|
||||
// ToDo: need support in ck_tile for using buffer_load_dwordx3
|
||||
// if constexpr(ElemPerThread % 6 == 0)
|
||||
// return 6;
|
||||
if constexpr(ElemPerThread % 8 == 0)
|
||||
return 8;
|
||||
else if constexpr(ElemPerThread % 4 == 0)
|
||||
return 4;
|
||||
else if constexpr(ElemPerThread % 2 == 0)
|
||||
return 2;
|
||||
return 1;
|
||||
}
|
||||
else if constexpr(std::is_same_v<DataType, float>)
|
||||
{
|
||||
// ToDo: need support in ck_tile for using buffer_load_dwordx3
|
||||
// if constexpr(ElemPerThread % 3 == 0)
|
||||
// return 3;
|
||||
if constexpr(ElemPerThread % 4 == 0)
|
||||
return 4;
|
||||
else if constexpr(ElemPerThread % 2 == 0)
|
||||
return 2;
|
||||
return 1;
|
||||
}
|
||||
else
|
||||
return 1;
|
||||
};
|
||||
|
||||
template <typename DataType,
|
||||
index_t kThreadBlockSize,
|
||||
index_t kHigherDimSize,
|
||||
index_t kLowerDimSize>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetDramTileAccessMaxVectorSize()
|
||||
{
|
||||
constexpr index_t ElemPerThread = (kHigherDimSize * kLowerDimSize) / kThreadBlockSize;
|
||||
|
||||
return GetMaxVectorSize<DataType, ElemPerThread>();
|
||||
}
|
||||
|
||||
} // namespace detail
|
||||
|
||||
template <typename QDataType_,
|
||||
typename KDataType_,
|
||||
typename VDataType_,
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,861 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch_default_policy.hpp"
|
||||
#include "ck_tile/ops/fmha/block/block_dropout.hpp"
|
||||
#include "ck_tile/ops/reduce/block/block_reduce.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSWholeKPrefetchDefaultPolicy>
|
||||
struct BlockFmhaPipelineQRKSVSWholeKPrefetchTrLoad
|
||||
{
|
||||
using Problem = remove_cvref_t<Problem_>;
|
||||
using Policy = remove_cvref_t<Policy_>;
|
||||
using QDataType = remove_cvref_t<typename Problem::QDataType>;
|
||||
using KDataType = remove_cvref_t<typename Problem::KDataType>;
|
||||
using VDataType = remove_cvref_t<typename Problem::VDataType>;
|
||||
using SaccDataType = remove_cvref_t<typename Problem::SaccDataType>;
|
||||
using CompDataType = remove_cvref_t<typename Problem::SMPLComputeDataType>;
|
||||
using BiasDataType = remove_cvref_t<typename Problem::BiasDataType>;
|
||||
using RandValOutputDataType = remove_cvref_t<typename Problem::RandValOutputDataType>;
|
||||
using LSEDataType = remove_cvref_t<typename Problem::LSEDataType>;
|
||||
using PDataType = remove_cvref_t<typename Problem::PDataType>;
|
||||
using OaccDataType = remove_cvref_t<typename Problem::OaccDataType>;
|
||||
using ODataType = remove_cvref_t<typename Problem::ODataType>;
|
||||
using FmhaMask = remove_cvref_t<typename Problem::FmhaMask>;
|
||||
using AttentionVariant = remove_cvref_t<typename Problem::AttentionVariant>;
|
||||
|
||||
using BlockFmhaShape = remove_cvref_t<typename Problem::BlockFmhaShape>;
|
||||
using VLayout = remove_cvref_t<typename BlockFmhaShape::VLayout>;
|
||||
static constexpr bool kQLoadOnce = true;
|
||||
static_assert(kQLoadOnce == Policy::QLoadOnce);
|
||||
static_assert(sizeof(KDataType) == sizeof(VDataType) &&
|
||||
alignof(KDataType) == alignof(VDataType),
|
||||
"K and V share the same LDS region; their element types must have identical "
|
||||
"size and alignment.");
|
||||
|
||||
static constexpr bool kUseN0Loop = true;
|
||||
static constexpr bool kIgnoreFastExp2 = true;
|
||||
static constexpr bool kIsNaiveHDimLoad = true;
|
||||
|
||||
static constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
|
||||
static constexpr index_t kM0 = BlockFmhaShape::kM0;
|
||||
static constexpr index_t kN0 = BlockFmhaShape::kN0;
|
||||
static constexpr index_t kN0Sub =
|
||||
BlockFmhaShape::kK0; // subdivision of kN0 used in N0-loop, same value as kK0
|
||||
static constexpr index_t kN1 = BlockFmhaShape::kN1;
|
||||
static constexpr index_t kK1 = BlockFmhaShape::kK1;
|
||||
static constexpr index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim;
|
||||
static constexpr index_t kSubQKHeaddim = BlockFmhaShape::kSubQKHeaddim;
|
||||
|
||||
static_assert(kSubQKHeaddim <= 256, "hdim bigger than 256 is not suitable for this pipeline!");
|
||||
|
||||
static constexpr bool kIsGroupMode = Problem::kIsGroupMode;
|
||||
static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ;
|
||||
static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK;
|
||||
static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ;
|
||||
static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV;
|
||||
static constexpr auto BiasEnum = Problem::BiasEnum;
|
||||
static constexpr bool kStoreLSE = Problem::kStoreLSE;
|
||||
static constexpr bool kHasDropout = Problem::kHasDropout;
|
||||
static constexpr bool kHasLogitsSoftCap = Problem::kHasLogitsSoftCap;
|
||||
|
||||
static_assert(Problem::kUseTrLoad == true, "Check failed!");
|
||||
|
||||
static constexpr bool kUseTrLoad = true;
|
||||
|
||||
// since this pipeline is only used by the inference path of xformers, the Dropout function is
|
||||
// not well tested with the pipeline, so here we have Dropout disabled
|
||||
static_assert(kHasDropout == false, "Dropout is not supported by this pipeline at present!");
|
||||
|
||||
// last dimension vector length used to create tensor view(and decide buffer_load vector length)
|
||||
// ... together with tensor distribution. tensor dist should able to overwrite this
|
||||
static constexpr index_t kAlignmentQ =
|
||||
kPadHeadDimQ ? 1 : Policy::template GetAlignmentQ<Problem>();
|
||||
static constexpr index_t kAlignmentK =
|
||||
kPadHeadDimQ ? 1 : Policy::template GetAlignmentK<Problem>();
|
||||
static constexpr index_t kAlignmentV =
|
||||
Problem::kPadHeadDimV ? 1 : Policy::template GetAlignmentV<Problem>();
|
||||
|
||||
static constexpr index_t kAlignmentO =
|
||||
kPadHeadDimV ? 1 : Policy::template GetAlignmentO<Problem>();
|
||||
static constexpr index_t kAlignmentBias =
|
||||
kPadSeqLenK ? 1 : Policy::template GetAlignmentBias<Problem>();
|
||||
|
||||
static constexpr index_t kBlockPerCu = []() {
|
||||
if constexpr(Problem::kBlockPerCu != -1)
|
||||
return Problem::kBlockPerCu;
|
||||
else
|
||||
{
|
||||
if constexpr(kQKHeaddim == 32)
|
||||
{
|
||||
return 2;
|
||||
}
|
||||
else if constexpr(kQKHeaddim == 64)
|
||||
{
|
||||
return 2;
|
||||
}
|
||||
else if constexpr(kQKHeaddim == 96 || kQKHeaddim == 128)
|
||||
{
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
|
||||
return 1;
|
||||
else
|
||||
return 2;
|
||||
}
|
||||
else if constexpr(kQKHeaddim == 256)
|
||||
{
|
||||
return 1;
|
||||
}
|
||||
else
|
||||
{
|
||||
return 1;
|
||||
};
|
||||
}
|
||||
}();
|
||||
|
||||
static constexpr const char* name = "qr_async_whole_k_prefetch_trload";
|
||||
|
||||
using DropoutType = std::conditional_t<kHasDropout, BlockDropout, NullBlockDropout>;
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
|
||||
{
|
||||
return Policy::template GetSmemSize<Problem>();
|
||||
}
|
||||
|
||||
template <typename QDramBlockWindowTmp,
|
||||
typename KDramBlockWindowTmp,
|
||||
typename VDramBlockWindowTmp,
|
||||
typename BiasDramBlockWindowTmp,
|
||||
typename RandValDramBlockWindowTmp,
|
||||
typename LSEDramBlockWindowTmp,
|
||||
typename QElementFunction,
|
||||
typename KElementFunction,
|
||||
typename VElementFunction,
|
||||
typename BiasElementFunction,
|
||||
typename LSEElementFunction,
|
||||
typename SAccElementFunction,
|
||||
typename PComputeElementFunction,
|
||||
typename OAccElementFunction,
|
||||
typename PositionEncoding,
|
||||
typename AttentionVariantParams,
|
||||
typename BlockIndices>
|
||||
CK_TILE_HOST_DEVICE auto
|
||||
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*kQKHeaddim tile
|
||||
const QElementFunction& q_element_func,
|
||||
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*kQKHeaddim tile
|
||||
const KElementFunction& k_element_func,
|
||||
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
|
||||
const VElementFunction& v_element_func,
|
||||
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
|
||||
const BiasElementFunction& bias_element_func,
|
||||
RandValDramBlockWindowTmp& randval_dram_block_window_tmp,
|
||||
LSEDramBlockWindowTmp& lse_dram_window_tmp, // M0*1 tile
|
||||
const LSEElementFunction& lse_element_func,
|
||||
const SAccElementFunction& s_acc_element_func,
|
||||
const PComputeElementFunction& p_compute_element_func,
|
||||
const OAccElementFunction& o_acc_element_func,
|
||||
FmhaMask mask,
|
||||
PositionEncoding position_encoding,
|
||||
float scale_s,
|
||||
const AttentionVariant& /* unused */,
|
||||
const AttentionVariantParams& /* unused */,
|
||||
const BlockIndices& /* unused */,
|
||||
void* smem_ptr,
|
||||
DropoutType& dropout) const
|
||||
{
|
||||
// xformers path does not require the pipeline to output random values for host
|
||||
// verification, since a separate kernel is used to generate random values
|
||||
ignore = randval_dram_block_window_tmp;
|
||||
|
||||
static_assert(
|
||||
std::is_same_v<QDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<KDataType, remove_cvref_t<typename KDramBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<VDataType, remove_cvref_t<typename VDramBlockWindowTmp::DataType>>,
|
||||
"wrong!");
|
||||
|
||||
static_assert(kM0 == QDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kN0Sub == KDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kQKHeaddim == KDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] &&
|
||||
kN1 == VDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kK1 == VDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] &&
|
||||
kM0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kN0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<1>{}],
|
||||
"wrong!");
|
||||
|
||||
constexpr auto I0 = number<0>{};
|
||||
constexpr auto I1 = number<1>{};
|
||||
|
||||
constexpr index_t n0_loops = kN0 / kN0Sub;
|
||||
constexpr index_t k1_loops = kN0 / kK1;
|
||||
|
||||
// usually kN0 is 128, kN0Sub/kK1 is 32/16
|
||||
static_assert(n0_loops >= 2, "n0_loops >= 2 required to use this pipeline");
|
||||
static_assert(k1_loops >= 2, "k1_loops >= 2 required to use this pipeline");
|
||||
|
||||
constexpr auto NumKVLdsBuffers = Policy::template GetNumKVLdsBuffers<Problem>();
|
||||
|
||||
constexpr index_t NumPrefetchV = Policy::template GetNumPrefetchV<Problem>();
|
||||
static_assert(n0_loops >= NumPrefetchV, "Check failed!");
|
||||
static_assert(k1_loops >= NumPrefetchV, "Check failed!");
|
||||
|
||||
constexpr bool kPreloadWholeNextIterationK =
|
||||
Policy::template IsPreloadWholeNextIterationK<Problem>();
|
||||
|
||||
// This path prefetches two k_tiles for next iteration, so it has the opportunity to
|
||||
// prefetch two v_tiles during Gemm0
|
||||
if constexpr(!kPreloadWholeNextIterationK)
|
||||
{
|
||||
static_assert(NumPrefetchV >= 2);
|
||||
};
|
||||
|
||||
// Block GEMM
|
||||
constexpr auto gemm_0 = Policy::template GetQKBlockGemm<Problem>();
|
||||
constexpr auto gemm_1 = Policy::template GetKVBlockGemm<Problem>();
|
||||
|
||||
// SaccBlockTile size is [kM0, kK1]
|
||||
// PcompBlockTile size is [kM0, kN0]
|
||||
using SaccBlockTileType = decltype(gemm_0.template MakeCBlockTile<kM0, kN0Sub>());
|
||||
using CombineSaccBlockTileType = decltype(gemm_0.template MakeCBlockTile<kM0, kN0>());
|
||||
using PcompBlockTileType = decltype(cast_tile<CompDataType>(CombineSaccBlockTileType{}));
|
||||
|
||||
SaccBlockTileType sacc_tile;
|
||||
PcompBlockTileType pcomp_tile;
|
||||
|
||||
// reduction function for softmax
|
||||
const auto f_max = [](auto e0, auto e1) { return max(e0, e1); };
|
||||
const auto f_sum = [](auto e0, auto e1) { return e0 + e1; };
|
||||
|
||||
using MLBlockTileType = decltype(block_tile_reduce<CompDataType>(
|
||||
PcompBlockTileType{}, sequence<1>{}, f_max, CompDataType{0}));
|
||||
|
||||
auto m = MLBlockTileType{};
|
||||
auto l = MLBlockTileType{};
|
||||
|
||||
using OaccBlockTileType = decltype(gemm_1.MakeCBlockTile());
|
||||
OaccBlockTileType o_acc;
|
||||
|
||||
auto q_dram_window = make_tile_window(q_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
make_tuple(number<kM0>{}, number<kQKHeaddim>{}),
|
||||
q_dram_block_window_tmp.get_window_origin(),
|
||||
Policy::template MakeQRegTileDistribution<Problem>());
|
||||
|
||||
const auto q_origin = q_dram_window.get_window_origin();
|
||||
const auto [seqlen_k_start, seqlen_k_end] =
|
||||
mask.GetTileRangeAlongX(q_origin.at(number<0>{}), number<kM0>{}, number<kN0>{});
|
||||
|
||||
if(seqlen_k_end <= seqlen_k_start)
|
||||
{
|
||||
clear_tile(o_acc);
|
||||
o_acc = tile_elementwise_in(o_acc_element_func, o_acc);
|
||||
return o_acc;
|
||||
};
|
||||
|
||||
auto k_dram_window =
|
||||
make_tile_window(k_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
make_tuple(number<kN0Sub>{}, number<kQKHeaddim>{}),
|
||||
{seqlen_k_start, 0},
|
||||
Policy::template MakeKDramTileDistribution<Problem>());
|
||||
|
||||
auto q_tile = load_tile(q_dram_window);
|
||||
|
||||
using k_tile_type = decltype(load_tile(k_dram_window));
|
||||
|
||||
auto k_tiles = [&]() {
|
||||
if constexpr(kPreloadWholeNextIterationK)
|
||||
return statically_indexed_array<k_tile_type, n0_loops>{};
|
||||
else
|
||||
return statically_indexed_array<k_tile_type, 2>{};
|
||||
}();
|
||||
|
||||
k_tiles[I0] = load_tile(k_dram_window);
|
||||
move_tile_window(k_dram_window, {kN0Sub, 0});
|
||||
|
||||
if constexpr(!kPreloadWholeNextIterationK)
|
||||
{
|
||||
k_tiles[I1] = load_tile(k_dram_window);
|
||||
move_tile_window(k_dram_window, {kN0Sub, 0});
|
||||
};
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0x00000001);
|
||||
|
||||
// provide partition_index for LDS tile window with so that warp_id is in vgpr
|
||||
array<index_t, 2> partition_index{get_warp_id<false>(), get_lane_id()};
|
||||
|
||||
// K tile in LDS
|
||||
KDataType* k_lds_ptr = static_cast<KDataType*>(smem_ptr);
|
||||
auto k_lds = make_tensor_view<address_space_enum::lds>(
|
||||
k_lds_ptr, Policy::template MakeKLdsBlockDescriptor<Problem>());
|
||||
auto k_lds_window = make_tile_window(
|
||||
k_lds, Policy::template MakeKLdsBlockDescriptor<Problem>().get_lengths(), {0, 0});
|
||||
|
||||
using k_lds_window_type = decltype(get_slice_tile(
|
||||
k_lds_window, sequence<0, 0>{}, sequence<kN0Sub, kQKHeaddim>{}));
|
||||
|
||||
statically_indexed_array<k_lds_window_type, NumKVLdsBuffers> k_lds_windows;
|
||||
|
||||
static_for<0, NumKVLdsBuffers, 1>{}([&](auto i_buf) {
|
||||
k_lds_windows[i_buf] = get_slice_tile(k_lds_window,
|
||||
sequence<i_buf * kN0Sub, 0>{},
|
||||
sequence<(i_buf + 1) * kN0Sub, kQKHeaddim>{});
|
||||
});
|
||||
|
||||
// V tile in LDS
|
||||
auto v_lds = make_tensor_view<address_space_enum::lds>(
|
||||
reinterpret_cast<VDataType*>(smem_ptr),
|
||||
Policy::template MakeVLdsBlockDescriptor<Problem>());
|
||||
auto v_lds_window = make_tile_window(
|
||||
v_lds, Policy::template MakeVLdsBlockDescriptor<Problem>().get_lengths(), {0, 0});
|
||||
|
||||
using v_lds_window_type =
|
||||
decltype(get_slice_tile(v_lds_window, sequence<0, 0>{}, sequence<kK1, kN1>{}));
|
||||
|
||||
statically_indexed_array<v_lds_window_type, NumKVLdsBuffers> v_lds_windows;
|
||||
|
||||
static_for<0, NumKVLdsBuffers, 1>{}([&](auto i_buf) {
|
||||
v_lds_windows[i_buf] = get_slice_tile(
|
||||
v_lds_window, sequence<i_buf * kK1, 0>{}, sequence<(i_buf + 1) * kK1, kN1>{});
|
||||
});
|
||||
|
||||
auto v_dram_window =
|
||||
make_tile_window(v_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
make_tuple(number<kK1>{}, number<kN1>{}),
|
||||
{seqlen_k_start, 0},
|
||||
Policy::template MakeVDramTileDistribution<Problem>());
|
||||
|
||||
const auto f_exp = [&](CompDataType x) {
|
||||
if constexpr(std::is_same_v<CompDataType, float>)
|
||||
{
|
||||
return __expf(x);
|
||||
}
|
||||
else
|
||||
{
|
||||
return exp(x);
|
||||
}
|
||||
};
|
||||
|
||||
const auto bias_origin = bias_dram_block_window_tmp.get_window_origin();
|
||||
auto bias_dram_window =
|
||||
make_tile_window(bias_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
make_tuple(number<kM0>{}, number<kN0>{}),
|
||||
{bias_origin.at(number<0>{}), seqlen_k_start},
|
||||
Policy::template MakeBiasDramTileDistribution<Problem>());
|
||||
|
||||
// assuming no random values need be saved, this is true when the pipeline is called from
|
||||
// xformers, since we have a separate kernel to generated random values
|
||||
auto null_randval_window = [&]() {
|
||||
if constexpr(kHasDropout)
|
||||
{
|
||||
// need to pass a null_randval_dram and tile window to the BlockDropout operator to
|
||||
// make it works
|
||||
const auto null_randval_dram = [&]() {
|
||||
const auto null_dram_naive = make_naive_tensor_view<address_space_enum::global>(
|
||||
static_cast<uint8_t*>(nullptr),
|
||||
make_tuple(1, 1),
|
||||
make_tuple(1, 1),
|
||||
number<1>{},
|
||||
number<1>{});
|
||||
|
||||
return pad_tensor_view(null_dram_naive,
|
||||
make_tuple(number<1>{}, number<1>{}),
|
||||
sequence<true, true>{});
|
||||
}();
|
||||
|
||||
return make_tile_window(
|
||||
null_randval_dram, make_tuple(number<1>{}, number<1>{}), {0, 0});
|
||||
}
|
||||
else
|
||||
return make_null_tile_window(make_tuple(number<1>{}, number<1>{}));
|
||||
}();
|
||||
|
||||
clear_tile(o_acc);
|
||||
set_tile(m, -numeric<CompDataType>::infinity());
|
||||
clear_tile(l);
|
||||
|
||||
q_tile = tile_elementwise_in(q_element_func, q_tile);
|
||||
|
||||
auto seqlen_k_curr = seqlen_k_start;
|
||||
|
||||
using v_tile_type = decltype(load_tile(v_dram_window));
|
||||
|
||||
statically_indexed_array<v_tile_type, k1_loops> v_tiles;
|
||||
|
||||
do
|
||||
{
|
||||
// STAGE 1, Gemm_0 ( S = Q@K )
|
||||
if constexpr(kPreloadWholeNextIterationK) // used when kM0 = 64
|
||||
{
|
||||
if(seqlen_k_curr == seqlen_k_start) // at first iteration
|
||||
{
|
||||
if(seqlen_k_curr < seqlen_k_end - kN0) // not the last iteration
|
||||
{
|
||||
static_for<0, n0_loops, 1>{}([&](auto i_n0) {
|
||||
store_tile(k_lds_windows[number<i_n0 % NumKVLdsBuffers>{}],
|
||||
tile_elementwise_in(k_element_func, k_tiles[number<i_n0>{}]),
|
||||
partition_index);
|
||||
|
||||
if constexpr(i_n0 < n0_loops - 1)
|
||||
{
|
||||
k_tiles[number<i_n0 + 1>{}] = load_tile(k_dram_window);
|
||||
move_tile_window(k_dram_window, {kN0Sub, 0});
|
||||
};
|
||||
|
||||
if constexpr(i_n0 == n0_loops - 1)
|
||||
{
|
||||
v_tiles[I0] = load_tile(v_dram_window);
|
||||
move_tile_window(v_dram_window, {kK1, 0});
|
||||
|
||||
// prefetch all k_tiles for next iteration
|
||||
static_for<0, n0_loops, 1>{}([&](auto ii_n0) {
|
||||
k_tiles[number<ii_n0>{}] = load_tile(k_dram_window);
|
||||
move_tile_window(k_dram_window, {kN0Sub, 0});
|
||||
});
|
||||
};
|
||||
|
||||
block_sync_lds();
|
||||
gemm_0(
|
||||
sacc_tile, q_tile, k_lds_windows[number<i_n0 % NumKVLdsBuffers>{}]);
|
||||
|
||||
sacc_tile = tile_elementwise_in(s_acc_element_func, sacc_tile);
|
||||
auto tmp_tile = cast_tile<CompDataType>(sacc_tile);
|
||||
set_slice_tile(pcomp_tile,
|
||||
tmp_tile,
|
||||
sequence<0, i_n0 * kN0Sub>{},
|
||||
sequence<kM0, (i_n0 + 1) * kN0Sub>{});
|
||||
});
|
||||
}
|
||||
else // the iteration is also the last iteration
|
||||
{
|
||||
static_for<0, n0_loops, 1>{}([&](auto i_n0) {
|
||||
store_tile(k_lds_windows[number<i_n0 % NumKVLdsBuffers>{}],
|
||||
tile_elementwise_in(k_element_func, k_tiles[number<i_n0>{}]),
|
||||
partition_index);
|
||||
|
||||
if constexpr(i_n0 < n0_loops - 1)
|
||||
{
|
||||
k_tiles[number<i_n0 + 1>{}] = load_tile(k_dram_window);
|
||||
move_tile_window(k_dram_window, {kN0Sub, 0});
|
||||
};
|
||||
|
||||
if constexpr(i_n0 == n0_loops - 1)
|
||||
{
|
||||
v_tiles[I0] = load_tile(v_dram_window);
|
||||
move_tile_window(v_dram_window, {kK1, 0});
|
||||
};
|
||||
|
||||
block_sync_lds();
|
||||
gemm_0(
|
||||
sacc_tile, q_tile, k_lds_windows[number<i_n0 % NumKVLdsBuffers>{}]);
|
||||
|
||||
sacc_tile = tile_elementwise_in(s_acc_element_func, sacc_tile);
|
||||
auto tmp_tile = cast_tile<CompDataType>(sacc_tile);
|
||||
set_slice_tile(pcomp_tile,
|
||||
tmp_tile,
|
||||
sequence<0, i_n0 * kN0Sub>{},
|
||||
sequence<kM0, (i_n0 + 1) * kN0Sub>{});
|
||||
});
|
||||
};
|
||||
}
|
||||
else // at intermediate and last iteration
|
||||
{
|
||||
if(seqlen_k_curr < seqlen_k_end - kN0) // intermediate iteration
|
||||
{
|
||||
static_for<0, n0_loops, 1>{}([&](auto i_n0) {
|
||||
store_tile(k_lds_windows[number<i_n0 % NumKVLdsBuffers>{}],
|
||||
tile_elementwise_in(k_element_func, k_tiles[number<i_n0>{}]),
|
||||
partition_index);
|
||||
|
||||
if constexpr(i_n0 == 0)
|
||||
{
|
||||
v_tiles[I0] = load_tile(v_dram_window);
|
||||
move_tile_window(v_dram_window, {kK1, 0});
|
||||
};
|
||||
|
||||
// prefetch k_tile for next iteration
|
||||
k_tiles[i_n0] = load_tile(k_dram_window);
|
||||
move_tile_window(k_dram_window, {kN0Sub, 0});
|
||||
|
||||
block_sync_lds();
|
||||
gemm_0(
|
||||
sacc_tile, q_tile, k_lds_windows[number<i_n0 % NumKVLdsBuffers>{}]);
|
||||
|
||||
sacc_tile = tile_elementwise_in(s_acc_element_func, sacc_tile);
|
||||
auto tmp_tile = cast_tile<CompDataType>(sacc_tile);
|
||||
set_slice_tile(pcomp_tile,
|
||||
tmp_tile,
|
||||
sequence<0, i_n0 * kN0Sub>{},
|
||||
sequence<kM0, (i_n0 + 1) * kN0Sub>{});
|
||||
});
|
||||
}
|
||||
else // last iteration
|
||||
{
|
||||
static_for<0, n0_loops, 1>{}([&](auto i_n0) {
|
||||
store_tile(k_lds_windows[number<i_n0 % NumKVLdsBuffers>{}],
|
||||
tile_elementwise_in(k_element_func, k_tiles[number<i_n0>{}]),
|
||||
partition_index);
|
||||
|
||||
if constexpr(i_n0 == 0)
|
||||
{
|
||||
v_tiles[I0] = load_tile(v_dram_window);
|
||||
move_tile_window(v_dram_window, {kK1, 0});
|
||||
};
|
||||
|
||||
block_sync_lds();
|
||||
gemm_0(
|
||||
sacc_tile, q_tile, k_lds_windows[number<i_n0 % NumKVLdsBuffers>{}]);
|
||||
|
||||
sacc_tile = tile_elementwise_in(s_acc_element_func, sacc_tile);
|
||||
auto tmp_tile = cast_tile<CompDataType>(sacc_tile);
|
||||
set_slice_tile(pcomp_tile,
|
||||
tmp_tile,
|
||||
sequence<0, i_n0 * kN0Sub>{},
|
||||
sequence<kM0, (i_n0 + 1) * kN0Sub>{});
|
||||
});
|
||||
};
|
||||
}
|
||||
}
|
||||
else // only preload one unroll of K for next iteration, used when kM0=128
|
||||
{
|
||||
static_for<0, n0_loops, 1>{}([&](auto i_n0) {
|
||||
store_tile(k_lds_windows[number<i_n0 % NumKVLdsBuffers>{}],
|
||||
tile_elementwise_in(k_element_func, k_tiles[number<i_n0 % 2>{}]),
|
||||
partition_index);
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0x00000001);
|
||||
|
||||
if constexpr(i_n0 < n0_loops - 2)
|
||||
{
|
||||
k_tiles[number<i_n0 % 2>{}] = load_tile(k_dram_window);
|
||||
move_tile_window(k_dram_window, {kN0Sub, 0});
|
||||
};
|
||||
|
||||
if constexpr(i_n0 >= n0_loops - 2)
|
||||
{
|
||||
v_tiles[number<i_n0 - (n0_loops - 2)>{}] = load_tile(v_dram_window);
|
||||
move_tile_window(v_dram_window, {kK1, 0});
|
||||
};
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0x00000001);
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
gemm_0(sacc_tile, q_tile, k_lds_windows[number<i_n0 % NumKVLdsBuffers>{}]);
|
||||
|
||||
sacc_tile = tile_elementwise_in(s_acc_element_func, sacc_tile);
|
||||
auto tmp_tile = cast_tile<CompDataType>(sacc_tile);
|
||||
set_slice_tile(pcomp_tile,
|
||||
tmp_tile,
|
||||
sequence<0, i_n0 * kN0Sub>{},
|
||||
sequence<kM0, (i_n0 + 1) * kN0Sub>{});
|
||||
});
|
||||
}
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0x000000001);
|
||||
|
||||
const auto bias_tile = load_tile(bias_dram_window); // load bias tile
|
||||
|
||||
// STAGE 2, scale_s, add bias, mask, softmax
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
|
||||
{
|
||||
tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, pcomp_tile);
|
||||
|
||||
tile_elementwise_inout(
|
||||
[&](auto& x, const auto y) {
|
||||
x += type_convert<CompDataType>(bias_element_func(y));
|
||||
},
|
||||
pcomp_tile,
|
||||
bias_tile);
|
||||
}
|
||||
else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI)
|
||||
{
|
||||
constexpr auto pcomp_spans = decltype(pcomp_tile)::get_distributed_spans();
|
||||
sweep_tile_span(pcomp_spans[number<0>{}], [&](auto idx0) {
|
||||
sweep_tile_span(pcomp_spans[number<1>{}], [&](auto idx1) {
|
||||
const auto tile_idx = get_x_indices_from_distributed_indices(
|
||||
pcomp_tile.get_tile_distribution(), make_tuple(idx0, idx1));
|
||||
|
||||
const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
|
||||
const auto col = seqlen_k_curr + tile_idx.at(number<1>{});
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
|
||||
pcomp_tile(i_j_idx) *= scale_s;
|
||||
position_encoding.update(pcomp_tile(i_j_idx), row, col);
|
||||
});
|
||||
});
|
||||
}
|
||||
else
|
||||
{
|
||||
tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, pcomp_tile);
|
||||
}
|
||||
|
||||
move_tile_window(bias_dram_window, {0, kN0});
|
||||
|
||||
if constexpr(kPadSeqLenK || FmhaMask::IsMasking)
|
||||
{
|
||||
bool need_perpixel_check = mask.IsEdgeTile(
|
||||
q_origin.at(number<0>{}), seqlen_k_curr, number<kM0>{}, number<kN0>{});
|
||||
if(need_perpixel_check)
|
||||
{
|
||||
set_tile_if(pcomp_tile, -numeric<CompDataType>::infinity(), [&](auto tile_idx) {
|
||||
const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
|
||||
const auto col = seqlen_k_curr + tile_idx.at(number<1>{});
|
||||
return mask.IsOutOfBound(row, col);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0x00000001);
|
||||
|
||||
auto m_local = block_tile_reduce<CompDataType>(
|
||||
pcomp_tile, sequence<1>{}, f_max, -numeric<CompDataType>::infinity());
|
||||
block_tile_reduce_sync(m_local, f_max, bool_constant<false>{});
|
||||
|
||||
const auto m_old = m;
|
||||
|
||||
tile_elementwise_inout(
|
||||
[](auto& e0, auto e1, auto e2) { e0 = max(e1, e2); }, m, m_old, m_local);
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
// check whether first V-LdsBufer overlap with last K-LdsBuffer,
|
||||
// this does not occur when k1_loops == 2 and NumKVLdsBuffers == 4
|
||||
if constexpr((n0_loops - 1) % NumKVLdsBuffers == 2 % NumKVLdsBuffers)
|
||||
{
|
||||
__builtin_amdgcn_s_barrier();
|
||||
};
|
||||
|
||||
store_tile(v_lds_windows[number<2 % NumKVLdsBuffers>{}],
|
||||
tile_elementwise_in(v_element_func, v_tiles[I0]),
|
||||
partition_index);
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0x00000001);
|
||||
|
||||
if constexpr(kPreloadWholeNextIterationK)
|
||||
{
|
||||
static_for<1, NumPrefetchV, 1>{}([&](auto i_k1) {
|
||||
v_tiles[i_k1] = load_tile(v_dram_window);
|
||||
move_tile_window(v_dram_window, {kK1, 0});
|
||||
});
|
||||
}
|
||||
else
|
||||
{
|
||||
static_for<2, NumPrefetchV, 1>{}([&](auto i_k1) {
|
||||
v_tiles[i_k1] = load_tile(v_dram_window);
|
||||
move_tile_window(v_dram_window, {kK1, 0});
|
||||
});
|
||||
};
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
constexpr auto p_spans = decltype(pcomp_tile)::get_distributed_spans();
|
||||
sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) {
|
||||
constexpr auto i_idx = make_tuple(idx0);
|
||||
|
||||
if(m[i_idx] == -numeric<CompDataType>::infinity())
|
||||
{
|
||||
sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) {
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
pcomp_tile(i_j_idx) = type_convert<CompDataType>(0.0f);
|
||||
});
|
||||
}
|
||||
else
|
||||
{
|
||||
sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) {
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
pcomp_tile(i_j_idx) = f_exp(pcomp_tile[i_j_idx] - m[i_idx]);
|
||||
});
|
||||
}
|
||||
});
|
||||
|
||||
auto rowsum_p =
|
||||
block_tile_reduce<CompDataType>(pcomp_tile, sequence<1>{}, f_sum, CompDataType{0});
|
||||
|
||||
block_tile_reduce_sync(rowsum_p, f_sum, bool_constant<false>{});
|
||||
|
||||
// adjust o_acc[] according to the update between m and m_old
|
||||
constexpr auto o_spans = decltype(o_acc)::get_distributed_spans();
|
||||
sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) {
|
||||
constexpr auto i_idx = make_tuple(idx0);
|
||||
|
||||
if(m[i_idx] == -numeric<CompDataType>::infinity())
|
||||
{
|
||||
l(i_idx) = rowsum_p[i_idx];
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto tmp = f_exp(m_old[i_idx] - m[i_idx]);
|
||||
l(i_idx) = tmp * l[i_idx] + rowsum_p[i_idx];
|
||||
sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) {
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
o_acc(i_j_idx) *= tmp;
|
||||
});
|
||||
}
|
||||
});
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0x00000001);
|
||||
|
||||
if constexpr(kHasDropout)
|
||||
{
|
||||
auto randval_lds_ptr =
|
||||
reinterpret_cast<char*>(smem_ptr) + Policy::template GetSmemSizeKV<Problem>();
|
||||
|
||||
dropout.template Run<decltype(gemm_0), CompDataType, uint8_t>(
|
||||
randval_lds_ptr, seqlen_k_curr, pcomp_tile, null_randval_window);
|
||||
}
|
||||
|
||||
seqlen_k_curr += kN0;
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0x00000001);
|
||||
|
||||
auto p = cast_tile<PDataType>(tile_elementwise_in(p_compute_element_func, pcomp_tile));
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0x00000001);
|
||||
|
||||
// STAGE 3, Gemm_1 ( O = P@V )
|
||||
static_for<0, k1_loops, 1>{}([&](auto i_k1) {
|
||||
if constexpr(i_k1 < k1_loops - NumPrefetchV)
|
||||
{
|
||||
v_tiles[number<i_k1 % NumPrefetchV>{}] = load_tile(v_dram_window);
|
||||
move_tile_window(v_dram_window, {kK1, 0});
|
||||
};
|
||||
|
||||
if constexpr(i_k1 == k1_loops - NumPrefetchV)
|
||||
{
|
||||
if constexpr(!kPreloadWholeNextIterationK)
|
||||
{
|
||||
if(seqlen_k_curr < seqlen_k_end)
|
||||
{
|
||||
k_tiles[I0] = load_tile(k_dram_window);
|
||||
move_tile_window(k_dram_window, {kN0Sub, 0});
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
if constexpr(i_k1 == k1_loops - NumPrefetchV + 1)
|
||||
{
|
||||
if constexpr(!kPreloadWholeNextIterationK)
|
||||
{
|
||||
if(seqlen_k_curr < seqlen_k_end)
|
||||
{
|
||||
k_tiles[I1] = load_tile(k_dram_window);
|
||||
move_tile_window(k_dram_window, {kN0Sub, 0});
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
block_sync_lds();
|
||||
gemm_1(
|
||||
o_acc,
|
||||
get_slice_tile(p, sequence<0, i_k1 * kK1>{}, sequence<kM0, (i_k1 + 1) * kK1>{}),
|
||||
v_lds_windows[number<(i_k1 + 2) % NumKVLdsBuffers>{}]);
|
||||
|
||||
if constexpr(i_k1 < k1_loops - 1)
|
||||
{
|
||||
store_tile(v_lds_windows[number<(i_k1 + 3) % NumKVLdsBuffers>{}],
|
||||
tile_elementwise_in(v_element_func,
|
||||
v_tiles[number<(i_k1 + 1) % NumPrefetchV>{}]),
|
||||
partition_index);
|
||||
};
|
||||
});
|
||||
|
||||
// check whether last V-LdsBuffer overlap with first K-LdsBuffer,
|
||||
// this does not occur when k1_loops == 2 and NumKVLdsBuffers == 4
|
||||
if constexpr((k1_loops - 1 + 2) % NumKVLdsBuffers == 0)
|
||||
{
|
||||
__builtin_amdgcn_s_barrier();
|
||||
};
|
||||
} while(seqlen_k_curr < seqlen_k_end);
|
||||
|
||||
// store lse
|
||||
if constexpr(kStoreLSE)
|
||||
{
|
||||
auto lse = make_static_distributed_tensor<LSEDataType>(m.get_tile_distribution());
|
||||
|
||||
constexpr auto lse_spans = decltype(lse)::get_distributed_spans();
|
||||
sweep_tile_span(lse_spans[number<0>{}], [&, m_ = m, l_ = l](auto idx0) {
|
||||
constexpr auto i_idx = make_tuple(idx0);
|
||||
lse(i_idx) = m_[i_idx] + log(l_[i_idx]);
|
||||
});
|
||||
|
||||
store_tile(lse_dram_window_tmp, tile_elementwise_in(lse_element_func, lse));
|
||||
}
|
||||
|
||||
// finally, O
|
||||
constexpr auto o_spans = decltype(o_acc)::get_distributed_spans();
|
||||
|
||||
sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) {
|
||||
constexpr auto i_idx = make_tuple(idx0);
|
||||
sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) {
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
|
||||
if(m[i_idx] == -numeric<CompDataType>::infinity())
|
||||
o_acc(i_j_idx) = 0.0f;
|
||||
else
|
||||
o_acc(i_j_idx) *= 1.0f / l[i_idx];
|
||||
});
|
||||
});
|
||||
|
||||
o_acc = tile_elementwise_in(o_acc_element_func, o_acc);
|
||||
|
||||
return o_acc;
|
||||
}
|
||||
|
||||
template <typename QDramBlockWindowTmp,
|
||||
typename KDramBlockWindowTmp,
|
||||
typename VDramBlockWindowTmp,
|
||||
typename BiasDramBlockWindowTmp,
|
||||
typename RandValDramBlockWindowTmp,
|
||||
typename LSEDramBlockWindowTmp,
|
||||
typename PositionEncoding,
|
||||
typename AttentionVariantParams,
|
||||
typename BlockIndices>
|
||||
CK_TILE_HOST_DEVICE auto
|
||||
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
|
||||
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile
|
||||
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
|
||||
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
|
||||
RandValDramBlockWindowTmp& randval_dram_block_window_tmp, // M0*N0 tile
|
||||
LSEDramBlockWindowTmp& lse_dram_block_window_tmp, // M0*1 tile
|
||||
FmhaMask mask,
|
||||
PositionEncoding position_encoding,
|
||||
float scale_s,
|
||||
const AttentionVariant& variant,
|
||||
const AttentionVariantParams& variant_params,
|
||||
const BlockIndices& block_indices,
|
||||
void* smem_ptr,
|
||||
DropoutType& dropout,
|
||||
const float sink_v) const
|
||||
{
|
||||
ignore = sink_v;
|
||||
|
||||
return operator()(q_dram_block_window_tmp,
|
||||
identity{},
|
||||
k_dram_block_window_tmp,
|
||||
identity{},
|
||||
v_dram_block_window_tmp,
|
||||
identity{},
|
||||
bias_dram_block_window_tmp,
|
||||
identity{},
|
||||
randval_dram_block_window_tmp,
|
||||
lse_dram_block_window_tmp,
|
||||
identity{},
|
||||
identity{},
|
||||
identity{},
|
||||
identity{},
|
||||
mask,
|
||||
position_encoding,
|
||||
scale_s,
|
||||
variant,
|
||||
variant_params,
|
||||
block_indices,
|
||||
smem_ptr,
|
||||
dropout);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -692,8 +692,11 @@ struct BlockFmhaPipelineQSKSVS
|
||||
const AttentionVariantParams& variant_params,
|
||||
const BlockIndices& block_indices,
|
||||
void* smem_ptr,
|
||||
DropoutType& dropout) const
|
||||
DropoutType& dropout,
|
||||
const float sink_v) const
|
||||
{
|
||||
ignore = sink_v;
|
||||
|
||||
return operator()(q_dram_block_window_tmp,
|
||||
identity{},
|
||||
k_dram_block_window_tmp,
|
||||
|
||||
@@ -57,7 +57,7 @@ struct TileFmhaShape
|
||||
static constexpr index_t kQKHeaddim =
|
||||
BlockTile::at(number<5>{}); // total length of K0, used for pipeline that need load Q at
|
||||
// once (or repeately load Q as a whole tile)
|
||||
static_assert(kQKHeaddim % kK0 == 0, "kQKHeaddim should be divisible by kK0");
|
||||
static_assert(kQKHeaddim % kK0 == 0, "kQKHeaddim must be divisible by kK0!");
|
||||
|
||||
static constexpr index_t kSubQKHeaddim = ceil_to_qualified_tile_length<kQKHeaddim>();
|
||||
|
||||
|
||||
@@ -0,0 +1,268 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2_default_policy.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// A is block distributed tensor
|
||||
// B is block window on shared memory
|
||||
// C is block distributed tensor
|
||||
template <typename Problem_, typename Policy_ = BlockGemmARegBSmemCRegV2DefaultPolicy>
|
||||
struct BlockGemmARegBSmemCRegV2PrefetchK
|
||||
{
|
||||
using Problem = remove_cvref_t<Problem_>;
|
||||
using Policy = remove_cvref_t<Policy_>;
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
using CDataType = remove_cvref_t<typename Problem::CDataType>;
|
||||
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
|
||||
|
||||
static constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
|
||||
// C += A * B
|
||||
template <typename CBlockTensor, typename ABlockTensorTmp, typename BBlockWindowTmp>
|
||||
CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor,
|
||||
const ABlockTensorTmp& a_block_tensor_tmp,
|
||||
const BBlockWindowTmp& b_block_window_tmp) const
|
||||
{
|
||||
static_assert(
|
||||
std::is_same_v<ADataType, remove_cv_t<typename ABlockTensorTmp::DataType>> &&
|
||||
std::is_same_v<BDataType, remove_cv_t<typename BBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<CDataType, remove_cv_t<typename CBlockTensor::DataType>>,
|
||||
"wrong!");
|
||||
|
||||
constexpr index_t MPerBlock = ABlockTensorTmp{}.get_lengths()[number<0>{}];
|
||||
constexpr index_t NPerBlock = BBlockWindowTmp{}.get_window_lengths()[number<0>{}];
|
||||
constexpr index_t KPerBlock = ABlockTensorTmp{}.get_lengths()[number<1>{}];
|
||||
|
||||
static_assert(MPerBlock == BlockGemmShape::kM && NPerBlock == BlockGemmShape::kN &&
|
||||
KPerBlock == BlockGemmShape::kK,
|
||||
"wrong!");
|
||||
|
||||
constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>();
|
||||
|
||||
using WG = remove_cvref_t<decltype(config.template at<0>())>;
|
||||
|
||||
constexpr index_t MWarp = config.template at<1>();
|
||||
constexpr index_t NWarp = config.template at<2>();
|
||||
|
||||
constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM);
|
||||
constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN);
|
||||
constexpr index_t KIterPerWarp = KPerBlock / WG::kK;
|
||||
|
||||
constexpr index_t NPerBlockPerIter = NPerBlock / NIterPerWarp;
|
||||
constexpr index_t KPerBlockPerIter = KPerBlock / KIterPerWarp;
|
||||
|
||||
const index_t iNWarp = get_warp_id<false>() % NWarp;
|
||||
|
||||
static_assert(NWarp == 1, "Check failed!");
|
||||
|
||||
constexpr auto c_block_outer_dstr_encoding =
|
||||
tile_distribution_encoding<sequence<NWarp>,
|
||||
tuple<sequence<MIterPerWarp, MWarp>, sequence<NIterPerWarp>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
|
||||
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{});
|
||||
|
||||
// constrcut from A-block-tensor from A-Block-tensor-tmp
|
||||
// FIXME: need method to check a_block_tensor and a_block_tensor_tmp have equivalent
|
||||
// distribution
|
||||
auto a_block_tensor = make_static_distributed_tensor<typename ABlockTensorTmp::DataType>(
|
||||
MakeABlockTileDistribution());
|
||||
|
||||
a_block_tensor.get_thread_buffer() = a_block_tensor_tmp.get_thread_buffer();
|
||||
|
||||
// construct B-warp-window
|
||||
auto b_warp_window_tmp = make_tile_window(
|
||||
b_block_window_tmp.get_bottom_tensor_view(),
|
||||
make_tuple(number<WG::kN>{}, number<WG::kK>{}),
|
||||
b_block_window_tmp.get_window_origin() + multi_index<2>{iNWarp * WG::kN, 0},
|
||||
make_static_tile_distribution(typename WG::BWarpDstrEncoding{}));
|
||||
|
||||
statically_indexed_array<
|
||||
statically_indexed_array<decltype(b_warp_window_tmp), KIterPerWarp>,
|
||||
NIterPerWarp>
|
||||
b_warp_windows;
|
||||
|
||||
// check C-block-distribution
|
||||
static_assert(
|
||||
std::is_same_v<remove_cvref_t<decltype(c_block_dstr_encode)>,
|
||||
remove_cvref_t<decltype(CBlockTensor::get_tile_distribution()
|
||||
.get_static_tile_distribution_encoding())>>,
|
||||
"wrong!");
|
||||
|
||||
using AWarpDstr = typename WG::AWarpDstr;
|
||||
using CWarpDstr = typename WG::CWarpDstr;
|
||||
|
||||
using AWarpTensor = typename WG::AWarpTensor;
|
||||
using CWarpTensor = typename WG::CWarpTensor;
|
||||
|
||||
constexpr auto a_warp_y_lengths =
|
||||
to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
|
||||
constexpr auto c_warp_y_lengths =
|
||||
to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
|
||||
|
||||
constexpr auto a_warp_y_index_zeros = uniform_sequence_gen_t<AWarpDstr::NDimY, 0>{};
|
||||
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
|
||||
|
||||
constexpr auto I0 = number<0>{};
|
||||
|
||||
// hot loop:
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
using b_warp_tensor_type = decltype(load_tile(b_warp_windows(I0)(I0)));
|
||||
|
||||
statically_indexed_array<b_warp_tensor_type, KIterPerWarp> b_warp_tensors;
|
||||
|
||||
// read B warp tensor from B Block window
|
||||
b_warp_windows(nIter)(I0) = b_warp_window_tmp;
|
||||
move_tile_window(b_warp_windows(nIter)(I0),
|
||||
{nIter * NPerBlockPerIter, 0 * KPerBlockPerIter});
|
||||
b_warp_tensors[I0] = load_tile(b_warp_windows(nIter)(I0));
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0x00000001);
|
||||
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
if constexpr(kIter < KIterPerWarp - 1)
|
||||
{
|
||||
// read B warp tensor from B Block window
|
||||
b_warp_windows(nIter)(number<kIter + 1>{}) = b_warp_window_tmp;
|
||||
move_tile_window(b_warp_windows(nIter)(number<kIter + 1>{}),
|
||||
{nIter * NPerBlockPerIter, (kIter + 1) * KPerBlockPerIter});
|
||||
b_warp_tensors[number<kIter + 1>{}] =
|
||||
load_tile(b_warp_windows(nIter)(number<kIter + 1>{}));
|
||||
};
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0x00000001);
|
||||
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
// read A warp tensor from A block tensor
|
||||
AWarpTensor a_warp_tensor;
|
||||
|
||||
a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, kIter>{}, a_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
|
||||
|
||||
// read C warp tensor from C block tensor
|
||||
CWarpTensor c_warp_tensor;
|
||||
|
||||
if constexpr(kIter == 0)
|
||||
{
|
||||
// warp GEMM
|
||||
c_warp_tensor = WG{}(a_warp_tensor, b_warp_tensors[kIter]);
|
||||
// WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor_array[nIter]);
|
||||
}
|
||||
else
|
||||
{
|
||||
c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
|
||||
// warp GEMM
|
||||
WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensors[kIter]);
|
||||
// WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor_array[nIter]);
|
||||
};
|
||||
|
||||
// write C warp tensor into C block tensor
|
||||
c_block_tensor.set_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
|
||||
c_warp_tensor.get_thread_buffer());
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
template <index_t MPerBlock = BlockGemmShape::kM, index_t KPerBlock = BlockGemmShape::kK>
|
||||
CK_TILE_DEVICE static constexpr auto MakeABlockDistributionEncode()
|
||||
{
|
||||
constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>();
|
||||
|
||||
using WG = remove_cvref_t<decltype(config.template at<0>())>;
|
||||
|
||||
constexpr index_t MWarp = config.template at<1>();
|
||||
constexpr index_t NWarp = config.template at<2>();
|
||||
|
||||
constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM);
|
||||
constexpr index_t KIterPerWarp = KPerBlock / WG::kK;
|
||||
|
||||
constexpr auto a_block_outer_dstr_encoding =
|
||||
tile_distribution_encoding<sequence<NWarp>,
|
||||
tuple<sequence<MIterPerWarp, MWarp>, sequence<KIterPerWarp>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
|
||||
constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
a_block_outer_dstr_encoding, typename WG::AWarpDstrEncoding{});
|
||||
|
||||
return a_block_dstr_encode;
|
||||
}
|
||||
|
||||
template <index_t MPerBlock = BlockGemmShape::kM, index_t KPerBlock = BlockGemmShape::kK>
|
||||
CK_TILE_DEVICE static constexpr auto MakeABlockTileDistribution()
|
||||
{
|
||||
constexpr auto a_block_dstr_encode = MakeABlockDistributionEncode<MPerBlock, KPerBlock>();
|
||||
|
||||
return make_static_tile_distribution(a_block_dstr_encode);
|
||||
}
|
||||
|
||||
template <index_t MPerBlock = BlockGemmShape::kM, index_t NPerBlock = BlockGemmShape::kN>
|
||||
CK_TILE_DEVICE static constexpr auto MakeCBlockDistributionEncode()
|
||||
{
|
||||
constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>();
|
||||
|
||||
using WG = remove_cvref_t<decltype(config.template at<0>())>;
|
||||
|
||||
constexpr index_t MWarp = config.template at<1>();
|
||||
constexpr index_t NWarp = config.template at<2>();
|
||||
|
||||
constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM);
|
||||
constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN);
|
||||
|
||||
static_assert(NWarp == 1, "Check failed!");
|
||||
|
||||
constexpr auto c_block_outer_dstr_encoding =
|
||||
tile_distribution_encoding<sequence<NWarp>,
|
||||
tuple<sequence<MIterPerWarp, MWarp>, sequence<NIterPerWarp>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
|
||||
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{});
|
||||
|
||||
return c_block_dstr_encode;
|
||||
}
|
||||
|
||||
template <index_t MPerBlock = BlockGemmShape::kM, index_t NPerBlock = BlockGemmShape::kN>
|
||||
CK_TILE_DEVICE static constexpr auto MakeCBlockTile()
|
||||
{
|
||||
constexpr auto c_block_dstr_encode = MakeCBlockDistributionEncode<MPerBlock, NPerBlock>();
|
||||
|
||||
constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode);
|
||||
auto c_block_tensor = make_static_distributed_tensor<CDataType>(c_block_dstr);
|
||||
return c_block_tensor;
|
||||
}
|
||||
|
||||
// C = A * B
|
||||
template <typename ABlockTensorTmp, typename BBlockWindowTmp>
|
||||
CK_TILE_DEVICE auto operator()(const ABlockTensorTmp& a_block_tensor_tmp,
|
||||
const BBlockWindowTmp& b_block_window_tmp) const
|
||||
{
|
||||
auto c_block_tensor = MakeCBlockTile();
|
||||
operator()(c_block_tensor, a_block_tensor_tmp, b_block_window_tmp);
|
||||
return c_block_tensor;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,239 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2_default_policy.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// A is block distributed tensor
|
||||
// B is block window on shared memory
|
||||
// C is block distributed tensor
|
||||
template <typename Problem_, typename Policy_ = BlockGemmARegBSmemCRegV2DefaultPolicy>
|
||||
struct BlockGemmARegBSmemCRegV2PrefetchN
|
||||
{
|
||||
using Problem = remove_cvref_t<Problem_>;
|
||||
using Policy = remove_cvref_t<Policy_>;
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
using CDataType = remove_cvref_t<typename Problem::CDataType>;
|
||||
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
|
||||
|
||||
static constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
|
||||
// C += A * B
|
||||
template <typename CBlockTensor, typename ABlockTensorTmp, typename BBlockWindowTmp>
|
||||
CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor,
|
||||
const ABlockTensorTmp& a_block_tensor_tmp,
|
||||
const BBlockWindowTmp& b_block_window_tmp) const
|
||||
{
|
||||
static_assert(
|
||||
std::is_same_v<ADataType, remove_cv_t<typename ABlockTensorTmp::DataType>> &&
|
||||
std::is_same_v<BDataType, remove_cv_t<typename BBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<CDataType, remove_cv_t<typename CBlockTensor::DataType>>,
|
||||
"wrong!");
|
||||
|
||||
constexpr index_t MPerBlock = ABlockTensorTmp{}.get_lengths()[number<0>{}];
|
||||
constexpr index_t NPerBlock = BBlockWindowTmp{}.get_window_lengths()[number<0>{}];
|
||||
constexpr index_t KPerBlock = ABlockTensorTmp{}.get_lengths()[number<1>{}];
|
||||
|
||||
static_assert(MPerBlock == BlockGemmShape::kM && NPerBlock == BlockGemmShape::kN &&
|
||||
KPerBlock == BlockGemmShape::kK,
|
||||
"wrong!");
|
||||
|
||||
constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>();
|
||||
|
||||
using WG = remove_cvref_t<decltype(config.template at<0>())>;
|
||||
|
||||
constexpr index_t MWarp = config.template at<1>();
|
||||
constexpr index_t NWarp = config.template at<2>();
|
||||
|
||||
constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM);
|
||||
constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN);
|
||||
constexpr index_t KIterPerWarp = KPerBlock / WG::kK;
|
||||
|
||||
constexpr index_t NPerBlockPerIter = NPerBlock / NIterPerWarp;
|
||||
constexpr index_t KPerBlockPerIter = KPerBlock / KIterPerWarp;
|
||||
|
||||
const index_t iNWarp = get_warp_id<false>() % NWarp;
|
||||
|
||||
constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<MIterPerWarp, MWarp>, sequence<NIterPerWarp, NWarp>>,
|
||||
tuple<sequence<1, 2>>,
|
||||
tuple<sequence<1, 1>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
|
||||
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{});
|
||||
|
||||
// constrcut from A-block-tensor from A-Block-tensor-tmp
|
||||
// FIXME: need method to check a_block_tensor and a_block_tensor_tmp have equivalent
|
||||
// distribution
|
||||
auto a_block_tensor = make_static_distributed_tensor<typename ABlockTensorTmp::DataType>(
|
||||
MakeABlockTileDistribution());
|
||||
|
||||
a_block_tensor.get_thread_buffer() = a_block_tensor_tmp.get_thread_buffer();
|
||||
|
||||
// construct B-warp-window
|
||||
auto b_warp_window_tmp = make_tile_window(
|
||||
b_block_window_tmp.get_bottom_tensor_view(),
|
||||
make_tuple(number<WG::kN>{}, number<WG::kK>{}),
|
||||
b_block_window_tmp.get_window_origin() + multi_index<2>{iNWarp * WG::kN, 0},
|
||||
make_static_tile_distribution(typename WG::BWarpDstrEncoding{}));
|
||||
|
||||
statically_indexed_array<
|
||||
statically_indexed_array<decltype(b_warp_window_tmp), KIterPerWarp>,
|
||||
NIterPerWarp>
|
||||
b_warp_windows;
|
||||
|
||||
// check C-block-distribution
|
||||
static_assert(
|
||||
std::is_same_v<remove_cvref_t<decltype(c_block_dstr_encode)>,
|
||||
remove_cvref_t<decltype(CBlockTensor::get_tile_distribution()
|
||||
.get_static_tile_distribution_encoding())>>,
|
||||
"wrong!");
|
||||
|
||||
using AWarpDstr = typename WG::AWarpDstr;
|
||||
using CWarpDstr = typename WG::CWarpDstr;
|
||||
|
||||
using AWarpTensor = typename WG::AWarpTensor;
|
||||
using CWarpTensor = typename WG::CWarpTensor;
|
||||
|
||||
constexpr auto a_warp_y_lengths =
|
||||
to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
|
||||
constexpr auto c_warp_y_lengths =
|
||||
to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
|
||||
|
||||
constexpr auto a_warp_y_index_zeros = uniform_sequence_gen_t<AWarpDstr::NDimY, 0>{};
|
||||
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
|
||||
|
||||
constexpr auto I0 = number<0>{};
|
||||
|
||||
using b_warp_tensor_type = decltype(load_tile(b_warp_windows(I0)(I0)));
|
||||
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
statically_indexed_array<b_warp_tensor_type, NIterPerWarp> b_warp_tensors;
|
||||
|
||||
// read B warp tensor from B Block window
|
||||
b_warp_windows(I0)(kIter) = b_warp_window_tmp;
|
||||
move_tile_window(b_warp_windows(I0)(kIter),
|
||||
{0 * NPerBlockPerIter, kIter * KPerBlockPerIter});
|
||||
b_warp_tensors(I0) = load_tile(b_warp_windows(I0)(kIter));
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0x00000001);
|
||||
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
if constexpr(nIter < NIterPerWarp - 1)
|
||||
{
|
||||
// read B warp tensor from B Block window
|
||||
b_warp_windows(number<nIter + 1>{})(kIter) = b_warp_window_tmp;
|
||||
move_tile_window(b_warp_windows(number<nIter + 1>{})(kIter),
|
||||
{(nIter + 1) * NPerBlockPerIter, kIter * KPerBlockPerIter});
|
||||
b_warp_tensors(number<nIter + 1>{}) =
|
||||
load_tile(b_warp_windows(number<nIter + 1>{})(kIter));
|
||||
};
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0x00000001);
|
||||
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
// read A warp tensor from A block tensor
|
||||
AWarpTensor a_warp_tensor;
|
||||
|
||||
a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, kIter>{}, a_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
|
||||
|
||||
// read C warp tensor from C block tensor
|
||||
CWarpTensor c_warp_tensor;
|
||||
|
||||
c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
|
||||
// warp GEMM
|
||||
WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensors[nIter]);
|
||||
|
||||
// write C warp tensor into C block tensor
|
||||
c_block_tensor.set_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
|
||||
c_warp_tensor.get_thread_buffer());
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
template <index_t MPerBlock = BlockGemmShape::kM, index_t KPerBlock = BlockGemmShape::kK>
|
||||
CK_TILE_DEVICE static constexpr auto MakeABlockTileDistribution()
|
||||
{
|
||||
constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>();
|
||||
|
||||
using WG = remove_cvref_t<decltype(config.template at<0>())>;
|
||||
|
||||
constexpr index_t MWarp = config.template at<1>();
|
||||
constexpr index_t NWarp = config.template at<2>();
|
||||
|
||||
constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM);
|
||||
constexpr index_t KIterPerWarp = KPerBlock / WG::kK;
|
||||
|
||||
constexpr auto a_block_outer_dstr_encoding =
|
||||
tile_distribution_encoding<sequence<NWarp>,
|
||||
tuple<sequence<MIterPerWarp, MWarp>, sequence<KIterPerWarp>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
|
||||
constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
a_block_outer_dstr_encoding, typename WG::AWarpDstrEncoding{});
|
||||
|
||||
return make_static_tile_distribution(a_block_dstr_encode);
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE static constexpr auto MakeCBlockTile()
|
||||
{
|
||||
constexpr index_t MPerBlock = BlockGemmShape::kM;
|
||||
constexpr index_t NPerBlock = BlockGemmShape::kN;
|
||||
|
||||
constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>();
|
||||
|
||||
using WG = remove_cvref_t<decltype(config.template at<0>())>;
|
||||
|
||||
constexpr index_t MWarp = config.template at<1>();
|
||||
constexpr index_t NWarp = config.template at<2>();
|
||||
|
||||
constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM);
|
||||
constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN);
|
||||
// constexpr index_t KIterPerWarp = KPerBlock / WG::kK;
|
||||
|
||||
constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<MIterPerWarp, MWarp>, sequence<NIterPerWarp, NWarp>>,
|
||||
tuple<sequence<1, 2>>,
|
||||
tuple<sequence<1, 1>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
|
||||
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{});
|
||||
constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode);
|
||||
auto c_block_tensor = make_static_distributed_tensor<CDataType>(c_block_dstr);
|
||||
return c_block_tensor;
|
||||
}
|
||||
|
||||
// C = A * B
|
||||
template <typename ABlockTensorTmp, typename BBlockWindowTmp>
|
||||
CK_TILE_DEVICE auto operator()(const ABlockTensorTmp& a_block_tensor_tmp,
|
||||
const BBlockWindowTmp& b_block_window_tmp) const
|
||||
{
|
||||
auto c_block_tensor = MakeCBlockTile();
|
||||
operator()(c_block_tensor, a_block_tensor_tmp, b_block_window_tmp);
|
||||
return c_block_tensor;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,243 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2_default_policy.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// A is block distributed tensor
|
||||
// B is block window on shared memory
|
||||
// C is block distributed tensor
|
||||
template <typename Problem_, typename Policy_ = BlockGemmARegBSmemCRegV2DefaultPolicy>
|
||||
struct BlockGemmARegBSmemTrLoadCRegV2PrefetchN
|
||||
{
|
||||
using Problem = remove_cvref_t<Problem_>;
|
||||
using Policy = remove_cvref_t<Policy_>;
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
using CDataType = remove_cvref_t<typename Problem::CDataType>;
|
||||
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
|
||||
|
||||
static constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
|
||||
// C += A * B
|
||||
template <typename CBlockTensor, typename ABlockTensorTmp, typename BBlockWindowTmp>
|
||||
CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor,
|
||||
const ABlockTensorTmp& a_block_tensor_tmp,
|
||||
const BBlockWindowTmp& b_block_window_tmp) const
|
||||
{
|
||||
static_assert(
|
||||
std::is_same_v<ADataType, remove_cv_t<typename ABlockTensorTmp::DataType>> &&
|
||||
std::is_same_v<BDataType, remove_cv_t<typename BBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<CDataType, remove_cv_t<typename CBlockTensor::DataType>>,
|
||||
"wrong!");
|
||||
|
||||
constexpr index_t MPerBlock = ABlockTensorTmp{}.get_lengths()[number<0>{}];
|
||||
constexpr index_t NPerBlock = BBlockWindowTmp{}.get_window_lengths()[number<1>{}];
|
||||
constexpr index_t KPerBlock = ABlockTensorTmp{}.get_lengths()[number<1>{}];
|
||||
|
||||
static_assert(MPerBlock == BlockGemmShape::kM && NPerBlock == BlockGemmShape::kN &&
|
||||
KPerBlock == BlockGemmShape::kK,
|
||||
"wrong!");
|
||||
|
||||
constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>();
|
||||
|
||||
using WG = remove_cvref_t<decltype(config.template at<0>())>;
|
||||
|
||||
constexpr index_t MWarp = config.template at<1>();
|
||||
constexpr index_t NWarp = config.template at<2>();
|
||||
|
||||
constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM);
|
||||
constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN);
|
||||
constexpr index_t KIterPerWarp = KPerBlock / WG::kK;
|
||||
|
||||
constexpr index_t NPerBlockPerIter = NPerBlock / NIterPerWarp;
|
||||
constexpr index_t KPerBlockPerIter = KPerBlock / KIterPerWarp;
|
||||
|
||||
const index_t iNWarp = get_warp_id<false>() % NWarp;
|
||||
|
||||
constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<MIterPerWarp, MWarp>, sequence<NIterPerWarp, NWarp>>,
|
||||
tuple<sequence<1, 2>>,
|
||||
tuple<sequence<1, 1>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
|
||||
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{});
|
||||
|
||||
// construct from A-block-tensor from A-Block-tensor-tmp
|
||||
// FIXME: need method to check a_block_tensor and a_block_tensor_tmp have equivalent
|
||||
// distribution
|
||||
auto a_block_tensor = make_static_distributed_tensor<typename ABlockTensorTmp::DataType>(
|
||||
MakeABlockTileDistribution());
|
||||
|
||||
a_block_tensor.get_thread_buffer() = a_block_tensor_tmp.get_thread_buffer();
|
||||
|
||||
constexpr auto b_warp_dstr_encode =
|
||||
typename InputTileDistributionTraits<typename WG::BWarpDstrEncoding,
|
||||
BDataType>::TransposedDstrEncode{};
|
||||
|
||||
// construct B-warp-window
|
||||
auto b_warp_window_tmp = make_tile_window(
|
||||
b_block_window_tmp.get_bottom_tensor_view(),
|
||||
make_tuple(number<WG::kK>{}, number<WG::kN>{}),
|
||||
b_block_window_tmp.get_window_origin() + multi_index<2>{0, iNWarp * WG::kN},
|
||||
make_static_tile_distribution(b_warp_dstr_encode));
|
||||
|
||||
statically_indexed_array<
|
||||
statically_indexed_array<decltype(b_warp_window_tmp), KIterPerWarp>,
|
||||
NIterPerWarp>
|
||||
b_warp_windows;
|
||||
|
||||
// check C-block-distribution
|
||||
static_assert(
|
||||
std::is_same_v<remove_cvref_t<decltype(c_block_dstr_encode)>,
|
||||
remove_cvref_t<decltype(CBlockTensor::get_tile_distribution()
|
||||
.get_static_tile_distribution_encoding())>>,
|
||||
"wrong!");
|
||||
|
||||
using AWarpDstr = typename WG::AWarpDstr;
|
||||
using CWarpDstr = typename WG::CWarpDstr;
|
||||
|
||||
using AWarpTensor = typename WG::AWarpTensor;
|
||||
using CWarpTensor = typename WG::CWarpTensor;
|
||||
|
||||
constexpr auto a_warp_y_lengths =
|
||||
to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
|
||||
constexpr auto c_warp_y_lengths =
|
||||
to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
|
||||
|
||||
constexpr auto a_warp_y_index_zeros = uniform_sequence_gen_t<AWarpDstr::NDimY, 0>{};
|
||||
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
|
||||
|
||||
constexpr auto I0 = number<0>{};
|
||||
|
||||
using b_warp_tensor_type = decltype(load_tile_transpose(b_warp_windows(I0)(I0)));
|
||||
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
statically_indexed_array<b_warp_tensor_type, NIterPerWarp> b_warp_tensors;
|
||||
|
||||
// read B warp tensor from B Block window
|
||||
b_warp_windows(I0)(kIter) = b_warp_window_tmp;
|
||||
move_tile_window(b_warp_windows(I0)(kIter),
|
||||
{kIter * KPerBlockPerIter, 0 * NPerBlockPerIter});
|
||||
b_warp_tensors(I0) = load_tile_transpose(b_warp_windows(I0)(kIter));
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
if constexpr(nIter < NIterPerWarp - 1)
|
||||
{
|
||||
// read B warp tensor from B Block window
|
||||
b_warp_windows(number<nIter + 1>{})(kIter) = b_warp_window_tmp;
|
||||
move_tile_window(b_warp_windows(number<nIter + 1>{})(kIter),
|
||||
{kIter * KPerBlockPerIter, (nIter + 1) * NPerBlockPerIter});
|
||||
b_warp_tensors(number<nIter + 1>{}) =
|
||||
load_tile_transpose(b_warp_windows(number<nIter + 1>{})(kIter));
|
||||
};
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
// read A warp tensor from A block tensor
|
||||
AWarpTensor a_warp_tensor;
|
||||
|
||||
a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, kIter>{}, a_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
|
||||
|
||||
// read C warp tensor from C block tensor
|
||||
CWarpTensor c_warp_tensor;
|
||||
|
||||
c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
|
||||
// warp GEMM
|
||||
WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensors[nIter]);
|
||||
|
||||
// write C warp tensor into C block tensor
|
||||
c_block_tensor.set_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
|
||||
c_warp_tensor.get_thread_buffer());
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
template <index_t MPerBlock = BlockGemmShape::kM, index_t KPerBlock = BlockGemmShape::kK>
|
||||
CK_TILE_DEVICE static constexpr auto MakeABlockTileDistribution()
|
||||
{
|
||||
constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>();
|
||||
|
||||
using WG = remove_cvref_t<decltype(config.template at<0>())>;
|
||||
|
||||
constexpr index_t MWarp = config.template at<1>();
|
||||
constexpr index_t NWarp = config.template at<2>();
|
||||
|
||||
constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM);
|
||||
constexpr index_t KIterPerWarp = KPerBlock / WG::kK;
|
||||
|
||||
constexpr auto a_block_outer_dstr_encoding =
|
||||
tile_distribution_encoding<sequence<NWarp>,
|
||||
tuple<sequence<MIterPerWarp, MWarp>, sequence<KIterPerWarp>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
|
||||
constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
a_block_outer_dstr_encoding, typename WG::AWarpDstrEncoding{});
|
||||
|
||||
return make_static_tile_distribution(a_block_dstr_encode);
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE static constexpr auto MakeCBlockTile()
|
||||
{
|
||||
constexpr index_t MPerBlock = BlockGemmShape::kM;
|
||||
constexpr index_t NPerBlock = BlockGemmShape::kN;
|
||||
|
||||
constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>();
|
||||
|
||||
using WG = remove_cvref_t<decltype(config.template at<0>())>;
|
||||
|
||||
constexpr index_t MWarp = config.template at<1>();
|
||||
constexpr index_t NWarp = config.template at<2>();
|
||||
|
||||
constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM);
|
||||
constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN);
|
||||
// constexpr index_t KIterPerWarp = KPerBlock / WG::kK;
|
||||
|
||||
constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<MIterPerWarp, MWarp>, sequence<NIterPerWarp, NWarp>>,
|
||||
tuple<sequence<1, 2>>,
|
||||
tuple<sequence<1, 1>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
|
||||
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{});
|
||||
constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode);
|
||||
auto c_block_tensor = make_static_distributed_tensor<CDataType>(c_block_dstr);
|
||||
return c_block_tensor;
|
||||
}
|
||||
|
||||
// C = A * B
|
||||
template <typename ABlockTensorTmp, typename BBlockWindowTmp>
|
||||
CK_TILE_DEVICE auto operator()(const ABlockTensorTmp& a_block_tensor_tmp,
|
||||
const BBlockWindowTmp& b_block_window_tmp) const
|
||||
{
|
||||
auto c_block_tensor = MakeCBlockTile();
|
||||
operator()(c_block_tensor, a_block_tensor_tmp, b_block_window_tmp);
|
||||
return c_block_tensor;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
Reference in New Issue
Block a user