* wmma_op + unit test

* add arch limitation to wmma test

* change arch limitation

* Refactor + Add all type unit test(int4 compile failed)

* Add f32_16x16x16_bf16 unit test

* tempsave

* tempsave

* tempsave

* runtime bug, cannot find symbol

* workaround for incorrect HIP warpSize return value

* debugging

* tempsave

* Correctness OK, waiting for optimization

* Tidy up + format

* temp save

* temp save, reproduce the v_bfi_b32 issue

* add inline asm for wmmaop test

* tidy up

* clean some debug purpose code

* discard some codes

* clang format

* clang format

* compiler issue fixed + increase tile size

* navi3x_multipleD+example

* temp save

* workable

* batchedgemm[OK], groupconv[debug]

* groupconv: Sanity check[OK], Performance[Bad]

* navi3x_groupconv_need_optimization

* create necessary files

* save progress

* Add Inter-Row thread transfer

* save progress

* save debugging progress

* sanity check pass

* fix a host tensor bug and clean up flash-attn code

* format

* cancel unnecessary change

* cancel unnecessary change

* cancel unnecessary change

* temp save, add asm backend flag to amd_wmma

* Mat-A LDS Bypass sanity pass

* temp save

* gemm sanity fix

* Porting new blockwise gemm to flash attention

* Example branch provide to compiler team

* tempsave

* Fix a bug

* batched gemm ported

* conv A-skip lds ported

* Skip B-Lds real gemm

* Skip B Lds Gemm + MulD

* batched gemm, conv, skip b lds

* format

* Attn, skip b lds

* Change GridwiseOp nam

* fix a typo caused bug

* Skip A_Lds sanity pass, Skip B_Lds scratch occured

* Bug found, intra-row permute off caused

* bug found

* a fix

* disable buffer load due to incorrect 3rd dword

* update fmha config, no scratch generated

* update 3rd dword

* fmha config update

* FMHA, add support to gfx1101/gfx1102

