mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 05:01:25 +00:00
Navi3 rel (#1176)
* wmma_op + unit test * add arch limitation to wmma test * change arch limitation * Refactor + Add all type unit test(int4 compile failed) * Add f32_16x16x16_bf16 unit test * tempsave * tempsave * tempsave * runtime bug, cannot find symbol * workaround for incorrect HIP warpSize return value * debugging * tempsave * Correctness OK, waiting for optimization * Tidy up + format * temp save * temp save, reproduce the v_bfi_b32 issue * add inline asm for wmmaop test * tidy up * clean some debug purpose code * discard some codes * clang format * clang format * compiler issue fixed + increase tile size * navi3x_multipleD+example * temp save * workable * batchedgemm[OK], groupconv[debug] * groupconv: Sanity check[OK], Performance[Bad] * navi3x_groupconv_need_optimization * create necessary files * save progress * Add Inter-Row thread transfer * save progress * save debugging progress * sanity check pass * fix a host tensor bug and clean up flash-attn code * format * cancel unnecessary change * cancel unnecessary change * cancel unnecessary change * temp save, add asm backend flag to amd_wmma * Mat-A LDS Bypass sanity pass * temp save * gemm sanity fix * Porting new blockwise gemm to flash attention * Example branch provide to compiler team * tempsave * Fix a bug * batched gemm ported * conv A-skip lds ported * Skip B-Lds real gemm * Skip B Lds Gemm + MulD * batched gemm, conv, skip b lds * format * Attn, skip b lds * Change GridwiseOp nam * fix a typo caused bug * Skip A_Lds sanity pass, Skip B_Lds scratch occured * Bug found, intra-row permute off caused * bug found * a fix * disable buffer load due to incorrect 3rd dword * update fmha config, no scratch generated * update 3rd dword * fmha config update * FMHA, add support to gfx1101/gfx1102 * Merge origin dev (#2) * [Navi3x] Fix Gridwise_multiple_d operation (#649) * Add CMake Option "USE_OPT_NAVI3X" * fix bug * standardize docs (#655) * Separate bibtex requirement from rocm-docs-core (#656) * separate bibtex requirement from rocm-docs-core * point requirements to source rocm-docs-core repo * Add CMake Option "USE_OPT_NAVI3X" (#647) * Add CMake Option "USE_OPT_NAVI3X" * remove navi3x opt compile option from cmake script * Conv + quantization + tanh (#645) * Rename file. Prepare to support another activation * Add comment for quantization * Extract out_elementop * Add tanh example * Add conv + bias + tanh quantization instance * Add missing parameter * Refine cmake * Add external api and client example * Extract variable in example * Fix the comment --------- Co-authored-by: zjing14 <zhangjing14@gmail.com> * Add a denorm test fix (#603) * Add type_convert implementations for bf16 * Add the fix for conv_fwd * Add the fix for conv_bwd_data * Add the fix for conv_bwd_weight * Format * Format * Another format * Add a macro to use workaround on MI200 only * Format --------- Co-authored-by: Rosty Geyyer <rosty.geyyer@amd.com> Co-authored-by: zjing14 <zhangjing14@gmail.com> * simplify karg in device/grid of split-k op (#644) * simplify karg in device/grid split-k op * fix mk_kn_mn instances * add more instances * use name from tensor layout * fix 3rd dword of buffer source descriptor (#659) * add fp64 instances (#658) Co-authored-by: root <root@ctr-ubbsmc15.amd.com> * Issue #666: Revert "simplify karg in device/grid of split-k op (#644)" (#665) This reverts commitbb5530af91. * Groupnorm + swish external api (#668) * Rename to proper naming * Add example of groupnorm + swish * Extract duplicate code in example * Add groupnorm + swish instances * Ractor instance generation, split into multiple cpp file * Add external api and client example * Refine profiler message * Use ck math version of exp * Refine problem size in example * Add host version of exp * add a marco to turn on/off denorm fix (off by default) (#673) * add a marco to turn off denorm fix by default * expose the marco --------- Co-authored-by: root <root@ctr-ubbsmc15.amd.com> * fixed quant example (#672) Co-authored-by: root <root@ctr-ubbsmc15.amd.com> * Add dependabot config and pin rocm-docs-core (#663) * [gtest] suppress unsafe buffer warn (#670) ref: https://github.com/ROCmSoftwarePlatform/MIOpen/pull/1912 * Add memory index guard in wmma device ops (#667) * Add more macros to turn on/off denorm fix (#678) Co-authored-by: Rosty Geyyer <rosty.geyyer@amd.com> * Fix a typo (#676) * Add (#677) * Allow using ROCm release candidate compilers. (#679) * enable use of rocm5.5 release candidate 4 * upgrade to ROCM5.5 RC5 * try fix the PUB_KEY error, remove the cmake-data package * upgrade to latest cmake version * use private dockerhub repo for rocm5.5 rc5 * add missing bracket * add vector load check * solve conflicts --------- Co-authored-by: Sam Wu <sjwu@ualberta.ca> Co-authored-by: Sam Wu <sam.wu2@amd.com> Co-authored-by: rocking5566 <ChunYu.Lai@amd.com> Co-authored-by: zjing14 <zhangjing14@gmail.com> Co-authored-by: Rostyslav Geyyer <46627076+geyyer@users.noreply.github.com> Co-authored-by: Rosty Geyyer <rosty.geyyer@amd.com> Co-authored-by: carlushuang <carlus.huang@amd.com> Co-authored-by: root <root@ctr-ubbsmc15.amd.com> Co-authored-by: Jun Liu <Liu.Jun@amd.com> Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com> * Disable SkipLDS & Align AIT api (#3) * fix layernorm, reduction Ops (#4) * [Navi3x] Fix Gridwise_multiple_d operation (#649) * Add CMake Option "USE_OPT_NAVI3X" * fix bug * standardize docs (#655) * Separate bibtex requirement from rocm-docs-core (#656) * separate bibtex requirement from rocm-docs-core * point requirements to source rocm-docs-core repo * Add CMake Option "USE_OPT_NAVI3X" (#647) * Add CMake Option "USE_OPT_NAVI3X" * remove navi3x opt compile option from cmake script * Conv + quantization + tanh (#645) * Rename file. Prepare to support another activation * Add comment for quantization * Extract out_elementop * Add tanh example * Add conv + bias + tanh quantization instance * Add missing parameter * Refine cmake * Add external api and client example * Extract variable in example * Fix the comment --------- Co-authored-by: zjing14 <zhangjing14@gmail.com> * Add a denorm test fix (#603) * Add type_convert implementations for bf16 * Add the fix for conv_fwd * Add the fix for conv_bwd_data * Add the fix for conv_bwd_weight * Format * Format * Another format * Add a macro to use workaround on MI200 only * Format --------- Co-authored-by: Rosty Geyyer <rosty.geyyer@amd.com> Co-authored-by: zjing14 <zhangjing14@gmail.com> * simplify karg in device/grid of split-k op (#644) * simplify karg in device/grid split-k op * fix mk_kn_mn instances * add more instances * use name from tensor layout * fix 3rd dword of buffer source descriptor (#659) * add fp64 instances (#658) Co-authored-by: root <root@ctr-ubbsmc15.amd.com> * Issue #666: Revert "simplify karg in device/grid of split-k op (#644)" (#665) This reverts commitbb5530af91. * Groupnorm + swish external api (#668) * Rename to proper naming * Add example of groupnorm + swish * Extract duplicate code in example * Add groupnorm + swish instances * Ractor instance generation, split into multiple cpp file * Add external api and client example * Refine profiler message * Use ck math version of exp * Refine problem size in example * Add host version of exp * add a marco to turn on/off denorm fix (off by default) (#673) * add a marco to turn off denorm fix by default * expose the marco --------- Co-authored-by: root <root@ctr-ubbsmc15.amd.com> * fixed quant example (#672) Co-authored-by: root <root@ctr-ubbsmc15.amd.com> * Add dependabot config and pin rocm-docs-core (#663) * [gtest] suppress unsafe buffer warn (#670) ref: https://github.com/ROCmSoftwarePlatform/MIOpen/pull/1912 * Add memory index guard in wmma device ops (#667) * Add more macros to turn on/off denorm fix (#678) Co-authored-by: Rosty Geyyer <rosty.geyyer@amd.com> * Fix a typo (#676) * Add (#677) * Allow using ROCm release candidate compilers. (#679) * enable use of rocm5.5 release candidate 4 * upgrade to ROCM5.5 RC5 * try fix the PUB_KEY error, remove the cmake-data package * upgrade to latest cmake version * use private dockerhub repo for rocm5.5 rc5 * add missing bracket * Disable SkipLDS & Align AIT api * Update dependabot config (#682) Co-authored-by: samjwu <samjwu@users.noreply.github.com> * update attn api * solve type_convert bug + enable --------- Co-authored-by: Sam Wu <sjwu@ualberta.ca> Co-authored-by: Sam Wu <sam.wu2@amd.com> Co-authored-by: rocking5566 <ChunYu.Lai@amd.com> Co-authored-by: zjing14 <zhangjing14@gmail.com> Co-authored-by: Rostyslav Geyyer <46627076+geyyer@users.noreply.github.com> Co-authored-by: Rosty Geyyer <rosty.geyyer@amd.com> Co-authored-by: carlushuang <carlus.huang@amd.com> Co-authored-by: root <root@ctr-ubbsmc15.amd.com> Co-authored-by: Jun Liu <Liu.Jun@amd.com> Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com> Co-authored-by: samjwu <samjwu@users.noreply.github.com> Co-authored-by: haocwang <Haocong.WANG@amd.com> * fix typo * Fix attention with causal mask * multiple fix, try ait compile * Add A/B not use LDS pipeline * Clang format, Add gfx1101, gfx1102 support of FMHA example * cancel change of format script * 1. Enable 2-stage global Prefetch ( May cause VGPR spilling) 2. Enable FP16 accumulator blockwise_gemm * clang-format * 1. change blockwise gemm loopover direction from kmn to mnk ( ~1% improvement) 2. change kernel timing mode to 50 warmup + 50 timed repeat * Update low level abstration of blockwise gemm wmma * (2/5) bilinear gemm pass, perf bug: skip a lds has lower performance than skip b lds * (3/5) batched gemm pass, perf bug: skip a lds has lower performance than skip b lds * (4/5) grouped conv pass * (5/5) attention pass, todo: debug lds perf bug * AIT Attention API refactor (#8) * sanity pass * sanity pass 2 * confirm significant performance regression. * turn on all instances * turn off instance format * Fix bug & tunning & format * DML meta, self_attn+cross_attn * sanity pass * remove useless flag * update tile and problem size used in AIT attention * bug fix in grouped conv supporting check * deprecate inline asm wmma * Bug fix: double lds skip * clang-format * Fix errors in 1. example, fmha 2. gridwise pipeline 3. deviceop, fmha, change some containers from vector to array * part2 of previous commit * clang format * API fix of gridwisegemmpipeline * separate array base and vector base attention tensor transformation * fix gemm * clang format * add gemm fp16 instances * Temp save * fpAintB kernel compile pass * Sanity pass. * Temp save * debug code enabled * Fp16AInt8B_GEMM sanity * MQA implementation * GQA-4 example * tempsave * Compile pass * New implementation of fp16Aint8B Gemm, Acheieve similar math throughput with native fp16 Gemm * format * Todo: fix gemm_bilinear_wmma instances compilation bug * Solve a bug when K1=16 * remove unnecessary changes * Remove tensor layout limitation to LDS usage in tesnor contraction * update self-attention and cross-attention * fix a typo of name * Add arch limiter for fp8 gemm * enable fp8 gemm_xdl for all gfx9 targets * temporarily disable gemm_xdl_fp16_fp8 on MI100/200 * fix the cmake logic for gemm_xdl_fp16_fp8 * re-enable the gemm_xdl_fp16_fp8 on MI100/200 --------- Co-authored-by: aska-0096 <haocwang@amd.com> Co-authored-by: Sam Wu <sjwu@ualberta.ca> Co-authored-by: Sam Wu <sam.wu2@amd.com> Co-authored-by: rocking5566 <ChunYu.Lai@amd.com> Co-authored-by: Rostyslav Geyyer <46627076+geyyer@users.noreply.github.com> Co-authored-by: Rosty Geyyer <rosty.geyyer@amd.com> Co-authored-by: carlushuang <carlus.huang@amd.com> Co-authored-by: root <root@ctr-ubbsmc15.amd.com> Co-authored-by: Jun Liu <Liu.Jun@amd.com> Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com> Co-authored-by: samjwu <samjwu@users.noreply.github.com> Co-authored-by: haocwang <Haocong.WANG@amd.com> Co-authored-by: illsilin <Illia.Silin@amd.com>
This commit is contained in:
@@ -89,6 +89,7 @@ struct wmma_type<WmmaInstr::wmma_f32_16x16x16_f16,
|
||||
static constexpr index_t src_a_data_size = 2;
|
||||
static constexpr index_t src_b_data_size = 2;
|
||||
static constexpr index_t acc_data_size = 4;
|
||||
static constexpr index_t acc_pack_number = 1;
|
||||
// * Thread mapping inside wave, num_thread_per_subgroups always alone N direction
|
||||
static constexpr index_t num_thread_per_subgroups = n_per_wmma;
|
||||
|
||||
@@ -100,7 +101,7 @@ struct wmma_type<WmmaInstr::wmma_f32_16x16x16_f16,
|
||||
// * num_acc_vgprs_per_wave alone M direction
|
||||
// * num_subgroups alone M direction
|
||||
static constexpr index_t num_acc_vgprs_per_wave =
|
||||
m_per_wmma * n_per_wmma * acc_data_size / wave_size / 4;
|
||||
m_per_wmma * n_per_wmma * acc_data_size * acc_pack_number / wave_size / 4;
|
||||
static constexpr index_t num_subgroups = wave_size / num_thread_per_subgroups;
|
||||
|
||||
template <index_t MPerWmma, index_t NPerWmma, class FloatA, class FloatB, class FloatC>
|
||||
@@ -129,6 +130,7 @@ struct wmma_type<WmmaInstr::wmma_f32_16x16x16_bf16,
|
||||
static constexpr index_t src_a_data_size = 2;
|
||||
static constexpr index_t src_b_data_size = 2;
|
||||
static constexpr index_t acc_data_size = 4;
|
||||
static constexpr index_t acc_pack_number = 1;
|
||||
static constexpr index_t num_thread_per_subgroups = n_per_wmma;
|
||||
|
||||
// Wave mode dependent propety
|
||||
@@ -136,7 +138,7 @@ struct wmma_type<WmmaInstr::wmma_f32_16x16x16_bf16,
|
||||
static constexpr index_t num_src_a_vgprs_per_wave = m_per_wmma * src_a_data_size / 4;
|
||||
static constexpr index_t num_src_b_vgprs_per_wave = n_per_wmma * src_b_data_size / 4;
|
||||
static constexpr index_t num_acc_vgprs_per_wave =
|
||||
m_per_wmma * n_per_wmma * acc_data_size / wave_size / 4;
|
||||
m_per_wmma * n_per_wmma * acc_data_size * acc_pack_number / wave_size / 4;
|
||||
static constexpr index_t num_subgroups = wave_size / num_thread_per_subgroups;
|
||||
|
||||
template <index_t MPerWmma, index_t NPerWmma, class FloatA, class FloatB, class FloatC>
|
||||
@@ -153,7 +155,6 @@ struct wmma_type<WmmaInstr::wmma_f32_16x16x16_bf16,
|
||||
}
|
||||
};
|
||||
|
||||
#ifdef CK_UNPACKED_ACC_DESC_LOGIC
|
||||
template <index_t WaveSize>
|
||||
struct wmma_type<WmmaInstr::wmma_f16_16x16x16_f16,
|
||||
WaveSize,
|
||||
@@ -166,6 +167,7 @@ struct wmma_type<WmmaInstr::wmma_f16_16x16x16_f16,
|
||||
static constexpr index_t src_a_data_size = 2;
|
||||
static constexpr index_t src_b_data_size = 2;
|
||||
static constexpr index_t acc_data_size = 2;
|
||||
static constexpr index_t acc_pack_number = 2;
|
||||
static constexpr index_t num_thread_per_subgroups = n_per_wmma;
|
||||
|
||||
// Wave mode dependent propety
|
||||
@@ -173,28 +175,22 @@ struct wmma_type<WmmaInstr::wmma_f16_16x16x16_f16,
|
||||
static constexpr index_t num_src_a_vgprs_per_wave = m_per_wmma * src_a_data_size / 4;
|
||||
static constexpr index_t num_src_b_vgprs_per_wave = n_per_wmma * src_b_data_size / 4;
|
||||
static constexpr index_t num_acc_vgprs_per_wave =
|
||||
m_per_wmma * n_per_wmma * acc_data_size / wave_size / 4;
|
||||
m_per_wmma * n_per_wmma * acc_data_size * acc_pack_number / wave_size / 4;
|
||||
static constexpr index_t num_subgroups = wave_size / num_thread_per_subgroups;
|
||||
|
||||
template <index_t MPerWmma,
|
||||
index_t NPerWmma,
|
||||
index_t Opsel,
|
||||
class FloatA,
|
||||
class FloatB,
|
||||
class FloatC>
|
||||
template <index_t MPerWmma, index_t NPerWmma, class FloatA, class FloatB, class FloatC>
|
||||
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
|
||||
{
|
||||
if constexpr(wave_size == 32)
|
||||
{
|
||||
intrin_wmma_f16_16x16x16_f16_w32<MPerWmma, NPerWmma, Opsel>::Run(a, b, reg_c);
|
||||
intrin_wmma_f16_16x16x16_f16_w32<MPerWmma, NPerWmma, false>::Run(a, b, reg_c);
|
||||
}
|
||||
else if constexpr(wave_size == 64)
|
||||
{
|
||||
intrin_wmma_f16_16x16x16_f16_w64<MPerWmma, NPerWmma, Opsel>::Run(a, b, reg_c);
|
||||
intrin_wmma_f16_16x16x16_f16_w64<MPerWmma, NPerWmma, false>::Run(a, b, reg_c);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <index_t WaveSize>
|
||||
struct wmma_type<WmmaInstr::wmma_bf16_16x16x16_bf16,
|
||||
WaveSize,
|
||||
@@ -207,6 +203,7 @@ struct wmma_type<WmmaInstr::wmma_bf16_16x16x16_bf16,
|
||||
static constexpr index_t src_a_data_size = 2;
|
||||
static constexpr index_t src_b_data_size = 2;
|
||||
static constexpr index_t acc_data_size = 2;
|
||||
static constexpr index_t acc_pack_number = 2;
|
||||
static constexpr index_t num_thread_per_subgroups = n_per_wmma;
|
||||
|
||||
// Wave mode dependent propety
|
||||
@@ -214,7 +211,7 @@ struct wmma_type<WmmaInstr::wmma_bf16_16x16x16_bf16,
|
||||
static constexpr index_t num_src_a_vgprs_per_wave = m_per_wmma * src_a_data_size / 4;
|
||||
static constexpr index_t num_src_b_vgprs_per_wave = n_per_wmma * src_b_data_size / 4;
|
||||
static constexpr index_t num_acc_vgprs_per_wave =
|
||||
m_per_wmma * n_per_wmma * acc_data_size / wave_size / 4;
|
||||
m_per_wmma * n_per_wmma * acc_data_size * acc_pack_number / wave_size / 4;
|
||||
static constexpr index_t num_subgroups = wave_size / num_thread_per_subgroups;
|
||||
|
||||
template <index_t MPerWmma,
|
||||
@@ -227,17 +224,15 @@ struct wmma_type<WmmaInstr::wmma_bf16_16x16x16_bf16,
|
||||
{
|
||||
if constexpr(wave_size == 32)
|
||||
{
|
||||
intrin_wmma_bf16_16x16x16_bf16_w32<MPerWmma, NPerWmma, Opsel>::Run(a, b, reg_c);
|
||||
intrin_wmma_bf16_16x16x16_bf16_w32<MPerWmma, NPerWmma, false>::Run(a, b, reg_c);
|
||||
}
|
||||
else if constexpr(wave_size == 64)
|
||||
{
|
||||
intrin_wmma_bf16_16x16x16_bf16_w64<MPerWmma, NPerWmma, Opsel>::Run(a, b, reg_c);
|
||||
intrin_wmma_bf16_16x16x16_bf16_w64<MPerWmma, NPerWmma, false>::Run(a, b, reg_c);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
#endif
|
||||
|
||||
template <index_t WaveSize>
|
||||
struct wmma_type<WmmaInstr::wmma_i32_16x16x16_iu8,
|
||||
WaveSize,
|
||||
@@ -250,6 +245,7 @@ struct wmma_type<WmmaInstr::wmma_i32_16x16x16_iu8,
|
||||
static constexpr index_t src_a_data_size = 2;
|
||||
static constexpr index_t src_b_data_size = 2;
|
||||
static constexpr index_t acc_data_size = 4;
|
||||
static constexpr index_t acc_pack_number = 1;
|
||||
static constexpr index_t num_thread_per_subgroups = n_per_wmma;
|
||||
|
||||
// Wave mode dependent propety
|
||||
@@ -257,7 +253,7 @@ struct wmma_type<WmmaInstr::wmma_i32_16x16x16_iu8,
|
||||
static constexpr index_t num_src_a_vgprs_per_wave = m_per_wmma * src_a_data_size / 4;
|
||||
static constexpr index_t num_src_b_vgprs_per_wave = n_per_wmma * src_b_data_size / 4;
|
||||
static constexpr index_t num_acc_vgprs_per_wave =
|
||||
m_per_wmma * n_per_wmma * acc_data_size / wave_size / 4;
|
||||
m_per_wmma * n_per_wmma * acc_data_size * acc_pack_number / wave_size / 4;
|
||||
static constexpr index_t num_subgroups = wave_size / num_thread_per_subgroups;
|
||||
|
||||
template <index_t MPerWmma,
|
||||
@@ -346,7 +342,7 @@ struct WmmaSelector
|
||||
static_assert(selected_wmma.k_per_wmma == 16, "WRONG! WMMA_M must equal to 16");
|
||||
|
||||
static_assert(selected_wmma.wave_size * selected_wmma.num_acc_vgprs_per_wave *
|
||||
selected_wmma.acc_data_size ==
|
||||
selected_wmma.acc_data_size * selected_wmma.acc_pack_number ==
|
||||
selected_wmma.m_per_wmma * selected_wmma.n_per_wmma * 4,
|
||||
"WRONG! Invalid Number of Accumulator Register");
|
||||
}
|
||||
@@ -358,7 +354,8 @@ template <typename src_type_a,
|
||||
index_t MPerWmma,
|
||||
index_t NPerWmma,
|
||||
index_t KPack,
|
||||
bool TransposeC = false>
|
||||
bool TransposeC = false,
|
||||
bool AssemblyBackend = false>
|
||||
struct WmmaGemm
|
||||
{
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
@@ -369,14 +366,14 @@ struct WmmaGemm
|
||||
static constexpr auto I5 = Number<5>{};
|
||||
|
||||
using CIndex = MultiIndex<2>;
|
||||
using CIndex4D = MultiIndex<4>;
|
||||
using CIndex3D = MultiIndex<3>;
|
||||
|
||||
__host__ __device__ constexpr WmmaGemm()
|
||||
{
|
||||
static_assert(NPerWmma == 16 && MPerWmma == 16,
|
||||
"Only support GemmNPerWmma == 16 and GemmMPerWmma == 16 for wmma");
|
||||
|
||||
static_assert(KPack == wmma_instr.k_per_wmma, "KPack should be k_per_wmma");
|
||||
static_assert(KPack % wmma_instr.k_per_wmma == 0, "KPack should be multiple of k_per_wmma");
|
||||
}
|
||||
|
||||
// WMMA output supporting C = A * B
|
||||
@@ -421,9 +418,49 @@ struct WmmaGemm
|
||||
Sequence<5>{}));
|
||||
}
|
||||
|
||||
// Transposed WMMA Output C' = B' * A'
|
||||
template <typename CDesc_MBlockxRepeat_MWave_MPerWMMA_NBlockxRepeat_NWave_NPerWMMA>
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeCDesc_MBlockxRepeat_MWave_MThreadPerSubGroup_NBlockxRepeat_NWave_NSubGroup_NAccVgprs(
|
||||
const CDesc_MBlockxRepeat_MWave_MPerWMMA_NBlockxRepeat_NWave_NPerWMMA&
|
||||
c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma)
|
||||
{
|
||||
const auto MBlockxRepeat =
|
||||
c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma.GetLength(I0);
|
||||
const auto NBlockxRepeat =
|
||||
c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma.GetLength(I3);
|
||||
const auto MWave =
|
||||
c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma.GetLength(I1);
|
||||
const auto NWave =
|
||||
c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma.GetLength(I4);
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma,
|
||||
make_tuple(
|
||||
make_pass_through_transform(MBlockxRepeat),
|
||||
make_pass_through_transform(MWave),
|
||||
make_pass_through_transform(Number<wmma_instr.num_thread_per_subgroups>{}),
|
||||
make_pass_through_transform(NBlockxRepeat),
|
||||
make_pass_through_transform(NWave),
|
||||
make_unmerge_transform(make_tuple(Number<wmma_instr.num_subgroups>{},
|
||||
Number<wmma_instr.num_acc_vgprs_per_wave>{}))),
|
||||
make_tuple(Sequence<0>{},
|
||||
Sequence<1>{},
|
||||
Sequence<2>{},
|
||||
Sequence<3>{},
|
||||
Sequence<4>{},
|
||||
Sequence<5>{}),
|
||||
make_tuple(Sequence<0>{},
|
||||
Sequence<1>{},
|
||||
Sequence<2>{},
|
||||
Sequence<3>{},
|
||||
Sequence<4>{},
|
||||
Sequence<5, 6>{}));
|
||||
}
|
||||
|
||||
__device__ static constexpr index_t GetRegSizePerWmma()
|
||||
{
|
||||
return wmma_instr.num_acc_vgprs_per_wave;
|
||||
return wmma_instr.num_acc_vgprs_per_wave * wmma_instr.acc_pack_number;
|
||||
}
|
||||
|
||||
__device__ static constexpr index_t GetWaveSize() { return wmma_instr.wave_size; }
|
||||
@@ -449,14 +486,16 @@ struct WmmaGemm
|
||||
,
|
||||
"base type couple must be (half, float), (bhalf, float), (half, half), (bhalf, bhalf), "
|
||||
"(int8, int32) or (int4, int32)!");
|
||||
if constexpr(!TransposeC)
|
||||
{
|
||||
wmma_instr.template run<MPerWmma, NPerWmma>(p_a_wave, p_b_wave, p_c_thread);
|
||||
}
|
||||
else
|
||||
{
|
||||
wmma_instr.template run<MPerWmma, NPerWmma>(p_b_wave, p_a_wave, p_c_thread);
|
||||
}
|
||||
static_for<0, KPack / wmma_instr.k_per_wmma, 1>{}([&](auto k) {
|
||||
if constexpr(!TransposeC)
|
||||
{
|
||||
wmma_instr.template run<MPerWmma, NPerWmma>(p_a_wave[k], p_b_wave[k], p_c_thread);
|
||||
}
|
||||
else
|
||||
{
|
||||
wmma_instr.template run<MPerWmma, NPerWmma>(p_b_wave[k], p_a_wave[k], p_c_thread);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
__device__ static auto GetLaneId() { return get_thread_local_1d_id() % wmma_instr.wave_size; }
|
||||
@@ -477,12 +516,12 @@ struct WmmaGemm
|
||||
|
||||
__host__ __device__ static auto CalculateAThreadOriginDataIndex()
|
||||
{
|
||||
return GetSwizzledLaneIdLow();
|
||||
return TransposeC ? GetLaneIdUnderSubGroup() : GetSwizzledLaneIdLow();
|
||||
}
|
||||
|
||||
__host__ __device__ static auto CalculateBThreadOriginDataIndex()
|
||||
{
|
||||
return GetLaneIdUnderSubGroup();
|
||||
return TransposeC ? GetSwizzledLaneIdLow() : GetLaneIdUnderSubGroup();
|
||||
}
|
||||
|
||||
__device__ static CIndex GetBeginOfThreadBlk()
|
||||
@@ -493,6 +532,14 @@ struct WmmaGemm
|
||||
return TransposeC ? CIndex{n_offset, m_offset} : CIndex{m_offset, n_offset};
|
||||
}
|
||||
|
||||
__device__ static CIndex3D GetBeginOfThreadBlk3D()
|
||||
{
|
||||
index_t n_offset = GetLaneIdUnderSubGroup();
|
||||
index_t m_offset = GetSubGroupId();
|
||||
|
||||
return TransposeC ? CIndex3D{n_offset, m_offset, I0} : CIndex3D{m_offset, n_offset, I0};
|
||||
}
|
||||
|
||||
static constexpr auto wmma =
|
||||
WmmaSelector<src_type_a, src_type_b, dst_type, MPerWmma, NPerWmma>{};
|
||||
static constexpr auto wmma_instr = wmma.selected_wmma;
|
||||
@@ -500,7 +547,10 @@ struct WmmaGemm
|
||||
__host__ __device__ static constexpr auto
|
||||
GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths()
|
||||
{
|
||||
return make_tuple(I1, I1, Number<wmma_instr.num_acc_vgprs_per_wave>{});
|
||||
return make_tuple(I1,
|
||||
I1,
|
||||
Number<wmma_instr.num_acc_vgprs_per_wave>{},
|
||||
Number<wmma_instr.acc_pack_number>{});
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
Reference in New Issue
Block a user