mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-04 21:51:28 +00:00
[rocm-libraries] ROCm/rocm-libraries#6051 (commit f0838b2)
[CK] Add FP8 per-tensor quantization support for FMHA V3 pipeline (#6051) ## Motivation The existing FMHA V3 pipeline only supports fp16/bf16 data types. This PR extends V3 to handle FP8 inputs with per-tensor descaling on gfx950, enabling higher throughput for FP8 inference workloads using the assembly-optimized V3 code path. ## Technical Details **Warp GEMM:** - Add FP8 32x32x32 warp gemm with C-transposed distribution (`WarpGemmMfma_f32_32x32x32_fp8_fp8_CTransposed`) and dispatcher entries **V3 Kernel (`fmha_fwd_v3_kernel.hpp`):** - Add per-tensor descale support for Q, K, V tensors, passing descale pointers through to pipeline kargs **V3 Pipeline (`block_fmha_fwd_v3_pipeline.hpp`):** - Add FP8 data path with dtype-aware type selection - Add asm volatile P matrix conversion from f32 to fp8 - Add FP8-aware instruction scheduling in `CoreLoopScheduler` **V3 Pipeline Policy (`block_fmha_fwd_v3_pipeline_default_policy.hpp`):** - Add FP8 QK warp gemm selection (SwizzleB variant for V tile distribution compatibility) **Codegen (`fmha_fwd.py`):** - Add gfx950 FP8BF16 V3 tile size (256x64x128x128x64x128) - Add FP8BF16 V3 pipeline variants (mask: no/causal, qscale: no/pertensor) - Extend `can_dispatch_v3` condition for fp8bf16 + pertensor **Misc:** - Add LLVM scheduler `TRANS` mask to `LLVMSchedGroupMask` enum (`arch.hpp`) - Fix `mask_info` default initialization for `no_mask` case (`mask.hpp`) V3 dispatch for FP8 is disabled by default (`F_is_v3_enabled=false`) pending further validation. ## Performance: fmha_fwd V3 FP8 (avg runs 2-6, stock ROCm 7.1.1, gfx950) | Problem | Regular (TFlops) | Varlen (TFlops) | |---|---:|---:| | batch=1 heads=6/1 seqlen=1024 causal | 48.9 | 47.6 | | batch=1 heads=6/1 seqlen=2048 causal | 119.8 | 117.4 | | batch=1 heads=6/1 seqlen=4096 causal | 263.7 | 259.2 | | batch=1 heads=6/1 seqlen=8192 causal | 548.9 | 543.6 | | batch=1 heads=6/1 seqlen=16384 causal | 1043.0 | 1063.7 | | batch=1 heads=6/1 seqlen=32768 causal | 1237.2 | 1279.6 | | batch=1 heads=6/1 seqlen=65536 causal | 1315.4 | 1382.7 | | batch=1 heads=6/1 seqlen=131072 causal | 1326.3 | 1402.2 | | batch=1 heads=16/1 seqlen=65536 causal | 1298.7 | 1388.4 | | batch=1 heads=40/40 seqlen=37200 non-causal | 1248.9 | 1326.1 | ## Test Plan Tested with aiter's `test_mha_fp8.py` test suite (176 cases) covering batch sizes (1-2), sequence lengths (113-4096), head counts (5/8/32/40), GQA ratios (1:1, 1:8), and causal/non-causal modes. Verified all cases dispatch to the V3 pipeline by enabling `F_is_v3_enabled` and confirming kernel names contain `qr_async_trload_v3`. ## Test Result 176/176 tests passed with V3 enabled. All cases correctly dispatched to V3 pipeline with `pertensor` quantization. ## Submission Checklist - [x] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
committed by
assistant-librarian[bot]
parent
020b6f435e
commit
c2ac7aa7b0
@@ -27,6 +27,7 @@ struct FmhaFwdV3Kernel
|
||||
using QDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::QDataType>;
|
||||
using KDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::KDataType>;
|
||||
using VDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::VDataType>;
|
||||
using PDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::PDataType>;
|
||||
using LSEDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::LSEDataType>;
|
||||
using ODataType = ck_tile::remove_cvref_t<typename FmhaPipeline::ODataType>;
|
||||
using SaccDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::SaccDataType>;
|
||||
@@ -38,6 +39,7 @@ struct FmhaFwdV3Kernel
|
||||
static constexpr bool kPadHeadDimV = FmhaPipeline::kPadHeadDimV;
|
||||
static constexpr bool kHasLogitsSoftCap = FmhaPipeline::kHasLogitsSoftCap;
|
||||
static constexpr bool kStoreLSE = FmhaPipeline::kStoreLSE;
|
||||
static constexpr auto QScaleEnum = FmhaPipeline::Problem::QScaleEnum;
|
||||
|
||||
using AttentionVariant = ck_tile::remove_cvref_t<typename FmhaPipeline::AttentionVariant>;
|
||||
using FmhaMask = ck_tile::remove_cvref_t<typename FmhaPipeline::FmhaMask>;
|
||||
@@ -118,11 +120,21 @@ struct FmhaFwdV3Kernel
|
||||
float logits_soft_cap_rcp;
|
||||
};
|
||||
|
||||
struct FmhaFwdCommonQScaleKargs
|
||||
{
|
||||
const void* q_descale_ptr = nullptr;
|
||||
const void* k_descale_ptr = nullptr;
|
||||
const void* v_descale_ptr = nullptr;
|
||||
};
|
||||
|
||||
struct FmhaFwdBatchModeKargs
|
||||
: FmhaFwdCommonKargs,
|
||||
std::conditional_t<kHasMask, FmhaFwdMaskKargs, FmhaFwdEmptyKargs<0>>,
|
||||
std::conditional_t<kStoreLSE, FmhaFwdCommonLSEKargs, FmhaFwdEmptyKargs<1>>,
|
||||
std::conditional_t<kHasLogitsSoftCap, FmhaFwdLogitsSoftCapKargs, FmhaFwdEmptyKargs<2>>
|
||||
std::conditional_t<QScaleEnum == BlockAttentionQuantScaleEnum::PERTENSOR,
|
||||
FmhaFwdCommonQScaleKargs,
|
||||
FmhaFwdEmptyKargs<2>>,
|
||||
std::conditional_t<kHasLogitsSoftCap, FmhaFwdLogitsSoftCapKargs, FmhaFwdEmptyKargs<3>>
|
||||
{
|
||||
ck_tile::index_t batch_stride_q;
|
||||
ck_tile::index_t batch_stride_k;
|
||||
@@ -139,7 +151,10 @@ struct FmhaFwdV3Kernel
|
||||
: FmhaFwdCommonKargs,
|
||||
std::conditional_t<kHasMask, FmhaFwdMaskKargs, FmhaFwdEmptyKargs<0>>,
|
||||
std::conditional_t<kStoreLSE, FmhaFwdCommonLSEKargs, FmhaFwdEmptyKargs<1>>,
|
||||
std::conditional_t<kHasLogitsSoftCap, FmhaFwdLogitsSoftCapKargs, FmhaFwdEmptyKargs<2>>
|
||||
std::conditional_t<QScaleEnum == BlockAttentionQuantScaleEnum::PERTENSOR,
|
||||
FmhaFwdCommonQScaleKargs,
|
||||
FmhaFwdEmptyKargs<2>>,
|
||||
std::conditional_t<kHasLogitsSoftCap, FmhaFwdLogitsSoftCapKargs, FmhaFwdEmptyKargs<3>>
|
||||
{
|
||||
const int32_t* seqstart_q_ptr;
|
||||
const int32_t* seqstart_k_ptr;
|
||||
@@ -166,6 +181,9 @@ struct FmhaFwdV3Kernel
|
||||
MakeKargs(const void* q_ptr,
|
||||
const void* k_ptr,
|
||||
const void* v_ptr,
|
||||
const void* q_descale_ptr,
|
||||
const void* k_descale_ptr,
|
||||
const void* v_descale_ptr,
|
||||
void* lse_ptr,
|
||||
void* o_ptr,
|
||||
ck_tile::index_t seqlen_q,
|
||||
@@ -218,6 +236,7 @@ struct FmhaFwdV3Kernel
|
||||
nhead_stride_o}, // args for common karg
|
||||
{}, // placeholder for mask
|
||||
{}, // placeholder for lse
|
||||
{}, // placeholder for qscale
|
||||
{}, // placeholder for logits_soft_cap
|
||||
batch_stride_q,
|
||||
batch_stride_k,
|
||||
@@ -237,6 +256,12 @@ struct FmhaFwdV3Kernel
|
||||
kargs.nhead_stride_lse = nhead_stride_lse;
|
||||
kargs.batch_stride_lse = batch_stride_lse;
|
||||
}
|
||||
if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::PERTENSOR)
|
||||
{
|
||||
kargs.q_descale_ptr = q_descale_ptr;
|
||||
kargs.k_descale_ptr = k_descale_ptr;
|
||||
kargs.v_descale_ptr = v_descale_ptr;
|
||||
}
|
||||
if constexpr(kHasLogitsSoftCap)
|
||||
{
|
||||
kargs.init_logits_soft_cap(logits_soft_cap);
|
||||
@@ -252,6 +277,9 @@ struct FmhaFwdV3Kernel
|
||||
MakeKargs(const void* q_ptr,
|
||||
const void* k_ptr,
|
||||
const void* v_ptr,
|
||||
const void* q_descale_ptr,
|
||||
const void* k_descale_ptr,
|
||||
const void* v_descale_ptr,
|
||||
void* lse_ptr,
|
||||
void* o_ptr,
|
||||
const void* seqstart_q_ptr,
|
||||
@@ -301,6 +329,7 @@ struct FmhaFwdV3Kernel
|
||||
nhead_stride_o}, // args for common karg
|
||||
{}, // placeholder for mask
|
||||
{}, // placeholder for lse
|
||||
{}, // placeholder for qscale
|
||||
{}, // placeholder for logits_soft_cap
|
||||
reinterpret_cast<const int32_t*>(seqstart_q_ptr),
|
||||
reinterpret_cast<const int32_t*>(seqstart_k_ptr),
|
||||
@@ -319,6 +348,12 @@ struct FmhaFwdV3Kernel
|
||||
kargs.lse_ptr = lse_ptr;
|
||||
kargs.nhead_stride_lse = nhead_stride_lse;
|
||||
}
|
||||
if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::PERTENSOR)
|
||||
{
|
||||
kargs.q_descale_ptr = q_descale_ptr;
|
||||
kargs.k_descale_ptr = k_descale_ptr;
|
||||
kargs.v_descale_ptr = v_descale_ptr;
|
||||
}
|
||||
if constexpr(kHasLogitsSoftCap)
|
||||
{
|
||||
kargs.init_logits_soft_cap(logits_soft_cap);
|
||||
@@ -437,8 +472,19 @@ struct FmhaFwdV3Kernel
|
||||
{
|
||||
using namespace ck_tile;
|
||||
|
||||
// allocate LDS
|
||||
__shared__ char smem_ptr[GetSmemSize()];
|
||||
// Notice: When using double buffering, make sure both buffers are in the same array.
|
||||
// This prevents the compiler from using separate VGPRs to store the base address
|
||||
// and enables the use of immediate offsets in load/store instructions.
|
||||
constexpr auto smem_size_kv =
|
||||
FmhaPipeline::Policy::template GetSmemSizeKV<typename FmhaPipeline::Problem>();
|
||||
__shared__ char smem_k[2][smem_size_kv];
|
||||
__shared__ char smem_v[2][smem_size_kv];
|
||||
|
||||
auto* smem_k0 = reinterpret_cast<KDataType*>(smem_k[0]);
|
||||
auto* smem_k1 = reinterpret_cast<KDataType*>(smem_k[1]);
|
||||
auto* smem_v0 = reinterpret_cast<VDataType*>(smem_v[0]);
|
||||
auto* smem_v1 = reinterpret_cast<VDataType*>(smem_v[1]);
|
||||
;
|
||||
|
||||
// divide problem
|
||||
const auto [i_tile_m, i_tile_n, i_nhead, i_batch] = GetTileIndex(kargs);
|
||||
@@ -640,32 +686,88 @@ struct FmhaFwdV3Kernel
|
||||
return FmhaMask{kargs.seqlen_q, kargs.seqlen_k};
|
||||
}();
|
||||
|
||||
const float scale_s = [&] {
|
||||
if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::PERTENSOR)
|
||||
{
|
||||
float q_descale = *(reinterpret_cast<const float*>(kargs.q_descale_ptr));
|
||||
float k_descale = *(reinterpret_cast<const float*>(kargs.k_descale_ptr));
|
||||
return kargs.scale_s * q_descale * k_descale;
|
||||
}
|
||||
else
|
||||
{
|
||||
return kargs.scale_s;
|
||||
}
|
||||
}();
|
||||
|
||||
AttentionVariant variant;
|
||||
const auto variant_params = [&] {
|
||||
if constexpr(kHasLogitsSoftCap)
|
||||
{
|
||||
return ck_tile::LogitsSoftCapParams<FmhaMask, CK_TILE_FMHA_FWD_FAST_EXP2>{
|
||||
mask, kargs.scale_s, kargs.logits_soft_cap, kargs.logits_soft_cap_rcp};
|
||||
mask, scale_s, kargs.logits_soft_cap, kargs.logits_soft_cap_rcp};
|
||||
}
|
||||
else
|
||||
{
|
||||
return ck_tile::StandardAttentionParams<FmhaMask>{mask, kargs.scale_s};
|
||||
return ck_tile::StandardAttentionParams<FmhaMask>{mask, scale_s};
|
||||
}
|
||||
}();
|
||||
|
||||
BlockIndices block_indices{i_batch, i_nhead, i_nhead / kargs.nhead_ratio_qk};
|
||||
|
||||
auto o_acc_tile = [&]() {
|
||||
return FmhaPipeline{}(q_dram_window,
|
||||
k_dram_window,
|
||||
v_dram_window,
|
||||
lse_dram_window,
|
||||
mask,
|
||||
kargs.scale_s,
|
||||
variant,
|
||||
variant_params,
|
||||
block_indices,
|
||||
smem_ptr);
|
||||
if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::PERTENSOR)
|
||||
{
|
||||
float v_descale = *(reinterpret_cast<const float*>(kargs.v_descale_ptr));
|
||||
float scale_p = ck_tile::type_convert<float>(ck_tile::numeric<PDataType>::max());
|
||||
float scale_o = v_descale / scale_p;
|
||||
|
||||
auto o_acc_element_func = [&]() {
|
||||
if constexpr(std::is_same_v<ODataType, ck_tile::fp8_t>)
|
||||
return make_composes(
|
||||
ck_tile::saturates<ck_tile::fp8_t>{},
|
||||
ck_tile::scales<remove_cvref_t<decltype(scale_o)>>{scale_o});
|
||||
else
|
||||
return ck_tile::scales<remove_cvref_t<decltype(scale_o)>>{scale_o};
|
||||
}();
|
||||
|
||||
return FmhaPipeline{}(
|
||||
q_dram_window,
|
||||
identity{}, // q_element_func
|
||||
k_dram_window,
|
||||
identity{}, // k_element_func
|
||||
v_dram_window,
|
||||
identity{}, // v_element_func
|
||||
lse_dram_window,
|
||||
identity{}, // lse_element_func
|
||||
identity{}, // s_acc_element_func
|
||||
scales<remove_cvref_t<decltype(scale_p)>>{scale_p}, // p_compute_element_func
|
||||
o_acc_element_func,
|
||||
mask,
|
||||
scale_s,
|
||||
variant,
|
||||
variant_params,
|
||||
block_indices,
|
||||
smem_k0,
|
||||
smem_k1,
|
||||
smem_v0,
|
||||
smem_v1);
|
||||
}
|
||||
else
|
||||
{
|
||||
return FmhaPipeline{}(q_dram_window,
|
||||
k_dram_window,
|
||||
v_dram_window,
|
||||
lse_dram_window,
|
||||
mask,
|
||||
scale_s,
|
||||
variant,
|
||||
variant_params,
|
||||
block_indices,
|
||||
smem_k0,
|
||||
smem_k1,
|
||||
smem_v0,
|
||||
smem_v1);
|
||||
}
|
||||
}();
|
||||
|
||||
// O DRAM and O DRAM window
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -239,10 +239,18 @@ struct BlockFmhaV3PipelineDefaultPolicy
|
||||
typename Problem::BlockFmhaShape::Gemm0BlockWarps,
|
||||
typename Problem::BlockFmhaShape::Gemm0WarpTile>>;
|
||||
|
||||
constexpr auto warp_gemm = []() {
|
||||
if constexpr(std::is_same_v<typename Problem::QDataType, half_t> &&
|
||||
std::is_same_v<typename Problem::KDataType, half_t> &&
|
||||
constexpr auto warp_gemm = [] {
|
||||
if constexpr(std::is_same_v<typename Problem::QDataType, fp8_t> &&
|
||||
std::is_same_v<typename Problem::KDataType, fp8_t> &&
|
||||
std::is_same_v<typename Problem::SaccDataType, float>)
|
||||
{
|
||||
// Use SwizzleB variant to get 8 contiguous K positions per lane,
|
||||
// matching the V tile distribution for PV GEMM
|
||||
return WarpGemmMfmaFp8Fp8F32M32N32K32SwizzleBTransposedCDistribution<>{};
|
||||
}
|
||||
else if constexpr(std::is_same_v<typename Problem::QDataType, half_t> &&
|
||||
std::is_same_v<typename Problem::KDataType, half_t> &&
|
||||
std::is_same_v<typename Problem::SaccDataType, float>)
|
||||
{
|
||||
/// NOTICE: in order to use load_tile_transpose() later for V tile, we cannot use
|
||||
/// WarpGemmMfmaF16F16F32M32N32K16SwizzleBTransposedCDistribution here
|
||||
@@ -310,9 +318,8 @@ struct BlockFmhaV3PipelineDefaultPolicy
|
||||
static constexpr ck_tile::index_t kKLdsPadInBytes = 4 * 4; // 4 dwords
|
||||
static constexpr ck_tile::index_t kVLdsPadInBytes = 4 * 16; // 16 dwords
|
||||
|
||||
template <typename Problem, ck_tile::index_t IBuf = 0>
|
||||
CK_TILE_DEVICE static constexpr auto
|
||||
MakeKLdsStoreBlockDescriptor(ck_tile::number<IBuf> = ck_tile::number<0>{})
|
||||
template <typename Problem>
|
||||
CK_TILE_DEVICE static constexpr auto MakeKLdsStoreBlockDescriptor()
|
||||
{
|
||||
using namespace ck_tile;
|
||||
|
||||
@@ -323,7 +330,6 @@ struct BlockFmhaV3PipelineDefaultPolicy
|
||||
constexpr index_t NumWarps = Problem::BlockFmhaShape::NumWarps;
|
||||
constexpr index_t WarpSize = ck_tile::get_warp_size();
|
||||
|
||||
[[maybe_unused]] constexpr index_t KPack = GetSmemKPackK<Problem>(); // this is for lds
|
||||
constexpr index_t KVector = GetAlignmentK<Problem>(); // this is for global load
|
||||
constexpr index_t kPad =
|
||||
kKLdsPadInBytes /
|
||||
@@ -339,31 +345,28 @@ struct BlockFmhaV3PipelineDefaultPolicy
|
||||
constexpr index_t NumIssues = kNPerBlock / (LaneGroups * NumWarps);
|
||||
static_assert(NumIssues == kNPerBlock * kKPerBlock / (kBlockSize * KVector));
|
||||
|
||||
constexpr auto k_lds_block_desc_0 = make_naive_tensor_descriptor_with_offset(
|
||||
make_tuple(number<NumIssues>{}, // n0
|
||||
number<LaneGroups>{}, // n1
|
||||
number<NumWarps>{}, // n2
|
||||
number<LanesPerK>{}, // k0
|
||||
number<KVector>{}), // k1
|
||||
make_tuple(number<NumWarps*(WarpSize * KVector + kPad)>{},
|
||||
number<kKPerBlock>{},
|
||||
number<WarpSize * KVector + kPad>{},
|
||||
number<KVector>{},
|
||||
number<1>{}),
|
||||
number<IBuf * GetSingleSmemElementSpaceSize<Problem>()>{},
|
||||
number<KVector>{},
|
||||
number<1>{});
|
||||
constexpr auto k_lds_block_desc_0 =
|
||||
make_naive_tensor_descriptor(make_tuple(number<NumIssues>{}, // n0
|
||||
number<LaneGroups>{}, // n1
|
||||
number<NumWarps>{}, // n2
|
||||
number<LanesPerK>{}, // k0
|
||||
number<KVector>{}), // k1
|
||||
make_tuple(number<NumWarps*(WarpSize * KVector + kPad)>{},
|
||||
number<kKPerBlock>{},
|
||||
number<WarpSize * KVector + kPad>{},
|
||||
number<KVector>{},
|
||||
number<1>{}),
|
||||
number<KVector>{},
|
||||
number<1>{});
|
||||
|
||||
// TODO this layout is hard coded, and will be used in async copy buffer view load
|
||||
// in LDS the real layout is (bufs, N0, N2, N1*K0*K1)
|
||||
// CRITICAL: Must match Load descriptor merge pattern (NumIssues, LaneGroups, NumWarps)
|
||||
constexpr auto k_lds_block_desc_issues_warps_lanes = transform_tensor_descriptor(
|
||||
k_lds_block_desc_0,
|
||||
make_tuple(make_pass_through_transform(number<NumIssues>{}),
|
||||
make_pass_through_transform(number<NumWarps>{}),
|
||||
make_merge_transform(make_tuple(
|
||||
number<LaneGroups>{}, number<LanesPerK>{}, number<KVector>{}))),
|
||||
make_tuple(sequence<0>{}, sequence<2>{}, sequence<1, 3, 4>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}));
|
||||
make_tuple(make_merge_transform(make_tuple(
|
||||
number<NumIssues>{}, number<LaneGroups>{}, number<NumWarps>{})),
|
||||
make_merge_transform(make_tuple(number<LanesPerK>{}, number<KVector>{}))),
|
||||
make_tuple(sequence<0, 1, 2>{}, sequence<3, 4>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
return k_lds_block_desc_issues_warps_lanes;
|
||||
}
|
||||
@@ -458,9 +461,8 @@ struct BlockFmhaV3PipelineDefaultPolicy
|
||||
return max(SingleKSize, SingleVSize);
|
||||
}
|
||||
|
||||
template <typename Problem, ck_tile::index_t IBuf = 0>
|
||||
CK_TILE_DEVICE static constexpr auto
|
||||
MakeVLdsStoreBlockDescriptor(ck_tile::number<IBuf> = ck_tile::number<0>{})
|
||||
template <typename Problem>
|
||||
CK_TILE_DEVICE static constexpr auto MakeVLdsStoreBlockDescriptor()
|
||||
{
|
||||
using namespace ck_tile;
|
||||
|
||||
@@ -471,7 +473,6 @@ struct BlockFmhaV3PipelineDefaultPolicy
|
||||
constexpr index_t NumWarps = Problem::BlockFmhaShape::NumWarps;
|
||||
constexpr index_t WarpSize = ck_tile::get_warp_size();
|
||||
|
||||
[[maybe_unused]] constexpr index_t KPack = GetSmemVPackK<Problem>(); // this is for lds
|
||||
constexpr index_t KVector = GetAlignmentV<Problem>(); // this is for global load
|
||||
constexpr index_t kPad =
|
||||
kVLdsPadInBytes /
|
||||
@@ -487,31 +488,27 @@ struct BlockFmhaV3PipelineDefaultPolicy
|
||||
constexpr index_t NumIssues = kNPerBlock / (LaneGroups * NumWarps);
|
||||
static_assert(NumIssues == kNPerBlock * kKPerBlock / (kBlockSize * KVector));
|
||||
|
||||
constexpr auto v_lds_block_desc_0 = make_naive_tensor_descriptor_with_offset(
|
||||
make_tuple(number<NumIssues>{}, // n0
|
||||
number<LaneGroups>{}, // n1
|
||||
number<NumWarps>{}, // n2
|
||||
number<LanesPerK>{}, // k0
|
||||
number<KVector>{}), // k1
|
||||
make_tuple(number<NumWarps*(WarpSize * KVector + kPad)>{},
|
||||
number<kKPerBlock>{},
|
||||
number<WarpSize * KVector + kPad>{},
|
||||
number<KVector>{},
|
||||
number<1>{}),
|
||||
number<(IBuf + 2) * GetSingleSmemElementSpaceSize<Problem>()>{},
|
||||
number<KVector>{},
|
||||
number<1>{});
|
||||
constexpr auto v_lds_block_desc_0 =
|
||||
make_naive_tensor_descriptor(make_tuple(number<NumIssues>{}, // n0
|
||||
number<LaneGroups>{}, // n1
|
||||
number<NumWarps>{}, // n2
|
||||
number<LanesPerK>{}, // k0
|
||||
number<KVector>{}), // k1
|
||||
make_tuple(number<NumWarps*(WarpSize * KVector + kPad)>{},
|
||||
number<kKPerBlock>{},
|
||||
number<WarpSize * KVector + kPad>{},
|
||||
number<KVector>{},
|
||||
number<1>{}),
|
||||
number<KVector>{},
|
||||
number<1>{});
|
||||
|
||||
// TODO this layout is hard coded, and will be used in async copy buffer view load
|
||||
// in LDS the real layout is (bufs, N0, N2, N1*K0*K1)
|
||||
constexpr auto v_lds_block_desc_issues_warps_lanes = transform_tensor_descriptor(
|
||||
v_lds_block_desc_0,
|
||||
make_tuple(make_pass_through_transform(number<NumIssues>{}),
|
||||
make_pass_through_transform(number<NumWarps>{}),
|
||||
make_merge_transform(make_tuple(
|
||||
number<LaneGroups>{}, number<LanesPerK>{}, number<KVector>{}))),
|
||||
make_tuple(sequence<0>{}, sequence<2>{}, sequence<1, 3, 4>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}));
|
||||
make_tuple(make_merge_transform(make_tuple(
|
||||
number<NumIssues>{}, number<LaneGroups>{}, number<NumWarps>{})),
|
||||
make_merge_transform(make_tuple(number<LanesPerK>{}, number<KVector>{}))),
|
||||
make_tuple(sequence<0, 1, 2>{}, sequence<3, 4>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
return v_lds_block_desc_issues_warps_lanes;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user