* Merge origin dev (#2)

* [Navi3x] Fix Gridwise_multiple_d operation (#649)

* Add CMake Option "USE_OPT_NAVI3X"

* fix bug

* standardize docs (#655)

* Separate bibtex requirement from rocm-docs-core (#656)

* separate bibtex requirement from rocm-docs-core

* point requirements to source rocm-docs-core repo

* Add CMake Option "USE_OPT_NAVI3X" (#647)

* Add CMake Option "USE_OPT_NAVI3X"

* remove navi3x opt compile option from cmake script

* Conv + quantization + tanh  (#645)

* Rename file. Prepare to support another activation

* Add comment for quantization

* Extract out_elementop

* Add tanh example

* Add conv + bias + tanh quantization instance

* Add missing parameter

* Refine cmake

* Add external api and client example

* Extract variable in example

* Fix the comment

---------

Co-authored-by: zjing14 <zhangjing14@gmail.com>

* Add a denorm test fix (#603)

* Add type_convert implementations for bf16

* Add the fix for conv_fwd

* Add the fix for conv_bwd_data

* Add the fix for conv_bwd_weight

* Format

* Format

* Another format

* Add a macro to use workaround on MI200 only

* Format

---------

Co-authored-by: Rosty Geyyer <rosty.geyyer@amd.com>
Co-authored-by: zjing14 <zhangjing14@gmail.com>

* simplify karg in device/grid of split-k op (#644)

* simplify karg in device/grid split-k op

* fix mk_kn_mn instances

* add more instances

* use name from tensor layout

* fix 3rd dword of buffer source descriptor (#659)

* add fp64 instances (#658)

Co-authored-by: root <root@ctr-ubbsmc15.amd.com>

* Issue #666: Revert "simplify karg in device/grid of split-k op (#644)" (#665)

This reverts commit bb5530af91.

* Groupnorm + swish external api (#668)

* Rename to proper naming

* Add example of groupnorm + swish

* Extract duplicate code in example

* Add groupnorm + swish instances

* Ractor instance generation, split into multiple cpp file

* Add external api and client example

* Refine profiler message

* Use ck math version of exp

* Refine problem size in example

* Add host version of exp

* add a marco to turn on/off denorm fix (off by default) (#673)

* add a marco to turn off denorm fix by default

* expose the marco

---------

Co-authored-by: root <root@ctr-ubbsmc15.amd.com>

* fixed quant example (#672)

Co-authored-by: root <root@ctr-ubbsmc15.amd.com>

* Add dependabot config and pin rocm-docs-core (#663)

* [gtest] suppress unsafe buffer warn (#670)

ref: https://github.com/ROCmSoftwarePlatform/MIOpen/pull/1912

* Add memory index guard in wmma device ops (#667)

* Add more macros to turn on/off denorm fix (#678)

Co-authored-by: Rosty Geyyer <rosty.geyyer@amd.com>

* Fix a typo (#676)

* Add (#677)

* Allow using ROCm release candidate compilers. (#679)

* enable use of rocm5.5 release candidate 4

* upgrade to ROCM5.5 RC5

* try fix the PUB_KEY error, remove the cmake-data package

* upgrade to latest cmake version

* use private dockerhub repo for rocm5.5 rc5

* add missing bracket

* add vector load check

* solve conflicts

---------

Co-authored-by: Sam Wu <sjwu@ualberta.ca>
Co-authored-by: Sam Wu <sam.wu2@amd.com>
Co-authored-by: rocking5566 <ChunYu.Lai@amd.com>
Co-authored-by: zjing14 <zhangjing14@gmail.com>
Co-authored-by: Rostyslav Geyyer <46627076+geyyer@users.noreply.github.com>
Co-authored-by: Rosty Geyyer <rosty.geyyer@amd.com>
Co-authored-by: carlushuang <carlus.huang@amd.com>
Co-authored-by: root <root@ctr-ubbsmc15.amd.com>
Co-authored-by: Jun Liu <Liu.Jun@amd.com>
Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com>

* Disable SkipLDS & Align AIT api (#3)

* fix layernorm, reduction Ops (#4)

* [Navi3x] Fix Gridwise_multiple_d operation (#649)

* Add CMake Option "USE_OPT_NAVI3X"

* fix bug

* standardize docs (#655)

* Separate bibtex requirement from rocm-docs-core (#656)

* separate bibtex requirement from rocm-docs-core

* point requirements to source rocm-docs-core repo

* Add CMake Option "USE_OPT_NAVI3X" (#647)

* Add CMake Option "USE_OPT_NAVI3X"

* remove navi3x opt compile option from cmake script

* Conv + quantization + tanh  (#645)

* Rename file. Prepare to support another activation

* Add comment for quantization

* Extract out_elementop

* Add tanh example

* Add conv + bias + tanh quantization instance

* Add missing parameter

* Refine cmake

* Add external api and client example

* Extract variable in example

* Fix the comment

---------

Co-authored-by: zjing14 <zhangjing14@gmail.com>

* Add a denorm test fix (#603)

* Add type_convert implementations for bf16

* Add the fix for conv_fwd

* Add the fix for conv_bwd_data

* Add the fix for conv_bwd_weight

* Format

* Format

* Another format

* Add a macro to use workaround on MI200 only

* Format

---------

Co-authored-by: Rosty Geyyer <rosty.geyyer@amd.com>
Co-authored-by: zjing14 <zhangjing14@gmail.com>

* simplify karg in device/grid of split-k op (#644)

* simplify karg in device/grid split-k op

* fix mk_kn_mn instances

* add more instances

* use name from tensor layout

* fix 3rd dword of buffer source descriptor (#659)

* add fp64 instances (#658)

Co-authored-by: root <root@ctr-ubbsmc15.amd.com>

* Issue #666: Revert "simplify karg in device/grid of split-k op (#644)" (#665)

This reverts commit bb5530af91.

* Groupnorm + swish external api (#668)

* Rename to proper naming

* Add example of groupnorm + swish

* Extract duplicate code in example

* Add groupnorm + swish instances

* Ractor instance generation, split into multiple cpp file

* Add external api and client example

* Refine profiler message

* Use ck math version of exp

* Refine problem size in example

* Add host version of exp

* add a marco to turn on/off denorm fix (off by default) (#673)

* add a marco to turn off denorm fix by default

* expose the marco

---------

Co-authored-by: root <root@ctr-ubbsmc15.amd.com>

* fixed quant example (#672)

Co-authored-by: root <root@ctr-ubbsmc15.amd.com>

* Add dependabot config and pin rocm-docs-core (#663)

* [gtest] suppress unsafe buffer warn (#670)

ref: https://github.com/ROCmSoftwarePlatform/MIOpen/pull/1912

* Add memory index guard in wmma device ops (#667)

* Add more macros to turn on/off denorm fix (#678)

Co-authored-by: Rosty Geyyer <rosty.geyyer@amd.com>

* Fix a typo (#676)

* Add (#677)

* Allow using ROCm release candidate compilers. (#679)

* enable use of rocm5.5 release candidate 4

* upgrade to ROCM5.5 RC5

* try fix the PUB_KEY error, remove the cmake-data package

* upgrade to latest cmake version

* use private dockerhub repo for rocm5.5 rc5

* add missing bracket

* Disable SkipLDS & Align AIT api

* Update dependabot config (#682)

Co-authored-by: samjwu <samjwu@users.noreply.github.com>

* update attn api

* solve type_convert bug + enable

---------

Co-authored-by: Sam Wu <sjwu@ualberta.ca>
Co-authored-by: Sam Wu <sam.wu2@amd.com>
Co-authored-by: rocking5566 <ChunYu.Lai@amd.com>
Co-authored-by: zjing14 <zhangjing14@gmail.com>
Co-authored-by: Rostyslav Geyyer <46627076+geyyer@users.noreply.github.com>
Co-authored-by: Rosty Geyyer <rosty.geyyer@amd.com>
Co-authored-by: carlushuang <carlus.huang@amd.com>
Co-authored-by: root <root@ctr-ubbsmc15.amd.com>
Co-authored-by: Jun Liu <Liu.Jun@amd.com>
Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com>
Co-authored-by: samjwu <samjwu@users.noreply.github.com>
Co-authored-by: haocwang <Haocong.WANG@amd.com>

* fix typo

* Fix attention with causal mask

* multiple fix, try ait compile

* Add A/B not use LDS pipeline

* Clang format, Add gfx1101, gfx1102 support of FMHA example

* cancel change of format script

* 1. Enable 2-stage global Prefetch ( May cause VGPR spilling)
2. Enable FP16 accumulator blockwise_gemm

* clang-format

* 1. change blockwise gemm loopover direction from kmn to mnk ( ~1% improvement)
2. change kernel timing mode to 50 warmup + 50 timed repeat

* Update low level abstration of blockwise gemm wmma

* (2/5) bilinear gemm pass, perf bug: skip a lds has lower performance than skip b lds

* (3/5) batched gemm pass, perf bug: skip a lds has lower performance than skip b lds

* (4/5) grouped conv pass

* (5/5) attention pass, todo: debug lds perf bug

* AIT Attention API refactor (#8)

* sanity pass

* sanity pass 2

* confirm significant performance regression.

* turn on all instances

* turn off instance format

* Fix bug & tunning & format

* DML meta, self_attn+cross_attn

* sanity pass

* remove useless flag

* update tile and problem size used in AIT attention

* bug fix in grouped conv supporting check

* deprecate inline asm wmma

* Bug fix: double lds skip

* clang-format

* Fix errors in
1. example, fmha
2. gridwise pipeline
3. deviceop, fmha, change some containers from vector to array

* part2 of previous commit

* clang format

* API fix of gridwisegemmpipeline

* separate array base and vector base attention tensor transformation

* fix gemm

* clang format

* add gemm fp16 instances

* Temp save

* fpAintB kernel compile pass

* Sanity pass.

* Temp save

* debug code enabled

* Fp16AInt8B_GEMM sanity

* MQA implementation

* GQA-4 example

* tempsave

* Compile pass

* New implementation of fp16Aint8B Gemm, Acheieve similar math throughput with native fp16 Gemm

* format

* Todo: fix gemm_bilinear_wmma instances compilation bug

* Solve a bug when K1=16

* remove unnecessary changes

* Remove tensor layout limitation to LDS usage in tesnor contraction

* update self-attention and cross-attention

* fix a typo of name

* Add arch limiter for fp8 gemm

* enable fp8 gemm_xdl for all gfx9 targets

* temporarily disable gemm_xdl_fp16_fp8 on MI100/200

* fix the cmake logic for gemm_xdl_fp16_fp8

* re-enable the gemm_xdl_fp16_fp8 on MI100/200

---------

Co-authored-by: aska-0096 <haocwang@amd.com>
Co-authored-by: Sam Wu <sjwu@ualberta.ca>
Co-authored-by: Sam Wu <sam.wu2@amd.com>
Co-authored-by: rocking5566 <ChunYu.Lai@amd.com>
Co-authored-by: Rostyslav Geyyer <46627076+geyyer@users.noreply.github.com>
Co-authored-by: Rosty Geyyer <rosty.geyyer@amd.com>
Co-authored-by: carlushuang <carlus.huang@amd.com>
Co-authored-by: root <root@ctr-ubbsmc15.amd.com>
Co-authored-by: Jun Liu <Liu.Jun@amd.com>
Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com>
Co-authored-by: samjwu <samjwu@users.noreply.github.com>
Co-authored-by: haocwang <Haocong.WANG@amd.com>
Co-authored-by: illsilin <Illia.Silin@amd.com>
This commit is contained in:
zjing14
2024-03-08 19:11:51 -06:00
committed by GitHub
parent 363feb482d
commit 1837040a9c
73 changed files with 17542 additions and 2020 deletions

View File

@@ -89,6 +89,7 @@ struct wmma_type<WmmaInstr::wmma_f32_16x16x16_f16,
static constexpr index_t src_a_data_size = 2;
static constexpr index_t src_b_data_size = 2;
static constexpr index_t acc_data_size = 4;
static constexpr index_t acc_pack_number = 1;
// * Thread mapping inside wave, num_thread_per_subgroups always alone N direction
static constexpr index_t num_thread_per_subgroups = n_per_wmma;
@@ -100,7 +101,7 @@ struct wmma_type<WmmaInstr::wmma_f32_16x16x16_f16,
// * num_acc_vgprs_per_wave alone M direction
// * num_subgroups alone M direction
static constexpr index_t num_acc_vgprs_per_wave =
m_per_wmma * n_per_wmma * acc_data_size / wave_size / 4;
m_per_wmma * n_per_wmma * acc_data_size * acc_pack_number / wave_size / 4;
static constexpr index_t num_subgroups = wave_size / num_thread_per_subgroups;
template <index_t MPerWmma, index_t NPerWmma, class FloatA, class FloatB, class FloatC>
@@ -129,6 +130,7 @@ struct wmma_type<WmmaInstr::wmma_f32_16x16x16_bf16,
static constexpr index_t src_a_data_size = 2;
static constexpr index_t src_b_data_size = 2;
static constexpr index_t acc_data_size = 4;
static constexpr index_t acc_pack_number = 1;
static constexpr index_t num_thread_per_subgroups = n_per_wmma;
// Wave mode dependent propety
@@ -136,7 +138,7 @@ struct wmma_type<WmmaInstr::wmma_f32_16x16x16_bf16,
static constexpr index_t num_src_a_vgprs_per_wave = m_per_wmma * src_a_data_size / 4;
static constexpr index_t num_src_b_vgprs_per_wave = n_per_wmma * src_b_data_size / 4;
static constexpr index_t num_acc_vgprs_per_wave =
m_per_wmma * n_per_wmma * acc_data_size / wave_size / 4;
m_per_wmma * n_per_wmma * acc_data_size * acc_pack_number / wave_size / 4;
static constexpr index_t num_subgroups = wave_size / num_thread_per_subgroups;
template <index_t MPerWmma, index_t NPerWmma, class FloatA, class FloatB, class FloatC>
@@ -153,7 +155,6 @@ struct wmma_type<WmmaInstr::wmma_f32_16x16x16_bf16,
}
};
#ifdef CK_UNPACKED_ACC_DESC_LOGIC
template <index_t WaveSize>
struct wmma_type<WmmaInstr::wmma_f16_16x16x16_f16,
WaveSize,
@@ -166,6 +167,7 @@ struct wmma_type<WmmaInstr::wmma_f16_16x16x16_f16,
static constexpr index_t src_a_data_size = 2;
static constexpr index_t src_b_data_size = 2;
static constexpr index_t acc_data_size = 2;
static constexpr index_t acc_pack_number = 2;
static constexpr index_t num_thread_per_subgroups = n_per_wmma;
// Wave mode dependent propety
@@ -173,28 +175,22 @@ struct wmma_type<WmmaInstr::wmma_f16_16x16x16_f16,
static constexpr index_t num_src_a_vgprs_per_wave = m_per_wmma * src_a_data_size / 4;
static constexpr index_t num_src_b_vgprs_per_wave = n_per_wmma * src_b_data_size / 4;
static constexpr index_t num_acc_vgprs_per_wave =
m_per_wmma * n_per_wmma * acc_data_size / wave_size / 4;
m_per_wmma * n_per_wmma * acc_data_size * acc_pack_number / wave_size / 4;
static constexpr index_t num_subgroups = wave_size / num_thread_per_subgroups;
template <index_t MPerWmma,
index_t NPerWmma,
index_t Opsel,
class FloatA,
class FloatB,
class FloatC>
template <index_t MPerWmma, index_t NPerWmma, class FloatA, class FloatB, class FloatC>
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
{
if constexpr(wave_size == 32)
{
intrin_wmma_f16_16x16x16_f16_w32<MPerWmma, NPerWmma, Opsel>::Run(a, b, reg_c);
intrin_wmma_f16_16x16x16_f16_w32<MPerWmma, NPerWmma, false>::Run(a, b, reg_c);
}
else if constexpr(wave_size == 64)
{
intrin_wmma_f16_16x16x16_f16_w64<MPerWmma, NPerWmma, Opsel>::Run(a, b, reg_c);
intrin_wmma_f16_16x16x16_f16_w64<MPerWmma, NPerWmma, false>::Run(a, b, reg_c);
}
}
};
template <index_t WaveSize>
struct wmma_type<WmmaInstr::wmma_bf16_16x16x16_bf16,
WaveSize,
@@ -207,6 +203,7 @@ struct wmma_type<WmmaInstr::wmma_bf16_16x16x16_bf16,
static constexpr index_t src_a_data_size = 2;
static constexpr index_t src_b_data_size = 2;
static constexpr index_t acc_data_size = 2;
static constexpr index_t acc_pack_number = 2;
static constexpr index_t num_thread_per_subgroups = n_per_wmma;
// Wave mode dependent propety
@@ -214,7 +211,7 @@ struct wmma_type<WmmaInstr::wmma_bf16_16x16x16_bf16,
static constexpr index_t num_src_a_vgprs_per_wave = m_per_wmma * src_a_data_size / 4;
static constexpr index_t num_src_b_vgprs_per_wave = n_per_wmma * src_b_data_size / 4;
static constexpr index_t num_acc_vgprs_per_wave =
m_per_wmma * n_per_wmma * acc_data_size / wave_size / 4;
m_per_wmma * n_per_wmma * acc_data_size * acc_pack_number / wave_size / 4;
static constexpr index_t num_subgroups = wave_size / num_thread_per_subgroups;
template <index_t MPerWmma,
@@ -227,17 +224,15 @@ struct wmma_type<WmmaInstr::wmma_bf16_16x16x16_bf16,
{
if constexpr(wave_size == 32)
{
intrin_wmma_bf16_16x16x16_bf16_w32<MPerWmma, NPerWmma, Opsel>::Run(a, b, reg_c);
intrin_wmma_bf16_16x16x16_bf16_w32<MPerWmma, NPerWmma, false>::Run(a, b, reg_c);
}
else if constexpr(wave_size == 64)
{
intrin_wmma_bf16_16x16x16_bf16_w64<MPerWmma, NPerWmma, Opsel>::Run(a, b, reg_c);
intrin_wmma_bf16_16x16x16_bf16_w64<MPerWmma, NPerWmma, false>::Run(a, b, reg_c);
}
}
};
#endif
template <index_t WaveSize>
struct wmma_type<WmmaInstr::wmma_i32_16x16x16_iu8,
WaveSize,
@@ -250,6 +245,7 @@ struct wmma_type<WmmaInstr::wmma_i32_16x16x16_iu8,
static constexpr index_t src_a_data_size = 2;
static constexpr index_t src_b_data_size = 2;
static constexpr index_t acc_data_size = 4;
static constexpr index_t acc_pack_number = 1;
static constexpr index_t num_thread_per_subgroups = n_per_wmma;
// Wave mode dependent propety
@@ -257,7 +253,7 @@ struct wmma_type<WmmaInstr::wmma_i32_16x16x16_iu8,
static constexpr index_t num_src_a_vgprs_per_wave = m_per_wmma * src_a_data_size / 4;
static constexpr index_t num_src_b_vgprs_per_wave = n_per_wmma * src_b_data_size / 4;
static constexpr index_t num_acc_vgprs_per_wave =
m_per_wmma * n_per_wmma * acc_data_size / wave_size / 4;
m_per_wmma * n_per_wmma * acc_data_size * acc_pack_number / wave_size / 4;
static constexpr index_t num_subgroups = wave_size / num_thread_per_subgroups;
template <index_t MPerWmma,
@@ -346,7 +342,7 @@ struct WmmaSelector
static_assert(selected_wmma.k_per_wmma == 16, "WRONG! WMMA_M must equal to 16");
static_assert(selected_wmma.wave_size * selected_wmma.num_acc_vgprs_per_wave *
selected_wmma.acc_data_size ==
selected_wmma.acc_data_size * selected_wmma.acc_pack_number ==
selected_wmma.m_per_wmma * selected_wmma.n_per_wmma * 4,
"WRONG! Invalid Number of Accumulator Register");
}
@@ -358,7 +354,8 @@ template <typename src_type_a,
index_t MPerWmma,
index_t NPerWmma,
index_t KPack,
bool TransposeC = false>
bool TransposeC = false,
bool AssemblyBackend = false>
struct WmmaGemm
{
static constexpr auto I0 = Number<0>{};
@@ -369,14 +366,14 @@ struct WmmaGemm
static constexpr auto I5 = Number<5>{};
using CIndex = MultiIndex<2>;
using CIndex4D = MultiIndex<4>;
using CIndex3D = MultiIndex<3>;
__host__ __device__ constexpr WmmaGemm()
{
static_assert(NPerWmma == 16 && MPerWmma == 16,
"Only support GemmNPerWmma == 16 and GemmMPerWmma == 16 for wmma");
static_assert(KPack == wmma_instr.k_per_wmma, "KPack should be k_per_wmma");
static_assert(KPack % wmma_instr.k_per_wmma == 0, "KPack should be multiple of k_per_wmma");
}
// WMMA output supporting C = A * B
@@ -421,9 +418,49 @@ struct WmmaGemm
Sequence<5>{}));
}
// Transposed WMMA Output C' = B' * A'
template <typename CDesc_MBlockxRepeat_MWave_MPerWMMA_NBlockxRepeat_NWave_NPerWMMA>
__host__ __device__ static constexpr auto
MakeCDesc_MBlockxRepeat_MWave_MThreadPerSubGroup_NBlockxRepeat_NWave_NSubGroup_NAccVgprs(
const CDesc_MBlockxRepeat_MWave_MPerWMMA_NBlockxRepeat_NWave_NPerWMMA&
c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma)
{
const auto MBlockxRepeat =
c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma.GetLength(I0);
const auto NBlockxRepeat =
c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma.GetLength(I3);
const auto MWave =
c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma.GetLength(I1);
const auto NWave =
c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma.GetLength(I4);
return transform_tensor_descriptor(
c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma,
make_tuple(
make_pass_through_transform(MBlockxRepeat),
make_pass_through_transform(MWave),
make_pass_through_transform(Number<wmma_instr.num_thread_per_subgroups>{}),
make_pass_through_transform(NBlockxRepeat),
make_pass_through_transform(NWave),
make_unmerge_transform(make_tuple(Number<wmma_instr.num_subgroups>{},
Number<wmma_instr.num_acc_vgprs_per_wave>{}))),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<3>{},
Sequence<4>{},
Sequence<5>{}),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<3>{},
Sequence<4>{},
Sequence<5, 6>{}));
}
__device__ static constexpr index_t GetRegSizePerWmma()
{
return wmma_instr.num_acc_vgprs_per_wave;
return wmma_instr.num_acc_vgprs_per_wave * wmma_instr.acc_pack_number;
}
__device__ static constexpr index_t GetWaveSize() { return wmma_instr.wave_size; }
@@ -449,14 +486,16 @@ struct WmmaGemm
,
"base type couple must be (half, float), (bhalf, float), (half, half), (bhalf, bhalf), "
"(int8, int32) or (int4, int32)!");
if constexpr(!TransposeC)
{
wmma_instr.template run<MPerWmma, NPerWmma>(p_a_wave, p_b_wave, p_c_thread);
}
else
{
wmma_instr.template run<MPerWmma, NPerWmma>(p_b_wave, p_a_wave, p_c_thread);
}
static_for<0, KPack / wmma_instr.k_per_wmma, 1>{}([&](auto k) {
if constexpr(!TransposeC)
{
wmma_instr.template run<MPerWmma, NPerWmma>(p_a_wave[k], p_b_wave[k], p_c_thread);
}
else
{
wmma_instr.template run<MPerWmma, NPerWmma>(p_b_wave[k], p_a_wave[k], p_c_thread);
}
});
}
__device__ static auto GetLaneId() { return get_thread_local_1d_id() % wmma_instr.wave_size; }
@@ -477,12 +516,12 @@ struct WmmaGemm
__host__ __device__ static auto CalculateAThreadOriginDataIndex()
{
return GetSwizzledLaneIdLow();
return TransposeC ? GetLaneIdUnderSubGroup() : GetSwizzledLaneIdLow();
}
__host__ __device__ static auto CalculateBThreadOriginDataIndex()
{
return GetLaneIdUnderSubGroup();
return TransposeC ? GetSwizzledLaneIdLow() : GetLaneIdUnderSubGroup();
}
__device__ static CIndex GetBeginOfThreadBlk()
@@ -493,6 +532,14 @@ struct WmmaGemm
return TransposeC ? CIndex{n_offset, m_offset} : CIndex{m_offset, n_offset};
}
__device__ static CIndex3D GetBeginOfThreadBlk3D()
{
index_t n_offset = GetLaneIdUnderSubGroup();
index_t m_offset = GetSubGroupId();
return TransposeC ? CIndex3D{n_offset, m_offset, I0} : CIndex3D{m_offset, n_offset, I0};
}
static constexpr auto wmma =
WmmaSelector<src_type_a, src_type_b, dst_type, MPerWmma, NPerWmma>{};
static constexpr auto wmma_instr = wmma.selected_wmma;
@@ -500,7 +547,10 @@ struct WmmaGemm
__host__ __device__ static constexpr auto
GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths()
{
return make_tuple(I1, I1, Number<wmma_instr.num_acc_vgprs_per_wave>{});
return make_tuple(I1,
I1,
Number<wmma_instr.num_acc_vgprs_per_wave>{},
Number<wmma_instr.acc_pack_number>{});
}
};