From 2664df5e3e5d9cb6af09d88c604f72234516f3a6 Mon Sep 17 00:00:00 2001 From: zjing14 Date: Fri, 8 Mar 2024 19:11:51 -0600 Subject: [PATCH] Navi3 rel (#1176) * wmma_op + unit test * add arch limitation to wmma test * change arch limitation * Refactor + Add all type unit test(int4 compile failed) * Add f32_16x16x16_bf16 unit test * tempsave * tempsave * tempsave * runtime bug, cannot find symbol * workaround for incorrect HIP warpSize return value * debugging * tempsave * Correctness OK, waiting for optimization * Tidy up + format * temp save * temp save, reproduce the v_bfi_b32 issue * add inline asm for wmmaop test * tidy up * clean some debug purpose code * discard some codes * clang format * clang format * compiler issue fixed + increase tile size * navi3x_multipleD+example * temp save * workable * batchedgemm[OK], groupconv[debug] * groupconv: Sanity check[OK], Performance[Bad] * navi3x_groupconv_need_optimization * create necessary files * save progress * Add Inter-Row thread transfer * save progress * save debugging progress * sanity check pass * fix a host tensor bug and clean up flash-attn code * format * cancel unnecessary change * cancel unnecessary change * cancel unnecessary change * temp save, add asm backend flag to amd_wmma * Mat-A LDS Bypass sanity pass * temp save * gemm sanity fix * Porting new blockwise gemm to flash attention * Example branch provide to compiler team * tempsave * Fix a bug * batched gemm ported * conv A-skip lds ported * Skip B-Lds real gemm * Skip B Lds Gemm + MulD * batched gemm, conv, skip b lds * format * Attn, skip b lds * Change GridwiseOp nam * fix a typo caused bug * Skip A_Lds sanity pass, Skip B_Lds scratch occured * Bug found, intra-row permute off caused * bug found * a fix * disable buffer load due to incorrect 3rd dword * update fmha config, no scratch generated * update 3rd dword * fmha config update * FMHA, add support to gfx1101/gfx1102 * Merge origin dev (#2) * [Navi3x] Fix Gridwise_multiple_d operation (#649) * Add CMake Option "USE_OPT_NAVI3X" * fix bug * standardize docs (#655) * Separate bibtex requirement from rocm-docs-core (#656) * separate bibtex requirement from rocm-docs-core * point requirements to source rocm-docs-core repo * Add CMake Option "USE_OPT_NAVI3X" (#647) * Add CMake Option "USE_OPT_NAVI3X" * remove navi3x opt compile option from cmake script * Conv + quantization + tanh (#645) * Rename file. Prepare to support another activation * Add comment for quantization * Extract out_elementop * Add tanh example * Add conv + bias + tanh quantization instance * Add missing parameter * Refine cmake * Add external api and client example * Extract variable in example * Fix the comment --------- Co-authored-by: zjing14 * Add a denorm test fix (#603) * Add type_convert implementations for bf16 * Add the fix for conv_fwd * Add the fix for conv_bwd_data * Add the fix for conv_bwd_weight * Format * Format * Another format * Add a macro to use workaround on MI200 only * Format --------- Co-authored-by: Rosty Geyyer Co-authored-by: zjing14 * simplify karg in device/grid of split-k op (#644) * simplify karg in device/grid split-k op * fix mk_kn_mn instances * add more instances * use name from tensor layout * fix 3rd dword of buffer source descriptor (#659) * add fp64 instances (#658) Co-authored-by: root * Issue #666: Revert "simplify karg in device/grid of split-k op (#644)" (#665) This reverts commit 469cce884ed93ab0e59e793df5b3c00d7657bf7a. * Groupnorm + swish external api (#668) * Rename to proper naming * Add example of groupnorm + swish * Extract duplicate code in example * Add groupnorm + swish instances * Ractor instance generation, split into multiple cpp file * Add external api and client example * Refine profiler message * Use ck math version of exp * Refine problem size in example * Add host version of exp * add a marco to turn on/off denorm fix (off by default) (#673) * add a marco to turn off denorm fix by default * expose the marco --------- Co-authored-by: root * fixed quant example (#672) Co-authored-by: root * Add dependabot config and pin rocm-docs-core (#663) * [gtest] suppress unsafe buffer warn (#670) ref: https://github.com/ROCmSoftwarePlatform/MIOpen/pull/1912 * Add memory index guard in wmma device ops (#667) * Add more macros to turn on/off denorm fix (#678) Co-authored-by: Rosty Geyyer * Fix a typo (#676) * Add (#677) * Allow using ROCm release candidate compilers. (#679) * enable use of rocm5.5 release candidate 4 * upgrade to ROCM5.5 RC5 * try fix the PUB_KEY error, remove the cmake-data package * upgrade to latest cmake version * use private dockerhub repo for rocm5.5 rc5 * add missing bracket * add vector load check * solve conflicts --------- Co-authored-by: Sam Wu Co-authored-by: Sam Wu Co-authored-by: rocking5566 Co-authored-by: zjing14 Co-authored-by: Rostyslav Geyyer <46627076+geyyer@users.noreply.github.com> Co-authored-by: Rosty Geyyer Co-authored-by: carlushuang Co-authored-by: root Co-authored-by: Jun Liu Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com> * Disable SkipLDS & Align AIT api (#3) * fix layernorm, reduction Ops (#4) * [Navi3x] Fix Gridwise_multiple_d operation (#649) * Add CMake Option "USE_OPT_NAVI3X" * fix bug * standardize docs (#655) * Separate bibtex requirement from rocm-docs-core (#656) * separate bibtex requirement from rocm-docs-core * point requirements to source rocm-docs-core repo * Add CMake Option "USE_OPT_NAVI3X" (#647) * Add CMake Option "USE_OPT_NAVI3X" * remove navi3x opt compile option from cmake script * Conv + quantization + tanh (#645) * Rename file. Prepare to support another activation * Add comment for quantization * Extract out_elementop * Add tanh example * Add conv + bias + tanh quantization instance * Add missing parameter * Refine cmake * Add external api and client example * Extract variable in example * Fix the comment --------- Co-authored-by: zjing14 * Add a denorm test fix (#603) * Add type_convert implementations for bf16 * Add the fix for conv_fwd * Add the fix for conv_bwd_data * Add the fix for conv_bwd_weight * Format * Format * Another format * Add a macro to use workaround on MI200 only * Format --------- Co-authored-by: Rosty Geyyer Co-authored-by: zjing14 * simplify karg in device/grid of split-k op (#644) * simplify karg in device/grid split-k op * fix mk_kn_mn instances * add more instances * use name from tensor layout * fix 3rd dword of buffer source descriptor (#659) * add fp64 instances (#658) Co-authored-by: root * Issue #666: Revert "simplify karg in device/grid of split-k op (#644)" (#665) This reverts commit 469cce884ed93ab0e59e793df5b3c00d7657bf7a. * Groupnorm + swish external api (#668) * Rename to proper naming * Add example of groupnorm + swish * Extract duplicate code in example * Add groupnorm + swish instances * Ractor instance generation, split into multiple cpp file * Add external api and client example * Refine profiler message * Use ck math version of exp * Refine problem size in example * Add host version of exp * add a marco to turn on/off denorm fix (off by default) (#673) * add a marco to turn off denorm fix by default * expose the marco --------- Co-authored-by: root * fixed quant example (#672) Co-authored-by: root * Add dependabot config and pin rocm-docs-core (#663) * [gtest] suppress unsafe buffer warn (#670) ref: https://github.com/ROCmSoftwarePlatform/MIOpen/pull/1912 * Add memory index guard in wmma device ops (#667) * Add more macros to turn on/off denorm fix (#678) Co-authored-by: Rosty Geyyer * Fix a typo (#676) * Add (#677) * Allow using ROCm release candidate compilers. (#679) * enable use of rocm5.5 release candidate 4 * upgrade to ROCM5.5 RC5 * try fix the PUB_KEY error, remove the cmake-data package * upgrade to latest cmake version * use private dockerhub repo for rocm5.5 rc5 * add missing bracket * Disable SkipLDS & Align AIT api * Update dependabot config (#682) Co-authored-by: samjwu * update attn api * solve type_convert bug + enable --------- Co-authored-by: Sam Wu Co-authored-by: Sam Wu Co-authored-by: rocking5566 Co-authored-by: zjing14 Co-authored-by: Rostyslav Geyyer <46627076+geyyer@users.noreply.github.com> Co-authored-by: Rosty Geyyer Co-authored-by: carlushuang Co-authored-by: root Co-authored-by: Jun Liu Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com> Co-authored-by: samjwu Co-authored-by: haocwang * fix typo * Fix attention with causal mask * multiple fix, try ait compile * Add A/B not use LDS pipeline * Clang format, Add gfx1101, gfx1102 support of FMHA example * cancel change of format script * 1. Enable 2-stage global Prefetch ( May cause VGPR spilling) 2. Enable FP16 accumulator blockwise_gemm * clang-format * 1. change blockwise gemm loopover direction from kmn to mnk ( ~1% improvement) 2. change kernel timing mode to 50 warmup + 50 timed repeat * Update low level abstration of blockwise gemm wmma * (2/5) bilinear gemm pass, perf bug: skip a lds has lower performance than skip b lds * (3/5) batched gemm pass, perf bug: skip a lds has lower performance than skip b lds * (4/5) grouped conv pass * (5/5) attention pass, todo: debug lds perf bug * AIT Attention API refactor (#8) * sanity pass * sanity pass 2 * confirm significant performance regression. * turn on all instances * turn off instance format * Fix bug & tunning & format * DML meta, self_attn+cross_attn * sanity pass * remove useless flag * update tile and problem size used in AIT attention * bug fix in grouped conv supporting check * deprecate inline asm wmma * Bug fix: double lds skip * clang-format * Fix errors in 1. example, fmha 2. gridwise pipeline 3. deviceop, fmha, change some containers from vector to array * part2 of previous commit * clang format * API fix of gridwisegemmpipeline * separate array base and vector base attention tensor transformation * fix gemm * clang format * add gemm fp16 instances * Temp save * fpAintB kernel compile pass * Sanity pass. * Temp save * debug code enabled * Fp16AInt8B_GEMM sanity * MQA implementation * GQA-4 example * tempsave * Compile pass * New implementation of fp16Aint8B Gemm, Acheieve similar math throughput with native fp16 Gemm * format * Todo: fix gemm_bilinear_wmma instances compilation bug * Solve a bug when K1=16 * remove unnecessary changes * Remove tensor layout limitation to LDS usage in tesnor contraction * update self-attention and cross-attention * fix a typo of name * Add arch limiter for fp8 gemm * enable fp8 gemm_xdl for all gfx9 targets * temporarily disable gemm_xdl_fp16_fp8 on MI100/200 * fix the cmake logic for gemm_xdl_fp16_fp8 * re-enable the gemm_xdl_fp16_fp8 on MI100/200 --------- Co-authored-by: aska-0096 Co-authored-by: Sam Wu Co-authored-by: Sam Wu Co-authored-by: rocking5566 Co-authored-by: Rostyslav Geyyer <46627076+geyyer@users.noreply.github.com> Co-authored-by: Rosty Geyyer Co-authored-by: carlushuang Co-authored-by: root Co-authored-by: Jun Liu Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com> Co-authored-by: samjwu Co-authored-by: haocwang Co-authored-by: illsilin [ROCm/composable_kernel commit: 1837040a9c3253ba51506b7b7d236836b6efaa8d] --- example/01_gemm/CMakeLists.txt | 15 +- example/01_gemm/gemm_wmma_fp16.cpp | 47 +- example/01_gemm/run_gemm_example.inc | 16 + .../gemm_bilinear_wmma_fp16.cpp | 87 +- .../gemm_bilinear_wmma_int8.cpp | 87 +- .../CMakeLists.txt | 2 +- .../batched_gemm_bias_e_permute_wmma_fp16.cpp | 87 +- ...ed_conv_fwd_bias_relu_add_wmma_example.inc | 29 +- .../CMakeLists.txt | 10 +- ...e_scale_softmax_gemm_permute_wmma_fp16.cpp | 166 ++ ...m_scale_softmax_gemm_permute_wmma_fp16.cpp | 288 +++ .../cross_attention_forward_wmma_fp16.cpp | 354 ++++ ...uped_query_attention_forward_wmma_fp16.cpp | 302 +++ ...ulti_query_attention_forward_wmma_fp16.cpp | 287 +++ ...d_gemm_scale_softmax_gemm_permute_wmma.inc | 340 ++++ .../run_cross_attention_wmma.inc | 384 ++++ ...n_grouped_query_attention_forward_wmma.inc | 340 ++++ ...run_multi_query_attention_forward_wmma.inc | 339 ++++ .../run_self_attention_wmma.inc | 376 ++++ .../self_attention_forward_wmma_fp16.cpp | 332 ++++ example/64_fpAintB_gemm/CMakeLists.txt | 5 + example/64_fpAintB_gemm/common.hpp | 123 ++ .../64_fpAintB_gemm/fp16int8_gemm_wmma.cpp | 93 + example/64_fpAintB_gemm/run_gemm_example.inc | 172 ++ .../gpu/block/blockwise_gemm_wmma.hpp | 971 ++++----- ...oup_tensor_slice_transfer_v4r1_dequant.hpp | 223 +++ .../gpu/device/device_gemm_dequantB.hpp | 46 + ...d_contraction_multiple_d_wmma_cshuffle.hpp | 321 +-- ...emm_softmax_gemm_permute_wmma_cshuffle.hpp | 1729 +++++++++++++++++ .../device/impl/device_fpAintB_gemm_wmma.hpp | 714 +++++++ .../device_gemm_multiple_d_wmma_cshuffle.hpp | 359 ++-- .../gpu/device/impl/device_gemm_wmma.hpp | 417 ++-- ...conv_bwd_data_multiple_d_wmma_cshuffle.hpp | 6 +- ..._grouped_conv_bwd_weight_wmma_cshuffle.hpp | 10 +- ...uped_conv_fwd_multiple_d_wmma_cshuffle.hpp | 270 ++- ...e_grouped_query_attention_forward_wmma.hpp | 1254 ++++++++++++ ...ice_multi_query_attention_forward_wmma.hpp | 1244 ++++++++++++ .../gpu/device/masking_specialization.hpp | 5 +- .../element/unary_element_wise_operation.hpp | 76 + ...iple_d_gemm_multiple_d_xdl_cshuffle_v1.hpp | 2 +- ...atched_gemm_softmax_gemm_wmma_cshuffle.hpp | 1596 +++++++++++++++ .../gpu/grid/gridwise_fpAintB_gemm_wmma.hpp | 1046 ++++++++++ ...gridwise_gemm_multiple_d_wmma_cshuffle.hpp | 782 ++++++-- .../grid/gridwise_gemm_pipeline_selector.hpp | 11 +- .../gpu/grid/gridwise_gemm_pipeline_v1.hpp | 410 +++- ...e_gemm_split_k_multiple_d_xdl_cshuffle.hpp | 2 +- .../gpu/grid/gridwise_gemm_wmma.hpp | 734 +++++-- .../threadwise_tensor_slice_transfer.hpp | 135 ++ ...ise_tensor_slice_transfer_v3r1_dequant.hpp | 1066 ++++++++++ .../tensor_operation/gpu/warp/wmma_gemm.hpp | 120 +- ...ransform_contraction_to_gemm_arraybase.hpp | 391 ++++ include/ck/utility/amd_buffer_addressing.hpp | 3 +- include/ck/utility/amd_inline_asm.hpp | 24 +- include/ck/utility/data_type.hpp | 15 + include/ck/utility/type_convert.hpp | 57 + .../cpu/reference_batched_gemm.hpp | 246 +++ .../cpu/reference_fpAintB_gemm.hpp | 177 ++ .../tensor_operation_instance/gpu/gemm.hpp | 24 + .../device_grouped_conv_fwd_wmma_instance.hpp | 100 +- .../gpu/gemm/CMakeLists.txt | 6 + ...emm_wmma_f16_f16_f16_km_kn_mn_instance.cpp | 78 + ...emm_wmma_f16_f16_f16_km_nk_mn_instance.cpp | 78 + ...emm_wmma_f16_f16_f16_mk_kn_mn_instance.cpp | 158 ++ ...emm_wmma_f16_f16_f16_mk_nk_mn_instance.cpp | 78 + ...uffle_i8_i8_i8_i8_km_kn_mn_mn_instance.cpp | 48 +- ...uffle_i8_i8_i8_i8_km_nk_mn_mn_instance.cpp | 48 +- ...uffle_i8_i8_i8_i8_mk_kn_mn_mn_instance.cpp | 48 +- ...uffle_i8_i8_i8_i8_mk_nk_mn_mn_instance.cpp | 80 +- .../grouped_conv2d_bwd_data/CMakeLists.txt | 32 +- .../gpu/grouped_conv2d_fwd/CMakeLists.txt | 34 +- .../gpu/grouped_conv3d_fwd/CMakeLists.txt | 3 +- test/grouped_convnd_bwd_data/CMakeLists.txt | 2 +- test/grouped_convnd_bwd_weight/CMakeLists.txt | 2 +- 73 files changed, 17542 insertions(+), 2020 deletions(-) create mode 100644 example/32_batched_gemm_scale_softmax_gemm/batched_gemm_lower_triangle_scale_softmax_gemm_permute_wmma_fp16.cpp create mode 100644 example/32_batched_gemm_scale_softmax_gemm/batched_gemm_scale_softmax_gemm_permute_wmma_fp16.cpp create mode 100644 example/32_batched_gemm_scale_softmax_gemm/cross_attention_forward_wmma_fp16.cpp create mode 100644 example/32_batched_gemm_scale_softmax_gemm/grouped_query_attention_forward_wmma_fp16.cpp create mode 100644 example/32_batched_gemm_scale_softmax_gemm/multi_query_attention_forward_wmma_fp16.cpp create mode 100644 example/32_batched_gemm_scale_softmax_gemm/run_batched_gemm_scale_softmax_gemm_permute_wmma.inc create mode 100644 example/32_batched_gemm_scale_softmax_gemm/run_cross_attention_wmma.inc create mode 100644 example/32_batched_gemm_scale_softmax_gemm/run_grouped_query_attention_forward_wmma.inc create mode 100644 example/32_batched_gemm_scale_softmax_gemm/run_multi_query_attention_forward_wmma.inc create mode 100644 example/32_batched_gemm_scale_softmax_gemm/run_self_attention_wmma.inc create mode 100644 example/32_batched_gemm_scale_softmax_gemm/self_attention_forward_wmma_fp16.cpp create mode 100644 example/64_fpAintB_gemm/CMakeLists.txt create mode 100644 example/64_fpAintB_gemm/common.hpp create mode 100644 example/64_fpAintB_gemm/fp16int8_gemm_wmma.cpp create mode 100644 example/64_fpAintB_gemm/run_gemm_example.inc create mode 100644 include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1_dequant.hpp create mode 100644 include/ck/tensor_operation/gpu/device/device_gemm_dequantB.hpp create mode 100644 include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp create mode 100644 include/ck/tensor_operation/gpu/device/impl/device_fpAintB_gemm_wmma.hpp create mode 100644 include/ck/tensor_operation/gpu/device/impl/device_grouped_query_attention_forward_wmma.hpp create mode 100644 include/ck/tensor_operation/gpu/device/impl/device_multi_query_attention_forward_wmma.hpp create mode 100644 include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp create mode 100644 include/ck/tensor_operation/gpu/grid/gridwise_fpAintB_gemm_wmma.hpp create mode 100644 include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1_dequant.hpp create mode 100644 include/ck/tensor_operation/operator_transform/transform_contraction_to_gemm_arraybase.hpp create mode 100644 library/include/ck/library/reference_tensor_operation/cpu/reference_fpAintB_gemm.hpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm/device_gemm_wmma_f16_f16_f16_km_kn_mn_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm/device_gemm_wmma_f16_f16_f16_km_nk_mn_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm/device_gemm_wmma_f16_f16_f16_mk_kn_mn_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm/device_gemm_wmma_f16_f16_f16_mk_nk_mn_instance.cpp diff --git a/example/01_gemm/CMakeLists.txt b/example/01_gemm/CMakeLists.txt index 5b71cd1548..2fa8e77462 100644 --- a/example/01_gemm/CMakeLists.txt +++ b/example/01_gemm/CMakeLists.txt @@ -27,7 +27,7 @@ add_example_dependencies(example_gemm_xdl example_gemm_xdl_wavelet_fp16) add_example_executable(example_gemm_xdl_skip_b_lds_fp16 gemm_xdl_skip_b_lds_fp16.cpp) add_example_dependencies(example_gemm_xdl example_gemm_xdl_skip_b_lds_fp16) -if(GPU_TARGETS MATCHES "gfx1100" OR GPU_TARGETS MATCHES "gfx1101" OR GPU_TARGETS MATCHES "gfx1102") +if(GPU_TARGETS MATCHES "gfx11") add_custom_target(example_gemm_wmma) add_example_executable(example_gemm_wmma_fp16 gemm_wmma_fp16.cpp) add_example_dependencies(example_gemm_wmma example_gemm_wmma_fp16) @@ -53,12 +53,6 @@ add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp64) add_example_executable(example_gemm_xdl_streamk gemm_xdl_streamk.cpp) -add_example_executable(example_gemm_xdl_fp8 gemm_xdl_fp8.cpp) -add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp8) - -add_example_executable(example_gemm_xdl_fp8_bf8 gemm_xdl_fp8_bf8.cpp) -add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp8_bf8) - list(APPEND gpu_list gfx90a gfx940 gfx941 gfx942) set(target 0) foreach(gpu IN LISTS GPU_TARGETS) @@ -72,5 +66,12 @@ foreach(gpu IN LISTS GPU_TARGETS) endif() endforeach() +add_example_executable(example_gemm_xdl_fp8 gemm_xdl_fp8.cpp) +add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp8) + +add_example_executable(example_gemm_xdl_fp8_bf8 gemm_xdl_fp8_bf8.cpp) +add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp8_bf8) + add_example_executable(example_gemm_xdl_fp16_fp8 gemm_xdl_fp16_fp8.cpp) add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp16_fp8) + diff --git a/example/01_gemm/gemm_wmma_fp16.cpp b/example/01_gemm/gemm_wmma_fp16.cpp index b11fe76ab2..8c52e4f7d7 100644 --- a/example/01_gemm/gemm_wmma_fp16.cpp +++ b/example/01_gemm/gemm_wmma_fp16.cpp @@ -19,15 +19,50 @@ using AElementOp = PassThrough; using BElementOp = PassThrough; using CElementOp = PassThrough; -static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::MNKPadding; // clang-format off using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmWmma_CShuffle -// ######| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer|MRepeat|NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| -// ######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | WMMA| WMMA| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN|MWmmaPerWave|NWmmaPerWave| _MBlock_MWaveMPerWmma| ScalarPerVector| -// ######| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerWmma| _NWaveNPerWmma| -// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - < ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmMNKPadding, 256, 128, 256, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, 1>; + < ALayout, + BLayout, + CLayout, + ADataType, + BDataType, + CDataType, + AccDataType, + CShuffleDataType, + AElementOp, + BElementOp, + CElementOp, + GemmDefault, + 1, // Prefetch stage + 128, // BlockSize + 64, // MPerBlock + 128, // NPerBlock + 64, // KPerBlock + 8, // K1 + 16, // MPerWmma + 16, // NPerWmma + 2, // M-Repeat // M-PerWmma / M-Repeat = M-Wave + 4, // N-Repeat // N-PerWmma / N-Repeat = N-Wave + S<4, 32, 1>, + S<1, 0, 2>, + S<1, 0, 2>, + 2, + 8, + 8, + true, + S<4, 32, 1>, + S<1, 0, 2>, + S<1, 0, 2>, + 2, + 8, + 8, + true, + 1, // C shuffle (M Repeat) Per store + 1, // C shuffle (N Repeat) Per store + S<1, 32, 1, 4>, + 8>; // clang-format on using ReferenceGemmInstance = ck::tensor_operation::host:: diff --git a/example/01_gemm/run_gemm_example.inc b/example/01_gemm/run_gemm_example.inc index 2837937ead..b04e4e53a8 100644 --- a/example/01_gemm/run_gemm_example.inc +++ b/example/01_gemm/run_gemm_example.inc @@ -150,6 +150,22 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) ck::utils::FillUniformDistributionIntegerValue{-5.f, 5.f}(a_m_k); ck::utils::FillUniformDistributionIntegerValue{-5.f, 5.f}(b_k_n); break; + case 2: + ck::utils::FillUniformDistribution{-1.f, 1.f}(a_m_k); + ck::utils::FillUniformDistribution{-1.f, 1.f}(b_k_n); + break; + case 3: + ck::utils::FillUniformDistributionIntegerValue{1.f, 1.f}(a_m_k); + ck::utils::FillUniformDistributionIntegerValue{-5.f, 5.f}(b_k_n); + break; + case 4: + ck::utils::FillUniformDistributionIntegerValue{1.f, 1.f}(a_m_k); + ck::utils::FillUniformDistributionIntegerValue{1.f, 1.f}(b_k_n); + break; + case 5: + ck::utils::FillUniformDistributionIntegerValue{-2.f, 2.f}(a_m_k); + ck::utils::FillUniformDistributionIntegerValue{-2.f, 2.f}(b_k_n); + break; default: ck::utils::FillUniformDistribution{-0.1f, 0.1f}(a_m_k); ck::utils::FillUniformDistribution{-0.1f, 0.1f}(b_k_n); diff --git a/example/02_gemm_bilinear/gemm_bilinear_wmma_fp16.cpp b/example/02_gemm_bilinear/gemm_bilinear_wmma_fp16.cpp index 877792d740..d1b820da7b 100644 --- a/example/02_gemm_bilinear/gemm_bilinear_wmma_fp16.cpp +++ b/example/02_gemm_bilinear/gemm_bilinear_wmma_fp16.cpp @@ -65,48 +65,49 @@ using CDEElementOp = AlphaBetaAdd; static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; -using DeviceOpInstance = - ck::tensor_operation::device::DeviceGemmMultipleD_Wmma_CShuffle, - ELayout, - ADataType, - BDataType, - ck::Tuple, - EDataType, - AccDataType, - CShuffleDataType, - AElementOp, - BElementOp, - CDEElementOp, - GemmSpec, - 256, - 128, - 256, - 8, - 8, - 16, - 16, - 4, - 4, - S<4, 64, 1>, - S<1, 0, 2>, - S<1, 0, 2>, - 2, - 8, - 8, - true, - S<4, 64, 1>, - S<1, 0, 2>, - S<1, 0, 2>, - 2, - 8, - 8, - true, - 1, - 1, - S<1, 32, 1, 8>, - 8>; +using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleD_Wmma_CShuffle< + ALayout, + BLayout, + ck::Tuple, + ELayout, + ADataType, + BDataType, + AccDataType, + CShuffleDataType, + ck::Tuple, + EDataType, + AElementOp, + BElementOp, + CDEElementOp, + GemmSpec, + 2, // Prefetch stage + 128, // BlockSize + 128, // MPerBlock + 64, // NPerBlock + 64, // KPerBlock + 8, // K1 + 16, // MPerWmma + 16, // NPerWmma + 4, // M-Repeat // M-PerWmma / M-Repeat = M-Wave + 2, // N-Repeat // N-PerWmma / N-Repeat = N-Wave + S<4, 32, 1>, + S<1, 0, 2>, + S<1, 0, 2>, + 2, + 8, + 8, + true, + S<4, 32, 1>, + S<1, 0, 2>, + S<1, 0, 2>, + 2, + 8, + 8, + true, + 1, // C shuffle (M Repeat) Per store + 1, // C shuffle (N Repeat) Per store + S<1, 32, 1, 4>, + 8>; int main(int argc, char* argv[]) { @@ -264,7 +265,7 @@ int main(int argc, char* argv[]) float gb_per_sec = num_btype / 1.E6 / ave_time; std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s" - << std::endl; + << device_op.GetTypeString() << std::endl; e_device_buf.FromDevice(e_m_n_device_result.mData.data()); diff --git a/example/02_gemm_bilinear/gemm_bilinear_wmma_int8.cpp b/example/02_gemm_bilinear/gemm_bilinear_wmma_int8.cpp index 9f23ad2652..aca136f801 100644 --- a/example/02_gemm_bilinear/gemm_bilinear_wmma_int8.cpp +++ b/example/02_gemm_bilinear/gemm_bilinear_wmma_int8.cpp @@ -55,7 +55,7 @@ using DDataType = I8; using EDataType = I8; using ALayout = Row; -using BLayout = Row; +using BLayout = Col; using DLayout = Row; using ELayout = Row; @@ -65,48 +65,49 @@ using CDEElementOp = AlphaBetaAdd; static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default; -using DeviceOpInstance = - ck::tensor_operation::device::DeviceGemmMultipleD_Wmma_CShuffle, - ELayout, - ADataType, - BDataType, - ck::Tuple, - EDataType, - AccDataType, - CShuffleDataType, - AElementOp, - BElementOp, - CDEElementOp, - GemmSpec, - 32, - 16, - 16, - 4, - 16, - 16, - 16, - 1, - 1, - S<2, 16, 1>, - S<1, 0, 2>, - S<1, 0, 2>, - 2, - 16, - 16, - 1, - S<4, 1, 8>, - S<0, 2, 1>, - S<0, 2, 1>, - 1, - 16, - 2, - 1, - 1, - 1, - S<1, 16, 1, 2>, - 8>; +using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleD_Wmma_CShuffle< + ALayout, + BLayout, + ck::Tuple, + ELayout, + ADataType, + BDataType, + AccDataType, + CShuffleDataType, + ck::Tuple, + EDataType, + AElementOp, + BElementOp, + CDEElementOp, + GemmSpec, + 2, // Prefetch stage + 128, // BlockSize + 128, // MPerBlock + 64, // NPerBlock + 64, // KPerBlock + 8, // K1 + 16, // MPerWmma + 16, // NPerWmma + 4, // M-Repeat // M-PerWmma / M-Repeat = M-Wave + 2, // N-Repeat // N-PerWmma / N-Repeat = N-Wave + S<4, 32, 1>, + S<1, 0, 2>, + S<1, 0, 2>, + 2, + 8, + 8, + true, + S<4, 32, 1>, + S<1, 0, 2>, + S<1, 0, 2>, + 2, + 8, + 8, + true, + 1, // C shuffle (M Repeat) Per store + 1, // C shuffle (N Repeat) Per store + S<1, 32, 1, 4>, + 8>; int main(int argc, char* argv[]) { diff --git a/example/29_batched_gemm_bias_e_permute/CMakeLists.txt b/example/29_batched_gemm_bias_e_permute/CMakeLists.txt index 32a87dd200..f343cc1910 100644 --- a/example/29_batched_gemm_bias_e_permute/CMakeLists.txt +++ b/example/29_batched_gemm_bias_e_permute/CMakeLists.txt @@ -1,5 +1,5 @@ add_example_executable(example_batched_gemm_bias_e_permute_xdl_fp16 batched_gemm_bias_e_permute_xdl_fp16.cpp) -if(GPU_TARGETS MATCHES "gfx1100" OR GPU_TARGETS MATCHES "gfx1101" OR GPU_TARGETS MATCHES "gfx1102") +if(GPU_TARGETS MATCHES "gfx11") add_example_executable(example_batched_gemm_bias_e_permute_wmma_fp16 batched_gemm_bias_e_permute_wmma_fp16.cpp) endif() diff --git a/example/29_batched_gemm_bias_e_permute/batched_gemm_bias_e_permute_wmma_fp16.cpp b/example/29_batched_gemm_bias_e_permute/batched_gemm_bias_e_permute_wmma_fp16.cpp index 62233e5351..2bbf430c4e 100644 --- a/example/29_batched_gemm_bias_e_permute/batched_gemm_bias_e_permute_wmma_fp16.cpp +++ b/example/29_batched_gemm_bias_e_permute/batched_gemm_bias_e_permute_wmma_fp16.cpp @@ -43,9 +43,10 @@ using AElementOp = ck::tensor_operation::element_wise::PassThrough; using BElementOp = ck::tensor_operation::element_wise::PassThrough; using CDEElementOp = ck::tensor_operation::element_wise::Add; -static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default; +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; -static constexpr auto ABSpec = ck::tensor_operation::device::TensorSpecialization::Packed; +static constexpr auto ASpec = ck::tensor_operation::device::TensorSpecialization::Default; +static constexpr auto BSpec = ck::tensor_operation::device::TensorSpecialization::Default; static constexpr auto DESpec = ck::tensor_operation::device::TensorSpecialization::Default; using DeviceOpInstanceKKNN = @@ -55,43 +56,44 @@ using DeviceOpInstanceKKNN = NumDimK, ADataType, BDataType, - DsDataType, - EDataType, AccDataType, CShuffleDataType, + DsDataType, + EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, - ABSpec, - ABSpec, + ASpec, + BSpec, DESpec, - 256, + 1, 128, - 256, - 8, - 8, + 64, + 64, + 64, + 4, 16, 16, + 1, 4, - 4, - S<4, 64, 1>, + S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, - 8, - 8, + 4, + 4, true, - S<4, 64, 1>, + S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, - 8, - 8, + 4, + 4, true, 1, 1, - S<1, 32, 1, 8>, + S<1, 64, 1, 2>, 8>; using DeviceOpInstance = DeviceOpInstanceKKNN; @@ -251,6 +253,38 @@ int main(int argc, char* argv[]) ck::index_t K0 = 2048; + if(argc == 1) + { + // use default case + } + else if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else if(argc == 11) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + G0 = std::stoi(argv[4]); + G1 = std::stoi(argv[5]); + M0 = std::stoi(argv[6]); + M1 = std::stoi(argv[7]); + N0 = std::stoi(argv[8]); + N1 = std::stoi(argv[9]); + K0 = std::stoi(argv[10]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=no, 1=yes)\n"); + printf("arg4-10: G0, G1, M0, M1, N0, N1, K0\n"); + exit(0); + } + // A[G0, G1, M0, M1, K0] std::vector a_gs_ms_ks_lengths{G0, G1, M0, M1, K0}; std::vector a_gs_ms_ks_strides{G1 * M0 * M1 * K0, M0 * M1 * K0, M1 * K0, K0, 1}; @@ -266,23 +300,6 @@ int main(int argc, char* argv[]) std::vector e_gs_ms_ns_strides{ G1 * M0 * N0 * M1 * N1, M0 * N0 * M1 * N1, N0 * M1 * N1, N1, M1 * N1, 1}; - if(argc == 1) - { - // use default case - } - else if(argc == 4) - { - do_verification = std::stoi(argv[1]); - init_method = std::stoi(argv[2]); - time_kernel = std::stoi(argv[3]); - } - else - { - printf("arg1: verification (0=no, 1=yes)\n"); - printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); - printf("arg3: time kernel (0=no, 1=yes)\n"); - exit(0); - } Tensor a_gs_ms_ks(a_gs_ms_ks_lengths, a_gs_ms_ks_strides); Tensor b_gs_ns_ks(b_gs_ns_ks_lengths, b_gs_ns_ks_strides); Tensor d_gs_ms_ns(d_gs_ms_ns_lengths, d_gs_ms_ns_strides); diff --git a/example/30_grouped_conv_fwd_multiple_d/run_grouped_conv_fwd_bias_relu_add_wmma_example.inc b/example/30_grouped_conv_fwd_multiple_d/run_grouped_conv_fwd_bias_relu_add_wmma_example.inc index 360b2c8947..ca8746bb97 100644 --- a/example/30_grouped_conv_fwd_multiple_d/run_grouped_conv_fwd_bias_relu_add_wmma_example.inc +++ b/example/30_grouped_conv_fwd_multiple_d/run_grouped_conv_fwd_bias_relu_add_wmma_example.inc @@ -42,41 +42,42 @@ using DeviceConvFwdInstance = OutputLayout, InKernelDataType, WeiKernelDataType, - ck::Tuple, - OutKernelDataType, AccDataType, CShuffleDataType, + ck::Tuple, + OutKernelDataType, InElementOp, WeiElementOp, OutElementOp, ConvSpec, // ConvForwardSpecialization GemmSpec, // GemmSpecialization - 256, // BlockSize - 128, // MPerBlock - 128, // NPerBlock - 4, // K0PerBlock + 1, // Prefetch stage + 128, // BlockSize + 64, // MPerBlock + 64, // NPerBlock + 64, // KPerBlock 8, // K1 16, // MPerWMMA 16, // NPerWMMA 4, // MRepeat - 2, // NRepeat - S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 + 1, // NRepeat + S<4, 32, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder S<1, 0, 2>, // ABlockTransferSrcAccessOrder 2, // ABlockTransferSrcVectorDim 8, // ABlockTransferSrcScalarPerVector 8, // ABlockTransferDstScalarPerVector_AK1 true, // ABlockLdsExtraM - S<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1 + S<4, 32, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1 S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder S<1, 0, 2>, // BBlockTransferSrcAccessOrder 2, // BBlockTransferSrcVectorDim 8, // BBlockTransferSrcScalarPerVector 8, // BBlockTransferDstScalarPerVector_BK1 true, // BBlockLdsExtraN - 4, - 2, - S<1, 32, 1, 8>, + 1, + 1, + S<1, 16, 1, 8>, 8>; template @@ -277,9 +278,9 @@ bool run_grouped_conv_fwd_bias_relu_add_example(int argc, char* argv[]) switch(conv_param.num_dim_spatial_) { - case 1: return run_grouped_conv_fwd_bias_relu_add<1>(config, conv_param); + // case 1: return run_grouped_conv_fwd_bias_relu_add<1>(config, conv_param); case 2: return run_grouped_conv_fwd_bias_relu_add<2>(config, conv_param); - case 3: return run_grouped_conv_fwd_bias_relu_add<3>(config, conv_param); + // case 3: return run_grouped_conv_fwd_bias_relu_add<3>(config, conv_param); } return false; diff --git a/example/32_batched_gemm_scale_softmax_gemm/CMakeLists.txt b/example/32_batched_gemm_scale_softmax_gemm/CMakeLists.txt index 2a24abf094..c6cca7b586 100644 --- a/example/32_batched_gemm_scale_softmax_gemm/CMakeLists.txt +++ b/example/32_batched_gemm_scale_softmax_gemm/CMakeLists.txt @@ -1,3 +1,12 @@ +if(GPU_TARGETS MATCHES "gfx11") + add_example_executable(example_batched_gemm_lower_triangle_scale_softmax_gemm_permute_wmma_fp16 batched_gemm_lower_triangle_scale_softmax_gemm_permute_wmma_fp16.cpp) + add_example_executable(example_batched_gemm_scale_softmax_gemm_permute_wmma_fp16 batched_gemm_scale_softmax_gemm_permute_wmma_fp16.cpp) + add_example_executable(example_self_attention_forward_wmma_fp16 self_attention_forward_wmma_fp16.cpp) + add_example_executable(example_cross_attention_forward_wmma_fp16 cross_attention_forward_wmma_fp16.cpp) + add_example_executable(example_multi_query_attention_forward_wmma_fp16 multi_query_attention_forward_wmma_fp16.cpp) + add_example_executable(example_grouped_query_attention_forward_wmma_fp16 grouped_query_attention_forward_wmma_fp16.cpp) +endif() + add_custom_target(example_gemm_scale_softmax_gemm) add_example_executable(example_batched_gemm_scale_softmax_gemm_xdl_fp16 batched_gemm_scale_softmax_gemm_xdl_fp16.cpp) @@ -20,4 +29,3 @@ add_example_dependencies(example_gemm_scale_softmax_gemm example_batched_gemm_sc add_example_executable(example_batched_gemm_scale_softmax_gemm_permute_xdl_bf16 batched_gemm_scale_softmax_gemm_permute_xdl_bf16.cpp) add_example_dependencies(example_gemm_scale_softmax_gemm example_batched_gemm_scale_softmax_gemm_permute_xdl_bf16) - diff --git a/example/32_batched_gemm_scale_softmax_gemm/batched_gemm_lower_triangle_scale_softmax_gemm_permute_wmma_fp16.cpp b/example/32_batched_gemm_scale_softmax_gemm/batched_gemm_lower_triangle_scale_softmax_gemm_permute_wmma_fp16.cpp new file mode 100644 index 0000000000..2c7bacfc4e --- /dev/null +++ b/example/32_batched_gemm_scale_softmax_gemm/batched_gemm_lower_triangle_scale_softmax_gemm_permute_wmma_fp16.cpp @@ -0,0 +1,166 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +/* +Gemm + Softmax + Gemm fused operation. Computes C_g_m_n = Softmax(A_g_m_k * B0_g_k_l) * B1_g_l_n + |-----------------| + Gemm0 + |-------------------------------------| + Gemm1 +*/ + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/tensor_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_softmax.hpp" + +template +using S = ck::Sequence; + +using F16 = ck::half_t; +using F32 = float; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using ADataType = F16; +using B0DataType = F16; +using B1DataType = F16; +using Acc0DataType = F32; +using Acc1DataType = F32; +using CShuffleDataType = F32; +using CDataType = F16; +using Acc0BiasDataType = ck::Tuple<>; +using Acc1BiasDataType = ck::Tuple<>; + +static constexpr ck::index_t NumDimG = 2; +static constexpr ck::index_t NumDimM = 1; +static constexpr ck::index_t NumDimN = 1; +static constexpr ck::index_t NumDimK = 1; +static constexpr ck::index_t NumDimO = 1; + +using AElementOp = PassThrough; +using B0ElementOp = PassThrough; +using Acc0ElementOp = ck::tensor_operation::element_wise::Scale; +using B1ElementOp = PassThrough; +using CElementOp = PassThrough; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKOPadding; +static constexpr auto MaskingSpec = + ck::tensor_operation::device::MaskingSpecialization::MaskOutUpperTriangle; + +static constexpr auto TensorSpecA = ck::tensor_operation::device::TensorSpecialization::Default; +static constexpr auto TensorSpecB0 = ck::tensor_operation::device::TensorSpecialization::Default; +static constexpr auto TensorSpecB1 = ck::tensor_operation::device::TensorSpecialization::Default; +static constexpr auto TensorSpecC = ck::tensor_operation::device::TensorSpecialization::Default; + +using DeviceMHAFactory = + std::tuple, // ABlockTransfer MK -> K0 M K1 + S<1, 0, 2>, + S<1, 0, 2>, + 2, + 8, + 8, + true, + S<4, 64, 1>, // B0BlockTransfer LK -> K0 L K1 + S<1, 0, 2>, + S<1, 0, 2>, + 2, + 8, + 8, + true, + S<4, 8, 8>, // B1BlockTransfer NL -> L0 N L1 + S<0, 2, 1>, + S<0, 2, 1>, + 1, + 8, + 1, + false, + 1, // CShuffleMWmmaPerWavePerShuffle + 2, // CShuffleNWmmaPerWavePerShuffle + S<1, 64, 1, 4>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock + 8, // CShuffleBlockTransferScalarPerVector_NPerBlock + MaskingSpec> // MaskingSpecialization + >; +// Ref Gemm0: fp16 in, fp32 out +using ReferenceGemm0Instance = ck::tensor_operation::host::ReferenceBatchedGemm; + +// Ref Softmax: fp32 in, fp16 out +using ReferenceSoftmaxInstance = + ck::tensor_operation::host::ReferenceSoftmax; + +// Ref Gemm1: fp16 in, fp16 out +using ReferenceGemm1Instance = ck::tensor_operation::host::ReferenceBatchedGemm; + +#include "run_batched_gemm_scale_softmax_gemm_permute_wmma.inc" + +int main(int argc, char* argv[]) { return run(argc, argv); } diff --git a/example/32_batched_gemm_scale_softmax_gemm/batched_gemm_scale_softmax_gemm_permute_wmma_fp16.cpp b/example/32_batched_gemm_scale_softmax_gemm/batched_gemm_scale_softmax_gemm_permute_wmma_fp16.cpp new file mode 100644 index 0000000000..d9ab645ee9 --- /dev/null +++ b/example/32_batched_gemm_scale_softmax_gemm/batched_gemm_scale_softmax_gemm_permute_wmma_fp16.cpp @@ -0,0 +1,288 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +/* +Gemm + Softmax + Gemm fused operation. Computes C_g_m_n = Softmax(A_g_m_k * B0_g_k_l) * B1_g_l_n + |-----------------| + Gemm0 + |-------------------------------------| + Gemm1 +*/ + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/tensor_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_softmax.hpp" + +template +using S = ck::Sequence; + +using F16 = ck::half_t; +using F32 = float; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using ADataType = F16; +using B0DataType = F16; +using B1DataType = F16; +using Acc0DataType = F32; +using Acc1DataType = F32; +using CShuffleDataType = F32; +using CDataType = F16; +using Acc0BiasDataType = ck::Tuple<>; +using Acc1BiasDataType = ck::Tuple<>; + +static constexpr ck::index_t NumDimG = 2; +static constexpr ck::index_t NumDimM = 1; +static constexpr ck::index_t NumDimN = 1; +static constexpr ck::index_t NumDimK = 1; +static constexpr ck::index_t NumDimO = 1; + +using AElementOp = PassThrough; +using B0ElementOp = PassThrough; +using Acc0ElementOp = ck::tensor_operation::element_wise::Scale; +using B1ElementOp = PassThrough; +using CElementOp = PassThrough; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKOPadding; +static constexpr auto MaskingSpec = + ck::tensor_operation::device::MaskingSpecialization::MaskDisabled; + +static constexpr auto TensorSpecA = ck::tensor_operation::device::TensorSpecialization::Default; +static constexpr auto TensorSpecB0 = ck::tensor_operation::device::TensorSpecialization::Default; +static constexpr auto TensorSpecB1 = ck::tensor_operation::device::TensorSpecialization::Default; +static constexpr auto TensorSpecC = ck::tensor_operation::device::TensorSpecialization::Default; + +// clang-format off +// #define CK_MHA_USE_WAVE_1 +// #define CK_MHA_USE_WAVE_2 +// #define CK_MHA_USE_WAVE_4 +#define CK_MHA_USE_WAVE_8 +using DeviceMHAFactory = + std::tuple< +#ifdef CK_MHA_USE_WAVE_1 + // 1 wave, mrepeat = 1, nrepeat = 2, k/o repeat = 1~5 + ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle< + NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, + ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType, + AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp, + GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1, + 32, + // Gemm 0 + 16, 128, 64, 8, 8, + // Gemm 1 + 64, 64, 8, + 16, 16, 16, + // Per repeat = wave_m = wave_num, wave_n = 1 + 1, 8, 4, + // ABlockTransfer MK -> K0 M K1 + S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B0BlockTransfer LK -> K0 L K1 + S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B1BlockTransfer NL -> L0 N L1 + S<2, 2, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 1, false, + // CShuffleBlockTransfer MN + 1, 1, S<1, 16, 1, 2>, 8, + MaskingSpec>, + ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle< + NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, + ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType, + AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp, + GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1, + 32, + // Gemm 0 + 16, 64, 64, 8, 8, + // Gemm 1 + 64, 64, 8, + 16, 16, 16, + // Per repeat = wave_m = wave_num, wave_n = 1 + 1, 4, 4, + // ABlockTransfer MK -> K0 M K1 + S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B0BlockTransfer LK -> K0 L K1 + S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B1BlockTransfer NL -> L0 N L1 + S<2, 2, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 1, false, + // CShuffleBlockTransfer MN + 1, 1, S<1, 16, 1, 2>, 8, + MaskingSpec>, +#endif +#ifdef CK_MHA_USE_WAVE_2 + ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle< + NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, + ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType, + AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp, + GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1, + 64, + // Gemm 0 + 32, 128, 64, 8, 8, + // Gemm 1 + 64, 64, 8, + 16, 16, 16, + // Per repeat = wave_m = wave_num, wave_n = 1 + 1, 8, 4, + // ABlockTransfer MK -> K0 M K1 + S<2, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B0BlockTransfer LK -> K0 L K1 + S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B1BlockTransfer NL -> L0 N L1 + S<2, 4, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, false, + // CShuffleBlockTransfer MN + 1, 1, S<1, 32, 1, 2>, 8, + MaskingSpec>, + ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle< + NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, + ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType, + AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp, + GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1, + 64, + // Gemm 0 + 32, 64, 64, 8, 8, + // Gemm 1 + 64, 64, 8, + 16, 16, 16, + // Per repeat = wave_m = wave_num, wave_n = 1 + 1, 4, 4, + // ABlockTransfer MK -> K0 M K1 + S<2, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B0BlockTransfer LK -> K0 L K1 + S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B1BlockTransfer NL -> L0 N L1 + S<2, 4, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, false, + // CShuffleBlockTransfer MN + 1, 1, S<1, 32, 1, 2>, 8, + MaskingSpec>, +#endif +#ifdef CK_MHA_USE_WAVE_4 + ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle< + NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, + ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType, + AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp, + GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1, + 128, + // Gemm 0 + 64, 128, 64, 8, 8, + // Gemm 1 + 64, 64, 8, + 16, 16, 16, + // Per repeat = wave_m = wave_num, wave_n = 1 + 1, 8, 4, + // ABlockTransfer MK -> K0 M K1 + S<2, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B0BlockTransfer LK -> K0 L K1 + S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B1BlockTransfer NL -> L0 N L1 + S<2, 8, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, false, + // CShuffleBlockTransfer MN + 1, 1, S<1, 64, 1, 2>, 8, + MaskingSpec>, + ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle< + NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, + ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType, + AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp, + GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1, + 128, + // Gemm 0 + 64, 64, 64, 8, 8, + // Gemm 1 + 64, 64, 8, + 16, 16, 16, + // Per repeat = wave_m = wave_num, wave_n = 1 + 1, 4, 4, + // ABlockTransfer MK -> K0 M K1 + S<2, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B0BlockTransfer LK -> K0 L K1 + S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B1BlockTransfer NL -> L0 N L1 + S<2, 8, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, false, + // CShuffleBlockTransfer MN + 1, 1, S<1, 64, 1, 2>, 8, + MaskingSpec>, +#endif +#ifdef CK_MHA_USE_WAVE_8 + ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle< + NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, + ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType, + AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp, + GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1, + 256, + // Gemm 0 + 128, 128, 64, 8, 8, + // Gemm 1 + 64, 64, 8, + 16, 16, 16, + // Per repeat = wave_m = wave_num, wave_n = 1 + 1, 8, 4, + // ABlockTransfer MK -> K0 M K1 + S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B0BlockTransfer LK -> K0 L K1 + S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B1BlockTransfer NL -> L0 N L1 + S<2, 16, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, false, + // CShuffleBlockTransfer MN + 1, 1, S<1, 128, 1, 2>, 8, + MaskingSpec>, + ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle< + NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, + ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType, + AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp, + GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1, + 256, + // Gemm 0 + 128, 128, 64, 8, 8, + // Gemm 1 + 64, 64, 8, + 16, 16, 16, + // Per repeat = wave_m = wave_num, wave_n = 1 + 1, 8, 4, + // ABlockTransfer MK -> K0 M K1 + S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B0BlockTransfer LK -> K0 L K1 + S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B1BlockTransfer NL -> L0 N L1 + S<2, 16, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, false, + // CShuffleBlockTransfer MN + 1, 1, S<1, 128, 1, 2>, 8, + MaskingSpec> +#endif + >; +// clang-format on +// Ref Gemm0: fp16 in, fp32 out +using ReferenceGemm0Instance = ck::tensor_operation::host::ReferenceBatchedGemm; + +// Ref Softmax: fp32 in, fp16 out +using ReferenceSoftmaxInstance = + ck::tensor_operation::host::ReferenceSoftmax; + +// Ref Gemm1: fp16 in, fp16 out +using ReferenceGemm1Instance = ck::tensor_operation::host::ReferenceBatchedGemm; + +#include "run_batched_gemm_scale_softmax_gemm_permute_wmma.inc" + +int main(int argc, char* argv[]) { return run(argc, argv); } diff --git a/example/32_batched_gemm_scale_softmax_gemm/cross_attention_forward_wmma_fp16.cpp b/example/32_batched_gemm_scale_softmax_gemm/cross_attention_forward_wmma_fp16.cpp new file mode 100644 index 0000000000..4c92c5497f --- /dev/null +++ b/example/32_batched_gemm_scale_softmax_gemm/cross_attention_forward_wmma_fp16.cpp @@ -0,0 +1,354 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +/* +Gemm + Softmax + Gemm fused operation. Computes C_g_m_n = Softmax(A_g_m_k * B0_g_k_l) * B1_g_l_n + |-----------------| + Gemm0 + |-------------------------------------| + Gemm1 +*/ + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/tensor_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_softmax.hpp" + +template +using S = ck::Sequence; + +using F16 = ck::half_t; +using F32 = float; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using ADataType = F16; +using B0DataType = F16; +using B1DataType = F16; +using Acc0DataType = F32; +using Acc1DataType = F32; +using CShuffleDataType = F32; +using CDataType = F16; +using Acc0BiasDataType = ck::Tuple<>; +using Acc1BiasDataType = ck::Tuple<>; + +static constexpr ck::index_t NumDimG = 2; +static constexpr ck::index_t NumDimM = 1; +static constexpr ck::index_t NumDimN = 1; +static constexpr ck::index_t NumDimK = 1; +static constexpr ck::index_t NumDimO = 1; + +using AElementOp = PassThrough; +using B0ElementOp = PassThrough; +using Acc0ElementOp = ck::tensor_operation::element_wise::Scale; +using B1ElementOp = PassThrough; +using CElementOp = PassThrough; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKOPadding; +static constexpr auto MaskingSpec = + ck::tensor_operation::device::MaskingSpecialization::MaskDisabled; + +static constexpr auto TensorSpecA = ck::tensor_operation::device::TensorSpecialization::Default; +static constexpr auto TensorSpecB0 = ck::tensor_operation::device::TensorSpecialization::Default; +static constexpr auto TensorSpecB1 = ck::tensor_operation::device::TensorSpecialization::Default; +static constexpr auto TensorSpecC = ck::tensor_operation::device::TensorSpecialization::Default; + +// clang-format off +#define CK_MHA_USE_WAVE_1 +#define CK_MHA_USE_WAVE_2 +#define CK_MHA_USE_WAVE_4 +#define CK_MHA_USE_WAVE_8 +using DeviceMHAFactory = + std::tuple< +#ifdef CK_MHA_USE_WAVE_1 + // 1 wave, mrepeat = 1, nrepeat = 2, k/o repeat = 1~5 + ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle< + NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, + ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType, + AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp, + GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1, + 32, + // Gemm 0 + 16, 32, 160, 8, 8, + // Gemm 1 + 80, 32, 8, + 16, 16, 16, + // Per repeat = wave_m = wave_num, wave_n = 1 + 1, 2, 5, + // ABlockTransfer MK -> K0 M K1 + S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B0BlockTransfer LK -> K0 L K1 + S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B1BlockTransfer NL -> L0 N L1 + S<2, 2, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 1, false, + // CShuffleBlockTransfer MN + 1, 1, S<1, 16, 1, 2>, 8, + MaskingSpec>, + ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle< + NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, + ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType, + AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp, + GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1, + 32, + // Gemm 0 + 16, 64, 80, 8, 8, + // Gemm 1 + 80, 64, 8, + 16, 16, 16, + // Per repeat = wave_m = wave_num, wave_n = 1 + 1, 4, 5, + // ABlockTransfer MK -> K0 M K1 + S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B0BlockTransfer LK -> K0 L K1 + S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B1BlockTransfer NL -> L0 N L1 + S<2, 2, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 1, false, + // CShuffleBlockTransfer MN + 1, 1, S<1, 16, 1, 2>, 8, + MaskingSpec>, + ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle< + NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, + ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType, + AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp, + GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1, + 32, + // Gemm 0 + 16, 64, 48, 8, 8, + // Gemm 1 + 48, 64, 8, + 16, 16, 16, + // Per repeat = wave_m = wave_num, wave_n = 1 + 1, 4, 3, + // ABlockTransfer MK -> K0 M K1 + S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B0BlockTransfer LK -> K0 L K1 + S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B1BlockTransfer NL -> L0 N L1 + S<2, 2, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 1, false, + // CShuffleBlockTransfer MN + 1, 1, S<1, 16, 1, 2>, 8, + MaskingSpec>, +#endif +#ifdef CK_MHA_USE_WAVE_2 + ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle< + NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, + ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType, + AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp, + GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1, + 64, + // Gemm 0 + 32, 64, 48, 8, 8, + // Gemm 1 + 48, 64, 8, + 16, 16, 16, + // Per repeat = wave_m = wave_num, wave_n = 1 + 1, 4, 3, + // ABlockTransfer MK -> K0 M K1 + S<2, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B0BlockTransfer LK -> K0 L K1 + S<2, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B1BlockTransfer NL -> L0 N L1 + S<2, 4, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, false, + // CShuffleBlockTransfer MN + 1, 1, S<1, 32, 1, 2>, 8, + MaskingSpec>, + ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle< + NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, + ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType, + AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp, + GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1, + 64, + // Gemm 0 + 32, 64, 80, 8, 8, + // Gemm 1 + 80, 64, 8, + 16, 16, 16, + // Per repeat = wave_m = wave_num, wave_n = 1 + 1, 4, 5, + // ABlockTransfer MK -> K0 M K1 + S<2, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B0BlockTransfer LK -> K0 L K1 + S<2, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B1BlockTransfer NL -> L0 N L1 + S<2, 4, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, false, + // CShuffleBlockTransfer MN + 1, 1, S<1, 32, 1, 2>, 8, + MaskingSpec>, + ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle< + NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, + ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType, + AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp, + GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1, + 64, + // Gemm 0 + 32, 32, 160, 8, 8, + // Gemm 1 + 80, 32, 8, + 16, 16, 16, + // Per repeat = wave_m = wave_num, wave_n = 1 + 1, 2, 5, + // ABlockTransfer MK -> K0 M K1 + S<2, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B0BlockTransfer LK -> K0 L K1 + S<2, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B1BlockTransfer NL -> L0 N L1 + S<2, 4, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, false, + // CShuffleBlockTransfer MN + 1, 1, S<1, 32, 1, 2>, 8, + MaskingSpec>, +#endif +#ifdef CK_MHA_USE_WAVE_4 + ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle< + NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, + ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType, + AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp, + GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1, + 128, + // Gemm 0 + 64, 128, 80, 8, 8, + // Gemm 1 + 80, 64, 8, + 16, 16, 16, + // Per repeat = wave_m = wave_num, wave_n = 1 + 1, 8, 5, + // ABlockTransfer MK -> K0 M K1 + S<2, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B0BlockTransfer LK -> K0 L K1 + S<2, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B1BlockTransfer NL -> L0 N L1 + S<2, 8, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, false, + // CShuffleBlockTransfer MN + 1, 1, S<1, 64, 1, 2>, 8, + MaskingSpec>, + ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle< + NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, + ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType, + AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp, + GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1, + 128, + // Gemm 0 + 64, 192, 48, 8, 8, + // Gemm 1 + 48, 64, 8, + 16, 16, 16, + // Per repeat = wave_m = wave_num, wave_n = 1 + 1, 12, 3, + // ABlockTransfer MK -> K0 M K1 + S<2, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B0BlockTransfer LK -> K0 L K1 + S<2, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B1BlockTransfer NL -> L0 N L1 + S<2, 8, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, false, + // CShuffleBlockTransfer MN + 1, 1, S<1, 64, 1, 2>, 8, + MaskingSpec>, + ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle< + NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, + ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType, + AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp, + GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1, + 128, + // Gemm 0 + 64, 64, 48, 8, 8, + // Gemm 1 + 48, 64, 8, + 16, 16, 16, + // Per repeat = wave_m = wave_num, wave_n = 1 + 1, 4, 3, + // ABlockTransfer MK -> K0 M K1 + S<2, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B0BlockTransfer LK -> K0 L K1 + S<2, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B1BlockTransfer NL -> L0 N L1 + S<2, 8, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, false, + // CShuffleBlockTransfer MN + 1, 1, S<1, 64, 1, 2>, 8, + MaskingSpec>, +#endif +#ifdef CK_MHA_USE_WAVE_8 + ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle< + NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, + ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType, + AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp, + GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1, + 256, + // Gemm 0 + 128, 192, 48, 8,4, + // Gemm 1 + 48, 64, 8, + 16, 16, 16, + // Per repeat = wave_m = wave_num, wave_n = 1 + 1, 12, 3, + // ABlockTransfer MK -> K0 M K1 + S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B0BlockTransfer LK -> K0 L K1 + S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, + // B1BlockTransfer NL -> L0 N L1 + S<2, 16, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, false, + // CShuffleBlockTransfer MN + 1, 1, S<1, 128, 1, 2>, 8, + MaskingSpec>, + ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle< + NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, + ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType, + AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp, + GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1, + 256, + // Gemm 0 + 128, 64, 48, 8,4, + // Gemm 1 + 48, 64, 8, + 16, 16, 16, + // Per repeat = wave_m = wave_num, wave_n = 1 + 1, 4, 3, + // ABlockTransfer MK -> K0 M K1 + S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B0BlockTransfer LK -> K0 L K1 + S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, + // B1BlockTransfer NL -> L0 N L1 + S<2, 16, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, false, + // CShuffleBlockTransfer MN + 1, 1, S<1, 128, 1, 2>, 8, + MaskingSpec> +#endif + >; +// clang-format on +// Ref Gemm0: fp16 in, fp32 out +using ReferenceGemm0Instance = ck::tensor_operation::host::ReferenceBatchedGemm; + +// Ref Softmax: fp32 in, fp16 out +using ReferenceSoftmaxInstance = + ck::tensor_operation::host::ReferenceSoftmax; + +// Ref Gemm1: fp16 in, fp16 out +using ReferenceGemm1Instance = ck::tensor_operation::host::ReferenceBatchedGemm; + +#include "run_cross_attention_wmma.inc" + +int main(int argc, char* argv[]) { return run(argc, argv); } diff --git a/example/32_batched_gemm_scale_softmax_gemm/grouped_query_attention_forward_wmma_fp16.cpp b/example/32_batched_gemm_scale_softmax_gemm/grouped_query_attention_forward_wmma_fp16.cpp new file mode 100644 index 0000000000..12dcfcc36d --- /dev/null +++ b/example/32_batched_gemm_scale_softmax_gemm/grouped_query_attention_forward_wmma_fp16.cpp @@ -0,0 +1,302 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +/* +Grouped Query Attention, +Ainslie, Joshua, James Lee-Thorp, Michiel de Jong, Yury Zemlyanskiy, Federico Lebrón, and Sumit +Sanghai. “GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints.” +arXiv, May 22, 2023. https://doi.org/10.48550/arXiv.2305.13245. + +Example is GQA-4 +*/ + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/tensor_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_query_attention_forward_wmma.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_softmax.hpp" + +template +using S = ck::Sequence; + +using F16 = ck::half_t; +using F32 = float; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using ADataType = F16; +using B0DataType = F16; +using B1DataType = F16; +using Acc0DataType = F32; +using Acc1DataType = F32; +using CShuffleDataType = F32; +using CDataType = F16; +using Acc0BiasDataType = ck::Tuple<>; +using Acc1BiasDataType = ck::Tuple<>; + +static constexpr ck::index_t NumDimG = 2; +static constexpr ck::index_t NumDimM = 1; +static constexpr ck::index_t NumDimN = 1; +static constexpr ck::index_t NumDimK = 1; +static constexpr ck::index_t NumDimO = 1; +static constexpr ck::index_t QueryGroupNumber = 4; + +using AElementOp = PassThrough; +using B0ElementOp = PassThrough; +using Acc0ElementOp = ck::tensor_operation::element_wise::Scale; +using B1ElementOp = PassThrough; +using CElementOp = PassThrough; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKOPadding; +static constexpr auto MaskingSpec = + ck::tensor_operation::device::MaskingSpecialization::MaskDisabled; + +static constexpr auto TensorSpecA = ck::tensor_operation::device::TensorSpecialization::Default; +static constexpr auto TensorSpecB0 = ck::tensor_operation::device::TensorSpecialization::Default; +static constexpr auto TensorSpecB1 = ck::tensor_operation::device::TensorSpecialization::Default; +static constexpr auto TensorSpecC = ck::tensor_operation::device::TensorSpecialization::Default; + +// clang-format off +// #define CK_MHA_USE_WAVE_1 +// #define CK_MHA_USE_WAVE_2 +// #define CK_MHA_USE_WAVE_4 +#define CK_MHA_USE_WAVE_8 +using DeviceMHAFactory = + std::tuple< +#ifdef CK_MHA_USE_WAVE_1 + // 1 wave, mrepeat = 1, nrepeat = 2, k/o repeat = 1~5 + ck::tensor_operation::device::DeviceGroupedQueryAttentionForward_Wmma< + NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, + ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType, + AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp, + GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1, + QueryGroupNumber, + 32, + // Gemm 0 + 16, 128, 64, 8, 8, + // Gemm 1 + 64, 64, 8, + 16, 16, 16, + // Per repeat = wave_m = wave_num, wave_n = 1 + 1, 8, 4, + // ABlockTransfer MK -> K0 M K1 + S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B0BlockTransfer LK -> K0 L K1 + S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B1BlockTransfer NL -> L0 N L1 + S<2, 2, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 1, false, + // CShuffleBlockTransfer MN + 1, 1, S<1, 16, 1, 2>, 8, + MaskingSpec>, + ck::tensor_operation::device::DeviceGroupedQueryAttentionForward_Wmma< + NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, + ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType, + AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp, + GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1, + QueryGroupNumber, + 32, + // Gemm 0 + 16, 64, 64, 8, 8, + // Gemm 1 + 64, 64, 8, + 16, 16, 16, + // Per repeat = wave_m = wave_num, wave_n = 1 + 1, 4, 4, + // ABlockTransfer MK -> K0 M K1 + S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B0BlockTransfer LK -> K0 L K1 + S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B1BlockTransfer NL -> L0 N L1 + S<2, 2, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 1, false, + // CShuffleBlockTransfer MN + 1, 1, S<1, 16, 1, 2>, 8, + MaskingSpec>, +#endif +#ifdef CK_MHA_USE_WAVE_2 + ck::tensor_operation::device::DeviceGroupedQueryAttentionForward_Wmma< + NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, + ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType, + AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp, + GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1, + QueryGroupNumber, + 64, + // Gemm 0 + 32, 128, 64, 8, 8, + // Gemm 1 + 64, 64, 8, + 16, 16, 16, + // Per repeat = wave_m = wave_num, wave_n = 1 + 1, 8, 4, + // ABlockTransfer MK -> K0 M K1 + S<2, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B0BlockTransfer LK -> K0 L K1 + S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B1BlockTransfer NL -> L0 N L1 + S<2, 4, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, false, + // CShuffleBlockTransfer MN + 1, 1, S<1, 32, 1, 2>, 8, + MaskingSpec>, + ck::tensor_operation::device::DeviceGroupedQueryAttentionForward_Wmma< + NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, + ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType, + AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp, + GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1, + QueryGroupNumber, + 64, + // Gemm 0 + 32, 64, 64, 8, 8, + // Gemm 1 + 64, 64, 8, + 16, 16, 16, + // Per repeat = wave_m = wave_num, wave_n = 1 + 1, 4, 4, + // ABlockTransfer MK -> K0 M K1 + S<2, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B0BlockTransfer LK -> K0 L K1 + S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B1BlockTransfer NL -> L0 N L1 + S<2, 4, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, false, + // CShuffleBlockTransfer MN + 1, 1, S<1, 32, 1, 2>, 8, + MaskingSpec>, +#endif +#ifdef CK_MHA_USE_WAVE_4 + ck::tensor_operation::device::DeviceGroupedQueryAttentionForward_Wmma< + NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, + ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType, + AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp, + GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1, + QueryGroupNumber, + 128, + // Gemm 0 + 64, 128, 64, 8, 8, + // Gemm 1 + 64, 64, 8, + 16, 16, 16, + // Per repeat = wave_m = wave_num, wave_n = 1 + 1, 8, 4, + // ABlockTransfer MK -> K0 M K1 + S<2, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B0BlockTransfer LK -> K0 L K1 + S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B1BlockTransfer NL -> L0 N L1 + S<2, 8, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, false, + // CShuffleBlockTransfer MN + 1, 1, S<1, 64, 1, 2>, 8, + MaskingSpec>, + ck::tensor_operation::device::DeviceGroupedQueryAttentionForward_Wmma< + NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, + ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType, + AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp, + GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1, + QueryGroupNumber, + 128, + // Gemm 0 + 64, 64, 64, 8, 8, + // Gemm 1 + 64, 64, 8, + 16, 16, 16, + // Per repeat = wave_m = wave_num, wave_n = 1 + 1, 4, 4, + // ABlockTransfer MK -> K0 M K1 + S<2, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B0BlockTransfer LK -> K0 L K1 + S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B1BlockTransfer NL -> L0 N L1 + S<2, 8, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, false, + // CShuffleBlockTransfer MN + 1, 1, S<1, 64, 1, 2>, 8, + MaskingSpec>, +#endif +#ifdef CK_MHA_USE_WAVE_8 + ck::tensor_operation::device::DeviceGroupedQueryAttentionForward_Wmma< + NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, + ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType, + AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp, + GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1, + QueryGroupNumber, + 256, + // Gemm 0 + 128, 128, 64, 8, 8, + // Gemm 1 + 64, 64, 8, + 16, 16, 16, + // Per repeat = wave_m = wave_num, wave_n = 1 + 1, 8, 4, + // ABlockTransfer MK -> K0 M K1 + S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B0BlockTransfer LK -> K0 L K1 + S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B1BlockTransfer NL -> L0 N L1 + S<2, 16, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, false, + // CShuffleBlockTransfer MN + 1, 1, S<1, 128, 1, 2>, 8, + MaskingSpec>, + ck::tensor_operation::device::DeviceGroupedQueryAttentionForward_Wmma< + NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, + ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType, + AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp, + GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1, + QueryGroupNumber, + 256, + // Gemm 0 + 128, 128, 64, 8, 8, + // Gemm 1 + 64, 64, 8, + 16, 16, 16, + // Per repeat = wave_m = wave_num, wave_n = 1 + 1, 8, 4, + // ABlockTransfer MK -> K0 M K1 + S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B0BlockTransfer LK -> K0 L K1 + S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B1BlockTransfer NL -> L0 N L1 + S<2, 16, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, false, + // CShuffleBlockTransfer MN + 1, 1, S<1, 128, 1, 2>, 8, + MaskingSpec> +#endif + >; +// clang-format on +// Ref Gemm0: fp16 in, fp32 out +using ReferenceGemm0Instance = + ck::tensor_operation::host::ReferenceBatchedGemm_GQA; + +// Ref Softmax: fp32 in, fp16 out +using ReferenceSoftmaxInstance = + ck::tensor_operation::host::ReferenceSoftmax; + +// Ref Gemm1: fp16 in, fp16 out +using ReferenceGemm1Instance = + ck::tensor_operation::host::ReferenceBatchedGemm_GQA; + +#include "run_grouped_query_attention_forward_wmma.inc" + +int main(int argc, char* argv[]) { return run(argc, argv); } diff --git a/example/32_batched_gemm_scale_softmax_gemm/multi_query_attention_forward_wmma_fp16.cpp b/example/32_batched_gemm_scale_softmax_gemm/multi_query_attention_forward_wmma_fp16.cpp new file mode 100644 index 0000000000..694a320a45 --- /dev/null +++ b/example/32_batched_gemm_scale_softmax_gemm/multi_query_attention_forward_wmma_fp16.cpp @@ -0,0 +1,287 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +/* +Multi-Query Attention +Shazeer, Noam. “Fast Transformer Decoding: One Write-Head Is All You Need.” arXiv.org, November 6, +2019. https://arxiv.org/abs/1911.02150v1. + +*/ + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/tensor_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_multi_query_attention_forward_wmma.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_softmax.hpp" + +template +using S = ck::Sequence; + +using F16 = ck::half_t; +using F32 = float; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using ADataType = F16; +using B0DataType = F16; +using B1DataType = F16; +using Acc0DataType = F32; +using Acc1DataType = F32; +using CShuffleDataType = F32; +using CDataType = F16; +using Acc0BiasDataType = ck::Tuple<>; +using Acc1BiasDataType = ck::Tuple<>; + +static constexpr ck::index_t NumDimG = 2; +static constexpr ck::index_t NumDimM = 1; +static constexpr ck::index_t NumDimN = 1; +static constexpr ck::index_t NumDimK = 1; +static constexpr ck::index_t NumDimO = 1; + +using AElementOp = PassThrough; +using B0ElementOp = PassThrough; +using Acc0ElementOp = ck::tensor_operation::element_wise::Scale; +using B1ElementOp = PassThrough; +using CElementOp = PassThrough; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKOPadding; +static constexpr auto MaskingSpec = + ck::tensor_operation::device::MaskingSpecialization::MaskDisabled; + +static constexpr auto TensorSpecA = ck::tensor_operation::device::TensorSpecialization::Default; +static constexpr auto TensorSpecB0 = ck::tensor_operation::device::TensorSpecialization::Default; +static constexpr auto TensorSpecB1 = ck::tensor_operation::device::TensorSpecialization::Default; +static constexpr auto TensorSpecC = ck::tensor_operation::device::TensorSpecialization::Default; + +// clang-format off +// #define CK_MHA_USE_WAVE_1 +// #define CK_MHA_USE_WAVE_2 +// #define CK_MHA_USE_WAVE_4 +#define CK_MHA_USE_WAVE_8 +using DeviceMHAFactory = + std::tuple< +#ifdef CK_MHA_USE_WAVE_1 + // 1 wave, mrepeat = 1, nrepeat = 2, k/o repeat = 1~5 + ck::tensor_operation::device::DeviceMultiQueryAttentionForward_Wmma< + NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, + ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType, + AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp, + GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1, + 32, + // Gemm 0 + 16, 128, 64, 8, 8, + // Gemm 1 + 64, 64, 8, + 16, 16, 16, + // Per repeat = wave_m = wave_num, wave_n = 1 + 1, 8, 4, + // ABlockTransfer MK -> K0 M K1 + S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B0BlockTransfer LK -> K0 L K1 + S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B1BlockTransfer NL -> L0 N L1 + S<2, 2, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 1, false, + // CShuffleBlockTransfer MN + 1, 1, S<1, 16, 1, 2>, 8, + MaskingSpec>, + ck::tensor_operation::device::DeviceMultiQueryAttentionForward_Wmma< + NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, + ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType, + AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp, + GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1, + 32, + // Gemm 0 + 16, 64, 64, 8, 8, + // Gemm 1 + 64, 64, 8, + 16, 16, 16, + // Per repeat = wave_m = wave_num, wave_n = 1 + 1, 4, 4, + // ABlockTransfer MK -> K0 M K1 + S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B0BlockTransfer LK -> K0 L K1 + S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B1BlockTransfer NL -> L0 N L1 + S<2, 2, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 1, false, + // CShuffleBlockTransfer MN + 1, 1, S<1, 16, 1, 2>, 8, + MaskingSpec>, +#endif +#ifdef CK_MHA_USE_WAVE_2 + ck::tensor_operation::device::DeviceMultiQueryAttentionForward_Wmma< + NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, + ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType, + AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp, + GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1, + 64, + // Gemm 0 + 32, 128, 64, 8, 8, + // Gemm 1 + 64, 64, 8, + 16, 16, 16, + // Per repeat = wave_m = wave_num, wave_n = 1 + 1, 8, 4, + // ABlockTransfer MK -> K0 M K1 + S<2, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B0BlockTransfer LK -> K0 L K1 + S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B1BlockTransfer NL -> L0 N L1 + S<2, 4, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, false, + // CShuffleBlockTransfer MN + 1, 1, S<1, 32, 1, 2>, 8, + MaskingSpec>, + ck::tensor_operation::device::DeviceMultiQueryAttentionForward_Wmma< + NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, + ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType, + AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp, + GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1, + 64, + // Gemm 0 + 32, 64, 64, 8, 8, + // Gemm 1 + 64, 64, 8, + 16, 16, 16, + // Per repeat = wave_m = wave_num, wave_n = 1 + 1, 4, 4, + // ABlockTransfer MK -> K0 M K1 + S<2, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B0BlockTransfer LK -> K0 L K1 + S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B1BlockTransfer NL -> L0 N L1 + S<2, 4, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, false, + // CShuffleBlockTransfer MN + 1, 1, S<1, 32, 1, 2>, 8, + MaskingSpec>, +#endif +#ifdef CK_MHA_USE_WAVE_4 + ck::tensor_operation::device::DeviceMultiQueryAttentionForward_Wmma< + NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, + ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType, + AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp, + GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1, + 128, + // Gemm 0 + 64, 128, 64, 8, 8, + // Gemm 1 + 64, 64, 8, + 16, 16, 16, + // Per repeat = wave_m = wave_num, wave_n = 1 + 1, 8, 4, + // ABlockTransfer MK -> K0 M K1 + S<2, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B0BlockTransfer LK -> K0 L K1 + S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B1BlockTransfer NL -> L0 N L1 + S<2, 8, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, false, + // CShuffleBlockTransfer MN + 1, 1, S<1, 64, 1, 2>, 8, + MaskingSpec>, + ck::tensor_operation::device::DeviceMultiQueryAttentionForward_Wmma< + NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, + ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType, + AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp, + GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1, + 128, + // Gemm 0 + 64, 64, 64, 8, 8, + // Gemm 1 + 64, 64, 8, + 16, 16, 16, + // Per repeat = wave_m = wave_num, wave_n = 1 + 1, 4, 4, + // ABlockTransfer MK -> K0 M K1 + S<2, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B0BlockTransfer LK -> K0 L K1 + S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B1BlockTransfer NL -> L0 N L1 + S<2, 8, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, false, + // CShuffleBlockTransfer MN + 1, 1, S<1, 64, 1, 2>, 8, + MaskingSpec>, +#endif +#ifdef CK_MHA_USE_WAVE_8 + ck::tensor_operation::device::DeviceMultiQueryAttentionForward_Wmma< + NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, + ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType, + AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp, + GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1, + 256, + // Gemm 0 + 128, 128, 64, 8, 8, + // Gemm 1 + 64, 64, 8, + 16, 16, 16, + // Per repeat = wave_m = wave_num, wave_n = 1 + 1, 8, 4, + // ABlockTransfer MK -> K0 M K1 + S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B0BlockTransfer LK -> K0 L K1 + S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B1BlockTransfer NL -> L0 N L1 + S<2, 16, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, false, + // CShuffleBlockTransfer MN + 1, 1, S<1, 128, 1, 2>, 8, + MaskingSpec>, + ck::tensor_operation::device::DeviceMultiQueryAttentionForward_Wmma< + NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, + ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType, + AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp, + GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1, + 256, + // Gemm 0 + 128, 128, 64, 8, 8, + // Gemm 1 + 64, 64, 8, + 16, 16, 16, + // Per repeat = wave_m = wave_num, wave_n = 1 + 1, 8, 4, + // ABlockTransfer MK -> K0 M K1 + S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B0BlockTransfer LK -> K0 L K1 + S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B1BlockTransfer NL -> L0 N L1 + S<2, 16, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, false, + // CShuffleBlockTransfer MN + 1, 1, S<1, 128, 1, 2>, 8, + MaskingSpec> +#endif + >; +// clang-format on +// Ref Gemm0: fp16 in, fp32 out +using ReferenceGemm0Instance = ck::tensor_operation::host::ReferenceBatchedGemm_MQA; + +// Ref Softmax: fp32 in, fp16 out +using ReferenceSoftmaxInstance = + ck::tensor_operation::host::ReferenceSoftmax; + +// Ref Gemm1: fp16 in, fp16 out +using ReferenceGemm1Instance = ck::tensor_operation::host::ReferenceBatchedGemm_MQA; + +#include "run_multi_query_attention_forward_wmma.inc" + +int main(int argc, char* argv[]) { return run(argc, argv); } diff --git a/example/32_batched_gemm_scale_softmax_gemm/run_batched_gemm_scale_softmax_gemm_permute_wmma.inc b/example/32_batched_gemm_scale_softmax_gemm/run_batched_gemm_scale_softmax_gemm_permute_wmma.inc new file mode 100644 index 0000000000..2e77479bcc --- /dev/null +++ b/example/32_batched_gemm_scale_softmax_gemm/run_batched_gemm_scale_softmax_gemm_permute_wmma.inc @@ -0,0 +1,340 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +int run(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; + + // GEMM shape for A/B0/B1/C + // C_g_m_o = A_g_m_k * B0_g_k_n * B1_g_n_o + ck::index_t M = 120; + ck::index_t N = 1000; + ck::index_t K = 64; + ck::index_t O = 128; + + // Output shape C[G0, M, G1, O]. Batch dim, outer dim, inner dim must match GEMM shape + // C_g0_g1_m_o = reshape(C_g_m_o, [g0, g1, m, o]) + // C_g0_m_g1_o = permute(C_g0_g1_m_o, [0, 2, 1, 3]) + ck::index_t G0 = 7; + ck::index_t G1 = 13; + + float alpha = 1; + + bool input_permute = false; + bool output_permute = true; + + if(argc == 1) + { + // use default case + } + else if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else if(argc == 13) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + + M = std::stoi(argv[4]); + N = std::stoi(argv[5]); + K = std::stoi(argv[6]); + O = std::stoi(argv[7]); + G0 = std::stoi(argv[8]); + G1 = std::stoi(argv[9]); + + alpha = std::stof(argv[10]); + + input_permute = std::stoi(argv[11]); + output_permute = std::stoi(argv[12]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=no, 1=yes)\n"); + printf("arg4 to 11: M, N, K, O, G0, G1\n"); + printf("arg10: scale (alpha)\n"); + printf("arg11 to 12: input / output permute\n"); + exit(0); + } + + std::vector a_gs_ms_ks_lengths{G0, G1, M, K}; + std::vector a_gs_ms_ks_strides = + input_permute + ? std::vector{M * G1 * K, K, G1 * K, 1} // A layout [G0, M, G1, K] + : std::vector{G1 * M * K, M * K, K, 1}; // A layout [G0, G1, M, K] + + std::vector b0_gs_ns_ks_lengths{G0, G1, N, K}; + std::vector b0_gs_ns_ks_strides = + input_permute + ? std::vector{N * G1 * K, K, G1 * K, 1} // B0 layout [G0, N, G1, K] + : std::vector{G1 * N * K, N * K, K, 1}; // B0 layout [G0, G1, N, K] + + std::vector b1_gs_os_ns_lengths{G0, G1, O, N}; + std::vector b1_gs_os_ns_strides = + input_permute + ? std::vector{N * G1 * O, O, 1, G1 * O} // B1 layout [G0, N, G1, O] + : std::vector{G1 * N * O, N * O, 1, O}; // B1 layout [G0, G1, N, O] + + std::vector c_gs_ms_os_lengths{G0, G1, M, O}; + std::vector c_gs_ms_os_strides = + output_permute + ? std::vector{M * G1 * O, O, G1 * O, 1} // C layout [G0, M, G1, O] + : std::vector{G1 * M * O, M * O, O, 1}; // C layout [G0, G1, M, O] + + Tensor a_gs_ms_ks(a_gs_ms_ks_lengths, a_gs_ms_ks_strides); + Tensor b0_gs_ns_ks(b0_gs_ns_ks_lengths, b0_gs_ns_ks_strides); + Tensor b1_gs_os_ns(b1_gs_os_ns_lengths, b1_gs_os_ns_strides); + Tensor c_gs_ms_os_host_result(c_gs_ms_os_lengths, c_gs_ms_os_strides); + Tensor c_gs_ms_os_device_result(c_gs_ms_os_lengths, c_gs_ms_os_strides); + + std::cout << "a_gs_ms_ks: " << a_gs_ms_ks.mDesc << std::endl; + std::cout << "b0_gs_ns_ks: " << b0_gs_ns_ks.mDesc << std::endl; + std::cout << "b1_gs_os_ns: " << b1_gs_os_ns.mDesc << std::endl; + std::cout << "c_gs_ms_os: " << c_gs_ms_os_host_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + break; + case 2: + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + break; + case 3: + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal{}); + b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal{}); + break; + case 4: // A, B0, B1 1 + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1{}); + b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_1{}); + b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_1{}); + break; + case 5: // Rand: b1 b0; unit: a + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1{}); + b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + break; + case 6: // Rand: a b0 ; unit: B1 + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_1{}); + break; + case 7: // Rand: a b1 ; unit: b0 + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_1{}); + b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + break; + case 8: // Rand: a ; unit: b0 b1 + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_1{}); + b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_1{}); + break; + case 9: // Rand: b0 ; unit: a b1 + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1{}); + b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_1{}); + break; + case 10: // Rand: b1 ; unit: a b0 + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1{}); + b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_1{}); + b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + break; + default: + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_Sequential<2>{}); + b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal{}); + b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal{}); + } + + DeviceMem a_device_buf(sizeof(ADataType) * a_gs_ms_ks.mDesc.GetElementSpaceSize()); + DeviceMem b0_device_buf(sizeof(B0DataType) * b0_gs_ns_ks.mDesc.GetElementSpaceSize()); + DeviceMem b1_device_buf(sizeof(B1DataType) * b1_gs_os_ns.mDesc.GetElementSpaceSize()); + DeviceMem c_device_buf(sizeof(CDataType) * + c_gs_ms_os_device_result.mDesc.GetElementSpaceSize()); + + a_device_buf.ToDevice(a_gs_ms_ks.mData.data()); + b0_device_buf.ToDevice(b0_gs_ns_ks.mData.data()); + b1_device_buf.ToDevice(b1_gs_os_ns.mData.data()); + + auto a_element_op = AElementOp{}; + auto b0_element_op = B0ElementOp{}; + auto acc0_element_op = Acc0ElementOp{alpha}; + auto b1_element_op = B1ElementOp{}; + auto c_element_op = CElementOp{}; + + // do GEMM + float best_perf = .0; + float best_time = .0; + int not_pass = 0; + std::string best_kernel = ""; + printf("Verification: %s\n", do_verification ? "ON" : "OFF"); + // TODO ANT: replace array with vector? + ck::static_for<0, std::tuple_size_v, 1>{}([&](auto i) -> void { + const auto device_mha_instance = std::get(DeviceMHAFactory{}); + + using DeviceMHAInstance = ck::remove_cvref_t; + auto gemm = DeviceMHAInstance{}; + auto invoker = gemm.MakeInvoker(); + auto argument = gemm.MakeArgument(static_cast(a_device_buf.GetDeviceBuffer()), + static_cast(b0_device_buf.GetDeviceBuffer()), + static_cast(b1_device_buf.GetDeviceBuffer()), + static_cast(c_device_buf.GetDeviceBuffer()), + M, + N, + K, + O, + G0, + G1, + alpha, + input_permute, + output_permute); + + if(!gemm.IsSupportedArgument(argument)) + { + std::cout << gemm.GetTypeString() << " does not support this problem" << std::endl; + + // return 0; + } + + ck::index_t BatchCount = G0 * G1; + + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + std::size_t flop = (size_t(M) * N * K * 2 + size_t(M) * N * O * 2) * BatchCount; + std::size_t num_btype = (sizeof(ADataType) * M * K + sizeof(B0DataType) * K * N + + sizeof(B1DataType) * N * O + sizeof(CDataType) * M * O) * + BatchCount; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec + << " GB/s, " << gemm.GetTypeString() << std::endl; + if(tflops > best_perf) + { + best_perf = tflops; + best_time = ave_time * 1000; + best_kernel = gemm.GetTypeString(); + } + if(do_verification) + { + c_device_buf.FromDevice(c_gs_ms_os_device_result.mData.data()); + + Tensor a_g_m_k({BatchCount, M, K}); + Tensor b0_g_k_n({BatchCount, K, N}); + Tensor b1_g_n_o({BatchCount, N, O}); + Tensor acc0_g_m_n({BatchCount, M, N}); // scratch object after gemm0 + Tensor a1_g_m_n({BatchCount, M, N}); // scratch object after softmax + Tensor c_g_m_o_host_result({BatchCount, M, O}); // scratch object after gemm1 + + // permute + a_gs_ms_ks.ForEach([&](auto& self, auto idx) { + a_g_m_k(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx); + }); + b0_gs_ns_ks.ForEach([&](auto& self, auto idx) { + b0_g_k_n(idx[0] * G1 + idx[1], idx[3], idx[2]) = self(idx); + }); + b1_gs_os_ns.ForEach([&](auto& self, auto idx) { + b1_g_n_o(idx[0] * G1 + idx[1], idx[3], idx[2]) = self(idx); + }); + + // gemm 0 + auto ref_gemm0 = ReferenceGemm0Instance{}; + auto ref_gemm0_invoker = ref_gemm0.MakeInvoker(); + auto ref_gemm0_argument = ref_gemm0.MakeArgument( + a_g_m_k, b0_g_k_n, acc0_g_m_n, a_element_op, b0_element_op, acc0_element_op); + + ref_gemm0_invoker.Run(ref_gemm0_argument); + + // masking + const auto mask = typename DeviceMHAInstance::C0MatrixMask(N); + acc0_g_m_n.ForEach([&](auto& self, auto idx) { + if(mask.IsMaskedElement(idx[1], idx[2])) + self(idx) = -ck::NumericLimits::Infinity(); + }); + + // softmax + auto ref_softmax = ReferenceSoftmaxInstance{}; + auto ref_softmax_invoker = ref_softmax.MakeInvoker(); + auto ref_softmax_argument = ref_softmax.MakeArgument(acc0_g_m_n, a1_g_m_n, 1, 0, {2}); + + ref_softmax_invoker.Run(ref_softmax_argument); + + // gemm1 + auto ref_gemm1 = ReferenceGemm1Instance{}; + auto ref_gemm1_invoker = ref_gemm1.MakeInvoker(); + auto ref_gemm1_argument = ref_gemm1.MakeArgument(a1_g_m_n, + b1_g_n_o, + c_g_m_o_host_result, + PassThrough{}, + b1_element_op, + c_element_op); + + ref_gemm1_invoker.Run(ref_gemm1_argument); + + // permute + c_gs_ms_os_host_result.ForEach([&](auto& self, auto idx) { + const size_t& g0 = idx[0]; + const size_t& g1 = idx[1]; + + const size_t g = g0 * G1 + g1; + + self(idx) = c_g_m_o_host_result(g, idx[2], idx[3]); + }); + + // default absolute error and relative error is 0.001 + double rtol = 1e-3; + double atol = 1e-3; + + // when BF16 is taken, set absolute error and relative error to 0.01 + if(std::is_same_v && std::is_same_v && + std::is_same_v && std::is_same_v) + { + rtol = 1e-2; + atol = 1e-2; + } + + bool this_run_verification = ck::utils::check_err(c_gs_ms_os_device_result.mData, + c_gs_ms_os_host_result.mData, + "Error: Incorrect results!", + rtol, + atol); + printf("Verification: %s, Pass: %s\n", + do_verification ? "ON" : "OFF", + this_run_verification ? "YES" : "NO"); + + if(!this_run_verification) + { + not_pass = 1; + printf("%d th MHA instance verification Failed \n", i.value); + } + } + }); + std::cout << "---------------------------------------------------------------------------------" + "-----------" + << std::endl; + std::cout << "Problem Size: BatchCount: " << G0 << ", HeadNum: " << G1 << ", M: " << M + << ", N: " << N << ", K: " << K << ", O: " << O << std::endl; + std::cout << "---------------------------------------------------------------------------------" + "-----------" + << std::endl; + std::cout << "Best kernel: " << best_kernel << " , " << best_perf << " TFlops , " << best_time + << " us" << std::endl; + std::cout << "---------------------------------------------------------------------------------" + "-----------" + << std::endl; + return not_pass; +} diff --git a/example/32_batched_gemm_scale_softmax_gemm/run_cross_attention_wmma.inc b/example/32_batched_gemm_scale_softmax_gemm/run_cross_attention_wmma.inc new file mode 100644 index 0000000000..9ff4c56e06 --- /dev/null +++ b/example/32_batched_gemm_scale_softmax_gemm/run_cross_attention_wmma.inc @@ -0,0 +1,384 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +int run(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; + + // GEMM shape for A/B0/B1/C + // C_g_m_o = A_g_m_k * B0_g_k_n * B1_g_n_o + ck::index_t q_sequence_length = 256; + ck::index_t kv_sequence_length = 64; + ck::index_t head_dim = 80; + + // Output shape C[batch_size, q_sequence_length, head_num, head_dim]. Batch dim, outer dim, + // inner dim must match GEMM shape C_g0_g1_m_o = reshape(C_g_m_o, [g0, g1, m, o]) C_g0_m_g1_o = + // permute(C_g0_g1_m_o, [0, 2, 1, 3]) + ck::index_t batch_size = 2; + ck::index_t head_num = 8; + + float alpha = 1; + bool input_permute = true; + bool output_permute = true; + + if(argc == 1) + { + // use default case + } + else if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else if(argc == 10) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + + q_sequence_length = std::stoi(argv[4]); + kv_sequence_length = std::stoi(argv[5]); + head_dim = std::stoi(argv[6]); + batch_size = std::stoi(argv[7]); + head_num = std::stoi(argv[8]); + + alpha = std::stof(argv[9]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=no, 1=yes)\n"); + printf( + "arg4 to 8: q_sequence_length, kv_sequence_length, head_dim, batch_size, head_num\n"); + printf("arg9: scale (alpha)\n"); + exit(0); + } + + std::vector a_gs_ms_ks_lengths{batch_size, head_num, q_sequence_length, head_dim}; + std::vector a_gs_ms_ks_strides = + input_permute ? std::vector{q_sequence_length * head_num * head_dim, + head_dim, + head_num * head_dim, + 1} + // A layout [batch_size, q_sequence_length, head_num, head_dim] + : std::vector{ + head_num * q_sequence_length * head_dim, + q_sequence_length * head_dim, + head_dim, + 1}; // A layout [batch_size, head_num, q_sequence_length, head_dim] + + std::vector b0_gs_ns_ks_lengths{ + batch_size, head_num, kv_sequence_length, head_dim}; + std::vector b0_gs_ns_ks_strides = + input_permute ? std::vector{kv_sequence_length * head_num * head_dim, + head_dim, + head_num * head_dim, + 1} + // B0 layout [batch_size, kv_sequence_length, head_num, head_dim] + : std::vector{ + head_num * kv_sequence_length * head_dim, + kv_sequence_length * head_dim, + head_dim, + 1}; // B0 layout [batch_size, head_num, kv_sequence_length, head_dim] + + std::vector b1_gs_os_ns_lengths{ + batch_size, head_num, head_dim, kv_sequence_length}; + std::vector b1_gs_os_ns_strides = + input_permute + ? std::vector{kv_sequence_length * head_num * head_dim, + head_dim, + 1, + head_num * head_dim} + // B1 layout [batch_size, kv_sequence_length, head_num, head_dim] + : std::vector{ + head_num * kv_sequence_length * head_dim, + kv_sequence_length * head_dim, + 1, + head_dim}; // B1 layout [batch_size, head_num, kv_sequence_length, head_dim] + + std::vector c_gs_ms_os_lengths{batch_size, head_num, q_sequence_length, head_dim}; + std::vector c_gs_ms_os_strides = + output_permute ? std::vector{q_sequence_length * head_num * head_dim, + head_dim, + head_num * head_dim, + 1} + // C layout [batch_size, q_sequence_length, head_num, head_dim] + : std::vector{ + head_num * q_sequence_length * head_dim, + q_sequence_length * head_dim, + head_dim, + 1}; // C layout [batch_size, head_num, q_sequence_length, head_dim] + + Tensor a_gs_ms_ks(a_gs_ms_ks_lengths, a_gs_ms_ks_strides); + Tensor b0_gs_ns_ks(b0_gs_ns_ks_lengths, b0_gs_ns_ks_strides); + Tensor b1_gs_os_ns(b1_gs_os_ns_lengths, b1_gs_os_ns_strides); + Tensor c_gs_ms_os_host_result(c_gs_ms_os_lengths, c_gs_ms_os_strides); + Tensor c_gs_ms_os_device_result(c_gs_ms_os_lengths, c_gs_ms_os_strides); + + std::cout << "a_gs_ms_ks: " << a_gs_ms_ks.mDesc << std::endl; + std::cout << "b0_gs_ns_ks: " << b0_gs_ns_ks.mDesc << std::endl; + std::cout << "b1_gs_os_ns: " << b1_gs_os_ns.mDesc << std::endl; + std::cout << "c_gs_ms_os: " << c_gs_ms_os_host_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + break; + case 2: + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + break; + case 3: + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal{}); + b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal{}); + break; + case 4: // A, B0, B1 1 + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1{}); + b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_1{}); + b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_1{}); + break; + case 5: // Rand: b1 b0; unit: a + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1{}); + b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + break; + case 6: // Rand: a b0 ; unit: B1 + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_1{}); + break; + case 7: // Rand: a b1 ; unit: b0 + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_1{}); + b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + break; + case 8: // Rand: a ; unit: b0 b1 + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_1{}); + b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_1{}); + break; + case 9: // Rand: b0 ; unit: a b1 + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1{}); + b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_1{}); + break; + case 10: // Rand: b1 ; unit: a b0 + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1{}); + b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_1{}); + b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + break; + default: + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_Sequential<2>{}); + b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal{}); + b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal{}); + } + + std::vector kv_gs_ns_ks_lengths{ + batch_size, head_num, kv_sequence_length, 2, head_dim}; + std::vector kv_gs_ns_ks_strides = std::vector{ + kv_sequence_length * head_num * 2 * head_dim, + 2 * head_dim, + head_num * 2 * head_dim, + head_dim, + 1}; // kv layout [batch_size, q_sequence_length, head_num, 2, head_dim] + Tensor kv_gs_ns_ks(kv_gs_ns_ks_lengths, kv_gs_ns_ks_strides); + // merge kv into a packed pointer send to device + b0_gs_ns_ks.ForEach( + [&](auto& self, auto idx) { kv_gs_ns_ks(idx[0], idx[1], idx[2], 0, idx[3]) = self(idx); }); + b1_gs_os_ns.ForEach( + [&](auto& self, auto idx) { kv_gs_ns_ks(idx[0], idx[1], idx[3], 1, idx[2]) = self(idx); }); + DeviceMem q_device_buf(sizeof(ADataType) * a_gs_ms_ks.mDesc.GetElementSpaceSize()); + DeviceMem kv_device_buf(sizeof(B0DataType) * b0_gs_ns_ks.mDesc.GetElementSpaceSize() + + sizeof(B1DataType) * b1_gs_os_ns.mDesc.GetElementSpaceSize()); + DeviceMem c_device_buf(sizeof(CDataType) * + c_gs_ms_os_device_result.mDesc.GetElementSpaceSize()); + q_device_buf.ToDevice(a_gs_ms_ks.mData.data()); + kv_device_buf.ToDevice(kv_gs_ns_ks.mData.data()); + + auto a_element_op = AElementOp{}; + auto b0_element_op = B0ElementOp{}; + auto acc0_element_op = Acc0ElementOp{alpha}; + auto b1_element_op = B1ElementOp{}; + auto c_element_op = CElementOp{}; + + // do GEMM + float best_perf = .0; + float best_time = .0; + int not_pass = 0; + std::string best_kernel = ""; + printf("Verification: %s\n", do_verification ? "ON" : "OFF"); + // TODO ANT: replace array with vector? + ck::static_for<0, std::tuple_size_v, 1>{}([&](auto i) -> void { + const auto device_mha_instance = std::get(DeviceMHAFactory{}); + + using DeviceMHAInstance = ck::remove_cvref_t; + auto gemm = DeviceMHAInstance{}; + auto invoker = gemm.MakeCrossAttnInvoker(); + auto argument = + gemm.MakeCrossAttnArgument(static_cast(q_device_buf.GetDeviceBuffer()), + static_cast(kv_device_buf.GetDeviceBuffer()), + static_cast(c_device_buf.GetDeviceBuffer()), + batch_size, + q_sequence_length, + kv_sequence_length, + head_num, + head_dim, + alpha); + + // if(!gemm.IsSupportedArgument(argument)) + // { + // std::cout << gemm.GetTypeString() << " does not support this problem" << std::endl; + + // return 0; + // } + + ck::index_t BatchCount = batch_size * head_num; + + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + std::size_t flop = (size_t(q_sequence_length) * kv_sequence_length * head_dim * 2 + + size_t(q_sequence_length) * kv_sequence_length * head_dim * 2) * + BatchCount; + std::size_t num_btype = (sizeof(ADataType) * q_sequence_length * head_dim + + sizeof(B0DataType) * head_dim * kv_sequence_length + + sizeof(B1DataType) * kv_sequence_length * head_dim + + sizeof(CDataType) * q_sequence_length * head_dim) * + BatchCount; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec + << " GB/s, " << gemm.GetTypeString() << std::endl; + if(tflops > best_perf) + { + best_perf = tflops; + best_time = ave_time * 1000; + best_kernel = gemm.GetTypeString(); + } + if(do_verification) + { + c_device_buf.FromDevice(c_gs_ms_os_device_result.mData.data()); + + Tensor a_g_m_k({BatchCount, q_sequence_length, head_dim}); + Tensor b0_g_k_n({BatchCount, head_dim, kv_sequence_length}); + Tensor b1_g_n_o({BatchCount, kv_sequence_length, head_dim}); + Tensor acc0_g_m_n( + {BatchCount, q_sequence_length, kv_sequence_length}); // scratch object after gemm0 + Tensor a1_g_m_n({BatchCount, + q_sequence_length, + kv_sequence_length}); // scratch object after softmax + Tensor c_g_m_o_host_result( + {BatchCount, q_sequence_length, head_dim}); // scratch object after gemm1 + + // permute + a_gs_ms_ks.ForEach([&](auto& self, auto idx) { + a_g_m_k(idx[0] * head_num + idx[1], idx[2], idx[3]) = self(idx); + }); + b0_gs_ns_ks.ForEach([&](auto& self, auto idx) { + b0_g_k_n(idx[0] * head_num + idx[1], idx[3], idx[2]) = self(idx); + }); + b1_gs_os_ns.ForEach([&](auto& self, auto idx) { + b1_g_n_o(idx[0] * head_num + idx[1], idx[3], idx[2]) = self(idx); + }); + + // gemm 0 + auto ref_gemm0 = ReferenceGemm0Instance{}; + auto ref_gemm0_invoker = ref_gemm0.MakeInvoker(); + auto ref_gemm0_argument = ref_gemm0.MakeArgument( + a_g_m_k, b0_g_k_n, acc0_g_m_n, a_element_op, b0_element_op, acc0_element_op); + + ref_gemm0_invoker.Run(ref_gemm0_argument); + + // masking + const auto mask = typename DeviceMHAInstance::C0MatrixMask(kv_sequence_length); + acc0_g_m_n.ForEach([&](auto& self, auto idx) { + if(mask.IsMaskedElement(idx[1], idx[2])) + self(idx) = -ck::NumericLimits::Infinity(); + }); + + // softmax + auto ref_softmax = ReferenceSoftmaxInstance{}; + auto ref_softmax_invoker = ref_softmax.MakeInvoker(); + auto ref_softmax_argument = ref_softmax.MakeArgument(acc0_g_m_n, a1_g_m_n, 1, 0, {2}); + + ref_softmax_invoker.Run(ref_softmax_argument); + + // gemm1 + auto ref_gemm1 = ReferenceGemm1Instance{}; + auto ref_gemm1_invoker = ref_gemm1.MakeInvoker(); + auto ref_gemm1_argument = ref_gemm1.MakeArgument(a1_g_m_n, + b1_g_n_o, + c_g_m_o_host_result, + PassThrough{}, + b1_element_op, + c_element_op); + + ref_gemm1_invoker.Run(ref_gemm1_argument); + + // permute + c_gs_ms_os_host_result.ForEach([&](auto& self, auto idx) { + const size_t& g0 = idx[0]; + const size_t& g1 = idx[1]; + + const size_t g = g0 * head_num + g1; + + self(idx) = c_g_m_o_host_result(g, idx[2], idx[3]); + }); + + // default absolute error and relative error is 0.001 + double rtol = 1e-3; + double atol = 1e-3; + + // when BF16 is taken, set absolute error and relative error to 0.01 + if(std::is_same_v && std::is_same_v && + std::is_same_v && std::is_same_v) + { + rtol = 1e-2; + atol = 1e-2; + } + + bool this_run_verification = ck::utils::check_err(c_gs_ms_os_device_result.mData, + c_gs_ms_os_host_result.mData, + "Error: Incorrect results!", + rtol, + atol); + printf("Verification: %s, Pass: %s\n", + do_verification ? "ON" : "OFF", + this_run_verification ? "YES" : "NO"); + + if(!this_run_verification) + { + not_pass = 1; + printf("%d th MHA instance verification Failed \n", i.value); + } + } + }); + std::cout << "---------------------------------------------------------------------------------" + "-----------" + << std::endl; + std::cout << "Problem Size: BatchCount: " << batch_size << ", HeadNum: " << head_num + << ", q_sequence_length: " << q_sequence_length + << ", kv_sequence_length: " << kv_sequence_length << ", head_dim: " << head_dim + << std::endl; + std::cout << "---------------------------------------------------------------------------------" + "-----------" + << std::endl; + std::cout << "Best kernel: " << best_kernel << " , " << best_perf << " TFlops , " << best_time + << " us" << std::endl; + std::cout << "---------------------------------------------------------------------------------" + "-----------" + << std::endl; + return not_pass; +} diff --git a/example/32_batched_gemm_scale_softmax_gemm/run_grouped_query_attention_forward_wmma.inc b/example/32_batched_gemm_scale_softmax_gemm/run_grouped_query_attention_forward_wmma.inc new file mode 100644 index 0000000000..609d085299 --- /dev/null +++ b/example/32_batched_gemm_scale_softmax_gemm/run_grouped_query_attention_forward_wmma.inc @@ -0,0 +1,340 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +int run(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; + + // GEMM shape for A/B0/B1/C + // C_g_m_o = A_g_m_k * B0_g_k_n * B1_g_n_o + ck::index_t M = 1024; + ck::index_t N = 1024; + ck::index_t K = 64; + ck::index_t O = 64; + + // Output shape C[G0, M, G1, O]. Batch dim, outer dim, inner dim must match GEMM shape + // C_g0_g1_m_o = reshape(C_g_m_o, [g0, g1, m, o]) + // C_g0_m_g1_o = permute(C_g0_g1_m_o, [0, 2, 1, 3]) + ck::index_t G0 = 4; + ck::index_t G1 = 16; + ck::index_t KV_head = QueryGroupNumber; + + float alpha = 1; + + bool input_permute = false; + bool output_permute = true; + + if(argc == 1) + { + // use default case + } + else if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else if(argc == 13) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + + M = std::stoi(argv[4]); + N = std::stoi(argv[5]); + K = std::stoi(argv[6]); + O = std::stoi(argv[7]); + G0 = std::stoi(argv[8]); + G1 = std::stoi(argv[9]); + + alpha = std::stof(argv[10]); + + input_permute = std::stoi(argv[11]); + output_permute = std::stoi(argv[12]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=no, 1=yes)\n"); + printf("arg4 to 11: M, N, K, O, G0, G1\n"); + printf("arg10: scale (alpha)\n"); + printf("arg11 to 12: input / output permute\n"); + exit(0); + } + + std::vector a_gs_ms_ks_lengths{G0, G1, M, K}; + std::vector a_gs_ms_ks_strides = + input_permute + ? std::vector{M * G1 * K, K, G1 * K, 1} // A layout [G0, M, G1, K] + : std::vector{G1 * M * K, M * K, K, 1}; // A layout [G0, G1, M, K] + + std::vector b0_gs_ns_ks_lengths{G0, KV_head, N, K}; + std::vector b0_gs_ns_ks_strides = + input_permute + ? std::vector{N * KV_head * K, K, KV_head * K, 1} + // B0 layout [G0, N, G1, K] + : std::vector{KV_head * N * K, N * K, K, 1}; // B0 layout [G0, G1, N, K] + + std::vector b1_gs_os_ns_lengths{G0, KV_head, O, N}; + std::vector b1_gs_os_ns_strides = + input_permute + ? std::vector{N * KV_head * O, O, 1, KV_head * O} + // B1 layout [G0, N, G1, O] + : std::vector{KV_head * N * O, N * O, 1, O}; // B1 layout [G0, G1, N, O] + + std::vector c_gs_ms_os_lengths{G0, G1, M, O}; + std::vector c_gs_ms_os_strides = + output_permute + ? std::vector{M * G1 * O, O, G1 * O, 1} // C layout [G0, M, G1, O] + : std::vector{G1 * M * O, M * O, O, 1}; // C layout [G0, G1, M, O] + + Tensor a_gs_ms_ks(a_gs_ms_ks_lengths, a_gs_ms_ks_strides); + Tensor b0_gs_ns_ks(b0_gs_ns_ks_lengths, b0_gs_ns_ks_strides); + Tensor b1_gs_os_ns(b1_gs_os_ns_lengths, b1_gs_os_ns_strides); + Tensor c_gs_ms_os_host_result(c_gs_ms_os_lengths, c_gs_ms_os_strides); + Tensor c_gs_ms_os_device_result(c_gs_ms_os_lengths, c_gs_ms_os_strides); + + std::cout << "a_gs_ms_ks: " << a_gs_ms_ks.mDesc << std::endl; + std::cout << "b0_gs_ns_ks: " << b0_gs_ns_ks.mDesc << std::endl; + std::cout << "b1_gs_os_ns: " << b1_gs_os_ns.mDesc << std::endl; + std::cout << "c_gs_ms_os: " << c_gs_ms_os_host_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + break; + case 2: + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + break; + case 3: + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal{}); + b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal{}); + break; + case 4: // A, B0, B1 1 + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1{}); + b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_1{}); + b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_1{}); + break; + case 5: // Rand: b1 b0; unit: a + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1{}); + b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + break; + case 6: // Rand: a b0 ; unit: B1 + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_1{}); + break; + case 7: // Rand: a b1 ; unit: b0 + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_1{}); + b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + break; + case 8: // Rand: a ; unit: b0 b1 + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_1{}); + b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_1{}); + break; + case 9: // Rand: b0 ; unit: a b1 + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1{}); + b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_1{}); + break; + case 10: // Rand: b1 ; unit: a b0 + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1{}); + b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_1{}); + b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + break; + default: + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_Sequential<2>{}); + b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal{}); + b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal{}); + } + + DeviceMem a_device_buf(sizeof(ADataType) * a_gs_ms_ks.mDesc.GetElementSpaceSize()); + DeviceMem b0_device_buf(sizeof(B0DataType) * b0_gs_ns_ks.mDesc.GetElementSpaceSize()); + DeviceMem b1_device_buf(sizeof(B1DataType) * b1_gs_os_ns.mDesc.GetElementSpaceSize()); + DeviceMem c_device_buf(sizeof(CDataType) * + c_gs_ms_os_device_result.mDesc.GetElementSpaceSize()); + + a_device_buf.ToDevice(a_gs_ms_ks.mData.data()); + b0_device_buf.ToDevice(b0_gs_ns_ks.mData.data()); + b1_device_buf.ToDevice(b1_gs_os_ns.mData.data()); + + auto a_element_op = AElementOp{}; + auto b0_element_op = B0ElementOp{}; + auto acc0_element_op = Acc0ElementOp{alpha}; + auto b1_element_op = B1ElementOp{}; + auto c_element_op = CElementOp{}; + + // do GEMM + float best_perf = .0; + float best_time = .0; + int not_pass = 0; + std::string best_kernel = ""; + printf("Verification: %s\n", do_verification ? "ON" : "OFF"); + // TODO ANT: replace array with vector? + ck::static_for<0, std::tuple_size_v, 1>{}([&](auto i) -> void { + const auto device_mha_instance = std::get(DeviceMHAFactory{}); + + using DeviceMHAInstance = ck::remove_cvref_t; + auto gemm = DeviceMHAInstance{}; + auto invoker = gemm.MakeInvoker(); + auto argument = gemm.MakeArgument(static_cast(a_device_buf.GetDeviceBuffer()), + static_cast(b0_device_buf.GetDeviceBuffer()), + static_cast(b1_device_buf.GetDeviceBuffer()), + static_cast(c_device_buf.GetDeviceBuffer()), + M, + N, + K, + O, + G0, + G1, + alpha, + input_permute, + output_permute); + + if(!gemm.IsSupportedArgument(argument)) + { + std::cout << gemm.GetTypeString() << " does not support this problem" << std::endl; + + // return 0; + } + + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + std::size_t flop = (size_t(M) * N * K * 2 + size_t(M) * N * O * 2) * G0 * G1; + std::size_t num_btype = + (sizeof(ADataType) * M * K + sizeof(CDataType) * M * O) * G0 * G1 + + (sizeof(B0DataType) * K * N + sizeof(B1DataType) * N * O) * G0 * QueryGroupNumber; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec + << " GB/s, " << gemm.GetTypeString() << std::endl; + if(tflops > best_perf) + { + best_perf = tflops; + best_time = ave_time * 1000; + best_kernel = gemm.GetTypeString(); + } + if(do_verification) + { + c_device_buf.FromDevice(c_gs_ms_os_device_result.mData.data()); + + Tensor a_g0_g1_m_k({G0, G1, M, K}); + Tensor b0_g0_gq_k_n({G0, QueryGroupNumber, K, N}); + Tensor b1_g0_gq_n_o({G0, QueryGroupNumber, N, O}); + Tensor acc0_g0_g1_m_n({G0, G1, M, N}); // scratch object after gemm0 + Tensor a1_g0_g1_m_n({G0, G1, M, N}); // scratch object after softmax + Tensor c_g0_g1_m_o_host_result({G0, G1, M, O}); // scratch object after gemm1 + + // permute + a_gs_ms_ks.ForEach([&](auto& self, auto idx) { + a_g0_g1_m_k(idx[0], idx[1], idx[2], idx[3]) = self(idx); + }); + b0_gs_ns_ks.ForEach([&](auto& self, auto idx) { + b0_g0_gq_k_n(idx[0], idx[1], idx[3], idx[2]) = self(idx); + }); + b1_gs_os_ns.ForEach([&](auto& self, auto idx) { + b1_g0_gq_n_o(idx[0], idx[1], idx[3], idx[2]) = self(idx); + }); + + // gemm 0 + auto ref_gemm0 = ReferenceGemm0Instance{}; + auto ref_gemm0_invoker = ref_gemm0.MakeInvoker(); + auto ref_gemm0_argument = ref_gemm0.MakeArgument(a_g0_g1_m_k, + b0_g0_gq_k_n, + acc0_g0_g1_m_n, + a_element_op, + b0_element_op, + acc0_element_op); + + ref_gemm0_invoker.Run(ref_gemm0_argument); + + // masking + const auto mask = typename DeviceMHAInstance::C0MatrixMask(N); + acc0_g0_g1_m_n.ForEach([&](auto& self, auto idx) { + if(mask.IsMaskedElement(idx[2], idx[3])) + self(idx) = -ck::NumericLimits::Infinity(); + }); + + // softmax + auto ref_softmax = ReferenceSoftmaxInstance{}; + auto ref_softmax_invoker = ref_softmax.MakeInvoker(); + auto ref_softmax_argument = + ref_softmax.MakeArgument(acc0_g0_g1_m_n, a1_g0_g1_m_n, 1, 0, {3}); + + ref_softmax_invoker.Run(ref_softmax_argument); + + // gemm1 + auto ref_gemm1 = ReferenceGemm1Instance{}; + auto ref_gemm1_invoker = ref_gemm1.MakeInvoker(); + auto ref_gemm1_argument = ref_gemm1.MakeArgument(a1_g0_g1_m_n, + b1_g0_gq_n_o, + c_g0_g1_m_o_host_result, + PassThrough{}, + b1_element_op, + c_element_op); + + ref_gemm1_invoker.Run(ref_gemm1_argument); + + // permute + c_gs_ms_os_host_result.ForEach( + [&](auto& self, auto idx) { self(idx) = c_g0_g1_m_o_host_result(idx); }); + + // default absolute error and relative error is 0.001 + double rtol = 1e-3; + double atol = 1e-3; + + // when BF16 is taken, set absolute error and relative error to 0.01 + if(std::is_same_v && std::is_same_v && + std::is_same_v && std::is_same_v) + { + rtol = 1e-2; + atol = 1e-2; + } + + bool this_run_verification = ck::utils::check_err(c_gs_ms_os_device_result.mData, + c_gs_ms_os_host_result.mData, + "Error: Incorrect results!", + rtol, + atol); + printf("Verification: %s, Pass: %s\n", + do_verification ? "ON" : "OFF", + this_run_verification ? "YES" : "NO"); + + if(!this_run_verification) + { + not_pass = 1; + printf("%d th MQA instance verification Failed \n", i.value); + } + } + }); + std::cout << "---------------------------------------------------------------------------------" + "-----------" + << std::endl; + std::cout << "Problem Size: BatchCount: " << G0 << ", HeadNum: " << G1 << ", M: " << M + << ", N: " << N << ", K: " << K << ", O: " << O << std::endl; + std::cout << "---------------------------------------------------------------------------------" + "-----------" + << std::endl; + std::cout << "Best kernel: " << best_kernel << " , " << best_perf << " TFlops , " << best_time + << " us" << std::endl; + std::cout << "---------------------------------------------------------------------------------" + "-----------" + << std::endl; + return not_pass; +} diff --git a/example/32_batched_gemm_scale_softmax_gemm/run_multi_query_attention_forward_wmma.inc b/example/32_batched_gemm_scale_softmax_gemm/run_multi_query_attention_forward_wmma.inc new file mode 100644 index 0000000000..b05915c07f --- /dev/null +++ b/example/32_batched_gemm_scale_softmax_gemm/run_multi_query_attention_forward_wmma.inc @@ -0,0 +1,339 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +int run(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; + + // GEMM shape for A/B0/B1/C + // C_g_m_o = A_g_m_k * B0_g_k_n * B1_g_n_o + ck::index_t M = 120; + ck::index_t N = 1000; + ck::index_t K = 64; + ck::index_t O = 128; + + // Output shape C[G0, M, G1, O]. Batch dim, outer dim, inner dim must match GEMM shape + // C_g0_g1_m_o = reshape(C_g_m_o, [g0, g1, m, o]) + // C_g0_m_g1_o = permute(C_g0_g1_m_o, [0, 2, 1, 3]) + ck::index_t G0 = 7; + ck::index_t G1 = 13; + ck::index_t KV_head = 1; + + float alpha = 1; + + bool input_permute = false; + bool output_permute = true; + + if(argc == 1) + { + // use default case + } + else if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else if(argc == 13) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + + M = std::stoi(argv[4]); + N = std::stoi(argv[5]); + K = std::stoi(argv[6]); + O = std::stoi(argv[7]); + G0 = std::stoi(argv[8]); + G1 = std::stoi(argv[9]); + + alpha = std::stof(argv[10]); + + input_permute = std::stoi(argv[11]); + output_permute = std::stoi(argv[12]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=no, 1=yes)\n"); + printf("arg4 to 11: M, N, K, O, G0, G1\n"); + printf("arg10: scale (alpha)\n"); + printf("arg11 to 12: input / output permute\n"); + exit(0); + } + + std::vector a_gs_ms_ks_lengths{G0, G1, M, K}; + std::vector a_gs_ms_ks_strides = + input_permute + ? std::vector{M * G1 * K, K, G1 * K, 1} // A layout [G0, M, G1, K] + : std::vector{G1 * M * K, M * K, K, 1}; // A layout [G0, G1, M, K] + + std::vector b0_gs_ns_ks_lengths{G0, KV_head, N, K}; + std::vector b0_gs_ns_ks_strides = + input_permute + ? std::vector{N * KV_head * K, K, KV_head * K, 1} + // B0 layout [G0, N, G1, K] + : std::vector{KV_head * N * K, N * K, K, 1}; // B0 layout [G0, G1, N, K] + + std::vector b1_gs_os_ns_lengths{G0, KV_head, O, N}; + std::vector b1_gs_os_ns_strides = + input_permute + ? std::vector{N * KV_head * O, O, 1, KV_head * O} + // B1 layout [G0, N, G1, O] + : std::vector{KV_head * N * O, N * O, 1, O}; // B1 layout [G0, G1, N, O] + + std::vector c_gs_ms_os_lengths{G0, G1, M, O}; + std::vector c_gs_ms_os_strides = + output_permute + ? std::vector{M * G1 * O, O, G1 * O, 1} // C layout [G0, M, G1, O] + : std::vector{G1 * M * O, M * O, O, 1}; // C layout [G0, G1, M, O] + + Tensor a_gs_ms_ks(a_gs_ms_ks_lengths, a_gs_ms_ks_strides); + Tensor b0_gs_ns_ks(b0_gs_ns_ks_lengths, b0_gs_ns_ks_strides); + Tensor b1_gs_os_ns(b1_gs_os_ns_lengths, b1_gs_os_ns_strides); + Tensor c_gs_ms_os_host_result(c_gs_ms_os_lengths, c_gs_ms_os_strides); + Tensor c_gs_ms_os_device_result(c_gs_ms_os_lengths, c_gs_ms_os_strides); + + std::cout << "a_gs_ms_ks: " << a_gs_ms_ks.mDesc << std::endl; + std::cout << "b0_gs_ns_ks: " << b0_gs_ns_ks.mDesc << std::endl; + std::cout << "b1_gs_os_ns: " << b1_gs_os_ns.mDesc << std::endl; + std::cout << "c_gs_ms_os: " << c_gs_ms_os_host_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + break; + case 2: + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + break; + case 3: + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal{}); + b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal{}); + break; + case 4: // A, B0, B1 1 + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1{}); + b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_1{}); + b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_1{}); + break; + case 5: // Rand: b1 b0; unit: a + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1{}); + b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + break; + case 6: // Rand: a b0 ; unit: B1 + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_1{}); + break; + case 7: // Rand: a b1 ; unit: b0 + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_1{}); + b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + break; + case 8: // Rand: a ; unit: b0 b1 + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_1{}); + b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_1{}); + break; + case 9: // Rand: b0 ; unit: a b1 + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1{}); + b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_1{}); + break; + case 10: // Rand: b1 ; unit: a b0 + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1{}); + b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_1{}); + b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + break; + default: + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_Sequential<2>{}); + b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal{}); + b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal{}); + } + + DeviceMem a_device_buf(sizeof(ADataType) * a_gs_ms_ks.mDesc.GetElementSpaceSize()); + DeviceMem b0_device_buf(sizeof(B0DataType) * b0_gs_ns_ks.mDesc.GetElementSpaceSize()); + DeviceMem b1_device_buf(sizeof(B1DataType) * b1_gs_os_ns.mDesc.GetElementSpaceSize()); + DeviceMem c_device_buf(sizeof(CDataType) * + c_gs_ms_os_device_result.mDesc.GetElementSpaceSize()); + + a_device_buf.ToDevice(a_gs_ms_ks.mData.data()); + b0_device_buf.ToDevice(b0_gs_ns_ks.mData.data()); + b1_device_buf.ToDevice(b1_gs_os_ns.mData.data()); + + auto a_element_op = AElementOp{}; + auto b0_element_op = B0ElementOp{}; + auto acc0_element_op = Acc0ElementOp{alpha}; + auto b1_element_op = B1ElementOp{}; + auto c_element_op = CElementOp{}; + + // do GEMM + float best_perf = .0; + float best_time = .0; + int not_pass = 0; + std::string best_kernel = ""; + printf("Verification: %s\n", do_verification ? "ON" : "OFF"); + // TODO ANT: replace array with vector? + ck::static_for<0, std::tuple_size_v, 1>{}([&](auto i) -> void { + const auto device_mha_instance = std::get(DeviceMHAFactory{}); + + using DeviceMHAInstance = ck::remove_cvref_t; + auto gemm = DeviceMHAInstance{}; + auto invoker = gemm.MakeInvoker(); + auto argument = gemm.MakeArgument(static_cast(a_device_buf.GetDeviceBuffer()), + static_cast(b0_device_buf.GetDeviceBuffer()), + static_cast(b1_device_buf.GetDeviceBuffer()), + static_cast(c_device_buf.GetDeviceBuffer()), + M, + N, + K, + O, + G0, + G1, + alpha, + input_permute, + output_permute); + + if(!gemm.IsSupportedArgument(argument)) + { + std::cout << gemm.GetTypeString() << " does not support this problem" << std::endl; + + // return 0; + } + + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + std::size_t flop = (size_t(M) * N * K * 2 + size_t(M) * N * O * 2) * G0 * G1; + std::size_t num_btype = (sizeof(ADataType) * M * K + sizeof(CDataType) * M * O) * G0 * G1 + + (sizeof(B0DataType) * K * N + sizeof(B1DataType) * N * O) * G0; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec + << " GB/s, " << gemm.GetTypeString() << std::endl; + if(tflops > best_perf) + { + best_perf = tflops; + best_time = ave_time * 1000; + best_kernel = gemm.GetTypeString(); + } + if(do_verification) + { + c_device_buf.FromDevice(c_gs_ms_os_device_result.mData.data()); + + Tensor a_g0_g1_m_k({G0, G1, M, K}); + Tensor b0_g0_1_k_n({G0, 1, K, N}); + Tensor b1_g0_1_n_o({G0, 1, N, O}); + Tensor acc0_g0_g1_m_n({G0, G1, M, N}); // scratch object after gemm0 + Tensor a1_g0_g1_m_n({G0, G1, M, N}); // scratch object after softmax + Tensor c_g0_g1_m_o_host_result({G0, G1, M, O}); // scratch object after gemm1 + + // permute + a_gs_ms_ks.ForEach([&](auto& self, auto idx) { + a_g0_g1_m_k(idx[0], idx[1], idx[2], idx[3]) = self(idx); + }); + b0_gs_ns_ks.ForEach([&](auto& self, auto idx) { + b0_g0_1_k_n(idx[0], idx[1], idx[3], idx[2]) = self(idx); + }); + b1_gs_os_ns.ForEach([&](auto& self, auto idx) { + b1_g0_1_n_o(idx[0], idx[1], idx[3], idx[2]) = self(idx); + }); + + // gemm 0 + auto ref_gemm0 = ReferenceGemm0Instance{}; + auto ref_gemm0_invoker = ref_gemm0.MakeInvoker(); + auto ref_gemm0_argument = ref_gemm0.MakeArgument(a_g0_g1_m_k, + b0_g0_1_k_n, + acc0_g0_g1_m_n, + a_element_op, + b0_element_op, + acc0_element_op); + + ref_gemm0_invoker.Run(ref_gemm0_argument); + + // masking + const auto mask = typename DeviceMHAInstance::C0MatrixMask(N); + acc0_g0_g1_m_n.ForEach([&](auto& self, auto idx) { + if(mask.IsMaskedElement(idx[2], idx[3])) + self(idx) = -ck::NumericLimits::Infinity(); + }); + + // softmax + auto ref_softmax = ReferenceSoftmaxInstance{}; + auto ref_softmax_invoker = ref_softmax.MakeInvoker(); + auto ref_softmax_argument = + ref_softmax.MakeArgument(acc0_g0_g1_m_n, a1_g0_g1_m_n, 1, 0, {3}); + + ref_softmax_invoker.Run(ref_softmax_argument); + + // gemm1 + auto ref_gemm1 = ReferenceGemm1Instance{}; + auto ref_gemm1_invoker = ref_gemm1.MakeInvoker(); + auto ref_gemm1_argument = ref_gemm1.MakeArgument(a1_g0_g1_m_n, + b1_g0_1_n_o, + c_g0_g1_m_o_host_result, + PassThrough{}, + b1_element_op, + c_element_op); + + ref_gemm1_invoker.Run(ref_gemm1_argument); + + // permute + c_gs_ms_os_host_result.ForEach( + [&](auto& self, auto idx) { self(idx) = c_g0_g1_m_o_host_result(idx); }); + + // default absolute error and relative error is 0.001 + double rtol = 1e-3; + double atol = 1e-3; + + // when BF16 is taken, set absolute error and relative error to 0.01 + if(std::is_same_v && std::is_same_v && + std::is_same_v && std::is_same_v) + { + rtol = 1e-2; + atol = 1e-2; + } + + bool this_run_verification = ck::utils::check_err(c_gs_ms_os_device_result.mData, + c_gs_ms_os_host_result.mData, + "Error: Incorrect results!", + rtol, + atol); + printf("Verification: %s, Pass: %s\n", + do_verification ? "ON" : "OFF", + this_run_verification ? "YES" : "NO"); + + if(!this_run_verification) + { + not_pass = 1; + printf("%d th MQA instance verification Failed \n", i.value); + } + } + }); + std::cout << "---------------------------------------------------------------------------------" + "-----------" + << std::endl; + std::cout << "Problem Size: BatchCount: " << G0 << ", HeadNum: " << G1 << ", M: " << M + << ", N: " << N << ", K: " << K << ", O: " << O << std::endl; + std::cout << "---------------------------------------------------------------------------------" + "-----------" + << std::endl; + std::cout << "Best kernel: " << best_kernel << " , " << best_perf << " TFlops , " << best_time + << " us" << std::endl; + std::cout << "---------------------------------------------------------------------------------" + "-----------" + << std::endl; + return not_pass; +} diff --git a/example/32_batched_gemm_scale_softmax_gemm/run_self_attention_wmma.inc b/example/32_batched_gemm_scale_softmax_gemm/run_self_attention_wmma.inc new file mode 100644 index 0000000000..3fdaaebb0f --- /dev/null +++ b/example/32_batched_gemm_scale_softmax_gemm/run_self_attention_wmma.inc @@ -0,0 +1,376 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +int run(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; + + // GEMM shape for A/B0/B1/C + // C_g_m_o = A_g_m_k * B0_g_k_n * B1_g_n_o + ck::index_t sequence_length = 256; + ck::index_t head_dim = 80; + + // Output shape C[batch_size, sequence_length, head_num, head_dim]. Batch dim, outer dim, inner + // dim must match GEMM shape C_g0_g1_m_o = reshape(C_g_m_o, [g0, g1, m, o]) C_g0_m_g1_o = + // permute(C_g0_g1_m_o, [0, 2, 1, 3]) + ck::index_t batch_size = 2; + ck::index_t head_num = 8; + + float alpha = 1; + bool input_permute = true; + bool output_permute = true; + + if(argc == 1) + { + // use default case + } + else if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else if(argc == 9) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + + sequence_length = std::stoi(argv[4]); + head_dim = std::stoi(argv[5]); + batch_size = std::stoi(argv[6]); + head_num = std::stoi(argv[7]); + + alpha = std::stof(argv[8]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=no, 1=yes)\n"); + printf("arg4 to 7: sequence_length, head_dim, batch_size, head_num\n"); + printf("arg8: scale (alpha)\n"); + exit(0); + } + + std::vector a_gs_ms_ks_lengths{batch_size, head_num, sequence_length, head_dim}; + std::vector a_gs_ms_ks_strides = + input_permute ? std::vector{sequence_length * head_num * head_dim, + head_dim, + head_num * head_dim, + 1} + // A layout [batch_size, sequence_length, head_num, head_dim] + : std::vector{ + head_num * sequence_length * head_dim, + sequence_length * head_dim, + head_dim, + 1}; // A layout [batch_size, head_num, sequence_length, head_dim] + + std::vector b0_gs_ns_ks_lengths{batch_size, head_num, sequence_length, head_dim}; + std::vector b0_gs_ns_ks_strides = + input_permute ? std::vector{sequence_length * head_num * head_dim, + head_dim, + head_num * head_dim, + 1} + // B0 layout [batch_size, sequence_length, head_num, head_dim] + : std::vector{ + head_num * sequence_length * head_dim, + sequence_length * head_dim, + head_dim, + 1}; // B0 layout [batch_size, head_num, sequence_length, head_dim] + + std::vector b1_gs_os_ns_lengths{batch_size, head_num, head_dim, sequence_length}; + std::vector b1_gs_os_ns_strides = + input_permute + ? std::vector{sequence_length * head_num * head_dim, + head_dim, + 1, + head_num * head_dim} + // B1 layout [batch_size, sequence_length, head_num, head_dim] + : std::vector{ + head_num * sequence_length * head_dim, + sequence_length * head_dim, + 1, + head_dim}; // B1 layout [batch_size, head_num, sequence_length, head_dim] + + std::vector c_gs_ms_os_lengths{batch_size, head_num, sequence_length, head_dim}; + std::vector c_gs_ms_os_strides = + output_permute ? std::vector{sequence_length * head_num * head_dim, + head_dim, + head_num * head_dim, + 1} + // C layout [batch_size, sequence_length, head_num, head_dim] + : std::vector{ + head_num * sequence_length * head_dim, + sequence_length * head_dim, + head_dim, + 1}; // C layout [batch_size, head_num, sequence_length, head_dim] + + Tensor a_gs_ms_ks(a_gs_ms_ks_lengths, a_gs_ms_ks_strides); + Tensor b0_gs_ns_ks(b0_gs_ns_ks_lengths, b0_gs_ns_ks_strides); + Tensor b1_gs_os_ns(b1_gs_os_ns_lengths, b1_gs_os_ns_strides); + Tensor c_gs_ms_os_host_result(c_gs_ms_os_lengths, c_gs_ms_os_strides); + Tensor c_gs_ms_os_device_result(c_gs_ms_os_lengths, c_gs_ms_os_strides); + + std::cout << "a_gs_ms_ks: " << a_gs_ms_ks.mDesc << std::endl; + std::cout << "b0_gs_ns_ks: " << b0_gs_ns_ks.mDesc << std::endl; + std::cout << "b1_gs_os_ns: " << b1_gs_os_ns.mDesc << std::endl; + std::cout << "c_gs_ms_os: " << c_gs_ms_os_host_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + break; + case 2: + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + break; + case 3: + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal{}); + b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal{}); + break; + case 4: // A, B0, B1 1 + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1{}); + b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_1{}); + b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_1{}); + break; + case 5: // Rand: b1 b0; unit: a + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1{}); + b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + break; + case 6: // Rand: a b0 ; unit: B1 + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_1{}); + break; + case 7: // Rand: a b1 ; unit: b0 + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_1{}); + b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + break; + case 8: // Rand: a ; unit: b0 b1 + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_1{}); + b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_1{}); + break; + case 9: // Rand: b0 ; unit: a b1 + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1{}); + b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_1{}); + break; + case 10: // Rand: b1 ; unit: a b0 + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1{}); + b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_1{}); + b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + break; + default: + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_Sequential<2>{}); + b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal{}); + b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal{}); + } + + std::vector qkv_gs_ms_ks_lengths{ + batch_size, head_num, sequence_length, 3, head_dim}; + std::vector qkv_gs_ms_ks_strides = std::vector{ + sequence_length * head_num * 3 * head_dim, + 3 * head_dim, + head_num * 3 * head_dim, + head_dim, + 1}; // qkv layout [batch_size, sequence_length, head_num, 3, head_dim] + Tensor qkv_gs_ms_ks(qkv_gs_ms_ks_lengths, qkv_gs_ms_ks_strides); + // merge qkv into a packed pointer send to device + a_gs_ms_ks.ForEach( + [&](auto& self, auto idx) { qkv_gs_ms_ks(idx[0], idx[1], idx[2], 0, idx[3]) = self(idx); }); + b0_gs_ns_ks.ForEach( + [&](auto& self, auto idx) { qkv_gs_ms_ks(idx[0], idx[1], idx[2], 1, idx[3]) = self(idx); }); + b1_gs_os_ns.ForEach( + [&](auto& self, auto idx) { qkv_gs_ms_ks(idx[0], idx[1], idx[3], 2, idx[2]) = self(idx); }); + DeviceMem qkv_device_buf(sizeof(ADataType) * a_gs_ms_ks.mDesc.GetElementSpaceSize() + + sizeof(B0DataType) * b0_gs_ns_ks.mDesc.GetElementSpaceSize() + + sizeof(B1DataType) * b1_gs_os_ns.mDesc.GetElementSpaceSize()); + DeviceMem c_device_buf(sizeof(CDataType) * + c_gs_ms_os_device_result.mDesc.GetElementSpaceSize()); + qkv_device_buf.ToDevice(qkv_gs_ms_ks.mData.data()); + + auto a_element_op = AElementOp{}; + auto b0_element_op = B0ElementOp{}; + auto acc0_element_op = Acc0ElementOp{alpha}; + auto b1_element_op = B1ElementOp{}; + auto c_element_op = CElementOp{}; + + // do GEMM + float best_perf = .0; + float best_time = .0; + int not_pass = 0; + std::string best_kernel = ""; + printf("Verification: %s\n", do_verification ? "ON" : "OFF"); + // TODO ANT: replace array with vector? + ck::static_for<0, std::tuple_size_v, 1>{}([&](auto i) -> void { + const auto device_mha_instance = std::get(DeviceMHAFactory{}); + + using DeviceMHAInstance = ck::remove_cvref_t; + auto gemm = DeviceMHAInstance{}; + auto invoker = gemm.MakeSelfAttnInvoker(); + auto argument = + gemm.MakeSelfAttnArgument(static_cast(qkv_device_buf.GetDeviceBuffer()), + static_cast(c_device_buf.GetDeviceBuffer()), + batch_size, + sequence_length, + head_num, + head_dim, + alpha); + + // if(!gemm.IsSupportedArgument(argument)) + // { + // std::cout << gemm.GetTypeString() << " does not support this problem" << std::endl; + + // return 0; + // } + + ck::index_t BatchCount = batch_size * head_num; + + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + std::size_t flop = (size_t(sequence_length) * sequence_length * head_dim * 2 + + size_t(sequence_length) * sequence_length * head_dim * 2) * + BatchCount; + std::size_t num_btype = (sizeof(ADataType) * sequence_length * head_dim + + sizeof(B0DataType) * head_dim * sequence_length + + sizeof(B1DataType) * sequence_length * head_dim + + sizeof(CDataType) * sequence_length * head_dim) * + BatchCount; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec + << " GB/s, " << gemm.GetTypeString() << std::endl; + if(tflops > best_perf) + { + best_perf = tflops; + best_time = ave_time * 1000; + best_kernel = gemm.GetTypeString(); + } + if(do_verification) + { + c_device_buf.FromDevice(c_gs_ms_os_device_result.mData.data()); + + Tensor a_g_m_k({BatchCount, sequence_length, head_dim}); + Tensor b0_g_k_n({BatchCount, head_dim, sequence_length}); + Tensor b1_g_n_o({BatchCount, sequence_length, head_dim}); + Tensor acc0_g_m_n( + {BatchCount, sequence_length, sequence_length}); // scratch object after gemm0 + Tensor a1_g_m_n( + {BatchCount, sequence_length, sequence_length}); // scratch object after softmax + Tensor c_g_m_o_host_result( + {BatchCount, sequence_length, head_dim}); // scratch object after gemm1 + + // permute + a_gs_ms_ks.ForEach([&](auto& self, auto idx) { + a_g_m_k(idx[0] * head_num + idx[1], idx[2], idx[3]) = self(idx); + }); + b0_gs_ns_ks.ForEach([&](auto& self, auto idx) { + b0_g_k_n(idx[0] * head_num + idx[1], idx[3], idx[2]) = self(idx); + }); + b1_gs_os_ns.ForEach([&](auto& self, auto idx) { + b1_g_n_o(idx[0] * head_num + idx[1], idx[3], idx[2]) = self(idx); + }); + + // gemm 0 + auto ref_gemm0 = ReferenceGemm0Instance{}; + auto ref_gemm0_invoker = ref_gemm0.MakeInvoker(); + auto ref_gemm0_argument = ref_gemm0.MakeArgument( + a_g_m_k, b0_g_k_n, acc0_g_m_n, a_element_op, b0_element_op, acc0_element_op); + + ref_gemm0_invoker.Run(ref_gemm0_argument); + + // masking + const auto mask = typename DeviceMHAInstance::C0MatrixMask(sequence_length); + acc0_g_m_n.ForEach([&](auto& self, auto idx) { + if(mask.IsMaskedElement(idx[1], idx[2])) + self(idx) = -ck::NumericLimits::Infinity(); + }); + + // softmax + auto ref_softmax = ReferenceSoftmaxInstance{}; + auto ref_softmax_invoker = ref_softmax.MakeInvoker(); + auto ref_softmax_argument = ref_softmax.MakeArgument(acc0_g_m_n, a1_g_m_n, 1, 0, {2}); + + ref_softmax_invoker.Run(ref_softmax_argument); + + // gemm1 + auto ref_gemm1 = ReferenceGemm1Instance{}; + auto ref_gemm1_invoker = ref_gemm1.MakeInvoker(); + auto ref_gemm1_argument = ref_gemm1.MakeArgument(a1_g_m_n, + b1_g_n_o, + c_g_m_o_host_result, + PassThrough{}, + b1_element_op, + c_element_op); + + ref_gemm1_invoker.Run(ref_gemm1_argument); + + // permute + c_gs_ms_os_host_result.ForEach([&](auto& self, auto idx) { + const size_t& g0 = idx[0]; + const size_t& g1 = idx[1]; + + const size_t g = g0 * head_num + g1; + + self(idx) = c_g_m_o_host_result(g, idx[2], idx[3]); + }); + + // default absolute error and relative error is 0.001 + double rtol = 1e-3; + double atol = 1e-3; + + // when BF16 is taken, set absolute error and relative error to 0.01 + if(std::is_same_v && std::is_same_v && + std::is_same_v && std::is_same_v) + { + rtol = 1e-2; + atol = 1e-2; + } + + bool this_run_verification = ck::utils::check_err(c_gs_ms_os_device_result.mData, + c_gs_ms_os_host_result.mData, + "Error: Incorrect results!", + rtol, + atol); + printf("Verification: %s, Pass: %s\n", + do_verification ? "ON" : "OFF", + this_run_verification ? "YES" : "NO"); + + if(!this_run_verification) + { + not_pass = 1; + printf("%d th MHA instance verification Failed \n", i.value); + } + } + }); + std::cout << "---------------------------------------------------------------------------------" + "-----------" + << std::endl; + std::cout << "Problem Size: BatchCount: " << batch_size << ", HeadNum: " << head_num + << ", sequence_length: " << sequence_length << ", head_dim: " << head_dim + << std::endl; + std::cout << "---------------------------------------------------------------------------------" + "-----------" + << std::endl; + std::cout << "Best kernel: " << best_kernel << " , " << best_perf << " TFlops , " << best_time + << " us" << std::endl; + std::cout << "---------------------------------------------------------------------------------" + "-----------" + << std::endl; + return not_pass; +} diff --git a/example/32_batched_gemm_scale_softmax_gemm/self_attention_forward_wmma_fp16.cpp b/example/32_batched_gemm_scale_softmax_gemm/self_attention_forward_wmma_fp16.cpp new file mode 100644 index 0000000000..8e037272b8 --- /dev/null +++ b/example/32_batched_gemm_scale_softmax_gemm/self_attention_forward_wmma_fp16.cpp @@ -0,0 +1,332 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +/* +Gemm + Softmax + Gemm fused operation. Computes C_g_m_n = Softmax(A_g_m_k * B0_g_k_l) * B1_g_l_n + |-----------------| + Gemm0 + |-------------------------------------| + Gemm1 +*/ + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/tensor_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_softmax.hpp" + +template +using S = ck::Sequence; + +using F16 = ck::half_t; +using F32 = float; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using ADataType = F16; +using B0DataType = F16; +using B1DataType = F16; +using Acc0DataType = F32; +using Acc1DataType = F32; +using CShuffleDataType = F32; +using CDataType = F16; +using Acc0BiasDataType = ck::Tuple<>; +using Acc1BiasDataType = ck::Tuple<>; + +static constexpr ck::index_t NumDimG = 2; +static constexpr ck::index_t NumDimM = 1; +static constexpr ck::index_t NumDimN = 1; +static constexpr ck::index_t NumDimK = 1; +static constexpr ck::index_t NumDimO = 1; + +using AElementOp = PassThrough; +using B0ElementOp = PassThrough; +using Acc0ElementOp = ck::tensor_operation::element_wise::Scale; +using B1ElementOp = PassThrough; +using CElementOp = PassThrough; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKOPadding; +static constexpr auto MaskingSpec = + ck::tensor_operation::device::MaskingSpecialization::MaskDisabled; + +static constexpr auto TensorSpecA = ck::tensor_operation::device::TensorSpecialization::Default; +static constexpr auto TensorSpecB0 = ck::tensor_operation::device::TensorSpecialization::Default; +static constexpr auto TensorSpecB1 = ck::tensor_operation::device::TensorSpecialization::Default; +static constexpr auto TensorSpecC = ck::tensor_operation::device::TensorSpecialization::Default; + +// clang-format off +#define CK_MHA_USE_WAVE_1 +#define CK_MHA_USE_WAVE_2 +#define CK_MHA_USE_WAVE_4 +#define CK_MHA_USE_WAVE_8 +using DeviceMHAFactory = + std::tuple< +#ifdef CK_MHA_USE_WAVE_1 + // 1 wave, mrepeat = 1, nrepeat = 2, k/o repeat = 1~5 + ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle< + NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, + ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType, + AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp, + GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1, + 32, + // Gemm 0 + 16, 32, 160, 8, 8, + // Gemm 1 + 80, 32, 8, + 16, 16, 16, + // Per repeat = wave_m = wave_num, wave_n = 1 + 1, 2, 5, + // ABlockTransfer MK -> K0 M K1 + S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B0BlockTransfer LK -> K0 L K1 + S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B1BlockTransfer NL -> L0 N L1 + S<2, 2, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 1, false, + // CShuffleBlockTransfer MN + 1, 1, S<1, 16, 1, 2>, 8, + MaskingSpec>, + ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle< + NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, + ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType, + AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp, + GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1, + 32, + // Gemm 0 + 16, 64, 80, 8, 8, + // Gemm 1 + 80, 64, 8, + 16, 16, 16, + // Per repeat = wave_m = wave_num, wave_n = 1 + 1, 4, 5, + // ABlockTransfer MK -> K0 M K1 + S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B0BlockTransfer LK -> K0 L K1 + S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B1BlockTransfer NL -> L0 N L1 + S<2, 2, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 1, false, + // CShuffleBlockTransfer MN + 1, 1, S<1, 16, 1, 2>, 8, + MaskingSpec>, + ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle< + NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, + ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType, + AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp, + GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1, + 32, + // Gemm 0 + 16, 64, 48, 8, 8, + // Gemm 1 + 48, 64, 8, + 16, 16, 16, + // Per repeat = wave_m = wave_num, wave_n = 1 + 1, 4, 3, + // ABlockTransfer MK -> K0 M K1 + S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B0BlockTransfer LK -> K0 L K1 + S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B1BlockTransfer NL -> L0 N L1 + S<2, 2, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 1, false, + // CShuffleBlockTransfer MN + 1, 1, S<1, 16, 1, 2>, 8, + MaskingSpec>, +#endif +#ifdef CK_MHA_USE_WAVE_2 + ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle< + NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, + ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType, + AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp, + GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1, + 64, + // Gemm 0 + 32, 64, 48, 8, 8, + // Gemm 1 + 48, 64, 8, + 16, 16, 16, + // Per repeat = wave_m = wave_num, wave_n = 1 + 1, 4, 3, + // ABlockTransfer MK -> K0 M K1 + S<2, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B0BlockTransfer LK -> K0 L K1 + S<2, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B1BlockTransfer NL -> L0 N L1 + S<2, 4, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, false, + // CShuffleBlockTransfer MN + 1, 1, S<1, 32, 1, 2>, 8, + MaskingSpec>, + ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle< + NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, + ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType, + AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp, + GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1, + 64, + // Gemm 0 + 32, 64, 80, 8, 8, + // Gemm 1 + 80, 64, 8, + 16, 16, 16, + // Per repeat = wave_m = wave_num, wave_n = 1 + 1, 4, 5, + // ABlockTransfer MK -> K0 M K1 + S<2, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B0BlockTransfer LK -> K0 L K1 + S<2, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B1BlockTransfer NL -> L0 N L1 + S<2, 4, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, false, + // CShuffleBlockTransfer MN + 1, 1, S<1, 32, 1, 2>, 8, + MaskingSpec>, + ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle< + NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, + ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType, + AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp, + GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1, + 64, + // Gemm 0 + 32, 32, 160, 8, 8, + // Gemm 1 + 80, 32, 8, + 16, 16, 16, + // Per repeat = wave_m = wave_num, wave_n = 1 + 1, 2, 5, + // ABlockTransfer MK -> K0 M K1 + S<2, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B0BlockTransfer LK -> K0 L K1 + S<2, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B1BlockTransfer NL -> L0 N L1 + S<2, 4, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, false, + // CShuffleBlockTransfer MN + 1, 1, S<1, 32, 1, 2>, 8, + MaskingSpec>, +#endif +#ifdef CK_MHA_USE_WAVE_4 + ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle< + NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, + ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType, + AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp, + GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1, + 128, + // Gemm 0 + 64, 128, 80, 8, 8, + // Gemm 1 + 80, 64, 8, + 16, 16, 16, + // Per repeat = wave_m = wave_num, wave_n = 1 + 1, 8, 5, + // ABlockTransfer MK -> K0 M K1 + S<2, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B0BlockTransfer LK -> K0 L K1 + S<2, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B1BlockTransfer NL -> L0 N L1 + S<2, 8, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, false, + // CShuffleBlockTransfer MN + 1, 1, S<1, 64, 1, 2>, 8, + MaskingSpec>, + ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle< + NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, + ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType, + AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp, + GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1, + 128, + // Gemm 0 + 64, 192, 48, 8, 8, + // Gemm 1 + 48, 64, 8, + 16, 16, 16, + // Per repeat = wave_m = wave_num, wave_n = 1 + 1, 12, 3, + // ABlockTransfer MK -> K0 M K1 + S<2, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B0BlockTransfer LK -> K0 L K1 + S<2, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B1BlockTransfer NL -> L0 N L1 + S<2, 8, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, false, + // CShuffleBlockTransfer MN + 1, 1, S<1, 64, 1, 2>, 8, + MaskingSpec>, + ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle< + NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, + ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType, + AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp, + GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1, + 128, + // Gemm 0 + 64, 64, 48, 8, 8, + // Gemm 1 + 48, 64, 8, + 16, 16, 16, + // Per repeat = wave_m = wave_num, wave_n = 1 + 1, 4, 3, + // ABlockTransfer MK -> K0 M K1 + S<2, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B0BlockTransfer LK -> K0 L K1 + S<2, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B1BlockTransfer NL -> L0 N L1 + S<2, 8, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, false, + // CShuffleBlockTransfer MN + 1, 1, S<1, 64, 1, 2>, 8, + MaskingSpec>, +#endif +#ifdef CK_MHA_USE_WAVE_8 + ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle< + NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, + ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType, + AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp, + GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1, + 256, + // Gemm 0 + 128, 192, 48, 8,4, + // Gemm 1 + 48, 64, 8, + 16, 16, 16, + // Per repeat = wave_m = wave_num, wave_n = 1 + 1, 12, 3, + // ABlockTransfer MK -> K0 M K1 + S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B0BlockTransfer LK -> K0 L K1 + S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, + // B1BlockTransfer NL -> L0 N L1 + S<2, 16, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, false, + // CShuffleBlockTransfer MN + 1, 1, S<1, 128, 1, 2>, 8, + MaskingSpec> +#endif + >; +// clang-format on +// Ref Gemm0: fp16 in, fp32 out +using ReferenceGemm0Instance = ck::tensor_operation::host::ReferenceBatchedGemm; + +// Ref Softmax: fp32 in, fp16 out +using ReferenceSoftmaxInstance = + ck::tensor_operation::host::ReferenceSoftmax; + +// Ref Gemm1: fp16 in, fp16 out +using ReferenceGemm1Instance = ck::tensor_operation::host::ReferenceBatchedGemm; + +#include "run_self_attention_wmma.inc" + +int main(int argc, char* argv[]) { return run(argc, argv); } diff --git a/example/64_fpAintB_gemm/CMakeLists.txt b/example/64_fpAintB_gemm/CMakeLists.txt new file mode 100644 index 0000000000..89cc2d7f62 --- /dev/null +++ b/example/64_fpAintB_gemm/CMakeLists.txt @@ -0,0 +1,5 @@ +if(GPU_TARGETS MATCHES "gfx11") + add_custom_target(example_fpAintB_gemm_wmma) + add_example_executable(example_fp16int8_gemm_wmma fp16int8_gemm_wmma.cpp) + add_dependencies(example_fpAintB_gemm_wmma example_fp16int8_gemm_wmma) +endif() diff --git a/example/64_fpAintB_gemm/common.hpp b/example/64_fpAintB_gemm/common.hpp new file mode 100644 index 0000000000..4fb4c41d05 --- /dev/null +++ b/example/64_fpAintB_gemm/common.hpp @@ -0,0 +1,123 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/utility/data_type.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/fill.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_fpAintB_gemm.hpp" + +struct ProblemSize final +{ + ck::index_t M = 3840; + ck::index_t N = 4096; + ck::index_t K = 4096; + + ck::index_t StrideA = 4096; + ck::index_t StrideB = 4096; + ck::index_t StrideC = 4096; +}; + +struct ExecutionConfig final +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; +}; + +template +using S = ck::Sequence; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +template +struct UnsignedWeightPreprocessor +{ +}; + +template <> +struct UnsignedWeightPreprocessor +{ + using UnsignedWeight = Tensor; + using SignedWeight = Tensor; + static UnsignedWeight convert(SignedWeight const& Input) + { + + UnsignedWeight Output = Input.template CopyAsType(); + + auto f_kn = [&](auto k, auto n) { + const uint8_t adder = 128; + int8_t v_signed_weight; + uint8_t v_unsigned_weight; + + ck::tensor_operation::element_wise::PassThrough{}(v_signed_weight, Input(k, n)); + v_unsigned_weight = ck::type_convert(v_signed_weight) + adder; + Output(k, n) = v_unsigned_weight; + }; + + make_ParallelTensorFunctor(f_kn, Input.mDesc.GetLengths()[0], Input.mDesc.GetLengths()[1])( + std::thread::hardware_concurrency()); + + return Output; + } + + UnsignedWeight operator()(SignedWeight const& Input) { return convert(Input); } +}; + +inline bool +parse_cmd_args(int argc, char* argv[], ProblemSize& problem_size, ExecutionConfig& config) +{ + if(argc == 1) + { + // use default case + } + else if(argc == 4) + { + config.do_verification = std::stoi(argv[1]); + config.init_method = std::stoi(argv[2]); + config.time_kernel = std::stoi(argv[3]); + } + else if(argc == 10) + { + config.do_verification = std::stoi(argv[1]); + config.init_method = std::stoi(argv[2]); + config.time_kernel = std::stoi(argv[3]); + + problem_size.M = std::stoi(argv[4]); + problem_size.N = std::stoi(argv[5]); + problem_size.K = std::stoi(argv[6]); + + problem_size.StrideA = std::stoi(argv[7]); + problem_size.StrideB = std::stoi(argv[8]); + problem_size.StrideC = std::stoi(argv[9]); + } + else + { + std::cerr << "arg1: verification (0=no, 1=yes)" << std::endl + << "arg2: initialization (0=no init, 1=integer value, 2=decimal value)" + << std::endl + << "arg3: time kernel (0=no, 1=yes)" << std::endl + << "arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC" << std::endl; + return false; + } + + return true; +} diff --git a/example/64_fpAintB_gemm/fp16int8_gemm_wmma.cpp b/example/64_fpAintB_gemm/fp16int8_gemm_wmma.cpp new file mode 100644 index 0000000000..9dc97fecd8 --- /dev/null +++ b/example/64_fpAintB_gemm/fp16int8_gemm_wmma.cpp @@ -0,0 +1,93 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +#include "ck/tensor_operation/gpu/device/impl/device_fpAintB_gemm_wmma.hpp" + +// Implementation follows the paper: +// Kim, Young Jin, Rawn Henry, Raffy Fahim, and Hany Hassan Awadalla. “Who Says Elephants Can’t Run: +// Bringing Large Scale MoE Models into Cloud Scale Production.” arXiv, November 17, 2022. +// https://doi.org/10.48550/arXiv.2211.10017. Assume weight (Matrix B) is add preprocess to +// unsigned. + +// The DeviceOp is CDataType = ADataType * Dequant(BDataType) * ScaleDataType +// The HostRef is CDataType = ADataType * Dequant(QuantDataType) * ScaleDataType + +// TODO: Current implementation consume more VGPR than expected. + +using ADataType = ck::half_t; +using QuantDataType = int8_t; +using BDataType = uint8_t; +using ScaleDataType = ck::half_t; +using AccDataType = float; +using CShuffleDataType = float; +using CDataType = ck::half_t; + +using ALayout = Row; +using BLayout = Col; +using CLayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CElementOp = PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +// clang-format off +using DeviceGemmInstance = ck::tensor_operation::device::DeviceFpAintBGemm_Wmma_CShuffle + < ALayout, + BLayout, + CLayout, + ADataType, + BDataType, + ScaleDataType, + CDataType, + AccDataType, + CShuffleDataType, + AElementOp, + BElementOp, + CElementOp, + GemmDefault, + 1, // Prefetch stage + 128, // BlockSize + 64, // MPerBlock + 128, // NPerBlock + 64, // KPerBlock + 8, // K1 + 16, // MPerWmma + 16, // NPerWmma + 2, // M-Repeat // M-PerWmma / M-Repeat = M-Wave + 4, // N-Repeat // N-PerWmma / N-Repeat = N-Wave + S<4, 32, 1>, + S<1, 0, 2>, + S<1, 0, 2>, + 2, + 8, + 8, + true, + S<4, 32, 1>, + S<1, 0, 2>, + S<1, 0, 2>, + 2, + 8, + 8, + true, + 1, // C shuffle (M Repeat) Per store + 1, // C shuffle (N Repeat) Per store + S<1, 32, 1, 4>, + 8>; +// clang-format on + +using ReferenceGemmInstance = ck::tensor_operation::host::ReferencefpAintBGemm; + +#include "run_gemm_example.inc" + +int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); } diff --git a/example/64_fpAintB_gemm/run_gemm_example.inc b/example/64_fpAintB_gemm/run_gemm_example.inc new file mode 100644 index 0000000000..dc2bdc18f0 --- /dev/null +++ b/example/64_fpAintB_gemm/run_gemm_example.inc @@ -0,0 +1,172 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +bool run_gemm(const ProblemSize& problem_size, const ExecutionConfig& config) +{ +#if defined(BUILD_INT4_EXAMPLE) && defined(CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4) + static_assert(sizeof(ck::int4_t) == sizeof(int8_t)); +#endif + + using namespace ck::literals; + + auto& [M, N, K, StrideA, StrideB, StrideC] = problem_size; + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + if constexpr(std::is_same_v) + { + return HostTensorDescriptor({row, col}, {stride, 1_uz}); + } + else + { + return HostTensorDescriptor({row, col}, {1_uz, stride}); + } + }; + + Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); + Tensor quant_b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + // assume scale tensor is [1, n] + Tensor scale_k_n(f_host_tensor_descriptor(K, N, 0, Row{})); + + switch(config.init_method) + { + case 0: break; + case 1: + ck::utils::FillUniformDistributionIntegerValue{-1.f, 1.f}(a_m_k); + ck::utils::FillUniformDistributionIntegerValue{-1.f, 1.f}(quant_b_k_n); + ck::utils::FillUniformDistributionIntegerValue{-1.f, 1.f}(scale_k_n); + break; + case 2: + ck::utils::FillUniformDistribution{-1.f, 1.f}(a_m_k); + ck::utils::FillUniformDistribution{-1.f, 1.f}(quant_b_k_n); + ck::utils::FillUniformDistribution{-1.f, 1.f}(scale_k_n); + break; + default: + ck::utils::FillUniformDistribution{-1.f, 1.f}(a_m_k); + ck::utils::FillUniformDistribution{-1.f, 1.f}(quant_b_k_n); + ck::utils::FillUniformDistribution{-1.f, 1.f}(scale_k_n); + } + + UnsignedWeightPreprocessor preprocessor; + Tensor b_k_n = preprocessor(quant_b_k_n); + + Tensor c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + Tensor c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + + std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; + std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; + std::cout << "scale_k_n: " << scale_k_n.mDesc << std::endl; + std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl; + +#ifdef BUILD_INT4_EXAMPLE + DeviceMem a_m_k_device_buf(sizeof(KernelADataType) * a_m_k.mDesc.GetElementSpaceSize()); + DeviceMem b_k_n_device_buf(sizeof(KernelBDataType) * b_k_n.mDesc.GetElementSpaceSize()); + DeviceMem c_m_n_device_buf(sizeof(KernelCDataType) * + c_m_n_device_result.mDesc.GetElementSpaceSize()); + + const Tensor a_m_k_converted(a_m_k); + const Tensor b_k_n_converted(b_k_n); + + a_m_k_device_buf.ToDevice(a_m_k_converted.mData.data()); + b_k_n_device_buf.ToDevice(b_k_n_converted.mData.data()); +#else + DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize()); + DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize()); + DeviceMem scale_k_n_device_buf(sizeof(ScaleDataType) * scale_k_n.mDesc.GetElementSpaceSize()); + DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpaceSize()); + + a_m_k_device_buf.ToDevice(a_m_k.mData.data()); + b_k_n_device_buf.ToDevice(b_k_n.mData.data()); + scale_k_n_device_buf.ToDevice(scale_k_n.mData.data()); +#endif + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto c_element_op = CElementOp{}; + + // do GEMM + auto gemm = DeviceGemmInstance{}; + auto invoker = gemm.MakeInvoker(); + auto argument = gemm.MakeArgument( +#ifdef BUILD_INT4_EXAMPLE + static_cast(a_m_k_device_buf.GetDeviceBuffer()), + static_cast(b_k_n_device_buf.GetDeviceBuffer()), + static_cast(c_m_n_device_buf.GetDeviceBuffer()), +#else + static_cast(a_m_k_device_buf.GetDeviceBuffer()), + static_cast(b_k_n_device_buf.GetDeviceBuffer()), + static_cast(scale_k_n_device_buf.GetDeviceBuffer()), + static_cast(c_m_n_device_buf.GetDeviceBuffer()), +#endif + M, + N, + K, + StrideA, + StrideB, + StrideC, + a_element_op, + b_element_op, + c_element_op); + + if(!gemm.IsSupportedArgument(argument)) + { + std::cerr << gemm.GetTypeString() << " does not support this problem" << std::endl; + + return true; + } + + float ave_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel}); + + std::size_t flop = 2_uz * M * N * K; + std::size_t num_btype = + sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(CDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " + << gemm.GetTypeString() << std::endl; + + if(config.do_verification) + { + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument(a_m_k, + quant_b_k_n, + scale_k_n, + c_m_n_host_result, + a_element_op, + b_element_op, + c_element_op); + + ref_invoker.Run(ref_argument); + +#ifdef BUILD_INT4_EXAMPLE + Tensor c_m_n_device_result_converted(c_m_n_host_result.mDesc); + + c_m_n_device_buf.FromDevice(c_m_n_device_result_converted.mData.data()); + + c_m_n_device_result = c_m_n_device_result_converted.CopyAsType(); + + return ck::utils::check_err(c_m_n_device_result_converted, c_m_n_host_result); +#else + c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data()); + + return ck::utils::check_err(c_m_n_device_result, c_m_n_host_result); +#endif + } + + return true; +} + +bool run_gemm_example(int argc, char* argv[]) +{ + ProblemSize problem_size; + ExecutionConfig config; + + return !parse_cmd_args(argc, argv, problem_size, config) || run_gemm(problem_size, config); +} diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp index b3d45f3d0c..f8ee283c67 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp @@ -7,6 +7,7 @@ #include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" #include "ck/tensor_operation/gpu/warp/wmma_gemm.hpp" #include "ck/tensor_description/tensor_adaptor.hpp" +#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" #define CK_MNK_LOOP @@ -16,25 +17,45 @@ template -/* A: K0PerBlock x MPerBlock x K1 + index_t KPack, + bool AEnableLds = true, + bool BEnableLds = true, + bool TransposeC = false> +/* Option: Read from LDS, big buffer hold all threads required data + * Source + * A: K0PerBlock x MPerBlock x K1 * B: K0PerBlock x NPerBlock x K1 - * C: MRepeat x MWave x MSubGroup x NRepeat x NWave x NThreadPerSubGroup x MAccVgprs + * Destination + * C, non-transpose + * thread level: MRepeat x NRepeat x MAccVgprs + * block level: MRepeat x MWave x MSubGroup x NRepeat x NWave x NThreadPerSubGroup x MAccVgprs * KPACK == WMMA_K = 16 + * + * Option: Read from VMEM, small buffer hold each thread own required data (Skip LDS) + * Source: + * A(if skip LDS): MRepeat x KPack + * B(if skip LDS): NRepeat x KPack + * Destination + * C, non-transpose + * block level: MRepeat x MWave x MSubGroup x NRepeat x NWave x NThreadPerSubGroup x MAccVgprs */ -struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle +struct BlockwiseGemmWMMA { static constexpr auto I0 = Number<0>{}; static constexpr auto I1 = Number<1>{}; static constexpr auto I2 = Number<2>{}; static constexpr auto I3 = Number<3>{}; static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; static constexpr auto WmmaK = Number<16>{}; using ThisThreadBlock = ThisThreadBlock; @@ -42,18 +63,16 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle // Hardcode of WaveSize, since current HIP Runtime(5.4.0-10984) could not return correct one. static constexpr index_t WaveSize = 32; - static constexpr index_t MPerBlock = AK0MK1BlockDesc{}.GetLength(I1); - static constexpr index_t NPerBlock = BK0NK1BlockDesc{}.GetLength(I1); - static constexpr index_t KPerBlock = - BK0NK1BlockDesc{}.GetLength(I0) * BK0NK1BlockDesc{}.GetLength(I2); - - static constexpr index_t A_K0 = AK0MK1BlockDesc{}.GetLength(I0); - static constexpr index_t B_K0 = BK0NK1BlockDesc{}.GetLength(I0); - static constexpr index_t A_K1 = AK0MK1BlockDesc{}.GetLength(I2); - static constexpr index_t B_K1 = BK0NK1BlockDesc{}.GetLength(I2); + // When use LDS, each Row(16 consecutive lanes) read whole data from source buffer + // When not use LDS, each Row read half of whole data from source buffer, exchange the data via + // permutation + static constexpr index_t A_KRow = AEnableLds ? 1 : 2; + static constexpr index_t B_KRow = BEnableLds ? 1 : 2; + static constexpr index_t A_K1 = ABlockDesc{}.GetLength(I5); + static constexpr index_t B_K1 = BBlockDesc{}.GetLength(I5); static constexpr auto wmma_gemm = - WmmaGemm{}; + WmmaGemm{}; static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerWMMA); static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerWMMA); @@ -79,371 +98,39 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle return threadid_to_wave_idx_adaptor.CalculateBottomIndex(make_multi_index(thread_id)); } + // Default, Block buffer in LDS, thread level offset enabled __device__ static auto CalculateAThreadOriginDataIndex() { - const auto wave_idx = GetWaveIdx(); - - const auto waveId_m = wave_idx[I0]; - - const auto WMMA_a_idx = wmma_gemm.CalculateAThreadOriginDataIndex(); - // |KRepeat |MRepeat|MWave |MLane |KPack - return make_tuple(0, 0, waveId_m, WMMA_a_idx, 0); - } - - __device__ static auto CalculateBThreadOriginDataIndex() - { - const auto wave_idx = GetWaveIdx(); - - const auto waveId_n = wave_idx[I1]; - - const auto WMMA_b_idx = wmma_gemm.CalculateBThreadOriginDataIndex(); - // |KRepeat |NRepeat|Nwave |NLane |KPack - return make_tuple(0, 0, waveId_n, WMMA_b_idx, 0); - } - - template - __device__ static auto CalculateCThreadOriginDataIndex(Number, Number) - { - const auto wave_idx = GetWaveIdx(); - - const auto waveId_m = wave_idx[I0]; - const auto waveId_n = wave_idx[I1]; - - const auto blk_idx = wmma_gemm.GetBeginOfThreadBlk(); - - constexpr auto mrepeat_mwave_mperWMMA_to_m_adaptor = make_single_stage_tensor_adaptor( - make_tuple(make_unmerge_transform(make_tuple(MRepeat, MWaves, MPerWMMA))), - make_tuple(Sequence<0>{}), - make_tuple(Sequence<0, 1, 2>{})); - - constexpr auto nrepeat_nwave_nperWMMA_to_n_adaptor = make_single_stage_tensor_adaptor( - make_tuple(make_unmerge_transform(make_tuple(NRepeat, NWaves, NPerWMMA))), - make_tuple(Sequence<0>{}), - make_tuple(Sequence<0, 1, 2>{})); - - const index_t c_thread_m = mrepeat_mwave_mperWMMA_to_m_adaptor.CalculateBottomIndex( - make_tuple(m0, waveId_m, blk_idx[I0]))[I0]; - const index_t c_thread_n = nrepeat_nwave_nperWMMA_to_n_adaptor.CalculateBottomIndex( - make_tuple(n0, waveId_n, blk_idx[I1]))[I0]; - - return make_tuple(c_thread_m, c_thread_n); - } - - __host__ __device__ BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle() - { - static_assert(AK0MK1BlockDesc::IsKnownAtCompileTime() && - BK0NK1BlockDesc::IsKnownAtCompileTime(), - "wrong! Desc should be known at compile-time"); - - static_assert(ThisThreadBlock::GetNumOfThread() == MWaves * NWaves * WaveSize, - "ThisThreadBlock::GetNumOfThread() != MWaves * NWaves * WaveSize\n"); - - static_assert(MPerBlock % (MPerWMMA * MRepeat) == 0 && - NPerBlock % (NPerWMMA * NRepeat) == 0, - "wrong!"); - } - - // Thread level, register decriptor. Vector-write - __host__ __device__ static constexpr auto - GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs() - { - constexpr auto c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens = - wmma_gemm.GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths(); - - constexpr auto MSubGroup = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I0]; - constexpr auto NThreadPerSubGroup = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I1]; - constexpr auto MAccVgprs = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I2]; - - return make_naive_tensor_descriptor_packed( - // |MRepeat |MWave |MSubGroup |NRepeat |NWave - // |NThreadPerSubGroup |MAccVgprs - make_tuple(Number{}, - I1, - MSubGroup, - Number{}, - I1, - NThreadPerSubGroup, - MAccVgprs)); - } - - // Provide dimension size - __host__ __device__ static constexpr auto - GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs() - { - constexpr auto c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma = - make_naive_tensor_descriptor_packed(make_tuple(Number{}, - Number{}, - Number{}, - Number{}, - Number{}, - Number{})); - - return wmma_gemm - .MakeCDesc_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs( - c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma); - } - - __host__ __device__ static constexpr auto MakeABlockDescriptor_K0_M0_M1_M2_K1() - { - return transform_tensor_descriptor( - AK0MK1BlockDesc{}, - make_tuple(make_pass_through_transform(Number{}), - make_unmerge_transform( - make_tuple(Number{}, Number{}, Number{})), - make_pass_through_transform(Number{})), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), - make_tuple(Sequence<0>{}, Sequence<1, 2, 3>{}, Sequence<4>{})); - } - - __host__ __device__ static constexpr auto MakeBBlockDescriptor_K0_N0_N1_N2_K1() - { - return transform_tensor_descriptor( - BK0NK1BlockDesc{}, - make_tuple(make_pass_through_transform(Number{}), - make_unmerge_transform( - make_tuple(Number{}, Number{}, Number{})), - make_pass_through_transform(Number{})), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), - make_tuple(Sequence<0>{}, Sequence<1, 2, 3>{}, Sequence<4>{})); - } - - // M0_M1_M2 = MRepeat_MWave_MPerWmma, N0_N1_N2 = NRepeat_NWave_NPerWmma - static constexpr auto a_block_desc_k0_m0_m1_m2_k1 = MakeABlockDescriptor_K0_M0_M1_M2_K1(); - static constexpr auto b_block_desc_k0_n0_n1_n2_k1 = MakeBBlockDescriptor_K0_N0_N1_N2_K1(); - - template - __device__ void Run(const ABlockBuffer& a_block_buf, - const BBlockBuffer& b_block_buf, - CThreadBuffer& c_thread_buf) const - { - auto a_thread_buf = make_static_buffer( - a_thread_desc_.GetElementSpaceSize()); - auto b_thread_buf = make_static_buffer( - b_thread_desc_.GetElementSpaceSize()); - - // basic intrinsic to determine loopover direction - if constexpr(MRepeat < NRepeat) + if constexpr(AEnableLds) { - static_for<0, KPerBlock / WmmaK, 1>{}( - [&](auto k) { // k=0,1,2 instead of k=0,kpack*1, ... - static_for<0, MRepeat, 1>{}([&](auto m0) { - // read A - a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1, - make_tuple(Number{}, m0, I0, I0, I0), - a_block_buf, - a_thread_desc_, - make_tuple(I0, m0, I0, I0, I0), - a_thread_buf); + const auto wave_idx = GetWaveIdx(); + const auto waveId_m = wave_idx[I0]; + const auto WMMA_a_idx = wmma_gemm.CalculateAThreadOriginDataIndex(); - static_for<0, NRepeat, 1>{}([&](auto n0) { - // read B - b_thread_copy_.Run( - b_block_desc_k0_n0_n1_n2_k1, - make_tuple(Number{}, n0, I0, I0, I0), - b_block_buf, - b_thread_desc_, - make_tuple(I0, n0, I0, I0, I0), - b_thread_buf); - vector_type a_thread_vec; - vector_type b_thread_vec; - - static_for<0, WmmaK, 1>{}([&](auto i) { - a_thread_vec.template AsType()(i) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(i) = - b_thread_buf[Number{}]; - }); - - using wmma_input_type_a = typename vector_type::type; - using wmma_input_type_b = typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - - wmma_gemm.template Run( - a_thread_vec.template AsType()(Number<0>{}), - b_thread_vec.template AsType()(Number<0>{}), - c_thread_buf.GetVectorTypeReference(Number{})); - }); - }); - }); + // |KRepeat |MRepeat|MWave |KRow |MLane |KPack + return make_tuple(0, 0, waveId_m, 0, WMMA_a_idx, 0); } else { - static_for<0, KPerBlock / WmmaK, 1>{}( - [&](auto k) { // k=0,1,2 instead of k=0,kpack*1, ... - static_for<0, NRepeat, 1>{}([&](auto n0) { - // read B - b_thread_copy_.Run(b_block_desc_k0_n0_n1_n2_k1, - make_tuple(Number{}, n0, I0, I0, I0), - b_block_buf, - b_thread_desc_, - make_tuple(I0, n0, I0, I0, I0), - b_thread_buf); - static_for<0, MRepeat, 1>{}([&](auto m0) { - // read A - a_thread_copy_.Run( - a_block_desc_k0_m0_m1_m2_k1, - make_tuple(Number{}, m0, I0, I0, I0), - a_block_buf, - a_thread_desc_, - make_tuple(I0, m0, I0, I0, I0), - a_thread_buf); - vector_type a_thread_vec; - vector_type b_thread_vec; - - static_for<0, WmmaK, 1>{}([&](auto i) { - a_thread_vec.template AsType()(i) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(i) = - b_thread_buf[Number{}]; - }); - - using wmma_input_type_a = typename vector_type::type; - using wmma_input_type_b = typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - - wmma_gemm.template Run( - a_thread_vec.template AsType()(Number<0>{}), - b_thread_vec.template AsType()(Number<0>{}), - c_thread_buf.GetVectorTypeReference(Number{})); - }); - }); - }); + return make_tuple(0, 0, 0, 0, 0, 0); } } - protected: - // A[K0, M0, M1, M2, K1] - static constexpr auto a_thread_desc_ = make_naive_tensor_descriptor_packed( - make_tuple(Number{}, Number{}, I1, I1, Number{})); - - // B[K0, N0, N1, N2, K1] - static constexpr auto b_thread_desc_ = make_naive_tensor_descriptor_packed( - make_tuple(Number{}, Number{}, I1, I1, Number{})); - - // C[M, N, NumRegWMMA] - static constexpr auto c_thread_desc_ = make_naive_tensor_descriptor_packed( - make_tuple(Number{}, Number{}, wmma_gemm.GetRegSizePerWmma())); - - using AThreadCopy = ThreadwiseTensorSliceTransfer_v4, - Sequence<0, 1, 2, 3, 4>, - 4, - A_K1, - A_K1>; - - using BThreadCopy = ThreadwiseTensorSliceTransfer_v4, - Sequence<0, 1, 2, 3, 4>, - 4, - B_K1, - B_K1>; - - AThreadCopy a_thread_copy_{CalculateAThreadOriginDataIndex()}; - BThreadCopy b_thread_copy_{CalculateBThreadOriginDataIndex()}; -}; - -// block wise level pipe designed for inline asm -template -/* A: K0PerBlock x MPerBlock x K1 - * B: K0PerBlock x NPerBlock x K1 - * C: MRepeat x MWave x MSubGroup x NRepeat x NWave x NThreadPerSubGroup x MAccVgprs - * KPACK == WMMA_K = 16 - */ -struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_FIFO -{ - static constexpr auto I0 = Number<0>{}; - static constexpr auto I1 = Number<1>{}; - static constexpr auto I2 = Number<2>{}; - static constexpr auto I3 = Number<3>{}; - static constexpr auto I4 = Number<4>{}; - static constexpr auto WmmaK = Number<16>{}; - - using ThisThreadBlock = ThisThreadBlock; - - // Hardcode of WaveSize, since current HIP Runtime(5.4.0-10984) could not return correct one. - static constexpr index_t WaveSize = 32; - - static constexpr index_t MPerBlock = AK0MK1BlockDesc{}.GetLength(I1); - static constexpr index_t NPerBlock = BK0NK1BlockDesc{}.GetLength(I1); - static constexpr index_t KPerBlock = - BK0NK1BlockDesc{}.GetLength(I0) * BK0NK1BlockDesc{}.GetLength(I2); - - static constexpr index_t A_K0 = AK0MK1BlockDesc{}.GetLength(I0); - static constexpr index_t B_K0 = BK0NK1BlockDesc{}.GetLength(I0); - static constexpr index_t A_K1 = AK0MK1BlockDesc{}.GetLength(I2); - static constexpr index_t B_K1 = BK0NK1BlockDesc{}.GetLength(I2); - - static constexpr auto wmma_gemm = - WmmaGemm{}; - - static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerWMMA); - static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerWMMA); - - StaticBufferTupleOfVector - c_thread_buf_; - - __host__ __device__ constexpr auto& GetCThreadBuffer() { return c_thread_buf_; } - - __device__ static auto GetWaveIdx() - { - const index_t thread_id = ThisThreadBlock::GetThreadId(); - - constexpr auto threadid_to_wave_idx_adaptor = make_single_stage_tensor_adaptor( - make_tuple(make_merge_transform(make_tuple(MWaves, NWaves, WaveSize))), - make_tuple(Sequence<0, 1, 2>{}), - make_tuple(Sequence<0>{})); - - return threadid_to_wave_idx_adaptor.CalculateBottomIndex(make_multi_index(thread_id)); - } - - __device__ static auto CalculateAThreadOriginDataIndex() - { - const auto wave_idx = GetWaveIdx(); - - const auto waveId_m = wave_idx[I0]; - - const auto WMMA_a_idx = wmma_gemm.CalculateAThreadOriginDataIndex(); - // |KRepeat |MRepeat|MWave |MLane |KPack - return make_tuple(0, 0, waveId_m, WMMA_a_idx, 0); - } - __device__ static auto CalculateBThreadOriginDataIndex() { - const auto wave_idx = GetWaveIdx(); + if constexpr(BEnableLds) + { + const auto wave_idx = GetWaveIdx(); + const auto waveId_n = wave_idx[I1]; + const auto WMMA_b_idx = wmma_gemm.CalculateBThreadOriginDataIndex(); - const auto waveId_n = wave_idx[I1]; - - const auto WMMA_b_idx = wmma_gemm.CalculateBThreadOriginDataIndex(); - // |KRepeat |NRepeat|Nwave |NLane |KPack - return make_tuple(0, 0, waveId_n, WMMA_b_idx, 0); + // |KRepeat |NRepeat|Nwave |KRow |NLane |KPack + return make_tuple(0, 0, waveId_n, 0, WMMA_b_idx, 0); + } + else + { + return make_tuple(0, 0, 0, 0, 0, 0); + } } template @@ -474,10 +161,26 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_FIFO return make_tuple(c_thread_m, c_thread_n); } - __host__ __device__ BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_FIFO() + template + __device__ static auto CalculateCThreadOriginDataIndex7D(Number, Number) { - static_assert(AK0MK1BlockDesc::IsKnownAtCompileTime() && - BK0NK1BlockDesc::IsKnownAtCompileTime(), + const auto wave_idx = GetWaveIdx(); + + const auto waveId_m = wave_idx[I0]; + const auto waveId_n = wave_idx[I1]; + + const auto blk_idx = wmma_gemm.GetBeginOfThreadBlk3D(); + + return make_tuple( + Number{}, waveId_m, blk_idx[I0], Number{}, waveId_n, blk_idx[I1], blk_idx[I2]); + } + + using Tuple6 = decltype(CalculateAThreadOriginDataIndex()); + __host__ __device__ BlockwiseGemmWMMA(Tuple6 a_origin = CalculateAThreadOriginDataIndex(), + Tuple6 b_origin = CalculateBThreadOriginDataIndex()) + : a_thread_copy_(a_origin), b_thread_copy_(b_origin) + { + static_assert(ABlockDesc::IsKnownAtCompileTime() && BBlockDesc::IsKnownAtCompileTime(), "wrong! Desc should be known at compile-time"); static_assert(ThisThreadBlock::GetNumOfThread() == MWaves * NWaves * WaveSize, @@ -487,6 +190,22 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_FIFO NPerBlock % (NPerWMMA * NRepeat) == 0, "wrong!"); } + + // transposed WMMA output C' = B' * A' + __host__ __device__ static constexpr auto + GetCThreadDescriptor_MRepeat_MWave_MThreadPerSubGroup_NRepeat_NWave_NSubGroup_NAccVgprs() + { + constexpr auto c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens = + wmma_gemm.GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths(); + + constexpr auto NAccVgprs = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I2]; + + return make_naive_tensor_descriptor_packed( + // |MRepeat |MWave |MSubGroup |NRepeat |NWave + // |NThreadPerSubGroup |MAccVgprs + make_tuple(Number{}, I1, I1, Number{}, I1, I1, NAccVgprs)); + } + // Thread level, register decriptor. Vector-write __host__ __device__ static constexpr auto GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs() @@ -494,20 +213,19 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_FIFO constexpr auto c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens = wmma_gemm.GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths(); - constexpr auto MSubGroup = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I0]; - constexpr auto NThreadPerSubGroup = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I1]; - constexpr auto MAccVgprs = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I2]; - - return make_naive_tensor_descriptor_packed( + constexpr auto MAccVgprs = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I2]; + constexpr auto AccStride = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I3]; + return make_naive_tensor_descriptor( // |MRepeat |MWave |MSubGroup |NRepeat |NWave // |NThreadPerSubGroup |MAccVgprs - make_tuple(Number{}, - I1, - MSubGroup, - Number{}, - I1, - NThreadPerSubGroup, - MAccVgprs)); + make_tuple(Number{}, I1, I1, Number{}, I1, I1, MAccVgprs), + make_tuple(Number{} * MAccVgprs * AccStride, + Number{} * MAccVgprs * AccStride, + Number{} * MAccVgprs * AccStride, + MAccVgprs * AccStride, + MAccVgprs * AccStride, + MAccVgprs * AccStride, + AccStride)); } template @@ -532,6 +250,23 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_FIFO c_grid_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma); } + // transposed WMMA output C' = B' * A' + __host__ __device__ static constexpr auto + GetCBlockDescriptor_MRepeat_MWave_MThreadPerSubGroup_NRepeat_NWave_NSubGroup_NAccVgprs() + { + constexpr auto c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma = + make_naive_tensor_descriptor_packed(make_tuple(Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + Number{})); + + return wmma_gemm + .MakeCDesc_MBlockxRepeat_MWave_MThreadPerSubGroup_NBlockxRepeat_NWave_NSubGroup_NAccVgprs( + c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma); + } + // Provide dimension size __host__ __device__ static constexpr auto GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs() @@ -549,33 +284,10 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_FIFO c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma); } - __host__ __device__ static constexpr auto MakeABlockDescriptor_K0_M0_M1_M2_K1() - { - return transform_tensor_descriptor( - AK0MK1BlockDesc{}, - make_tuple(make_pass_through_transform(Number{}), - make_unmerge_transform( - make_tuple(Number{}, Number{}, Number{})), - make_pass_through_transform(Number{})), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), - make_tuple(Sequence<0>{}, Sequence<1, 2, 3>{}, Sequence<4>{})); - } - - __host__ __device__ static constexpr auto MakeBBlockDescriptor_K0_N0_N1_N2_K1() - { - return transform_tensor_descriptor( - BK0NK1BlockDesc{}, - make_tuple(make_pass_through_transform(Number{}), - make_unmerge_transform( - make_tuple(Number{}, Number{}, Number{})), - make_pass_through_transform(Number{})), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), - make_tuple(Sequence<0>{}, Sequence<1, 2, 3>{}, Sequence<4>{})); - } - + // Describe how data allocated in thread copy src buffer // M0_M1_M2 = MRepeat_MWave_MPerWmma, N0_N1_N2 = NRepeat_NWave_NPerWmma - static constexpr auto a_block_desc_k0_m0_m1_m2_k1 = MakeABlockDescriptor_K0_M0_M1_M2_K1(); - static constexpr auto b_block_desc_k0_n0_n1_n2_k1 = MakeBBlockDescriptor_K0_N0_N1_N2_K1(); + static constexpr ABlockDesc a_block_desc_k0_m0_m1_m2_k1; + static constexpr BBlockDesc b_block_desc_k0_n0_n1_n2_k1; template __device__ void Run(const ABlockBuffer& a_block_buf, @@ -587,268 +299,235 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_FIFO auto b_thread_buf = make_static_buffer( b_thread_desc_.GetElementSpaceSize()); - constexpr auto RepeatDiff = MRepeat - NRepeat; - // Read all Mrepeat, Nrepeat - static_for<0, NRepeat, 1>{}([&](auto iN) { - b_thread_copy_.Run(b_block_desc_k0_n0_n1_n2_k1, - make_tuple(I0, Number{}, I0, I0, I0), - b_block_buf, - b_thread_desc_, - make_tuple(I0, Number{}, I0, I0, I0), - b_thread_buf); - }); + // basic intrinsic to determine loopover direction + if constexpr(MRepeat < NRepeat) + { + static_for<0, KPerBlock / KPack, 1>{}( + [&](auto k) { // k=0,1,2 instead of k=0,kpack*1, ... + static_for<0, MRepeat, 1>{}([&](auto m0) { + // read A + a_thread_copy_.Run( + a_block_desc_k0_m0_m1_m2_k1, + make_tuple(Number{}, m0, I0, I0, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(I0, m0, I0, I0, I0, I0), + a_thread_buf); - static_for<0, MRepeat, 1>{}([&](auto iM) { - a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1, - make_tuple(I0, Number{}, I0, I0, I0), - a_block_buf, - a_thread_desc_, - make_tuple(I0, Number{}, I0, I0, I0), - a_thread_buf); - }); + static_for<0, NRepeat, 1>{}([&](auto n0) { + // read B + b_thread_copy_.Run( + b_block_desc_k0_n0_n1_n2_k1, + make_tuple(Number{}, n0, I0, I0, I0, I0), + b_block_buf, + b_thread_desc_, + make_tuple(I0, n0, I0, I0, I0, I0), + b_thread_buf); - // Stage 1: Cut to Repeat Retangle to Square, assume MRepeat > NRepeat - static_for<0, RepeatDiff, 1>{}([&](auto iCut) { - static_for<0, NRepeat, 1>{}([&](auto iN) { - vector_type a_thread_vec; - vector_type b_thread_vec; + vector_type a_thread_vec; + vector_type b_thread_vec; - static_for<0, WmmaK, 1>{}([&](auto iK) { - a_thread_vec.template AsType()(iK) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(iK) = - b_thread_buf[Number{}]; - }); - using wmma_input_type_a = typename vector_type::type; - using wmma_input_type_b = typename vector_type::type; + static_for<0, KPack, 1>{}([&](auto i) { + a_thread_vec.template AsType()(i) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(i) = + b_thread_buf[Number{}]; + }); - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(iCut, iN, 0)); - // s_nop(); - wmma_gemm.template Run( - a_thread_vec.template AsType()(Number<0>{}), - b_thread_vec.template AsType()(Number<0>{}), - c_thread_buf.GetVectorTypeReference(Number{})); - // s_nop(); - }); - if constexpr(KPerBlock > WmmaK) - { - // Read Consumed Next inner loop A - a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1, - make_tuple(Number{}, Number{}, I0, I0, I0), - a_block_buf, - a_thread_desc_, - make_tuple(I0, Number{}, I0, I0, I0), - a_thread_buf); - } - }); + using wmma_input_type_a = typename vector_type::type; + using wmma_input_type_b = typename vector_type::type; - static_for{}([&](auto iWmmaK) { - // Stage 2: Run FIFO fashion loopover in Square - static_for<0, NRepeat, 1>{}([&](auto WmmaInnerloop) { - // Row Repeatation - static_for{}([&](auto iN) { - vector_type a_thread_vec; - vector_type b_thread_vec; + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - static_for<0, WmmaK, 1>{}([&](auto iK) { - a_thread_vec.template AsType()(iK) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(iK) = - b_thread_buf[Number{}]; + wmma_gemm.template Run( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); }); - using wmma_input_type_a = typename vector_type::type; - using wmma_input_type_b = typename vector_type::type; - - constexpr index_t c_offset = c_thread_desc_.CalculateOffset( - make_tuple(WmmaInnerloop + RepeatDiff, iN, 0)); - // s_nop(); - wmma_gemm.template Run( - a_thread_vec.template AsType()(Number<0>{}), - b_thread_vec.template AsType()(Number<0>{}), - c_thread_buf.GetVectorTypeReference(Number{})); - // s_nop(); }); + } + else + { + static_for<0, NRepeat, 1>{}([&](auto n0) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, KPerBlock / KPack, 1>{}([&](auto k) { // k=0,1,2 instead of + // k=0,kpack*1, .. + // read B + b_thread_copy_.Run( + b_block_desc_k0_n0_n1_n2_k1, + make_tuple(Number{}, n0, I0, I0, I0, I0), + b_block_buf, + b_thread_desc_, + make_tuple(I0, n0, I0, I0, I0, I0), + b_thread_buf); + // read A + a_thread_copy_.Run( + a_block_desc_k0_m0_m1_m2_k1, + make_tuple(Number{}, m0, I0, I0, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(I0, m0, I0, I0, I0, I0), + a_thread_buf); - // Read Consumed Next inner loop A - a_thread_copy_.Run( - a_block_desc_k0_m0_m1_m2_k1, - make_tuple( - Number{}, Number{}, I0, I0, I0), - a_block_buf, - a_thread_desc_, - make_tuple(I0, Number{}, I0, I0, I0), - a_thread_buf); + vector_type a_thread_vec; + vector_type b_thread_vec; - // Col Repeatation - static_for{}([&](auto iM) { - vector_type a_thread_vec; - vector_type b_thread_vec; + static_for<0, KPack, 1>{}([&](auto i) { + b_thread_vec.template AsType()(i) = + b_thread_buf[Number{}]; + a_thread_vec.template AsType()(i) = + a_thread_buf[Number{}]; + }); - static_for<0, WmmaK, 1>{}([&](auto iK) { - a_thread_vec.template AsType()(iK) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(iK) = - b_thread_buf[Number{}]; + using wmma_input_type_a = typename vector_type::type; + using wmma_input_type_b = typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + wmma_gemm.template Run( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); }); - using wmma_input_type_a = typename vector_type::type; - using wmma_input_type_b = typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(iM, WmmaInnerloop, 0)); - // s_nop(); - wmma_gemm.template Run( - a_thread_vec.template AsType()(Number<0>{}), - b_thread_vec.template AsType()(Number<0>{}), - c_thread_buf.GetVectorTypeReference(Number{})); - // s_nop(); }); - // Read Consumed Next inner loop B - b_thread_copy_.Run( - b_block_desc_k0_n0_n1_n2_k1, - make_tuple(Number{}, Number{}, I0, I0, I0), - b_block_buf, - b_thread_desc_, - make_tuple(I0, Number{}, I0, I0, I0), - b_thread_buf); }); - - // Stage 1: Cut to Repeat Retangle to Square, assume MRepeat > NRepeat - static_for<0, RepeatDiff, 1>{}([&](auto iCut) { - static_for<0, NRepeat, 1>{}([&](auto iN) { - vector_type a_thread_vec; - vector_type b_thread_vec; - - static_for<0, WmmaK, 1>{}([&](auto iK) { - a_thread_vec.template AsType()(iK) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(iK) = - b_thread_buf[Number{}]; - }); - using wmma_input_type_a = typename vector_type::type; - using wmma_input_type_b = typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(iCut, iN, 0)); - // s_nop(); - wmma_gemm.template Run( - a_thread_vec.template AsType()(Number<0>{}), - b_thread_vec.template AsType()(Number<0>{}), - c_thread_buf.GetVectorTypeReference(Number{})); - // s_nop(); - }); - if constexpr(KPerBlock > WmmaK) - { - a_thread_copy_.Run( - a_block_desc_k0_m0_m1_m2_k1, - make_tuple(Number<(iWmmaK + WmmaK) / A_K1>{}, Number{}, I0, I0, I0), - a_block_buf, - a_thread_desc_, - make_tuple(I0, Number{}, I0, I0, I0), - a_thread_buf); - } - }); - }); - - // Stage 2: Run FIFO fashion loopover in Square - static_for<0, NRepeat, 1>{}([&](auto WmmaInnerloop) { - // Row Repeatation - static_for{}([&](auto iN) { - vector_type a_thread_vec; - vector_type b_thread_vec; - - static_for<0, WmmaK, 1>{}([&](auto iK) { - a_thread_vec.template AsType()(iK) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(iK) = - b_thread_buf[Number{}]; - }); - using wmma_input_type_a = typename vector_type::type; - using wmma_input_type_b = typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(WmmaInnerloop + RepeatDiff, iN, 0)); - // s_nop(); - wmma_gemm.template Run( - a_thread_vec.template AsType()(Number<0>{}), - b_thread_vec.template AsType()(Number<0>{}), - c_thread_buf.GetVectorTypeReference(Number{})); - // s_nop(); - }); - - // Col Repeatation - static_for{}([&](auto iM) { - vector_type a_thread_vec; - vector_type b_thread_vec; - - static_for<0, WmmaK, 1>{}([&](auto iK) { - a_thread_vec.template AsType()(iK) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(iK) = - b_thread_buf[Number{}]; - }); - using wmma_input_type_a = typename vector_type::type; - using wmma_input_type_b = typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(iM, WmmaInnerloop, 0)); - // s_nop(); - wmma_gemm.template Run( - a_thread_vec.template AsType()(Number<0>{}), - b_thread_vec.template AsType()(Number<0>{}), - c_thread_buf.GetVectorTypeReference(Number{})); - // s_nop(); - }); - }); + } } protected: - // A[M0, M1, M2, K0 = WmmaK] - static constexpr auto a_thread_desc_ = make_naive_tensor_descriptor_packed( - make_tuple(Number{}, Number{}, I1, I1, Number{})); + static constexpr auto a_thread_desc_ = + make_naive_tensor_descriptor(make_tuple(Number{}, + Number{}, + I1, + Number{}, + I1, + Number{}), + make_tuple(Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + Number<1>{})); - // B[N0, N1, N2, K0 = WmmaK] - static constexpr auto b_thread_desc_ = make_naive_tensor_descriptor_packed( - make_tuple(Number{}, Number{}, I1, I1, Number{})); + static constexpr auto b_thread_desc_ = + make_naive_tensor_descriptor(make_tuple(Number{}, + Number{}, + I1, + Number{}, + I1, + Number{}), + make_tuple(Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + Number<1>{})); // C[M, N, NumRegWMMA] static constexpr auto c_thread_desc_ = make_naive_tensor_descriptor_packed( make_tuple(Number{}, Number{}, wmma_gemm.GetRegSizePerWmma())); - using AThreadCopy = ThreadwiseTensorSliceTransfer_v4, - Sequence<0, 1, 2, 3, 4>, - 4, - A_K1, - A_K1>; + template + struct AThreadCopySelector; - using BThreadCopy = ThreadwiseTensorSliceTransfer_v4, - Sequence<0, 1, 2, 3, 4>, - 4, - B_K1, - B_K1>; + template <> + struct AThreadCopySelector + { + using type = + ThreadwiseTensorSliceTransfer_v4, + Sequence<0, 1, 2, 3, 4, 5>, + 5, + A_K1, + A_K1>; + }; - AThreadCopy a_thread_copy_{CalculateAThreadOriginDataIndex()}; - BThreadCopy b_thread_copy_{CalculateBThreadOriginDataIndex()}; + template <> + struct AThreadCopySelector + { + using type = ThreadwiseTensorSliceTransfer_StaticToStatic_InterRow< + FloatA, + FloatA, + decltype(a_block_desc_k0_m0_m1_m2_k1), + decltype(a_thread_desc_), + tensor_operation::element_wise::PassThrough, + Sequence, + Sequence<0, 1, 2, 3, 4, 5>, + 5, + A_K1, + 0x76543210, + 0xfedcba98, + TransposeC ? false : true>; + }; + + template + struct BThreadCopySelector; + + template <> + struct BThreadCopySelector + { + using type = + ThreadwiseTensorSliceTransfer_v4, + Sequence<0, 1, 2, 3, 4, 5>, + 5, + B_K1, + B_K1>; + }; + + template <> + struct BThreadCopySelector + { + using type = ThreadwiseTensorSliceTransfer_StaticToStatic_InterRow< + FloatB, + FloatB, + decltype(b_block_desc_k0_n0_n1_n2_k1), + decltype(b_thread_desc_), + tensor_operation::element_wise::PassThrough, + Sequence, + Sequence<0, 1, 2, 3, 4, 5>, + 5, + B_K1, + 0x76543210, + 0xfedcba98, + TransposeC ? true : false>; + }; + + typename AThreadCopySelector::type a_thread_copy_; + typename BThreadCopySelector::type b_thread_copy_; }; } // namespace ck diff --git a/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1_dequant.hpp b/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1_dequant.hpp new file mode 100644 index 0000000000..ab826bb041 --- /dev/null +++ b/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1_dequant.hpp @@ -0,0 +1,223 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_description/cluster_descriptor.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1_dequant.hpp" + +namespace ck { + +/** + * @brief Blockwise data transfer with dequantization + * + * RunRead would load low-precision data and scale data. + * RunWrite would process dequantization process. + * Assume Scale is identical along K-dimension + * + * This version does following things to avoid scratch memory issue + * 1. Use StaticallyIndexedArray instead of C array for thread buffer + * 2. ThreadwiseTensorSliceTransfer_v3 does not keep reference to tensor descriptor + * 3. ThreadwiseTensorSliceTransfer_v3::Run() does not construct new tensor coordinate + * + */ +template +struct ThreadGroupTensorSliceTransfer_v4r1_dequant +{ + static constexpr index_t nDim = remove_reference_t::GetNumOfDimension(); + + static constexpr auto thread_slice_lengths = BlockSliceLengths{} / ThreadClusterLengths{}; + static constexpr auto scale_thread_slice_lengths = + BlockScaleSliceLengths{} / ThreadClusterLengths{}; + + using Index = MultiIndex; + + __device__ constexpr ThreadGroupTensorSliceTransfer_v4r1_dequant( + const SrcDesc& src_desc, + const Index& src_block_slice_origin, + const SrcElementwiseOperation& src_element_op, + const ScaleDesc& scale_desc, + const Index& scale_block_slice_origin, + const ScaleElementwiseOperation& scale_element_op, + const DstDesc& dst_desc, + const Index& dst_block_slice_origin, + const DstElementwiseOperation& dst_element_op) + : threadwise_transfer_(src_desc, + make_zero_multi_index(), + src_element_op, + scale_desc, + make_zero_multi_index(), + scale_element_op, + dst_desc, + make_zero_multi_index(), + dst_element_op) + + { + static_assert(nDim == remove_cvref_t::GetNumOfDimension() && + nDim == remove_cvref_t::GetNumOfDimension() && + nDim == remove_cvref_t::GetNumOfDimension() && + nDim == ThreadClusterLengths::Size() && + nDim == ThreadClusterArrangeOrder::Size() && + nDim == SrcDimAccessOrder::Size() && nDim == DstDimAccessOrder::Size(), + "wrong! nDim not consistent"); + + static_assert( + is_same{} && + is_same{}, + "wrong! threads should be mapped to cover entire slicing window"); + + static_assert(ThreadGroup::GetNumOfThread() >= thread_cluster_desc_.GetElementSize(), + "wrong! ThreadGroup::GetNumOfThread() too small"); + + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) + { + const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex( + make_multi_index(ThreadGroup::GetThreadId())); + + const auto thread_data_idx_begin = thread_cluster_idx * thread_slice_lengths; + + threadwise_transfer_.SetSrcSliceOrigin(src_desc, + src_block_slice_origin + thread_data_idx_begin); + threadwise_transfer_.SetScaleSliceOrigin( + scale_desc, scale_block_slice_origin + thread_data_idx_begin); + threadwise_transfer_.SetDstSliceOrigin(dst_desc, + dst_block_slice_origin + thread_data_idx_begin); + } + } + + template + __device__ void RunRead(const SrcDesc& src_desc, + const SrcBuffer& src_buf, + Number thread_scratch_id = Number{}) + { + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) + { + threadwise_transfer_.RunRead(src_desc, src_buf, thread_scratch_id); + } + } + + // With the assumption, scale scratch is always one + template + __device__ void RunScaleRead(const ScaleDesc& scale_desc, const ScaleBuffer& scale_buf) + { + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) + { + threadwise_transfer_.RunScaleRead(scale_desc, scale_buf); + } + } + + template + __device__ void RunWrite(const DstDesc& dst_desc, + DstBuffer& dst_buf, + Number thread_scratch_id = Number{}) + { + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) + { + threadwise_transfer_.RunWrite(dst_desc, dst_buf, thread_scratch_id); + } + } + + // We don't prefer use this API directly + /* + template + __device__ void Run(const SrcDesc& src_desc, + const SrcBuffer& src_buf, + const DstDesc& dst_desc, + DstBuffer& dst_buf, + Number thread_scratch_id) + { + RunRead(src_desc, src_buf, thread_scratch_id); + RunWrite(dst_desc, dst_buf, thread_scratch_id); + } + */ + + __device__ void MoveSrcSliceWindow(const SrcDesc& src_desc, const Index& step) + { + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) + { + threadwise_transfer_.MoveSrcSliceWindow(src_desc, step); + } + } + + __device__ void MoveDstSliceWindow(const DstDesc& dst_desc, const Index& step) + { + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) + { + threadwise_transfer_.MoveDstSliceWindow(dst_desc, step); + } + } + + // With the assumption, scale buffer don't need move slice window method + + private: + static constexpr auto thread_cluster_desc_ = + make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{}); + + using ThreadwiseTransfer = + ThreadwiseTensorSliceTransfer_v3r1_dequant; + + ThreadwiseTransfer threadwise_transfer_; +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/device_gemm_dequantB.hpp b/include/ck/tensor_operation/gpu/device/device_gemm_dequantB.hpp new file mode 100644 index 0000000000..acb18efabf --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/device_gemm_dequantB.hpp @@ -0,0 +1,46 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/tensor_operation/gpu/device/device_base.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +// Dequantization of input tensor could not be decoupled from gridwisegemm pipeline +// As input tensor thread buffer declared inside blockwise-gemm pipeline. + +template +struct DeviceGemm_dequantB : public BaseOperator +{ + virtual std::unique_ptr + MakeArgumentPointer(const void* p_a, + const void* p_b, + const void* p_scale, + void* p_c, + ck::index_t M, + ck::index_t N, + ck::index_t K, + ck::index_t StrideA, + ck::index_t StrideB, + ck::index_t StrideC, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_wmma_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_wmma_cshuffle.hpp index b32f3a8daa..d35645c068 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_wmma_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_wmma_cshuffle.hpp @@ -62,10 +62,10 @@ template struct DeviceBatchedContractionMultipleD_Wmma_CShuffle @@ -123,15 +123,32 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle static constexpr auto I1 = Number<1>{}; static constexpr auto I2 = Number<2>{}; static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto I6 = Number<6>{}; // K1 = Max Vector Access Pixels static constexpr auto K1Number = Number{}; + static constexpr auto MWaves = MPerBlock / (MRepeat * MPerWmma); + static constexpr auto NWaves = NPerBlock / (NRepeat * NPerWmma); + static constexpr auto WmmaK = K1 == 16 ? 32 : 16; + + static constexpr auto AEnableLds_auto = NWaves == 1 ? false : true; + static constexpr auto BEnableLds_auto = MWaves == 1 ? false : true; + + // If true, LDS is used unconditionally + static constexpr auto AEnableLds_manu = false; + static constexpr auto BEnableLds_manu = false; + + static constexpr auto AEnableLds = AEnableLds_auto || AEnableLds_manu || (NumPrefetch > 1); + static constexpr auto BEnableLds = BEnableLds_auto || BEnableLds_manu || (NumPrefetch > 1); + static constexpr auto matrix_padder = - MatrixPadder{MPerBlock, NPerBlock, K0PerBlock* K1}; + MatrixPadder{MPerBlock, NPerBlock, KPerBlock}; // Assume: A[G0, G1, ..., M0, M1, M2, ..., K0, K1, K2, ...] - static auto MakeAGridDescriptor_M_K(const std::vector& a_gs_ms_ks_lengths_vec, - const std::vector& a_gs_ms_ks_strides_vec) + static auto MakeAGridDescriptor(const std::vector& a_gs_ms_ks_lengths_vec, + const std::vector& a_gs_ms_ks_strides_vec) { assert(a_gs_ms_ks_lengths_vec.size() == NumDimG + NumDimM + NumDimK && a_gs_ms_ks_strides_vec.size() == NumDimG + NumDimM + NumDimK); @@ -158,36 +175,72 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle // lengths for K0, K1, ... const auto kLengths = get_container_subset(a_ms_ks_lengths, kDimIds); - if constexpr(ASpec == TensorSpecialization::Packed) + const auto a_grid_desc_m_k = [&]() { + if constexpr(ASpec == TensorSpecialization::Packed) + { + auto M = container_reduce(mLengths, math::multiplies{}, Number<1>{}); + auto K = container_reduce(kLengths, math::multiplies{}, Number<1>{}); + const auto a_grid_desc_mraw_kraw = make_naive_tensor_descriptor( + make_tuple(M, K), + make_tuple(a_ms_ks_strides[Number{}], + a_ms_ks_strides[Number{}])); + return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw); + } + else + { + // naive tensor A[M0, M1, M2, ..., K0, K1, K2...] + const auto a_grid_desc_ms_ks = + make_naive_tensor_descriptor(a_ms_ks_lengths, a_ms_ks_strides); + + // transformed tensor A[MRaw = M0 * M1 * M2 * ... , KRaw = K0 * K1 * K2 * ...] + const auto a_grid_desc_mraw_kraw = transform_tensor_descriptor( + a_grid_desc_ms_ks, + make_tuple(make_merge_transform(mLengths), make_merge_transform(kLengths)), + make_tuple(mDimIds, kDimIds), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw); + } + }(); + + const auto M = a_grid_desc_m_k.GetLength(I0); + const auto K = a_grid_desc_m_k.GetLength(I1); + assert(K % K1 == 0); + + if constexpr(AEnableLds) { - auto M = container_reduce(mLengths, math::multiplies{}, Number<1>{}); - auto K = container_reduce(kLengths, math::multiplies{}, Number<1>{}); - const auto a_grid_desc_mraw_kraw = make_naive_tensor_descriptor( - make_tuple(M, K), - make_tuple(a_ms_ks_strides[Number{}], - a_ms_ks_strides[Number{}])); - return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw); + const index_t K0 = K / K1; + + return transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); } else { - // naive tensor A[M0, M1, M2, ..., K0, K1, K2...] - const auto a_grid_desc_ms_ks = - make_naive_tensor_descriptor(a_ms_ks_lengths, a_ms_ks_strides); + constexpr auto A_KRow = 2; + constexpr auto A_K0PerWmma = WmmaK / A_KRow / K1Number; + const auto A_KWmma = K / WmmaK; - // transformed tensor A[MRaw = M0 * M1 * M2 * ... , KRaw = K0 * K1 * K2 * ...] - const auto a_grid_desc_mraw_kraw = transform_tensor_descriptor( - a_grid_desc_ms_ks, - make_tuple(make_merge_transform(mLengths), make_merge_transform(kLengths)), - make_tuple(mDimIds, kDimIds), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw); + const auto M0 = M / MPerBlock; + // 0 1 0 1 2 3 4 5 6 + // M - K <-> A_KWmma - MBlock*MRepeat - MWaves - A_K0PerWmma - A_KRow - MPerWmma - A_K1 + return transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple( + A_KWmma, Number{}, Number{}, K1Number)), + make_unmerge_transform( + make_tuple(M0 * MRepeat, Number{}, Number{}))), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 3, 4, 6>{}, Sequence<1, 2, 5>{})); } } // Assume: B[G0, G1, ..., N0, N1, N2, ..., K0, K1, K2, ...] - static auto MakeBGridDescriptor_N_K(const std::vector& b_gs_ns_ks_lengths_vec, - const std::vector& b_gs_ns_ks_strides_vec) + static auto MakeBGridDescriptor(const std::vector& b_gs_ns_ks_lengths_vec, + const std::vector& b_gs_ns_ks_strides_vec) { assert(b_gs_ns_ks_lengths_vec.size() == NumDimG + NumDimN + NumDimK && b_gs_ns_ks_strides_vec.size() == NumDimG + NumDimN + NumDimK); @@ -214,30 +267,66 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle // lengths for N0, N1, ... const auto nLengths = get_container_subset(b_ns_ks_lengths, nDimIds); - if constexpr(BSpec == TensorSpecialization::Packed) + const auto b_grid_desc_n_k = [&]() { + if constexpr(BSpec == TensorSpecialization::Packed) + { + auto N = container_reduce(nLengths, math::multiplies{}, Number<1>{}); + auto K = container_reduce(kLengths, math::multiplies{}, Number<1>{}); + const auto b_grid_desc_nraw_kraw = make_naive_tensor_descriptor( + make_tuple(N, K), + make_tuple(b_ns_ks_strides[Number{}], + b_ns_ks_strides[Number{}])); + return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw); + } + else + { + // naive tensor B[N0, N1, N2, ..., K0, K1, K2, ...] + const auto b_grid_desc_ns_ks = + make_naive_tensor_descriptor(b_ns_ks_lengths, b_ns_ks_strides); + + // transformed tensor B[NRaw = N0 * N1 * N2 * ..., KRaw = K0 * K1 * K2 * ...] + const auto b_grid_desc_nraw_kraw = transform_tensor_descriptor( + b_grid_desc_ns_ks, + make_tuple(make_merge_transform(nLengths), make_merge_transform(kLengths)), + make_tuple(nDimIds, kDimIds), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw); + } + }(); + + const auto N = b_grid_desc_n_k.GetLength(I0); + const auto K = b_grid_desc_n_k.GetLength(I1); + assert(K % K1 == 0); + + if constexpr(BEnableLds) { - auto N = container_reduce(nLengths, math::multiplies{}, Number<1>{}); - auto K = container_reduce(kLengths, math::multiplies{}, Number<1>{}); - const auto b_grid_desc_nraw_kraw = make_naive_tensor_descriptor( - make_tuple(N, K), - make_tuple(b_ns_ks_strides[Number{}], - b_ns_ks_strides[Number{}])); - return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw); + const index_t K0 = K / K1; + + return transform_tensor_descriptor( + b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), + make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); } else { - // naive tensor B[N0, N1, N2, ..., K0, K1, K2, ...] - const auto b_grid_desc_ns_ks = - make_naive_tensor_descriptor(b_ns_ks_lengths, b_ns_ks_strides); + constexpr auto B_KRow = 2; + constexpr auto B_K0PerWmma = WmmaK / B_KRow / K1Number; + const auto B_KWmma = K / WmmaK; - // transformed tensor B[NRaw = N0 * N1 * N2 * ..., KRaw = K0 * K1 * K2 * ...] - const auto b_grid_desc_nraw_kraw = transform_tensor_descriptor( - b_grid_desc_ns_ks, - make_tuple(make_merge_transform(nLengths), make_merge_transform(kLengths)), - make_tuple(nDimIds, kDimIds), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw); + const auto N0 = N / NPerBlock; + // 0 1 0 1 2 3 4 5 6 + // M - K <-> A_KWmma - MBlock*MRepeat - MWaves - A_K0PerWmma - A_KRow - MPerWmma - A_K1 + return transform_tensor_descriptor( + b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple( + B_KWmma, Number{}, Number{}, K1Number)), + make_unmerge_transform( + make_tuple(N0 * NRepeat, Number{}, Number{}))), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 3, 4, 6>{}, Sequence<1, 2, 5>{})); } } @@ -393,8 +482,6 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle } // Gridwise descriptor, mapping to whole given provblem. - using AGridDesc_M_K = decltype(MakeAGridDescriptor_M_K({}, {})); - using BGridDesc_N_K = decltype(MakeBGridDescriptor_N_K({}, {})); using DsGridDesc_M_N = remove_cvref_t; using EGridDesc_M_N = decltype(MakeEGridDescriptor_M_N({}, {})); @@ -449,45 +536,11 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle EGridDesc_G_M_N e_grid_desc_g_m_n_; }; - // A desc for source in blockwise copy - template - __host__ __device__ static constexpr auto - MakeAGridDescriptor_K0_M_K1(const AGridDesc_M_K& a_grid_desc_m_k) - { - const auto M = a_grid_desc_m_k.GetLength(I0); - const auto K = a_grid_desc_m_k.GetLength(I1); - - const auto AK0 = K / K1; - - return transform_tensor_descriptor( - a_grid_desc_m_k, - make_tuple(make_unmerge_transform(make_tuple(AK0, K1)), make_pass_through_transform(M)), - make_tuple(Sequence<1>{}, Sequence<0>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); - } - - // B desc for source in blockwise copy - template - __host__ __device__ static constexpr auto - MakeBGridDescriptor_K0_N_K1(const BGridDesc_N_K& b_grid_desc_n_k) - { - const auto N = b_grid_desc_n_k.GetLength(I0); - const auto K = b_grid_desc_n_k.GetLength(I1); - - const auto BK0 = K / K1; - - return transform_tensor_descriptor( - b_grid_desc_n_k, - make_tuple(make_unmerge_transform(make_tuple(BK0, K1)), make_pass_through_transform(N)), - make_tuple(Sequence<1>{}, Sequence<0>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); - } - - using AGridDesc_K0_M_K1 = decltype(DeviceOp::MakeAGridDescriptor_K0_M_K1(AGridDesc_M_K{})); - using BGridDesc_K0_N_K1 = decltype(DeviceOp::MakeBGridDescriptor_K0_N_K1(BGridDesc_N_K{})); + using AGridDesc = decltype(DeviceOp::MakeAGridDescriptor({}, {})); + using BGridDesc = decltype(DeviceOp::MakeBGridDescriptor({}, {})); // GridwiseOp - using GridwiseOp = GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle< + using GridwiseOp = GridwiseGemmMultipleD_Wmma< // DataType Family ADataType, BDataType, @@ -496,8 +549,8 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle DsDataType, EDataType, // InMemory Data Descriptor - AGridDesc_K0_M_K1, - BGridDesc_K0_N_K1, + AGridDesc, + BGridDesc, DsGridDesc_M_N, EGridDesc_M_N, // ElementwiseOp Family @@ -508,9 +561,9 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle // Tiling Family MPerBlock, NPerBlock, - K0PerBlock, - MPerWMMA, - NPerWMMA, + KPerBlock, + MPerWmma, + NPerWmma, K1, MRepeat, NRepeat, @@ -523,6 +576,7 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_K1, false, // AThreadTransferSrcResetCoordinateAfterRun, + AEnableLds, ABlockLdsAddExtraM, BBlockTransferThreadClusterLengths_K0_N_K1, BBlockTransferThreadClusterArrangeOrder, @@ -531,6 +585,7 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_K1, false, // BThreadTransferSrcResetCoordinateAfterRun, + BEnableLds, BBlockLdsAddExtraN, CShuffleMRepeatPerShuffle, CShuffleNRepeatPerShuffle, @@ -564,16 +619,14 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle p_b_grid_{static_cast(p_b_grid)}, p_ds_grid_{}, p_e_grid_{static_cast(p_e_grid)}, - a_grid_desc_m_k_{}, - b_grid_desc_n_k_{}, + a_grid_desc_{}, + b_grid_desc_{}, ds_grid_desc_m_n_{}, e_grid_desc_m_n_{}, ds_grid_desc_g_m_n_{ DeviceOp::MakeDsGridDescriptor_G_M_N(ds_gs_ms_ns_lengths, ds_gs_ms_ns_strides)}, e_grid_desc_g_m_n_{ DeviceOp::MakeEGridDescriptor_G_M_N(e_gs_ms_ns_lengths, e_gs_ms_ns_strides)}, - a_grid_desc_k0_m_k1_{}, - b_grid_desc_k0_n_k1_{}, ds_grid_desc_mblock_mperblock_nblock_nperblock{}, e_grid_desc_mblock_mperblock_nblock_nperblock{}, block_2_ctile_map_{}, @@ -600,10 +653,8 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle p_ds_grid_(i) = static_cast(p_ds_grid[i]); }); - a_grid_desc_m_k_ = - DeviceOp::MakeAGridDescriptor_M_K(a_gs_ms_ks_lengths, a_gs_ms_ks_strides); - b_grid_desc_n_k_ = - DeviceOp::MakeBGridDescriptor_N_K(b_gs_ns_ks_lengths, b_gs_ns_ks_strides); + a_grid_desc_ = DeviceOp::MakeAGridDescriptor(a_gs_ms_ks_lengths, a_gs_ms_ks_strides); + b_grid_desc_ = DeviceOp::MakeBGridDescriptor(b_gs_ns_ks_lengths, b_gs_ns_ks_strides); ds_grid_desc_m_n_ = DeviceOp::MakeDsGridDescriptor_M_N(ds_gs_ms_ns_lengths, ds_gs_ms_ns_strides); @@ -611,9 +662,6 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle e_grid_desc_m_n_ = DeviceOp::MakeEGridDescriptor_M_N(e_gs_ms_ns_lengths, e_gs_ms_ns_strides); - a_grid_desc_k0_m_k1_ = DeviceOp::MakeAGridDescriptor_K0_M_K1(a_grid_desc_m_k_); - b_grid_desc_k0_n_k1_ = DeviceOp::MakeBGridDescriptor_K0_N_K1(b_grid_desc_n_k_); - block_2_ctile_map_ = GridwiseOp::MakeDefaultBlock2CTileMap(e_grid_desc_m_n_, M01, N01); ds_grid_desc_mblock_mperblock_nblock_nperblock = @@ -644,16 +692,13 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle EDataType* p_e_grid_; // Tensor Descriptors - AGridDesc_M_K a_grid_desc_m_k_; - BGridDesc_N_K b_grid_desc_n_k_; + AGridDesc a_grid_desc_; + BGridDesc b_grid_desc_; DsGridDesc_M_N ds_grid_desc_m_n_; EGridDesc_M_N e_grid_desc_m_n_; DsGridDesc_G_M_N ds_grid_desc_g_m_n_; EGridDesc_G_M_N e_grid_desc_g_m_n_; - AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_; - BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_; - typename GridwiseOp::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock ds_grid_desc_mblock_mperblock_nblock_nperblock; typename GridwiseOp::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock @@ -686,6 +731,11 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle // Batch Offset ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch_; + + // for checking vector load/store + // index_t MRaw_; + // index_t NRaw_; + // index_t KRaw_; }; // Invoker @@ -700,8 +750,17 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle const index_t grid_size = arg.block_2_ctile_map_.CalculateGridSize(arg.e_grid_desc_m_n_) * G; - const auto K = - arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2); + const auto K = [&]() { + if constexpr(AEnableLds) + { + return arg.a_grid_desc_.GetLength(I0) * arg.a_grid_desc_.GetLength(I2); + } + else + { + return arg.a_grid_desc_.GetLength(I0) * arg.a_grid_desc_.GetLength(I3) * + arg.a_grid_desc_.GetLength(I4) * arg.a_grid_desc_.GetLength(I6); + } + }(); auto launch_kernel = [&](auto has_main_k_block_loop) { constexpr bool has_main_loop = has_main_k_block_loop.value; @@ -712,8 +771,8 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle BDataType, typename GridwiseOp::DsGridPointer, EDataType, - DeviceOp::AGridDesc_K0_M_K1, - DeviceOp::BGridDesc_K0_N_K1, + DeviceOp::AGridDesc, + DeviceOp::BGridDesc, typename GridwiseOp::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, typename GridwiseOp::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, AElementwiseOperation, @@ -733,8 +792,8 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle arg.p_ds_grid_, arg.p_e_grid_, G, - arg.a_grid_desc_k0_m_k1_, - arg.b_grid_desc_k0_n_k1_, + arg.a_grid_desc_, + arg.b_grid_desc_, arg.ds_grid_desc_mblock_mperblock_nblock_nperblock, arg.e_grid_desc_mblock_mperblock_nblock_nperblock, arg.a_element_op_, @@ -774,6 +833,7 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle { if constexpr(!(is_same_v || is_same_v)) { + printf("DeviceOp: Arch check failure\n"); return false; } } @@ -782,12 +842,13 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle return false; } - if(!GridwiseOp::CheckValidity(arg.a_grid_desc_k0_m_k1_, - arg.b_grid_desc_k0_n_k1_, + if(!GridwiseOp::CheckValidity(arg.a_grid_desc_, + arg.b_grid_desc_, arg.ds_grid_desc_m_n_, arg.e_grid_desc_m_n_, arg.block_2_ctile_map_)) { + printf("GridwiseOp: Validity check failure\n"); return false; } @@ -800,16 +861,18 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle if constexpr(ABlockTransferSrcVectorDim == 1) { if(!(arg.a_mz_stride_ == 1 && - arg.a_grid_desc_k0_m_k1_.GetLength(I1) % ABlockTransferSrcScalarPerVector == 0)) + arg.a_grid_desc_.GetLength(I1) % ABlockTransferSrcScalarPerVector == 0)) { + printf("DeviceOp: Vector Access A-m check failure\n"); return false; } } else { if(!(arg.a_kz_stride_ == 1 && - arg.a_grid_desc_k0_m_k1_.GetLength(I2) % ABlockTransferSrcScalarPerVector == 0)) + arg.a_grid_desc_.GetLength(I2) % ABlockTransferSrcScalarPerVector == 0)) { + printf("DeviceOp: Vector Access A-k check failure\n"); return false; } } @@ -818,16 +881,18 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle if constexpr(BBlockTransferSrcVectorDim == 1) { if(!(arg.b_nz_stride_ == 1 && - arg.b_grid_desc_k0_n_k1_.GetLength(I1) % BBlockTransferSrcScalarPerVector == 0)) + arg.b_grid_desc_.GetLength(I1) % BBlockTransferSrcScalarPerVector == 0)) { + printf("DeviceOp: Vector Access B-n check failure\n"); return false; } } else { if(!(arg.b_kz_stride_ == 1 && - arg.b_grid_desc_k0_n_k1_.GetLength(I2) % BBlockTransferSrcScalarPerVector == 0)) + arg.b_grid_desc_.GetLength(I2) % BBlockTransferSrcScalarPerVector == 0)) { + printf("DeviceOp: Vector Access B-k check failure\n"); return false; } } @@ -841,6 +906,7 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle CDEShuffleBlockTransferScalarPerVector_NPerBlock == 0)) { + printf("DeviceOp: Vector Access D-n check failure\n"); valid_d_access = false; } }); @@ -857,6 +923,7 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle 0) || CDEShuffleBlockTransferScalarPerVector_NPerBlock == 1)) { + printf("DeviceOp: Vector Access E-n check failure\n"); return false; } @@ -967,14 +1034,18 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle << BlockSize << ", " << MPerBlock << ", " << NPerBlock << ", " - << K0PerBlock << ", " + << KPerBlock << ", " << K1 << ", " - << MPerWMMA << ", " - << NPerWMMA << ", " + << MPerWmma << ", " + << NPerWmma << ", " << MRepeat << ", " << NRepeat << ">" - << " NumPrefetch: " + << " AEnableLds: " + << AEnableLds << ", " + << "BEnableLds: " + << BEnableLds << ", " + << "NumPrefetch: " << NumPrefetch << ", " << "LoopScheduler: " << LoopSchedToString[LoopSched] << ", " diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp new file mode 100644 index 0000000000..e218ee5c15 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp @@ -0,0 +1,1729 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm_permute.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp" +#include "ck/tensor_operation/operator_transform/transform_contraction_to_gemm_arraybase.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_batched_gemm_softmax_gemm_wmma_cshuffle(const ADataType* __restrict__ p_a_grid, + const B0DataType* __restrict__ p_b0_grid, + const B1DataType* __restrict__ p_b1_grid, + CDataType* __restrict__ p_c_grid, + index_t M, + index_t N, + index_t K, + index_t O, + index_t G0, + index_t G1, + float alpha, + bool input_permute, + bool output_permute) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__)) + + // clang-format off +// *************************************************** +// Make Tensor Descriptors + constexpr index_t array_size = 4; + std::array a_gs_ms_ks_lengths{G0, G1, M, K}; + std::array a_gs_ms_ks_strides = + input_permute + ? std::array{M * G1 * K, K, G1 * K, 1} // A layout [G0, M, G1, K] + : std::array{G1 * M * K, M * K, K, 1}; // A layout [G0, G1, M, K] + + std::array b0_gs_ns_ks_lengths{G0, G1, N, K}; + std::array b0_gs_ns_ks_strides = + input_permute + ? std::array{N * G1 * K, K, G1 * K, 1} // B0 layout [G0, N, G1, K] + : std::array{G1 * N * K, N * K, K, 1}; // B0 layout [G0, G1, N, K] + + std::array b1_gs_os_ns_lengths{G0, G1, O, N}; + std::array b1_gs_os_ns_strides = + input_permute + ? std::array{N * G1 * O, O, 1, G1 * O} // B1 layout [G0, N, G1, O] + : std::array{G1 * N * O, N * O, 1, O}; // B1 layout [G0, G1, N, O] + + std::array c_gs_ms_os_lengths{G0, G1, M, O}; + std::array c_gs_ms_os_strides = + output_permute + ? std::array{M * G1 * O, O, G1 * O, 1} // C layout [G0, M, G1, O] + : std::array{G1 * M * O, M * O, O, 1}; // C layout [G0, G1, M, O] + + const auto a_element_op = AElementwiseOperation{}; + const auto b0_element_op = B0ElementwiseOperation{}; + const auto acc0_element_op = AccElementwiseOperation{alpha}; + const auto b1_element_op = B1ElementwiseOperation{}; + const auto c_element_op = CElementwiseOperation{}; + // fail to reuse DeviceOp::MakeArgument() because of the __device__ function required. + + const auto a_grid_desc = DeviceOp::MakeAGridDescriptor(a_gs_ms_ks_lengths, a_gs_ms_ks_strides); + const auto b0_grid_desc = + DeviceOp::MakeB0GridDescriptor(b0_gs_ns_ks_lengths, b0_gs_ns_ks_strides); + const auto b1_grid_desc = + DeviceOp::MakeB1GridDescriptor(b1_gs_os_ns_lengths, b1_gs_os_ns_strides); + const auto c_grid_desc_m_n = + DeviceOp::Transform::MakeCGridDescriptor_M_N(c_gs_ms_os_lengths, c_gs_ms_os_strides); + const auto c_grid_desc_mblock_mperblock_nblock_nperblock = + GridwiseOp::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(c_grid_desc_m_n); + const auto block_2_ctile_map = GridwiseOp::MakeDefaultBlock2CTileMap(c_grid_desc_m_n, 1, 1); + + const auto a_grid_desc_g_m_k = + DeviceOp::Transform::MakeAGridDescriptor_G_M_K(a_gs_ms_ks_lengths, a_gs_ms_ks_strides); + const auto b0_grid_desc_g_l_k = + DeviceOp::Transform::MakeB0GridDescriptor_G_N_K(b0_gs_ns_ks_lengths, b0_gs_ns_ks_strides); + const auto b1_grid_desc_g_n_l = + DeviceOp::Transform::MakeB1GridDescriptor_G_N_K(b1_gs_os_ns_lengths, b1_gs_os_ns_strides); + const auto c_grid_desc_g_m_n = + DeviceOp::Transform::MakeCGridDescriptor_G_M_N(c_gs_ms_os_lengths, c_gs_ms_os_strides); + const auto compute_base_ptr_of_batch = + typename DeviceOp::ComputeBasePtrOfStridedBatch{a_grid_desc_g_m_k, b0_grid_desc_g_l_k, b1_grid_desc_g_n_l, c_grid_desc_g_m_n}; + index_t batch_count = c_grid_desc_g_m_n.GetLength(Number<0>{}); + const auto c0_matrix_mask = typename DeviceOp::C0MatrixMask{b0_grid_desc_g_l_k.GetLength(Number<1>{})}; + + // clang-format on + __shared__ char p_shared[GridwiseOp::GetSharedMemoryNumberOfByte()]; + const index_t num_blocks_per_batch = + __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); + const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); + + const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_base_ptr_of_batch.GetABasePtr(g_idx))); + const long_index_t b0_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_base_ptr_of_batch.GetB0BasePtr(g_idx))); + const long_index_t b1_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_base_ptr_of_batch.GetB1BasePtr(g_idx))); + const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_base_ptr_of_batch.GetCBasePtr(g_idx))); + + GridwiseOp::template Run(p_a_grid + a_batch_offset, + p_b0_grid + b0_batch_offset, + p_b1_grid + b1_batch_offset, + p_c_grid + c_batch_offset, + p_shared, + a_grid_desc, + b0_grid_desc, + b1_grid_desc, + c_grid_desc_mblock_mperblock_nblock_nperblock, + a_element_op, + b0_element_op, + acc0_element_op, + b1_element_op, + c_element_op, + c0_matrix_mask, + block_2_ctile_map); +#else + ignore = p_a_grid; + ignore = p_b0_grid; + ignore = p_b1_grid; + ignore = p_c_grid; + ignore = M; + ignore = N; + ignore = K; + ignore = O; + ignore = G0; + ignore = G1; + ignore = input_permute; + ignore = output_permute; +#endif // end of if (defined(__gfx11__)) +} + +// Self-Attention +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_wmma_self_attention_forward(const QKVDataType* __restrict__ p_qkv_grid, + ODataType* __restrict__ p_out_grid, + index_t batch_size, + index_t sequence_length, + index_t head_count, + index_t head_size, + float alpha) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__)) + + // clang-format off +// *************************************************** +// Make Tensor Descriptors +// o Self-attention(packed QKV): [batchSize, sequenceLength, headCount, 3, headSize] + constexpr index_t array_size = 4; + std::array qk_gs_ms_ks_lengths{batch_size, head_count, sequence_length, head_size}; + std::array qk_gs_ms_ks_strides{sequence_length * head_count * 3 * head_size, 3 * head_size, head_count * 3 * head_size, 1}; + + std::array v_gs_os_ns_lengths{batch_size, head_count, head_size, sequence_length}; + std::array v_gs_os_ns_strides{sequence_length * head_count * 3 * head_size, 3 * head_size, 1, head_count * 3 * head_size}; + + std::array c_gs_ms_os_lengths{batch_size, head_count, sequence_length, head_size}; + std::array c_gs_ms_os_strides{sequence_length * head_count * head_size, head_size, head_count * head_size, 1}; + + + const auto a_element_op = AElementwiseOperation{}; + const auto b0_element_op = B0ElementwiseOperation{}; + const auto acc0_element_op = AccElementwiseOperation{alpha}; + const auto b1_element_op = B1ElementwiseOperation{}; + const auto c_element_op = CElementwiseOperation{}; + + const auto a_grid_desc = DeviceOp::MakeAGridDescriptor(qk_gs_ms_ks_lengths, qk_gs_ms_ks_strides); + const auto b0_grid_desc = + DeviceOp::MakeB0GridDescriptor(qk_gs_ms_ks_lengths, qk_gs_ms_ks_strides); + const auto b1_grid_desc = + DeviceOp::MakeB1GridDescriptor(v_gs_os_ns_lengths, v_gs_os_ns_strides); + const auto c_grid_desc_m_n = + DeviceOp::Transform::MakeCGridDescriptor_M_N(c_gs_ms_os_lengths, c_gs_ms_os_strides); + const auto c_grid_desc_mblock_mperblock_nblock_nperblock = + GridwiseOp::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(c_grid_desc_m_n); + const auto block_2_ctile_map = GridwiseOp::MakeDefaultBlock2CTileMap(c_grid_desc_m_n, 1, 1); + + const auto a_grid_desc_g_m_k = + DeviceOp::Transform::MakeAGridDescriptor_G_M_K(qk_gs_ms_ks_lengths, qk_gs_ms_ks_strides); + const auto b0_grid_desc_g_l_k = + DeviceOp::Transform::MakeB0GridDescriptor_G_N_K(qk_gs_ms_ks_lengths, qk_gs_ms_ks_strides); + const auto b1_grid_desc_g_n_l = + DeviceOp::Transform::MakeB1GridDescriptor_G_N_K(v_gs_os_ns_lengths, v_gs_os_ns_strides); + const auto c_grid_desc_g_m_n = + DeviceOp::Transform::MakeCGridDescriptor_G_M_N(c_gs_ms_os_lengths, c_gs_ms_os_strides); + const auto compute_base_ptr_of_batch = + typename DeviceOp::ComputeBasePtrOfStridedBatch{a_grid_desc_g_m_k, b0_grid_desc_g_l_k, b1_grid_desc_g_n_l, c_grid_desc_g_m_n}; + index_t batch_count = c_grid_desc_g_m_n.GetLength(Number<0>{}); + const auto c0_matrix_mask = typename DeviceOp::C0MatrixMask{b0_grid_desc_g_l_k.GetLength(Number<1>{})}; + + // clang-format on + __shared__ char p_shared[GridwiseOp::GetSharedMemoryNumberOfByte()]; + const index_t num_blocks_per_batch = + __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); + const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); + + const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_base_ptr_of_batch.GetABasePtr(g_idx))); + const long_index_t b0_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_base_ptr_of_batch.GetB0BasePtr(g_idx))); + const long_index_t b1_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_base_ptr_of_batch.GetB1BasePtr(g_idx))); + const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_base_ptr_of_batch.GetCBasePtr(g_idx))); + + const index_t qkv_gap = __builtin_amdgcn_readfirstlane(head_size); +#ifdef CK_SELF_ATTN_DEBUG + if(get_thread_global_1d_id() == 0) + { + printf("batch_size: %d\n", batch_size); + printf("sequence_length: %d\n", sequence_length); + printf("head_count: %d\n", head_count); + printf("head_size: %d\n", head_size); + printf("qkv_gap: %d\n", qkv_gap); + printf("get_grid_size(): %d\n", get_grid_size()); + printf("batch_count: %d\n", batch_count); + printf("blockid: %d\n", get_block_1d_id()); + printf("num_blocks_per_batch: %d\n", num_blocks_per_batch); + printf("g_idx: %d\n", g_idx); + printf("a_batch_offset: %ld\n", a_batch_offset); + printf("b0_batch_offset: %ld\n", b0_batch_offset); + printf("b1_batch_offset: %ld\n", b1_batch_offset); + } +#endif + GridwiseOp::template Run(p_qkv_grid + 0 * qkv_gap + a_batch_offset, + p_qkv_grid + 1 * qkv_gap + b0_batch_offset, + p_qkv_grid + 2 * qkv_gap + b1_batch_offset, + p_out_grid + c_batch_offset, + p_shared, + a_grid_desc, + b0_grid_desc, + b1_grid_desc, + c_grid_desc_mblock_mperblock_nblock_nperblock, + a_element_op, + b0_element_op, + acc0_element_op, + b1_element_op, + c_element_op, + c0_matrix_mask, + block_2_ctile_map); +#else + ignore = p_qkv_grid; + ignore = p_out_grid; + ignore = batch_size; + ignore = sequence_length; + ignore = head_count; + ignore = head_size; + ignore = alpha; +#endif // end of if (defined(__gfx11__)) +} +// Cross-Attention +// Self-Attention +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_wmma_cross_attention_forward(const QDataType* __restrict__ p_q_grid, + const KVDataType* __restrict__ p_kv_grid, + ODataType* __restrict__ p_out_grid, + index_t batch_size, + index_t q_sequence_length, + index_t kv_sequence_length, + index_t head_count, + index_t head_size, + float alpha) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__)) + + // clang-format off +// *************************************************** +// Make Tensor Descriptors +// o Self-attention(packed QKV): [batchSize, sequenceLength, headCount, 3, headSize] + constexpr index_t array_size = 4; + std::array q_gs_ms_ks_lengths{batch_size, head_count, q_sequence_length, head_size}; + std::array q_gs_ms_ks_strides{q_sequence_length * head_count * head_size, head_size, head_count * head_size, 1}; + + std::array k_gs_ms_ks_lengths{batch_size, head_count, kv_sequence_length, head_size}; + std::array k_gs_ms_ks_strides{kv_sequence_length * head_count * 2 * head_size, 2 * head_size, head_count * 2 * head_size, 1}; + + std::array v_gs_os_ns_lengths{batch_size, head_count, head_size, kv_sequence_length}; + std::array v_gs_os_ns_strides{kv_sequence_length * head_count * 2 * head_size, 2 * head_size, 1, head_count * 2 * head_size}; + + std::array c_gs_ms_os_lengths{batch_size, head_count, q_sequence_length, head_size}; + std::array c_gs_ms_os_strides{q_sequence_length * head_count * head_size, head_size, head_count * head_size, 1}; + + + const auto a_element_op = AElementwiseOperation{}; + const auto b0_element_op = B0ElementwiseOperation{}; + const auto acc0_element_op = AccElementwiseOperation{alpha}; + const auto b1_element_op = B1ElementwiseOperation{}; + const auto c_element_op = CElementwiseOperation{}; + + const auto a_grid_desc = DeviceOp::MakeAGridDescriptor(q_gs_ms_ks_lengths, q_gs_ms_ks_strides); + const auto b0_grid_desc = + DeviceOp::MakeB0GridDescriptor(k_gs_ms_ks_lengths, k_gs_ms_ks_strides); + const auto b1_grid_desc = + DeviceOp::MakeB1GridDescriptor(v_gs_os_ns_lengths, v_gs_os_ns_strides); + const auto c_grid_desc_m_n = + DeviceOp::Transform::MakeCGridDescriptor_M_N(c_gs_ms_os_lengths, c_gs_ms_os_strides); + const auto c_grid_desc_mblock_mperblock_nblock_nperblock = + GridwiseOp::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(c_grid_desc_m_n); + const auto block_2_ctile_map = GridwiseOp::MakeDefaultBlock2CTileMap(c_grid_desc_m_n, 1, 1); + + const auto a_grid_desc_g_m_k = + DeviceOp::Transform::MakeAGridDescriptor_G_M_K(q_gs_ms_ks_lengths, q_gs_ms_ks_strides); + const auto b0_grid_desc_g_l_k = + DeviceOp::Transform::MakeB0GridDescriptor_G_N_K(k_gs_ms_ks_lengths, k_gs_ms_ks_strides); + const auto b1_grid_desc_g_n_l = + DeviceOp::Transform::MakeB1GridDescriptor_G_N_K(v_gs_os_ns_lengths, v_gs_os_ns_strides); + const auto c_grid_desc_g_m_n = + DeviceOp::Transform::MakeCGridDescriptor_G_M_N(c_gs_ms_os_lengths, c_gs_ms_os_strides); + const auto compute_base_ptr_of_batch = + typename DeviceOp::ComputeBasePtrOfStridedBatch{a_grid_desc_g_m_k, b0_grid_desc_g_l_k, b1_grid_desc_g_n_l, c_grid_desc_g_m_n}; + index_t batch_count = c_grid_desc_g_m_n.GetLength(Number<0>{}); + const auto c0_matrix_mask = typename DeviceOp::C0MatrixMask{b0_grid_desc_g_l_k.GetLength(Number<1>{})}; + + // clang-format on + __shared__ char p_shared[GridwiseOp::GetSharedMemoryNumberOfByte()]; + const index_t num_blocks_per_batch = + __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); + const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); + + const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_base_ptr_of_batch.GetABasePtr(g_idx))); + const long_index_t b0_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_base_ptr_of_batch.GetB0BasePtr(g_idx))); + const long_index_t b1_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_base_ptr_of_batch.GetB1BasePtr(g_idx))); + const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_base_ptr_of_batch.GetCBasePtr(g_idx))); + + const index_t kv_gap = __builtin_amdgcn_readfirstlane(head_size); +#ifdef CK_SELF_ATTN_DEBUG + if(get_thread_global_1d_id() == 0) + { + printf("batch_size: %d\n", batch_size); + printf("q_sequence_length: %d\n", q_sequence_length); + printf("k_sequence_length: %d\n", kv_sequence_length); + printf("head_count: %d\n", head_count); + printf("head_size: %d\n", head_size); + printf("kv_gap: %d\n", kv_gap); + printf("get_grid_size(): %d\n", get_grid_size()); + printf("batch_count: %d\n", batch_count); + printf("blockid: %d\n", get_block_1d_id()); + printf("num_blocks_per_batch: %d\n", num_blocks_per_batch); + printf("g_idx: %d\n", g_idx); + printf("a_batch_offset: %ld\n", a_batch_offset); + printf("b0_batch_offset: %ld\n", b0_batch_offset); + printf("b1_batch_offset: %ld\n", b1_batch_offset); + } +#endif + GridwiseOp::template Run(p_q_grid + a_batch_offset, + p_kv_grid + 0 * kv_gap + b0_batch_offset, + p_kv_grid + 1 * kv_gap + b1_batch_offset, + p_out_grid + c_batch_offset, + p_shared, + a_grid_desc, + b0_grid_desc, + b1_grid_desc, + c_grid_desc_mblock_mperblock_nblock_nperblock, + a_element_op, + b0_element_op, + acc0_element_op, + b1_element_op, + c_element_op, + c0_matrix_mask, + block_2_ctile_map); +#else + ignore = p_q_grid; + ignore = p_kv_grid; + ignore = p_out_grid; + ignore = batch_size; + ignore = q_sequence_length; + ignore = kv_sequence_length; + ignore = head_count; + ignore = head_size; + ignore = alpha; +#endif // end of if (defined(__gfx11__)) +} +// Computes C = A * B0 * B1 +// MN = MK * KL * LN +// ^^^^^^ (Acc0) +// ^^^^^^^^^^^ (Acc1) +template +struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle + : public DeviceBatchedGemmSoftmaxGemmPermute +{ + static_assert(NumDimG > 0 && NumDimM > 0 && NumDimL > 0 && NumDimK > 0 && NumDimN > 0, + "Number of dimension must be greater than 0"); + + static constexpr index_t NumAcc0Bias = Acc0BiasDataType::Size(); + static constexpr index_t NumAcc1Bias = Acc1BiasDataType::Size(); + + // TODO ANT: implement bias combination + static_assert(NumAcc0Bias == 0 && NumAcc0Bias == 0, "Bias addition is unimplemented"); + + static constexpr index_t NumDimGemm0M = NumDimM; + static constexpr index_t NumDimGemm0N = NumDimL; + static constexpr index_t NumDimGemm0K = NumDimK; + static constexpr index_t NumDimGemm1M = NumDimM; + static constexpr index_t NumDimGemm1N = NumDimN; + static constexpr index_t NumDimGemm1K = NumDimL; + + using DeviceOp = DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto I6 = Number<6>{}; + + static constexpr auto WmmaK = 16; + + static constexpr auto MWaves = MPerBlock / (MRepeat * MPerWmma); + static constexpr auto LWaves = LPerBlock / (LRepeat * LPerWmma); + static constexpr auto NWaves = NPerBlock / (NRepeat * NPerWmma); + + static constexpr auto AEnableLds_auto = LWaves == 1 ? false : true; + static constexpr auto B0EnableLds_auto = MWaves == 1 ? false : true; + static constexpr auto B1EnableLds_auto = MWaves == 1 ? false : true; + + static constexpr auto AEnableLds_manu = false; + static constexpr auto B0EnableLds_manu = true; + static constexpr auto B1EnableLds_manu = true; + + static constexpr auto AEnableLds = AEnableLds_auto || AEnableLds_manu || (NumPrefetch > 1); + static constexpr auto B0EnableLds = B0EnableLds_auto || B0EnableLds_manu || (NumPrefetch > 1); + static constexpr auto B1EnableLds = B1EnableLds_auto || B1EnableLds_manu || (NumPrefetch > 1); + + using Transform = TransformBatchedContractionContractionToBatchedGemmGemm_Wmma< + Sequence, + Sequence, + GemmSpec, + ASpec, + B0Spec, + B1Spec, + CSpec>; + + __host__ __device__ static auto MakeAGridDescriptor( + const std::array& a_gs_ms_ks_lengths_vec, + const std::array& a_gs_ms_ks_strides_vec) + { + if constexpr(AEnableLds) + { + return Transform::MakeAGridDescriptor_AK0_M_AK1( + Transform::MakeAGridDescriptor_M_K(a_gs_ms_ks_lengths_vec, a_gs_ms_ks_strides_vec), + Number{}); + } + else + { + return Transform:: + MakeAGridDescriptor_AKWmma_MBlockRepeat_MWaves_AK0PerWmma_AKRow_MPerWmma_AK1( + Transform::MakeAGridDescriptor_M_K(a_gs_ms_ks_lengths_vec, + a_gs_ms_ks_strides_vec), + Number{}, + Number{}, + Number{}, + Number{}, + Number{}); + } + } + + __host__ __device__ static auto MakeB0GridDescriptor( + const std::array& b0_gs_ls_ks_lengths_vec, + const std::array& b0_gs_ls_ks_strides_vec) + { + if constexpr(B0EnableLds) + { + return Transform::MakeB0GridDescriptor_BK0_N_BK1( + Transform::MakeB0GridDescriptor_N_K(b0_gs_ls_ks_lengths_vec, + b0_gs_ls_ks_strides_vec), + Number{}); + } + else + { + return Transform:: + MakeB0GridDescriptor_BKWmma_LBlockRepeat_LWaves_BK0PerWmma_BKRow_LPerWmma_BK1( + Transform::MakeB0GridDescriptor_N_K(b0_gs_ls_ks_lengths_vec, + b0_gs_ls_ks_strides_vec), + Number{}, + Number{}, + Number{}, + Number{}, + Number{}); + } + } + + __host__ __device__ static auto MakeB1GridDescriptor( + const std::array& b1_gs_ns_ls_lengths_vec, + const std::array& b1_gs_ns_ls_strides_vec) + { + if constexpr(B1EnableLds) + { + return Transform::MakeB1GridDescriptor_BK0_N_BK1( + Transform::MakeB1GridDescriptor_N_K(b1_gs_ns_ls_lengths_vec, + b1_gs_ns_ls_strides_vec), + Number{}); + } + else + { + return Transform:: + MakeB1GridDescriptor_BLWmma_NBlockRepeat_NWaves__BL0PerWmma_BLRow_NPerWmma_BL1( + Transform::MakeB1GridDescriptor_N_K(b1_gs_ns_ls_lengths_vec, + b1_gs_ns_ls_strides_vec), + Number{}, + Number{}, + Number{}, + Number{}, + Number{}); + } + } + + using AGridDesc = decltype(MakeAGridDescriptor({}, {})); + using B0GridDesc = decltype(MakeB0GridDescriptor({}, {})); + using B1GridDesc = decltype(MakeB1GridDescriptor({}, {})); + using CGridDesc_M_N = decltype(Transform::MakeCGridDescriptor_M_N({}, {})); + using AGridDesc_G_M_K = decltype(Transform::MakeAGridDescriptor_G_M_K({}, {})); + using B0GridDesc_G_L_K = decltype(Transform::MakeB0GridDescriptor_G_N_K({}, {})); + using B1GridDesc_G_N_L = decltype(Transform::MakeB1GridDescriptor_G_N_K({}, {})); + using CGridDesc_G_M_N = decltype(Transform::MakeCGridDescriptor_G_M_N({}, {})); + + __host__ __device__ constexpr static auto make_MaskOutPredicate() + { + if constexpr(MaskingSpec == MaskingSpecialization::MaskDisabled) + { + return MaskDisabledPredicate{}; + } + else if constexpr(MaskingSpec == MaskingSpecialization::MaskOutUpperTriangle) + { + return MaskOutUpperTrianglePredicate{}; + } + } + using C0MatrixMask = C0MatrixMask_impl; + + struct ComputeBasePtrOfStridedBatch + { + __host__ __device__ ComputeBasePtrOfStridedBatch(const AGridDesc_G_M_K& a_grid_desc_g_m_k, + const B0GridDesc_G_L_K& b0_grid_desc_g_l_k, + const B1GridDesc_G_N_L& b1_grid_desc_g_n_l, + const CGridDesc_G_M_N& c_grid_desc_g_m_n) + : a_grid_desc_g_m_k_(a_grid_desc_g_m_k), + b0_grid_desc_g_l_k_(b0_grid_desc_g_l_k), + b1_grid_desc_g_n_l_(b1_grid_desc_g_n_l), + c_grid_desc_g_m_n_(c_grid_desc_g_m_n) + { + } + + __host__ __device__ constexpr long_index_t GetABasePtr(index_t g_idx) const + { + return a_grid_desc_g_m_k_.CalculateOffset(make_multi_index(g_idx, 0, 0)); + } + + __host__ __device__ constexpr long_index_t GetB0BasePtr(index_t g_idx) const + { + return b0_grid_desc_g_l_k_.CalculateOffset(make_multi_index(g_idx, 0, 0)); + } + + __host__ __device__ constexpr long_index_t GetB1BasePtr(index_t g_idx) const + { + return b1_grid_desc_g_n_l_.CalculateOffset(make_multi_index(g_idx, 0, 0)); + } + + __host__ __device__ constexpr long_index_t GetCBasePtr(index_t g_idx) const + { + return c_grid_desc_g_m_n_.CalculateOffset(make_multi_index(g_idx, 0, 0)); + } + + private: + AGridDesc_G_M_K a_grid_desc_g_m_k_; + B0GridDesc_G_L_K b0_grid_desc_g_l_k_; + B1GridDesc_G_N_L b1_grid_desc_g_n_l_; + CGridDesc_G_M_N c_grid_desc_g_m_n_; + }; + + // GridwiseOp + using GridwiseOp = GridwiseBatchedGemmSoftmaxGemm_Wmma< + // DataType Family + ADataType, + B0DataType, + Acc0DataType, + B1DataType, + Acc1DataType, + CShuffleDataType, + CDataType, + // ElementwiseOp Family + AElementwiseOperation, + B0ElementwiseOperation, + AccElementwiseOperation, + B1ElementwiseOperation, + CElementwiseOperation, + InMemoryDataOperationEnum::Set, + // InMemory Data Descriptor + AGridDesc, + B0GridDesc, + B1GridDesc, + CGridDesc_M_N, + // Tiling Family + MPerBlock, + LPerBlock, + KPerBlock, + AK1, + BK1, + NPerBlock, + LTilePerBlock, + L1, + MPerWmma, + LPerWmma, + NPerWmma, + MRepeat, + LRepeat, + NRepeat, + // ThreadCluster Family + BlockSize, + ABlockTransferThreadClusterLengths_K0_M_K1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + true, + AEnableLds, + ABlockLdsAddExtraM, + B0BlockTransferThreadClusterLengths_K0_L_K1, + B0BlockTransferThreadClusterArrangeOrder, + B0BlockTransferSrcAccessOrder, + B0BlockTransferSrcVectorDim, + B0BlockTransferSrcScalarPerVector, + B0BlockTransferDstScalarPerVector_K1, + true, + B0EnableLds, + B0BlockLdsAddExtraL, + B1BlockTransferThreadClusterLengths_L0_N_L1, + B1BlockTransferThreadClusterArrangeOrder, + B1BlockTransferSrcAccessOrder, + B1BlockTransferSrcVectorDim, + B1BlockTransferSrcScalarPerVector, + B1BlockTransferDstScalarPerVector_L1, + false, + B1EnableLds, + B1BlockLdsAddExtraN, + CShuffleMRepeatPerShuffle, + CShuffleNRepeatPerShuffle, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CShuffleBlockTransferScalarPerVector_NPerBlock, + Transform::matrix_padder.PadN, + MaskingSpec == MaskingSpecialization::MaskOutUpperTriangle, + NumPrefetch, + LoopSched, + PipelineVer>; + + struct RawArg : public BaseArgument + { + RawArg(const ADataType* p_a_grid, + const B0DataType* p_b0_grid, + const B1DataType* p_b1_grid, + CDataType* p_c_grid, + index_t M, + index_t N, + index_t K, + index_t O, + index_t G0, + index_t G1, + float alpha, + bool input_permute, + bool output_permute) + : p_a_grid_{p_a_grid}, + p_b0_grid_{p_b0_grid}, + p_b1_grid_{p_b1_grid}, + p_c_grid_{p_c_grid}, + M_{M}, + N_{N}, + K_{K}, + O_{O}, + G0_{G0}, + G1_{G1}, + alpha_{alpha}, + input_permute_{input_permute}, + output_permute_{output_permute} + { + } + // Pointers + const ADataType* p_a_grid_; + const B0DataType* p_b0_grid_; + const B1DataType* p_b1_grid_; + CDataType* p_c_grid_; + + // Raw Problem Size + index_t M_; + index_t N_; + index_t K_; + index_t O_; + index_t G0_; + index_t G1_; + float alpha_; + bool input_permute_; + bool output_permute_; + }; + + static auto MakeArgument(const ADataType* p_a, + const B0DataType* p_b0, + const B1DataType* p_b1, + CDataType* p_c, + index_t M, + index_t N, + index_t K, + index_t O, + index_t G0, + index_t G1, + float alpha, + bool input_permute, + bool output_permute) + { + return RawArg{ + p_a, p_b0, p_b1, p_c, M, N, K, O, G0, G1, alpha, input_permute, output_permute}; + } + + static bool IsSupportedArgument(const RawArg& arg) + { + if(ck::is_navi3_supported()) + { + if constexpr(!(is_same_v || is_same_v)) + { + printf("DeviceOp: Acc0 Type err"); + return false; + } + + if constexpr(!(is_same_v || is_same_v)) + { + printf("DeviceOp: Acc1 Type err"); + return false; + } + } + else + { + printf("DeviceOp: Arch err"); + return false; + } + + constexpr index_t array_size = 4; + ck::index_t G0 = arg.G0_; + ck::index_t G1 = arg.G1_; + ck::index_t M = arg.M_; + ck::index_t N = arg.N_; + ck::index_t K = arg.K_; + ck::index_t O = arg.O_; + bool input_permute = arg.input_permute_; + bool output_permute = arg.output_permute_; + + std::array a_gs_ms_ks_lengths{G0, G1, M, K}; + std::array a_gs_ms_ks_strides = + input_permute ? std::array{M * G1 * K, K, G1 * K, 1} + // A layout [G0, M, G1, K] + : std::array{ + G1 * M * K, M * K, K, 1}; // A layout [G0, G1, M, K] + + std::array b0_gs_ns_ks_lengths{G0, G1, N, K}; + std::array b0_gs_ns_ks_strides = + input_permute ? std::array{N * G1 * K, K, G1 * K, 1} + // B0 layout [G0, N, G1, K] + : std::array{ + G1 * N * K, N * K, K, 1}; // B0 layout [G0, G1, N, K] + + std::array b1_gs_os_ns_lengths{G0, G1, O, N}; + std::array b1_gs_os_ns_strides = + input_permute ? std::array{N * G1 * O, O, 1, G1 * O} + // B1 layout [G0, N, G1, O] + : std::array{ + G1 * N * O, N * O, 1, O}; // B1 layout [G0, G1, N, O] + + std::array c_gs_ms_os_lengths{G0, G1, M, O}; + std::array c_gs_ms_os_strides = + output_permute ? std::array{M * G1 * O, O, G1 * O, 1} + // C layout [G0, M, G1, O] + : std::array{ + G1 * M * O, M * O, O, 1}; // C layout [G0, G1, M, O] + + const auto a_grid_desc = + DeviceOp::MakeAGridDescriptor(a_gs_ms_ks_lengths, a_gs_ms_ks_strides); + const auto b0_grid_desc = + DeviceOp::MakeB0GridDescriptor(b0_gs_ns_ks_lengths, b0_gs_ns_ks_strides); + const auto b1_grid_desc = + DeviceOp::MakeB1GridDescriptor(b1_gs_os_ns_lengths, b1_gs_os_ns_strides); + const auto c_grid_desc_m_n = + DeviceOp::Transform::MakeCGridDescriptor_M_N(c_gs_ms_os_lengths, c_gs_ms_os_strides); + + const auto block_2_ctile_map = GridwiseOp::MakeDefaultBlock2CTileMap(c_grid_desc_m_n, 1, 1); + + const auto c_grid_desc_g_m_n = + DeviceOp::Transform::MakeCGridDescriptor_G_M_N(c_gs_ms_os_lengths, c_gs_ms_os_strides); + index_t batch_count = c_grid_desc_g_m_n.GetLength(Number<0>{}); + + if(!GridwiseOp::CheckValidity( + a_grid_desc, b0_grid_desc, b1_grid_desc, c_grid_desc_m_n, block_2_ctile_map)) + { + return false; + } + + // Check if C permute dimension matches GEMM + GEMM shape + const index_t c_g = c_grid_desc_g_m_n.GetLength(I0); // unpadded + + if(!(c_g == batch_count)) + { + printf("DeviceOp: BatchCount err"); + return false; + } + + // Note: we need raw lengths since threadwise copy can not handle vector load when part of + // vector is out of bounds + // Note: need lowest dim in Ms/Ns/Ks/Os, not merged M/N/K/O + const auto MzRaw = M; + const auto LzRaw = N; + const auto KzRaw = K; + const auto NzRaw = O; + + // Check scalar per vector requirement + const auto a_extent_lowest = ABlockTransferSrcVectorDim == 2 ? KzRaw : MzRaw; + const auto b0_extent_lowest = B0BlockTransferSrcVectorDim == 2 ? KzRaw : LzRaw; + const auto b1_extent_lowest = B1BlockTransferSrcVectorDim == 2 ? LzRaw : NzRaw; + const auto c_extent_lowest = NzRaw; + + if(!(a_extent_lowest % ABlockTransferSrcScalarPerVector == 0 && + b0_extent_lowest % B0BlockTransferSrcScalarPerVector == 0 && + b1_extent_lowest % B1BlockTransferSrcScalarPerVector == 0 && + c_extent_lowest % CShuffleBlockTransferScalarPerVector_NPerBlock == 0)) + { + printf("DeviceOp: Data Transfer Vector scalar err"); + return false; + } + + std::array a_mz_kz_strides_{ + a_gs_ms_ks_strides[NumDimG + NumDimM - 1], + a_gs_ms_ks_strides[NumDimG + NumDimM + NumDimK - 1]}; + std::array b0_lz_kz_strides_{ + b0_gs_ns_ks_strides[NumDimG + NumDimL - 1], + b0_gs_ns_ks_strides[NumDimG + NumDimL + NumDimK - 1]}; + std::array b1_nz_lz_strides_{ + b1_gs_os_ns_strides[NumDimG + NumDimN - 1], + b1_gs_os_ns_strides[NumDimG + NumDimN + NumDimL - 1]}; + std::array c_mz_nz_strides_{ + c_gs_ms_os_strides[NumDimG + NumDimM - 1], + c_gs_ms_os_strides[NumDimG + NumDimM + NumDimN - 1]}; + + // Check vector load/store requirement + const auto a_stride_lowest = + ABlockTransferSrcVectorDim == 2 ? a_mz_kz_strides_[1] : a_mz_kz_strides_[0]; + const auto b0_stride_lowest = + B0BlockTransferSrcVectorDim == 2 ? b0_lz_kz_strides_[1] : b0_lz_kz_strides_[0]; + const auto b1_stride_lowest = + B1BlockTransferSrcVectorDim == 2 ? b1_nz_lz_strides_[1] : b1_nz_lz_strides_[0]; + const auto c_stride_lowest = c_mz_nz_strides_[1]; + + if(!(a_stride_lowest == 1 || b0_stride_lowest == 1 || b1_stride_lowest == 1 || + c_stride_lowest == 1)) + { + printf("DeviceOp: Data Vectorize transfer err"); + return false; + } + + return true; + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + struct SelfAttnArg : public BaseArgument + { + SelfAttnArg(const ADataType* p_qkv_grid, + CDataType* p_out_grid, + index_t batch_size, + index_t sequence_length, + index_t head_count, + index_t head_size, + float alpha) + : p_qkv_grid_{p_qkv_grid}, + p_out_grid_{p_out_grid}, + batch_size_{batch_size}, + sequence_length_{sequence_length}, + head_count_{head_count}, + head_size_{head_size}, + alpha_{alpha} + { + } + // Pointers + const ADataType* p_qkv_grid_; + CDataType* p_out_grid_; + + // Raw Problem Size + index_t batch_size_; + index_t sequence_length_; + index_t head_count_; + index_t head_size_; + float alpha_; + }; + + static auto MakeSelfAttnArgument(const ADataType* p_qkv_grid, + CDataType* p_out_grid, + index_t batch_size, + index_t sequence_length, + index_t head_count, + index_t head_size, + float alpha) + { + return SelfAttnArg{ + p_qkv_grid, p_out_grid, batch_size, sequence_length, head_count, head_size, alpha}; + } + + struct CrossAttnArg : public BaseArgument + { + CrossAttnArg(const ADataType* p_q_grid, + const B0DataType* p_kv_grid, + CDataType* p_out_grid, + index_t batch_size, + index_t q_sequence_length, + index_t kv_sequence_length, + index_t head_count, + index_t head_size, + float alpha) + : p_q_grid_{p_q_grid}, + p_kv_grid_{p_kv_grid}, + p_out_grid_{p_out_grid}, + batch_size_{batch_size}, + q_sequence_length_{q_sequence_length}, + kv_sequence_length_{kv_sequence_length}, + head_count_{head_count}, + head_size_{head_size}, + alpha_{alpha} + { + } + // Pointers + const ADataType* p_q_grid_; + const B0DataType* p_kv_grid_; + CDataType* p_out_grid_; + + // Raw Problem Size + index_t batch_size_; + index_t q_sequence_length_; + index_t kv_sequence_length_; + index_t head_count_; + index_t head_size_; + float alpha_; + }; + + static auto MakeCrossAttnArgument(const ADataType* p_q_grid, + const B0DataType* p_kv_grid, + CDataType* p_out_grid, + index_t batch_size, + index_t q_sequence_length, + index_t kv_sequence_length, + index_t head_count, + index_t head_size, + float alpha) + { + return CrossAttnArg{p_q_grid, + p_kv_grid, + p_out_grid, + batch_size, + q_sequence_length, + kv_sequence_length, + head_count, + head_size, + alpha}; + } + + // Argument + struct Argument : public BaseArgument + { + Argument( + const ADataType* p_a_grid, + const B0DataType* p_b0_grid, + const B1DataType* p_b1_grid, + CDataType* p_c_grid, + const std::array p_acc0_biases, + const std::array p_acc1_biases, + const std::array& a_gs_ms_ks_lengths, + const std::array& a_gs_ms_ks_strides, + const std::array& b0_gs_ls_ks_lengths, + const std::array& b0_gs_ls_ks_strides, + const std::array& b1_gs_ns_ls_lengths, + const std::array& b1_gs_ns_ls_strides, + const std::array& c_gs_ms_ns_lengths, + const std::array& c_gs_ms_ns_strides, + const std::array, NumAcc0Bias> acc0_biases_gs_ms_ls_lengths, + const std::array, NumAcc0Bias> acc0_biases_gs_ms_ls_strides, + const std::array, NumAcc1Bias> acc1_biases_gs_ms_ns_lengths, + const std::array, NumAcc1Bias> acc1_biases_gs_ms_ns_strides, + const index_t M01, + const index_t N01, + AElementwiseOperation a_element_op, + B0ElementwiseOperation b0_element_op, + AccElementwiseOperation acc_element_op, + B1ElementwiseOperation b1_element_op, + CElementwiseOperation c_element_op) + : p_a_grid_{p_a_grid}, + p_b0_grid_{p_b0_grid}, + p_b1_grid_{p_b1_grid}, + p_c_grid_{p_c_grid}, + a_grid_desc{DeviceOp::MakeAGridDescriptor(a_gs_ms_ks_lengths, a_gs_ms_ks_strides)}, + b0_grid_desc{ + DeviceOp::MakeB0GridDescriptor(b0_gs_ls_ks_lengths, b0_gs_ls_ks_strides)}, + b1_grid_desc{ + DeviceOp::MakeB1GridDescriptor(b1_gs_ns_ls_lengths, b1_gs_ns_ls_strides)}, + c_grid_desc_m_n_{ + Transform::MakeCGridDescriptor_M_N(c_gs_ms_ns_lengths, c_gs_ms_ns_strides)}, + a_grid_desc_g_m_k_{ + Transform::MakeAGridDescriptor_G_M_K(a_gs_ms_ks_lengths, a_gs_ms_ks_strides)}, + b0_grid_desc_g_l_k_{ + Transform::MakeB0GridDescriptor_G_N_K(b0_gs_ls_ks_lengths, b0_gs_ls_ks_strides)}, + b1_grid_desc_g_n_l_{ + Transform::MakeB1GridDescriptor_G_N_K(b1_gs_ns_ls_lengths, b1_gs_ns_ls_strides)}, + c_grid_desc_g_m_n_{ + Transform::MakeCGridDescriptor_G_M_N(c_gs_ms_ns_lengths, c_gs_ms_ns_strides)}, + c_grid_desc_mblock_mperblock_nblock_nperblock_{}, + block_2_ctile_map_{GridwiseOp::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_, M01, N01)}, + a_element_op_{a_element_op}, + b0_element_op_{b0_element_op}, + acc_element_op_{acc_element_op}, + b1_element_op_{b1_element_op}, + c_element_op_{c_element_op}, + c0_matrix_mask_{b0_grid_desc_g_l_k_.GetLength(I1)}, + raw_lengths_mz_lz_kz_nz_{a_gs_ms_ks_lengths[NumDimG + NumDimM - 1], + b0_gs_ls_ks_lengths[NumDimG + NumDimL - 1], + b0_gs_ls_ks_lengths[NumDimG + NumDimL + NumDimK - 1], + b1_gs_ns_ls_lengths[NumDimG + NumDimN - 1]}, + a_mz_kz_strides_{a_gs_ms_ks_strides[NumDimG + NumDimM - 1], + a_gs_ms_ks_strides[NumDimG + NumDimM + NumDimK - 1]}, + b0_lz_kz_strides_{b0_gs_ls_ks_strides[NumDimG + NumDimL - 1], + b0_gs_ls_ks_strides[NumDimG + NumDimL + NumDimK - 1]}, + b1_nz_lz_strides_{b1_gs_ns_ls_strides[NumDimG + NumDimN - 1], + b1_gs_ns_ls_strides[NumDimG + NumDimN + NumDimL - 1]}, + c_mz_nz_strides_{c_gs_ms_ns_strides[NumDimG + NumDimM - 1], + c_gs_ms_ns_strides[NumDimG + NumDimM + NumDimN - 1]}, + batch_count_{c_grid_desc_g_m_n_.GetLength(I0)}, + compute_ptr_offset_of_batch_{ + a_grid_desc_g_m_k_, b0_grid_desc_g_l_k_, b1_grid_desc_g_n_l_, c_grid_desc_g_m_n_} + { + // TODO ANT: implement bias addition + ignore = p_acc0_biases; + ignore = p_acc1_biases; + ignore = acc0_biases_gs_ms_ls_lengths; + ignore = acc0_biases_gs_ms_ls_strides; + ignore = acc1_biases_gs_ms_ns_lengths; + ignore = acc1_biases_gs_ms_ns_strides; + + if(GridwiseOp::CheckValidity( + a_grid_desc, b0_grid_desc, b1_grid_desc, c_grid_desc_m_n_, block_2_ctile_map_)) + { + c_grid_desc_mblock_mperblock_nblock_nperblock_ = + GridwiseOp::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + c_grid_desc_m_n_); + } + } + + // Pointers + const ADataType* p_a_grid_; + const B0DataType* p_b0_grid_; + const B1DataType* p_b1_grid_; + CDataType* p_c_grid_; + + // Tensor Descriptors + AGridDesc a_grid_desc; + B0GridDesc b0_grid_desc; + B1GridDesc b1_grid_desc; + CGridDesc_M_N c_grid_desc_m_n_; + + AGridDesc_G_M_K a_grid_desc_g_m_k_; + B0GridDesc_G_L_K b0_grid_desc_g_l_k_; + B1GridDesc_G_N_L b1_grid_desc_g_n_l_; + CGridDesc_G_M_N c_grid_desc_g_m_n_; + + typename GridwiseOp::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + c_grid_desc_mblock_mperblock_nblock_nperblock_; + + // Block to Tile mapping + typename GridwiseOp::DefaultBlock2CTileMap block_2_ctile_map_; + + // ElementwiseOp + AElementwiseOperation a_element_op_; + B0ElementwiseOperation b0_element_op_; + AccElementwiseOperation acc_element_op_; + B1ElementwiseOperation b1_element_op_; + CElementwiseOperation c_element_op_; + + // check C0 masking and padding + C0MatrixMask c0_matrix_mask_; + + // Strides for the last M/N/K dimensions of A/B0/B1/C + // for sanity check of vector load/store + std::array raw_lengths_mz_lz_kz_nz_; + std::array a_mz_kz_strides_; + std::array b0_lz_kz_strides_; + std::array b1_nz_lz_strides_; + std::array c_mz_nz_strides_; + + index_t batch_count_; + // Batch Offset + ComputeBasePtrOfStridedBatch compute_ptr_offset_of_batch_; + }; + + // Invoker + struct SelfAttnInvoker : public BaseInvoker + { + using Argument = DeviceOp::SelfAttnArg; + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + const auto M0 = math::integer_divide_ceil(arg.sequence_length_, MPerBlock); + const auto N0 = math::integer_divide_ceil(arg.head_size_, NPerBlock); + + const index_t grid_size = arg.batch_size_ * arg.head_count_ * M0 * N0; + const auto K = arg.head_size_; + + auto launch_kernel = [&](auto has_main_k_block_loop) { + const auto kernel = kernel_wmma_self_attention_forward; + + return launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_qkv_grid_, + arg.p_out_grid_, + arg.batch_size_, + arg.sequence_length_, + arg.head_count_, + arg.head_size_, + arg.alpha_); + }; + + if(GridwiseOp::CalculateHasMainKBlockLoop(K)) + { + return launch_kernel(integral_constant{}); + } + else + { + return launch_kernel(integral_constant{}); + } + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static auto MakeSelfAttnInvoker() { return SelfAttnInvoker{}; } + + // Invoker + struct CrossAttnInvoker : public BaseInvoker + { + using Argument = DeviceOp::CrossAttnArg; + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + const auto M0 = math::integer_divide_ceil(arg.q_sequence_length_, MPerBlock); + const auto N0 = math::integer_divide_ceil(arg.head_size_, NPerBlock); + + const index_t grid_size = arg.batch_size_ * arg.head_count_ * M0 * N0; + const auto K = arg.head_size_; + + auto launch_kernel = [&](auto has_main_k_block_loop) { + const auto kernel = kernel_wmma_cross_attention_forward; + + return launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_q_grid_, + arg.p_kv_grid_, + arg.p_out_grid_, + arg.batch_size_, + arg.q_sequence_length_, + arg.kv_sequence_length_, + arg.head_count_, + arg.head_size_, + arg.alpha_); + }; + + if(GridwiseOp::CalculateHasMainKBlockLoop(K)) + { + return launch_kernel(integral_constant{}); + } + else + { + return launch_kernel(integral_constant{}); + } + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static auto MakeCrossAttnInvoker() { return CrossAttnInvoker{}; } + + struct Invoker : public BaseInvoker + { + using Argument = DeviceOp::RawArg; + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + const auto M0 = math::integer_divide_ceil(arg.M_, MPerBlock); + const auto N0 = math::integer_divide_ceil(arg.O_, NPerBlock); + + const index_t grid_size = arg.G0_ * arg.G1_ * M0 * N0; + const auto K = arg.K_; + // printf("HasKBlockLoop: %d\n", GridwiseOp::CalculateHasMainKBlockLoop(K)); + auto launch_kernel = [&](auto has_main_k_block_loop) { + const auto kernel = + kernel_batched_gemm_softmax_gemm_wmma_cshuffle; + + return launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b0_grid_, + arg.p_b1_grid_, + arg.p_c_grid_, + arg.M_, + arg.N_, + arg.K_, + arg.O_, + arg.G0_, + arg.G1_, + arg.alpha_, + arg.input_permute_, + arg.output_permute_); + }; + + if(GridwiseOp::CalculateHasMainKBlockLoop(K)) + { + return launch_kernel(integral_constant{}); + } + else + { + return launch_kernel(integral_constant{}); + } + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } +#if 0 + static bool IsSupportedArgument(const Argument& arg) + { + if(ck::is_navi3_supported()) + { + if constexpr(!(is_same_v || is_same_v)) + { + printf("DeviceOp: Acc0 Type err"); + return false; + } + + if constexpr(!(is_same_v || is_same_v)) + { + printf("DeviceOp: Acc1 Type err"); + return false; + } + } + else + { + printf("DeviceOp: Arch err"); + return false; + } + + if(!GridwiseOp::CheckValidity(arg.a_grid_desc, + arg.b0_grid_desc, + arg.b1_grid_desc, + arg.c_grid_desc_m_n_, + arg.block_2_ctile_map_)) + { + return false; + } + + // Check if C permute dimension matches GEMM + GEMM shape + const index_t c_g = arg.c_grid_desc_g_m_n_.GetLength(I0); // unpadded + + if(!(c_g == arg.batch_count_)) + { + printf("DeviceOp: BatchCount err"); + return false; + } + + // Note: we need raw lengths since threadwise copy can not handle vector load when part of + // vector is out of bounds + // Note: need lowest dim in Ms/Ns/Ks/Os, not merged M/N/K/O + const auto MzRaw = arg.raw_lengths_mz_lz_kz_nz_[0]; + const auto LzRaw = arg.raw_lengths_mz_lz_kz_nz_[1]; + const auto KzRaw = arg.raw_lengths_mz_lz_kz_nz_[2]; + const auto NzRaw = arg.raw_lengths_mz_lz_kz_nz_[3]; + + // Check scalar per vector requirement + const auto a_extent_lowest = ABlockTransferSrcVectorDim == 2 ? KzRaw : MzRaw; + const auto b0_extent_lowest = B0BlockTransferSrcVectorDim == 2 ? KzRaw : LzRaw; + const auto b1_extent_lowest = B1BlockTransferSrcVectorDim == 2 ? LzRaw : NzRaw; + const auto c_extent_lowest = NzRaw; + + if(!(a_extent_lowest % ABlockTransferSrcScalarPerVector == 0 && + b0_extent_lowest % B0BlockTransferSrcScalarPerVector == 0 && + b1_extent_lowest % B1BlockTransferSrcScalarPerVector == 0 && + c_extent_lowest % CShuffleBlockTransferScalarPerVector_NPerBlock == 0)) + { + printf("DeviceOp: Data Transfer Vector scalar err"); + return false; + } + + // Check vector load/store requirement + const auto a_stride_lowest = + ABlockTransferSrcVectorDim == 2 ? arg.a_mz_kz_strides_[1] : arg.a_mz_kz_strides_[0]; + const auto b0_stride_lowest = + B0BlockTransferSrcVectorDim == 2 ? arg.b0_lz_kz_strides_[1] : arg.b0_lz_kz_strides_[0]; + const auto b1_stride_lowest = + B1BlockTransferSrcVectorDim == 2 ? arg.b1_nz_lz_strides_[1] : arg.b1_nz_lz_strides_[0]; + const auto c_stride_lowest = arg.c_mz_nz_strides_[1]; + + if(!(a_stride_lowest == 1 || b0_stride_lowest == 1 || b1_stride_lowest == 1 || + c_stride_lowest == 1)) + { + printf("DeviceOp: Data Vectorize transfer err"); + return false; + } + + return true; + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument( + const ADataType* p_a, + const B0DataType* p_b0, + const B1DataType* p_b1, + CDataType* p_c, + const std::array p_acc0_biases, + const std::array p_acc1_biases, + const std::array& a_gs_ms_ks_lengths, + const std::array& a_gs_ms_ks_strides, + const std::array& b0_gs_ls_ks_lengths, + const std::array& b0_gs_ls_ks_strides, + const std::array& b1_gs_ns_ls_lengths, + const std::array& b1_gs_ns_ls_strides, + const std::array& c_gs_ms_ns_lengths, + const std::array& c_gs_ms_ns_strides, + const std::array, NumAcc0Bias> acc0_biases_gs_ms_ls_lengths, + const std::array, NumAcc0Bias> acc0_biases_gs_ms_ls_strides, + const std::array, NumAcc1Bias> acc1_biases_gs_ms_ns_lengths, + const std::array, NumAcc1Bias> acc1_biases_gs_ms_ns_strides, + AElementwiseOperation a_element_op, + B0ElementwiseOperation b0_element_op, + AccElementwiseOperation acc_element_op, + B1ElementwiseOperation b1_element_op, + CElementwiseOperation c_element_op) + { + return Argument{p_a, + p_b0, + p_b1, + p_c, + p_acc0_biases, + p_acc1_biases, + a_gs_ms_ks_lengths, + a_gs_ms_ks_strides, + b0_gs_ls_ks_lengths, + b0_gs_ls_ks_strides, + b1_gs_ns_ls_lengths, + b1_gs_ns_ls_strides, + c_gs_ms_ns_lengths, + c_gs_ms_ns_strides, + acc0_biases_gs_ms_ls_lengths, + acc0_biases_gs_ms_ls_strides, + acc1_biases_gs_ms_ns_lengths, + acc1_biases_gs_ms_ns_strides, + 1, + 1, + a_element_op, + b0_element_op, + acc_element_op, + b1_element_op, + c_element_op}; + } +#endif + + // polymorphic + std::unique_ptr MakeArgumentPointer( + const void* p_a, + const void* p_b0, + const void* p_b1, + void* p_c, + const std::array p_acc0_biases, + const std::array p_acc1_biases, + const std::vector& a_gs_ms_ks_lengths, + const std::vector& a_gs_ms_ks_strides, + const std::vector& b0_gs_ls_ks_lengths, + const std::vector& b0_gs_ls_ks_strides, + const std::vector& b1_gs_ns_ls_lengths, + const std::vector& b1_gs_ns_ls_strides, + const std::vector& c_gs_ms_ns_lengths, + const std::vector& c_gs_ms_ns_strides, + const std::array, NumAcc0Bias> acc0_biases_gs_ms_ls_lengths, + const std::array, NumAcc0Bias> acc0_biases_gs_ms_ls_strides, + const std::array, NumAcc1Bias> acc1_biases_gs_ms_ns_lengths, + const std::array, NumAcc1Bias> acc1_biases_gs_ms_ns_strides, + AElementwiseOperation a_element_op, + B0ElementwiseOperation b0_element_op, + AccElementwiseOperation acc_element_op, + B1ElementwiseOperation b1_element_op, + CElementwiseOperation c_element_op) override + { + std::array a_lengths; + std::array a_strides; + std::array b0_lengths; + std::array b0_strides; + std::array b1_lengths; + std::array b1_strides; + std::array c_lengths; + std::array c_strides; + std::transform(a_gs_ms_ks_lengths.begin(), + a_gs_ms_ks_lengths.end(), + a_lengths.begin(), + [](index_t i) { return i; }); + std::transform(a_gs_ms_ks_strides.begin(), + a_gs_ms_ks_strides.end(), + a_strides.begin(), + [](index_t i) { return i; }); + std::transform(b0_gs_ls_ks_lengths.begin(), + b0_gs_ls_ks_lengths.end(), + b0_lengths.begin(), + [](index_t i) { return i; }); + std::transform(b0_gs_ls_ks_strides.begin(), + b0_gs_ls_ks_strides.end(), + b0_strides.begin(), + [](index_t i) { return i; }); + std::transform(b1_gs_ns_ls_lengths.begin(), + b1_gs_ns_ls_lengths.end(), + b1_lengths.begin(), + [](index_t i) { return i; }); + std::transform(b1_gs_ns_ls_strides.begin(), + b1_gs_ns_ls_strides.end(), + b1_strides.begin(), + [](index_t i) { return i; }); + std::transform(c_gs_ms_ns_lengths.begin(), + c_gs_ms_ns_lengths.end(), + c_lengths.begin(), + [](index_t i) { return i; }); + std::transform(c_gs_ms_ns_strides.begin(), + c_gs_ms_ns_strides.end(), + c_strides.begin(), + [](index_t i) { return i; }); + return std::make_unique(static_cast(p_a), + static_cast(p_b0), + static_cast(p_b1), + static_cast(p_c), + p_acc0_biases, + p_acc1_biases, + a_lengths, + a_strides, + b0_lengths, + b0_strides, + b1_lengths, + b1_strides, + c_lengths, + c_strides, + acc0_biases_gs_ms_ls_lengths, + acc0_biases_gs_ms_ls_strides, + acc1_biases_gs_ms_ns_lengths, + acc1_biases_gs_ms_ns_strides, + 1, + 1, + a_element_op, + b0_element_op, + acc_element_op, + b1_element_op, + c_element_op); + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + std::map LoopSchedToString{ + {LoopScheduler::Default, "Default"}, {LoopScheduler::Interwave, "Interwave"}}; + + std::map PipelineVersionToString{{PipelineVersion::v1, "v1"}, + {PipelineVersion::v2, "v2"}}; + + // clang-format off + str << "DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << LPerBlock << ", " + << KPerBlock << ", " + << AK1 << ", " + << BK1 << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << LTilePerBlock << ", " + << L1 << ", " + << getGemmSpecializationString(GemmSpec) << ", " + << "ASpec" << getTensorSpecializationString(ASpec) << ", " + << "B0Spec" << getTensorSpecializationString(B0Spec) << ", " + << "B1Spec" << getTensorSpecializationString(B1Spec) << ", " + << "CSpec" << getTensorSpecializationString(CSpec) << ", " + << getMaskingSpecializationString(MaskingSpec) + << ">" + << " AEnableLds: " + << AEnableLds << ", " + << "B0EnableLds: " + << B0EnableLds << ", " + << "B1EnableLds: " + << B1EnableLds << ", " + << "NumPrefetch: " + << NumPrefetch << ", " + << "LoopScheduler: " + << LoopSchedToString[LoopSched] << ", " + << "PipelineVersion: " + << PipelineVersionToString[PipelineVer]; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/impl/device_fpAintB_gemm_wmma.hpp b/include/ck/tensor_operation/gpu/device/impl/device_fpAintB_gemm_wmma.hpp new file mode 100644 index 0000000000..4385d64c19 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/impl/device_fpAintB_gemm_wmma.hpp @@ -0,0 +1,714 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_dequantB.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_fpAintB_gemm_wmma.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +// 1. DequantB(K, N) = int2fp(B(K, N)) * scale(1, N) +// 2. C(M, N) = A(M, K) * DequantB(K, N) + +template +struct DeviceFpAintBGemm_Wmma_CShuffle : public DeviceGemm_dequantB +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto I6 = Number<6>{}; + // K1 = Max Vector Access Pixels + static constexpr auto K1Number = Number{}; + + static constexpr auto MWaves = MPerBlock / (MRepeat * MPerWmma); + static constexpr auto NWaves = NPerBlock / (NRepeat * NPerWmma); + static constexpr auto WmmaK = K1 == 16 ? 32 : 16; + + static constexpr auto AEnableLds_auto = + (NWaves == 1 && is_same::value) ? false : true; + static constexpr auto BEnableLds_auto = + (MWaves == 1 && is_same::value) ? false : true; + + // If true, LDS is used unconditionally + // LDS bypass feature not implemented for dequantization pipeline. + static constexpr auto AEnableLds_manu = true; + static constexpr auto BEnableLds_manu = true; + + static constexpr auto AEnableLds = AEnableLds_auto || AEnableLds_manu || (NumPrefetch > 1); + static constexpr auto BEnableLds = BEnableLds_auto || BEnableLds_manu || (NumPrefetch > 1); + + static constexpr auto matrix_padder = + MatrixPadder{MPerBlock, NPerBlock, KPerBlock}; + + using DeviceOp = DeviceFpAintBGemm_Wmma_CShuffle; + + // Describe how data read from Global memory + static auto MakeAGridDescriptor(index_t MRaw, index_t KRaw, index_t StrideA) + { + const auto a_grid_desc_m_k = [&]() { + if constexpr(is_same::value) + { + const auto a_grid_desc_mraw_kraw = + make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), make_tuple(StrideA, I1)); + + return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw); + } + else if constexpr(is_same::value) + { + const auto a_grid_desc_mraw_kraw = + make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), make_tuple(I1, StrideA)); + + return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw); + } + }(); + + const auto M = a_grid_desc_m_k.GetLength(I0); + const auto K = a_grid_desc_m_k.GetLength(I1); + assert(K % K1 == 0); + + if constexpr(AEnableLds) + { + const index_t K0 = K / K1; + + return transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + else + { + constexpr auto A_KRow = 2; + constexpr auto A_K0PerWmma = WmmaK / A_KRow / K1Number; + const auto A_KWmma = K / WmmaK; + + const auto M0 = M / MPerBlock; + // 0 1 0 1 2 3 4 5 6 + // M - K <-> A_KWmma - MBlock*MRepeat - MWaves - A_K0PerWmma - A_KRow - MPerWmma - A_K1 + return transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple( + A_KWmma, Number{}, Number{}, K1Number)), + make_unmerge_transform( + make_tuple(M0 * MRepeat, Number{}, Number{}))), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 3, 4, 6>{}, Sequence<1, 2, 5>{})); + } + } + + static auto MakeBGridDescriptor(index_t KRaw, index_t NRaw, index_t StrideB) + { + const auto b_grid_desc_n_k = [&]() { + if constexpr(is_same::value) + { + const auto b_grid_desc_nraw_kraw = + make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), make_tuple(I1, StrideB)); + + return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw); + } + else if constexpr(is_same_v) + { + const auto b_grid_desc_nraw_kraw = + make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), make_tuple(StrideB, I1)); + + return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw); + } + }(); + + const auto N = b_grid_desc_n_k.GetLength(I0); + const auto K = b_grid_desc_n_k.GetLength(I1); + assert(K % K1 == 0); + + if constexpr(BEnableLds) + { + const index_t K0 = K / K1; + + return transform_tensor_descriptor( + b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), + make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + else + { + constexpr auto B_KRow = 2; + constexpr auto B_K0PerWmma = WmmaK / B_KRow / K1Number; + const auto B_KWmma = K / WmmaK; + + const auto N0 = N / NPerBlock; + // 0 1 0 1 2 3 4 5 6 + // M - K <-> A_KWmma - MBlock*MRepeat - MWaves - A_K0PerWmma - A_KRow - MPerWmma - A_K1 + return transform_tensor_descriptor( + b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple( + B_KWmma, Number{}, Number{}, K1Number)), + make_unmerge_transform( + make_tuple(N0 * NRepeat, Number{}, Number{}))), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 3, 4, 6>{}, Sequence<1, 2, 5>{})); + } + } + + static auto MakeScaleGridDescriptor(index_t KRaw, index_t NRaw, index_t StrideB = 0) + { + // assume Scale is [1, N] + const auto scale_grid_desc_n_k = [&]() { + const auto scale_grid_desc_nraw_kraw = + make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), make_tuple(I1, StrideB)); + + return matrix_padder.PadBDescriptor_N_K(scale_grid_desc_nraw_kraw); + }(); + + const auto N = scale_grid_desc_n_k.GetLength(I0); + const auto K = scale_grid_desc_n_k.GetLength(I1); + // When K = 1, it might be scale tensor. + assert(K % K1 == 0 && K != 1); + + if constexpr(BEnableLds) + { + const index_t K0 = K / K1; + + return transform_tensor_descriptor( + scale_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(K0, 1)), // Reduce K1 = 1 + make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + else + { + constexpr auto B_KRow = 2; + constexpr auto B_K0PerWmma = WmmaK / B_KRow / K1Number; + const auto B_KWmma = K / WmmaK; + + const auto N0 = N / NPerBlock; + // 0 1 0 1 2 3 4 5 6 + // M - K <-> A_KWmma - MBlock*MRepeat - MWaves - A_K0PerWmma - A_KRow - MPerWmma - A_K1 + return transform_tensor_descriptor( + scale_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple( + B_KWmma, Number{}, Number{}, K1Number)), + make_unmerge_transform( + make_tuple(N0 * NRepeat, Number{}, Number{}))), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 3, 4, 6>{}, Sequence<1, 2, 5>{})); + } + } + + static auto MakeCGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideC) + { + const auto c_grid_desc_mraw_nraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw), + make_tuple(StrideC, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw), + make_tuple(I1, StrideC)); + } + }(); + + return matrix_padder.PadCDescriptor_M_N(c_grid_desc_mraw_nraw); + } + + // Gridwise descriptor, mapping to whole given provblem. + using AGridDesc = decltype(MakeAGridDescriptor(1, 1, 1)); + using BGridDesc = decltype(MakeBGridDescriptor(1, 1, 1)); + using ScaleGridDesc = decltype(MakeScaleGridDescriptor(1, 1, 0)); + using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1)); + + // GridwiseGemm + using GridwiseGemm = GridwiseFpAintBGemm_Wmma< + BlockSize, + ADataType, + BDataType, + ScaleDataType, + AccDataType, + CShuffleDataType, + CDataType, + InMemoryDataOperationEnum::Set, + AGridDesc, + BGridDesc, + ScaleGridDesc, + CGridDesc_M_N, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + MPerBlock, + NPerBlock, + KPerBlock, + MPerWmma, + NPerWmma, + K1, + MRepeat, + NRepeat, + ABlockTransferThreadClusterLengths_K0_M_K1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + false, // AThreadTransferSrcResetCoordinateAfterRun, + AEnableLds, + ABlockLdsAddExtraM, + BBlockTransferThreadClusterLengths_K0_N_K1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + false, // BThreadTransferSrcResetCoordinateAfterRun, + BEnableLds, + BBlockLdsAddExtraN, + CShuffleMRepeatPerShuffle, + CShuffleNRepeatPerShuffle, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CShuffleBlockTransferScalarPerVector_NPerBlock, + NumPrefetch, + LoopSched, + PipelineVer>; + + // Argument + struct Argument : public BaseArgument + { + Argument(const ADataType* p_a_grid, + const BDataType* p_b_grid, + const ScaleDataType* p_scale_grid, + CDataType* p_c_grid, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + index_t StrideC, + index_t M01, + index_t N01, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + : p_a_grid_{p_a_grid}, + p_b_grid_{p_b_grid}, + p_scale_grid_{p_scale_grid}, + p_c_grid_{p_c_grid}, + a_grid_desc_{}, + b_grid_desc_{}, + scale_grid_desc_{}, + c_grid_desc_m_n_{}, + c_grid_desc_mblock_mperblock_nblock_nperblock{}, + block_2_ctile_map_{}, + M01_{M01}, + N01_{N01}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + c_element_op_{c_element_op}, + MRaw_{M}, + NRaw_{N}, + KRaw_{K} + { + a_grid_desc_ = DeviceOp::MakeAGridDescriptor(M, K, StrideA); + b_grid_desc_ = DeviceOp::MakeBGridDescriptor(K, N, StrideB); + scale_grid_desc_ = DeviceOp::MakeScaleGridDescriptor(K, N, 0); + c_grid_desc_m_n_ = DeviceOp::MakeCGridDescriptor_M_N(M, N, StrideC); + + block_2_ctile_map_ = + GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_, M01, N01); + + if(GridwiseGemm::CheckValidity( + a_grid_desc_, b_grid_desc_, c_grid_desc_m_n_, block_2_ctile_map_)) + { + c_grid_desc_mblock_mperblock_nblock_nperblock = + GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + c_grid_desc_m_n_); + } + } + + // private: + const ADataType* p_a_grid_; + const BDataType* p_b_grid_; + const ScaleDataType* p_scale_grid_; + CDataType* p_c_grid_; + AGridDesc a_grid_desc_; + BGridDesc b_grid_desc_; + ScaleGridDesc scale_grid_desc_; + CGridDesc_M_N c_grid_desc_m_n_; + typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + c_grid_desc_mblock_mperblock_nblock_nperblock; + typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_; + index_t M01_; + index_t N01_; + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CElementwiseOperation c_element_op_; + // for checking vector load/store + index_t MRaw_; + index_t NRaw_; + index_t KRaw_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceOp::Argument; + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_, + arg.b_grid_desc_, + arg.c_grid_desc_m_n_, + arg.block_2_ctile_map_)) + { + throw std::runtime_error( + "wrong! GridwiseGemm_k0mk1_k0nk1_m0nm1_wmma_v1r1 has invalid setting"); + } + + const index_t grid_size = + arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_); + + const auto K = [&]() { + if constexpr(AEnableLds) + { + return arg.a_grid_desc_.GetLength(I0) * arg.a_grid_desc_.GetLength(I2); + } + else + { + return arg.a_grid_desc_.GetLength(I0) * arg.a_grid_desc_.GetLength(I3) * + arg.a_grid_desc_.GetLength(I4) * arg.a_grid_desc_.GetLength(I6); + } + }(); + auto launch_kernel = [&](auto has_main_k_block_loop) { + const auto kernel = kernel_fpAintB_gemm_wmma< + GridwiseGemm, + ADataType, + BDataType, + ScaleDataType, + CDataType, + remove_reference_t, + remove_reference_t, + remove_reference_t, + remove_reference_t< + typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock>, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + remove_reference_t, + has_main_k_block_loop>; + + return launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_scale_grid_, + arg.p_c_grid_, + arg.a_grid_desc_, + arg.b_grid_desc_, + arg.scale_grid_desc_, + arg.c_grid_desc_mblock_mperblock_nblock_nperblock, + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_, + arg.block_2_ctile_map_); + }; + + if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) + { + return launch_kernel(integral_constant{}); + } + else + { + return launch_kernel(integral_constant{}); + } + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + if(ck::is_navi3_supported()) + { + if constexpr(!(is_same_v || is_same_v || + is_same_v)) + { + printf("DeviceOp err: AccDataType"); + return false; + } + } + else + { + printf("DeviceOp err: Arch"); + return false; + } + + // check vector load/store + { + using Row = ck::tensor_layout::gemm::RowMajor; + using Col = ck::tensor_layout::gemm::ColumnMajor; + + // check vector load of A + if constexpr(is_same_v && ABlockTransferSrcVectorDim == 2) + { + if(arg.KRaw_ % ABlockTransferSrcScalarPerVector != 0) + { + return false; + } + } + else if constexpr(is_same_v && ABlockTransferSrcVectorDim == 1) + { + // FIXME: not rigorous + if(arg.MRaw_ % ABlockTransferSrcScalarPerVector != 0) + { + return false; + } + } + else + { + return false; + } + + // check vector laod of B + if constexpr(is_same_v && BBlockTransferSrcVectorDim == 2) + { + if(arg.KRaw_ % BBlockTransferSrcScalarPerVector != 0) + { + return false; + } + } + else if constexpr(is_same_v && BBlockTransferSrcVectorDim == 1) + { + // FIXME: not rigorous + if(arg.NRaw_ % BBlockTransferSrcScalarPerVector != 0) + { + return false; + } + } + else + { + return false; + } + + // check vector store of C + // only support RowMajor for now + if constexpr(is_same_v) + { + if(arg.NRaw_ % CShuffleBlockTransferScalarPerVector_NPerBlock != 0) + { + return false; + } + } + else + { + return false; + } + } + + return GridwiseGemm::CheckValidity( + arg.a_grid_desc_, arg.b_grid_desc_, arg.c_grid_desc_m_n_, arg.block_2_ctile_map_); + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(const ADataType* p_a, + const BDataType* p_b, + const ScaleDataType* p_scale, + CDataType* p_c, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + index_t StrideC, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + { + return Argument{p_a, + p_b, + p_scale, + p_c, + M, + N, + K, + StrideA, + StrideB, + StrideC, + 1, + 1, + a_element_op, + b_element_op, + c_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr MakeArgumentPointer(const void* p_a, + const void* p_b, + const void* p_scale, + void* p_c, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + index_t StrideC, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) override + { + return std::make_unique(static_cast(p_a), + static_cast(p_b), + static_cast(p_scale), + static_cast(p_c), + M, + N, + K, + StrideA, + StrideB, + StrideC, + 1, + 1, + a_element_op, + b_element_op, + c_element_op); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + std::map LoopSchedToString{ + {LoopScheduler::Default, "Default"}, {LoopScheduler::Interwave, "Interwave"}}; + + std::map PipelineVersionToString{ + {PipelineVersion::v1, "v1"}, + {PipelineVersion::v2, "v2"}, + {PipelineVersion::weight_only, "weight_only"}}; + + // clang-format off + str << "DeviceFpAintBGemm_Wmma_CShuffle" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << KPerBlock << ", " + << K1 << ", " + << MPerWmma << ", " + << NPerWmma << ", " + << MRepeat << ", " + << NRepeat + << ">" + << " AEnableLds: " + << AEnableLds << ", " + << "BEnableLds: " + << BEnableLds << ", " + << "NumPrefetch: " + << NumPrefetch << ", " + << "LoopScheduler: " + << LoopSchedToString[LoopSched] << ", " + << "PipelineVersion: " + << PipelineVersionToString[PipelineVer]; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle.hpp index fd90c7f1ea..a2af5d6a85 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle.hpp @@ -16,6 +16,7 @@ #include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp" #include "ck/host_utility/device_prop.hpp" #include "ck/host_utility/kernel_launch.hpp" +#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" namespace ck { namespace tensor_operation { @@ -27,21 +28,22 @@ template struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD{}; static constexpr auto I1 = Number<1>{}; static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto I6 = Number<6>{}; // K1 = Max Vector Access Pixels static constexpr auto K1Number = Number{}; - static constexpr auto matrix_padder = - MatrixPadder{MPerBlock, NPerBlock, K0PerBlock* K1}; + static constexpr auto MWaves = MPerBlock / (MRepeat * MPerWmma); + static constexpr auto NWaves = NPerBlock / (NRepeat * NPerWmma); + static constexpr auto WmmaK = K1 == 16 ? 32 : 16; - static auto MakeAGridDescriptor_K0_M_K1(index_t MRaw, index_t KRaw, index_t StrideA) + static constexpr auto AEnableLds_auto = + (NWaves == 1 && is_same::value) ? false : true; + static constexpr auto BEnableLds_auto = + (MWaves == 1 && is_same::value) ? false : true; + + // If true, LDS is used unconditionally + static constexpr auto AEnableLds_manu = false; + static constexpr auto BEnableLds_manu = false; + + static constexpr auto AEnableLds = AEnableLds_auto || AEnableLds_manu || (NumPrefetch > 1); + static constexpr auto BEnableLds = BEnableLds_auto || BEnableLds_manu || (NumPrefetch > 1); + + static constexpr auto matrix_padder = + MatrixPadder{MPerBlock, NPerBlock, KPerBlock}; + + // Describe how data read from Global memory + static auto MakeAGridDescriptor(index_t MRaw, index_t KRaw, index_t StrideA) { - const auto a_grid_desc_mraw_kraw = [&]() { - if constexpr(is_same_v) + const auto a_grid_desc_m_k = [&]() { + if constexpr(is_same::value) { - return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), - make_tuple(StrideA, I1)); + const auto a_grid_desc_mraw_kraw = + make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), make_tuple(StrideA, I1)); + + return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw); } - else if constexpr(is_same_v) + else if constexpr(is_same::value) { - return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), - make_tuple(I1, StrideA)); + const auto a_grid_desc_mraw_kraw = + make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), make_tuple(I1, StrideA)); + + return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw); } }(); - const auto a_grid_desc_m_k = matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw); - const auto M = a_grid_desc_m_k.GetLength(I0); - const auto K = a_grid_desc_m_k.GetLength(I1); + const auto M = a_grid_desc_m_k.GetLength(I0); + const auto K = a_grid_desc_m_k.GetLength(I1); assert(K % K1 == 0); - const index_t K0 = K / K1; - return transform_tensor_descriptor( - a_grid_desc_m_k, - make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), - make_pass_through_transform(M)), - make_tuple(Sequence<1>{}, Sequence<0>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + if constexpr(AEnableLds) + { + const index_t K0 = K / K1; + + return transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + else + { + constexpr auto A_KRow = 2; + constexpr auto A_K0PerWmma = WmmaK / A_KRow / K1Number; + const auto A_KWmma = K / WmmaK; + + const auto M0 = M / MPerBlock; + // 0 1 0 1 2 3 4 5 6 + // M - K <-> A_KWmma - MBlock*MRepeat - MWaves - A_K0PerWmma - A_KRow - MPerWmma - A_K1 + return transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple( + A_KWmma, Number{}, Number{}, K1Number)), + make_unmerge_transform( + make_tuple(M0 * MRepeat, Number{}, Number{}))), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 3, 4, 6>{}, Sequence<1, 2, 5>{})); + } } - static auto MakeBGridDescriptor_K0_N_K1(index_t KRaw, index_t NRaw, index_t StrideB) + static auto MakeBGridDescriptor(index_t KRaw, index_t NRaw, index_t StrideB) { - const auto b_grid_desc_nraw_kraw = [&]() { - if constexpr(is_same_v) + const auto b_grid_desc_n_k = [&]() { + if constexpr(is_same::value) { - return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), - make_tuple(I1, StrideB)); + const auto b_grid_desc_nraw_kraw = + make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), make_tuple(I1, StrideB)); + + return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw); } else if constexpr(is_same_v) { - return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), - make_tuple(StrideB, I1)); + const auto b_grid_desc_nraw_kraw = + make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), make_tuple(StrideB, I1)); + + return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw); } }(); - const auto b_grid_desc_n_k = matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw); - const auto N = b_grid_desc_n_k.GetLength(I0); - const auto K = b_grid_desc_n_k.GetLength(I1); + const auto N = b_grid_desc_n_k.GetLength(I0); + const auto K = b_grid_desc_n_k.GetLength(I1); assert(K % K1 == 0); - const index_t K0 = K / K1; - return transform_tensor_descriptor( - b_grid_desc_n_k, - make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), - make_pass_through_transform(N)), - make_tuple(Sequence<1>{}, Sequence<0>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + if constexpr(BEnableLds) + { + const index_t K0 = K / K1; + + return transform_tensor_descriptor( + b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), + make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + else + { + constexpr auto B_KRow = 2; + constexpr auto B_K0PerWmma = WmmaK / B_KRow / K1Number; + const auto B_KWmma = K / WmmaK; + + const auto N0 = N / NPerBlock; + // 0 1 0 1 2 3 4 5 6 + // M - K <-> A_KWmma - MBlock*MRepeat - MWaves - A_K0PerWmma - A_KRow - MPerWmma - A_K1 + return transform_tensor_descriptor( + b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple( + B_KWmma, Number{}, Number{}, K1Number)), + make_unmerge_transform( + make_tuple(N0 * NRepeat, Number{}, Number{}))), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 3, 4, 6>{}, Sequence<1, 2, 5>{})); + } } template @@ -180,13 +252,13 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD; - using EGridDesc_M_N = decltype(MakeEGridDescriptor_M_N(1, 1, 1)); + using AGridDesc = decltype(MakeAGridDescriptor(1, 1, 1)); + using BGridDesc = decltype(MakeBGridDescriptor(1, 1, 1)); + using DsGridDesc_M_N = remove_cvref_t; + using EGridDesc_M_N = decltype(MakeEGridDescriptor_M_N(1, 1, 1)); // GridwiseOp - using GridwiseOp = GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle< + using GridwiseOp = GridwiseGemmMultipleD_Wmma< // DataType Family ADataType, BDataType, @@ -195,8 +267,8 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD(p_b_grid)}, p_ds_grid_{}, p_e_grid_{static_cast(p_e_grid)}, - a_grid_desc_k0_m_k1_{}, - b_grid_desc_k0_n_k1_{}, + a_grid_desc{}, + b_grid_desc{}, ds_grid_desc_m_n_{}, e_grid_desc_m_n_{}, ds_grid_desc_mblock_mperblock_nblock_nperblock{}, @@ -278,8 +352,8 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD{}([&](auto i) { using DLayout = remove_cvref_t>; using DDataType = remove_cvref_t>; @@ -295,8 +369,8 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD, + remove_reference_t, + remove_reference_t< + typename GridwiseOp::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock>, + remove_reference_t< + typename GridwiseOp::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock>, + AElementwiseOperation, + BElementwiseOperation, + CDEElementwiseOperation, + remove_reference_t, + has_main_k_block_loop>; // Last Option is W/O + + return launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_ds_grid_, + arg.p_e_grid_, + arg.a_grid_desc, + arg.b_grid_desc, + arg.ds_grid_desc_mblock_mperblock_nblock_nperblock, + arg.e_grid_desc_mblock_mperblock_nblock_nperblock, + arg.a_element_op_, + arg.b_element_op_, + arg.cde_element_op_, + arg.block_2_ctile_map_); + }; if(GridwiseOp::CalculateHasMainKBlockLoop(K)) { - const auto kernel = kernel_gemm_mupltipe_d_wmma_cshuffle< - GridwiseOp, - ADataType, - BDataType, - typename GridwiseOp::DsGridPointer, - EDataType, - remove_reference_t, - remove_reference_t, - remove_reference_t< - typename GridwiseOp::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock>, - remove_reference_t< - typename GridwiseOp::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock>, - AElementwiseOperation, - BElementwiseOperation, - CDEElementwiseOperation, - remove_reference_t, - true>; // Last Option is W/O - - ave_time = - launch_and_time_kernel(stream_config, - kernel, - dim3(grid_size), - dim3(BlockSize), - 0, - arg.p_a_grid_, - arg.p_b_grid_, - arg.p_ds_grid_, - arg.p_e_grid_, - arg.a_grid_desc_k0_m_k1_, - arg.b_grid_desc_k0_n_k1_, - arg.ds_grid_desc_mblock_mperblock_nblock_nperblock, - arg.e_grid_desc_mblock_mperblock_nblock_nperblock, - arg.a_element_op_, - arg.b_element_op_, - arg.cde_element_op_, - arg.block_2_ctile_map_); + return launch_kernel(integral_constant{}); } else { - const auto kernel = kernel_gemm_mupltipe_d_wmma_cshuffle< - GridwiseOp, - ADataType, - BDataType, - typename GridwiseOp::DsGridPointer, - EDataType, - remove_reference_t, - remove_reference_t, - remove_reference_t< - typename GridwiseOp::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock>, - remove_reference_t< - typename GridwiseOp::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock>, - AElementwiseOperation, - BElementwiseOperation, - CDEElementwiseOperation, - remove_reference_t, - false>; - - ave_time = - launch_and_time_kernel(stream_config, - kernel, - dim3(grid_size), - dim3(BlockSize), - 0, - arg.p_a_grid_, - arg.p_b_grid_, - arg.p_ds_grid_, - arg.p_e_grid_, - arg.a_grid_desc_k0_m_k1_, - arg.b_grid_desc_k0_n_k1_, - arg.ds_grid_desc_mblock_mperblock_nblock_nperblock, - arg.e_grid_desc_mblock_mperblock_nblock_nperblock, - arg.a_element_op_, - arg.b_element_op_, - arg.cde_element_op_, - arg.block_2_ctile_map_); + return launch_kernel(integral_constant{}); } - - return ave_time; } // polymorphic @@ -575,8 +606,8 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD" - << " NumPrefetch: " + << " AEnableLds: " + << AEnableLds << ", " + << "BEnableLds: " + << BEnableLds << ", " + << "NumPrefetch: " << NumPrefetch << ", " << "LoopScheduler: " << LoopSchedToString[LoopSched] << ", " diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp index 98d14caa6d..a7f2305291 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp @@ -16,6 +16,7 @@ #include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp" #include "ck/host_utility/device_prop.hpp" #include "ck/host_utility/kernel_launch.hpp" +#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" namespace ck { namespace tensor_operation { @@ -33,13 +34,14 @@ template struct DeviceGemmWmma_CShuffle : public DeviceGemm{}; static constexpr auto I1 = Number<1>{}; static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto I6 = Number<6>{}; // K1 = Max Vector Access Pixels static constexpr auto K1Number = Number{}; - static constexpr auto matrix_padder = - MatrixPadder{MPerBlock, NPerBlock, K0PerBlock* K1}; + static constexpr auto MWaves = MPerBlock / (MRepeat * MPerWmma); + static constexpr auto NWaves = NPerBlock / (NRepeat * NPerWmma); + static constexpr auto WmmaK = K1 == 16 ? 32 : 16; - static auto MakeAGridDescriptor_K0_M_K1(index_t MRaw, index_t KRaw, index_t StrideA) + static constexpr auto AEnableLds_auto = + (NWaves == 1 && is_same::value) ? false : true; + static constexpr auto BEnableLds_auto = + (MWaves == 1 && is_same::value) ? false : true; + + // If true, LDS is used unconditionally + static constexpr auto AEnableLds_manu = false; + static constexpr auto BEnableLds_manu = false; + + static constexpr auto AEnableLds = AEnableLds_auto || AEnableLds_manu || (NumPrefetch > 1); + static constexpr auto BEnableLds = BEnableLds_auto || BEnableLds_manu || (NumPrefetch > 1); + + static constexpr auto matrix_padder = + MatrixPadder{MPerBlock, NPerBlock, KPerBlock}; + // Describe how data read from Global memory + static auto MakeAGridDescriptor(index_t MRaw, index_t KRaw, index_t StrideA) { - const auto a_grid_desc_mraw_kraw = [&]() { - if constexpr(is_same_v) + const auto a_grid_desc_m_k = [&]() { + if constexpr(is_same::value) { - return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), - make_tuple(StrideA, I1)); + const auto a_grid_desc_mraw_kraw = + make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), make_tuple(StrideA, I1)); + + return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw); } - else if constexpr(is_same_v) + else if constexpr(is_same::value) { - return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), - make_tuple(I1, StrideA)); + const auto a_grid_desc_mraw_kraw = + make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), make_tuple(I1, StrideA)); + + return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw); } }(); - const auto a_grid_desc_m_k = matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw); - const auto M = a_grid_desc_m_k.GetLength(I0); - const auto K = a_grid_desc_m_k.GetLength(I1); + const auto M = a_grid_desc_m_k.GetLength(I0); + const auto K = a_grid_desc_m_k.GetLength(I1); assert(K % K1 == 0); - const index_t K0 = K / K1; - return transform_tensor_descriptor( - a_grid_desc_m_k, - make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), - make_pass_through_transform(M)), - make_tuple(Sequence<1>{}, Sequence<0>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + if constexpr(AEnableLds) + { + const index_t K0 = K / K1; + + return transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + else + { + constexpr auto A_KRow = 2; + constexpr auto A_K0PerWmma = WmmaK / A_KRow / K1Number; + const auto A_KWmma = K / WmmaK; + + const auto M0 = M / MPerBlock; + // 0 1 0 1 2 3 4 5 6 + // M - K <-> A_KWmma - MBlock*MRepeat - MWaves - A_K0PerWmma - A_KRow - MPerWmma - A_K1 + return transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple( + A_KWmma, Number{}, Number{}, K1Number)), + make_unmerge_transform( + make_tuple(M0 * MRepeat, Number{}, Number{}))), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 3, 4, 6>{}, Sequence<1, 2, 5>{})); + } } - static auto MakeBGridDescriptor_K0_N_K1(index_t KRaw, index_t NRaw, index_t StrideB) + static auto MakeBGridDescriptor(index_t KRaw, index_t NRaw, index_t StrideB) { - const auto b_grid_desc_nraw_kraw = [&]() { - if constexpr(is_same_v) + const auto b_grid_desc_n_k = [&]() { + if constexpr(is_same::value) { - return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), - make_tuple(I1, StrideB)); + const auto b_grid_desc_nraw_kraw = + make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), make_tuple(I1, StrideB)); + + return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw); } else if constexpr(is_same_v) { - return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), - make_tuple(StrideB, I1)); + const auto b_grid_desc_nraw_kraw = + make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), make_tuple(StrideB, I1)); + + return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw); } }(); - const auto b_grid_desc_n_k = matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw); - const auto N = b_grid_desc_n_k.GetLength(I0); - const auto K = b_grid_desc_n_k.GetLength(I1); + const auto N = b_grid_desc_n_k.GetLength(I0); + const auto K = b_grid_desc_n_k.GetLength(I1); assert(K % K1 == 0); - const index_t K0 = K / K1; - return transform_tensor_descriptor( - b_grid_desc_n_k, - make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), - make_pass_through_transform(N)), - make_tuple(Sequence<1>{}, Sequence<0>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + if constexpr(BEnableLds) + { + const index_t K0 = K / K1; + + return transform_tensor_descriptor( + b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), + make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + else + { + constexpr auto B_KRow = 2; + constexpr auto B_K0PerWmma = WmmaK / B_KRow / K1Number; + const auto B_KWmma = K / WmmaK; + + const auto N0 = N / NPerBlock; + // 0 1 0 1 2 3 4 5 6 + // M - K <-> A_KWmma - MBlock*MRepeat - MWaves - A_K0PerWmma - A_KRow - MPerWmma - A_K1 + return transform_tensor_descriptor( + b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple( + B_KWmma, Number{}, Number{}, K1Number)), + make_unmerge_transform( + make_tuple(N0 * NRepeat, Number{}, Number{}))), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 3, 4, 6>{}, Sequence<1, 2, 5>{})); + } } static auto MakeCGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideC) @@ -159,56 +230,58 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm; + using GridwiseGemm = + GridwiseGemm_Wmma; // Argument struct Argument : public BaseArgument @@ -230,7 +303,7 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm, + remove_reference_t, + remove_reference_t< + typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock>, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + remove_reference_t, + has_main_k_block_loop>; - float ave_time = 0; + return launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.a_grid_desc_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_mblock_mperblock_nblock_nperblock, + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_, + arg.block_2_ctile_map_); + }; if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) { - const auto kernel = kernel_gemm_wmma< - GridwiseGemm, - ADataType, - BDataType, - CDataType, - remove_reference_t, - remove_reference_t, - remove_reference_t< - typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock>, - AElementwiseOperation, - BElementwiseOperation, - CElementwiseOperation, - remove_reference_t, - true>; // Last Option is W/O - - ave_time = launch_and_time_kernel(stream_config, - kernel, - dim3(grid_size), - dim3(BlockSize), - 0, - arg.p_a_grid_, - arg.p_b_grid_, - arg.p_c_grid_, - arg.a_grid_desc_k0_m_k1_, - arg.b_grid_desc_k0_n_k1_, - arg.c_grid_desc_mblock_mperblock_nblock_nperblock, - arg.a_element_op_, - arg.b_element_op_, - arg.c_element_op_, - arg.block_2_ctile_map_); + return launch_kernel(integral_constant{}); } else { - const auto kernel = kernel_gemm_wmma< - GridwiseGemm, - ADataType, - BDataType, - CDataType, - remove_reference_t, - remove_reference_t, - remove_reference_t< - typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock>, - AElementwiseOperation, - BElementwiseOperation, - CElementwiseOperation, - remove_reference_t, - false>; - - ave_time = launch_and_time_kernel(stream_config, - kernel, - dim3(grid_size), - dim3(BlockSize), - 0, - arg.p_a_grid_, - arg.p_b_grid_, - arg.p_c_grid_, - arg.a_grid_desc_k0_m_k1_, - arg.b_grid_desc_k0_n_k1_, - arg.c_grid_desc_mblock_mperblock_nblock_nperblock, - arg.a_element_op_, - arg.b_element_op_, - arg.c_element_op_, - arg.block_2_ctile_map_); + return launch_kernel(integral_constant{}); } - - return ave_time; } // polymorphic @@ -413,13 +445,16 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm || is_same_v)) + if constexpr(!(is_same_v || is_same_v || + is_same_v)) { + printf("DeviceOp err: AccDataType"); return false; } } else { + printf("DeviceOp err: Arch"); return false; } @@ -485,7 +520,7 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm" - << " NumPrefetch: " + << " AEnableLds: " + << AEnableLds << ", " + << "BEnableLds: " + << BEnableLds << ", " + << "NumPrefetch: " << NumPrefetch << ", " << "LoopScheduler: " << LoopSchedToString[LoopSched] << ", " diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp index 0b3de153c3..b0e0e6da76 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp @@ -196,7 +196,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle using EGridDesc_M_N = remove_cvref_t>; // GridwiseGemm - using GridwiseGemm = GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle< + using GridwiseGemm = GridwiseGemmMultipleD_Wmma< // DataType Family ADataType, BDataType, @@ -217,7 +217,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle // Tiling Family MPerBlock, NPerBlock, - K0PerBlock, + KPerBlock, MPerWMMA, NPerWMMA, K1, @@ -232,6 +232,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, + true, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, @@ -240,6 +241,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, + true, BBlockLdsExtraN, CShuffleMRepeatPerShuffle, CShuffleNRepeatPerShuffle, diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle.hpp index 8850b13d0a..e440eb82a4 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle.hpp @@ -393,12 +393,14 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffle using BGridDesc_K0_N_K1 = remove_cvref_t; using CGridDesc_M_N = remove_cvref_t; - using GridwiseGemm = GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle< + using CShuffleDataType = AccDataType; + + using GridwiseGemm = GridwiseGemmMultipleD_Wmma< // DataType Family ADataType, BDataType, AccDataType, - CDataType, + CShuffleDataType, Tuple<>, CDataType, // InMemory Data Descriptor @@ -414,7 +416,7 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffle // Tiling Family MPerBlock, NPerBlock, - K0PerBlock, + KPerBlock, MPerWMMA, NPerWMMA, K1, @@ -429,6 +431,7 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffle ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_K1, false, + true, ABlockLdsAddExtraM, BBlockTransferThreadClusterLengths_K0_N_K1, BBlockTransferThreadClusterArrangeOrder, @@ -437,6 +440,7 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffle BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_K1, false, + true, BBlockLdsAddExtraN, CShuffleMRepeatPerShuffle, CShuffleNRepeatPerShuffle, diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp index ba2a4b0f7a..d70d462e24 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp @@ -52,22 +52,23 @@ template struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle @@ -109,11 +109,31 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle static constexpr index_t NumDTensor = DsDataType::Size(); - static constexpr auto I0 = Number<0>{}; - static constexpr auto I1 = Number<1>{}; - static constexpr auto I2 = Number<2>{}; - static constexpr auto I3 = Number<3>{}; - static constexpr index_t KPerBlock = K0PerBlock * K1; + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto I6 = Number<6>{}; + // K1 = Max Vector Access Pixels + static constexpr auto K1Number = Number{}; + + static constexpr auto MWaves = MPerBlock / (MRepeat * MPerWmma); + static constexpr auto NWaves = NPerBlock / (NRepeat * NPerWmma); + static constexpr auto WmmaK = 16; + + static constexpr auto AEnableLds_auto = NWaves == 1 ? false : true; + static constexpr auto BEnableLds_auto = MWaves == 1 ? false : true; + + // If true, LDS is used unconditionally + static constexpr auto AEnableLds_manu = true; + static constexpr auto BEnableLds_manu = true; + + static constexpr auto AEnableLds = + AEnableLds_auto || AEnableLds_manu || (NumGemmKPrefetchStage > 1); + static constexpr auto BEnableLds = + BEnableLds_auto || BEnableLds_manu || (NumGemmKPrefetchStage > 1); static constexpr auto conv_to_gemm_transformer = TransformConvFwdToGemm{}; @@ -122,17 +142,16 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle MatrixPadder{MPerBlock, NPerBlock, KPerBlock}; template - static auto - MakeAGridDescriptor_M_K(const std::array& a_g_n_c_wis_lengths, - const std::array& a_g_n_c_wis_strides, - const std::array& b_g_k_c_xs_lengths, - const std::array& b_g_k_c_xs_strides, - const std::array& e_g_n_k_wos_lengths, - const std::array& e_g_n_k_wos_strides, - const std::array& conv_filter_strides, - const std::array& conv_filter_dilations, - const std::array& input_left_pads, - const std::array& input_right_pads) + static auto MakeAGridDescriptor(const std::array& a_g_n_c_wis_lengths, + const std::array& a_g_n_c_wis_strides, + const std::array& b_g_k_c_xs_lengths, + const std::array& b_g_k_c_xs_strides, + const std::array& e_g_n_k_wos_lengths, + const std::array& e_g_n_k_wos_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads) { const auto in_gemmmraw_gemmkraw_desc = conv_to_gemm_transformer.template MakeADescriptor_M_K(a_g_n_c_wis_lengths, @@ -149,13 +168,44 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle const auto in_gemmm_gemmk_desc = matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_desc); - return in_gemmm_gemmk_desc; + const auto M = in_gemmm_gemmk_desc.GetLength(I0); + const auto K = in_gemmm_gemmk_desc.GetLength(I1); + assert(K % K1 == 0); + + if constexpr(AEnableLds) + { + const index_t K0 = K / K1; + + return transform_tensor_descriptor( + in_gemmm_gemmk_desc, + make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + else + { + constexpr auto A_KRow = 2; + constexpr auto A_K0PerWmma = WmmaK / A_KRow / K1Number; + const auto A_KWmma = K / WmmaK; + + const auto M0 = M / MPerBlock; + // 0 1 0 1 2 3 4 5 6 + // M - K <-> A_KWmma - MBlock*MRepeat - MWaves - A_K0PerWmma - A_KRow - MPerWmma - A_K1 + return transform_tensor_descriptor( + in_gemmm_gemmk_desc, + make_tuple(make_unmerge_transform(make_tuple( + A_KWmma, Number{}, Number{}, K1Number)), + make_unmerge_transform( + make_tuple(M0 * MRepeat, Number{}, Number{}))), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 3, 4, 6>{}, Sequence<1, 2, 5>{})); + } } template - static auto - MakeBGridDescriptor_N_K(const std::array& b_g_k_c_xs_lengths, - const std::array& b_g_k_c_xs_strides) + static auto MakeBGridDescriptor(const std::array& b_g_k_c_xs_lengths, + const std::array& b_g_k_c_xs_strides) { const auto wei_gemmnraw_gemmkraw_desc = conv_to_gemm_transformer.template MakeBDescriptor_N_K(b_g_k_c_xs_lengths, @@ -164,7 +214,39 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle const auto wei_gemmn_gemmk_desc = matrix_padder.PadBDescriptor_N_K(wei_gemmnraw_gemmkraw_desc); - return wei_gemmn_gemmk_desc; + const auto N = wei_gemmn_gemmk_desc.GetLength(I0); + const auto K = wei_gemmn_gemmk_desc.GetLength(I1); + assert(K % K1 == 0); + + if constexpr(BEnableLds) + { + const index_t K0 = K / K1; + + return transform_tensor_descriptor( + wei_gemmn_gemmk_desc, + make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), + make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + else + { + constexpr auto B_KRow = 2; + constexpr auto B_K0PerWmma = WmmaK / B_KRow / K1Number; + const auto B_KWmma = K / WmmaK; + + const auto N0 = N / NPerBlock; + // 0 1 0 1 2 3 4 5 6 + // M - K <-> A_KWmma - MBlock*MRepeat - MWaves - A_K0PerWmma - A_KRow - MPerWmma - A_K1 + return transform_tensor_descriptor( + wei_gemmn_gemmk_desc, + make_tuple(make_unmerge_transform(make_tuple( + B_KWmma, Number{}, Number{}, K1Number)), + make_unmerge_transform( + make_tuple(N0 * NRepeat, Number{}, Number{}))), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 3, 4, 6>{}, Sequence<1, 2, 5>{})); + } } template @@ -197,53 +279,14 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle } // desc for problem definition - using AGridDesc_M_K = remove_cvref_t( - {}, {}, {}, {}, {}, {}, {}, {}, {}, {}))>; - using BGridDesc_N_K = remove_cvref_t({}, {}))>; + using AGridDesc = + decltype(DeviceOp::MakeAGridDescriptor({}, {}, {}, {}, {}, {}, {}, {}, {}, {})); + using BGridDesc = decltype(DeviceOp::MakeBGridDescriptor({}, {})); using DsGridDesc_M_N = remove_cvref_t; using EGridDesc_M_N = remove_cvref_t({}, {}))>; - // A desc for source in blockwise copy - template - __host__ __device__ static constexpr auto - MakeAGridDescriptor_AK0_M_AK1(const AGridDesc_M_K& a_grid_desc_m_k) - { - const auto M = a_grid_desc_m_k.GetLength(I0); - const auto K = a_grid_desc_m_k.GetLength(I1); - - const auto AK1 = K1; - const auto AK0 = K / AK1; - - return transform_tensor_descriptor(a_grid_desc_m_k, - make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), - make_pass_through_transform(M)), - make_tuple(Sequence<1>{}, Sequence<0>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); - } - - // B desc for source in blockwise copy - template - __host__ __device__ static constexpr auto - MakeBGridDescriptor_BK0_N_BK1(const BGridDesc_N_K& b_grid_desc_n_k) - { - const auto N = b_grid_desc_n_k.GetLength(I0); - const auto K = b_grid_desc_n_k.GetLength(I1); - - const auto BK1 = K1; - const auto BK0 = K / BK1; - - return transform_tensor_descriptor(b_grid_desc_n_k, - make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), - make_pass_through_transform(N)), - make_tuple(Sequence<1>{}, Sequence<0>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); - } - - using AGridDesc_AK0_M_AK1 = decltype(DeviceOp::MakeAGridDescriptor_AK0_M_AK1(AGridDesc_M_K{})); - using BGridDesc_BK0_N_BK1 = decltype(DeviceOp::MakeBGridDescriptor_BK0_N_BK1(BGridDesc_N_K{})); - // GridwiseOp - using GridwiseOp = GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle< + using GridwiseOp = GridwiseGemmMultipleD_Wmma< // DataType Family ADataType, BDataType, @@ -252,8 +295,8 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle DsDataType, EDataType, // InMemory Data Descriptor - AGridDesc_AK0_M_AK1, - BGridDesc_BK0_N_BK1, + AGridDesc, + BGridDesc, DsGridDesc_M_N, EGridDesc_M_N, // ElementwiseOp Family @@ -264,9 +307,9 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle // Tiling Family MPerBlock, NPerBlock, - K0PerBlock, - MPerWMMA, - NPerWMMA, + KPerBlock, + MPerWmma, + NPerWmma, K1, MRepeat, NRepeat, @@ -279,6 +322,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, + AEnableLds, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, @@ -287,6 +331,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, + BEnableLds, BBlockLdsExtraN, CShuffleMRepeatPerShuffle, CShuffleNRepeatPerShuffle, @@ -327,23 +372,21 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle p_ds_grid_{}, p_e_grid_{static_cast(p_e)}, num_group_{a_g_n_c_wis_lengths[0]}, - a_grid_desc_m_k_{DeviceOp::MakeAGridDescriptor_M_K(a_g_n_c_wis_lengths, - a_g_n_c_wis_strides, - b_g_k_c_xs_lengths, - b_g_k_c_xs_strides, - e_g_n_k_wos_lengths, - e_g_n_k_wos_strides, - conv_filter_strides, - conv_filter_dilations, - input_left_pads, - input_right_pads)}, - b_grid_desc_n_k_{DeviceOp::MakeBGridDescriptor_N_K(b_g_k_c_xs_lengths, - b_g_k_c_xs_strides)}, ds_grid_desc_m_n_{}, e_grid_desc_m_n_{DeviceOp::MakeEGridDescriptor_M_N(e_g_n_k_wos_lengths, e_g_n_k_wos_strides)}, - a_grid_desc_ak0_m_ak1_{DeviceOp::MakeAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k_)}, - b_grid_desc_bk0_n_bk1_{DeviceOp::MakeBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k_)}, + a_grid_desc_{DeviceOp::MakeAGridDescriptor(a_g_n_c_wis_lengths, + a_g_n_c_wis_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + e_g_n_k_wos_lengths, + e_g_n_k_wos_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads)}, + b_grid_desc_{ + DeviceOp::MakeBGridDescriptor(b_g_k_c_xs_lengths, b_g_k_c_xs_strides)}, ds_grid_desc_mblock_mperblock_nblock_nperblock_{}, e_grid_desc_mblock_mperblock_nblock_nperblock_{}, block_2_etile_map_{GridwiseOp::MakeDefaultBlock2CTileMap(e_grid_desc_m_n_, M01, N01)}, @@ -395,8 +438,8 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle void Print() const { - std::cout << "A[M, K]: " << a_grid_desc_m_k_ << std::endl; - std::cout << "B[N, K]: " << b_grid_desc_n_k_ << std::endl; + std::cout << "A[M, K]: " << a_grid_desc_ << std::endl; + std::cout << "B[N, K]: " << b_grid_desc_ << std::endl; static_for<0, NumDTensor, 1>{}( [&](auto i) { std::cout << "Ds[M, N]: " << ds_grid_desc_m_n_[i] << std::endl; }); std::cout << "E[M, N]: " << e_grid_desc_m_n_ << std::endl; @@ -411,14 +454,12 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle // tensor descriptors for problem definiton index_t num_group_; - AGridDesc_M_K a_grid_desc_m_k_; - BGridDesc_N_K b_grid_desc_n_k_; DsGridDesc_M_N ds_grid_desc_m_n_; EGridDesc_M_N e_grid_desc_m_n_; // tensor descriptors for block/thread-wise copy - AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; - BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; + AGridDesc a_grid_desc_; + BGridDesc b_grid_desc_; typename GridwiseOp::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock ds_grid_desc_mblock_mperblock_nblock_nperblock_; typename GridwiseOp::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock @@ -465,8 +506,17 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle const index_t grid_size = arg.block_2_etile_map_.CalculateGridSize(arg.e_grid_desc_m_n_) * arg.num_group_; - const auto K = - arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2); + const auto K = [&]() { + if constexpr(AEnableLds) + { + return arg.a_grid_desc_.GetLength(I0) * arg.a_grid_desc_.GetLength(I2); + } + else + { + return arg.a_grid_desc_.GetLength(I0) * arg.a_grid_desc_.GetLength(I3) * + arg.a_grid_desc_.GetLength(I4) * arg.a_grid_desc_.GetLength(I6); + } + }(); auto launch_kernel = [&](auto has_main_k_block_loop) { constexpr bool has_main_loop = has_main_k_block_loop.value; @@ -480,8 +530,8 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, - DeviceOp::AGridDesc_AK0_M_AK1, - DeviceOp::BGridDesc_BK0_N_BK1, + DeviceOp::AGridDesc, + DeviceOp::BGridDesc, typename GridwiseOp::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, typename GridwiseOp::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, remove_reference_t, @@ -501,8 +551,8 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle arg.b_element_op_, arg.cde_element_op_, arg.a_g_n_c_wis_lengths_[0], // Group count - arg.a_grid_desc_ak0_m_ak1_, - arg.b_grid_desc_bk0_n_bk1_, + arg.a_grid_desc_, + arg.b_grid_desc_, arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_, arg.e_grid_desc_mblock_mperblock_nblock_nperblock_, arg.block_2_etile_map_, @@ -670,8 +720,8 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle } // check Gridwise GEMM - return GridwiseOp::CheckValidity(arg.a_grid_desc_ak0_m_ak1_, - arg.b_grid_desc_bk0_n_bk1_, + return GridwiseOp::CheckValidity(arg.a_grid_desc_, + arg.b_grid_desc_, arg.ds_grid_desc_m_n_, arg.e_grid_desc_m_n_, arg.block_2_etile_map_); @@ -790,9 +840,19 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle << KPerBlock << ", " << getConvForwardSpecializationString(ConvForwardSpecialization) << ", " << K1 << ", " + << MPerWmma << ", " + << NPerWmma << ", " + << MRepeat << ", " + << NRepeat + << ">" + << " AEnableLds: " + << AEnableLds << ", " + << "BEnableLds: " + << BEnableLds << ", " + << "ABlockTransferSrcScalarPerVector: " << ABlockTransferSrcScalarPerVector << ", " - << BBlockTransferSrcScalarPerVector - << ">"; + << "BBlockTransferSrcScalarPerVector: " + << BBlockTransferSrcScalarPerVector; // clang-format on return str.str(); diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_query_attention_forward_wmma.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_query_attention_forward_wmma.hpp new file mode 100644 index 0000000000..84ad48d4c7 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_query_attention_forward_wmma.hpp @@ -0,0 +1,1254 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm_permute.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp" +#include "ck/tensor_operation/operator_transform/transform_contraction_to_gemm_arraybase.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +// Multi-Query Attention (MQA) kernel implementation +// Assume number of head of K,V is 1. +// Q [G0, G1, M, K] * K [G0, 1, K, N] = P [G0, G1, M, N] +// P [G0, G1, M, N] * V [G0, 1, N, O] = Out [G0, G1, M, O] +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_grouped_query_attention_wmma(const ADataType* __restrict__ p_a_grid, + const B0DataType* __restrict__ p_b0_grid, + const B1DataType* __restrict__ p_b1_grid, + CDataType* __restrict__ p_c_grid, + index_t M, // SequenceQ + index_t N, // SequenceK + index_t K, // HeadDim + index_t O, // SequenceK + index_t G0, // Batch + index_t G1, // HeadNum + float alpha, + bool input_permute, + bool output_permute) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__)) + + // clang-format off +// *************************************************** + const auto q_head = G1; + const auto kv_head = QueryGroupNumber; +// Make Tensor Descriptors + constexpr index_t array_size = 4; + std::array a_gs_ms_ks_lengths{G0, q_head, M, K}; + std::array a_gs_ms_ks_strides = + input_permute + ? std::array{M * q_head * K, K, q_head * K, 1} // A layout [G0, M, G1, K] + : std::array{q_head * M * K, M * K, K, 1}; // A layout [G0, G1, M, K] + + std::array b0_gs_ns_ks_lengths{G0, kv_head, N, K}; + std::array b0_gs_ns_ks_strides = + input_permute + ? std::array{N * kv_head * K, K, kv_head * K, 1} // B0 layout [G0, N, 1, K] + : std::array{kv_head * N * K, N * K, K, 1}; // B0 layout [G0, 1, N, K] + + std::array b1_gs_os_ns_lengths{G0, kv_head, O, N}; + std::array b1_gs_os_ns_strides = + input_permute + ? std::array{N * kv_head * O, O, 1, kv_head * O} // B1 layout [G0, N, 1, O] + : std::array{kv_head * N * O, N * O, 1, O}; // B1 layout [G0, 1, N, O] + + std::array c_gs_ms_os_lengths{G0, q_head, M, O}; + std::array c_gs_ms_os_strides = + output_permute + ? std::array{M * q_head * O, O, q_head * O, 1} // C layout [G0, M, G1, O] + : std::array{q_head * M * O, M * O, O, 1}; // C layout [G0, G1, M, O] + + const auto a_element_op = AElementwiseOperation{}; + const auto b0_element_op = B0ElementwiseOperation{}; + const auto acc0_element_op = AccElementwiseOperation{alpha}; + const auto b1_element_op = B1ElementwiseOperation{}; + const auto c_element_op = CElementwiseOperation{}; + // fail to reuse DeviceOp::MakeArgument() because of the __device__ function required. + + const auto a_grid_desc = DeviceOp::MakeAGridDescriptor(a_gs_ms_ks_lengths, a_gs_ms_ks_strides); + const auto b0_grid_desc = + DeviceOp::MakeB0GridDescriptor(b0_gs_ns_ks_lengths, b0_gs_ns_ks_strides); + const auto b1_grid_desc = + DeviceOp::MakeB1GridDescriptor(b1_gs_os_ns_lengths, b1_gs_os_ns_strides); + const auto c_grid_desc_m_n = + DeviceOp::Transform::MakeCGridDescriptor_M_N(c_gs_ms_os_lengths, c_gs_ms_os_strides); + const auto c_grid_desc_mblock_mperblock_nblock_nperblock = + GridwiseOp::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(c_grid_desc_m_n); + const auto block_2_ctile_map = GridwiseOp::MakeDefaultBlock2CTileMap(c_grid_desc_m_n, 1, 1); + + const auto a_grid_desc_g_m_k = + DeviceOp::Transform::MakeAGridDescriptor_G_M_K(a_gs_ms_ks_lengths, a_gs_ms_ks_strides); + const auto b0_grid_desc_g_l_k = + DeviceOp::Transform::MakeB0GridDescriptor_G_N_K(b0_gs_ns_ks_lengths, b0_gs_ns_ks_strides); + const auto b1_grid_desc_g_n_l = + DeviceOp::Transform::MakeB1GridDescriptor_G_N_K(b1_gs_os_ns_lengths, b1_gs_os_ns_strides); + const auto c_grid_desc_g_m_n = + DeviceOp::Transform::MakeCGridDescriptor_G_M_N(c_gs_ms_os_lengths, c_gs_ms_os_strides); + const auto compute_base_ptr_of_batch = + typename DeviceOp::ComputeBasePtrOfStridedBatch{a_grid_desc_g_m_k, b0_grid_desc_g_l_k, b1_grid_desc_g_n_l, c_grid_desc_g_m_n}; + index_t batch_count = c_grid_desc_g_m_n.GetLength(Number<0>{}); + const auto c0_matrix_mask = typename DeviceOp::C0MatrixMask{b0_grid_desc_g_l_k.GetLength(Number<1>{})}; + + // clang-format on + __shared__ char p_shared[GridwiseOp::GetSharedMemoryNumberOfByte()]; + const index_t num_blocks_per_batch = + __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); + const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); + + const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_base_ptr_of_batch.GetABasePtr(g_idx))); + const long_index_t b0_batch_offset = __builtin_amdgcn_readfirstlane(static_cast( + compute_base_ptr_of_batch.GetB0BasePtr(g_idx * QueryGroupNumber / G1))); + const long_index_t b1_batch_offset = __builtin_amdgcn_readfirstlane(static_cast( + compute_base_ptr_of_batch.GetB1BasePtr(g_idx * QueryGroupNumber / G1))); + const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_base_ptr_of_batch.GetCBasePtr(g_idx))); + + GridwiseOp::template Run(p_a_grid + a_batch_offset, + p_b0_grid + b0_batch_offset, + p_b1_grid + b1_batch_offset, + p_c_grid + c_batch_offset, + p_shared, + a_grid_desc, + b0_grid_desc, + b1_grid_desc, + c_grid_desc_mblock_mperblock_nblock_nperblock, + a_element_op, + b0_element_op, + acc0_element_op, + b1_element_op, + c_element_op, + c0_matrix_mask, + block_2_ctile_map); +#else + ignore = p_a_grid; + ignore = p_b0_grid; + ignore = p_b1_grid; + ignore = p_c_grid; + ignore = M; + ignore = N; + ignore = K; + ignore = O; + ignore = G0; + ignore = G1; + ignore = input_permute; + ignore = output_permute; +#endif // end of if (defined(__gfx11__)) +} + +// Computes C = A * B0 * B1 +// MN = MK * KL * LN +// ^^^^^^ (Acc0) +// ^^^^^^^^^^^ (Acc1) +template +struct DeviceGroupedQueryAttentionForward_Wmma + : public DeviceBatchedGemmSoftmaxGemmPermute +{ + static_assert(NumDimG > 0 && NumDimM > 0 && NumDimL > 0 && NumDimK > 0 && NumDimN > 0, + "Number of dimension must be greater than 0"); + + static constexpr index_t NumAcc0Bias = Acc0BiasDataType::Size(); + static constexpr index_t NumAcc1Bias = Acc1BiasDataType::Size(); + + // TODO ANT: implement bias combination + static_assert(NumAcc0Bias == 0 && NumAcc0Bias == 0, "Bias addition is unimplemented"); + + static constexpr index_t NumDimGemm0M = NumDimM; + static constexpr index_t NumDimGemm0N = NumDimL; + static constexpr index_t NumDimGemm0K = NumDimK; + static constexpr index_t NumDimGemm1M = NumDimM; + static constexpr index_t NumDimGemm1N = NumDimN; + static constexpr index_t NumDimGemm1K = NumDimL; + + using DeviceOp = DeviceGroupedQueryAttentionForward_Wmma; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto I6 = Number<6>{}; + + static constexpr auto WmmaK = 16; + + static constexpr auto MWaves = MPerBlock / (MRepeat * MPerWmma); + static constexpr auto LWaves = LPerBlock / (LRepeat * LPerWmma); + static constexpr auto NWaves = NPerBlock / (NRepeat * NPerWmma); + + static constexpr auto AEnableLds_auto = LWaves == 1 ? false : true; + static constexpr auto B0EnableLds_auto = MWaves == 1 ? false : true; + static constexpr auto B1EnableLds_auto = MWaves == 1 ? false : true; + + static constexpr auto AEnableLds_manu = false; + static constexpr auto B0EnableLds_manu = true; + static constexpr auto B1EnableLds_manu = true; + + static constexpr auto AEnableLds = AEnableLds_auto || AEnableLds_manu || (NumPrefetch > 1); + static constexpr auto B0EnableLds = B0EnableLds_auto || B0EnableLds_manu || (NumPrefetch > 1); + static constexpr auto B1EnableLds = B1EnableLds_auto || B1EnableLds_manu || (NumPrefetch > 1); + + using Transform = TransformBatchedContractionContractionToBatchedGemmGemm_Wmma< + Sequence, + Sequence, + GemmSpec, + ASpec, + B0Spec, + B1Spec, + CSpec>; + + __host__ __device__ static auto MakeAGridDescriptor( + const std::array& a_gs_ms_ks_lengths_vec, + const std::array& a_gs_ms_ks_strides_vec) + { + if constexpr(AEnableLds) + { + return Transform::MakeAGridDescriptor_AK0_M_AK1( + Transform::MakeAGridDescriptor_M_K(a_gs_ms_ks_lengths_vec, a_gs_ms_ks_strides_vec), + Number{}); + } + else + { + return Transform:: + MakeAGridDescriptor_AKWmma_MBlockRepeat_MWaves_AK0PerWmma_AKRow_MPerWmma_AK1( + Transform::MakeAGridDescriptor_M_K(a_gs_ms_ks_lengths_vec, + a_gs_ms_ks_strides_vec), + Number{}, + Number{}, + Number{}, + Number{}, + Number{}); + } + } + + __host__ __device__ static auto MakeB0GridDescriptor( + const std::array& b0_gs_ls_ks_lengths_vec, + const std::array& b0_gs_ls_ks_strides_vec) + { + if constexpr(B0EnableLds) + { + return Transform::MakeB0GridDescriptor_BK0_N_BK1( + Transform::MakeB0GridDescriptor_N_K(b0_gs_ls_ks_lengths_vec, + b0_gs_ls_ks_strides_vec), + Number{}); + } + else + { + return Transform:: + MakeB0GridDescriptor_BKWmma_LBlockRepeat_LWaves_BK0PerWmma_BKRow_LPerWmma_BK1( + Transform::MakeB0GridDescriptor_N_K(b0_gs_ls_ks_lengths_vec, + b0_gs_ls_ks_strides_vec), + Number{}, + Number{}, + Number{}, + Number{}, + Number{}); + } + } + + __host__ __device__ static auto MakeB1GridDescriptor( + const std::array& b1_gs_ns_ls_lengths_vec, + const std::array& b1_gs_ns_ls_strides_vec) + { + if constexpr(B1EnableLds) + { + return Transform::MakeB1GridDescriptor_BK0_N_BK1( + Transform::MakeB1GridDescriptor_N_K(b1_gs_ns_ls_lengths_vec, + b1_gs_ns_ls_strides_vec), + Number{}); + } + else + { + return Transform:: + MakeB1GridDescriptor_BLWmma_NBlockRepeat_NWaves__BL0PerWmma_BLRow_NPerWmma_BL1( + Transform::MakeB1GridDescriptor_N_K(b1_gs_ns_ls_lengths_vec, + b1_gs_ns_ls_strides_vec), + Number{}, + Number{}, + Number{}, + Number{}, + Number{}); + } + } + + using AGridDesc = decltype(MakeAGridDescriptor({}, {})); + using B0GridDesc = decltype(MakeB0GridDescriptor({}, {})); + using B1GridDesc = decltype(MakeB1GridDescriptor({}, {})); + using CGridDesc_M_N = decltype(Transform::MakeCGridDescriptor_M_N({}, {})); + using AGridDesc_G_M_K = decltype(Transform::MakeAGridDescriptor_G_M_K({}, {})); + using B0GridDesc_G_L_K = decltype(Transform::MakeB0GridDescriptor_G_N_K({}, {})); + using B1GridDesc_G_N_L = decltype(Transform::MakeB1GridDescriptor_G_N_K({}, {})); + using CGridDesc_G_M_N = decltype(Transform::MakeCGridDescriptor_G_M_N({}, {})); + + __host__ __device__ constexpr static auto make_MaskOutPredicate() + { + if constexpr(MaskingSpec == MaskingSpecialization::MaskDisabled) + { + return MaskDisabledPredicate{}; + } + else if constexpr(MaskingSpec == MaskingSpecialization::MaskOutUpperTriangle) + { + return MaskOutUpperTrianglePredicate{}; + } + } + using C0MatrixMask = C0MatrixMask_impl; + + struct ComputeBasePtrOfStridedBatch + { + __host__ __device__ ComputeBasePtrOfStridedBatch(const AGridDesc_G_M_K& a_grid_desc_g_m_k, + const B0GridDesc_G_L_K& b0_grid_desc_g_l_k, + const B1GridDesc_G_N_L& b1_grid_desc_g_n_l, + const CGridDesc_G_M_N& c_grid_desc_g_m_n) + : a_grid_desc_g_m_k_(a_grid_desc_g_m_k), + b0_grid_desc_g_l_k_(b0_grid_desc_g_l_k), + b1_grid_desc_g_n_l_(b1_grid_desc_g_n_l), + c_grid_desc_g_m_n_(c_grid_desc_g_m_n) + { + } + + __host__ __device__ constexpr long_index_t GetABasePtr(index_t g_idx) const + { + return a_grid_desc_g_m_k_.CalculateOffset(make_multi_index(g_idx, 0, 0)); + } + + __host__ __device__ constexpr long_index_t GetB0BasePtr(index_t g_idx) const + { + return b0_grid_desc_g_l_k_.CalculateOffset(make_multi_index(g_idx, 0, 0)); + } + + __host__ __device__ constexpr long_index_t GetB1BasePtr(index_t g_idx) const + { + return b1_grid_desc_g_n_l_.CalculateOffset(make_multi_index(g_idx, 0, 0)); + } + + __host__ __device__ constexpr long_index_t GetCBasePtr(index_t g_idx) const + { + return c_grid_desc_g_m_n_.CalculateOffset(make_multi_index(g_idx, 0, 0)); + } + + private: + AGridDesc_G_M_K a_grid_desc_g_m_k_; + B0GridDesc_G_L_K b0_grid_desc_g_l_k_; + B1GridDesc_G_N_L b1_grid_desc_g_n_l_; + CGridDesc_G_M_N c_grid_desc_g_m_n_; + }; + + // GridwiseOp + using GridwiseOp = GridwiseBatchedGemmSoftmaxGemm_Wmma< + // DataType Family + ADataType, + B0DataType, + Acc0DataType, + B1DataType, + Acc1DataType, + CShuffleDataType, + CDataType, + // ElementwiseOp Family + AElementwiseOperation, + B0ElementwiseOperation, + AccElementwiseOperation, + B1ElementwiseOperation, + CElementwiseOperation, + InMemoryDataOperationEnum::Set, + // InMemory Data Descriptor + AGridDesc, + B0GridDesc, + B1GridDesc, + CGridDesc_M_N, + // Tiling Family + MPerBlock, + LPerBlock, + KPerBlock, + AK1, + BK1, + NPerBlock, + LTilePerBlock, + L1, + MPerWmma, + LPerWmma, + NPerWmma, + MRepeat, + LRepeat, + NRepeat, + // ThreadCluster Family + BlockSize, + ABlockTransferThreadClusterLengths_K0_M_K1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + true, + AEnableLds, + ABlockLdsAddExtraM, + B0BlockTransferThreadClusterLengths_K0_L_K1, + B0BlockTransferThreadClusterArrangeOrder, + B0BlockTransferSrcAccessOrder, + B0BlockTransferSrcVectorDim, + B0BlockTransferSrcScalarPerVector, + B0BlockTransferDstScalarPerVector_K1, + true, + B0EnableLds, + B0BlockLdsAddExtraL, + B1BlockTransferThreadClusterLengths_L0_N_L1, + B1BlockTransferThreadClusterArrangeOrder, + B1BlockTransferSrcAccessOrder, + B1BlockTransferSrcVectorDim, + B1BlockTransferSrcScalarPerVector, + B1BlockTransferDstScalarPerVector_L1, + false, + B1EnableLds, + B1BlockLdsAddExtraN, + CShuffleMRepeatPerShuffle, + CShuffleNRepeatPerShuffle, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CShuffleBlockTransferScalarPerVector_NPerBlock, + Transform::matrix_padder.PadN, + MaskingSpec == MaskingSpecialization::MaskOutUpperTriangle, + NumPrefetch, + LoopSched, + PipelineVer>; + + struct RawArg : public BaseArgument + { + RawArg(const ADataType* p_a_grid, + const B0DataType* p_b0_grid, + const B1DataType* p_b1_grid, + CDataType* p_c_grid, + index_t M, + index_t N, + index_t K, + index_t O, + index_t G0, + index_t G1, + float alpha, + bool input_permute, + bool output_permute) + : p_a_grid_{p_a_grid}, + p_b0_grid_{p_b0_grid}, + p_b1_grid_{p_b1_grid}, + p_c_grid_{p_c_grid}, + M_{M}, + N_{N}, + K_{K}, + O_{O}, + G0_{G0}, + G1_{G1}, + alpha_{alpha}, + input_permute_{input_permute}, + output_permute_{output_permute} + { + } + // Pointers + const ADataType* p_a_grid_; + const B0DataType* p_b0_grid_; + const B1DataType* p_b1_grid_; + CDataType* p_c_grid_; + + // Raw Problem Size + index_t M_; + index_t N_; + index_t K_; + index_t O_; + index_t G0_; + index_t G1_; + float alpha_; + bool input_permute_; + bool output_permute_; + }; + + static auto MakeArgument(const ADataType* p_a, + const B0DataType* p_b0, + const B1DataType* p_b1, + CDataType* p_c, + index_t M, + index_t N, + index_t K, + index_t O, + index_t G0, + index_t G1, + float alpha, + bool input_permute, + bool output_permute) + { + return RawArg{ + p_a, p_b0, p_b1, p_c, M, N, K, O, G0, G1, alpha, input_permute, output_permute}; + } + + static bool IsSupportedArgument(const RawArg& arg) + { + if(ck::is_navi3_supported()) + { + if constexpr(!(is_same_v || is_same_v)) + { + printf("DeviceOp: Acc0 Type err"); + return false; + } + + if constexpr(!(is_same_v || is_same_v)) + { + printf("DeviceOp: Acc1 Type err"); + return false; + } + } + else + { + printf("DeviceOp: Arch err"); + return false; + } + + if(arg.G1_ % QueryGroupNumber != 0) + { + return false; + } + + constexpr index_t array_size = 4; + ck::index_t G0 = arg.G0_; + ck::index_t G1 = arg.G1_; + ck::index_t M = arg.M_; + ck::index_t N = arg.N_; + ck::index_t K = arg.K_; + ck::index_t O = arg.O_; + bool input_permute = arg.input_permute_; + bool output_permute = arg.output_permute_; + + std::array a_gs_ms_ks_lengths{G0, G1, M, K}; + std::array a_gs_ms_ks_strides = + input_permute ? std::array{M * G1 * K, K, G1 * K, 1} + // A layout [G0, M, G1, K] + : std::array{ + G1 * M * K, M * K, K, 1}; // A layout [G0, G1, M, K] + + std::array b0_gs_ns_ks_lengths{G0, G1, N, K}; + std::array b0_gs_ns_ks_strides = + input_permute ? std::array{N * G1 * K, K, G1 * K, 1} + // B0 layout [G0, N, G1, K] + : std::array{ + G1 * N * K, N * K, K, 1}; // B0 layout [G0, G1, N, K] + + std::array b1_gs_os_ns_lengths{G0, G1, O, N}; + std::array b1_gs_os_ns_strides = + input_permute ? std::array{N * G1 * O, O, 1, G1 * O} + // B1 layout [G0, N, G1, O] + : std::array{ + G1 * N * O, N * O, 1, O}; // B1 layout [G0, G1, N, O] + + std::array c_gs_ms_os_lengths{G0, G1, M, O}; + std::array c_gs_ms_os_strides = + output_permute ? std::array{M * G1 * O, O, G1 * O, 1} + // C layout [G0, M, G1, O] + : std::array{ + G1 * M * O, M * O, O, 1}; // C layout [G0, G1, M, O] + + const auto a_grid_desc = + DeviceOp::MakeAGridDescriptor(a_gs_ms_ks_lengths, a_gs_ms_ks_strides); + const auto b0_grid_desc = + DeviceOp::MakeB0GridDescriptor(b0_gs_ns_ks_lengths, b0_gs_ns_ks_strides); + const auto b1_grid_desc = + DeviceOp::MakeB1GridDescriptor(b1_gs_os_ns_lengths, b1_gs_os_ns_strides); + const auto c_grid_desc_m_n = + DeviceOp::Transform::MakeCGridDescriptor_M_N(c_gs_ms_os_lengths, c_gs_ms_os_strides); + + const auto block_2_ctile_map = GridwiseOp::MakeDefaultBlock2CTileMap(c_grid_desc_m_n, 1, 1); + + const auto c_grid_desc_g_m_n = + DeviceOp::Transform::MakeCGridDescriptor_G_M_N(c_gs_ms_os_lengths, c_gs_ms_os_strides); + index_t batch_count = c_grid_desc_g_m_n.GetLength(Number<0>{}); + + if(!GridwiseOp::CheckValidity( + a_grid_desc, b0_grid_desc, b1_grid_desc, c_grid_desc_m_n, block_2_ctile_map)) + { + return false; + } + + // Check if C permute dimension matches GEMM + GEMM shape + const index_t c_g = c_grid_desc_g_m_n.GetLength(I0); // unpadded + + if(!(c_g == batch_count)) + { + printf("DeviceOp: BatchCount err"); + return false; + } + + // Note: we need raw lengths since threadwise copy can not handle vector load when part of + // vector is out of bounds + // Note: need lowest dim in Ms/Ns/Ks/Os, not merged M/N/K/O + const auto MzRaw = M; + const auto LzRaw = N; + const auto KzRaw = K; + const auto NzRaw = O; + + // Check scalar per vector requirement + const auto a_extent_lowest = ABlockTransferSrcVectorDim == 2 ? KzRaw : MzRaw; + const auto b0_extent_lowest = B0BlockTransferSrcVectorDim == 2 ? KzRaw : LzRaw; + const auto b1_extent_lowest = B1BlockTransferSrcVectorDim == 2 ? LzRaw : NzRaw; + const auto c_extent_lowest = NzRaw; + + if(!(a_extent_lowest % ABlockTransferSrcScalarPerVector == 0 && + b0_extent_lowest % B0BlockTransferSrcScalarPerVector == 0 && + b1_extent_lowest % B1BlockTransferSrcScalarPerVector == 0 && + c_extent_lowest % CShuffleBlockTransferScalarPerVector_NPerBlock == 0)) + { + printf("DeviceOp: Data Transfer Vector scalar err"); + return false; + } + + std::array a_mz_kz_strides_{ + a_gs_ms_ks_strides[NumDimG + NumDimM - 1], + a_gs_ms_ks_strides[NumDimG + NumDimM + NumDimK - 1]}; + std::array b0_lz_kz_strides_{ + b0_gs_ns_ks_strides[NumDimG + NumDimL - 1], + b0_gs_ns_ks_strides[NumDimG + NumDimL + NumDimK - 1]}; + std::array b1_nz_lz_strides_{ + b1_gs_os_ns_strides[NumDimG + NumDimN - 1], + b1_gs_os_ns_strides[NumDimG + NumDimN + NumDimL - 1]}; + std::array c_mz_nz_strides_{ + c_gs_ms_os_strides[NumDimG + NumDimM - 1], + c_gs_ms_os_strides[NumDimG + NumDimM + NumDimN - 1]}; + + // Check vector load/store requirement + const auto a_stride_lowest = + ABlockTransferSrcVectorDim == 2 ? a_mz_kz_strides_[1] : a_mz_kz_strides_[0]; + const auto b0_stride_lowest = + B0BlockTransferSrcVectorDim == 2 ? b0_lz_kz_strides_[1] : b0_lz_kz_strides_[0]; + const auto b1_stride_lowest = + B1BlockTransferSrcVectorDim == 2 ? b1_nz_lz_strides_[1] : b1_nz_lz_strides_[0]; + const auto c_stride_lowest = c_mz_nz_strides_[1]; + + if(!(a_stride_lowest == 1 || b0_stride_lowest == 1 || b1_stride_lowest == 1 || + c_stride_lowest == 1)) + { + printf("DeviceOp: Data Vectorize transfer err"); + return false; + } + + return true; + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + // Argument + struct Argument : public BaseArgument + { + Argument( + const ADataType* p_a_grid, + const B0DataType* p_b0_grid, + const B1DataType* p_b1_grid, + CDataType* p_c_grid, + const std::array p_acc0_biases, + const std::array p_acc1_biases, + const std::array& a_gs_ms_ks_lengths, + const std::array& a_gs_ms_ks_strides, + const std::array& b0_gs_ls_ks_lengths, + const std::array& b0_gs_ls_ks_strides, + const std::array& b1_gs_ns_ls_lengths, + const std::array& b1_gs_ns_ls_strides, + const std::array& c_gs_ms_ns_lengths, + const std::array& c_gs_ms_ns_strides, + const std::array, NumAcc0Bias> acc0_biases_gs_ms_ls_lengths, + const std::array, NumAcc0Bias> acc0_biases_gs_ms_ls_strides, + const std::array, NumAcc1Bias> acc1_biases_gs_ms_ns_lengths, + const std::array, NumAcc1Bias> acc1_biases_gs_ms_ns_strides, + const index_t M01, + const index_t N01, + AElementwiseOperation a_element_op, + B0ElementwiseOperation b0_element_op, + AccElementwiseOperation acc_element_op, + B1ElementwiseOperation b1_element_op, + CElementwiseOperation c_element_op) + : p_a_grid_{p_a_grid}, + p_b0_grid_{p_b0_grid}, + p_b1_grid_{p_b1_grid}, + p_c_grid_{p_c_grid}, + a_grid_desc{DeviceOp::MakeAGridDescriptor(a_gs_ms_ks_lengths, a_gs_ms_ks_strides)}, + b0_grid_desc{ + DeviceOp::MakeB0GridDescriptor(b0_gs_ls_ks_lengths, b0_gs_ls_ks_strides)}, + b1_grid_desc{ + DeviceOp::MakeB1GridDescriptor(b1_gs_ns_ls_lengths, b1_gs_ns_ls_strides)}, + c_grid_desc_m_n_{ + Transform::MakeCGridDescriptor_M_N(c_gs_ms_ns_lengths, c_gs_ms_ns_strides)}, + a_grid_desc_g_m_k_{ + Transform::MakeAGridDescriptor_G_M_K(a_gs_ms_ks_lengths, a_gs_ms_ks_strides)}, + b0_grid_desc_g_l_k_{ + Transform::MakeB0GridDescriptor_G_N_K(b0_gs_ls_ks_lengths, b0_gs_ls_ks_strides)}, + b1_grid_desc_g_n_l_{ + Transform::MakeB1GridDescriptor_G_N_K(b1_gs_ns_ls_lengths, b1_gs_ns_ls_strides)}, + c_grid_desc_g_m_n_{ + Transform::MakeCGridDescriptor_G_M_N(c_gs_ms_ns_lengths, c_gs_ms_ns_strides)}, + c_grid_desc_mblock_mperblock_nblock_nperblock_{}, + block_2_ctile_map_{GridwiseOp::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_, M01, N01)}, + a_element_op_{a_element_op}, + b0_element_op_{b0_element_op}, + acc_element_op_{acc_element_op}, + b1_element_op_{b1_element_op}, + c_element_op_{c_element_op}, + c0_matrix_mask_{b0_grid_desc_g_l_k_.GetLength(I1)}, + raw_lengths_mz_lz_kz_nz_{a_gs_ms_ks_lengths[NumDimG + NumDimM - 1], + b0_gs_ls_ks_lengths[NumDimG + NumDimL - 1], + b0_gs_ls_ks_lengths[NumDimG + NumDimL + NumDimK - 1], + b1_gs_ns_ls_lengths[NumDimG + NumDimN - 1]}, + a_mz_kz_strides_{a_gs_ms_ks_strides[NumDimG + NumDimM - 1], + a_gs_ms_ks_strides[NumDimG + NumDimM + NumDimK - 1]}, + b0_lz_kz_strides_{b0_gs_ls_ks_strides[NumDimG + NumDimL - 1], + b0_gs_ls_ks_strides[NumDimG + NumDimL + NumDimK - 1]}, + b1_nz_lz_strides_{b1_gs_ns_ls_strides[NumDimG + NumDimN - 1], + b1_gs_ns_ls_strides[NumDimG + NumDimN + NumDimL - 1]}, + c_mz_nz_strides_{c_gs_ms_ns_strides[NumDimG + NumDimM - 1], + c_gs_ms_ns_strides[NumDimG + NumDimM + NumDimN - 1]}, + batch_count_{c_grid_desc_g_m_n_.GetLength(I0)}, + compute_ptr_offset_of_batch_{ + a_grid_desc_g_m_k_, b0_grid_desc_g_l_k_, b1_grid_desc_g_n_l_, c_grid_desc_g_m_n_} + { + // TODO ANT: implement bias addition + ignore = p_acc0_biases; + ignore = p_acc1_biases; + ignore = acc0_biases_gs_ms_ls_lengths; + ignore = acc0_biases_gs_ms_ls_strides; + ignore = acc1_biases_gs_ms_ns_lengths; + ignore = acc1_biases_gs_ms_ns_strides; + + if(GridwiseOp::CheckValidity( + a_grid_desc, b0_grid_desc, b1_grid_desc, c_grid_desc_m_n_, block_2_ctile_map_)) + { + c_grid_desc_mblock_mperblock_nblock_nperblock_ = + GridwiseOp::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + c_grid_desc_m_n_); + } + } + + // Pointers + const ADataType* p_a_grid_; + const B0DataType* p_b0_grid_; + const B1DataType* p_b1_grid_; + CDataType* p_c_grid_; + + // Tensor Descriptors + AGridDesc a_grid_desc; + B0GridDesc b0_grid_desc; + B1GridDesc b1_grid_desc; + CGridDesc_M_N c_grid_desc_m_n_; + + AGridDesc_G_M_K a_grid_desc_g_m_k_; + B0GridDesc_G_L_K b0_grid_desc_g_l_k_; + B1GridDesc_G_N_L b1_grid_desc_g_n_l_; + CGridDesc_G_M_N c_grid_desc_g_m_n_; + + typename GridwiseOp::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + c_grid_desc_mblock_mperblock_nblock_nperblock_; + + // Block to Tile mapping + typename GridwiseOp::DefaultBlock2CTileMap block_2_ctile_map_; + + // ElementwiseOp + AElementwiseOperation a_element_op_; + B0ElementwiseOperation b0_element_op_; + AccElementwiseOperation acc_element_op_; + B1ElementwiseOperation b1_element_op_; + CElementwiseOperation c_element_op_; + + // check C0 masking and padding + C0MatrixMask c0_matrix_mask_; + + // Strides for the last M/N/K dimensions of A/B0/B1/C + // for sanity check of vector load/store + std::array raw_lengths_mz_lz_kz_nz_; + std::array a_mz_kz_strides_; + std::array b0_lz_kz_strides_; + std::array b1_nz_lz_strides_; + std::array c_mz_nz_strides_; + + index_t batch_count_; + // Batch Offset + ComputeBasePtrOfStridedBatch compute_ptr_offset_of_batch_; + }; + + struct Invoker : public BaseInvoker + { + using Argument = DeviceOp::RawArg; + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + const auto M0 = math::integer_divide_ceil(arg.M_, MPerBlock); + const auto N0 = math::integer_divide_ceil(arg.O_, NPerBlock); + + const index_t grid_size = arg.G0_ * arg.G1_ * M0 * N0; + const auto K = arg.K_; + // printf("HasKBlockLoop: %d\n", GridwiseOp::CalculateHasMainKBlockLoop(K)); + auto launch_kernel = [&](auto has_main_k_block_loop) { + const auto kernel = kernel_grouped_query_attention_wmma; + + return launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b0_grid_, + arg.p_b1_grid_, + arg.p_c_grid_, + arg.M_, + arg.N_, + arg.K_, + arg.O_, + arg.G0_, + arg.G1_, + arg.alpha_, + arg.input_permute_, + arg.output_permute_); + }; + + if(GridwiseOp::CalculateHasMainKBlockLoop(K)) + { + return launch_kernel(integral_constant{}); + } + else + { + return launch_kernel(integral_constant{}); + } + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } +#if 0 + static bool IsSupportedArgument(const Argument& arg) + { + if(ck::is_navi3_supported()) + { + if constexpr(!(is_same_v || is_same_v)) + { + printf("DeviceOp: Acc0 Type err"); + return false; + } + + if constexpr(!(is_same_v || is_same_v)) + { + printf("DeviceOp: Acc1 Type err"); + return false; + } + } + else + { + printf("DeviceOp: Arch err"); + return false; + } + + if(!GridwiseOp::CheckValidity(arg.a_grid_desc, + arg.b0_grid_desc, + arg.b1_grid_desc, + arg.c_grid_desc_m_n_, + arg.block_2_ctile_map_)) + { + return false; + } + + // Check if C permute dimension matches GEMM + GEMM shape + const index_t c_g = arg.c_grid_desc_g_m_n_.GetLength(I0); // unpadded + + if(!(c_g == arg.batch_count_)) + { + printf("DeviceOp: BatchCount err"); + return false; + } + + // Note: we need raw lengths since threadwise copy can not handle vector load when part of + // vector is out of bounds + // Note: need lowest dim in Ms/Ns/Ks/Os, not merged M/N/K/O + const auto MzRaw = arg.raw_lengths_mz_lz_kz_nz_[0]; + const auto LzRaw = arg.raw_lengths_mz_lz_kz_nz_[1]; + const auto KzRaw = arg.raw_lengths_mz_lz_kz_nz_[2]; + const auto NzRaw = arg.raw_lengths_mz_lz_kz_nz_[3]; + + // Check scalar per vector requirement + const auto a_extent_lowest = ABlockTransferSrcVectorDim == 2 ? KzRaw : MzRaw; + const auto b0_extent_lowest = B0BlockTransferSrcVectorDim == 2 ? KzRaw : LzRaw; + const auto b1_extent_lowest = B1BlockTransferSrcVectorDim == 2 ? LzRaw : NzRaw; + const auto c_extent_lowest = NzRaw; + + if(!(a_extent_lowest % ABlockTransferSrcScalarPerVector == 0 && + b0_extent_lowest % B0BlockTransferSrcScalarPerVector == 0 && + b1_extent_lowest % B1BlockTransferSrcScalarPerVector == 0 && + c_extent_lowest % CShuffleBlockTransferScalarPerVector_NPerBlock == 0)) + { + printf("DeviceOp: Data Transfer Vector scalar err"); + return false; + } + + // Check vector load/store requirement + const auto a_stride_lowest = + ABlockTransferSrcVectorDim == 2 ? arg.a_mz_kz_strides_[1] : arg.a_mz_kz_strides_[0]; + const auto b0_stride_lowest = + B0BlockTransferSrcVectorDim == 2 ? arg.b0_lz_kz_strides_[1] : arg.b0_lz_kz_strides_[0]; + const auto b1_stride_lowest = + B1BlockTransferSrcVectorDim == 2 ? arg.b1_nz_lz_strides_[1] : arg.b1_nz_lz_strides_[0]; + const auto c_stride_lowest = arg.c_mz_nz_strides_[1]; + + if(!(a_stride_lowest == 1 || b0_stride_lowest == 1 || b1_stride_lowest == 1 || + c_stride_lowest == 1)) + { + printf("DeviceOp: Data Vectorize transfer err"); + return false; + } + + return true; + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument( + const ADataType* p_a, + const B0DataType* p_b0, + const B1DataType* p_b1, + CDataType* p_c, + const std::array p_acc0_biases, + const std::array p_acc1_biases, + const std::array& a_gs_ms_ks_lengths, + const std::array& a_gs_ms_ks_strides, + const std::array& b0_gs_ls_ks_lengths, + const std::array& b0_gs_ls_ks_strides, + const std::array& b1_gs_ns_ls_lengths, + const std::array& b1_gs_ns_ls_strides, + const std::array& c_gs_ms_ns_lengths, + const std::array& c_gs_ms_ns_strides, + const std::array, NumAcc0Bias> acc0_biases_gs_ms_ls_lengths, + const std::array, NumAcc0Bias> acc0_biases_gs_ms_ls_strides, + const std::array, NumAcc1Bias> acc1_biases_gs_ms_ns_lengths, + const std::array, NumAcc1Bias> acc1_biases_gs_ms_ns_strides, + AElementwiseOperation a_element_op, + B0ElementwiseOperation b0_element_op, + AccElementwiseOperation acc_element_op, + B1ElementwiseOperation b1_element_op, + CElementwiseOperation c_element_op) + { + return Argument{p_a, + p_b0, + p_b1, + p_c, + p_acc0_biases, + p_acc1_biases, + a_gs_ms_ks_lengths, + a_gs_ms_ks_strides, + b0_gs_ls_ks_lengths, + b0_gs_ls_ks_strides, + b1_gs_ns_ls_lengths, + b1_gs_ns_ls_strides, + c_gs_ms_ns_lengths, + c_gs_ms_ns_strides, + acc0_biases_gs_ms_ls_lengths, + acc0_biases_gs_ms_ls_strides, + acc1_biases_gs_ms_ns_lengths, + acc1_biases_gs_ms_ns_strides, + 1, + 1, + a_element_op, + b0_element_op, + acc_element_op, + b1_element_op, + c_element_op}; + } +#endif + + // polymorphic + std::unique_ptr MakeArgumentPointer( + const void* p_a, + const void* p_b0, + const void* p_b1, + void* p_c, + const std::array p_acc0_biases, + const std::array p_acc1_biases, + const std::vector& a_gs_ms_ks_lengths, + const std::vector& a_gs_ms_ks_strides, + const std::vector& b0_gs_ls_ks_lengths, + const std::vector& b0_gs_ls_ks_strides, + const std::vector& b1_gs_ns_ls_lengths, + const std::vector& b1_gs_ns_ls_strides, + const std::vector& c_gs_ms_ns_lengths, + const std::vector& c_gs_ms_ns_strides, + const std::array, NumAcc0Bias> acc0_biases_gs_ms_ls_lengths, + const std::array, NumAcc0Bias> acc0_biases_gs_ms_ls_strides, + const std::array, NumAcc1Bias> acc1_biases_gs_ms_ns_lengths, + const std::array, NumAcc1Bias> acc1_biases_gs_ms_ns_strides, + AElementwiseOperation a_element_op, + B0ElementwiseOperation b0_element_op, + AccElementwiseOperation acc_element_op, + B1ElementwiseOperation b1_element_op, + CElementwiseOperation c_element_op) override + { + std::array a_lengths; + std::array a_strides; + std::array b0_lengths; + std::array b0_strides; + std::array b1_lengths; + std::array b1_strides; + std::array c_lengths; + std::array c_strides; + std::transform(a_gs_ms_ks_lengths.begin(), + a_gs_ms_ks_lengths.end(), + a_lengths.begin(), + [](index_t i) { return i; }); + std::transform(a_gs_ms_ks_strides.begin(), + a_gs_ms_ks_strides.end(), + a_strides.begin(), + [](index_t i) { return i; }); + std::transform(b0_gs_ls_ks_lengths.begin(), + b0_gs_ls_ks_lengths.end(), + b0_lengths.begin(), + [](index_t i) { return i; }); + std::transform(b0_gs_ls_ks_strides.begin(), + b0_gs_ls_ks_strides.end(), + b0_strides.begin(), + [](index_t i) { return i; }); + std::transform(b1_gs_ns_ls_lengths.begin(), + b1_gs_ns_ls_lengths.end(), + b1_lengths.begin(), + [](index_t i) { return i; }); + std::transform(b1_gs_ns_ls_strides.begin(), + b1_gs_ns_ls_strides.end(), + b1_strides.begin(), + [](index_t i) { return i; }); + std::transform(c_gs_ms_ns_lengths.begin(), + c_gs_ms_ns_lengths.end(), + c_lengths.begin(), + [](index_t i) { return i; }); + std::transform(c_gs_ms_ns_strides.begin(), + c_gs_ms_ns_strides.end(), + c_strides.begin(), + [](index_t i) { return i; }); + return std::make_unique(static_cast(p_a), + static_cast(p_b0), + static_cast(p_b1), + static_cast(p_c), + p_acc0_biases, + p_acc1_biases, + a_lengths, + a_strides, + b0_lengths, + b0_strides, + b1_lengths, + b1_strides, + c_lengths, + c_strides, + acc0_biases_gs_ms_ls_lengths, + acc0_biases_gs_ms_ls_strides, + acc1_biases_gs_ms_ns_lengths, + acc1_biases_gs_ms_ns_strides, + 1, + 1, + a_element_op, + b0_element_op, + acc_element_op, + b1_element_op, + c_element_op); + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + std::map LoopSchedToString{ + {LoopScheduler::Default, "Default"}, {LoopScheduler::Interwave, "Interwave"}}; + + std::map PipelineVersionToString{{PipelineVersion::v1, "v1"}, + {PipelineVersion::v2, "v2"}}; + + // clang-format off + str << "DeviceGroupedQueryAttentionForward_Wmma, " + << "QueryGroupNumber: " + << QueryGroupNumber << ", " + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << LPerBlock << ", " + << KPerBlock << ", " + << AK1 << ", " + << BK1 << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << LTilePerBlock << ", " + << L1 << ", " + << getGemmSpecializationString(GemmSpec) << ", " + << "ASpec" << getTensorSpecializationString(ASpec) << ", " + << "B0Spec" << getTensorSpecializationString(B0Spec) << ", " + << "B1Spec" << getTensorSpecializationString(B1Spec) << ", " + << "CSpec" << getTensorSpecializationString(CSpec) << ", " + << getMaskingSpecializationString(MaskingSpec) + << ">" + << " AEnableLds: " + << AEnableLds << ", " + << "B0EnableLds: " + << B0EnableLds << ", " + << "B1EnableLds: " + << B1EnableLds << ", " + << "NumPrefetch: " + << NumPrefetch << ", " + << "LoopScheduler: " + << LoopSchedToString[LoopSched] << ", " + << "PipelineVersion: " + << PipelineVersionToString[PipelineVer]; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/impl/device_multi_query_attention_forward_wmma.hpp b/include/ck/tensor_operation/gpu/device/impl/device_multi_query_attention_forward_wmma.hpp new file mode 100644 index 0000000000..b7551e78a2 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/impl/device_multi_query_attention_forward_wmma.hpp @@ -0,0 +1,1244 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm_permute.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp" +#include "ck/tensor_operation/operator_transform/transform_contraction_to_gemm_arraybase.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +// Multi-Query Attention (MQA) kernel implementation +// Assume number of head of K,V is 1. +// Q [G0, G1, M, K] * K [G0, 1, K, N] = P [G0, G1, M, N] +// P [G0, G1, M, N] * V [G0, 1, N, O] = Out [G0, G1, M, O] +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_multi_query_attention_wmma(const ADataType* __restrict__ p_a_grid, + const B0DataType* __restrict__ p_b0_grid, + const B1DataType* __restrict__ p_b1_grid, + CDataType* __restrict__ p_c_grid, + index_t M, // SequenceQ + index_t N, // SequenceK + index_t K, // HeadDim + index_t O, // SequenceK + index_t G0, // Batch + index_t G1, // HeadNum + float alpha, + bool input_permute, + bool output_permute) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__)) + + // clang-format off +// *************************************************** + const auto q_head = G1; + const auto kv_head = 1; +// Make Tensor Descriptors + constexpr index_t array_size = 4; + std::array a_gs_ms_ks_lengths{G0, q_head, M, K}; + std::array a_gs_ms_ks_strides = + input_permute + ? std::array{M * q_head * K, K, q_head * K, 1} // A layout [G0, M, G1, K] + : std::array{q_head * M * K, M * K, K, 1}; // A layout [G0, G1, M, K] + + std::array b0_gs_ns_ks_lengths{G0, kv_head, N, K}; + std::array b0_gs_ns_ks_strides = + input_permute + ? std::array{N * kv_head * K, K, kv_head * K, 1} // B0 layout [G0, N, 1, K] + : std::array{kv_head * N * K, N * K, K, 1}; // B0 layout [G0, 1, N, K] + + std::array b1_gs_os_ns_lengths{G0, kv_head, O, N}; + std::array b1_gs_os_ns_strides = + input_permute + ? std::array{N * kv_head * O, O, 1, kv_head * O} // B1 layout [G0, N, 1, O] + : std::array{kv_head * N * O, N * O, 1, O}; // B1 layout [G0, 1, N, O] + + std::array c_gs_ms_os_lengths{G0, q_head, M, O}; + std::array c_gs_ms_os_strides = + output_permute + ? std::array{M * q_head * O, O, q_head * O, 1} // C layout [G0, M, G1, O] + : std::array{q_head * M * O, M * O, O, 1}; // C layout [G0, G1, M, O] + + const auto a_element_op = AElementwiseOperation{}; + const auto b0_element_op = B0ElementwiseOperation{}; + const auto acc0_element_op = AccElementwiseOperation{alpha}; + const auto b1_element_op = B1ElementwiseOperation{}; + const auto c_element_op = CElementwiseOperation{}; + // fail to reuse DeviceOp::MakeArgument() because of the __device__ function required. + + const auto a_grid_desc = DeviceOp::MakeAGridDescriptor(a_gs_ms_ks_lengths, a_gs_ms_ks_strides); + const auto b0_grid_desc = + DeviceOp::MakeB0GridDescriptor(b0_gs_ns_ks_lengths, b0_gs_ns_ks_strides); + const auto b1_grid_desc = + DeviceOp::MakeB1GridDescriptor(b1_gs_os_ns_lengths, b1_gs_os_ns_strides); + const auto c_grid_desc_m_n = + DeviceOp::Transform::MakeCGridDescriptor_M_N(c_gs_ms_os_lengths, c_gs_ms_os_strides); + const auto c_grid_desc_mblock_mperblock_nblock_nperblock = + GridwiseOp::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(c_grid_desc_m_n); + const auto block_2_ctile_map = GridwiseOp::MakeDefaultBlock2CTileMap(c_grid_desc_m_n, 1, 1); + + const auto a_grid_desc_g_m_k = + DeviceOp::Transform::MakeAGridDescriptor_G_M_K(a_gs_ms_ks_lengths, a_gs_ms_ks_strides); + const auto b0_grid_desc_g_l_k = + DeviceOp::Transform::MakeB0GridDescriptor_G_N_K(b0_gs_ns_ks_lengths, b0_gs_ns_ks_strides); + const auto b1_grid_desc_g_n_l = + DeviceOp::Transform::MakeB1GridDescriptor_G_N_K(b1_gs_os_ns_lengths, b1_gs_os_ns_strides); + const auto c_grid_desc_g_m_n = + DeviceOp::Transform::MakeCGridDescriptor_G_M_N(c_gs_ms_os_lengths, c_gs_ms_os_strides); + const auto compute_base_ptr_of_batch = + typename DeviceOp::ComputeBasePtrOfStridedBatch{a_grid_desc_g_m_k, b0_grid_desc_g_l_k, b1_grid_desc_g_n_l, c_grid_desc_g_m_n}; + index_t batch_count = c_grid_desc_g_m_n.GetLength(Number<0>{}); + const auto c0_matrix_mask = typename DeviceOp::C0MatrixMask{b0_grid_desc_g_l_k.GetLength(Number<1>{})}; + + // clang-format on + __shared__ char p_shared[GridwiseOp::GetSharedMemoryNumberOfByte()]; + const index_t num_blocks_per_batch = + __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); + const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); + + const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_base_ptr_of_batch.GetABasePtr(g_idx))); + const long_index_t b0_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_base_ptr_of_batch.GetB0BasePtr(g_idx / G1))); + const long_index_t b1_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_base_ptr_of_batch.GetB1BasePtr(g_idx / G1))); + const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_base_ptr_of_batch.GetCBasePtr(g_idx))); + + GridwiseOp::template Run(p_a_grid + a_batch_offset, + p_b0_grid + b0_batch_offset, + p_b1_grid + b1_batch_offset, + p_c_grid + c_batch_offset, + p_shared, + a_grid_desc, + b0_grid_desc, + b1_grid_desc, + c_grid_desc_mblock_mperblock_nblock_nperblock, + a_element_op, + b0_element_op, + acc0_element_op, + b1_element_op, + c_element_op, + c0_matrix_mask, + block_2_ctile_map); +#else + ignore = p_a_grid; + ignore = p_b0_grid; + ignore = p_b1_grid; + ignore = p_c_grid; + ignore = M; + ignore = N; + ignore = K; + ignore = O; + ignore = G0; + ignore = G1; + ignore = input_permute; + ignore = output_permute; +#endif // end of if (defined(__gfx11__)) +} + +// Computes C = A * B0 * B1 +// MN = MK * KL * LN +// ^^^^^^ (Acc0) +// ^^^^^^^^^^^ (Acc1) +template +struct DeviceMultiQueryAttentionForward_Wmma + : public DeviceBatchedGemmSoftmaxGemmPermute +{ + static_assert(NumDimG > 0 && NumDimM > 0 && NumDimL > 0 && NumDimK > 0 && NumDimN > 0, + "Number of dimension must be greater than 0"); + + static constexpr index_t NumAcc0Bias = Acc0BiasDataType::Size(); + static constexpr index_t NumAcc1Bias = Acc1BiasDataType::Size(); + + // TODO ANT: implement bias combination + static_assert(NumAcc0Bias == 0 && NumAcc0Bias == 0, "Bias addition is unimplemented"); + + static constexpr index_t NumDimGemm0M = NumDimM; + static constexpr index_t NumDimGemm0N = NumDimL; + static constexpr index_t NumDimGemm0K = NumDimK; + static constexpr index_t NumDimGemm1M = NumDimM; + static constexpr index_t NumDimGemm1N = NumDimN; + static constexpr index_t NumDimGemm1K = NumDimL; + + using DeviceOp = DeviceMultiQueryAttentionForward_Wmma; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto I6 = Number<6>{}; + + static constexpr auto WmmaK = 16; + + static constexpr auto MWaves = MPerBlock / (MRepeat * MPerWmma); + static constexpr auto LWaves = LPerBlock / (LRepeat * LPerWmma); + static constexpr auto NWaves = NPerBlock / (NRepeat * NPerWmma); + + static constexpr auto AEnableLds_auto = LWaves == 1 ? false : true; + static constexpr auto B0EnableLds_auto = MWaves == 1 ? false : true; + static constexpr auto B1EnableLds_auto = MWaves == 1 ? false : true; + + static constexpr auto AEnableLds_manu = false; + static constexpr auto B0EnableLds_manu = true; + static constexpr auto B1EnableLds_manu = true; + + static constexpr auto AEnableLds = AEnableLds_auto || AEnableLds_manu || (NumPrefetch > 1); + static constexpr auto B0EnableLds = B0EnableLds_auto || B0EnableLds_manu || (NumPrefetch > 1); + static constexpr auto B1EnableLds = B1EnableLds_auto || B1EnableLds_manu || (NumPrefetch > 1); + + using Transform = TransformBatchedContractionContractionToBatchedGemmGemm_Wmma< + Sequence, + Sequence, + GemmSpec, + ASpec, + B0Spec, + B1Spec, + CSpec>; + + __host__ __device__ static auto MakeAGridDescriptor( + const std::array& a_gs_ms_ks_lengths_vec, + const std::array& a_gs_ms_ks_strides_vec) + { + if constexpr(AEnableLds) + { + return Transform::MakeAGridDescriptor_AK0_M_AK1( + Transform::MakeAGridDescriptor_M_K(a_gs_ms_ks_lengths_vec, a_gs_ms_ks_strides_vec), + Number{}); + } + else + { + return Transform:: + MakeAGridDescriptor_AKWmma_MBlockRepeat_MWaves_AK0PerWmma_AKRow_MPerWmma_AK1( + Transform::MakeAGridDescriptor_M_K(a_gs_ms_ks_lengths_vec, + a_gs_ms_ks_strides_vec), + Number{}, + Number{}, + Number{}, + Number{}, + Number{}); + } + } + + __host__ __device__ static auto MakeB0GridDescriptor( + const std::array& b0_gs_ls_ks_lengths_vec, + const std::array& b0_gs_ls_ks_strides_vec) + { + if constexpr(B0EnableLds) + { + return Transform::MakeB0GridDescriptor_BK0_N_BK1( + Transform::MakeB0GridDescriptor_N_K(b0_gs_ls_ks_lengths_vec, + b0_gs_ls_ks_strides_vec), + Number{}); + } + else + { + return Transform:: + MakeB0GridDescriptor_BKWmma_LBlockRepeat_LWaves_BK0PerWmma_BKRow_LPerWmma_BK1( + Transform::MakeB0GridDescriptor_N_K(b0_gs_ls_ks_lengths_vec, + b0_gs_ls_ks_strides_vec), + Number{}, + Number{}, + Number{}, + Number{}, + Number{}); + } + } + + __host__ __device__ static auto MakeB1GridDescriptor( + const std::array& b1_gs_ns_ls_lengths_vec, + const std::array& b1_gs_ns_ls_strides_vec) + { + if constexpr(B1EnableLds) + { + return Transform::MakeB1GridDescriptor_BK0_N_BK1( + Transform::MakeB1GridDescriptor_N_K(b1_gs_ns_ls_lengths_vec, + b1_gs_ns_ls_strides_vec), + Number{}); + } + else + { + return Transform:: + MakeB1GridDescriptor_BLWmma_NBlockRepeat_NWaves__BL0PerWmma_BLRow_NPerWmma_BL1( + Transform::MakeB1GridDescriptor_N_K(b1_gs_ns_ls_lengths_vec, + b1_gs_ns_ls_strides_vec), + Number{}, + Number{}, + Number{}, + Number{}, + Number{}); + } + } + + using AGridDesc = decltype(MakeAGridDescriptor({}, {})); + using B0GridDesc = decltype(MakeB0GridDescriptor({}, {})); + using B1GridDesc = decltype(MakeB1GridDescriptor({}, {})); + using CGridDesc_M_N = decltype(Transform::MakeCGridDescriptor_M_N({}, {})); + using AGridDesc_G_M_K = decltype(Transform::MakeAGridDescriptor_G_M_K({}, {})); + using B0GridDesc_G_L_K = decltype(Transform::MakeB0GridDescriptor_G_N_K({}, {})); + using B1GridDesc_G_N_L = decltype(Transform::MakeB1GridDescriptor_G_N_K({}, {})); + using CGridDesc_G_M_N = decltype(Transform::MakeCGridDescriptor_G_M_N({}, {})); + + __host__ __device__ constexpr static auto make_MaskOutPredicate() + { + if constexpr(MaskingSpec == MaskingSpecialization::MaskDisabled) + { + return MaskDisabledPredicate{}; + } + else if constexpr(MaskingSpec == MaskingSpecialization::MaskOutUpperTriangle) + { + return MaskOutUpperTrianglePredicate{}; + } + } + using C0MatrixMask = C0MatrixMask_impl; + + struct ComputeBasePtrOfStridedBatch + { + __host__ __device__ ComputeBasePtrOfStridedBatch(const AGridDesc_G_M_K& a_grid_desc_g_m_k, + const B0GridDesc_G_L_K& b0_grid_desc_g_l_k, + const B1GridDesc_G_N_L& b1_grid_desc_g_n_l, + const CGridDesc_G_M_N& c_grid_desc_g_m_n) + : a_grid_desc_g_m_k_(a_grid_desc_g_m_k), + b0_grid_desc_g_l_k_(b0_grid_desc_g_l_k), + b1_grid_desc_g_n_l_(b1_grid_desc_g_n_l), + c_grid_desc_g_m_n_(c_grid_desc_g_m_n) + { + } + + __host__ __device__ constexpr long_index_t GetABasePtr(index_t g_idx) const + { + return a_grid_desc_g_m_k_.CalculateOffset(make_multi_index(g_idx, 0, 0)); + } + + __host__ __device__ constexpr long_index_t GetB0BasePtr(index_t g_idx) const + { + return b0_grid_desc_g_l_k_.CalculateOffset(make_multi_index(g_idx, 0, 0)); + } + + __host__ __device__ constexpr long_index_t GetB1BasePtr(index_t g_idx) const + { + return b1_grid_desc_g_n_l_.CalculateOffset(make_multi_index(g_idx, 0, 0)); + } + + __host__ __device__ constexpr long_index_t GetCBasePtr(index_t g_idx) const + { + return c_grid_desc_g_m_n_.CalculateOffset(make_multi_index(g_idx, 0, 0)); + } + + private: + AGridDesc_G_M_K a_grid_desc_g_m_k_; + B0GridDesc_G_L_K b0_grid_desc_g_l_k_; + B1GridDesc_G_N_L b1_grid_desc_g_n_l_; + CGridDesc_G_M_N c_grid_desc_g_m_n_; + }; + + // GridwiseOp + using GridwiseOp = GridwiseBatchedGemmSoftmaxGemm_Wmma< + // DataType Family + ADataType, + B0DataType, + Acc0DataType, + B1DataType, + Acc1DataType, + CShuffleDataType, + CDataType, + // ElementwiseOp Family + AElementwiseOperation, + B0ElementwiseOperation, + AccElementwiseOperation, + B1ElementwiseOperation, + CElementwiseOperation, + InMemoryDataOperationEnum::Set, + // InMemory Data Descriptor + AGridDesc, + B0GridDesc, + B1GridDesc, + CGridDesc_M_N, + // Tiling Family + MPerBlock, + LPerBlock, + KPerBlock, + AK1, + BK1, + NPerBlock, + LTilePerBlock, + L1, + MPerWmma, + LPerWmma, + NPerWmma, + MRepeat, + LRepeat, + NRepeat, + // ThreadCluster Family + BlockSize, + ABlockTransferThreadClusterLengths_K0_M_K1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + true, + AEnableLds, + ABlockLdsAddExtraM, + B0BlockTransferThreadClusterLengths_K0_L_K1, + B0BlockTransferThreadClusterArrangeOrder, + B0BlockTransferSrcAccessOrder, + B0BlockTransferSrcVectorDim, + B0BlockTransferSrcScalarPerVector, + B0BlockTransferDstScalarPerVector_K1, + true, + B0EnableLds, + B0BlockLdsAddExtraL, + B1BlockTransferThreadClusterLengths_L0_N_L1, + B1BlockTransferThreadClusterArrangeOrder, + B1BlockTransferSrcAccessOrder, + B1BlockTransferSrcVectorDim, + B1BlockTransferSrcScalarPerVector, + B1BlockTransferDstScalarPerVector_L1, + false, + B1EnableLds, + B1BlockLdsAddExtraN, + CShuffleMRepeatPerShuffle, + CShuffleNRepeatPerShuffle, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CShuffleBlockTransferScalarPerVector_NPerBlock, + Transform::matrix_padder.PadN, + MaskingSpec == MaskingSpecialization::MaskOutUpperTriangle, + NumPrefetch, + LoopSched, + PipelineVer>; + + struct RawArg : public BaseArgument + { + RawArg(const ADataType* p_a_grid, + const B0DataType* p_b0_grid, + const B1DataType* p_b1_grid, + CDataType* p_c_grid, + index_t M, + index_t N, + index_t K, + index_t O, + index_t G0, + index_t G1, + float alpha, + bool input_permute, + bool output_permute) + : p_a_grid_{p_a_grid}, + p_b0_grid_{p_b0_grid}, + p_b1_grid_{p_b1_grid}, + p_c_grid_{p_c_grid}, + M_{M}, + N_{N}, + K_{K}, + O_{O}, + G0_{G0}, + G1_{G1}, + alpha_{alpha}, + input_permute_{input_permute}, + output_permute_{output_permute} + { + } + // Pointers + const ADataType* p_a_grid_; + const B0DataType* p_b0_grid_; + const B1DataType* p_b1_grid_; + CDataType* p_c_grid_; + + // Raw Problem Size + index_t M_; + index_t N_; + index_t K_; + index_t O_; + index_t G0_; + index_t G1_; + float alpha_; + bool input_permute_; + bool output_permute_; + }; + + static auto MakeArgument(const ADataType* p_a, + const B0DataType* p_b0, + const B1DataType* p_b1, + CDataType* p_c, + index_t M, + index_t N, + index_t K, + index_t O, + index_t G0, + index_t G1, + float alpha, + bool input_permute, + bool output_permute) + { + return RawArg{ + p_a, p_b0, p_b1, p_c, M, N, K, O, G0, G1, alpha, input_permute, output_permute}; + } + + static bool IsSupportedArgument(const RawArg& arg) + { + if(ck::is_navi3_supported()) + { + if constexpr(!(is_same_v || is_same_v)) + { + printf("DeviceOp: Acc0 Type err"); + return false; + } + + if constexpr(!(is_same_v || is_same_v)) + { + printf("DeviceOp: Acc1 Type err"); + return false; + } + } + else + { + printf("DeviceOp: Arch err"); + return false; + } + + constexpr index_t array_size = 4; + ck::index_t G0 = arg.G0_; + ck::index_t G1 = arg.G1_; + ck::index_t M = arg.M_; + ck::index_t N = arg.N_; + ck::index_t K = arg.K_; + ck::index_t O = arg.O_; + bool input_permute = arg.input_permute_; + bool output_permute = arg.output_permute_; + + std::array a_gs_ms_ks_lengths{G0, G1, M, K}; + std::array a_gs_ms_ks_strides = + input_permute ? std::array{M * G1 * K, K, G1 * K, 1} + // A layout [G0, M, G1, K] + : std::array{ + G1 * M * K, M * K, K, 1}; // A layout [G0, G1, M, K] + + std::array b0_gs_ns_ks_lengths{G0, G1, N, K}; + std::array b0_gs_ns_ks_strides = + input_permute ? std::array{N * G1 * K, K, G1 * K, 1} + // B0 layout [G0, N, G1, K] + : std::array{ + G1 * N * K, N * K, K, 1}; // B0 layout [G0, G1, N, K] + + std::array b1_gs_os_ns_lengths{G0, G1, O, N}; + std::array b1_gs_os_ns_strides = + input_permute ? std::array{N * G1 * O, O, 1, G1 * O} + // B1 layout [G0, N, G1, O] + : std::array{ + G1 * N * O, N * O, 1, O}; // B1 layout [G0, G1, N, O] + + std::array c_gs_ms_os_lengths{G0, G1, M, O}; + std::array c_gs_ms_os_strides = + output_permute ? std::array{M * G1 * O, O, G1 * O, 1} + // C layout [G0, M, G1, O] + : std::array{ + G1 * M * O, M * O, O, 1}; // C layout [G0, G1, M, O] + + const auto a_grid_desc = + DeviceOp::MakeAGridDescriptor(a_gs_ms_ks_lengths, a_gs_ms_ks_strides); + const auto b0_grid_desc = + DeviceOp::MakeB0GridDescriptor(b0_gs_ns_ks_lengths, b0_gs_ns_ks_strides); + const auto b1_grid_desc = + DeviceOp::MakeB1GridDescriptor(b1_gs_os_ns_lengths, b1_gs_os_ns_strides); + const auto c_grid_desc_m_n = + DeviceOp::Transform::MakeCGridDescriptor_M_N(c_gs_ms_os_lengths, c_gs_ms_os_strides); + + const auto block_2_ctile_map = GridwiseOp::MakeDefaultBlock2CTileMap(c_grid_desc_m_n, 1, 1); + + const auto c_grid_desc_g_m_n = + DeviceOp::Transform::MakeCGridDescriptor_G_M_N(c_gs_ms_os_lengths, c_gs_ms_os_strides); + index_t batch_count = c_grid_desc_g_m_n.GetLength(Number<0>{}); + + if(!GridwiseOp::CheckValidity( + a_grid_desc, b0_grid_desc, b1_grid_desc, c_grid_desc_m_n, block_2_ctile_map)) + { + return false; + } + + // Check if C permute dimension matches GEMM + GEMM shape + const index_t c_g = c_grid_desc_g_m_n.GetLength(I0); // unpadded + + if(!(c_g == batch_count)) + { + printf("DeviceOp: BatchCount err"); + return false; + } + + // Note: we need raw lengths since threadwise copy can not handle vector load when part of + // vector is out of bounds + // Note: need lowest dim in Ms/Ns/Ks/Os, not merged M/N/K/O + const auto MzRaw = M; + const auto LzRaw = N; + const auto KzRaw = K; + const auto NzRaw = O; + + // Check scalar per vector requirement + const auto a_extent_lowest = ABlockTransferSrcVectorDim == 2 ? KzRaw : MzRaw; + const auto b0_extent_lowest = B0BlockTransferSrcVectorDim == 2 ? KzRaw : LzRaw; + const auto b1_extent_lowest = B1BlockTransferSrcVectorDim == 2 ? LzRaw : NzRaw; + const auto c_extent_lowest = NzRaw; + + if(!(a_extent_lowest % ABlockTransferSrcScalarPerVector == 0 && + b0_extent_lowest % B0BlockTransferSrcScalarPerVector == 0 && + b1_extent_lowest % B1BlockTransferSrcScalarPerVector == 0 && + c_extent_lowest % CShuffleBlockTransferScalarPerVector_NPerBlock == 0)) + { + printf("DeviceOp: Data Transfer Vector scalar err"); + return false; + } + + std::array a_mz_kz_strides_{ + a_gs_ms_ks_strides[NumDimG + NumDimM - 1], + a_gs_ms_ks_strides[NumDimG + NumDimM + NumDimK - 1]}; + std::array b0_lz_kz_strides_{ + b0_gs_ns_ks_strides[NumDimG + NumDimL - 1], + b0_gs_ns_ks_strides[NumDimG + NumDimL + NumDimK - 1]}; + std::array b1_nz_lz_strides_{ + b1_gs_os_ns_strides[NumDimG + NumDimN - 1], + b1_gs_os_ns_strides[NumDimG + NumDimN + NumDimL - 1]}; + std::array c_mz_nz_strides_{ + c_gs_ms_os_strides[NumDimG + NumDimM - 1], + c_gs_ms_os_strides[NumDimG + NumDimM + NumDimN - 1]}; + + // Check vector load/store requirement + const auto a_stride_lowest = + ABlockTransferSrcVectorDim == 2 ? a_mz_kz_strides_[1] : a_mz_kz_strides_[0]; + const auto b0_stride_lowest = + B0BlockTransferSrcVectorDim == 2 ? b0_lz_kz_strides_[1] : b0_lz_kz_strides_[0]; + const auto b1_stride_lowest = + B1BlockTransferSrcVectorDim == 2 ? b1_nz_lz_strides_[1] : b1_nz_lz_strides_[0]; + const auto c_stride_lowest = c_mz_nz_strides_[1]; + + if(!(a_stride_lowest == 1 || b0_stride_lowest == 1 || b1_stride_lowest == 1 || + c_stride_lowest == 1)) + { + printf("DeviceOp: Data Vectorize transfer err"); + return false; + } + + return true; + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + // Argument + struct Argument : public BaseArgument + { + Argument( + const ADataType* p_a_grid, + const B0DataType* p_b0_grid, + const B1DataType* p_b1_grid, + CDataType* p_c_grid, + const std::array p_acc0_biases, + const std::array p_acc1_biases, + const std::array& a_gs_ms_ks_lengths, + const std::array& a_gs_ms_ks_strides, + const std::array& b0_gs_ls_ks_lengths, + const std::array& b0_gs_ls_ks_strides, + const std::array& b1_gs_ns_ls_lengths, + const std::array& b1_gs_ns_ls_strides, + const std::array& c_gs_ms_ns_lengths, + const std::array& c_gs_ms_ns_strides, + const std::array, NumAcc0Bias> acc0_biases_gs_ms_ls_lengths, + const std::array, NumAcc0Bias> acc0_biases_gs_ms_ls_strides, + const std::array, NumAcc1Bias> acc1_biases_gs_ms_ns_lengths, + const std::array, NumAcc1Bias> acc1_biases_gs_ms_ns_strides, + const index_t M01, + const index_t N01, + AElementwiseOperation a_element_op, + B0ElementwiseOperation b0_element_op, + AccElementwiseOperation acc_element_op, + B1ElementwiseOperation b1_element_op, + CElementwiseOperation c_element_op) + : p_a_grid_{p_a_grid}, + p_b0_grid_{p_b0_grid}, + p_b1_grid_{p_b1_grid}, + p_c_grid_{p_c_grid}, + a_grid_desc{DeviceOp::MakeAGridDescriptor(a_gs_ms_ks_lengths, a_gs_ms_ks_strides)}, + b0_grid_desc{ + DeviceOp::MakeB0GridDescriptor(b0_gs_ls_ks_lengths, b0_gs_ls_ks_strides)}, + b1_grid_desc{ + DeviceOp::MakeB1GridDescriptor(b1_gs_ns_ls_lengths, b1_gs_ns_ls_strides)}, + c_grid_desc_m_n_{ + Transform::MakeCGridDescriptor_M_N(c_gs_ms_ns_lengths, c_gs_ms_ns_strides)}, + a_grid_desc_g_m_k_{ + Transform::MakeAGridDescriptor_G_M_K(a_gs_ms_ks_lengths, a_gs_ms_ks_strides)}, + b0_grid_desc_g_l_k_{ + Transform::MakeB0GridDescriptor_G_N_K(b0_gs_ls_ks_lengths, b0_gs_ls_ks_strides)}, + b1_grid_desc_g_n_l_{ + Transform::MakeB1GridDescriptor_G_N_K(b1_gs_ns_ls_lengths, b1_gs_ns_ls_strides)}, + c_grid_desc_g_m_n_{ + Transform::MakeCGridDescriptor_G_M_N(c_gs_ms_ns_lengths, c_gs_ms_ns_strides)}, + c_grid_desc_mblock_mperblock_nblock_nperblock_{}, + block_2_ctile_map_{GridwiseOp::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_, M01, N01)}, + a_element_op_{a_element_op}, + b0_element_op_{b0_element_op}, + acc_element_op_{acc_element_op}, + b1_element_op_{b1_element_op}, + c_element_op_{c_element_op}, + c0_matrix_mask_{b0_grid_desc_g_l_k_.GetLength(I1)}, + raw_lengths_mz_lz_kz_nz_{a_gs_ms_ks_lengths[NumDimG + NumDimM - 1], + b0_gs_ls_ks_lengths[NumDimG + NumDimL - 1], + b0_gs_ls_ks_lengths[NumDimG + NumDimL + NumDimK - 1], + b1_gs_ns_ls_lengths[NumDimG + NumDimN - 1]}, + a_mz_kz_strides_{a_gs_ms_ks_strides[NumDimG + NumDimM - 1], + a_gs_ms_ks_strides[NumDimG + NumDimM + NumDimK - 1]}, + b0_lz_kz_strides_{b0_gs_ls_ks_strides[NumDimG + NumDimL - 1], + b0_gs_ls_ks_strides[NumDimG + NumDimL + NumDimK - 1]}, + b1_nz_lz_strides_{b1_gs_ns_ls_strides[NumDimG + NumDimN - 1], + b1_gs_ns_ls_strides[NumDimG + NumDimN + NumDimL - 1]}, + c_mz_nz_strides_{c_gs_ms_ns_strides[NumDimG + NumDimM - 1], + c_gs_ms_ns_strides[NumDimG + NumDimM + NumDimN - 1]}, + batch_count_{c_grid_desc_g_m_n_.GetLength(I0)}, + compute_ptr_offset_of_batch_{ + a_grid_desc_g_m_k_, b0_grid_desc_g_l_k_, b1_grid_desc_g_n_l_, c_grid_desc_g_m_n_} + { + // TODO ANT: implement bias addition + ignore = p_acc0_biases; + ignore = p_acc1_biases; + ignore = acc0_biases_gs_ms_ls_lengths; + ignore = acc0_biases_gs_ms_ls_strides; + ignore = acc1_biases_gs_ms_ns_lengths; + ignore = acc1_biases_gs_ms_ns_strides; + + if(GridwiseOp::CheckValidity( + a_grid_desc, b0_grid_desc, b1_grid_desc, c_grid_desc_m_n_, block_2_ctile_map_)) + { + c_grid_desc_mblock_mperblock_nblock_nperblock_ = + GridwiseOp::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + c_grid_desc_m_n_); + } + } + + // Pointers + const ADataType* p_a_grid_; + const B0DataType* p_b0_grid_; + const B1DataType* p_b1_grid_; + CDataType* p_c_grid_; + + // Tensor Descriptors + AGridDesc a_grid_desc; + B0GridDesc b0_grid_desc; + B1GridDesc b1_grid_desc; + CGridDesc_M_N c_grid_desc_m_n_; + + AGridDesc_G_M_K a_grid_desc_g_m_k_; + B0GridDesc_G_L_K b0_grid_desc_g_l_k_; + B1GridDesc_G_N_L b1_grid_desc_g_n_l_; + CGridDesc_G_M_N c_grid_desc_g_m_n_; + + typename GridwiseOp::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + c_grid_desc_mblock_mperblock_nblock_nperblock_; + + // Block to Tile mapping + typename GridwiseOp::DefaultBlock2CTileMap block_2_ctile_map_; + + // ElementwiseOp + AElementwiseOperation a_element_op_; + B0ElementwiseOperation b0_element_op_; + AccElementwiseOperation acc_element_op_; + B1ElementwiseOperation b1_element_op_; + CElementwiseOperation c_element_op_; + + // check C0 masking and padding + C0MatrixMask c0_matrix_mask_; + + // Strides for the last M/N/K dimensions of A/B0/B1/C + // for sanity check of vector load/store + std::array raw_lengths_mz_lz_kz_nz_; + std::array a_mz_kz_strides_; + std::array b0_lz_kz_strides_; + std::array b1_nz_lz_strides_; + std::array c_mz_nz_strides_; + + index_t batch_count_; + // Batch Offset + ComputeBasePtrOfStridedBatch compute_ptr_offset_of_batch_; + }; + + struct Invoker : public BaseInvoker + { + using Argument = DeviceOp::RawArg; + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + const auto M0 = math::integer_divide_ceil(arg.M_, MPerBlock); + const auto N0 = math::integer_divide_ceil(arg.O_, NPerBlock); + + const index_t grid_size = arg.G0_ * arg.G1_ * M0 * N0; + const auto K = arg.K_; + // printf("HasKBlockLoop: %d\n", GridwiseOp::CalculateHasMainKBlockLoop(K)); + auto launch_kernel = [&](auto has_main_k_block_loop) { + const auto kernel = kernel_multi_query_attention_wmma; + + return launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b0_grid_, + arg.p_b1_grid_, + arg.p_c_grid_, + arg.M_, + arg.N_, + arg.K_, + arg.O_, + arg.G0_, + arg.G1_, + arg.alpha_, + arg.input_permute_, + arg.output_permute_); + }; + + if(GridwiseOp::CalculateHasMainKBlockLoop(K)) + { + return launch_kernel(integral_constant{}); + } + else + { + return launch_kernel(integral_constant{}); + } + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } +#if 0 + static bool IsSupportedArgument(const Argument& arg) + { + if(ck::is_navi3_supported()) + { + if constexpr(!(is_same_v || is_same_v)) + { + printf("DeviceOp: Acc0 Type err"); + return false; + } + + if constexpr(!(is_same_v || is_same_v)) + { + printf("DeviceOp: Acc1 Type err"); + return false; + } + } + else + { + printf("DeviceOp: Arch err"); + return false; + } + + if(!GridwiseOp::CheckValidity(arg.a_grid_desc, + arg.b0_grid_desc, + arg.b1_grid_desc, + arg.c_grid_desc_m_n_, + arg.block_2_ctile_map_)) + { + return false; + } + + // Check if C permute dimension matches GEMM + GEMM shape + const index_t c_g = arg.c_grid_desc_g_m_n_.GetLength(I0); // unpadded + + if(!(c_g == arg.batch_count_)) + { + printf("DeviceOp: BatchCount err"); + return false; + } + + // Note: we need raw lengths since threadwise copy can not handle vector load when part of + // vector is out of bounds + // Note: need lowest dim in Ms/Ns/Ks/Os, not merged M/N/K/O + const auto MzRaw = arg.raw_lengths_mz_lz_kz_nz_[0]; + const auto LzRaw = arg.raw_lengths_mz_lz_kz_nz_[1]; + const auto KzRaw = arg.raw_lengths_mz_lz_kz_nz_[2]; + const auto NzRaw = arg.raw_lengths_mz_lz_kz_nz_[3]; + + // Check scalar per vector requirement + const auto a_extent_lowest = ABlockTransferSrcVectorDim == 2 ? KzRaw : MzRaw; + const auto b0_extent_lowest = B0BlockTransferSrcVectorDim == 2 ? KzRaw : LzRaw; + const auto b1_extent_lowest = B1BlockTransferSrcVectorDim == 2 ? LzRaw : NzRaw; + const auto c_extent_lowest = NzRaw; + + if(!(a_extent_lowest % ABlockTransferSrcScalarPerVector == 0 && + b0_extent_lowest % B0BlockTransferSrcScalarPerVector == 0 && + b1_extent_lowest % B1BlockTransferSrcScalarPerVector == 0 && + c_extent_lowest % CShuffleBlockTransferScalarPerVector_NPerBlock == 0)) + { + printf("DeviceOp: Data Transfer Vector scalar err"); + return false; + } + + // Check vector load/store requirement + const auto a_stride_lowest = + ABlockTransferSrcVectorDim == 2 ? arg.a_mz_kz_strides_[1] : arg.a_mz_kz_strides_[0]; + const auto b0_stride_lowest = + B0BlockTransferSrcVectorDim == 2 ? arg.b0_lz_kz_strides_[1] : arg.b0_lz_kz_strides_[0]; + const auto b1_stride_lowest = + B1BlockTransferSrcVectorDim == 2 ? arg.b1_nz_lz_strides_[1] : arg.b1_nz_lz_strides_[0]; + const auto c_stride_lowest = arg.c_mz_nz_strides_[1]; + + if(!(a_stride_lowest == 1 || b0_stride_lowest == 1 || b1_stride_lowest == 1 || + c_stride_lowest == 1)) + { + printf("DeviceOp: Data Vectorize transfer err"); + return false; + } + + return true; + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument( + const ADataType* p_a, + const B0DataType* p_b0, + const B1DataType* p_b1, + CDataType* p_c, + const std::array p_acc0_biases, + const std::array p_acc1_biases, + const std::array& a_gs_ms_ks_lengths, + const std::array& a_gs_ms_ks_strides, + const std::array& b0_gs_ls_ks_lengths, + const std::array& b0_gs_ls_ks_strides, + const std::array& b1_gs_ns_ls_lengths, + const std::array& b1_gs_ns_ls_strides, + const std::array& c_gs_ms_ns_lengths, + const std::array& c_gs_ms_ns_strides, + const std::array, NumAcc0Bias> acc0_biases_gs_ms_ls_lengths, + const std::array, NumAcc0Bias> acc0_biases_gs_ms_ls_strides, + const std::array, NumAcc1Bias> acc1_biases_gs_ms_ns_lengths, + const std::array, NumAcc1Bias> acc1_biases_gs_ms_ns_strides, + AElementwiseOperation a_element_op, + B0ElementwiseOperation b0_element_op, + AccElementwiseOperation acc_element_op, + B1ElementwiseOperation b1_element_op, + CElementwiseOperation c_element_op) + { + return Argument{p_a, + p_b0, + p_b1, + p_c, + p_acc0_biases, + p_acc1_biases, + a_gs_ms_ks_lengths, + a_gs_ms_ks_strides, + b0_gs_ls_ks_lengths, + b0_gs_ls_ks_strides, + b1_gs_ns_ls_lengths, + b1_gs_ns_ls_strides, + c_gs_ms_ns_lengths, + c_gs_ms_ns_strides, + acc0_biases_gs_ms_ls_lengths, + acc0_biases_gs_ms_ls_strides, + acc1_biases_gs_ms_ns_lengths, + acc1_biases_gs_ms_ns_strides, + 1, + 1, + a_element_op, + b0_element_op, + acc_element_op, + b1_element_op, + c_element_op}; + } +#endif + + // polymorphic + std::unique_ptr MakeArgumentPointer( + const void* p_a, + const void* p_b0, + const void* p_b1, + void* p_c, + const std::array p_acc0_biases, + const std::array p_acc1_biases, + const std::vector& a_gs_ms_ks_lengths, + const std::vector& a_gs_ms_ks_strides, + const std::vector& b0_gs_ls_ks_lengths, + const std::vector& b0_gs_ls_ks_strides, + const std::vector& b1_gs_ns_ls_lengths, + const std::vector& b1_gs_ns_ls_strides, + const std::vector& c_gs_ms_ns_lengths, + const std::vector& c_gs_ms_ns_strides, + const std::array, NumAcc0Bias> acc0_biases_gs_ms_ls_lengths, + const std::array, NumAcc0Bias> acc0_biases_gs_ms_ls_strides, + const std::array, NumAcc1Bias> acc1_biases_gs_ms_ns_lengths, + const std::array, NumAcc1Bias> acc1_biases_gs_ms_ns_strides, + AElementwiseOperation a_element_op, + B0ElementwiseOperation b0_element_op, + AccElementwiseOperation acc_element_op, + B1ElementwiseOperation b1_element_op, + CElementwiseOperation c_element_op) override + { + std::array a_lengths; + std::array a_strides; + std::array b0_lengths; + std::array b0_strides; + std::array b1_lengths; + std::array b1_strides; + std::array c_lengths; + std::array c_strides; + std::transform(a_gs_ms_ks_lengths.begin(), + a_gs_ms_ks_lengths.end(), + a_lengths.begin(), + [](index_t i) { return i; }); + std::transform(a_gs_ms_ks_strides.begin(), + a_gs_ms_ks_strides.end(), + a_strides.begin(), + [](index_t i) { return i; }); + std::transform(b0_gs_ls_ks_lengths.begin(), + b0_gs_ls_ks_lengths.end(), + b0_lengths.begin(), + [](index_t i) { return i; }); + std::transform(b0_gs_ls_ks_strides.begin(), + b0_gs_ls_ks_strides.end(), + b0_strides.begin(), + [](index_t i) { return i; }); + std::transform(b1_gs_ns_ls_lengths.begin(), + b1_gs_ns_ls_lengths.end(), + b1_lengths.begin(), + [](index_t i) { return i; }); + std::transform(b1_gs_ns_ls_strides.begin(), + b1_gs_ns_ls_strides.end(), + b1_strides.begin(), + [](index_t i) { return i; }); + std::transform(c_gs_ms_ns_lengths.begin(), + c_gs_ms_ns_lengths.end(), + c_lengths.begin(), + [](index_t i) { return i; }); + std::transform(c_gs_ms_ns_strides.begin(), + c_gs_ms_ns_strides.end(), + c_strides.begin(), + [](index_t i) { return i; }); + return std::make_unique(static_cast(p_a), + static_cast(p_b0), + static_cast(p_b1), + static_cast(p_c), + p_acc0_biases, + p_acc1_biases, + a_lengths, + a_strides, + b0_lengths, + b0_strides, + b1_lengths, + b1_strides, + c_lengths, + c_strides, + acc0_biases_gs_ms_ls_lengths, + acc0_biases_gs_ms_ls_strides, + acc1_biases_gs_ms_ns_lengths, + acc1_biases_gs_ms_ns_strides, + 1, + 1, + a_element_op, + b0_element_op, + acc_element_op, + b1_element_op, + c_element_op); + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + std::map LoopSchedToString{ + {LoopScheduler::Default, "Default"}, {LoopScheduler::Interwave, "Interwave"}}; + + std::map PipelineVersionToString{{PipelineVersion::v1, "v1"}, + {PipelineVersion::v2, "v2"}}; + + // clang-format off + str << "DeviceMultiQueryAttentionForward_Wmma" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << LPerBlock << ", " + << KPerBlock << ", " + << AK1 << ", " + << BK1 << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << LTilePerBlock << ", " + << L1 << ", " + << getGemmSpecializationString(GemmSpec) << ", " + << "ASpec" << getTensorSpecializationString(ASpec) << ", " + << "B0Spec" << getTensorSpecializationString(B0Spec) << ", " + << "B1Spec" << getTensorSpecializationString(B1Spec) << ", " + << "CSpec" << getTensorSpecializationString(CSpec) << ", " + << getMaskingSpecializationString(MaskingSpec) + << ">" + << " AEnableLds: " + << AEnableLds << ", " + << "B0EnableLds: " + << B0EnableLds << ", " + << "B1EnableLds: " + << B1EnableLds << ", " + << "NumPrefetch: " + << NumPrefetch << ", " + << "LoopScheduler: " + << LoopSchedToString[LoopSched] << ", " + << "PipelineVersion: " + << PipelineVersionToString[PipelineVer]; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/masking_specialization.hpp b/include/ck/tensor_operation/gpu/device/masking_specialization.hpp index d6d6f74abd..0ec55984bc 100644 --- a/include/ck/tensor_operation/gpu/device/masking_specialization.hpp +++ b/include/ck/tensor_operation/gpu/device/masking_specialization.hpp @@ -53,7 +53,10 @@ struct MaskOutUpperTrianglePredicate template struct C0MatrixMask_impl { - C0MatrixMask_impl(index_t NRaw) : NRaw_(NRaw), predicate_(MaskOutPredicate{}) {} + __host__ __device__ C0MatrixMask_impl(index_t NRaw) + : NRaw_(NRaw), predicate_(MaskOutPredicate{}) + { + } __host__ __device__ constexpr bool IsNOutOfBound(/*index_t m, */ index_t n) const { diff --git a/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp b/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp index 33c2cb6c6d..c6d933893e 100644 --- a/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp +++ b/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp @@ -123,6 +123,12 @@ struct PassThrough y = type_convert(x); } + template <> + __host__ __device__ void operator()(uint8_t& y, const uint8_t& x) const + { + y = x; + } + template <> __host__ __device__ void operator()(int8_t& y, const int32_t& x) const { @@ -663,6 +669,76 @@ struct Elu const float alpha_; }; +// support fastconvert of int8 to fp16 + +template +struct FastNumericArrayConverter +{ +}; + +template <> +struct FastNumericArrayConverter +{ + using InputArray = vector_type; + using OutputArray = vector_type; + + __device__ static OutputArray convert(InputArray const& Input) + { + OutputArray Output; + + uint32_t* half_2 = reinterpret_cast(&Output); + uint32_t const uint8_4 = reinterpret_cast(Input); + + static constexpr uint32_t byte_selector_01 = 0x05010500; + static constexpr uint32_t byte_selector_23 = 0x05030502; + static constexpr uint32_t fp16_adder = 0x64646464; + half_2[0] = __builtin_amdgcn_perm(fp16_adder, uint8_4, byte_selector_01); + half_2[1] = __builtin_amdgcn_perm(fp16_adder, uint8_4, byte_selector_23); + + static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480; + asm volatile("v_pk_add_f16 %0, %1, %2 neg_lo:[0,1] neg_hi:[0,1]" + : "=v"(half_2[0]) + : "v"(half_2[0]), "s"(I8s_TO_F16s_MAGIC_NUM)); + asm volatile("v_pk_add_f16 %0, %1, %2 neg_lo:[0,1] neg_hi:[0,1]" + : "=v"(half_2[1]) + : "v"(half_2[1]), "s"(I8s_TO_F16s_MAGIC_NUM)); + + return Output; + } + + __device__ OutputArray operator()(InputArray const& Input) { return convert(Input); } +}; + +template +struct FastNumericArrayConverter +{ + static constexpr int VEC_WIDTH = 4; + static_assert(!(N % VEC_WIDTH), "N must be multiple of 4."); + + using InputArray = vector_type; + using OutputArray = vector_type; + + __device__ static OutputArray convert(InputArray const& Input) + { + FastNumericArrayConverter converter; + + OutputArray Output; + + using Vec_InputArray = vector_type; + using Vec_OutputArray = vector_type; + + Vec_OutputArray* half_4_ptr = reinterpret_cast(&Output); + Vec_InputArray const* uint8_4_ptr = reinterpret_cast(&Input); + + static_for<0, N / VEC_WIDTH, 1>{}( + [&](auto i) { half_4_ptr[i] = converter(uint8_4_ptr[i]); }); + + return Output; + } + + __device__ OutputArray operator()(InputArray const& Input) { return convert(Input); } +}; + } // namespace element_wise } // namespace tensor_operation } // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle_v1.hpp index a0924ae3b0..42f7c2a33f 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle_v1.hpp @@ -116,7 +116,7 @@ struct GridwiseBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle using ThisThreadBlock = ThisThreadBlock; - using GridwiseGemmPipe = GridwiseGemmPipeline_v1; + using GridwiseGemmPipe = GridwiseGemmPipeline_v1; // ck::Tuple static constexpr auto MakeD0sGridPointer() diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp new file mode 100644 index 0000000000..16717ff819 --- /dev/null +++ b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp @@ -0,0 +1,1596 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/multi_index_transform_helper.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_softmax.hpp" + +namespace ck { + +// Gemm0: A [M x K] x B0 [K x L] = Acc [M x L] +// Gemm1: Acc [M x L] x B1 [L x N] = C [M x N] +template +struct GridwiseBatchedGemmSoftmaxGemm_Wmma +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto I6 = Number<6>{}; + static constexpr auto I7 = Number<7>{}; + + static constexpr auto AK1 = Number{}; + static constexpr auto BK0 = Number{}; + static constexpr auto BK1 = Number{}; + + static constexpr auto L0PerBlock = LTilePerBlock / L1Value; + static constexpr auto AL0 = Number{}; + static constexpr auto AL1 = Number{}; + static constexpr auto BL0 = Number{}; + static constexpr auto BL1 = Number{}; + + static constexpr auto MWaves = MPerBlock / (MRepeat * MPerWmma); + static constexpr auto LWaves = LPerBlock / (LRepeat * LPerWmma); + static constexpr auto NWaves = NPerBlock / (NRepeat * NPerWmma); + static constexpr auto WmmaK = 16; + static constexpr auto WmmaL = 16; + + using ThisThreadBlock = ThisThreadBlock; + + using GridwiseGemmPipe = + remove_cvref_t())>; + + __host__ __device__ static constexpr auto MakeABlockDescriptor() + { + constexpr auto a_block_desc = [&]() { + if constexpr(AEnableLds) + { + // K0->M->K1 Per Block + constexpr auto K0PerBlock = KPerBlock / AK1; + constexpr auto max_lds_align = AK1; + + if constexpr(ABlockLdsExtraM) + { + return make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, AK1), + make_tuple(Number{} * AK1, AK1, I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, AK1), max_lds_align); + } + } + else + { + constexpr auto KWmmaPerblock = KPerBlock / WmmaK; + constexpr auto K0PerWmma = WmmaK / 2 / AK1; + // KWmma->MRepeat->MWave->K0PerWmma->KRow->MPerWmma->K1 Per Thread + return make_naive_tensor_descriptor( + make_tuple(Number{}, + Number{}, + I1, + Number{}, + I1, + I1, + AK1), + make_tuple(Number{} * Number{} * AK1, + Number{} * AK1, + Number{} * AK1, + AK1, + AK1, + AK1, + I1)); + } + }(); + + return a_block_desc; + } + + __host__ __device__ static constexpr auto MakeB0BlockDescriptor() + { + constexpr auto b0_block_desc = [&]() { + if constexpr(B0EnableLds) + { + // K0->L->BK1 Per Block + constexpr auto K0PerBlock = KPerBlock / BK1; + constexpr auto max_lds_align = BK1; + + if constexpr(B0BlockLdsExtraL) + { + return make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, BK1), + make_tuple(Number{} * BK1, BK1, I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, BK1), max_lds_align); + } + } + else + { + constexpr auto KWmmaPerblock = KPerBlock / WmmaK; + constexpr auto K0PerWmma = WmmaK / 2 / BK1; + // KWmma->NRepeat->MWave->K0PerWmma->KRow->MPerWmma->K1 Per Thread + return make_naive_tensor_descriptor( + make_tuple(Number{}, + Number{}, + I1, + Number{}, + I1, + I1, + BK1), + make_tuple(Number{} * Number{} * BK1, + Number{} * BK1, + Number{} * BK1, + BK1, + BK1, + BK1, + I1)); + } + }(); + + return b0_block_desc; + } + + __host__ __device__ static constexpr auto MakeB1BlockDescriptor() + { + constexpr auto b1_block_desc = [&]() { + if constexpr(B1EnableLds) + { + // L0->N->BL1 Per Block + constexpr auto max_lds_align = BL1; + + if constexpr(B1BlockLdsExtraN) + { + return make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, BL1), + make_tuple(Number{} * BL1, BL1, I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, BL1), max_lds_align); + } + } + else + { + constexpr auto LWmmaPerblock = LPerBlock / WmmaL; + constexpr auto L0PerWmma = WmmaL / 2 / BL1; + // LWmma->NRepeat->MWave->L0PerWmma->LRow->MPerWmma->L1 Per Thread + return make_naive_tensor_descriptor( + make_tuple(Number{}, + Number{}, + I1, + Number{}, + I1, + I1, + BL1), + make_tuple(Number{} * Number{} * BL1, + Number{} * BL1, + Number{} * BL1, + BL1, + BL1, + BL1, + I1)); + } + }(); + + return b1_block_desc; + } + + __host__ __device__ static constexpr auto MakeABlockSliceCopyStep() + { + constexpr auto a_block_copy_step = [&]() { + if constexpr(AEnableLds) + { + constexpr auto K0PerBlock = KPerBlock / AK1; + + return make_multi_index(K0PerBlock, 0, 0); + } + else + { + constexpr auto KWmmaPerBlock = KPerBlock / WmmaK; + + return make_multi_index(KWmmaPerBlock, 0, 0, 0, 0, 0, 0); + } + }(); + + return a_block_copy_step; + } + + __host__ __device__ static constexpr auto MakeB0BlockSliceCopyStep() + { + constexpr auto b0_block_copy_step = [&]() { + if constexpr(B0EnableLds) + { + constexpr auto K0PerBlock = KPerBlock / BK1; + + return make_multi_index(K0PerBlock, 0, 0); + } + else + { + constexpr auto KWmmaPerBlock = KPerBlock / WmmaK; + + return make_multi_index(KWmmaPerBlock, 0, 0, 0, 0, 0, 0); + } + }(); + + return b0_block_copy_step; + } + + __host__ __device__ static constexpr auto MakeB1BlockSliceCopyStep() + { + constexpr auto b1_block_copy_step = [&]() { + if constexpr(B1EnableLds) + { + return make_multi_index(L0PerBlock, 0, 0); + } + else + { + constexpr auto LWmmaPerBlock = LTilePerBlock / WmmaL; + + return make_multi_index(LWmmaPerBlock, 0, 0, 0, 0, 0, 0); + } + }(); + + return b1_block_copy_step; + } + + // Describe how data read from (LDS/VGPR) buffer + template + __host__ __device__ static constexpr auto MakeAWaveDescriptor(const ABlockDesc_&) + { + + constexpr auto a_wave_desc = [&]() { + if constexpr(AEnableLds) + { + // AK0_M_AK1 -> AK0_MRepeat_Mwaves_MPerWmma_AK1 + constexpr auto A_K0 = ABlockDesc_{}.GetLength(I0); + constexpr auto A_K1 = ABlockDesc_{}.GetLength(I2); + constexpr auto A_KRow = I1; + return transform_tensor_descriptor( + ABlockDesc_{}, + make_tuple(make_unmerge_transform(make_tuple(Number{}, A_KRow)), + make_unmerge_transform(make_tuple( + Number{}, Number{}, Number{})), + make_pass_through_transform(Number{})), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0, 3>{}, Sequence<1, 2, 4>{}, Sequence<5>{})); + } + else + { + // KWmma_MRepeat_MWave_K0PerWmma_KRow_MPerWmma_K1 -> K0_MRepeat_Mwaves_MPerWmma_K1 + constexpr auto KWmma = ABlockDesc_{}.GetLength(I0); + constexpr auto K0PerWmma = ABlockDesc_{}.GetLength(I3); + constexpr auto A_KRow = ABlockDesc_{}.GetLength(I4); + constexpr auto A_K1 = ABlockDesc_{}.GetLength(I6); + + return make_naive_tensor_descriptor_packed(make_tuple(Number{}, + Number{}, + I1, + Number{}, + I1, + Number{})); + } + }(); + + return a_wave_desc; + } + + template + __host__ __device__ static constexpr auto MakeB0WaveDescriptor(const B0BlockDesc_&) + { + + constexpr auto b0_wave_desc = [&]() { + if constexpr(B0EnableLds) + { + // BK0_L_BK1 -> BK0_LRepeat_Lwaves_LPerWmma_BK1 + constexpr auto B_K0 = B0BlockDesc_{}.GetLength(I0); + constexpr auto B_K1 = B0BlockDesc_{}.GetLength(I2); + constexpr auto B_KRow = I1; + return transform_tensor_descriptor( + B0BlockDesc_{}, + make_tuple(make_unmerge_transform(make_tuple(Number{}, B_KRow)), + make_unmerge_transform(make_tuple( + Number{}, Number{}, Number{})), + make_pass_through_transform(Number{})), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0, 3>{}, Sequence<1, 2, 4>{}, Sequence<5>{})); + } + else + { + // KWmma_MRepeat_MWave_K0PerWmma_KRow_MPerWmma_K1 -> K0_MRepeat_Mwaves_MPerWmma_K1 + constexpr auto KWmma = B0BlockDesc_{}.GetLength(I0); + constexpr auto K0PerWmma = B0BlockDesc_{}.GetLength(I3); + constexpr auto B_KRow = B0BlockDesc_{}.GetLength(I4); + constexpr auto B_K1 = B0BlockDesc_{}.GetLength(I6); + + // Workaround, Freeze transform + return make_naive_tensor_descriptor_packed(make_tuple(Number{}, + Number{}, + I1, + Number{}, + I1, + Number{})); + } + }(); + + return b0_wave_desc; + } + + template + __host__ __device__ static constexpr auto + MakeA1WaveDescriptor_L0_M0_M1_M2_L1(const A1BlockDesc_AL0_M_AL1&) + { + constexpr index_t A_L0 = A1BlockDesc_AL0_M_AL1{}.GetLength(I0); + constexpr index_t A_L1 = A1BlockDesc_AL0_M_AL1{}.GetLength(I2); + constexpr auto A_LRow = I1; + return transform_tensor_descriptor( + A1BlockDesc_AL0_M_AL1{}, + make_tuple(make_unmerge_transform(make_tuple(Number{}, A_LRow)), + make_unmerge_transform(make_tuple(Number{}, I1, I1)), + make_pass_through_transform(Number{})), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0, 3>{}, Sequence<1, 2, 4>{}, Sequence<5>{})); + } + + template + __host__ __device__ static constexpr auto MakeB1WaveDescriptor(const B1BlockDesc_&) + { + + constexpr auto b1_wave_desc = [&]() { + if constexpr(B1EnableLds) + { + // BL0_N_BL1 -> BL0_NRepeat_Nwaves_NPerWmma_BL1 + constexpr auto B_L0 = B1BlockDesc_{}.GetLength(I0); + constexpr auto B_L1 = B1BlockDesc_{}.GetLength(I2); + constexpr auto B_LRow = I1; + return transform_tensor_descriptor( + B1BlockDesc_{}, + make_tuple(make_unmerge_transform(make_tuple(Number{}, B_LRow)), + make_unmerge_transform(make_tuple( + Number{}, Number{}, Number{})), + make_pass_through_transform(Number{})), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0, 3>{}, Sequence<1, 2, 4>{}, Sequence<5>{})); + } + else + { + constexpr auto LWmma = B1BlockDesc_{}.GetLength(I0); + constexpr auto L0PerWmma = B1BlockDesc_{}.GetLength(I3); + constexpr auto B_LRow = B1BlockDesc_{}.GetLength(I4); + constexpr auto B_L1 = B1BlockDesc_{}.GetLength(I6); + + return make_naive_tensor_descriptor_packed(make_tuple(Number{}, + Number{}, + I1, + Number{}, + I1, + Number{})); + } + }(); + + return b1_wave_desc; + } + + __host__ __device__ static constexpr auto + // *Caution Here repeat is shuffle repeat + GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat() + { + constexpr index_t MWave = MPerBlock / (MRepeat * MPerWmma); + constexpr index_t NWave = NPerBlock / (NRepeat * NPerWmma); + + constexpr auto c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat = + make_naive_tensor_descriptor_packed( + make_tuple(I1, + Number{}, + I1, + Number{})); + + return c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat; + } + + __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() + { + // LDS allocation for A and B: be careful of alignment + const index_t gemm0_bytes_end = + (SharedMemTrait::a_block_space_size_aligned * sizeof(ADataType) + + SharedMemTrait::b0_block_space_size_aligned * sizeof(B0DataType)); + + const index_t gemm1_bytes_end = + (SharedMemTrait::b1_block_space_offset + + SharedMemTrait::b1_block_space_size_aligned * sizeof(B1DataType)); + + const index_t softmax_bytes_end = + SharedMemTrait::reduction_space_offset + + SharedMemTrait::reduction_space_size_aligned * sizeof(Acc0DataType); + + const index_t c_block_bytes_end = + SharedMemTrait::c_block_space_size * sizeof(CShuffleDataType); + + return math::max(gemm0_bytes_end, gemm1_bytes_end, softmax_bytes_end, c_block_bytes_end); + } + + // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} + template + __host__ __device__ static constexpr bool CheckValidity(const AGridDesc& a_grid_desc, + const B0GridDesc& b0_grid_desc, + const B1GridDesc& b1_grid_desc, + const CGridDesc_M_N& c_grid_desc_m_n, + const Block2CTileMap& block_2_ctile_map) + { + static_assert((MPerBlock % (MPerWmma * MRepeat) == 0) && + (LPerBlock % (LPerWmma * LRepeat)) == 0, + "Invalid tuning param!"); + + const auto GetAProblemsizeMK = [&]() { + if constexpr(AEnableLds) + { + return make_tuple(a_grid_desc.GetLength(I1), + a_grid_desc.GetLength(I0) * a_grid_desc.GetLength(I2)); + } + else + { + return make_tuple(a_grid_desc.GetLength(I1) * a_grid_desc.GetLength(I2) * + a_grid_desc.GetLength(I5), + a_grid_desc.GetLength(I0) * a_grid_desc.GetLength(I3) * + a_grid_desc.GetLength(I4) * a_grid_desc.GetLength(I6)); + } + }; + + const auto GetB0ProblemsizeLK = [&]() { + if constexpr(B0EnableLds) + { + return make_tuple(b0_grid_desc.GetLength(I1), + b0_grid_desc.GetLength(I0) * b0_grid_desc.GetLength(I2)); + } + else + { + return make_tuple(b0_grid_desc.GetLength(I1) * b0_grid_desc.GetLength(I2) * + b0_grid_desc.GetLength(I5), + b0_grid_desc.GetLength(I0) * b0_grid_desc.GetLength(I3) * + b0_grid_desc.GetLength(I4) * b0_grid_desc.GetLength(I6)); + } + }; + + const auto GetB1ProblemsizeNL = [&]() { + if constexpr(B1EnableLds) + { + return make_tuple(b1_grid_desc.GetLength(I1), + b1_grid_desc.GetLength(I0) * b1_grid_desc.GetLength(I2)); + } + else + { + return make_tuple(b1_grid_desc.GetLength(I1) * b1_grid_desc.GetLength(I2) * + b1_grid_desc.GetLength(I5), + b1_grid_desc.GetLength(I0) * b1_grid_desc.GetLength(I3) * + b1_grid_desc.GetLength(I4) * b1_grid_desc.GetLength(I6)); + } + }; + + const auto M = GetAProblemsizeMK()[I0]; + const auto L = GetB0ProblemsizeLK()(I0); + const auto K = GetAProblemsizeMK()[I1]; + const auto N = GetB1ProblemsizeNL()(I0); + + if(!(M == c_grid_desc_m_n.GetLength(I0) && N == c_grid_desc_m_n.GetLength(I1))) + { + printf("GridwiseOp: M/N Length err, A_M/N = %d, %d | C_M/N = %d, %d\n", + M, + N, + c_grid_desc_m_n.GetLength(I0), + c_grid_desc_m_n.GetLength(I1)); + return false; + } + + if(!(M % MPerBlock == 0 && L % LPerBlock == 0 && K % KPerBlock == 0 && N % NPerBlock == 0)) + { + printf("GridwiseOp: M/L/K/N Division err, M/L/K/N = %d, %d, %d, %d | M/L/K/NPerBlock = " + "%d, %d, %d, %d\n", + M, + L, + K, + N, + MPerBlock, + LPerBlock, + KPerBlock, + NPerBlock); + return false; + } + + // check gemm0 gridwise gemm pipeline + const auto num_gemm0_k_loop = K / KPerBlock; + if(!GridwiseGemmPipe::IsSupported(num_gemm0_k_loop)) + { + printf("GridwiseOp: outer loop unsupport\n"); + return false; + } + + // check gemm1 gridwise gemm pipeline + if(!(LPerBlock % LTilePerBlock == 0)) + { + printf("GridwiseOp: inner loop division, L/LTilePerblock: %d, %d\n", + LPerBlock, + LTilePerBlock); + return false; + } + + const auto num_gemm1_k_inner_loop = LPerBlock / LTilePerBlock; + if(!GridwiseGemmPipe::IsSupported(num_gemm1_k_inner_loop)) + { + printf("GridwiseOp: inner loop unsupport\n"); + return false; + } + + if(!block_2_ctile_map.CheckValidity(c_grid_desc_m_n)) + { + return false; + } + + // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) + return true; + } + + __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K) + { + const index_t num_loop = math::integer_divide_ceil(K, KPerBlock); + + return GridwiseGemmPipe::CalculateHasMainLoop(num_loop); + } + + __host__ __device__ static constexpr auto + MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc_M_N& c_grid_desc_m_n) + { + const auto M = c_grid_desc_m_n.GetLength(I0); + const auto N = c_grid_desc_m_n.GetLength(I1); + + const auto MBlock = M / MPerBlock; + const auto NBlock = N / NPerBlock; + + const auto c_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple(make_unmerge_transform(make_tuple(MBlock, Number{})), + make_unmerge_transform(make_tuple(NBlock, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{})); + + return c_grid_desc_mblock_mperblock_nblock_nperblock; + } + + // return block_id to C matrix tile idx (m0, n0) mapping + __host__ __device__ static constexpr auto MakeDefaultBlock2CTileMap( + const CGridDesc_M_N& c_grid_desc_m_n, index_t /* M01 */, index_t /* N01 */) + { + return BlockToCTileMap_M00_N0_M01Adapt( + c_grid_desc_m_n); + } + + using CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = + remove_cvref_t; + using DefaultBlock2CTileMap = + remove_cvref_t; + + struct SharedMemTrait + { + // LDS allocation for A and B: be careful of alignment + static constexpr auto max_lds_align = math::lcm(math::lcm(AK1, BK1), BL1); + + static constexpr auto a_block_space_size_aligned = + AEnableLds ? math::integer_least_multiple(MakeABlockDescriptor().GetElementSpaceSize(), + max_lds_align) + : 0; + static constexpr auto b0_block_space_size_aligned = + B0EnableLds ? math::integer_least_multiple( + MakeB0BlockDescriptor().GetElementSpaceSize(), max_lds_align) + : 0; + static constexpr auto b1_block_space_size_aligned = + B1EnableLds ? math::integer_least_multiple( + MakeB1BlockDescriptor().GetElementSpaceSize(), max_lds_align) + : 0; + + static constexpr auto a_block_space_offset = 0; + static constexpr auto b0_block_space_offset = a_block_space_size_aligned; + static constexpr auto b1_block_space_offset = 0; + + // LDS allocation for reduction + // Feature to add, IntraThread Reduction + static constexpr index_t reduction_space_size_aligned = + math::integer_least_multiple(BlockSize, max_lds_align); + + static constexpr auto reduction_space_offset = 0; + + // LDS allocation for C shuffle in LDS + static constexpr auto c_block_space_size = + GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat() + .GetElementSpaceSize(); + }; + + template + __device__ static void Run(const ADataType* __restrict__ p_a_grid, + const B0DataType* __restrict__ p_b0_grid, + const B1DataType* __restrict__ p_b1_grid, + CDataType* __restrict__ p_c_grid, + void* __restrict__ p_shared, + const AGridDesc& a_grid_desc, + const B0GridDesc& b0_grid_desc, + const B1GridDesc& b1_grid_desc, + const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock& + c_grid_desc_mblock_mperblock_nblock_nperblock, + const AElementwiseOperation& a_element_op, + const B0ElementwiseOperation& b0_element_op, + const AccElementwiseOperation& acc_element_op, + const B1ElementwiseOperation& b1_element_op, + const CElementwiseOperation& c_element_op, + const C0MatrixMask& c0_matrix_mask, + const Block2CTileMap& block_2_ctile_map) + { + // clang-format off +/*******************************************************************************/ +// Memory buffer zone. + const auto a_grid_buf = make_dynamic_buffer( + p_a_grid, a_grid_desc.GetElementSpaceSize()); + const auto b0_grid_buf = make_dynamic_buffer( + p_b0_grid, b0_grid_desc.GetElementSpaceSize()); + const auto b1_grid_buf = make_dynamic_buffer( + p_b1_grid, b1_grid_desc.GetElementSpaceSize()); + auto c_grid_buf = make_dynamic_buffer( + p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + +/*******************************************************************************/ +// BlockIdx.x -> [BlockId.m, BlockId.n] + const auto block_work_idx = block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + if(!block_2_ctile_map.ValidCTileIndex( + block_work_idx, + make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0), + c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2)))) + { return; } + + // Store BlockId into SGPR + const index_t m_block_data_idx_on_grid = __builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock); + const index_t n_block_data_idx_on_grid = __builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock); + +/*******************************************************************************/ +// set up Gemm0 +/*******************************************************************************/ + +/*******************************************************************************/ +// BlockLevel, A/B Matrix ThreadMapping in LDS, As Destinaion of BlockWise_Copy + constexpr auto a_block_desc = MakeABlockDescriptor(); + constexpr auto b0_block_desc = MakeB0BlockDescriptor(); + + auto a_block_trait = [&](){ + // A matrix blockwise copy + if constexpr(AEnableLds) + { + constexpr auto AK0PerBlock = KPerBlock/ AK1; + auto a_block_buf = make_dynamic_buffer( + static_cast(p_shared) + SharedMemTrait::a_block_space_offset, + SharedMemTrait::a_block_space_size_aligned); + + auto a_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1< ThisThreadBlock, +/* typename SrcElementwiseOperation, */ AElementwiseOperation, +/* typename DstElementwiseOperation, */ ck::tensor_operation::element_wise::PassThrough, +/* InMemoryDataOperationEnum DstInMemOp, */ InMemoryDataOperationEnum::Set, +/* typename BlockSliceLengths, */ Sequence, +/* typename ThreadClusterLengths, */ ABlockTransferThreadClusterLengths_K0_M_K1, +/* typename ThreadClusterArrangeOrder, */ ABlockTransferThreadClusterArrangeOrder, +/* typename SrcData, */ ADataType, +/* typename DstData, */ ADataType, +/* typename SrcDesc, */ decltype(a_grid_desc), +/* typename DstDesc, */ decltype(a_block_desc), +/* typename SrcDimAccessOrder, */ ABlockTransferSrcAccessOrder, +/* typename DstDimAccessOrder, */ Sequence<0, 1, 2>, +/* index_t SrcVectorDim, */ ABlockTransferSrcVectorDim, +/* index_t DstVectorDim, */ 2, +/* index_t SrcScalarPerVector, */ ABlockTransferSrcScalarPerVector, +/* index_t DstScalarPerVector, */ ABlockTransferDstScalarPerVector_K1, +/* index_t SrcScalarStrideInVector, */ 1, +/* index_t DstScalarStrideInVector, */ 1, +/* bool ThreadTransferSrcResetCoordinateAfterRun, */ AThreadTransferSrcResetCoordinateAfterRun, +/* bool ThreadTransferDstResetCoordinateAfterRun, */ true, + NumGemmKPrefetchStage>( + a_grid_desc, + make_multi_index(0, m_block_data_idx_on_grid, 0), + a_element_op, + a_block_desc, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + return make_tuple(a_block_buf, a_blockwise_copy); + } + else + { + // Thread-wise copy + // KPerBlock/WmmaK -> MRepeat -> MWaves -> WmmaK/K1 -> MPerWmma -> K1 + constexpr auto KWmmaPerBlock = KPerBlock / WmmaK; + constexpr auto K0PerWmma = WmmaK/2/AK1Value; + auto a_block_buf = make_static_buffer( + a_block_desc.GetElementSpaceSize()); + + // Limitation: NumDim of Src and Dst descriptor should be identical + auto a_blockwise_copy = + ThreadwiseTensorSliceTransfer_v2{}, + Number{}, + I1, + Number{}, + I1, + I1, + Number{}>, + Sequence<0, 1, 2, 3, 4, 5, 6>, + 6, + ABlockTransferSrcScalarPerVector, + AThreadTransferSrcResetCoordinateAfterRun, + true>( + a_grid_desc, + make_multi_index(0, + m_block_data_idx_on_grid/(MWaves * MPerWmma), + get_thread_local_1d_id() / 32, + 0, + (get_thread_local_1d_id() % 32 )/ 16, + get_thread_local_1d_id() % 16, + 0)); + + return make_tuple(a_block_buf, a_blockwise_copy); + } + }; + + auto b0_block_trait = [&](){ + if constexpr(B0EnableLds) + { + auto b0_block_buf = make_dynamic_buffer( + static_cast(p_shared) + SharedMemTrait::b0_block_space_offset, + SharedMemTrait::b0_block_space_size_aligned); + + auto b0_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + B0BlockTransferThreadClusterLengths_K0_L_K1, + B0BlockTransferThreadClusterArrangeOrder, + B0DataType, + B0DataType, + decltype(b0_grid_desc), + decltype(b0_block_desc), + B0BlockTransferSrcAccessOrder, + Sequence<0, 1, 2>, + B0BlockTransferSrcVectorDim, + 2, + B0BlockTransferSrcScalarPerVector, + B0BlockTransferDstScalarPerVector_K1, + 1, + 1, + B0ThreadTransferSrcResetCoordinateAfterRun, + true, + NumGemmKPrefetchStage>( + b0_grid_desc, + make_multi_index(0, 0, 0), + b0_element_op, + b0_block_desc, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + return make_tuple(b0_block_buf, b0_blockwise_copy); + } + else + { + // Thread-wise copy + // KPerBlock/WmmaK -> LRepeat -> LWaves -> KRow -> LPerWmma -> K1 + constexpr auto KWmmaPerBlock = KPerBlock / WmmaK; + constexpr auto K0PerWmma = WmmaK/2/BK1Value; + auto b0_block_buf = make_static_buffer( + b0_block_desc.GetElementSpaceSize()); + + // Limitation: NumDim of Src and Dst descriptor should be identical + auto b0_blockwise_copy = + ThreadwiseTensorSliceTransfer_v2{}, + Number{}, + I1, + Number{}, + I1, + I1, + Number{}>, + Sequence<0, 1, 2, 3, 4, 5, 6>, + 6, + B0BlockTransferSrcScalarPerVector, + B0ThreadTransferSrcResetCoordinateAfterRun, + true>( + b0_grid_desc, + make_multi_index(0, + 0/(LWaves * LPerWmma), + get_thread_local_1d_id() / 32, + 0, + (get_thread_local_1d_id() % 32 )/ 16, + get_thread_local_1d_id() % 16, + 0)); + + return make_tuple(b0_block_buf, b0_blockwise_copy); + } + }; + + auto a_block_buf = a_block_trait()[I0]; + auto a_blockwise_copy = a_block_trait()[I1]; + + auto b0_block_buf = b0_block_trait()[I0]; + auto b0_blockwise_copy = b0_block_trait()[I1]; + +/*******************************************************************************/ + // Gemm0 + constexpr auto KPack = math::integer_least_multiple(math::integer_least_multiple(AK1Value,BK1Value), WmmaK); + + auto blockwise_gemm0 = BlockwiseGemmWMMA< + BlockSize, + ADataType, + B0DataType, + Acc0DataType, + decltype(MakeAWaveDescriptor(a_block_desc)), + decltype(MakeB0WaveDescriptor(b0_block_desc)), + MPerBlock, + LPerBlock, + KPerBlock, + MPerWmma, + LPerWmma, + MRepeat, + LRepeat, + KPack, + AEnableLds, + B0EnableLds, + true>{}; // C' = B' x A' + + + // Prepare Register for A*B0 matrix + auto acc0_thread_buf = blockwise_gemm0.GetCThreadBuffer(); + + constexpr auto acc0_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs = + blockwise_gemm0.GetCThreadDescriptor_MRepeat_MWave_MThreadPerSubGroup_NRepeat_NWave_NSubGroup_NAccVgprs(); + + constexpr auto mrepeat = acc0_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.GetLength(I0); + constexpr auto mwave = acc0_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.GetLength(I1); + constexpr auto mthreadpersubgroup = acc0_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.GetLength(I2); + constexpr auto lrepeat = acc0_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.GetLength(I3); + constexpr auto lwave = acc0_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.GetLength(I4); + constexpr auto lsubgroup = acc0_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.GetLength(I5); + constexpr auto laccvgprs = acc0_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.GetLength(I6); + + constexpr auto acc0_thread_desc_l0perblock_mperblock_l1 = transform_tensor_descriptor( + acc0_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs, + make_tuple(make_merge_transform_v3_division_mod(make_tuple(lrepeat, lwave, lsubgroup)), + make_merge_transform_v3_division_mod(make_tuple(mrepeat, mwave, mthreadpersubgroup)), + make_pass_through_transform(laccvgprs)), + make_tuple(Sequence<3, 4, 5>{}, Sequence<0, 1, 2>{}, Sequence<6>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + +/*******************************************************************************/ + // Shift Per SUB_K + constexpr auto a_block_slice_copy_step = MakeABlockSliceCopyStep(); + constexpr auto b0_block_slice_copy_step = MakeB0BlockSliceCopyStep(); + + const auto a_block_reset_copy_step = [&](){ + if constexpr(AEnableLds){ + return make_multi_index(-a_grid_desc.GetLength(I0), 0, 0); + } + else{ + return make_multi_index(-a_grid_desc.GetLength(I0), 0, 0, 0, 0, 0, 0); + } + }(); + + const auto b0_block_reset_copy_step = [&](){ + if constexpr(B0EnableLds){ + return make_multi_index(-b0_grid_desc.GetLength(I0), LPerBlock, 0); + } + else{ + return make_multi_index(-b0_grid_desc.GetLength(I0), LRepeat, 0, 0, 0, 0, 0); + } + }(); + + const auto K = [&](){ + if constexpr(AEnableLds){ + return a_grid_desc.GetLength(I0) * a_grid_desc.GetLength(I2); + } + else{ + return a_grid_desc.GetLength(I0) * a_grid_desc.GetLength(I3) * + a_grid_desc.GetLength(I4) * a_grid_desc.GetLength(I6); + } + }(); + + const index_t KBlockMainLoop = __builtin_amdgcn_readfirstlane(K / KPerBlock); +/*******************************************************************************/ +// softmax +/*******************************************************************************/ + auto workspace_buf = make_dynamic_buffer( + static_cast(p_shared) + SharedMemTrait::reduction_space_offset, + SharedMemTrait::reduction_space_size_aligned); + // get acc0 7D thread cluster + constexpr auto thread_cluster_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs = + blockwise_gemm0.GetCBlockDescriptor_MRepeat_MWave_MThreadPerSubGroup_NRepeat_NWave_NSubGroup_NAccVgprs().GetLengths() / + blockwise_gemm0.GetCThreadDescriptor_MRepeat_MWave_MThreadPerSubGroup_NRepeat_NWave_NSubGroup_NAccVgprs().GetLengths(); + constexpr auto t_mrepeat = thread_cluster_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.At(I0); + constexpr auto t_mwave = thread_cluster_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.At(I1); + constexpr auto t_mthreadpersubgroup = thread_cluster_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.At(I2); + constexpr auto t_lrepeat = thread_cluster_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.At(I3); + constexpr auto t_lwave = thread_cluster_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.At(I4); + constexpr auto t_lsubgroup = thread_cluster_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.At(I5); + constexpr auto t_laccvgprs = thread_cluster_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.At(I6); + // get acc0 thread map + constexpr auto m0_l_m1_to_m_l_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_unmerge_transform(make_tuple(t_mrepeat * t_mwave, t_mthreadpersubgroup)), + make_pass_through_transform(I1)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + constexpr auto threadid_to_m0_l_m1_adaptor = make_single_stage_tensor_adaptor( + make_tuple( + make_merge_transform( + make_tuple(t_mrepeat * t_mwave, t_lrepeat * t_lwave * t_lsubgroup * t_laccvgprs, t_mthreadpersubgroup))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + const auto threadid_to_l_n_thread_cluster_adaptor = + chain_tensor_adaptors(m0_l_m1_to_m_l_adaptor, threadid_to_m0_l_m1_adaptor); + + // get acc0 2D thread cluster & 2D thread slice + constexpr auto thread_cluster_desc_m_l = make_naive_tensor_descriptor_packed( + make_tuple(t_mrepeat * t_mwave * t_mthreadpersubgroup, t_lrepeat * t_lwave * t_lsubgroup * t_laccvgprs)); + + constexpr auto thread_slice_desc_m_l = make_naive_tensor_descriptor_packed( + make_tuple(mrepeat * mwave * mthreadpersubgroup, lrepeat * lwave * lsubgroup * laccvgprs)); + + auto blockwise_softmax = BlockwiseSoftmax{}; + + // Initialize running sum and max of exponentiating row vectors + using SoftmaxBuf = typename decltype(blockwise_softmax)::BufferType; + SoftmaxBuf running_sum, running_sum_new, running_max, running_max_new; + running_sum = 0; + running_sum_new = 0; + running_max = NumericLimits::Lowest(); + running_max_new = NumericLimits::Lowest(); +/*******************************************************************************/ +// set up Gemm1 +/*******************************************************************************/ + // Acc0 thread buffer -> A1 thread buffer -> blockwise gemm + // A1 matrix in VGPR + constexpr auto A1ThreadSlice_L0PerBlock_MPerBlock_L1 = make_tuple( + Number{}, + Number{}, + Number{}); + + constexpr auto A1ThreadSliceL0PerBlock = A1ThreadSlice_L0PerBlock_MPerBlock_L1[I0]; + constexpr auto A1ThreadSliceMPerBlock = A1ThreadSlice_L0PerBlock_MPerBlock_L1[I1]; + constexpr auto A1ThreadSliceL1 = A1ThreadSlice_L0PerBlock_MPerBlock_L1[I2]; + + constexpr auto a1_thread_desc_l0perblock_mperblock_l1 = make_naive_tensor_descriptor( + make_tuple(A1ThreadSliceL0PerBlock, A1ThreadSliceMPerBlock, A1ThreadSliceL1), + make_tuple(A1ThreadSliceMPerBlock * A1ThreadSliceL1, A1ThreadSliceL1, I1)); + + // A1 matrix blockwise copy + auto a1_blockwise_copy = ThreadwiseTensorSliceTransfer_StaticToStatic< + Acc0DataType, + ADataType, + decltype(acc0_thread_desc_l0perblock_mperblock_l1), + decltype(a1_thread_desc_l0perblock_mperblock_l1), + tensor_operation::element_wise::PassThrough, + Sequence, + Sequence<0, 1, 2>, + 2, + laccvgprs>{tensor_operation::element_wise::PassThrough{}}; + + auto a1_thread_buf = make_static_buffer( + a1_thread_desc_l0perblock_mperblock_l1.GetElementSpaceSize()); + + constexpr auto b1_block_desc = MakeB1BlockDescriptor(); + + auto b1_block_trait = [&](){ + if constexpr(B1EnableLds) + { + auto b1_block_buf = make_dynamic_buffer( + static_cast(p_shared) + SharedMemTrait::b1_block_space_offset, + SharedMemTrait::b1_block_space_size_aligned); + + auto b1_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1< ThisThreadBlock, +/* typename SrcElementwiseOperation, */ B1ElementwiseOperation, +/* typename DstElementwiseOperation, */ tensor_operation::element_wise::PassThrough, +/* InMemoryDataOperationEnum DstInMemOp, */ InMemoryDataOperationEnum::Set, +/* typename BlockSliceLengths, */ Sequence, +/* typename ThreadClusterLengths, */ B1BlockTransferThreadClusterLengths_L0_N_L1, +/* typename ThreadClusterArrangeOrder, */ B1BlockTransferThreadClusterArrangeOrder, +/* typename SrcData, */ B1DataType, +/* typename DstData, */ B1DataType, +/* typename SrcDesc, */ decltype(b1_grid_desc), +/* typename DstDesc, */ decltype(b1_block_desc), +/* typename SrcDimAccessOrder, */ B1BlockTransferSrcAccessOrder, +/* typename DstDimAccessOrder, */ Sequence<1, 0, 2>, +/* index_t SrcVectorDim, */ B1BlockTransferSrcVectorDim, +/* index_t DstVectorDim, */ 2, +/* index_t SrcScalarPerVector, */ B1BlockTransferSrcScalarPerVector, +/* index_t DstScalarPerVector, */ B1BlockTransferDstScalarPerVector_L1, +/* index_t SrcScalarStrideInVector, */ 1, +/* index_t DstScalarStrideInVector, */ 1, +/* bool ThreadTransferSrcResetCoordinateAfterRun, */ B1ThreadTransferSrcResetCoordinateAfterRun, +/* bool ThreadTransferDstResetCoordinateAfterRun, */ true, // DstResetCoord + NumGemmKPrefetchStage>( + b1_grid_desc, + make_multi_index(0, n_block_data_idx_on_grid, 0), + b1_element_op, + b1_block_desc, + make_multi_index(0, 0, 0), + tensor_operation::element_wise::PassThrough{}); + + return make_tuple(b1_block_buf, b1_blockwise_copy); + } + else + { + // Thread-wise copy + // KPerBlock/WmmaK -> NRepeat -> NWaves -> WmmaK/K1 -> NPerWmma -> K1 + constexpr auto LWmmaPerBlock = LTilePerBlock / WmmaL; + constexpr auto L0PerWmma = WmmaL/2/L1Value; + auto b1_block_buf = make_static_buffer( + b1_block_desc.GetElementSpaceSize()); + + // Limitation: NumDim of Src and Dst descriptor should be identical + auto b1_blockwise_copy = + ThreadwiseTensorSliceTransfer_v2{}, + Number{}, + I1, + Number{}, + I1, + I1, + Number{}>, + Sequence<0, 1, 2, 3, 4, 5, 6>, + 6, + B1BlockTransferSrcScalarPerVector, + B1ThreadTransferSrcResetCoordinateAfterRun, + true>( + b1_grid_desc, + make_multi_index(0, + n_block_data_idx_on_grid/(NWaves * NPerWmma), + get_thread_local_1d_id() / 32, + 0, + (get_thread_local_1d_id() % 32 )/ 16, + get_thread_local_1d_id() % 16, + 0)); + + return make_tuple(b1_block_buf, b1_blockwise_copy); + } + }; + + auto b1_block_buf = b1_block_trait()[I0]; + auto b1_blockwise_copy = b1_block_trait()[I1]; + + constexpr auto b1_block_slice_copy_step = MakeB1BlockSliceCopyStep(); + + auto blockwise_gemm1 = + BlockwiseGemmWMMA{make_tuple(0, 0, 0, 0, 0, 0)}; + + auto acc1_thread_buf = blockwise_gemm1.GetCThreadBuffer(); + + const auto L = [&](){ + if constexpr(B0EnableLds){ + return b0_grid_desc.GetLength(I1); + } + else{ + return b0_grid_desc.GetLength(I1) * b0_grid_desc.GetLength(I2) * b0_grid_desc.GetLength(I5); + } + }(); + + const index_t num_gemm1_l_block_outer_loop = L / LPerBlock; + constexpr index_t num_gemm1_l_block_inner_loop = LPerBlock / LTilePerBlock; + + // Initialize C + StaticBuffer c_thread_buf; + c_thread_buf.Clear(); + +/*******************************************************************************/ + // + // Kernel Main Stage + // + // Flash Attention + // Dao, Tri, et al. "Flashattention: Fast and memory-efficient exact attention with io-awareness." arXiv preprint arXiv:2205.14135 (2022). + index_t gemm1_l_block_outer_index = 0; + // Outer loop, along GEMM_L + // Inner loop, along GEMM_K + do{ + auto l_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(gemm1_l_block_outer_index * LPerBlock); + if(c0_matrix_mask.IsTileSkippable( + m_block_data_idx_on_grid, l_block_data_idx_on_grid, MPerBlock, LPerBlock)) + { + continue; + } + // gemm0 start, A-B swaped + GridwiseGemmPipe::template Run(a_grid_desc, + a_block_desc, + a_blockwise_copy, + a_grid_buf, + a_block_buf, + a_block_slice_copy_step, + b0_grid_desc, + b0_block_desc, + b0_blockwise_copy, + b0_grid_buf, + b0_block_buf, + b0_block_slice_copy_step, + blockwise_gemm0, + acc0_thread_buf, + KBlockMainLoop); + // do MNK padding or upper triangular masking + if constexpr(MaskOutUpperTriangle || PadN) + { + // 7d thread_desc in thread scope + constexpr auto c_thread_lengths = + blockwise_gemm0.GetCThreadDescriptor_MRepeat_MWave_MThreadPerSubGroup_NRepeat_NWave_NSubGroup_NAccVgprs().GetLengths(); + + // 7d block_desc in block scope + constexpr auto c_block_lengths = + blockwise_gemm0.GetCBlockDescriptor_MRepeat_MWave_MThreadPerSubGroup_NRepeat_NWave_NSubGroup_NAccVgprs().GetLengths(); + + constexpr auto MREPEAT = c_block_lengths[I0]; + constexpr auto MWAVE = c_block_lengths[I1]; + constexpr auto MTHREADSubGroup = c_block_lengths[I2]; + constexpr auto LREPEAT = c_block_lengths[I3]; + constexpr auto LWAVE = c_block_lengths[I4]; + constexpr auto LSUBGROUP = c_block_lengths[I5]; + constexpr auto LACCVGPRS = c_block_lengths[I6]; + + // works like multi-dimension static_for (static_ford), but provides both the linear + // index as well as n-d index + using Acc0TileIterator = SpaceFillingCurve< + decltype(c_thread_lengths), + typename arithmetic_sequence_gen<0, c_thread_lengths.Size(), 1>::type, + typename uniform_sequence_gen::type, + false>; // SnakeCurved + + auto acc0_thread_origin = blockwise_gemm0.CalculateCThreadOriginDataIndex7D( + Number<0>{}, Number<0>{}); + + constexpr auto block_idx_to_m_l_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_unmerge_transform(make_tuple(MREPEAT, MWAVE, MTHREADSubGroup)), + make_unmerge_transform(make_tuple(LREPEAT, LWAVE, LSUBGROUP, LACCVGPRS))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4, 5, 6>{})); + + static_for<0, Acc0TileIterator::GetNumOfAccess(), 1>{}([&](auto i) { + auto acc0_thread_idx = Acc0TileIterator::GetIndex(i) + acc0_thread_origin; + auto m_local = block_idx_to_m_l_adaptor.CalculateBottomIndex(acc0_thread_idx)[I0]; + auto l_local = block_idx_to_m_l_adaptor.CalculateBottomIndex(acc0_thread_idx)[I1]; + auto m_global = m_local + m_block_data_idx_on_grid; + auto l_global = l_local + l_block_data_idx_on_grid; + if(c0_matrix_mask.IsMaskedElement(m_global, l_global)) + { + acc0_thread_buf(i) = -ck::NumericLimits::Infinity(); + } + else + { + acc_element_op(acc0_thread_buf(i), acc0_thread_buf[i]); + } + }); + } + else + { static_for<0, acc0_thread_buf.Size(), 1>{}( + [&](auto i) { acc_element_op(acc0_thread_buf(i), acc0_thread_buf[i]); }); + } + + block_sync_lds(); + // Tiled softmax start + // softmax + SoftmaxBuf& max = blockwise_softmax.max_value_buf; + SoftmaxBuf& sum = blockwise_softmax.sum_value_buf; + + blockwise_softmax.Run(acc0_thread_buf, workspace_buf); + + // TODO: may convert to log domain + running_max_new = mathext::max(max, running_max); + running_sum_new = mathext::exp(running_max - running_max_new) * running_sum + + mathext::exp(max - running_max_new) * sum; + + // gemm1 + { + // TODO: explore using dynamic buffer for a1 thread buffer + // For a1_blockwise_copy, the goal is to satisfy pipeline requirements RunRead(), + // RunWrite(), and MoveSliceWindow(). But it is impossible to implement given that + // the A1 source buffer is static buffer holding the output of first GEMM and + // requires constexpr offset by design. Therefore, we pass tensor coordinate offset + // explicitly in Run() below. + + // Initialize acc1 + acc1_thread_buf.Clear(); + + // preload data into LDS + b1_blockwise_copy.RunRead(b1_grid_desc, b1_grid_buf); + + b1_blockwise_copy.MoveSrcSliceWindow(b1_grid_desc, + b1_block_slice_copy_step); + + block_sync_lds(); // wait for reduction LDS read + + b1_blockwise_copy.RunWrite(b1_block_desc, b1_block_buf); + + // main body + if constexpr(num_gemm1_l_block_inner_loop > 1) + { + static_for<0, num_gemm1_l_block_inner_loop - 1, 1>{}([&](auto i) { + // Data cast from Acc0DataType to ADataType happen here + a1_blockwise_copy.Run(acc0_thread_desc_l0perblock_mperblock_l1, + make_tuple(Number{}, I0, I0), + acc0_thread_buf, + a1_thread_desc_l0perblock_mperblock_l1, + make_tuple(I0, I0, I0), + a1_thread_buf); + + b1_blockwise_copy.RunRead(b1_grid_desc, b1_grid_buf); + + block_sync_lds(); + + blockwise_gemm1.Run(a1_thread_buf, b1_block_buf, acc1_thread_buf); + + block_sync_lds(); + + b1_blockwise_copy.MoveSrcSliceWindow(b1_grid_desc, + b1_block_slice_copy_step); + + b1_blockwise_copy.RunWrite(b1_block_desc, b1_block_buf); + }); + } + // tail + { + a1_blockwise_copy.Run( + acc0_thread_desc_l0perblock_mperblock_l1, + make_tuple( + Number<(num_gemm1_l_block_inner_loop - 1) * A1ThreadSliceL0PerBlock>{}, I0, I0), + acc0_thread_buf, + a1_thread_desc_l0perblock_mperblock_l1, + make_tuple(I0, I0, I0), + a1_thread_buf); + + block_sync_lds(); + + blockwise_gemm1.Run(a1_thread_buf, b1_block_buf, acc1_thread_buf); + } + } // end gemm1 + + constexpr auto c_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs = + blockwise_gemm1.GetCThreadDescriptor_MRepeat_MWave_MThreadPerSubGroup_NRepeat_NWave_NSubGroup_NAccVgprs(); + constexpr auto c_mrepeat = c_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.GetLength(I0); + constexpr auto c_mwave = c_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.GetLength(I1); + constexpr auto c_mthreadpersubgroup = c_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.GetLength(I2); + constexpr auto c_nrepeat = c_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.GetLength(I3); + constexpr auto c_nwave = c_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.GetLength(I4); + constexpr auto c_nsubgroup = c_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.GetLength(I5); + constexpr auto c_naccvgprs = c_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.GetLength(I6); + + constexpr auto c_thread_slice_desc_m_n = make_naive_tensor_descriptor_packed( + make_tuple(c_mrepeat * c_mwave * c_mthreadpersubgroup, + c_nrepeat * c_nwave * c_nsubgroup * c_naccvgprs)); + constexpr auto c_thread_buf_slice_m = c_thread_slice_desc_m_n.GetLength(I0); + constexpr auto c_thread_buf_slice_n = c_thread_slice_desc_m_n.GetLength(I1); + + static_for<0, c_thread_buf_slice_m, 1>{}([&](auto iM) { + static_for<0, c_thread_buf_slice_n, 1>{}([&](auto iN) { + auto I = Number{}; + Acc1DataType acc1 = acc1_thread_buf[I]; // P*V + Acc1DataType c = c_thread_buf[I]; // O + Acc1DataType c_new = + (running_sum[iM] * math::exp(running_max[iM] - running_max_new[iM]) * c + + math::exp(max[iM] - running_max_new[iM]) * acc1) / + running_sum_new[iM]; + + c_thread_buf(I) = c_new; // O_new + }); + }); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, + a_block_reset_copy_step); // rewind K + b0_blockwise_copy.MoveSrcSliceWindow(b0_grid_desc, + b0_block_reset_copy_step); // rewind K and step N + + // update before next j iteration + running_max = running_max_new; + running_sum = running_sum_new; + + block_sync_lds(); // wait for gemm1 LDS read + }while(++gemm1_l_block_outer_index < num_gemm1_l_block_outer_loop); +/*******************************************************************************/ + // write out to C, implement shuffle + { + constexpr auto c_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs = + blockwise_gemm1.GetCThreadDescriptor_MRepeat_MWave_MThreadPerSubGroup_NRepeat_NWave_NSubGroup_NAccVgprs(); + + // This API Provide All dimension (size) you need + constexpr auto c_block_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs_tmp = + blockwise_gemm1.GetCBlockDescriptor_MRepeat_MWave_MThreadPerSubGroup_NRepeat_NWave_NSubGroup_NAccVgprs(); + + constexpr auto MWave = c_block_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs_tmp.GetLength(I1); + constexpr auto MThreadPerSubGroup = c_block_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs_tmp.GetLength(I2); + constexpr auto NWave = c_block_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs_tmp.GetLength(I4); + constexpr auto NSubGroup = c_block_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs_tmp.GetLength(I5); + constexpr auto NAccVgprs = c_block_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs_tmp.GetLength(I6); + + // LDS descriptor, shuffle and write out in MRepeat x NRepeat times + constexpr auto c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat = + GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat(); + + auto c_shuffle_block_buf = make_dynamic_buffer( + static_cast(p_shared), + c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat.GetElementSpaceSize()); + + constexpr auto c_block_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs = transform_tensor_descriptor( + c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat, + make_tuple( + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // MRepeat per shuffle repeat + MWave, // MWave + MThreadPerSubGroup // MThreadPerSubGroup = MPerWmma + )), + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // NRepeat per shuffle repeat + NWave, // NWave + NSubGroup, + NAccVgprs))), // NSubGroup * NAccVgprs = NPerWmma + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<>{}, Sequence<0, 1, 2>{}, Sequence<>{}, Sequence<3, 4, 5, 6>{})); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = blockwise_gemm1.CalculateCThreadOriginDataIndex(I0, I0); + + const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; + const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; + + const auto m_thread_data_on_block_to_mrepeat_mwave_mthreadpersubgroup_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(MRepeat, MWave, MThreadPerSubGroup))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_block_to_nrepeat_nwave_nsubgroup_naccvgprs_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(NRepeat, NWave, NSubGroup, NAccVgprs))), + make_tuple(Sequence<0, 1, 2, 3>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_block_idx = m_thread_data_on_block_to_mrepeat_mwave_mthreadpersubgroup_adaptor.CalculateBottomIndex( + make_multi_index(m_thread_data_on_block)); + + const auto n_thread_data_on_block_idx = n_thread_data_on_block_to_nrepeat_nwave_nsubgroup_naccvgprs_adaptor.CalculateBottomIndex( + make_multi_index(n_thread_data_on_block)); + + // shuffle: threadwise copy C from VGPR to LDS + auto c_thread_copy_vgpr_to_lds = + ThreadwiseTensorSliceTransfer_v1r3, + Sequence<0, 1, 2, 3, 4, 5, 6>, + 6, + 8, // vector write pixel + InMemoryDataOperationEnum::Set, + 1, + true>{ + c_block_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs, + make_multi_index(0, + m_thread_data_on_block_idx[I1], + m_thread_data_on_block_idx[I2], + 0, + n_thread_data_on_block_idx[I1], + n_thread_data_on_block_idx[I2], + n_thread_data_on_block_idx[I3]), + ck::tensor_operation::element_wise::PassThrough{}}; + + // shuffle: blockwise copy C from LDS to global + auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1< + ThisThreadBlock, // ThreadGroup + CElementwiseOperation, // ElementwiseOperation, + CGlobalMemoryDataOperation, // DstInMemOp, + Sequence<1, + CShuffleMRepeatPerShuffle * MWave * MPerWmma, + 1, + CShuffleNRepeatPerShuffle * NWave * NPerWmma>, // BlockSliceLengths, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, + CShuffleDataType, // typename SrcData, + CDataType, // typename DstData, + decltype(c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat), + decltype(c_grid_desc_mblock_mperblock_nblock_nperblock), + Sequence<0, 1, 2, 3>, // typename DimAccessOrder, + 3, // index_t VectorDim, + CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector, + true, // bool ThreadTransferSrcResetCoordinateAfterRun, + false> // bool ThreadTransferDstResetCoordinateAfterRun> + {c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat, + make_multi_index(0, 0, 0, 0), + c_grid_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0), + c_element_op}; + + // space filling curve for local reg & global memory + // space filling curve for threadwise C in VGPR + constexpr auto sfc_c_vgpr = + SpaceFillingCurve, + Sequence<0, 1, 2, 3, 4, 5, 6>, + Sequence>{}; + + // space filling curve for shuffled blockwise C in global mem + constexpr auto sfc_c_global = + SpaceFillingCurve, + Sequence<0, 2, 1, 3>, + Sequence<1, + CShuffleMRepeatPerShuffle * MWave * MPerWmma, + 1, + CShuffleNRepeatPerShuffle * NWave * NPerWmma>>{}; + + constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); + + static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!"); + + static_for<0, num_access, 1>{}([&](auto access_id) { + // make sure it's safe to write to LDS + block_sync_lds(); + + // each thread write its data from VGPR to LDS + c_thread_copy_vgpr_to_lds.Run(c_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs, + sfc_c_vgpr.GetIndexTupleOfNumber(access_id), + c_thread_buf, + c_block_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs, + c_shuffle_block_buf); + + // make sure it's safe to read from LDS + block_sync_lds(); + + // each block copy its data from LDS to global + c_shuffle_block_copy_lds_to_global.Run( + c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat, + c_shuffle_block_buf, + c_grid_desc_mblock_mperblock_nblock_nperblock, + c_grid_buf); + + if constexpr(access_id < num_access - 1) + { + constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id); + // move on C + c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow( + c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step); + } + }); + } + // clang-format on + } +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_fpAintB_gemm_wmma.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_fpAintB_gemm_wmma.hpp new file mode 100644 index 0000000000..67e211ef8d --- /dev/null +++ b/include/ck/tensor_operation/gpu/grid/gridwise_fpAintB_gemm_wmma.hpp @@ -0,0 +1,1046 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/multi_index_transform_helper.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1_dequant.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +namespace ck { + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_fpAintB_gemm_wmma(const ADataType* __restrict__ p_a_grid, + const BDataType* __restrict__ p_b_grid, + const ScaleDataType* __restrict__ p_scale_grid, + CDataType* __restrict__ p_c_grid, + const AGridDesc a_grid_desc, + const BGridDesc b_grid_desc, + const ScaleGridDesc scale_grid_desc, + const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + c_grid_desc_mblock_mperblock_nblock_nperblock, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CElementwiseOperation c_element_op, + const Block2CTileMap block_2_ctile_map) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx1100__) || defined(__gfx1101__) || \ + defined(__gfx1102__)) + __shared__ char p_shared[GridwiseGemm::SharedMemTrait::lds_size]; + + GridwiseGemm::template Run(p_a_grid, + p_b_grid, + p_scale_grid, + p_c_grid, + p_shared, + a_grid_desc, + b_grid_desc, + scale_grid_desc, + c_grid_desc_mblock_mperblock_nblock_nperblock, + a_element_op, + b_element_op, + c_element_op, + block_2_ctile_map); +#else + ignore = p_a_grid; + ignore = p_b_grid; + ignore = p_scale_grid; + ignore = p_c_grid; + ignore = a_grid_desc; + ignore = b_grid_desc; + ignore = scale_grid_desc; + ignore = c_grid_desc_mblock_mperblock_nblock_nperblock; + ignore = a_element_op; + ignore = b_element_op; + ignore = c_element_op; + ignore = block_2_ctile_map; +#endif // end of if (defined(__gfx1100__)) +} + +// Assume B is Col-Major +template +struct GridwiseFpAintBGemm_Wmma +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto I6 = Number<6>{}; + static constexpr auto I7 = Number<7>{}; + + // FIX ME: To be deprecated + static constexpr auto K1 = Number{}; + + static constexpr auto MWaves = MPerBlock / (MRepeat * MPerWmma); + static constexpr auto NWaves = NPerBlock / (NRepeat * NPerWmma); + static constexpr auto WmmaK = K1 == 16 ? 32 : 16; + + using ThisThreadBlock = ThisThreadBlock; + + using GridwiseGemmPipe = + remove_cvref_t())>; + + // Describe how data store to (LDS/VGPR) buffer from Global memory + __host__ __device__ static constexpr auto MakeABlockDescriptor() + { + constexpr auto a_block_desc = [&]() { + if constexpr(AEnableLds) + { + // K0->M->K1 Per Block + constexpr auto K0PerBlock = KPerBlock / K1; + constexpr auto max_lds_align = K1; + + if constexpr(ABlockLdsExtraM) + { + return make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, K1), + make_tuple(Number{} * K1, K1, I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, K1), max_lds_align); + } + } + else + { + constexpr auto KWmmaPerblock = KPerBlock / WmmaK; + constexpr auto K0PerWmma = WmmaK / 2 / K1; + // KWmma->MRepeat->MWave->K0PerWmma->KRow->MPerWmma->K1 Per Thread + return make_naive_tensor_descriptor( + make_tuple(Number{}, + Number{}, + I1, + Number{}, + I1, + I1, + K1), + make_tuple(Number{} * Number{} * K1, + Number{} * K1, + Number{} * K1, + K1, + K1, + K1, + I1)); + } + }(); + + return a_block_desc; + } + + __host__ __device__ static constexpr auto MakeBBlockDescriptor() + { + constexpr auto b_block_desc = [&]() { + if constexpr(BEnableLds) + { + // K0->N->K1 Per Block + constexpr auto K0PerBlock = KPerBlock / K1; + constexpr auto max_lds_align = K1; + + if constexpr(BBlockLdsExtraN) + { + return make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, K1), + make_tuple(Number{} * K1, K1, I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, K1), max_lds_align); + } + } + else + { + constexpr auto KWmmaPerblock = KPerBlock / WmmaK; + constexpr auto K0PerWmma = WmmaK / 2 / K1; + // KWmma->NRepeat->MWave->K0PerWmma->KRow->MPerWmma->K1 Per Thread + return make_naive_tensor_descriptor( + make_tuple(Number{}, + Number{}, + I1, + Number{}, + I1, + I1, + K1), + make_tuple(Number{} * Number{} * K1, + Number{} * K1, + Number{} * K1, + K1, + K1, + K1, + I1)); + } + }(); + + return b_block_desc; + } + + __host__ __device__ static constexpr auto MakeABlockSliceCopyStep() + { + constexpr auto a_block_copy_step = [&]() { + if constexpr(AEnableLds) + { + constexpr auto K0PerBlock = KPerBlock / K1; + + return make_multi_index(K0PerBlock, 0, 0); + } + else + { + constexpr auto KWmmaPerBlock = KPerBlock / WmmaK; + + return make_multi_index(KWmmaPerBlock, 0, 0, 0, 0, 0, 0); + } + }(); + + return a_block_copy_step; + } + + __host__ __device__ static constexpr auto MakeBBlockSliceCopyStep() + { + constexpr auto b_block_copy_step = [&]() { + if constexpr(BEnableLds) + { + constexpr auto K0PerBlock = KPerBlock / K1; + + return make_multi_index(K0PerBlock, 0, 0); + } + else + { + constexpr auto KWmmaPerBlock = KPerBlock / WmmaK; + + return make_multi_index(KWmmaPerBlock, 0, 0, 0, 0, 0, 0); + } + }(); + + return b_block_copy_step; + } + + // Describe how data read from (LDS/VGPR) buffer + template + __host__ __device__ static constexpr auto MakeAWaveDescriptor(const ABlockDesc_&) + { + + constexpr auto a_wave_desc = [&]() { + if constexpr(AEnableLds) + { + // AK0_M_AK1 -> AK0_MRepeat_Mwaves_AKRow_MPerWmma_AK1 + constexpr auto A_K0 = ABlockDesc_{}.GetLength(I0); + constexpr auto A_K1 = ABlockDesc_{}.GetLength(I2); + constexpr auto A_KRow = I1; + return transform_tensor_descriptor( + ABlockDesc_{}, + make_tuple(make_unmerge_transform(make_tuple(Number{}, A_KRow)), + make_unmerge_transform(make_tuple( + Number{}, Number{}, Number{})), + make_pass_through_transform(Number{})), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0, 3>{}, Sequence<1, 2, 4>{}, Sequence<5>{})); + } + else + { + // KWmma_MRepeat_MWave_K0PerWmma_KRow_MPerWmma_K1 -> K0_MRepeat_Mwaves_MPerWmma_K1 + constexpr auto KWmma = ABlockDesc_{}.GetLength(I0); + constexpr auto K0PerWmma = ABlockDesc_{}.GetLength(I3); + constexpr auto A_KRow = ABlockDesc_{}.GetLength(I4); + constexpr auto A_K1 = ABlockDesc_{}.GetLength(I6); + + // Err: merge transform cause non-constexpr issue + + // return transform_tensor_descriptor( + // ABlockDesc_{}, + // make_tuple(make_merge_transform(make_tuple(Number{}, I1)), + // make_pass_through_transform(Number{}), + // make_pass_through_transform(I1), + // make_pass_through_transform(I1), + // make_pass_through_transform(Number{})), + // make_tuple(Sequence<0, 3>{}, + // Sequence<1>{}, + // Sequence<2>{}, + // Sequence<4>{}, + // Sequence<5>{}), + // make_tuple( + // Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, + // Sequence<4>{})); + + // Workaround, Freeze transform + return make_naive_tensor_descriptor_packed(make_tuple(Number{}, + Number{}, + I1, + Number{}, + I1, + Number{})); + } + }(); + + return a_wave_desc; + } + + template + __host__ __device__ static constexpr auto MakeBWaveDescriptor(const BBlockDesc_&) + { + constexpr auto b_wave_desc = [&]() { + if constexpr(BEnableLds) + { + // BK0_N_BK1 -> BK0_NRepeat_Nwaves_NPerWmma_BK1 + constexpr auto B_K0 = BBlockDesc_{}.GetLength(I0); + constexpr auto B_K1 = BBlockDesc_{}.GetLength(I2); + constexpr auto B_KRow = I1; + return transform_tensor_descriptor( + BBlockDesc_{}, + make_tuple(make_unmerge_transform(make_tuple(Number{}, B_KRow)), + make_unmerge_transform(make_tuple( + Number{}, Number{}, Number{})), + make_pass_through_transform(Number{})), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0, 3>{}, Sequence<1, 2, 4>{}, Sequence<5>{})); + } + else + { + // KWmma_MRepeat_MWave_K0PerWmma_KRow_MPerWmma_K1 -> K0_MRepeat_Mwaves_MPerWmma_K1 + constexpr auto KWmma = BBlockDesc_{}.GetLength(I0); + constexpr auto K0PerWmma = BBlockDesc_{}.GetLength(I3); + constexpr auto B_KRow = BBlockDesc_{}.GetLength(I4); + constexpr auto B_K1 = BBlockDesc_{}.GetLength(I6); + + // Workaround, Freeze transform + return make_naive_tensor_descriptor_packed(make_tuple(Number{}, + Number{}, + I1, + Number{}, + I1, + Number{})); + } + }(); + + return b_wave_desc; + } + + __host__ __device__ static constexpr auto + // *Caution Here repeat is shuffle repeat + GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat() + { + constexpr auto c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat = + make_naive_tensor_descriptor_packed( + make_tuple(I1, + Number{}, + I1, + Number{})); + + return c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat; + } + + // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} + template + __host__ __device__ static constexpr bool CheckValidity(const AGridDesc& a_grid_desc, + const BGridDesc& b_grid_desc, + const CGridDesc_M_N& c_grid_desc_m_n, + const Block2CTileMap& block_2_ctile_map) + { + static_assert(is_known_at_compile_time>::value, + "wrong! K1 need to be known at compile-time"); + + static_assert((MPerBlock % (MPerWmma * MRepeat) == 0) && + (NPerBlock % (NRepeat * NPerWmma)) == 0, + "Invalid tuning param!"); + + const auto GetAProblemsizeMK = [&]() { + if constexpr(AEnableLds) + { + return make_tuple(a_grid_desc.GetLength(I1), + a_grid_desc.GetLength(I0) * a_grid_desc.GetLength(I2)); + } + else + { + return make_tuple(a_grid_desc.GetLength(I1) * a_grid_desc.GetLength(I2) * + a_grid_desc.GetLength(I5), + a_grid_desc.GetLength(I0) * a_grid_desc.GetLength(I3) * + a_grid_desc.GetLength(I4) * a_grid_desc.GetLength(I6)); + } + }; + + const auto GetBProblemsizeNK = [&]() { + if constexpr(BEnableLds) + { + return make_tuple(b_grid_desc.GetLength(I1), + b_grid_desc.GetLength(I0) * b_grid_desc.GetLength(I2)); + } + else + { + return make_tuple(b_grid_desc.GetLength(I1) * b_grid_desc.GetLength(I2) * + b_grid_desc.GetLength(I5), + b_grid_desc.GetLength(I0) * b_grid_desc.GetLength(I3) * + b_grid_desc.GetLength(I4) * b_grid_desc.GetLength(I6)); + } + }; + + const auto M = GetAProblemsizeMK()[I0]; + const auto N = GetBProblemsizeNK()[I0]; + const auto K = GetAProblemsizeMK()[I1]; + + if(!(M == c_grid_desc_m_n.GetLength(I0) && N == c_grid_desc_m_n.GetLength(I1) && + K == GetBProblemsizeNK()[I1])) + { + printf("A: MxK = %d x %d, B: NxK = %d x %d, C: MxN = %d x %d\n", + GetAProblemsizeMK()[I0], + GetAProblemsizeMK()[I1], + GetBProblemsizeNK()[I0], + GetBProblemsizeNK()[I1], + c_grid_desc_m_n.GetLength(I0), + c_grid_desc_m_n.GetLength(I1)); + printf("GridwiseOp err: ProblemSize check"); + return false; + } + + if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0)) + { + printf("GridwiseOp err: ProblemSize division"); + return false; + } + + // check gridwise gemm pipeline + const auto num_k_loop = K / KPerBlock; + + if(!GridwiseGemmPipe::IsSupported(num_k_loop)) + { + printf("GridwiseOp err: Pipeline not support this k_loop"); + return false; + } + + if(!block_2_ctile_map.CheckValidity(c_grid_desc_m_n)) + { + return false; + } + + // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) + constexpr long_index_t TwoGB = (long_index_t{1} << 31); + + if(!(a_grid_desc.GetElementSpaceSize() * sizeof(ADataType) <= TwoGB && + b_grid_desc.GetElementSpaceSize() * sizeof(BDataType) <= TwoGB)) + { + return false; + } + return true; + } + + __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K) + { + const index_t num_loop = K / KPerBlock; + + return GridwiseGemmPipe::CalculateHasMainLoop(num_loop); + } + + __host__ __device__ static constexpr auto + MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc_M_N& c_grid_desc_m_n) + { + const auto M = c_grid_desc_m_n.GetLength(I0); + const auto N = c_grid_desc_m_n.GetLength(I1); + + const auto MBlock = M / MPerBlock; + const auto NBlock = N / NPerBlock; + + const auto c_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple(make_unmerge_transform(make_tuple(MBlock, Number{})), + make_unmerge_transform(make_tuple(NBlock, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{})); + + return c_grid_desc_mblock_mperblock_nblock_nperblock; + } + + // return block_id to C matrix tile idx (m0, n0) mapping + __host__ __device__ static constexpr auto MakeDefaultBlock2CTileMap( + const CGridDesc_M_N& c_grid_desc_m_n, index_t /* M01 */, index_t /* N01 */) + { + return BlockToCTileMap_M00_N0_M01Adapt( + c_grid_desc_m_n); + } + + using CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = + remove_cvref_t; + using DefaultBlock2CTileMap = + remove_cvref_t; + + struct SharedMemTrait + { + // LDS allocation for A and Dequantized B: be careful of DataType + // scale would not put into LDS. + using LDS_ADataType = ADataType; + using LDS_BDataType = ADataType; + using LDS_CDataType = CShuffleDataType; + static constexpr auto max_lds_align = K1; + + static constexpr auto a_block_space_size_aligned = + AEnableLds ? math::integer_least_multiple(MakeABlockDescriptor().GetElementSpaceSize(), + max_lds_align) + : 0; + static constexpr auto b_block_space_size_aligned = + BEnableLds ? math::integer_least_multiple(MakeBBlockDescriptor().GetElementSpaceSize(), + max_lds_align) + : 0; + + static constexpr auto a_block_space_offset = 0; + // B would be dequantize to ADataType before enter LDS + // b_lds_offset = LDS size allocated for a in byte / LDS_BDataType + static constexpr auto b_block_space_offset = + (a_block_space_offset + a_block_space_size_aligned) * sizeof(LDS_ADataType) / + sizeof(LDS_BDataType); + + // LDS allocation for C shuffle in LDS + static constexpr auto c_shuffle_block_space_size = + GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat() + .GetElementSpaceSize(); + + static constexpr auto c_shuffle_block_space_offset = 0; + + static constexpr auto lds_size = + math::max(c_shuffle_block_space_size * sizeof(LDS_CDataType), + a_block_space_size_aligned * sizeof(LDS_ADataType) + + b_block_space_size_aligned * sizeof(LDS_BDataType)); + }; + + template + __device__ static void Run(const ADataType* __restrict__ p_a_grid, + const BDataType* __restrict__ p_b_grid, + const ScaleDataType* __restrict__ p_scale_grid, + CDataType* __restrict__ p_c_grid, + void* __restrict__ p_shared, + const AGridDesc& a_grid_desc, + const BGridDesc& b_grid_desc, + const ScaleGridDesc& scale_grid_desc, + const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock& + c_grid_desc_mblock_mperblock_nblock_nperblock, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CElementwiseOperation& c_element_op, + const Block2CTileMap& block_2_ctile_map) + { + // clang-format off +/*******************************************************************************/ +// Memory buffer zone. + const auto a_grid_buf = make_dynamic_buffer( + p_a_grid, a_grid_desc.GetElementSpaceSize()); + const auto b_grid_buf = make_dynamic_buffer( + p_b_grid, b_grid_desc.GetElementSpaceSize()); + const auto scale_grid_buf = make_dynamic_buffer( + p_scale_grid, scale_grid_desc.GetElementSpaceSize()); + auto c_grid_buf = make_dynamic_buffer( + p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + +/*******************************************************************************/ +// BlockIdx.x -> [BlockId.m, BlockId.n] + const auto block_work_idx = block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + if(!block_2_ctile_map.ValidCTileIndex( + block_work_idx, + make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0), + c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2)))) + { return; } + + // Store BlockId into SGPR + const index_t m_block_data_idx_on_grid = __builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock); + const index_t n_block_data_idx_on_grid = __builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock); + +/*******************************************************************************/ +// BlockLevel, A/B Matrix ThreadMapping in WMMA Source buffer, As Destinaion of BlockWise_Copy + const auto K = [&](){ + if constexpr(AEnableLds){ + return a_grid_desc.GetLength(I0) * a_grid_desc.GetLength(I2); + } + else{ + return a_grid_desc.GetLength(I0) * a_grid_desc.GetLength(I3) + * a_grid_desc.GetLength(I4) * a_grid_desc.GetLength(I6); + } + }(); + + constexpr auto a_block_desc = MakeABlockDescriptor(); + constexpr auto b_block_desc = MakeBBlockDescriptor(); + + auto a_block_trait = [&](){ + // A matrix blockwise copy + if constexpr(AEnableLds) + { + constexpr auto K0PerBlock = KPerBlock/ K1; + auto a_block_buf = make_dynamic_buffer( + static_cast(p_shared), + SharedMemTrait::a_block_space_size_aligned); + + auto a_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, +/* typename ThreadClusterLengths, */ ABlockTransferThreadClusterLengths_K0_M_K1, +/* typename ThreadClusterArrangeOrder, */ ABlockTransferThreadClusterArrangeOrder, +/* typename SrcData, */ ADataType, +/* typename DstData, */ ADataType, +/* typename SrcDesc, */ decltype(a_grid_desc), +/* typename DstDesc, */ decltype(a_block_desc), +/* typename SrcDimAccessOrder, */ ABlockTransferSrcAccessOrder, +/* typename DstDimAccessOrder, */ Sequence<0, 1, 2>, +/* index_t SrcVectorDim, */ ABlockTransferSrcVectorDim, +/* index_t DstVectorDim, */ 2, +/* index_t SrcScalarPerVector, */ ABlockTransferSrcScalarPerVector, +/* index_t DstScalarPerVector, */ ABlockTransferDstScalarPerVector_K1, +/* index_t SrcScalarStrideInVector, */ 1, +/* index_t DstScalarStrideInVector, */ 1, +/* bool ThreadTransferSrcResetCoordinateAfterRun, */ AThreadTransferSrcResetCoordinateAfterRun, +/* bool ThreadTransferDstResetCoordinateAfterRun, */ true, + NumGemmKPrefetchStage>( + a_grid_desc, + make_multi_index(0, m_block_data_idx_on_grid, 0), + a_element_op, + a_block_desc, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + return make_tuple(a_block_buf, a_blockwise_copy); + } + else + { + // Thread-wise copy + // KPerBlock/WmmaK -> MRepeat -> MWaves -> K0PerWmma -> KRow -> MPerWmma -> K1 + constexpr auto KWmmaPerBlock = KPerBlock / WmmaK; + constexpr auto K0PerWmma = WmmaK/2/K1Value; + auto a_block_buf = make_static_buffer( + a_block_desc.GetElementSpaceSize()); + + // Limitation: NumDim of Src and Dst descriptor should be identical + auto a_blockwise_copy = + ThreadwiseTensorSliceTransfer_v2{}, + Number{}, + I1, + Number{}, + I1, + I1, + Number{}>, + Sequence<0, 1, 2, 3, 4, 5, 6>, + 6, + ABlockTransferSrcScalarPerVector, + AThreadTransferSrcResetCoordinateAfterRun, + true>( + a_grid_desc, + make_multi_index(0, + m_block_data_idx_on_grid/(MWaves * MPerWmma), + get_thread_local_1d_id() / 32, + 0, + (get_thread_local_1d_id() % 32 )/ 16, + get_thread_local_1d_id() % 16, + 0)); + + return make_tuple(a_block_buf, a_blockwise_copy); + } + }; + + auto b_block_trait = [&](){ + if constexpr(BEnableLds) + { + constexpr auto K0PerBlock = KPerBlock/ K1; + auto b_block_buf = make_dynamic_buffer( + static_cast(p_shared) + SharedMemTrait::b_block_space_offset, + SharedMemTrait::b_block_space_size_aligned); + + auto b_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1_dequant, +/* typename BlockScaleSliceLengths, */ Sequence, +/* typename ThreadClusterLengths, */ BBlockTransferThreadClusterLengths_K0_N_K1, +/* typename ThreadClusterArrangeOrder, */ BBlockTransferThreadClusterArrangeOrder, +/* typename SrcData, */ BDataType, +/* typename ScaleData, */ ScaleDataType, +/* typename DstData, */ ADataType, +/* typename SrcDesc, */ decltype(b_grid_desc), +/* typename ScaleDesc, */ decltype(scale_grid_desc), +/* typename DstDesc, */ decltype(b_block_desc), +/* typename SrcDimAccessOrder, */ BBlockTransferSrcAccessOrder, +/* typename DstDimAccessOrder, */ Sequence<0, 1, 2>, +/* index_t SrcVectorDim, */ BBlockTransferSrcVectorDim, +/* index_t DstVectorDim, */ 2, +/* index_t SrcScalarPerVector, */ BBlockTransferSrcScalarPerVector, +/* index_t ScaleScalarPerVector, */ 1, +/* index_t DstScalarPerVector, */ BBlockTransferDstScalarPerVector_K1, +/* index_t SrcScalarStrideInVector, */ 1, +/* index_t ScaleScalarStrideInVector, */ 1, +/* index_t DstScalarStrideInVector, */ 1, +/* bool ThreadTransferSrcResetCoordinateAfterRun, */ BThreadTransferSrcResetCoordinateAfterRun, +/* bool ThreadTransferDstResetCoordinateAfterRun, */ true, + NumGemmKPrefetchStage>( + b_grid_desc, + make_multi_index(0, n_block_data_idx_on_grid, 0), + b_element_op, + scale_grid_desc, + make_multi_index(0, n_block_data_idx_on_grid, 0), + ck::tensor_operation::element_wise::PassThrough{}, + b_block_desc, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + return make_tuple(b_block_buf, b_blockwise_copy); + } + else + { + // Thread-wise copy + // KPerBlock/WmmaK -> NRepeat -> NWaves -> WmmaK/K1 -> NPerWmma -> K1 + constexpr auto KWmmaPerBlock = KPerBlock / WmmaK; + constexpr auto K0PerWmma = WmmaK/2/K1Value; + auto b_block_buf = make_static_buffer( + b_block_desc.GetElementSpaceSize()); + + // Limitation: NumDim of Src and Dst descriptor should be identical + auto b_blockwise_copy = + ThreadwiseTensorSliceTransfer_v2{}, + Number{}, + I1, + Number{}, + I1, + I1, + Number{}>, + Sequence<0, 1, 2, 3, 4, 5, 6>, + 6, + BBlockTransferSrcScalarPerVector, + BThreadTransferSrcResetCoordinateAfterRun, + true>( + b_grid_desc, + make_multi_index(0, + n_block_data_idx_on_grid/(NWaves * NPerWmma), + get_thread_local_1d_id() / 32, + 0, + (get_thread_local_1d_id() % 32 )/ 16, + get_thread_local_1d_id() % 16, + 0)); + + return make_tuple(b_block_buf, b_blockwise_copy); + } + }; + + auto a_block_buf = a_block_trait()[I0]; + auto a_blockwise_copy = a_block_trait()[I1]; + + auto b_block_buf = b_block_trait()[I0]; + auto b_blockwise_copy = b_block_trait()[I1]; +/*******************************************************************************/ + // GEMM + constexpr auto KPack = math::integer_least_multiple(K1, WmmaK); + + auto blockwise_gemm = + BlockwiseGemmWMMA{}; + + // Prepare Register for C matrix + auto c_thread_buf = blockwise_gemm.GetCThreadBuffer(); + +/*******************************************************************************/ + // Shift Per SUB_K + constexpr auto a_block_slice_copy_step = MakeABlockSliceCopyStep(); + constexpr auto b_block_slice_copy_step = MakeBBlockSliceCopyStep(); + + // gridwise GEMM pipeline + const index_t KBlockMainLoop = __builtin_amdgcn_readfirstlane(K / KPerBlock); + GridwiseGemmPipe::template Run(a_grid_desc, + a_block_desc, + a_blockwise_copy, + a_grid_buf, + a_block_buf, + a_block_slice_copy_step, + b_grid_desc, + b_block_desc, + b_blockwise_copy, + b_grid_buf, + b_block_buf, + b_block_slice_copy_step, + scale_grid_desc, + scale_grid_buf, + blockwise_gemm, + c_thread_buf, + KBlockMainLoop); +/*******************************************************************************/ + // write out to C, implement shuffle + { + // C mapping in single thread. + constexpr auto c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs = + blockwise_gemm.GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs(); + + // C mapping in single block + constexpr auto c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp = + blockwise_gemm.GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs(); + + constexpr auto MWave = c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp.GetLength(I1); + constexpr auto MSubGroup = c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp.GetLength(I2); + constexpr auto NWave = c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp.GetLength(I4); + constexpr auto NThreadPerSubGroup = c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp.GetLength(I5); + constexpr auto MAccVgprs = c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp.GetLength(I6); + + // LDS descriptor, shuffle and write out in MRepeat x NRepeat times + constexpr auto c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat = + GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat(); + + auto c_shuffle_block_buf = make_dynamic_buffer( + static_cast(p_shared) + SharedMemTrait::c_shuffle_block_space_offset, + SharedMemTrait::c_shuffle_block_space_size); + + constexpr auto c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs = transform_tensor_descriptor( + c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat, + make_tuple( + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // MRepeat per shuffle repeat + MWave, // MWave + MSubGroup, // MSubGroup * MAccVgprs = MPerWmma + MAccVgprs)), + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // NRepeat per shuffle repeat + NWave, // NWave + NThreadPerSubGroup))), // NThreadPerSubGroup = NPerWmma + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<>{}, Sequence<0, 1, 2, 6>{}, Sequence<>{}, Sequence<3, 4, 5>{})); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0); + + const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; + const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; + + const auto m_thread_data_on_block_to_mrepeat_mwave_msubgroup_maccvgprs_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(MRepeat, MWave, MSubGroup, MAccVgprs))), + make_tuple(Sequence<0, 1, 2, 3>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_block_to_nrepeat_nwave_nthreadpersubgroup_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(NRepeat, NWave, NThreadPerSubGroup))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_block_idx = m_thread_data_on_block_to_mrepeat_mwave_msubgroup_maccvgprs_adaptor.CalculateBottomIndex( + make_multi_index(m_thread_data_on_block)); + + const auto n_thread_data_on_block_idx = n_thread_data_on_block_to_nrepeat_nwave_nthreadpersubgroup_adaptor.CalculateBottomIndex( + make_multi_index(n_thread_data_on_block)); + + // shuffle: threadwise copy C from VGPR to LDS + auto c_thread_copy_vgpr_to_lds = + ThreadwiseTensorSliceTransfer_v1r3, + Sequence<0, 1, 2, 3, 4, 5, 6>, + 6, + 1, // vector write pixel + InMemoryDataOperationEnum::Set, + 1, + true>{ + c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs, + make_multi_index(0, + m_thread_data_on_block_idx[I1], + m_thread_data_on_block_idx[I2], + 0, + n_thread_data_on_block_idx[I1], + n_thread_data_on_block_idx[I2], + m_thread_data_on_block_idx[I3]), + ck::tensor_operation::element_wise::PassThrough{}}; + + // shuffle: blockwise copy C from LDS to global + auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1< + ThisThreadBlock, // ThreadGroup + CElementwiseOperation, // ElementwiseOperation, + CGlobalMemoryDataOperation, // DstInMemOp, + Sequence<1, + CShuffleMRepeatPerShuffle * MWave * MPerWmma, + 1, + CShuffleNRepeatPerShuffle * NWave * NPerWmma>, // BlockSliceLengths, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, + CShuffleDataType, // typename SrcData, + CDataType, // typename DstData, + decltype(c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat), + decltype(c_grid_desc_mblock_mperblock_nblock_nperblock), + Sequence<0, 1, 2, 3>, // typename DimAccessOrder, + 3, // index_t VectorDim, + CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector, + true, // bool ThreadTransferSrcResetCoordinateAfterRun, + false> // bool ThreadTransferDstResetCoordinateAfterRun> + {c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat, + make_multi_index(0, 0, 0, 0), + c_grid_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0), + c_element_op}; + + // space filling curve for local reg & global memory + // space filling curve for threadwise C in VGPR + constexpr auto sfc_c_vgpr = + SpaceFillingCurve, + Sequence<0, 1, 2, 3, 4, 5, 6>, + Sequence>{}; + + // space filling curve for shuffled blockwise C in global mem + constexpr auto sfc_c_global = + SpaceFillingCurve, + Sequence<0, 2, 1, 3>, + Sequence<1, + CShuffleMRepeatPerShuffle * MWave * MPerWmma, + 1, + CShuffleNRepeatPerShuffle * NWave * NPerWmma>>{}; + + constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); + + static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!"); + + static_for<0, num_access, 1>{}([&](auto access_id) { + // make sure it's safe to write to LDS + block_sync_lds(); + + // each thread write its data from VGPR to LDS + c_thread_copy_vgpr_to_lds.Run(c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs, + sfc_c_vgpr.GetIndexTupleOfNumber(access_id), + c_thread_buf, + c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs, + c_shuffle_block_buf); + + // make sure it's safe to read from LDS + block_sync_lds(); + + // each block copy its data from LDS to global + c_shuffle_block_copy_lds_to_global.Run( + c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat, + c_shuffle_block_buf, + c_grid_desc_mblock_mperblock_nblock_nperblock, + c_grid_buf); + + if constexpr(access_id < num_access - 1) + { + constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id); + + // move on C + c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow( + c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step); + } + }); + } + // clang-format on + } +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp index f514e3a119..82d010a99a 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp @@ -45,8 +45,8 @@ __global__ void const BElementwiseOperation b_element_op, const CDEElementwiseOperation cde_element_op, const index_t batch_count, - const AGridDesc_AK0_M_AK1 a_grid_desc_k0_m_k1, - const BGridDesc_BK0_N_BK1 b_grid_desc_k0_n_k1, + const AGridDesc_AK0_M_AK1 a_grid_desc, + const BGridDesc_BK0_N_BK1 b_grid_desc, const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock ds_grid_desc_mblock_mperblock_nblock_nperblock, const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock @@ -69,7 +69,7 @@ __global__ void const auto ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx); - __shared__ char p_shared[GridwiseOp::GetSharedMemoryNumberOfByte()]; + __shared__ char p_shared[GridwiseOp::SharedMemTrait::lds_size]; DsPointer p_ds_grid_grp; @@ -84,8 +84,8 @@ __global__ void p_ds_grid_grp, p_e_grid + e_batch_offset, p_shared, - a_grid_desc_k0_m_k1, - b_grid_desc_k0_n_k1, + a_grid_desc, + b_grid_desc, ds_grid_desc_mblock_mperblock_nblock_nperblock, e_grid_desc_mblock_mperblock_nblock_nperblock_, a_element_op, @@ -98,8 +98,8 @@ __global__ void ignore = p_ds_grid; ignore = p_e_grid; ignore = batch_count; - ignore = a_grid_desc_k0_m_k1; - ignore = b_grid_desc_k0_n_k1; + ignore = a_grid_desc; + ignore = b_grid_desc; ignore = ds_grid_desc_mblock_mperblock_nblock_nperblock; ignore = e_grid_desc_mblock_mperblock_nblock_nperblock_; ignore = a_element_op; @@ -115,8 +115,8 @@ template (compute_ptr_offset_of_batch.GetAPtrOffset(g_idx))); const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane( @@ -170,20 +169,16 @@ __global__ void DsPointer p_ds_grid_grp; - // printf("before allocate pointer d"); - static_for<0, NumDTensor, 1>{}( [&](auto i) { p_ds_grid_grp(i) = p_ds_grid[i] + ds_batch_offset[i]; }); - // printf("before entry"); - GridwiseOp::template Run(p_a_grid + a_batch_offset, p_b_grid + b_batch_offset, p_ds_grid_grp, p_e_grid + e_batch_offset, p_shared, - a_grid_desc_k0_m_k1, - b_grid_desc_k0_n_k1, + a_grid_desc, + b_grid_desc, ds_grid_desc_mblock_mperblock_nblock_nperblock, e_grid_desc_mblock_mperblock_nblock_nperblock, a_element_op, @@ -199,8 +194,8 @@ __global__ void ignore = a_element_op; ignore = b_element_op; ignore = cde_element_op; - ignore = a_grid_desc_k0_m_k1; - ignore = b_grid_desc_k0_n_k1; + ignore = a_grid_desc; + ignore = b_grid_desc; ignore = ds_grid_desc_mblock_mperblock_nblock_nperblock; ignore = e_grid_desc_mblock_mperblock_nblock_nperblock; ignore = block_2_etile_map; @@ -213,8 +208,8 @@ template (p_a_grid, p_b_grid, p_ds_grid, p_e_grid, p_shared, - a_grid_desc_k0_m_k1, - b_grid_desc_k0_n_k1, + a_grid_desc, + b_grid_desc, ds_grid_desc_mblock_mperblock_nblock_nperblock, e_grid_desc_mblock_mperblock_nblock_nperblock, a_element_op, @@ -263,8 +258,8 @@ __global__ void ignore = p_b_grid; ignore = p_ds_grid; ignore = p_e_grid; - ignore = a_grid_desc_k0_m_k1; - ignore = b_grid_desc_k0_n_k1; + ignore = a_grid_desc; + ignore = b_grid_desc; ignore = ds_grid_desc_mblock_mperblock_nblock_nperblock; ignore = e_grid_desc_mblock_mperblock_nblock_nperblock; ignore = a_element_op; @@ -282,8 +277,8 @@ template < // DataType Family typename DsDataType, typename EDataType, // InMemory Data Descriptor - typename AGridDesc_K0_M_K1, - typename BGridDesc_K0_N_K1, + typename AGridDesc, + typename BGridDesc, typename DsGridDesc_M_N, typename EGridDesc_M_N, // ElementwiseOp Family @@ -294,7 +289,7 @@ template < // DataType Family // Tiling Family index_t MPerBlock, index_t NPerBlock, - index_t K0PerBlock, + index_t KPerBlock, index_t MPerWmma, index_t NPerWmma, index_t K1Value, @@ -309,6 +304,7 @@ template < // DataType Family index_t ABlockTransferSrcScalarPerVector, index_t ABlockTransferDstScalarPerVector_K1, bool AThreadTransferSrcResetCoordinateAfterRun, + bool AEnableLds, bool ABlockLdsExtraM, typename BBlockTransferThreadClusterLengths_K0_N_K1, typename BBlockTransferThreadClusterArrangeOrder, @@ -317,6 +313,7 @@ template < // DataType Family index_t BBlockTransferSrcScalarPerVector, index_t BBlockTransferDstScalarPerVector_K1, bool BThreadTransferSrcResetCoordinateAfterRun, + bool BEnableLds, bool BBlockLdsExtraN, index_t CShuffleMRepeatPerShuffle, index_t CShuffleNRepeatPerShuffle, @@ -325,7 +322,7 @@ template < // DataType Family index_t NumGemmKPrefetchStage = 1, LoopScheduler LoopSched = make_default_loop_scheduler(), PipelineVersion PipelineVer = PipelineVersion::v1> -struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle +struct GridwiseGemmMultipleD_Wmma { static constexpr index_t NumDTensor = DsDataType::Size(); @@ -341,53 +338,233 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle // K1 should be Number<...> static constexpr auto K1 = Number{}; + static constexpr auto MWaves = MPerBlock / (MRepeat * MPerWmma); + static constexpr auto NWaves = NPerBlock / (NRepeat * NPerWmma); + static constexpr auto WmmaK = K1 == 16 ? 32 : 16; + using ThisThreadBlock = ThisThreadBlock; - using GridwiseGemmPipe = remove_cvref_t< - decltype(GridwiseGemmPipeline_Selector())>; + using GridwiseGemmPipe = + remove_cvref_t())>; - __host__ __device__ static constexpr auto GetABlockDescriptor_K0PerBlock_MPerBlock_K1() + // Describe how data store to (LDS/VGPR) buffer from Global memory + __host__ __device__ static constexpr auto MakeABlockDescriptor() { - constexpr auto max_lds_align = K1; - - // A matrix in LDS memory, dst of blockwise copy - constexpr auto a_block_desc_k0perblock_mperblock_k1 = [&]() { - if constexpr(ABlockLdsExtraM) + constexpr auto a_block_desc = [&]() { + if constexpr(AEnableLds) { - return make_naive_tensor_descriptor( - make_tuple(Number{}, Number{}, K1), - make_tuple(Number{} * K1, K1, I1)); + // K0->M->K1 Per Block + constexpr auto K0PerBlock = KPerBlock / K1; + constexpr auto max_lds_align = K1; + + if constexpr(ABlockLdsExtraM) + { + return make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, K1), + make_tuple(Number{} * K1, K1, I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, K1), max_lds_align); + } } else { - return make_naive_tensor_descriptor_aligned( - make_tuple(Number{}, Number{}, K1), max_lds_align); + constexpr auto KWmmaPerblock = KPerBlock / WmmaK; + constexpr auto K0PerWmma = WmmaK / 2 / K1; + // KWmma->MRepeat->MWave->K0PerWmma->KRow->MPerWmma->K1 Per Thread + return make_naive_tensor_descriptor( + make_tuple(Number{}, + Number{}, + I1, + Number{}, + I1, + I1, + K1), + make_tuple(Number{} * Number{} * K1, + Number{} * K1, + Number{} * K1, + K1, + K1, + K1, + I1)); } }(); - return a_block_desc_k0perblock_mperblock_k1; + return a_block_desc; } - __host__ __device__ static constexpr auto GetBBlockDescriptor_K0PerBlock_NPerBlock_K1() + __host__ __device__ static constexpr auto MakeBBlockDescriptor() { - constexpr auto max_lds_align = K1; - - // B matrix in LDS memory, dst of blockwise copy - constexpr auto b_block_desc_k0perblock_nperblock_k1 = [&]() { - if constexpr(BBlockLdsExtraN) + constexpr auto b_block_desc = [&]() { + if constexpr(BEnableLds) { - return make_naive_tensor_descriptor( - make_tuple(Number{}, Number{}, K1), - make_tuple(Number{} * K1, K1, I1)); + // K0->N->K1 Per Block + constexpr auto K0PerBlock = KPerBlock / K1; + constexpr auto max_lds_align = K1; + + if constexpr(BBlockLdsExtraN) + { + return make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, K1), + make_tuple(Number{} * K1, K1, I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, K1), max_lds_align); + } } else { - return make_naive_tensor_descriptor_aligned( - make_tuple(Number{}, Number{}, K1), max_lds_align); + constexpr auto KWmmaPerblock = KPerBlock / WmmaK; + constexpr auto K0PerWmma = WmmaK / 2 / K1; + // KWmma->NRepeat->MWave->K0PerWmma->KRow->MPerWmma->K1 Per Thread + return make_naive_tensor_descriptor( + make_tuple(Number{}, + Number{}, + I1, + Number{}, + I1, + I1, + K1), + make_tuple(Number{} * Number{} * K1, + Number{} * K1, + Number{} * K1, + K1, + K1, + K1, + I1)); } }(); - return b_block_desc_k0perblock_nperblock_k1; + return b_block_desc; + } + + __host__ __device__ static constexpr auto MakeABlockSliceCopyStep() + { + constexpr auto a_block_copy_step = [&]() { + if constexpr(AEnableLds) + { + constexpr auto K0PerBlock = KPerBlock / K1; + + return make_multi_index(K0PerBlock, 0, 0); + } + else + { + constexpr auto KWmmaPerBlock = KPerBlock / WmmaK; + + return make_multi_index(KWmmaPerBlock, 0, 0, 0, 0, 0, 0); + } + }(); + + return a_block_copy_step; + } + + __host__ __device__ static constexpr auto MakeBBlockSliceCopyStep() + { + constexpr auto b_block_copy_step = [&]() { + if constexpr(BEnableLds) + { + constexpr auto K0PerBlock = KPerBlock / K1; + + return make_multi_index(K0PerBlock, 0, 0); + } + else + { + constexpr auto KWmmaPerBlock = KPerBlock / WmmaK; + + return make_multi_index(KWmmaPerBlock, 0, 0, 0, 0, 0, 0); + } + }(); + + return b_block_copy_step; + } + + // Describe how data read from (LDS/VGPR) buffer + template + __host__ __device__ static constexpr auto MakeAWaveDescriptor(const ABlockDesc_&) + { + + constexpr auto a_wave_desc = [&]() { + if constexpr(AEnableLds) + { + // AK0_M_AK1 -> AK0_MRepeat_Mwaves_AKRow_MPerWmma_AK1 + constexpr auto A_K0 = ABlockDesc_{}.GetLength(I0); + constexpr auto A_K1 = ABlockDesc_{}.GetLength(I2); + constexpr auto A_KRow = I1; + return transform_tensor_descriptor( + ABlockDesc_{}, + make_tuple(make_unmerge_transform(make_tuple(Number{}, A_KRow)), + make_unmerge_transform(make_tuple( + Number{}, Number{}, Number{})), + make_pass_through_transform(Number{})), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0, 3>{}, Sequence<1, 2, 4>{}, Sequence<5>{})); + } + else + { + // KWmma_MRepeat_MWave_K0PerWmma_KRow_MPerWmma_K1 -> K0_MRepeat_Mwaves_MPerWmma_K1 + constexpr auto KWmma = ABlockDesc_{}.GetLength(I0); + constexpr auto K0PerWmma = ABlockDesc_{}.GetLength(I3); + constexpr auto A_KRow = ABlockDesc_{}.GetLength(I4); + constexpr auto A_K1 = ABlockDesc_{}.GetLength(I6); + + return make_naive_tensor_descriptor_packed(make_tuple(Number{}, + Number{}, + I1, + Number{}, + I1, + Number{})); + } + }(); + + return a_wave_desc; + } + + template + __host__ __device__ static constexpr auto MakeBWaveDescriptor(const BBlockDesc_&) + { + constexpr auto b_wave_desc = [&]() { + if constexpr(BEnableLds) + { + // BK0_N_BK1 -> BK0_NRepeat_Nwaves_NPerWmma_BK1 + constexpr auto B_K0 = BBlockDesc_{}.GetLength(I0); + constexpr auto B_K1 = BBlockDesc_{}.GetLength(I2); + constexpr auto B_KRow = I1; + return transform_tensor_descriptor( + BBlockDesc_{}, + make_tuple(make_unmerge_transform(make_tuple(Number{}, B_KRow)), + make_unmerge_transform(make_tuple( + Number{}, Number{}, Number{})), + make_pass_through_transform(Number{})), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0, 3>{}, Sequence<1, 2, 4>{}, Sequence<5>{})); + } + else + { + // KWmma_MRepeat_MWave_K0PerWmma_KRow_MPerWmma_K1 -> K0_MRepeat_Mwaves_MPerWmma_K1 + constexpr auto KWmma = BBlockDesc_{}.GetLength(I0); + constexpr auto K0PerWmma = BBlockDesc_{}.GetLength(I3); + constexpr auto B_KRow = BBlockDesc_{}.GetLength(I4); + constexpr auto B_K1 = BBlockDesc_{}.GetLength(I6); + + // Workaround, Freeze transform + return make_naive_tensor_descriptor_packed(make_tuple(Number{}, + Number{}, + I1, + Number{}, + I1, + Number{})); + } + }(); + + return b_wave_desc; } __host__ __device__ static constexpr auto @@ -419,43 +596,12 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle Number{}); } - __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() - { - // LDS allocation for A and B: be careful of alignment - constexpr auto a_block_desc_k0perblock_mperblock_k1 = - GetABlockDescriptor_K0PerBlock_MPerBlock_K1(); - - constexpr auto b_block_desc_k0perblock_nperblock_k1 = - GetBBlockDescriptor_K0PerBlock_NPerBlock_K1(); - - constexpr auto cshuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat = - GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat(); - - constexpr auto max_lds_align = K1; - - constexpr auto a_block_space_size_aligned = math::integer_least_multiple( - a_block_desc_k0perblock_mperblock_k1.GetElementSpaceSize(), max_lds_align); - - constexpr auto b_block_space_size_aligned = math::integer_least_multiple( - b_block_desc_k0perblock_nperblock_k1.GetElementSpaceSize(), max_lds_align); - - constexpr auto c_block_space_size_aligned = math::integer_least_multiple( - cshuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat.GetElementSpaceSize(), - max_lds_align); - - return math::max((a_block_space_size_aligned * sizeof(ADataType) + - b_block_space_size_aligned * sizeof(BDataType)), - c_block_space_size_aligned * sizeof(CShuffleDataType)); - } - - // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} // CheckValidity for kernels without multi D template - __host__ __device__ static constexpr bool - CheckValidity(const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1, - const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1, - const EGridDesc_M_N& e_grid_desc_m_n, - const Block2CTileMap& block_2_ctile_map) + __host__ __device__ static constexpr bool CheckValidity(const AGridDesc& a_grid_desc, + const BGridDesc& b_grid_desc, + const EGridDesc_M_N& e_grid_desc_m_n, + const Block2CTileMap& block_2_ctile_map) { static_assert(is_known_at_compile_time>::value, "wrong! K1 need to be known at compile-time"); @@ -464,20 +610,55 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle (NPerBlock % (NRepeat * NPerWmma)) == 0, "Invalid tuning param!"); - const auto M = a_grid_desc_k0_m_k1.GetLength(I1); - const auto N = b_grid_desc_k0_n_k1.GetLength(I1); - const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0); + const auto GetAProblemsizeMK = [&]() { + if constexpr(AEnableLds) + { + return make_tuple(a_grid_desc.GetLength(I1), + a_grid_desc.GetLength(I0) * a_grid_desc.GetLength(I2)); + } + else + { + return make_tuple(a_grid_desc.GetLength(I1) * a_grid_desc.GetLength(I2) * + a_grid_desc.GetLength(I5), + a_grid_desc.GetLength(I0) * a_grid_desc.GetLength(I3) * + a_grid_desc.GetLength(I4) * a_grid_desc.GetLength(I6)); + } + }; + + const auto GetBProblemsizeNK = [&]() { + if constexpr(BEnableLds) + { + return make_tuple(b_grid_desc.GetLength(I1), + b_grid_desc.GetLength(I0) * b_grid_desc.GetLength(I2)); + } + else + { + return make_tuple(b_grid_desc.GetLength(I1) * b_grid_desc.GetLength(I2) * + b_grid_desc.GetLength(I5), + b_grid_desc.GetLength(I0) * b_grid_desc.GetLength(I3) * + b_grid_desc.GetLength(I4) * b_grid_desc.GetLength(I6)); + } + }; + + const auto M = GetAProblemsizeMK()[I0]; + const auto N = GetBProblemsizeNK()[I0]; + const auto K = GetAProblemsizeMK()[I1]; if(!(M == e_grid_desc_m_n.GetLength(I0) && N == e_grid_desc_m_n.GetLength(I1) && - K0 == b_grid_desc_k0_n_k1.GetLength(I0) && K1 == a_grid_desc_k0_m_k1.GetLength(I2) && - K1 == b_grid_desc_k0_n_k1.GetLength(I2))) + K == GetBProblemsizeNK()[I1])) + { + printf("GridwiseOp: ABE descriptor dimension cross check failure\n"); return false; + } - if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % K0PerBlock == 0)) + if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0)) + { + printf("GridwiseOp: Problemsize descriptor dimension check failure\n"); return false; + } // check gridwise gemm pipeline - const auto num_k_loop = K0 / K0PerBlock; + const auto num_k_loop = K / KPerBlock; if(!GridwiseGemmPipe::IsSupported(num_k_loop)) { @@ -492,8 +673,8 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) constexpr long_index_t TwoGB = (long_index_t{1} << 31); - if(!(a_grid_desc_k0_m_k1.GetElementSpaceSize() * sizeof(ADataType) <= TwoGB && - b_grid_desc_k0_n_k1.GetElementSpaceSize() * sizeof(BDataType) <= TwoGB && + if(!(a_grid_desc.GetElementSpaceSize() * sizeof(ADataType) <= TwoGB && + b_grid_desc.GetElementSpaceSize() * sizeof(BDataType) <= TwoGB && e_grid_desc_m_n.GetElementSpaceSize() * sizeof(EDataType) <= TwoGB)) { return false; @@ -502,17 +683,57 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle return true; } + // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} template - __host__ __device__ static constexpr bool - CheckValidity(const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1, - const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1, - const DsGridDesc_M_N& ds_grid_desc_m_n, - const EGridDesc_M_N& e_grid_desc_m_n, - const Block2CTileMap& block_2_ctile_map) + __host__ __device__ static constexpr bool CheckValidity(const AGridDesc& a_grid_desc, + const BGridDesc& b_grid_desc, + const DsGridDesc_M_N& ds_grid_desc_m_n, + const EGridDesc_M_N& e_grid_desc_m_n, + const Block2CTileMap& block_2_ctile_map) { - const auto M = a_grid_desc_k0_m_k1.GetLength(I1); - const auto N = b_grid_desc_k0_n_k1.GetLength(I1); - bool valid = true; + static_assert(is_known_at_compile_time>::value, + "wrong! K1 need to be known at compile-time"); + + static_assert((MPerBlock % (MPerWmma * MRepeat) == 0) && + (NPerBlock % (NRepeat * NPerWmma)) == 0, + "Invalid tuning param!"); + + const auto GetAProblemsizeMK = [&]() { + if constexpr(AEnableLds) + { + return make_tuple(a_grid_desc.GetLength(I1), + a_grid_desc.GetLength(I0) * a_grid_desc.GetLength(I2)); + } + else + { + return make_tuple(a_grid_desc.GetLength(I1) * a_grid_desc.GetLength(I2) * + a_grid_desc.GetLength(I5), + a_grid_desc.GetLength(I0) * a_grid_desc.GetLength(I3) * + a_grid_desc.GetLength(I4) * a_grid_desc.GetLength(I6)); + } + }; + + const auto GetBProblemsizeNK = [&]() { + if constexpr(BEnableLds) + { + return make_tuple(b_grid_desc.GetLength(I1), + b_grid_desc.GetLength(I0) * b_grid_desc.GetLength(I2)); + } + else + { + return make_tuple(b_grid_desc.GetLength(I1) * b_grid_desc.GetLength(I2) * + b_grid_desc.GetLength(I5), + b_grid_desc.GetLength(I0) * b_grid_desc.GetLength(I3) * + b_grid_desc.GetLength(I4) * b_grid_desc.GetLength(I6)); + } + }; + + const auto M = GetAProblemsizeMK()[I0]; + const auto N = GetBProblemsizeNK()[I0]; + const auto K = GetAProblemsizeMK()[I1]; + + bool valid = true; + static_for<0, NumDTensor, 1>{}([&](auto i) { valid = valid && (M == ds_grid_desc_m_n[i].GetLength(I0) && N == ds_grid_desc_m_n[i].GetLength(I1)); @@ -520,16 +741,52 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle if(!valid) { + printf("GridwiseOp: D descriptor dimension check failure\n"); return false; } - return CheckValidity( - a_grid_desc_k0_m_k1, b_grid_desc_k0_n_k1, e_grid_desc_m_n, block_2_ctile_map); + if(!(M == e_grid_desc_m_n.GetLength(I0) && N == e_grid_desc_m_n.GetLength(I1) && + K == GetBProblemsizeNK()[I1])) + { + printf("GridwiseOp: ABE descriptor dimension cross check failure\n"); + return false; + } + + if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0)) + { + printf("GridwiseOp: Problemsize descriptor dimension check failure\n"); + return false; + } + + // check gridwise gemm pipeline + const auto num_k_loop = K / KPerBlock; + + if(!GridwiseGemmPipe::IsSupported(num_k_loop)) + { + return false; + } + + if(!block_2_ctile_map.CheckValidity(e_grid_desc_m_n)) + { + return false; + } + + // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) + constexpr long_index_t TwoGB = (long_index_t{1} << 31); + + if(!(a_grid_desc.GetElementSpaceSize() * sizeof(ADataType) <= TwoGB && + b_grid_desc.GetElementSpaceSize() * sizeof(BDataType) <= TwoGB && + e_grid_desc_m_n.GetElementSpaceSize() * sizeof(EDataType) <= TwoGB)) + { + return false; + } + + return true; } __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K) { - const index_t num_loop = K / (K0PerBlock * K1); + const index_t num_loop = K / KPerBlock; return GridwiseGemmPipe::CalculateHasMainLoop(num_loop); } @@ -542,9 +799,8 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle const auto M = e_grid_desc_m_n.GetLength(I0); const auto N = e_grid_desc_m_n.GetLength(I1); - const auto MBlock = M / MPerBlock; - const auto NBlock = N / NPerBlock; - + const auto MBlock = M / MPerBlock; + const auto NBlock = N / NPerBlock; const auto e_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor( e_grid_desc_m_n, make_tuple(make_unmerge_transform(make_tuple(MBlock, Number{})), @@ -575,6 +831,37 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle e_grid_desc_m_n); } + struct SharedMemTrait + { + // LDS allocation for A and B: be careful of alignment + + static constexpr auto max_lds_align = K1; + + static constexpr auto a_block_space_size_aligned = + AEnableLds ? math::integer_least_multiple(MakeABlockDescriptor().GetElementSpaceSize(), + max_lds_align) + : 0; + static constexpr auto b_block_space_size_aligned = + BEnableLds ? math::integer_least_multiple(MakeBBlockDescriptor().GetElementSpaceSize(), + max_lds_align) + : 0; + + static constexpr auto a_block_space_offset = 0; + static constexpr auto b_block_space_offset = a_block_space_size_aligned; + + // LDS allocation for C shuffle in LDS + static constexpr auto c_shuffle_block_space_size = + GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat() + .GetElementSpaceSize(); + + static constexpr auto c_shuffle_block_space_offset = 0; + + static constexpr auto lds_size = + math::max(c_shuffle_block_space_size * sizeof(CShuffleDataType), + a_block_space_size_aligned * sizeof(ADataType) + + b_block_space_size_aligned * sizeof(BDataType)); + }; + using DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t; @@ -591,8 +878,8 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle DsGridPointer p_ds_grid, EDataType* __restrict__ p_e_grid, void* __restrict__ p_shared, - const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1, - const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1, + const AGridDesc& a_grid_desc, + const BGridDesc& b_grid_desc, const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock& ds_grid_desc_mblock_mperblock_nblock_nperblock, const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock& @@ -602,14 +889,13 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle const CDEElementwiseOperation& cde_element_op, const Block2CTileMap& block_2_ctile_map) { - // printf("safe entry"); // clang-format off /*******************************************************************************/ // Memory buffer zone. const auto a_grid_buf = make_dynamic_buffer( - p_a_grid, a_grid_desc_k0_m_k1.GetElementSpaceSize()); + p_a_grid, a_grid_desc.GetElementSpaceSize()); const auto b_grid_buf = make_dynamic_buffer( - p_b_grid, b_grid_desc_k0_n_k1.GetElementSpaceSize()); + p_b_grid, b_grid_desc.GetElementSpaceSize()); const auto ds_grid_buf = generate_tuple( [&](auto i) { return make_dynamic_buffer( @@ -635,13 +921,30 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle /*******************************************************************************/ // BlockLevel, A/B Matrix ThreadMapping in LDS, As Destinaion of BlockWise_Copy - const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0); - constexpr auto max_lds_align = K1; - constexpr auto a_block_desc_k0perblock_mperblock_k1 = GetABlockDescriptor_K0PerBlock_MPerBlock_K1(); - constexpr auto b_block_desc_k0perblock_nperblock_k1 = GetBBlockDescriptor_K0PerBlock_NPerBlock_K1(); - // A matrix blockwise copy - auto a_blockwise_copy = - ThreadGroupTensorSliceTransfer_v4r1< ThisThreadBlock, + const auto K = [&](){ + if constexpr(AEnableLds){ + return a_grid_desc.GetLength(I0) * a_grid_desc.GetLength(I2); + } + else{ + return a_grid_desc.GetLength(I0) * a_grid_desc.GetLength(I3) * + a_grid_desc.GetLength(I4) * a_grid_desc.GetLength(I6); + } + }(); + + constexpr auto a_block_desc = MakeABlockDescriptor(); + constexpr auto b_block_desc = MakeBBlockDescriptor(); + + auto a_block_trait = [&](){ + // A matrix blockwise copy + if constexpr(AEnableLds) + { + constexpr auto K0PerBlock = KPerBlock/ K1; + auto a_block_buf = make_dynamic_buffer( + static_cast(p_shared), + a_block_desc.GetElementSpaceSize()); + + auto a_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, /* index_t SrcVectorDim, */ ABlockTransferSrcVectorDim, @@ -661,92 +964,189 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle /* index_t SrcScalarStrideInVector, */ 1, /* index_t DstScalarStrideInVector, */ 1, /* bool ThreadTransferSrcResetCoordinateAfterRun, */ AThreadTransferSrcResetCoordinateAfterRun, -/* bool ThreadTransferDstResetCoordinateAfterRun, */ true>( - a_grid_desc_k0_m_k1, +/* bool ThreadTransferDstResetCoordinateAfterRun, */ true, + NumGemmKPrefetchStage>( + a_grid_desc, make_multi_index(0, m_block_data_idx_on_grid, 0), a_element_op, - a_block_desc_k0perblock_mperblock_k1, + a_block_desc, make_multi_index(0, 0, 0), ck::tensor_operation::element_wise::PassThrough{}); - // B matrix blockwise copy - auto b_blockwise_copy = - ThreadGroupTensorSliceTransfer_v4r1, - BBlockTransferThreadClusterLengths_K0_N_K1, - BBlockTransferThreadClusterArrangeOrder, - BDataType, - BDataType, - decltype(b_grid_desc_k0_n_k1), - decltype(b_block_desc_k0perblock_nperblock_k1), - BBlockTransferSrcAccessOrder, - Sequence<0, 1, 2>, - BBlockTransferSrcVectorDim, - 2, - BBlockTransferSrcScalarPerVector, - BBlockTransferDstScalarPerVector_K1, - 1, - 1, - BThreadTransferSrcResetCoordinateAfterRun, - true>( - b_grid_desc_k0_n_k1, - make_multi_index(0, n_block_data_idx_on_grid, 0), - b_element_op, - b_block_desc_k0perblock_nperblock_k1, - make_multi_index(0, 0, 0), - ck::tensor_operation::element_wise::PassThrough{}); + return make_tuple(a_block_buf, a_blockwise_copy); + } + else + { + // Thread-wise copy + // KPerBlock/WmmaK -> MRepeat -> MWaves -> K0PerWmma -> KRow -> MPerWmma -> K1 + constexpr auto KWmmaPerBlock = KPerBlock / WmmaK; + constexpr auto K0PerWmma = WmmaK/2/K1Value; + auto a_block_buf = make_static_buffer( + a_block_desc.GetElementSpaceSize()); + + // Limitation: NumDim of Src and Dst descriptor should be identical + auto a_blockwise_copy = + ThreadwiseTensorSliceTransfer_v2{}, + Number{}, + I1, + Number{}, + I1, + I1, + Number{}>, + Sequence<0, 1, 2, 3, 4, 5, 6>, + 6, + ABlockTransferSrcScalarPerVector, + AThreadTransferSrcResetCoordinateAfterRun, + true>( + a_grid_desc, + make_multi_index(0, + m_block_data_idx_on_grid/(MWaves * MPerWmma), + get_thread_local_1d_id() / 32, + 0, + (get_thread_local_1d_id() % 32 )/ 16, + get_thread_local_1d_id() % 16, + 0)); + + return make_tuple(a_block_buf, a_blockwise_copy); + } + }; + auto b_block_trait = [&](){ + if constexpr(BEnableLds) + { + constexpr auto K0PerBlock = KPerBlock/ K1; + auto b_block_buf = make_dynamic_buffer( + static_cast(p_shared) + SharedMemTrait::a_block_space_size_aligned, + b_block_desc.GetElementSpaceSize()); + + auto b_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + BBlockTransferThreadClusterLengths_K0_N_K1, + BBlockTransferThreadClusterArrangeOrder, + BDataType, + BDataType, + decltype(b_grid_desc), + decltype(b_block_desc), + BBlockTransferSrcAccessOrder, + Sequence<0, 1, 2>, + BBlockTransferSrcVectorDim, + 2, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + 1, + 1, + BThreadTransferSrcResetCoordinateAfterRun, + true, + NumGemmKPrefetchStage>( + b_grid_desc, + make_multi_index(0, n_block_data_idx_on_grid, 0), + b_element_op, + b_block_desc, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + return make_tuple(b_block_buf, b_blockwise_copy); + } + else + { + // Thread-wise copy + // KPerBlock/WmmaK -> NRepeat -> NWaves -> WmmaK/K1 -> NPerWmma -> K1 + constexpr auto KWmmaPerBlock = KPerBlock / WmmaK; + constexpr auto K0PerWmma = WmmaK/2/K1Value; + auto b_block_buf = make_static_buffer( + b_block_desc.GetElementSpaceSize()); + + // Limitation: NumDim of Src and Dst descriptor should be identical + auto b_blockwise_copy = + ThreadwiseTensorSliceTransfer_v2{}, + Number{}, + I1, + Number{}, + I1, + I1, + Number{}>, + Sequence<0, 1, 2, 3, 4, 5, 6>, + 6, + BBlockTransferSrcScalarPerVector, + BThreadTransferSrcResetCoordinateAfterRun, + true>( + b_grid_desc, + make_multi_index(0, + n_block_data_idx_on_grid/(NWaves * NPerWmma), + get_thread_local_1d_id() / 32, + 0, + (get_thread_local_1d_id() % 32 )/ 16, + get_thread_local_1d_id() % 16, + 0)); + + return make_tuple(b_block_buf, b_blockwise_copy); + } + }; + + auto a_block_buf = a_block_trait()[I0]; + auto a_blockwise_copy = a_block_trait()[I1]; + + auto b_block_buf = b_block_trait()[I0]; + auto b_blockwise_copy = b_block_trait()[I1]; /*******************************************************************************/ // GEMM - constexpr auto WmmaK = 16; constexpr auto KPack = math::integer_least_multiple(K1, WmmaK); auto blockwise_gemm = - BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle{}; + BlockwiseGemmWMMA{}; // Prepare Register for C matrix auto c_thread_buf = blockwise_gemm.GetCThreadBuffer(); -/*******************************************************************************/ - constexpr auto a_block_space_size_aligned = math::integer_least_multiple(a_block_desc_k0perblock_mperblock_k1.GetElementSpaceSize(), max_lds_align); - // LDS allocation for A and B: be careful of alignment - auto a_block_buf = make_dynamic_buffer(static_cast(p_shared), a_block_desc_k0perblock_mperblock_k1.GetElementSpaceSize()); - auto b_block_buf = make_dynamic_buffer(static_cast(p_shared) + a_block_space_size_aligned, b_block_desc_k0perblock_nperblock_k1.GetElementSpaceSize()); - +/*******************************************************************************/ // Shift Per SUB_K - constexpr auto a_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0); - constexpr auto b_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0); + constexpr auto a_block_slice_copy_step = MakeABlockSliceCopyStep(); + constexpr auto b_block_slice_copy_step = MakeBBlockSliceCopyStep(); // gridwise GEMM pipeline - const index_t K0BlockMainLoop = __builtin_amdgcn_readfirstlane(K0 / K0PerBlock); - GridwiseGemmPipe::template Run(a_grid_desc_k0_m_k1, - a_block_desc_k0perblock_mperblock_k1, + const index_t KBlockMainLoop = __builtin_amdgcn_readfirstlane(K / KPerBlock); + GridwiseGemmPipe::template Run(a_grid_desc, + a_block_desc, a_blockwise_copy, a_grid_buf, a_block_buf, a_block_slice_copy_step, - b_grid_desc_k0_n_k1, - b_block_desc_k0perblock_nperblock_k1, + b_grid_desc, + b_block_desc, b_blockwise_copy, b_grid_buf, b_block_buf, b_block_slice_copy_step, blockwise_gemm, c_thread_buf, - K0BlockMainLoop); + KBlockMainLoop); /*******************************************************************************/ // write out to C, implement shuffle { diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp index ecbcb61f3e..567c42362c 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp @@ -17,18 +17,21 @@ enum struct PipelineVersion v2, // v3 is only used in the Stream-K implementation. v4, + weight_only, }; template + LoopScheduler LoopSched = LoopScheduler::Default, + bool AEnableLds = true, + bool BEnableLds = true> constexpr auto GridwiseGemmPipeline_Selector() { if constexpr(PipelineVer == PipelineVersion::v1) { if constexpr(LoopSched == LoopScheduler::Default) { - return GridwiseGemmPipeline_v1{}; + return GridwiseGemmPipeline_v1{}; } else if constexpr(LoopSched == LoopScheduler::Interwave) { @@ -43,6 +46,10 @@ constexpr auto GridwiseGemmPipeline_Selector() { return GridwiseGemmPipeline_v4{}; } + else if constexpr(PipelineVer == PipelineVersion::weight_only) + { + return GridwiseGemmPipeline_v1_WeightOnly{}; + } else { std::cerr << "GridwiseGemmPipeline configuration is not available" << std::endl; diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp index 754a3e89c9..0cdb7ce2ca 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp @@ -9,12 +9,12 @@ namespace ck { -template +template struct GridwiseGemmPipeline_v1; // 1-stage prefetch template <> -struct GridwiseGemmPipeline_v1<1> +struct GridwiseGemmPipeline_v1<1, true, true> { static constexpr auto I0 = Number<0>{}; static constexpr auto I1 = Number<1>{}; @@ -108,7 +108,7 @@ struct GridwiseGemmPipeline_v1<1> // 2-stage prefetch template <> -struct GridwiseGemmPipeline_v1<2> +struct GridwiseGemmPipeline_v1<2, true, true> { static constexpr auto I0 = Number<0>{}; static constexpr auto I1 = Number<1>{}; @@ -254,6 +254,406 @@ struct GridwiseGemmPipeline_v1<2> } }; +template <> +struct GridwiseGemmPipeline_v1<1, false, true> +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + + __host__ __device__ static constexpr bool IsSupported(index_t /* num_loop */) { return true; } + + __host__ __device__ static constexpr bool CalculateHasMainLoop(index_t num_loop) + { + return num_loop > 1; + } + + template + __device__ static void Run(const AGridDesc& a_grid_desc, + const ABlockDesc& a_block_desc, + ABlockTransfer& a_blockwise_copy, + const AGridBuffer& a_grid_buf, + ABlockBuffer& a_block_buf, + const ABlockTransferStep& a_block_copy_step, + const BGridDesc& b_grid_desc, + const BBlockDesc& b_block_desc, + BBlockTransfer& b_blockwise_copy, + const BGridBuffer& b_grid_buf, + BBlockBuffer& b_block_buf, + const BBlockTransferStep& b_block_copy_step, + const BlockwiseGemm& blockwise_gemm, + CThreadBuffer& c_thread_buf, + index_t num_loop) + { + constexpr auto a_block_origin_idx = make_tuple(I0, I0, I0, I0, I0, I0, I0); + auto a_block_buf_switch = a_block_buf; + + // preload data into LDS + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); + a_blockwise_copy.Run( + a_grid_desc, a_grid_buf, a_block_desc, a_block_origin_idx, a_block_buf); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + // Initialize C + c_thread_buf.Clear(); + + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf); + + // main body + if constexpr(HasMainLoop) + { + index_t i = 0; + + do + { + a_blockwise_copy.Run( + a_grid_desc, a_grid_buf, a_block_desc, a_block_origin_idx, a_block_buf_switch); + + block_sync_lds(); + + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); + + blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); + + block_sync_lds(); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf); + + a_block_buf = a_block_buf_switch; + ++i; + } while(i < (num_loop - 1)); + } + + // tail + { + block_sync_lds(); + + blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); + + block_sync_lds(); + } + } +}; + +template <> +struct GridwiseGemmPipeline_v1<1, true, false> +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + + __host__ __device__ static constexpr bool IsSupported(index_t /* num_loop */) { return true; } + + __host__ __device__ static constexpr bool CalculateHasMainLoop(index_t num_loop) + { + return num_loop > 1; + } + + template + __device__ static void Run(const AGridDesc& a_grid_desc, + const ABlockDesc& a_block_desc, + ABlockTransfer& a_blockwise_copy, + const AGridBuffer& a_grid_buf, + ABlockBuffer& a_block_buf, + const ABlockTransferStep& a_block_copy_step, + const BGridDesc& b_grid_desc, + const BBlockDesc& b_block_desc, + BBlockTransfer& b_blockwise_copy, + const BGridBuffer& b_grid_buf, + BBlockBuffer& b_block_buf, + const BBlockTransferStep& b_block_copy_step, + const BlockwiseGemm& blockwise_gemm, + CThreadBuffer& c_thread_buf, + index_t num_loop) + { + constexpr auto b_block_origin_idx = make_tuple(I0, I0, I0, I0, I0, I0, I0); + auto b_block_buf_switch = b_block_buf; + + // preload data into LDS + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + b_blockwise_copy.Run( + b_grid_desc, b_grid_buf, b_block_desc, b_block_origin_idx, b_block_buf); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + // Initialize C + c_thread_buf.Clear(); + + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); + + // main body + if constexpr(HasMainLoop) + { + index_t i = 0; + + do + { + b_blockwise_copy.Run( + b_grid_desc, b_grid_buf, b_block_desc, b_block_origin_idx, b_block_buf_switch); + + block_sync_lds(); + + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + + blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); + + block_sync_lds(); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); + + b_block_buf = b_block_buf_switch; + ++i; + } while(i < (num_loop - 1)); + } + + // tail + { + block_sync_lds(); + + blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); + + block_sync_lds(); + } + } +}; + +template <> +struct GridwiseGemmPipeline_v1<1, false, false> +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + + __host__ __device__ static constexpr bool IsSupported(index_t /* num_loop */) { return true; } + + __host__ __device__ static constexpr bool CalculateHasMainLoop(index_t num_loop) + { + return num_loop > 1; + } + + template + __device__ static void Run(const AGridDesc& a_grid_desc, + const ABlockDesc& a_block_desc, + ABlockTransfer& a_blockwise_copy, + const AGridBuffer& a_grid_buf, + ABlockBuffer& a_block_buf, + const ABlockTransferStep& a_block_copy_step, + const BGridDesc& b_grid_desc, + const BBlockDesc& b_block_desc, + BBlockTransfer& b_blockwise_copy, + const BGridBuffer& b_grid_buf, + BBlockBuffer& b_block_buf, + const BBlockTransferStep& b_block_copy_step, + const BlockwiseGemm& blockwise_gemm, + CThreadBuffer& c_thread_buf, + index_t num_loop) + { + constexpr auto b_block_origin_idx = make_tuple(I0, I0, I0, I0, I0, I0, I0); + constexpr auto a_block_origin_idx = make_tuple(I0, I0, I0, I0, I0, I0, I0); + auto b_block_buf_switch = b_block_buf; + auto a_block_buf_switch = a_block_buf; + + // preload data into LDS + a_blockwise_copy.Run( + a_grid_desc, a_grid_buf, a_block_desc, a_block_origin_idx, a_block_buf); + b_blockwise_copy.Run( + b_grid_desc, b_grid_buf, b_block_desc, b_block_origin_idx, b_block_buf); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + // Initialize C + c_thread_buf.Clear(); + + // main body + if constexpr(HasMainLoop) + { + index_t i = 0; + + do + { + a_blockwise_copy.Run( + a_grid_desc, a_grid_buf, a_block_desc, a_block_origin_idx, a_block_buf_switch); + b_blockwise_copy.Run( + b_grid_desc, b_grid_buf, b_block_desc, b_block_origin_idx, b_block_buf_switch); + + block_sync_lds(); + + blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); + + block_sync_lds(); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + a_block_buf = a_block_buf_switch; + b_block_buf = b_block_buf_switch; + ++i; + } while(i < (num_loop - 1)); + } + + // tail + { + block_sync_lds(); + + blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); + + block_sync_lds(); + } + } +}; + +template +struct GridwiseGemmPipeline_v1_WeightOnly; + +template <> +struct GridwiseGemmPipeline_v1_WeightOnly<1, true, true> +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + + __host__ __device__ static constexpr bool IsSupported(index_t /* num_loop */) { return true; } + + __host__ __device__ static constexpr bool CalculateHasMainLoop(index_t num_loop) + { + return num_loop > 1; + } + + template + __device__ static void Run(const AGridDesc& a_grid_desc, + const ABlockDesc& a_block_desc, + ABlockTransfer& a_blockwise_copy, + const AGridBuffer& a_grid_buf, + ABlockBuffer& a_block_buf, + const ABlockTransferStep& a_block_copy_step, + const BGridDesc& b_grid_desc, + const BBlockDesc& b_block_desc, + BBlockTransfer& b_blockwise_copy, + const BGridBuffer& b_grid_buf, + BBlockBuffer& b_block_buf, + const BBlockTransferStep& b_block_copy_step, + const ScaleGridDesc& scale_grid_desc, + const ScaleGridBuffer& scale_grid_buf, + const BlockwiseGemm& blockwise_gemm, + CThreadBuffer& c_thread_buf, + index_t num_loop) + { + // Global Prefetch Stage 1 + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); + // Scale read once + b_blockwise_copy.RunScaleRead(scale_grid_desc, scale_grid_buf); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + // Initialize C + c_thread_buf.Clear(); + + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); + // Dequantization fused in blockwise_copy + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf); + + // main body + if constexpr(HasMainLoop) + { + index_t i = 0; + + do + { + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + + block_sync_lds(); + + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); + + blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); + + block_sync_lds(); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf); + + ++i; + } while(i < (num_loop - 1)); + } + + // tail + { + block_sync_lds(); + + blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); + } + } +}; + template struct GridwiseGemmPipelineInterwave_v1; @@ -349,7 +749,7 @@ struct GridwiseGemmPipelineInterwave_v1<1> // Note: 2 stage prefetch not optimized for inter-wave loop scheduler template <> -struct GridwiseGemmPipelineInterwave_v1<2> : public GridwiseGemmPipeline_v1<2> +struct GridwiseGemmPipelineInterwave_v1<2> : public GridwiseGemmPipeline_v1<2, true, true> { }; @@ -359,7 +759,7 @@ constexpr auto GridwiseGemmPipeline_v1_Selector() { if constexpr(LoopSched == LoopScheduler::Default) { - return GridwiseGemmPipeline_v1{}; + return GridwiseGemmPipeline_v1{}; } else if constexpr(LoopSched == LoopScheduler::Interwave) { diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_split_k_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_split_k_multiple_d_xdl_cshuffle.hpp index e7dc0d3eb0..0078660556 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_split_k_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_split_k_multiple_d_xdl_cshuffle.hpp @@ -93,7 +93,7 @@ struct GridwiseGemmSplitKMultipleD_xdl_cshuffle using ThisThreadBlock = ThisThreadBlock; - using GridwiseGemmPipe = GridwiseGemmPipeline_v1; + using GridwiseGemmPipe = GridwiseGemmPipeline_v1; __host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() { diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp index 066cfc62f2..8e4117593c 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp @@ -18,11 +18,11 @@ namespace ck { template (p_a_grid, p_b_grid, p_c_grid, p_shared, - a_grid_desc_k0_m_k1, - b_grid_desc_k0_n_k1, + a_grid_desc, + b_grid_desc, c_grid_desc_mblock_mperblock_nblock_nperblock, a_element_op, b_element_op, @@ -67,8 +63,8 @@ __global__ void ignore = p_a_grid; ignore = p_b_grid; ignore = p_c_grid; - ignore = a_grid_desc_k0_m_k1; - ignore = b_grid_desc_k0_n_k1; + ignore = a_grid_desc; + ignore = b_grid_desc; ignore = c_grid_desc_mblock_mperblock_nblock_nperblock; ignore = a_element_op; ignore = b_element_op; @@ -78,21 +74,21 @@ __global__ void } template -struct GridwiseGemm_k0mk1_k0nk1_mn_wmma +struct GridwiseGemm_Wmma { static constexpr auto I0 = Number<0>{}; static constexpr auto I1 = Number<1>{}; @@ -132,103 +130,277 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma static constexpr auto I6 = Number<6>{}; static constexpr auto I7 = Number<7>{}; - // K1 should be Number<...> + // FIX ME: To be deprecated static constexpr auto K1 = Number{}; + static constexpr auto MWaves = MPerBlock / (MRepeat * MPerWmma); + static constexpr auto NWaves = NPerBlock / (NRepeat * NPerWmma); + static constexpr auto WmmaK = K1 == 16 ? 32 : 16; + using ThisThreadBlock = ThisThreadBlock; - using GridwiseGemmPipe = remove_cvref_t< - decltype(GridwiseGemmPipeline_Selector())>; + using GridwiseGemmPipe = + remove_cvref_t())>; - __host__ __device__ static constexpr auto GetABlockDescriptor_K0PerBlock_MPerBlock_K1() + // Describe how data store to (LDS/VGPR) buffer from Global memory + __host__ __device__ static constexpr auto MakeABlockDescriptor() { - constexpr auto max_lds_align = K1; - - // A matrix in LDS memory, dst of blockwise copy - constexpr auto a_block_desc_k0perblock_mperblock_k1 = [&]() { - if constexpr(ABlockLdsExtraM) + constexpr auto a_block_desc = [&]() { + if constexpr(AEnableLds) { - return make_naive_tensor_descriptor( - make_tuple(Number{}, Number{}, K1), - make_tuple(Number{} * K1, K1, I1)); + // K0->M->K1 Per Block + constexpr auto K0PerBlock = KPerBlock / K1; + constexpr auto max_lds_align = K1; + + if constexpr(ABlockLdsExtraM) + { + return make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, K1), + make_tuple(Number{} * K1, K1, I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, K1), max_lds_align); + } } else { - return make_naive_tensor_descriptor_aligned( - make_tuple(Number{}, Number{}, K1), max_lds_align); + constexpr auto KWmmaPerblock = KPerBlock / WmmaK; + constexpr auto K0PerWmma = WmmaK / 2 / K1; + // KWmma->MRepeat->MWave->K0PerWmma->KRow->MPerWmma->K1 Per Thread + return make_naive_tensor_descriptor( + make_tuple(Number{}, + Number{}, + I1, + Number{}, + I1, + I1, + K1), + make_tuple(Number{} * Number{} * K1, + Number{} * K1, + Number{} * K1, + K1, + K1, + K1, + I1)); } }(); - return a_block_desc_k0perblock_mperblock_k1; + return a_block_desc; } - __host__ __device__ static constexpr auto GetBBlockDescriptor_K0PerBlock_NPerBlock_K1() + __host__ __device__ static constexpr auto MakeBBlockDescriptor() { - constexpr auto max_lds_align = K1; - - // B matrix in LDS memory, dst of blockwise copy - constexpr auto b_block_desc_k0perblock_nperblock_k1 = [&]() { - if constexpr(BBlockLdsExtraN) + constexpr auto b_block_desc = [&]() { + if constexpr(BEnableLds) { - return make_naive_tensor_descriptor( - make_tuple(Number{}, Number{}, K1), - make_tuple(Number{} * K1, K1, I1)); + // K0->N->K1 Per Block + constexpr auto K0PerBlock = KPerBlock / K1; + constexpr auto max_lds_align = K1; + + if constexpr(BBlockLdsExtraN) + { + return make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, K1), + make_tuple(Number{} * K1, K1, I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, K1), max_lds_align); + } } else { - return make_naive_tensor_descriptor_aligned( - make_tuple(Number{}, Number{}, K1), max_lds_align); + constexpr auto KWmmaPerblock = KPerBlock / WmmaK; + constexpr auto K0PerWmma = WmmaK / 2 / K1; + // KWmma->NRepeat->MWave->K0PerWmma->KRow->MPerWmma->K1 Per Thread + return make_naive_tensor_descriptor( + make_tuple(Number{}, + Number{}, + I1, + Number{}, + I1, + I1, + K1), + make_tuple(Number{} * Number{} * K1, + Number{} * K1, + Number{} * K1, + K1, + K1, + K1, + I1)); } }(); - return b_block_desc_k0perblock_nperblock_k1; + return b_block_desc; + } + + __host__ __device__ static constexpr auto MakeABlockSliceCopyStep() + { + constexpr auto a_block_copy_step = [&]() { + if constexpr(AEnableLds) + { + constexpr auto K0PerBlock = KPerBlock / K1; + + return make_multi_index(K0PerBlock, 0, 0); + } + else + { + constexpr auto KWmmaPerBlock = KPerBlock / WmmaK; + + return make_multi_index(KWmmaPerBlock, 0, 0, 0, 0, 0, 0); + } + }(); + + return a_block_copy_step; + } + + __host__ __device__ static constexpr auto MakeBBlockSliceCopyStep() + { + constexpr auto b_block_copy_step = [&]() { + if constexpr(BEnableLds) + { + constexpr auto K0PerBlock = KPerBlock / K1; + + return make_multi_index(K0PerBlock, 0, 0); + } + else + { + constexpr auto KWmmaPerBlock = KPerBlock / WmmaK; + + return make_multi_index(KWmmaPerBlock, 0, 0, 0, 0, 0, 0); + } + }(); + + return b_block_copy_step; + } + + // Describe how data read from (LDS/VGPR) buffer + template + __host__ __device__ static constexpr auto MakeAWaveDescriptor(const ABlockDesc_&) + { + + constexpr auto a_wave_desc = [&]() { + if constexpr(AEnableLds) + { + // AK0_M_AK1 -> AK0_MRepeat_Mwaves_AKRow_MPerWmma_AK1 + constexpr auto A_K0 = ABlockDesc_{}.GetLength(I0); + constexpr auto A_K1 = ABlockDesc_{}.GetLength(I2); + constexpr auto A_KRow = I1; + return transform_tensor_descriptor( + ABlockDesc_{}, + make_tuple(make_unmerge_transform(make_tuple(Number{}, A_KRow)), + make_unmerge_transform(make_tuple( + Number{}, Number{}, Number{})), + make_pass_through_transform(Number{})), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0, 3>{}, Sequence<1, 2, 4>{}, Sequence<5>{})); + } + else + { + // KWmma_MRepeat_MWave_K0PerWmma_KRow_MPerWmma_K1 -> K0_MRepeat_Mwaves_MPerWmma_K1 + constexpr auto KWmma = ABlockDesc_{}.GetLength(I0); + constexpr auto K0PerWmma = ABlockDesc_{}.GetLength(I3); + constexpr auto A_KRow = ABlockDesc_{}.GetLength(I4); + constexpr auto A_K1 = ABlockDesc_{}.GetLength(I6); + + // Err: merge transform cause non-constexpr issue + + // return transform_tensor_descriptor( + // ABlockDesc_{}, + // make_tuple(make_merge_transform(make_tuple(Number{}, I1)), + // make_pass_through_transform(Number{}), + // make_pass_through_transform(I1), + // make_pass_through_transform(I1), + // make_pass_through_transform(Number{})), + // make_tuple(Sequence<0, 3>{}, + // Sequence<1>{}, + // Sequence<2>{}, + // Sequence<4>{}, + // Sequence<5>{}), + // make_tuple( + // Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, + // Sequence<4>{})); + + // Workaround, Freeze transform + return make_naive_tensor_descriptor_packed(make_tuple(Number{}, + Number{}, + I1, + Number{}, + I1, + Number{})); + } + }(); + + return a_wave_desc; + } + + template + __host__ __device__ static constexpr auto MakeBWaveDescriptor(const BBlockDesc_&) + { + constexpr auto b_wave_desc = [&]() { + if constexpr(BEnableLds) + { + // BK0_N_BK1 -> BK0_NRepeat_Nwaves_NPerWmma_BK1 + constexpr auto B_K0 = BBlockDesc_{}.GetLength(I0); + constexpr auto B_K1 = BBlockDesc_{}.GetLength(I2); + constexpr auto B_KRow = I1; + return transform_tensor_descriptor( + BBlockDesc_{}, + make_tuple(make_unmerge_transform(make_tuple(Number{}, B_KRow)), + make_unmerge_transform(make_tuple( + Number{}, Number{}, Number{})), + make_pass_through_transform(Number{})), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0, 3>{}, Sequence<1, 2, 4>{}, Sequence<5>{})); + } + else + { + // KWmma_MRepeat_MWave_K0PerWmma_KRow_MPerWmma_K1 -> K0_MRepeat_Mwaves_MPerWmma_K1 + constexpr auto KWmma = BBlockDesc_{}.GetLength(I0); + constexpr auto K0PerWmma = BBlockDesc_{}.GetLength(I3); + constexpr auto B_KRow = BBlockDesc_{}.GetLength(I4); + constexpr auto B_K1 = BBlockDesc_{}.GetLength(I6); + + // Workaround, Freeze transform + return make_naive_tensor_descriptor_packed(make_tuple(Number{}, + Number{}, + I1, + Number{}, + I1, + Number{})); + } + }(); + + return b_wave_desc; } __host__ __device__ static constexpr auto // *Caution Here repeat is shuffle repeat GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat() { - constexpr index_t MWave = MPerBlock / (MRepeat * MPerWmma); - constexpr index_t NWave = NPerBlock / (NRepeat * NPerWmma); - constexpr auto c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat = make_naive_tensor_descriptor_packed( make_tuple(I1, - Number{}, + Number{}, I1, - Number{})); + Number{})); return c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat; } - __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() - { - // LDS allocation for A and B: be careful of alignment - constexpr auto a_block_desc_k0perblock_mperblock_k1 = - GetABlockDescriptor_K0PerBlock_MPerBlock_K1(); - - constexpr auto b_block_desc_k0perblock_nperblock_k1 = - GetBBlockDescriptor_K0PerBlock_NPerBlock_K1(); - - constexpr auto max_lds_align = K1; - - constexpr auto a_block_space_size_aligned = math::integer_least_multiple( - a_block_desc_k0perblock_mperblock_k1.GetElementSpaceSize(), max_lds_align); - - constexpr auto b_block_space_size_aligned = math::integer_least_multiple( - b_block_desc_k0perblock_nperblock_k1.GetElementSpaceSize(), max_lds_align); - - return (a_block_space_size_aligned * sizeof(FloatA) + - b_block_space_size_aligned * sizeof(FloatB)); - } - // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} template - __host__ __device__ static constexpr bool - CheckValidity(const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1, - const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1, - const CGridDesc_M_N& c_grid_desc_m_n, - const Block2CTileMap& block_2_ctile_map) + __host__ __device__ static constexpr bool CheckValidity(const AGridDesc& a_grid_desc, + const BGridDesc& b_grid_desc, + const CGridDesc_M_N& c_grid_desc_m_n, + const Block2CTileMap& block_2_ctile_map) { static_assert(is_known_at_compile_time>::value, "wrong! K1 need to be known at compile-time"); @@ -237,23 +409,66 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma (NPerBlock % (NRepeat * NPerWmma)) == 0, "Invalid tuning param!"); - const auto M = a_grid_desc_k0_m_k1.GetLength(I1); - const auto N = b_grid_desc_k0_n_k1.GetLength(I1); - const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0); + const auto GetAProblemsizeMK = [&]() { + if constexpr(AEnableLds) + { + return make_tuple(a_grid_desc.GetLength(I1), + a_grid_desc.GetLength(I0) * a_grid_desc.GetLength(I2)); + } + else + { + return make_tuple(a_grid_desc.GetLength(I1) * a_grid_desc.GetLength(I2) * + a_grid_desc.GetLength(I5), + a_grid_desc.GetLength(I0) * a_grid_desc.GetLength(I3) * + a_grid_desc.GetLength(I4) * a_grid_desc.GetLength(I6)); + } + }; + + const auto GetBProblemsizeNK = [&]() { + if constexpr(BEnableLds) + { + return make_tuple(b_grid_desc.GetLength(I1), + b_grid_desc.GetLength(I0) * b_grid_desc.GetLength(I2)); + } + else + { + return make_tuple(b_grid_desc.GetLength(I1) * b_grid_desc.GetLength(I2) * + b_grid_desc.GetLength(I5), + b_grid_desc.GetLength(I0) * b_grid_desc.GetLength(I3) * + b_grid_desc.GetLength(I4) * b_grid_desc.GetLength(I6)); + } + }; + + const auto M = GetAProblemsizeMK()[I0]; + const auto N = GetBProblemsizeNK()[I0]; + const auto K = GetAProblemsizeMK()[I1]; if(!(M == c_grid_desc_m_n.GetLength(I0) && N == c_grid_desc_m_n.GetLength(I1) && - K0 == b_grid_desc_k0_n_k1.GetLength(I0) && K1 == a_grid_desc_k0_m_k1.GetLength(I2) && - K1 == b_grid_desc_k0_n_k1.GetLength(I2))) + K == GetBProblemsizeNK()[I1])) + { + printf("A: MxK = %d x %d, B: NxK = %d x %d, C: MxN = %d x %d\n", + GetAProblemsizeMK()[I0], + GetAProblemsizeMK()[I1], + GetBProblemsizeNK()[I0], + GetBProblemsizeNK()[I1], + c_grid_desc_m_n.GetLength(I0), + c_grid_desc_m_n.GetLength(I1)); + printf("GridwiseOp err: ProblemSize check"); return false; + } - if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % K0PerBlock == 0)) + if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0)) + { + printf("GridwiseOp err: ProblemSize division"); return false; + } // check gridwise gemm pipeline - const auto num_k_loop = K0 / K0PerBlock; + const auto num_k_loop = K / KPerBlock; if(!GridwiseGemmPipe::IsSupported(num_k_loop)) { + printf("GridwiseOp err: Pipeline not support this k_loop"); return false; } @@ -265,8 +480,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) constexpr long_index_t TwoGB = (long_index_t{1} << 31); - if(!(a_grid_desc_k0_m_k1.GetElementSpaceSize() * sizeof(FloatA) <= TwoGB && - b_grid_desc_k0_n_k1.GetElementSpaceSize() * sizeof(FloatB) <= TwoGB)) + if(!(a_grid_desc.GetElementSpaceSize() * sizeof(ADataType) <= TwoGB && + b_grid_desc.GetElementSpaceSize() * sizeof(BDataType) <= TwoGB)) { return false; } @@ -275,7 +490,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K) { - const index_t num_loop = K / (K0PerBlock * K1); + const index_t num_loop = K / KPerBlock; return GridwiseGemmPipe::CalculateHasMainLoop(num_loop); } @@ -313,13 +528,44 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma using DefaultBlock2CTileMap = remove_cvref_t; + struct SharedMemTrait + { + // LDS allocation for A and B: be careful of alignment + + static constexpr auto max_lds_align = K1; + + static constexpr auto a_block_space_size_aligned = + AEnableLds ? math::integer_least_multiple(MakeABlockDescriptor().GetElementSpaceSize(), + max_lds_align) + : 0; + static constexpr auto b_block_space_size_aligned = + BEnableLds ? math::integer_least_multiple(MakeBBlockDescriptor().GetElementSpaceSize(), + max_lds_align) + : 0; + + static constexpr auto a_block_space_offset = 0; + static constexpr auto b_block_space_offset = a_block_space_size_aligned; + + // LDS allocation for C shuffle in LDS + static constexpr auto c_shuffle_block_space_size = + GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat() + .GetElementSpaceSize(); + + static constexpr auto c_shuffle_block_space_offset = 0; + + static constexpr auto lds_size = + math::max(c_shuffle_block_space_size * sizeof(CShuffleDataType), + a_block_space_size_aligned * sizeof(ADataType) + + b_block_space_size_aligned * sizeof(BDataType)); + }; + template - __device__ static void Run(const FloatA* __restrict__ p_a_grid, - const FloatB* __restrict__ p_b_grid, - FloatC* __restrict__ p_c_grid, + __device__ static void Run(const ADataType* __restrict__ p_a_grid, + const BDataType* __restrict__ p_b_grid, + CDataType* __restrict__ p_c_grid, void* __restrict__ p_shared, - const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1, - const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1, + const AGridDesc& a_grid_desc, + const BGridDesc& b_grid_desc, const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock& c_grid_desc_mblock_mperblock_nblock_nperblock, const AElementwiseOperation& a_element_op, @@ -331,9 +577,9 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma /*******************************************************************************/ // Memory buffer zone. const auto a_grid_buf = make_dynamic_buffer( - p_a_grid, a_grid_desc_k0_m_k1.GetElementSpaceSize()); + p_a_grid, a_grid_desc.GetElementSpaceSize()); const auto b_grid_buf = make_dynamic_buffer( - p_b_grid, b_grid_desc_k0_n_k1.GetElementSpaceSize()); + p_b_grid, b_grid_desc.GetElementSpaceSize()); auto c_grid_buf = make_dynamic_buffer( p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); @@ -351,24 +597,41 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma const index_t n_block_data_idx_on_grid = __builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock); /*******************************************************************************/ -// BlockLevel, A/B Matrix ThreadMapping in LDS, As Destinaion of BlockWise_Copy - const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0); - constexpr auto max_lds_align = K1; - constexpr auto a_block_desc_k0perblock_mperblock_k1 = GetABlockDescriptor_K0PerBlock_MPerBlock_K1(); - constexpr auto b_block_desc_k0perblock_nperblock_k1 = GetBBlockDescriptor_K0PerBlock_NPerBlock_K1(); - // A matrix blockwise copy - auto a_blockwise_copy = - ThreadGroupTensorSliceTransfer_v4r1< ThisThreadBlock, +// BlockLevel, A/B Matrix ThreadMapping in WMMA Source buffer, As Destinaion of BlockWise_Copy + const auto K = [&](){ + if constexpr(AEnableLds){ + return a_grid_desc.GetLength(I0) * a_grid_desc.GetLength(I2); + } + else{ + return a_grid_desc.GetLength(I0) * a_grid_desc.GetLength(I3) + * a_grid_desc.GetLength(I4) * a_grid_desc.GetLength(I6); + } + }(); + + constexpr auto a_block_desc = MakeABlockDescriptor(); + constexpr auto b_block_desc = MakeBBlockDescriptor(); + + auto a_block_trait = [&](){ + // A matrix blockwise copy + if constexpr(AEnableLds) + { + constexpr auto K0PerBlock = KPerBlock/ K1; + auto a_block_buf = make_dynamic_buffer( + static_cast(p_shared), + SharedMemTrait::a_block_space_size_aligned); + + auto a_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, /* typename ThreadClusterLengths, */ ABlockTransferThreadClusterLengths_K0_M_K1, /* typename ThreadClusterArrangeOrder, */ ABlockTransferThreadClusterArrangeOrder, -/* typename SrcData, */ FloatA, -/* typename DstData, */ FloatA, -/* typename SrcDesc, */ decltype(a_grid_desc_k0_m_k1), -/* typename DstDesc, */ decltype(a_block_desc_k0perblock_mperblock_k1), +/* typename SrcData, */ ADataType, +/* typename DstData, */ ADataType, +/* typename SrcDesc, */ decltype(a_grid_desc), +/* typename DstDesc, */ decltype(a_block_desc), /* typename SrcDimAccessOrder, */ ABlockTransferSrcAccessOrder, /* typename DstDimAccessOrder, */ Sequence<0, 1, 2>, /* index_t SrcVectorDim, */ ABlockTransferSrcVectorDim, @@ -378,99 +641,197 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma /* index_t SrcScalarStrideInVector, */ 1, /* index_t DstScalarStrideInVector, */ 1, /* bool ThreadTransferSrcResetCoordinateAfterRun, */ AThreadTransferSrcResetCoordinateAfterRun, -/* bool ThreadTransferDstResetCoordinateAfterRun, */ true>( - a_grid_desc_k0_m_k1, +/* bool ThreadTransferDstResetCoordinateAfterRun, */ true, + NumGemmKPrefetchStage>( + a_grid_desc, make_multi_index(0, m_block_data_idx_on_grid, 0), a_element_op, - a_block_desc_k0perblock_mperblock_k1, + a_block_desc, make_multi_index(0, 0, 0), ck::tensor_operation::element_wise::PassThrough{}); - // B matrix blockwise copy - auto b_blockwise_copy = - ThreadGroupTensorSliceTransfer_v4r1, - BBlockTransferThreadClusterLengths_K0_N_K1, - BBlockTransferThreadClusterArrangeOrder, - FloatB, - FloatB, - decltype(b_grid_desc_k0_n_k1), - decltype(b_block_desc_k0perblock_nperblock_k1), - BBlockTransferSrcAccessOrder, - Sequence<0, 1, 2>, - BBlockTransferSrcVectorDim, - 2, - BBlockTransferSrcScalarPerVector, - BBlockTransferDstScalarPerVector_K1, - 1, - 1, - BThreadTransferSrcResetCoordinateAfterRun, - true>( - b_grid_desc_k0_n_k1, - make_multi_index(0, n_block_data_idx_on_grid, 0), - b_element_op, - b_block_desc_k0perblock_nperblock_k1, - make_multi_index(0, 0, 0), - ck::tensor_operation::element_wise::PassThrough{}); + return make_tuple(a_block_buf, a_blockwise_copy); + } + else + { + // Thread-wise copy + // KPerBlock/WmmaK -> MRepeat -> MWaves -> K0PerWmma -> KRow -> MPerWmma -> K1 + constexpr auto KWmmaPerBlock = KPerBlock / WmmaK; + constexpr auto K0PerWmma = WmmaK/2/K1Value; + auto a_block_buf = make_static_buffer( + a_block_desc.GetElementSpaceSize()); + + // Limitation: NumDim of Src and Dst descriptor should be identical + auto a_blockwise_copy = + ThreadwiseTensorSliceTransfer_v2{}, + Number{}, + I1, + Number{}, + I1, + I1, + Number{}>, + Sequence<0, 1, 2, 3, 4, 5, 6>, + 6, + ABlockTransferSrcScalarPerVector, + AThreadTransferSrcResetCoordinateAfterRun, + true>( + a_grid_desc, + make_multi_index(0, + m_block_data_idx_on_grid/(MWaves * MPerWmma), + get_thread_local_1d_id() / 32, + 0, + (get_thread_local_1d_id() % 32 )/ 16, + get_thread_local_1d_id() % 16, + 0)); + + return make_tuple(a_block_buf, a_blockwise_copy); + } + }; + auto b_block_trait = [&](){ + if constexpr(BEnableLds) + { + constexpr auto K0PerBlock = KPerBlock/ K1; + auto b_block_buf = make_dynamic_buffer( + static_cast(p_shared) + SharedMemTrait::b_block_space_offset, + SharedMemTrait::b_block_space_size_aligned); + + auto b_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + BBlockTransferThreadClusterLengths_K0_N_K1, + BBlockTransferThreadClusterArrangeOrder, + BDataType, + BDataType, + decltype(b_grid_desc), + decltype(b_block_desc), + BBlockTransferSrcAccessOrder, + Sequence<0, 1, 2>, + BBlockTransferSrcVectorDim, + 2, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + 1, + 1, + BThreadTransferSrcResetCoordinateAfterRun, + true, + NumGemmKPrefetchStage>( + b_grid_desc, + make_multi_index(0, n_block_data_idx_on_grid, 0), + b_element_op, + b_block_desc, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + return make_tuple(b_block_buf, b_blockwise_copy); + } + else + { + // Thread-wise copy + // KPerBlock/WmmaK -> NRepeat -> NWaves -> WmmaK/K1 -> NPerWmma -> K1 + constexpr auto KWmmaPerBlock = KPerBlock / WmmaK; + constexpr auto K0PerWmma = WmmaK/2/K1Value; + auto b_block_buf = make_static_buffer( + b_block_desc.GetElementSpaceSize()); + + // Limitation: NumDim of Src and Dst descriptor should be identical + auto b_blockwise_copy = + ThreadwiseTensorSliceTransfer_v2{}, + Number{}, + I1, + Number{}, + I1, + I1, + Number{}>, + Sequence<0, 1, 2, 3, 4, 5, 6>, + 6, + BBlockTransferSrcScalarPerVector, + BThreadTransferSrcResetCoordinateAfterRun, + true>( + b_grid_desc, + make_multi_index(0, + n_block_data_idx_on_grid/(NWaves * NPerWmma), + get_thread_local_1d_id() / 32, + 0, + (get_thread_local_1d_id() % 32 )/ 16, + get_thread_local_1d_id() % 16, + 0)); + + return make_tuple(b_block_buf, b_blockwise_copy); + } + }; + + auto a_block_buf = a_block_trait()[I0]; + auto a_blockwise_copy = a_block_trait()[I1]; + + auto b_block_buf = b_block_trait()[I0]; + auto b_blockwise_copy = b_block_trait()[I1]; /*******************************************************************************/ // GEMM - constexpr auto WmmaK = 16; constexpr auto KPack = math::integer_least_multiple(K1, WmmaK); auto blockwise_gemm = - BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle{}; + BlockwiseGemmWMMA{}; // Prepare Register for C matrix auto c_thread_buf = blockwise_gemm.GetCThreadBuffer(); -/*******************************************************************************/ - constexpr auto a_block_space_size_aligned = math::integer_least_multiple(a_block_desc_k0perblock_mperblock_k1.GetElementSpaceSize(), max_lds_align); - // LDS allocation for A and B: be careful of alignment - auto a_block_buf = make_dynamic_buffer(static_cast(p_shared), a_block_desc_k0perblock_mperblock_k1.GetElementSpaceSize()); - auto b_block_buf = make_dynamic_buffer(static_cast(p_shared) + a_block_space_size_aligned, b_block_desc_k0perblock_nperblock_k1.GetElementSpaceSize()); - +/*******************************************************************************/ // Shift Per SUB_K - constexpr auto a_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0); - constexpr auto b_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0); + constexpr auto a_block_slice_copy_step = MakeABlockSliceCopyStep(); + constexpr auto b_block_slice_copy_step = MakeBBlockSliceCopyStep(); // gridwise GEMM pipeline - const index_t K0BlockMainLoop = __builtin_amdgcn_readfirstlane(K0 / K0PerBlock); - GridwiseGemmPipe::template Run(a_grid_desc_k0_m_k1, - a_block_desc_k0perblock_mperblock_k1, + const index_t KBlockMainLoop = __builtin_amdgcn_readfirstlane(K / KPerBlock); + GridwiseGemmPipe::template Run(a_grid_desc, + a_block_desc, a_blockwise_copy, a_grid_buf, a_block_buf, a_block_slice_copy_step, - b_grid_desc_k0_n_k1, - b_block_desc_k0perblock_nperblock_k1, + b_grid_desc, + b_block_desc, b_blockwise_copy, b_grid_buf, b_block_buf, b_block_slice_copy_step, blockwise_gemm, c_thread_buf, - K0BlockMainLoop); + KBlockMainLoop); /*******************************************************************************/ // write out to C, implement shuffle { + // C mapping in single thread. constexpr auto c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs = blockwise_gemm.GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs(); - // This API Provide All dimension (size) you need + // C mapping in single block constexpr auto c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp = blockwise_gemm.GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs(); @@ -485,8 +846,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat(); auto c_shuffle_block_buf = make_dynamic_buffer( - static_cast(p_shared), - c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat.GetElementSpaceSize()); + static_cast(p_shared) + SharedMemTrait::c_shuffle_block_space_offset, + SharedMemTrait::c_shuffle_block_space_size); constexpr auto c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs = transform_tensor_descriptor( c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat, @@ -532,8 +893,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma // shuffle: threadwise copy C from VGPR to LDS auto c_thread_copy_vgpr_to_lds = - ThreadwiseTensorSliceTransfer_v1r3, // BlockSliceLengths, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, - FloatCShuffle, // typename SrcData, - FloatC, // typename DstData, + CShuffleDataType, // typename SrcData, + CDataType, // typename DstData, decltype(c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat), decltype(c_grid_desc_mblock_mperblock_nblock_nperblock), Sequence<0, 1, 2, 3>, // typename DimAccessOrder, @@ -636,6 +997,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma if constexpr(access_id < num_access - 1) { constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id); + // move on C c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow( c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step); diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp index 608679a4fa..3fdf686523 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp @@ -1333,4 +1333,139 @@ struct ThreadwiseTensorSliceTransfer_StaticToStatic ElementwiseOperation element_op_; }; +// Specilized for WMMA +// A single Wave32 is composed by double row +// Data exchange allowed between these two rows +// This RowLane Dst buf will be filled from two Src buf +// SrcA: From specific thread buffer hold by This RowLane on This Row +// SrcB: From specific thread buffer hold by This RowLane on The other Row +template ::type = false> +struct ThreadwiseTensorSliceTransfer_StaticToStatic_InterRow +{ + static constexpr index_t nDim = SliceLengths::Size(); + + using Index = MultiIndex; + + __device__ constexpr ThreadwiseTensorSliceTransfer_StaticToStatic_InterRow(const Index& src_idx) + { + static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(), + "wrong! Desc need to known at compile-time"); + + static_assert(SliceLengths::At(Number{}) % DstScalarPerVector == 0, + "wrong! Not divisible"); + ignore = src_idx; + } + + template + __device__ void Run(const SrcDesc&, + const SrcSliceOriginIdx&, + const SrcBuffer& src_buf, + const DstDesc&, + const DstSliceOriginIdx&, + DstBuffer& dst_buf) const + { + static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(), + "wrong! Desc need to known at compile-time"); + + static_assert(is_known_at_compile_time>::value && + is_known_at_compile_time>::value, + "wrong! SliceOrigin need to known at compile-time"); + + static_assert(SrcBuffer::IsStaticBuffer() && DstBuffer::IsStaticBuffer(), + "wrong! Buffer need to be StaticBuffer"); + + // SrcDesc and src_slice_origin_idx are known at compile-time + constexpr auto src_desc = remove_cvref_t{}; + constexpr auto dst_desc = remove_cvref_t{}; + constexpr auto src_slice_origin_idx = to_multi_index(SrcSliceOriginIdx{}); + constexpr auto dst_slice_origin_idx = to_multi_index(DstSliceOriginIdx{}); + + // scalar per access on each dim + constexpr auto dst_scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + + constexpr auto dst_scalar_step_in_vector = + generate_sequence(detail::lambda_scalar_step_in_vector{}, Number{}); + + using SpaceFillingCurve = SpaceFillingCurve>; + + static_assert(DstScalarPerVector == SpaceFillingCurve::ScalarPerVector, + "wrong!DstScalarPerVector != SpaceFillingCurve::ScalarPerVector"); + + constexpr auto num_access = SpaceFillingCurve::GetNumOfAccess(); + + static_for<0, num_access, 1>{}([&](auto idx_1d) { + constexpr auto idx_md = SpaceFillingCurve::GetIndex(idx_1d); + + // copy data from src_buf into dst_vector + static_for<0, DstScalarPerVector, 1>{}([&](auto i) { + // src_desc error, non constexpr, caused by merge transform + constexpr index_t src_offset = src_desc.CalculateOffset( + src_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector); + + constexpr index_t dst_offset = dst_desc.CalculateOffset( + dst_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector); + + SrcData v_this_row, v_theother_row; + // int type temp value due to intrinsic requirement + int temp = 0; + + // apply element-wise operation + element_op_(v_this_row, src_buf[Number{}]); + + // apply intra-row permute. + if constexpr(IntraRowSwizzlePerm) + { + temp = __builtin_amdgcn_permlane16( + temp, type_convert_sp(v_this_row), 0xb3a29180, 0xf7e6d5c4, 1, 0); + v_this_row = type_convert_sp(temp); + } + + // apply inter-row permute. + temp = __builtin_amdgcn_permlanex16(temp, + type_convert_sp(v_this_row), + LowEightRowlaneIdx, + HighEightRowLaneIdx, + 1, + 0); + v_theother_row = type_convert_sp(temp); + + if(get_thread_local_1d_id() % 32 < 16) + { + // apply type convert + dst_buf(Number{}) = type_convert_sp(v_this_row); + dst_buf(Number{}) = + type_convert_sp(v_theother_row); + } + else + { + // apply type convert + dst_buf(Number{}) = + type_convert_sp(v_this_row); + dst_buf(Number{}) = type_convert_sp(v_theother_row); + } + }); + }); + } + ElementwiseOperation element_op_{}; +}; + } // namespace ck diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1_dequant.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1_dequant.hpp new file mode 100644 index 0000000000..174b82f870 --- /dev/null +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1_dequant.hpp @@ -0,0 +1,1066 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor/static_tensor.hpp" + +namespace ck { + +namespace detail { +// TODO: How to fix this? It uses an struct instead of lambda because lambda +// doesn't have constructor +template +struct lambda_scalar_per_access_for_src_and_dst_idle +{ + __host__ __device__ constexpr auto operator()(index_t i) const + { + if(i == SrcVectorDim && i == DstVectorDim) + { + return math::lcm(SrcScalarPerVector, DstScalarPerVector); + } + else if(i == SrcVectorDim) + { + return SrcScalarPerVector; + } + else if(i == DstVectorDim) + { + return DstScalarPerVector; + } + else + { + return 1; + } + } +}; + +} // namespace detail + +// Assume: +// 1. src_desc and dst_desc are not known at compile-time +// 2. SrcBuffer and DstBuffer are DynamicBuffer +// 3. src_slice_origin and dst_slice_origin are not known at compile-time, +// 4. Use thread buffer +// 5. Dequantization happened between read and write. +template +struct ThreadwiseTensorSliceTransfer_v3r1_dequant +{ + static constexpr index_t nDim = SliceLengths::Size(); + using Index = MultiIndex; + + using SrcCoord = decltype(make_tensor_coordinate(SrcDesc{}, Index{})); + using ScaleCoord = decltype(make_tensor_coordinate(SrcDesc{}, Index{})); + using DstCoord = decltype(make_tensor_coordinate(DstDesc{}, Index{})); + + static constexpr auto I0 = Number<0>{}; + + __device__ constexpr ThreadwiseTensorSliceTransfer_v3r1_dequant( + const SrcDesc& src_desc, + const Index& src_slice_origin, + const SrcElementwiseOperation& src_element_op, + const ScaleDesc& scale_desc, + const Index& scale_slice_origin, + const ScaleElementwiseOperation& scale_element_op, + const DstDesc& dst_desc, + const Index& dst_slice_origin, + const DstElementwiseOperation& dst_element_op) + : src_coord_(make_tensor_coordinate(src_desc, src_slice_origin)), + scale_coord_(make_tensor_coordinate(scale_desc, scale_slice_origin)), + dst_coord_(make_tensor_coordinate(dst_desc, dst_slice_origin)), + src_element_op_(src_element_op), + scale_element_op_(scale_element_op), + dst_element_op_(dst_element_op) + { + } + + __device__ void SetSrcSliceOrigin(const SrcDesc& src_desc, const Index& src_slice_origin_idx) + { + src_coord_ = make_tensor_coordinate(src_desc, src_slice_origin_idx); + } + + __device__ void SetScaleSliceOrigin(const ScaleDesc& scale_desc, + const Index& scale_slice_origin_idx) + { + scale_coord_ = make_tensor_coordinate(scale_desc, scale_slice_origin_idx); + } + + __device__ void SetDstSliceOrigin(const DstDesc& dst_desc, const Index& dst_slice_origin_idx) + { + dst_coord_ = make_tensor_coordinate(dst_desc, dst_slice_origin_idx); + } + + template + __device__ void RunRead(const SrcDesc& src_desc, + const SrcBuffer& src_buf, + Number thread_scratch_id = Number{}) + { + static_assert(SrcBuffer::GetAddressSpace() == AddressSpaceEnum::Global or + SrcBuffer::GetAddressSpace() == AddressSpaceEnum::Lds, + "wrong!"); + + static_assert( + is_same, remove_cvref_t>::value, + "wrong! SrcBuffer and SrcData data type are inconsistent"); + + // scalar per access on each dim + // TODO: don't use lambda_scalar_per_access + constexpr auto src_scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + + constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access; + + constexpr auto src_dim_access_order = SrcDimAccessOrder{}; + + constexpr auto ordered_src_access_lengths = + container_reorder_given_new2old(src_access_lengths, src_dim_access_order); + + // make forward steps + const auto src_forward_steps = generate_tuple( + [&](auto i) { + Index forward_step_idx; + + static_for<0, nDim, 1>{}([&](auto j) { + forward_step_idx(j) = (i.value == j.value) ? src_scalar_per_access[i] : 0; + }); + + return make_tensor_coordinate_step(src_desc, forward_step_idx); + }, + Number{}); + + // make backward steps + const auto src_backward_steps = generate_tuple( + [&](auto i) { + Index backward_step_idx; + + static_for<0, nDim, 1>{}([&](auto j) { + backward_step_idx(j) = (i.value == j.value) ? -src_scalar_per_access[i] : 0; + }); + + return make_tensor_coordinate_step(src_desc, backward_step_idx); + }, + Number{}); + + // loop over tensor and copy + static_ford{}([&](auto ordered_src_access_idx) { + // judge move forward or move backward + constexpr auto forward_sweep = [&]() { + StaticallyIndexedArray forward_sweep_; + + forward_sweep_(I0) = true; + + static_for<1, nDim, 1>{}([&](auto i) { + index_t tmp = ordered_src_access_idx[I0]; + + static_for<1, i, 1>{}([&](auto j) { + tmp = tmp * ordered_src_access_lengths[j] + ordered_src_access_idx[j]; + }); + + forward_sweep_(i) = tmp % 2 == 0; + }); + + return forward_sweep_; + }(); + + // calculate src data index + constexpr auto src_data_idx = [&]() { + Index ordered_idx; + + static_for<0, nDim, 1>{}([&](auto i) { + ordered_idx(i) = forward_sweep[i] ? ordered_src_access_idx[i] + : ordered_src_access_lengths[i] - 1 - + ordered_src_access_idx[i]; + }); + + return container_reorder_given_old2new(ordered_idx, src_dim_access_order) * + src_scalar_per_access; + }(); + + constexpr auto src_data_idx_seq = generate_sequence_v2( + [&](auto i) { return Number{}; }, Number{}); + + const bool is_src_valid = + coordinate_has_valid_offset_assuming_visible_index_is_valid(src_desc, src_coord_); + + using src_vector_type = vector_type_maker_t; + using src_vector_t = typename src_vector_type::type; + + // copy data from src_buf into src_vector_container + auto src_vector_container = src_vector_type{ + src_buf.template Get(src_coord_.GetOffset(), is_src_valid)}; + + // copy data from src_vector_container into src_thread_scratch_ + src_thread_scratch_tuple_(thread_scratch_id) + .template SetAsType( + src_data_idx_seq, src_vector_container.template AsType()[I0]); + + constexpr auto move_on_dim = [&]() constexpr + { + StaticallyIndexedArray move_on_dim_; + + static_for<0, nDim, 1>{}([&](auto i) { + move_on_dim_(i) = ordered_src_access_idx[i] < ordered_src_access_lengths[i] - 1; + + static_for{}([&](auto j) { + move_on_dim_(i) &= + ordered_src_access_idx[j] == ordered_src_access_lengths[j] - 1; + }); + }); + + return move_on_dim_; + } + (); + + // move src coord + static_for<0, nDim, 1>{}([&](auto i) { + if constexpr(move_on_dim[i]) + { + if constexpr(forward_sweep[i]) + { + move_tensor_coordinate( + src_desc, src_coord_, src_forward_steps[src_dim_access_order[i]]); + } + else + { + move_tensor_coordinate( + src_desc, src_coord_, src_backward_steps[src_dim_access_order[i]]); + } + } + }); + }); + + // move src coordinate back to slice origin (or not) + if constexpr(SrcResetCoordinateAfterRun) + { + const auto src_reset_step = + make_tensor_coordinate_step(src_desc, GetSrcCoordinateResetStep()); + + move_tensor_coordinate(src_desc, src_coord_, src_reset_step); + } + } + + template + __device__ void RunScaleRead(const ScaleDesc& scale_desc, const ScaleBuffer& scale_buf) + { + static_assert(ScaleBuffer::GetAddressSpace() == AddressSpaceEnum::Global or + ScaleBuffer::GetAddressSpace() == AddressSpaceEnum::Lds, + "wrong!"); + + static_assert( + is_same, remove_cvref_t>::value, + "wrong! ScaleBuffer and ScaleData data type are inconsistent"); + + // scalar per access on each dim + // TODO: don't use lambda_scalar_per_access + constexpr auto scale_scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + + constexpr auto scale_access_lengths = SliceLengths{} / scale_scalar_per_access; + + constexpr auto scale_dim_access_order = SrcDimAccessOrder{}; + + constexpr auto ordered_scale_access_lengths = + container_reorder_given_new2old(scale_access_lengths, scale_dim_access_order); + + // make forward steps + const auto scale_forward_steps = generate_tuple( + [&](auto i) { + Index forward_step_idx; + + static_for<0, nDim, 1>{}([&](auto j) { + forward_step_idx(j) = (i.value == j.value) ? scale_scalar_per_access[i] : 0; + }); + + return make_tensor_coordinate_step(scale_desc, forward_step_idx); + }, + Number{}); + + // make backward steps + const auto scale_backward_steps = generate_tuple( + [&](auto i) { + Index backward_step_idx; + + static_for<0, nDim, 1>{}([&](auto j) { + backward_step_idx(j) = (i.value == j.value) ? -scale_scalar_per_access[i] : 0; + }); + + return make_tensor_coordinate_step(scale_desc, backward_step_idx); + }, + Number{}); + + // loop over tensor and copy + static_ford{}([&](auto ordered_scale_access_idx) { + // judge move forward or move backward + constexpr auto forward_sweep = [&]() { + StaticallyIndexedArray forward_sweep_; + + forward_sweep_(I0) = true; + + static_for<1, nDim, 1>{}([&](auto i) { + index_t tmp = ordered_scale_access_idx[I0]; + + static_for<1, i, 1>{}([&](auto j) { + tmp = tmp * ordered_scale_access_lengths[j] + ordered_scale_access_idx[j]; + }); + + forward_sweep_(i) = tmp % 2 == 0; + }); + + return forward_sweep_; + }(); + + // calculate scale data index + constexpr auto scale_data_idx = [&]() { + Index ordered_idx; + + static_for<0, nDim, 1>{}([&](auto i) { + ordered_idx(i) = forward_sweep[i] ? ordered_scale_access_idx[i] + : ordered_scale_access_lengths[i] - 1 - + ordered_scale_access_idx[i]; + }); + + return container_reorder_given_old2new(ordered_idx, scale_dim_access_order) * + scale_scalar_per_access; + }(); + + constexpr auto scale_data_idx_seq = + generate_sequence_v2([&](auto i) { return Number{}; }, + Number{}); + + const bool is_scale_valid = coordinate_has_valid_offset_assuming_visible_index_is_valid( + scale_desc, scale_coord_); + + using scale_vector_type = vector_type_maker_t; + using scale_vector_t = typename scale_vector_type::type; + + // copy data from scale_buf into scale_vector_container + auto scale_vector_container = scale_vector_type{ + scale_buf.template Get(scale_coord_.GetOffset(), is_scale_valid)}; + + // copy data from scale_vector_container into scale_thread_scratch_ + scale_thread_scratch_.template SetAsType( + scale_data_idx_seq, scale_vector_container.template AsType()[I0]); + + constexpr auto move_on_dim = [&]() constexpr + { + StaticallyIndexedArray move_on_dim_; + + static_for<0, nDim, 1>{}([&](auto i) { + move_on_dim_(i) = + ordered_scale_access_idx[i] < ordered_scale_access_lengths[i] - 1; + + static_for{}([&](auto j) { + move_on_dim_(i) &= + ordered_scale_access_idx[j] == ordered_scale_access_lengths[j] - 1; + }); + }); + + return move_on_dim_; + } + (); + + // move scale coord + static_for<0, nDim, 1>{}([&](auto i) { + if constexpr(move_on_dim[i]) + { + if constexpr(forward_sweep[i]) + { + move_tensor_coordinate(scale_desc, + scale_coord_, + scale_forward_steps[scale_dim_access_order[i]]); + } + else + { + move_tensor_coordinate(scale_desc, + scale_coord_, + scale_backward_steps[scale_dim_access_order[i]]); + } + } + }); + }); + + // don't need to move scale coordinate back to slice origin + /* + if constexpr(SrcResetCoordinateAfterRun) + { + const auto scale_reset_step = + make_tensor_coordinate_step(scale_desc, GetScaleCoordinateResetStep()); + + move_tensor_coordinate(scale_desc, scale_coord_, scale_reset_step); + } + */ + } + + template + __device__ void + TransferDataFromSrcThreadScratchToDstThreadScratch(Number thread_scratch_id) + { +#if !CK_EXPERIMENTAL_USE_IN_REGISTER_SUB_DWORD_TRANSPOSE + static_ford{}([&](auto idx) { + // convert from SrcData to DstData here + dst_thread_scratch_(idx) = + type_convert(src_thread_scratch_tuple_[thread_scratch_id][idx]); + }); +#else + // sub-dword transpose between src_thread_scratch_ and dst_thread_scratch_ + // TODO make this logic more generic for more sub-dword datatype + if constexpr(SrcVectorDim != DstVectorDim && + ((is_same>::value && + is_same>::value && + SrcScalarPerVector % 2 == 0 && DstScalarPerVector % 2 == 0) || + (is_same>::value && + is_same>::value && + SrcScalarPerVector % 4 == 0 && DstScalarPerVector % 4 == 0))) + { + // each transpose does + // DstScalarPerVector # of src vectors in src_thread_scratch_ + // SrcScalarPerVector # of dst vectors in dst_thread_scratch_ + constexpr index_t num_src_vector = Number{}; + constexpr index_t num_dst_vector = Number{}; + + // Assume SrcVectorDim is not the same as DstVectorDim, so we do transpose + // TODO: make this logic generic for all scenario + static_assert(SrcVectorDim != DstVectorDim, "wrong"); + + constexpr auto src_scalar_step_in_vector = generate_sequence( + detail::lambda_scalar_step_in_vector{}, Number{}); + + constexpr auto dst_scalar_step_in_vector = generate_sequence( + detail::lambda_scalar_step_in_vector{}, Number{}); + + constexpr auto scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access_for_src_and_dst_idle{}, + Number{}); + + constexpr auto access_lengths = SliceLengths{} / scalar_per_access; + + static_ford{}([&](auto access_idx) { + constexpr auto data_idx = access_idx * scalar_per_access; + + constexpr auto data_idx_seq = generate_sequence_v2( + [&](auto i) { return Number{}; }, Number{}); + + using src_vector_t = vector_type_maker_t; + using dst_vector_t = vector_type_maker_t; + + // get DstScalarPerVector # of read-only references to src vectors from + // src_thread_scratch_ + const auto src_vector_refs = generate_tie( + [&](auto i) -> const src_vector_t& { + // i increment corresponds to movement in DstVectorDim + return src_thread_scratch_tuple_[thread_scratch_id].GetVectorTypeReference( + data_idx_seq + i * dst_scalar_step_in_vector); + }, + Number{}); + + // get SrcScalarPerVector # of references to dst vectors from dst_thread_scratch_ + auto dst_vector_refs = generate_tie( + [&](auto i) -> dst_vector_t& { + // i increment corresponds to movement in SrcVectorDim + return dst_thread_scratch_.GetVectorTypeReference( + data_idx_seq + i * src_scalar_step_in_vector); + }, + Number{}); + + // do data transpose + transpose_vectors{}( + src_vector_refs, dst_vector_refs); + }); + } + + // Do fast numeric convert + constexpr auto scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access_for_src_and_dst_idle{}, + Number{}); + + constexpr auto access_lengths = SliceLengths{} / scalar_per_access; + + using src_vector_type = vector_type_maker_t; + using src_vector_t = typename src_vector_type::type; + + using src_converted_vector_type = vector_type_maker_t; + using src_converted_vector_t = typename src_converted_vector_type::type; + // Vector-wise type convert + static_ford{}([&](auto access_idx) { + auto src_vector_container = src_vector_type{ + src_thread_scratch_tuple_[thread_scratch_id].template GetAsType( + access_idx)}; + + auto src_converted_vector_container = + src_converted_vector_type{fast_numeric_converter(src_vector_container)}; + + src_converted_thread_scratch_.template SetAsType( + access_idx, + src_converted_vector_container.template AsType()[I0]); + }); + + // Element-scale operation, expect packed multiplication + static_ford{}([&](auto idx) { + DstData dst_v; + constexpr auto scale_idx = Sequence{}; + // printf("Tid: %03d, scale: %04x\n", get_thread_local_1d_id(), + // *(reinterpret_cast(&scale_thread_scratch_[scale_idx]))); + src_element_op_(dst_v, + src_converted_thread_scratch_[idx] * scale_thread_scratch_[scale_idx]); + dst_thread_scratch_(idx) = dst_v; + }); +#endif + } + + template + __device__ void RunWrite(const DstDesc& dst_desc, + DstBuffer& dst_buf, + Number thread_scratch_id = Number{}) + { + // if there is transpose, it's done here + // TODO move this elsewhere + TransferDataFromSrcThreadScratchToDstThreadScratch(thread_scratch_id); + + static_assert(DstBuffer::GetAddressSpace() == AddressSpaceEnum::Global or + DstBuffer::GetAddressSpace() == AddressSpaceEnum::Lds, + "wrong!"); + + static_assert( + is_same, remove_cvref_t>::value, + "wrong! SrcBuffer or DstBuffer data type is wrong"); + + // src scalar per access on each dim + // TODO: don't use this + constexpr auto dst_scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + + constexpr auto dst_access_lengths = SliceLengths{} / dst_scalar_per_access; + + constexpr auto dst_dim_access_order = DstDimAccessOrder{}; + + constexpr auto ordered_dst_access_lengths = + container_reorder_given_new2old(dst_access_lengths, dst_dim_access_order); + + // make forward steps + const auto dst_forward_steps = generate_tuple( + [&](auto i) { + Index forward_step_idx; + + static_for<0, nDim, 1>{}([&](auto j) { + forward_step_idx(j) = (i.value == j.value) ? dst_scalar_per_access[i] : 0; + }); + + return make_tensor_coordinate_step(dst_desc, forward_step_idx); + }, + Number{}); + + // make backward steps + const auto dst_backward_steps = generate_tuple( + [&](auto i) { + Index backward_step_idx; + + static_for<0, nDim, 1>{}([&](auto j) { + backward_step_idx(j) = (i.value == j.value) ? -dst_scalar_per_access[i] : 0; + }); + + return make_tensor_coordinate_step(dst_desc, backward_step_idx); + }, + Number{}); + + // loop over tensor and copy + static_ford{}([&](auto ordered_dst_access_idx) { + // judge move forward or move backward + constexpr auto forward_sweep = [&]() { + StaticallyIndexedArray forward_sweep_; + + forward_sweep_(I0) = true; + + static_for<1, nDim, 1>{}([&](auto i) { + index_t tmp = ordered_dst_access_idx[I0]; + + static_for<1, i, 1>{}([&](auto j) { + tmp = tmp * ordered_dst_access_lengths[j] + ordered_dst_access_idx[j]; + }); + + forward_sweep_(i) = tmp % 2 == 0; + }); + + return forward_sweep_; + }(); + + // calculate dst data index + constexpr auto dst_data_idx = [&]() { + Index ordered_idx; + + static_for<0, nDim, 1>{}([&](auto i) { + ordered_idx(i) = forward_sweep[i] ? ordered_dst_access_idx[i] + : ordered_dst_access_lengths[i] - 1 - + ordered_dst_access_idx[i]; + }); + + return container_reorder_given_old2new(ordered_idx, dst_dim_access_order) * + dst_scalar_per_access; + }(); + + constexpr auto dst_data_idx_seq = generate_sequence_v2( + [&](auto i) { return Number{}; }, Number{}); + + const bool is_dst_valid = + coordinate_has_valid_offset_assuming_visible_index_is_valid(dst_desc, dst_coord_); + + using dst_vector_type = vector_type_maker_t; + using dst_vector_t = typename dst_vector_type::type; + + // copy data from dst_thread_scratch_ into dst_vector_container + auto dst_vector_container = dst_vector_type{ + dst_thread_scratch_.template GetAsType(dst_data_idx_seq)}; + + static_for<0, DstScalarPerVector, 1>{}([&](auto i) { + DstData dst_v; + + // apply DstElementwiseOperation + dst_element_op_(dst_v, dst_vector_container.template AsType()[i]); + + dst_vector_container.template AsType()(i) = dst_v; + }); + + // copy data from dst_vector_container to dst_buf + dst_buf.template Set( + dst_coord_.GetOffset(), + is_dst_valid, + dst_vector_container.template AsType()[I0]); + + constexpr auto move_on_dim = [&]() constexpr + { + StaticallyIndexedArray move_on_dim_; + + static_for<0, nDim, 1>{}([&](auto i) { + move_on_dim_(i) = ordered_dst_access_idx[i] < ordered_dst_access_lengths[i] - 1; + + static_for{}([&](auto j) { + move_on_dim_(i) &= + ordered_dst_access_idx[j] == ordered_dst_access_lengths[j] - 1; + }); + }); + + return move_on_dim_; + } + (); + + // move dst coord + static_for<0, nDim, 1>{}([&](auto i) { + if constexpr(move_on_dim[i]) + { + if constexpr(forward_sweep[i]) + { + move_tensor_coordinate( + dst_desc, dst_coord_, dst_forward_steps[dst_dim_access_order[i]]); + } + else + { + move_tensor_coordinate( + dst_desc, dst_coord_, dst_backward_steps[dst_dim_access_order[i]]); + } + } + }); + }); + + // move dst coordinate back to slice origin (or not) + if constexpr(DstResetCoordinateAfterRun) + { + const auto dst_reset_step = + make_tensor_coordinate_step(dst_desc, GetDstCoordinateResetStep()); + + move_tensor_coordinate(dst_desc, dst_coord_, dst_reset_step); + } + } + + __device__ static constexpr auto GetSrcCoordinateResetStep() + { + // scalar per access on each dim + // TODO: don't use lambda_scalar_per_access + constexpr auto src_scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + + constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access; + + constexpr auto src_dim_access_order = SrcDimAccessOrder{}; + + constexpr auto ordered_src_access_lengths = + container_reorder_given_new2old(src_access_lengths, src_dim_access_order); + + // judge move forward or move backward during the last iteration + constexpr auto forward_sweep = [&]() { + StaticallyIndexedArray forward_sweep_; + + forward_sweep_(I0) = true; + + static_for<1, nDim, 1>{}([&](auto i) { + index_t tmp = ordered_src_access_lengths[I0] - 1; + + static_for<1, i, 1>{}([&](auto j) { + tmp = tmp * ordered_src_access_lengths[j] + ordered_src_access_lengths[j] - 1; + }); + + forward_sweep_(i) = tmp % 2 == 0; + }); + + return forward_sweep_; + }(); + + // calculate src data index after last iteration in RunRead(), if it has not being reset by + // RunRead() + constexpr auto src_data_idx = [&]() { + Index ordered_idx; + + static_for<0, nDim, 1>{}([&](auto i) { + ordered_idx(i) = forward_sweep[i] ? ordered_src_access_lengths[i] - 1 : 0; + }); + + return container_reorder_given_old2new(ordered_idx, src_dim_access_order) * + src_scalar_per_access; + }(); + + // + constexpr auto reset_src_data_step = [&]() { + Index reset_src_data_step_; + + static_for<0, nDim, 1>{}([&](auto i) { reset_src_data_step_(i) = -src_data_idx[i]; }); + + return reset_src_data_step_; + }(); + + return reset_src_data_step; + } + + __device__ static constexpr auto GetDstCoordinateResetStep() + { + // scalar per access on each dim + // TODO: don't use lambda_scalar_per_access + constexpr auto dst_scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + + constexpr auto dst_access_lengths = SliceLengths{} / dst_scalar_per_access; + + constexpr auto dst_dim_access_order = DstDimAccessOrder{}; + + constexpr auto ordered_dst_access_lengths = + container_reorder_given_new2old(dst_access_lengths, dst_dim_access_order); + + // judge move forward or move backward during the last iteration + constexpr auto forward_sweep = [&]() { + StaticallyIndexedArray forward_sweep_; + + forward_sweep_(I0) = true; + + static_for<1, nDim, 1>{}([&](auto i) { + index_t tmp = ordered_dst_access_lengths[I0] - 1; + + static_for<1, i, 1>{}([&](auto j) { + tmp = tmp * ordered_dst_access_lengths[j] + ordered_dst_access_lengths[j] - 1; + }); + + forward_sweep_(i) = tmp % 2 == 0; + }); + + return forward_sweep_; + }(); + + // calculate dst data index after last iteration in RunWrite(), if it has not being reset by + // RunWrite() + constexpr auto dst_data_idx = [&]() { + Index ordered_idx; + + static_for<0, nDim, 1>{}([&](auto i) { + ordered_idx(i) = forward_sweep[i] ? ordered_dst_access_lengths[i] - 1 : 0; + }); + + return container_reorder_given_old2new(ordered_idx, dst_dim_access_order) * + dst_scalar_per_access; + }(); + + // + constexpr auto reset_dst_data_step = [&]() { + Index reset_dst_data_step_; + + static_for<0, nDim, 1>{}([&](auto i) { reset_dst_data_step_(i) = -dst_data_idx[i]; }); + + return reset_dst_data_step_; + }(); + + return reset_dst_data_step; + } + + // src_slice_origin_step_idx need to be known at compile-time, for performance reason + __device__ void MoveSrcSliceWindow(const SrcDesc& src_desc, + const Index& src_slice_origin_step_idx) + { + // if src coord was not reset by RunRead(), then need to adjust the step here + const auto adjusted_step_idx = + SrcResetCoordinateAfterRun ? src_slice_origin_step_idx + : src_slice_origin_step_idx + GetSrcCoordinateResetStep(); + + // is it OK to construct a new step every time? + const auto adjusted_step = make_tensor_coordinate_step(src_desc, adjusted_step_idx); + + move_tensor_coordinate(src_desc, src_coord_, adjusted_step); + } + + // dst_slice_origin_step_idx need to be known at compile-time, for performance reason + __device__ void MoveDstSliceWindow(const DstDesc& dst_desc, + const Index& dst_slice_origin_step_idx) + { + // if dst coord was not reset by RunWrite(), then need to adjust the step here + const auto adjusted_step_idx = + DstResetCoordinateAfterRun ? dst_slice_origin_step_idx + : dst_slice_origin_step_idx + GetDstCoordinateResetStep(); + + // is it OK to construct a new step every time? + const auto adjusted_step = make_tensor_coordinate_step(dst_desc, adjusted_step_idx); + + move_tensor_coordinate(dst_desc, dst_coord_, adjusted_step); + } + + __device__ static constexpr auto GetSrcThreadScratchDescriptor() + { + constexpr auto src_scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + + constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access; + + constexpr auto src_access_lengths_and_vector_length = container_push_back( + sequence_to_tuple_of_number(src_access_lengths), Number{}); + + // 1st stage of transforms + constexpr auto desc0 = + make_naive_tensor_descriptor_packed(src_access_lengths_and_vector_length); + + // 2nd stage of transforms + constexpr auto transforms = generate_tuple( + [&](auto i) { + if constexpr(i == SrcVectorDim) + { + return make_merge_transform_v3_division_mod( + make_tuple(src_access_lengths_and_vector_length[i], + src_access_lengths_and_vector_length[Number{}])); + } + else + { + return make_pass_through_transform(src_access_lengths_and_vector_length[i]); + } + }, + Number{}); + + constexpr auto low_dim_idss = generate_tuple( + [&](auto i) { + if constexpr(i == SrcVectorDim) + { + return Sequence{}; + } + else + { + return Sequence{}; + } + }, + Number{}); + + constexpr auto up_dim_idss = + generate_tuple([&](auto i) { return Sequence{}; }, Number{}); + + return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss); + } + + __device__ static constexpr auto GetScaleThreadScratchDescriptor() + { + + constexpr auto scale_scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + + constexpr auto scale_access_lengths = SliceLengths{} / scale_scalar_per_access; + + constexpr auto scale_access_lengths_and_vector_length = container_push_back( + sequence_to_tuple_of_number(scale_access_lengths), Number{}); + + // 1st stage of transforms + constexpr auto desc0 = + make_naive_tensor_descriptor_packed(scale_access_lengths_and_vector_length); + + // 2nd stage of transforms + constexpr auto transforms = generate_tuple( + [&](auto i) { + if constexpr(i == SrcVectorDim) + { + return make_merge_transform_v3_division_mod( + make_tuple(scale_access_lengths_and_vector_length[i], + scale_access_lengths_and_vector_length[Number{}])); + } + else + { + return make_pass_through_transform(scale_access_lengths_and_vector_length[i]); + } + }, + Number{}); + + constexpr auto low_dim_idss = generate_tuple( + [&](auto i) { + if constexpr(i == SrcVectorDim) + { + return Sequence{}; + } + else + { + return Sequence{}; + } + }, + Number{}); + + constexpr auto up_dim_idss = + generate_tuple([&](auto i) { return Sequence{}; }, Number{}); + + return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss); + } + + __device__ static constexpr auto GetDstThreadScratchDescriptor() + { + // 1st stage of transforms + constexpr auto dst_scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + + constexpr auto dst_access_lengths = SliceLengths{} / dst_scalar_per_access; + + constexpr auto dst_access_lengths_and_vector_length = container_push_back( + sequence_to_tuple_of_number(dst_access_lengths), Number{}); + + constexpr auto desc0 = + make_naive_tensor_descriptor_packed(dst_access_lengths_and_vector_length); + + // 2nd stage of transforms + constexpr auto transforms = generate_tuple( + [&](auto i) { + if constexpr(i == DstVectorDim) + { + return make_merge_transform_v3_division_mod( + make_tuple(dst_access_lengths_and_vector_length[i], + dst_access_lengths_and_vector_length[Number{}])); + } + else + { + return make_pass_through_transform(dst_access_lengths_and_vector_length[i]); + } + }, + Number{}); + + constexpr auto low_dim_idss = generate_tuple( + [&](auto i) { + if constexpr(i == DstVectorDim) + { + return Sequence{}; + } + else + { + return Sequence{}; + } + }, + Number{}); + + constexpr auto up_dim_idss = + generate_tuple([&](auto i) { return Sequence{}; }, Number{}); + + return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss); + } + + private: + static constexpr auto src_thread_scratch_desc_ = decltype(GetSrcThreadScratchDescriptor()){}; + static constexpr auto scale_thread_scratch_desc_ = + decltype(GetScaleThreadScratchDescriptor()){}; + static constexpr auto dst_thread_scratch_desc_ = decltype(GetDstThreadScratchDescriptor()){}; + + /* + template + struct ScaleThreadScratchDesc{}; + */ + + // Registers, contain raw data loaded from global buffer + using SrcThreadScratch = StaticTensorTupleOfVectorBuffer; + + // Registers, contain fast converted data + using SrcThreadConvertedScratch = + StaticTensorTupleOfVectorBuffer; + + // Registers, contain scale data + using ScaleThreadScratch = StaticTensorTupleOfVectorBuffer; + + // Registers, contain dequantized data + using DstThreadScratch = StaticTensorTupleOfVectorBuffer; + + using FastTypeConverter = tensor_operation::element_wise:: + FastNumericArrayConverter; + + StaticallyIndexedArray src_thread_scratch_tuple_; + SrcThreadConvertedScratch src_converted_thread_scratch_; + ScaleThreadScratch scale_thread_scratch_; + + DstThreadScratch dst_thread_scratch_; + FastTypeConverter fast_numeric_converter; + + SrcCoord src_coord_; + ScaleCoord scale_coord_; + DstCoord dst_coord_; + const SrcElementwiseOperation src_element_op_; + const ScaleElementwiseOperation scale_element_op_; + const DstElementwiseOperation dst_element_op_; +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp b/include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp index 814b4167b8..70fbcec10f 100644 --- a/include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp +++ b/include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp @@ -89,6 +89,7 @@ struct wmma_type @@ -129,6 +130,7 @@ struct wmma_type @@ -153,7 +155,6 @@ struct wmma_type struct wmma_type + template __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const { if constexpr(wave_size == 32) { - intrin_wmma_f16_16x16x16_f16_w32::Run(a, b, reg_c); + intrin_wmma_f16_16x16x16_f16_w32::Run(a, b, reg_c); } else if constexpr(wave_size == 64) { - intrin_wmma_f16_16x16x16_f16_w64::Run(a, b, reg_c); + intrin_wmma_f16_16x16x16_f16_w64::Run(a, b, reg_c); } } }; - template struct wmma_type::Run(a, b, reg_c); + intrin_wmma_bf16_16x16x16_bf16_w32::Run(a, b, reg_c); } else if constexpr(wave_size == 64) { - intrin_wmma_bf16_16x16x16_bf16_w64::Run(a, b, reg_c); + intrin_wmma_bf16_16x16x16_bf16_w64::Run(a, b, reg_c); } } }; -#endif - template struct wmma_type + bool TransposeC = false, + bool AssemblyBackend = false> struct WmmaGemm { static constexpr auto I0 = Number<0>{}; @@ -369,14 +366,14 @@ struct WmmaGemm static constexpr auto I5 = Number<5>{}; using CIndex = MultiIndex<2>; - using CIndex4D = MultiIndex<4>; + using CIndex3D = MultiIndex<3>; __host__ __device__ constexpr WmmaGemm() { static_assert(NPerWmma == 16 && MPerWmma == 16, "Only support GemmNPerWmma == 16 and GemmMPerWmma == 16 for wmma"); - static_assert(KPack == wmma_instr.k_per_wmma, "KPack should be k_per_wmma"); + static_assert(KPack % wmma_instr.k_per_wmma == 0, "KPack should be multiple of k_per_wmma"); } // WMMA output supporting C = A * B @@ -421,9 +418,49 @@ struct WmmaGemm Sequence<5>{})); } + // Transposed WMMA Output C' = B' * A' + template + __host__ __device__ static constexpr auto + MakeCDesc_MBlockxRepeat_MWave_MThreadPerSubGroup_NBlockxRepeat_NWave_NSubGroup_NAccVgprs( + const CDesc_MBlockxRepeat_MWave_MPerWMMA_NBlockxRepeat_NWave_NPerWMMA& + c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma) + { + const auto MBlockxRepeat = + c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma.GetLength(I0); + const auto NBlockxRepeat = + c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma.GetLength(I3); + const auto MWave = + c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma.GetLength(I1); + const auto NWave = + c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma.GetLength(I4); + + return transform_tensor_descriptor( + c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma, + make_tuple( + make_pass_through_transform(MBlockxRepeat), + make_pass_through_transform(MWave), + make_pass_through_transform(Number{}), + make_pass_through_transform(NBlockxRepeat), + make_pass_through_transform(NWave), + make_unmerge_transform(make_tuple(Number{}, + Number{}))), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5, 6>{})); + } + __device__ static constexpr index_t GetRegSizePerWmma() { - return wmma_instr.num_acc_vgprs_per_wave; + return wmma_instr.num_acc_vgprs_per_wave * wmma_instr.acc_pack_number; } __device__ static constexpr index_t GetWaveSize() { return wmma_instr.wave_size; } @@ -449,14 +486,16 @@ struct WmmaGemm , "base type couple must be (half, float), (bhalf, float), (half, half), (bhalf, bhalf), " "(int8, int32) or (int4, int32)!"); - if constexpr(!TransposeC) - { - wmma_instr.template run(p_a_wave, p_b_wave, p_c_thread); - } - else - { - wmma_instr.template run(p_b_wave, p_a_wave, p_c_thread); - } + static_for<0, KPack / wmma_instr.k_per_wmma, 1>{}([&](auto k) { + if constexpr(!TransposeC) + { + wmma_instr.template run(p_a_wave[k], p_b_wave[k], p_c_thread); + } + else + { + wmma_instr.template run(p_b_wave[k], p_a_wave[k], p_c_thread); + } + }); } __device__ static auto GetLaneId() { return get_thread_local_1d_id() % wmma_instr.wave_size; } @@ -477,12 +516,12 @@ struct WmmaGemm __host__ __device__ static auto CalculateAThreadOriginDataIndex() { - return GetSwizzledLaneIdLow(); + return TransposeC ? GetLaneIdUnderSubGroup() : GetSwizzledLaneIdLow(); } __host__ __device__ static auto CalculateBThreadOriginDataIndex() { - return GetLaneIdUnderSubGroup(); + return TransposeC ? GetSwizzledLaneIdLow() : GetLaneIdUnderSubGroup(); } __device__ static CIndex GetBeginOfThreadBlk() @@ -493,6 +532,14 @@ struct WmmaGemm return TransposeC ? CIndex{n_offset, m_offset} : CIndex{m_offset, n_offset}; } + __device__ static CIndex3D GetBeginOfThreadBlk3D() + { + index_t n_offset = GetLaneIdUnderSubGroup(); + index_t m_offset = GetSubGroupId(); + + return TransposeC ? CIndex3D{n_offset, m_offset, I0} : CIndex3D{m_offset, n_offset, I0}; + } + static constexpr auto wmma = WmmaSelector{}; static constexpr auto wmma_instr = wmma.selected_wmma; @@ -500,7 +547,10 @@ struct WmmaGemm __host__ __device__ static constexpr auto GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths() { - return make_tuple(I1, I1, Number{}); + return make_tuple(I1, + I1, + Number{}, + Number{}); } }; diff --git a/include/ck/tensor_operation/operator_transform/transform_contraction_to_gemm_arraybase.hpp b/include/ck/tensor_operation/operator_transform/transform_contraction_to_gemm_arraybase.hpp new file mode 100644 index 0000000000..56181d38c8 --- /dev/null +++ b/include/ck/tensor_operation/operator_transform/transform_contraction_to_gemm_arraybase.hpp @@ -0,0 +1,391 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/tensor_specialization.hpp" + +namespace ck { +namespace tensor_operation { + +// assume C[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2...] +template +__host__ __device__ static auto +MakeGridDescriptorPair(const std::array& gs_ms_ns_lengths_vec, + const std::array& gs_ms_ns_strides_vec) +{ + // if(!(gs_ms_ns_lengths_vec.size() == NumDimG + NumDimM + NumDimN && + // gs_ms_ns_strides_vec.size() == NumDimG + NumDimM + NumDimN)) + // { + // throw std::runtime_error("wrong! dimension must match input lengths"); + // } + + const auto to_tuple = [&](auto& vec, auto start, auto end) { + return generate_tuple([&](auto i) { return vec[start + i]; }, Number{}); + }; + + const auto gs_ms_ns_lengths = + to_tuple(gs_ms_ns_lengths_vec, Number<0>{}, Number{}); + const auto gs_ms_ns_strides = + to_tuple(gs_ms_ns_strides_vec, Number<0>{}, Number{}); + + // dimension Ids for G0, G1, ... + constexpr auto gDimIds = typename arithmetic_sequence_gen<0, NumDimG, 1>::type{}; + + // dimension Ids for M0, M1, ... + constexpr auto mDimIds = + typename arithmetic_sequence_gen::type{}; + + // dimension Ids for N0, N1, ... + constexpr auto nDimIds = + typename arithmetic_sequence_gen::type{}; + + // lengths for G0, G1, ... + const auto gLengths = get_container_subset(gs_ms_ns_lengths, gDimIds); + + // lengths for M0, M1, ... + const auto mLengths = get_container_subset(gs_ms_ns_lengths, mDimIds); + + // lengths for N0, N1, ... + const auto nLengths = get_container_subset(gs_ms_ns_lengths, nDimIds); + + if constexpr(TensorSpec == device::TensorSpecialization::Packed) + { + auto G = container_reduce(gLengths, math::multiplies{}, Number<1>{}); + auto M = container_reduce(mLengths, math::multiplies{}, Number<1>{}); + auto N = container_reduce(nLengths, math::multiplies{}, Number<1>{}); + const auto grid_desc_g_mraw_nraw = make_naive_tensor_descriptor( + make_tuple(G, M, N), + make_tuple(gs_ms_ns_strides[Number{}], + gs_ms_ns_strides[Number{}], + gs_ms_ns_strides[Number{}])); + + const auto grid_desc_mraw_nraw = make_naive_tensor_descriptor( + make_tuple(M, N), + make_tuple(gs_ms_ns_strides[Number{}], + gs_ms_ns_strides[Number{}])); + + return std::make_pair(grid_desc_g_mraw_nraw, grid_desc_mraw_nraw); + } + else + { + // naive tensor C[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2...] + const auto grid_desc_gs_ms_ns = + make_naive_tensor_descriptor(gs_ms_ns_lengths, gs_ms_ns_strides); + + // transformed tensor C[G = G0 * G1 * ..., MRaw = M0 * M1 * M2 * ... , NRaw = N0 * N1 * + // N2 * ...] + // Note: This does not require padding as it only provides G offset calculation. Technically + // descriptor for only G is needed. Here we opt for backward compatibility purpose to return + // G_M_N + const auto grid_desc_g_mraw_nraw = + transform_tensor_descriptor(grid_desc_gs_ms_ns, + make_tuple(make_merge_transform(gLengths), + make_merge_transform(mLengths), + make_merge_transform(nLengths)), + make_tuple(gDimIds, mDimIds, nDimIds), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + const auto c_ms_ns_lengths = to_tuple( + gs_ms_ns_lengths_vec, Number{}, Number{}); + const auto c_ms_ns_strides = to_tuple( + gs_ms_ns_strides_vec, Number{}, Number{}); + + // transformed tensor C[MRaw = M0 * M1 * M2 * ... , NRaw = N0 * N1 * + // N2 * ...] + const auto grid_desc_ms_ns = make_naive_tensor_descriptor(c_ms_ns_lengths, c_ms_ns_strides); + + const auto grid_desc_mraw_nraw = transform_tensor_descriptor( + grid_desc_ms_ns, + make_tuple(make_merge_transform(mLengths), make_merge_transform(nLengths)), + make_tuple(mDimIds - Number{}, nDimIds - Number{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return std::make_pair(grid_desc_g_mraw_nraw, grid_desc_mraw_nraw); + } +} + +template + typename PerBlock_M_N_K_O, // Sequence<> + device::GemmSpecialization GemmSpec, + device::TensorSpecialization ASpec, + device::TensorSpecialization B0Spec, + device::TensorSpecialization B1Spec, + device::TensorSpecialization CSpec> +struct TransformBatchedContractionContractionToBatchedGemmGemm_Wmma +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + + static constexpr index_t NumDimG = NumDims_G_M_N_K_O::At(I0); + static constexpr index_t NumDimM = NumDims_G_M_N_K_O::At(I1); + static constexpr index_t NumDimN = NumDims_G_M_N_K_O::At(I2); + static constexpr index_t NumDimK = NumDims_G_M_N_K_O::At(I3); + static constexpr index_t NumDimO = NumDims_G_M_N_K_O::At(I4); + + static constexpr index_t MPerBlock = PerBlock_M_N_K_O::At(I0); + static constexpr index_t NPerBlock = PerBlock_M_N_K_O::At(I1); + static constexpr index_t KPerBlock = PerBlock_M_N_K_O::At(I2); + static constexpr index_t OPerBlock = PerBlock_M_N_K_O::At(I3); + + static constexpr auto matrix_padder = + device::GemmGemmPadder{ + MPerBlock, NPerBlock, KPerBlock, OPerBlock}; + + // + // A + // + __host__ __device__ static auto MakeAGridDescriptorPair( + const std::array& a_gs_ms_ks_lengths_vec, + const std::array& a_gs_ms_ks_strides_vec) + { + return MakeGridDescriptorPair(a_gs_ms_ks_lengths_vec, + a_gs_ms_ks_strides_vec); + } + + // TODO: rename to G_MRaw_KRaw + __host__ __device__ static auto MakeAGridDescriptor_G_M_K( + const std::array& a_gs_ms_ks_lengths_vec, + const std::array& a_gs_ms_ks_strides_vec) + { + return MakeAGridDescriptorPair(a_gs_ms_ks_lengths_vec, a_gs_ms_ks_strides_vec).first; + } + __host__ __device__ static auto MakeAGridDescriptor_M_K( + const std::array& a_gs_ms_ks_lengths_vec, + const std::array& a_gs_ms_ks_strides_vec) + { + return matrix_padder.PadADescriptor_M_K( + MakeAGridDescriptorPair(a_gs_ms_ks_lengths_vec, a_gs_ms_ks_strides_vec).second); + } + + template + __host__ __device__ static constexpr auto + MakeAGridDescriptor_AK0_M_AK1(const AGridDesc_M_K& a_grid_desc_m_k, const Number& AK1) + { + const auto M = a_grid_desc_m_k.GetLength(I0); + const auto K = a_grid_desc_m_k.GetLength(I1); + + const auto AK0 = K / AK1; + + return transform_tensor_descriptor(a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + + template + __host__ __device__ static constexpr auto + MakeAGridDescriptor_AKWmma_MBlockRepeat_MWaves_AK0PerWmma_AKRow_MPerWmma_AK1( + const AGridDesc_M_K& a_grid_desc_m_k, + const WmmaK&, + const MRepeat&, + const MWaves&, + const MPerWmma&, + const AK1&) + { + const auto M0 = a_grid_desc_m_k.GetLength(I0) / MPerBlock; + const auto K = a_grid_desc_m_k.GetLength(I1); + const auto AKWmma = K / WmmaK{}; + constexpr auto AKRow = 2; + constexpr auto AK0PerWmma = WmmaK{} / AKRow / AK1{}; + + return transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform( + make_tuple(AKWmma, Number{}, Number{}, AK1{})), + make_unmerge_transform(make_tuple(M0 * MRepeat{}, MWaves{}, MPerWmma{}))), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 3, 4, 6>{}, Sequence<1, 2, 5>{})); + } + + // + // B (alias of B0) + // + __host__ __device__ static auto MakeB0GridDescriptorPair( + const std::array& b0_gs_ns_ks_lengths_vec, + const std::array& b0_gs_ns_ks_strides_vec) + { + return MakeGridDescriptorPair(b0_gs_ns_ks_lengths_vec, + b0_gs_ns_ks_strides_vec); + } + + // TODO: rename to G_MRaw_NRaw + __host__ __device__ static auto MakeB0GridDescriptor_G_N_K( + const std::array& b0_gs_ns_ks_lengths_vec, + const std::array& b0_gs_ns_ks_strides_vec) + { + return MakeB0GridDescriptorPair(b0_gs_ns_ks_lengths_vec, b0_gs_ns_ks_strides_vec).first; + } + __host__ __device__ static auto MakeB0GridDescriptor_N_K( + const std::array& b0_gs_ns_ks_lengths_vec, + const std::array& b0_gs_ns_ks_strides_vec) + { + // alias of matrix_padder.PadB0Descriptor_N_K + return matrix_padder.PadBDescriptor_N_K( + MakeB0GridDescriptorPair(b0_gs_ns_ks_lengths_vec, b0_gs_ns_ks_strides_vec).second); + } + + template + __host__ __device__ static constexpr auto + MakeB0GridDescriptor_BK0_N_BK1(const BGridDesc_N_K& b_grid_desc_n_k, const Number& BK1) + { + const auto N = b_grid_desc_n_k.GetLength(I0); + const auto K = b_grid_desc_n_k.GetLength(I1); + + const auto BK0 = K / BK1; + + return transform_tensor_descriptor(b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), + make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + + template + __host__ __device__ static constexpr auto + MakeB0GridDescriptor_BKWmma_LBlockRepeat_LWaves_BK0PerWmma_BKRow_LPerWmma_BK1( + const BGridDesc_L_K& b_grid_desc_l_k, + const WmmaK&, + const LRepeat&, + const LWaves&, + const LPerWmma&, + const BK1&) + { + const auto L0 = b_grid_desc_l_k.GetLength(I0) / NPerBlock; + const auto K = b_grid_desc_l_k.GetLength(I1); + const auto BKWmma = K / WmmaK{}; + constexpr auto BKRow = 2; + constexpr auto BK0PerWmma = WmmaK{} / BKRow / BK1{}; + + return transform_tensor_descriptor( + b_grid_desc_l_k, + make_tuple(make_unmerge_transform( + make_tuple(BKWmma, Number{}, Number{}, BK1{})), + make_unmerge_transform(make_tuple(L0 * LRepeat{}, LWaves{}, LPerWmma{}))), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 3, 4, 6>{}, Sequence<1, 2, 5>{})); + } + + // + // B1 + // + __host__ __device__ static auto MakeB1GridDescriptorPair( + const std::array& b1_gs_os_ns_lengths_vec, + const std::array& b1_gs_os_ns_strides_vec) + { + return MakeGridDescriptorPair(b1_gs_os_ns_lengths_vec, + b1_gs_os_ns_strides_vec); + } + + // TODO: rename to G_NRaw_KRaw + __host__ __device__ static auto MakeB1GridDescriptor_G_N_K( + const std::array& b1_gs_os_ns_lengths_vec, + const std::array& b1_gs_os_ns_strides_vec) + { + return MakeB1GridDescriptorPair(b1_gs_os_ns_lengths_vec, b1_gs_os_ns_strides_vec).first; + } + __host__ __device__ static auto MakeB1GridDescriptor_N_K( + const std::array& b1_gs_os_ns_lengths_vec, + const std::array& b1_gs_os_ns_strides_vec) + { + // alias of matrix_padder.PadB1Descriptor_O_N + return matrix_padder.PadB1Descriptor_N_K( + MakeB1GridDescriptorPair(b1_gs_os_ns_lengths_vec, b1_gs_os_ns_strides_vec).second); + } + + template + __host__ __device__ static constexpr auto + MakeB1GridDescriptor_BK0_N_BK1(const B1GridDesc_N_K& b1_grid_desc_n_k, const Number& B1K1) + { + const auto N = b1_grid_desc_n_k.GetLength(I0); + const auto K = b1_grid_desc_n_k.GetLength(I1); + + const auto B1K0 = K / B1K1; + + return transform_tensor_descriptor( + b1_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(B1K0, B1K1)), + make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + + template + __host__ __device__ static constexpr auto + MakeB1GridDescriptor_BLWmma_NBlockRepeat_NWaves__BL0PerWmma_BLRow_NPerWmma_BL1( + const BGridDesc_N_L& b_grid_desc_n_l, + const WmmaL&, + const NRepeat&, + const NWaves&, + const NPerWmma&, + const BL1&) + { + const auto N0 = b_grid_desc_n_l.GetLength(I0) / OPerBlock; + const auto L = b_grid_desc_n_l.GetLength(I1); + const auto BLWmma = L / WmmaL{}; + constexpr auto BLRow = 2; + constexpr auto BL0PerWmma = WmmaL{} / BLRow / BL1{}; + + return transform_tensor_descriptor( + b_grid_desc_n_l, + make_tuple(make_unmerge_transform( + make_tuple(BLWmma, Number{}, Number{}, BL1{})), + make_unmerge_transform(make_tuple(N0 * NRepeat{}, NWaves{}, NPerWmma{}))), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 3, 4, 6>{}, Sequence<1, 2, 5>{})); + } + + // + // C + // + __host__ __device__ static auto MakeCGridDescriptorPair( + const std::array& c_gs_ms_os_lengths_vec, + const std::array& c_gs_ms_os_strides_vec) + { + return MakeGridDescriptorPair(c_gs_ms_os_lengths_vec, + c_gs_ms_os_strides_vec); + } + + // TODO: rename to G_MRaw_NRaw + __host__ __device__ static auto MakeCGridDescriptor_G_M_N( + const std::array& c_gs_ms_os_lengths_vec, + const std::array& c_gs_ms_os_strides_vec) + { + return MakeCGridDescriptorPair(c_gs_ms_os_lengths_vec, c_gs_ms_os_strides_vec).first; + } + __host__ __device__ static auto MakeCGridDescriptor_M_N( + const std::array& c_gs_ms_os_lengths_vec, + const std::array& c_gs_ms_os_strides_vec) + { + return matrix_padder.PadCDescriptor_M_N( + MakeCGridDescriptorPair(c_gs_ms_os_lengths_vec, c_gs_ms_os_strides_vec).second); + } +}; + +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/utility/amd_buffer_addressing.hpp b/include/ck/utility/amd_buffer_addressing.hpp index 2ea5419d09..678c55b95f 100644 --- a/include/ck/utility/amd_buffer_addressing.hpp +++ b/include/ck/utility/amd_buffer_addressing.hpp @@ -417,7 +417,8 @@ __device__ typename vector_type::type amd_buffer_load_impl(int32x4_t src_w (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || - (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)), + (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || + (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)), "wrong! not implemented"); using r_t = typename vector_type::type; diff --git a/include/ck/utility/amd_inline_asm.hpp b/include/ck/utility/amd_inline_asm.hpp index 43baa817d3..5dc67a5ade 100644 --- a/include/ck/utility/amd_inline_asm.hpp +++ b/include/ck/utility/amd_inline_asm.hpp @@ -220,8 +220,8 @@ amd_assembly_outer_product_1x2(int8x4_t a, int8x4_t b0, int8x4_t b1, int32_t& c0 "0"(c0), "1"(c1)); #else - c0 = __builtin_amdgcn_sdot4(bit_cast(a), bit_cast(b0), c0, false); - c1 = __builtin_amdgcn_sdot4(bit_cast(a), bit_cast(b1), c1, false); + c0 = __builtin_amdgcn_sdot4(bit_cast(a), bit_cast(b0), c0, false); + c1 = __builtin_amdgcn_sdot4(bit_cast(a), bit_cast(b1), c1, false); #endif } @@ -257,10 +257,10 @@ __device__ void amd_assembly_outer_product_1x4(int8x4_t a, "2"(c2), "3"(c3)); #else - c0 = __builtin_amdgcn_sdot4(bit_cast(a), bit_cast(b0), c0, false); - c1 = __builtin_amdgcn_sdot4(bit_cast(a), bit_cast(b1), c1, false); - c2 = __builtin_amdgcn_sdot4(bit_cast(a), bit_cast(b2), c2, false); - c3 = __builtin_amdgcn_sdot4(bit_cast(a), bit_cast(b3), c3, false); + c0 = __builtin_amdgcn_sdot4(bit_cast(a), bit_cast(b0), c0, false); + c1 = __builtin_amdgcn_sdot4(bit_cast(a), bit_cast(b1), c1, false); + c2 = __builtin_amdgcn_sdot4(bit_cast(a), bit_cast(b2), c2, false); + c3 = __builtin_amdgcn_sdot4(bit_cast(a), bit_cast(b3), c3, false); #endif } @@ -355,17 +355,5 @@ __device__ void amd_assembly_outer_product_1x4(int8x16_t a, c3); } -// Ranged input operand -__device__ void amd_assembly_wmma_f32_16x16x16_f16_w32(half16_t a, half16_t b, float8_t& c) -{ -#if defined(__gfx11__) - asm volatile("v_wmma_f32_16x16x16_f16 %0, %1, %2, %0" : "=v"(c) : "v"(a), "v"(b), "0"(c)); -#else - ignore = a; - ignore = b; - ignore = c; -#endif -} - } // namespace ck #endif diff --git a/include/ck/utility/data_type.hpp b/include/ck/utility/data_type.hpp index 31ae71880a..4d6791b5a7 100644 --- a/include/ck/utility/data_type.hpp +++ b/include/ck/utility/data_type.hpp @@ -133,6 +133,13 @@ struct scalar_type static constexpr index_t vector_size = 1; }; +template <> +struct scalar_type +{ + using type = uint8_t; + static constexpr index_t vector_size = 1; +}; + #ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 template <> struct scalar_type @@ -1037,6 +1044,14 @@ using bf8x8_t = typename vector_type::type; using bf8x16_t = typename vector_type::type; using bf8x32_t = typename vector_type::type; using bf8x64_t = typename vector_type::type; +// u8 +// i8 +using uint8x2_t = typename vector_type::type; +using uint8x4_t = typename vector_type::type; +using uint8x8_t = typename vector_type::type; +using uint8x16_t = typename vector_type::type; +using uint8x32_t = typename vector_type::type; +using uint8x64_t = typename vector_type::type; template struct NumericLimits diff --git a/include/ck/utility/type_convert.hpp b/include/ck/utility/type_convert.hpp index dbac1f0c85..be74b1fdc1 100644 --- a/include/ck/utility/type_convert.hpp +++ b/include/ck/utility/type_convert.hpp @@ -99,6 +99,63 @@ inline __host__ __device__ constexpr bhalf_t type_convert(int8_ return type_convert(x_fp32); } +// Convert X to Y +template +__host__ __device__ constexpr Y type_convert_sp(X x) +{ + static_assert(!std::is_reference_v && !std::is_reference_v); + + return static_cast(x); +} + +template <> +inline __host__ __device__ constexpr int type_convert_sp(float x) +{ + union + { + float fp32; + int int32; + } u = {x}; + + return u.int32; +} + +template <> +inline __host__ __device__ constexpr float type_convert_sp(int x) +{ + union + { + int int32; + float fp32; + } u = {x}; + + return u.fp32; +} + +template <> +inline __host__ __device__ constexpr int type_convert_sp(half_t x) +{ + union + { + half_t fp16; + int int32; + } u = {x}; + + return u.int32; +} + +template <> +inline __host__ __device__ constexpr half_t type_convert_sp(int x) +{ + union + { + int int32; + half_t fp16; + } u = {x}; + + return u.fp16; +} + // Declare a template function for fp8 conversion using SR template __host__ __device__ constexpr Y f8_convert_sr(X x); diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp index a1b1e0d91b..7a8e1d9a37 100644 --- a/library/include/ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp @@ -133,6 +133,252 @@ struct ReferenceBatchedGemm : public device::BaseOperator } }; +template +struct ReferenceBatchedGemm_MQA : public device::BaseOperator +{ + // Argument + struct Argument : public device::BaseArgument + { + Argument(const Tensor& a_g0_g1_m_k, + const Tensor& b_g0_1_k_n, + Tensor& c_g0_g1_m_n, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + : a_g0_g1_m_k_{a_g0_g1_m_k}, + b_g0_1_k_n_{b_g0_1_k_n}, + c_g0_g1_m_n_{c_g0_g1_m_n}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + c_element_op_{c_element_op} + { + } + + const Tensor& a_g0_g1_m_k_; + const Tensor& b_g0_1_k_n_; + Tensor& c_g0_g1_m_n_; + + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CElementwiseOperation c_element_op_; + }; + + // Invoker + struct Invoker : public device::BaseInvoker + { + using Argument = ReferenceBatchedGemm_MQA::Argument; + + float Run(const Argument& arg) + { + auto f_g0g1mk_g01kn_g0g1mn = [&](auto g0, auto g1, auto m, auto n) { + const int K = arg.a_g0_g1_m_k_.mDesc.GetLengths()[3]; + + AccDataType v_acc = 0; + + for(int k = 0; k < K; ++k) + { + ADataType v_a; + BDataType v_b; + + arg.a_element_op_(v_a, arg.a_g0_g1_m_k_(g0, g1, m, k)); + arg.b_element_op_(v_b, arg.b_g0_1_k_n_(g0, 0, k, n)); + + v_acc += + ck::type_convert(v_a) * ck::type_convert(v_b); + } + + AccDataType v_c; + + arg.c_element_op_(v_c, v_acc); + + arg.c_g0_g1_m_n_(g0, g1, m, n) = ck::type_convert(v_c); + }; + + make_ParallelTensorFunctor(f_g0g1mk_g01kn_g0g1mn, + arg.c_g0_g1_m_n_.mDesc.GetLengths()[0], + arg.c_g0_g1_m_n_.mDesc.GetLengths()[1], + arg.c_g0_g1_m_n_.mDesc.GetLengths()[2], + arg.c_g0_g1_m_n_.mDesc.GetLengths()[3])( + std::thread::hardware_concurrency()); + return 0; + } + + float Run(const device::BaseArgument* p_arg, + const StreamConfig& /* stream_config */ = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg)); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + bool IsSupportedArgument(const device::BaseArgument*) override { return true; } + + static auto MakeArgument(const Tensor& a_g0_g1_m_k, + const Tensor& b_g0_1_k_n, + Tensor& c_g0_g1_m_n, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + { + return Argument{ + a_g0_g1_m_k, b_g0_1_k_n, c_g0_g1_m_n, a_element_op, b_element_op, c_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + virtual std::unique_ptr MakeInvokerPointer() + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "ReferenceBatchedGemm_MQA" + << std::endl; + // clang-format on + + return str.str(); + } +}; + +template +struct ReferenceBatchedGemm_GQA : public device::BaseOperator +{ + // Argument + struct Argument : public device::BaseArgument + { + Argument(const Tensor& a_g0_g1_m_k, + const Tensor& b_g0_gq_k_n, + Tensor& c_g0_g1_m_n, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + : a_g0_g1_m_k_{a_g0_g1_m_k}, + b_g0_gq_k_n_{b_g0_gq_k_n}, + c_g0_g1_m_n_{c_g0_g1_m_n}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + c_element_op_{c_element_op} + { + } + + const Tensor& a_g0_g1_m_k_; + const Tensor& b_g0_gq_k_n_; + Tensor& c_g0_g1_m_n_; + + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CElementwiseOperation c_element_op_; + }; + + // Invoker + struct Invoker : public device::BaseInvoker + { + using Argument = ReferenceBatchedGemm_GQA::Argument; + + float Run(const Argument& arg) + { + auto f_g0g1mk_g0gqkn_g0g1mn = [&](auto g0, auto g1, auto m, auto n) { + const int G1 = arg.a_g0_g1_m_k_.mDesc.GetLengths()[1]; + const int K = arg.a_g0_g1_m_k_.mDesc.GetLengths()[3]; + + AccDataType v_acc = 0; + + for(int k = 0; k < K; ++k) + { + ADataType v_a; + BDataType v_b; + + arg.a_element_op_(v_a, arg.a_g0_g1_m_k_(g0, g1, m, k)); + arg.b_element_op_(v_b, arg.b_g0_gq_k_n_(g0, g1 * QueryGroupNumber / G1, k, n)); + + v_acc += + ck::type_convert(v_a) * ck::type_convert(v_b); + } + + AccDataType v_c; + + arg.c_element_op_(v_c, v_acc); + + arg.c_g0_g1_m_n_(g0, g1, m, n) = ck::type_convert(v_c); + }; + + make_ParallelTensorFunctor(f_g0g1mk_g0gqkn_g0g1mn, + arg.c_g0_g1_m_n_.mDesc.GetLengths()[0], + arg.c_g0_g1_m_n_.mDesc.GetLengths()[1], + arg.c_g0_g1_m_n_.mDesc.GetLengths()[2], + arg.c_g0_g1_m_n_.mDesc.GetLengths()[3])( + std::thread::hardware_concurrency()); + return 0; + } + + float Run(const device::BaseArgument* p_arg, + const StreamConfig& /* stream_config */ = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg)); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + bool IsSupportedArgument(const device::BaseArgument*) override { return true; } + + static auto MakeArgument(const Tensor& a_g0_g1_m_k, + const Tensor& b_g0_gq_k_n, + Tensor& c_g0_g1_m_n, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + { + return Argument{ + a_g0_g1_m_k, b_g0_gq_k_n, c_g0_g1_m_n, a_element_op, b_element_op, c_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + virtual std::unique_ptr MakeInvokerPointer() + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "ReferenceBatchedGemm_GQA" + << std::endl; + // clang-format on + + return str.str(); + } +}; + } // namespace host } // namespace tensor_operation } // namespace ck diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_fpAintB_gemm.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_fpAintB_gemm.hpp new file mode 100644 index 0000000000..ac392f0906 --- /dev/null +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_fpAintB_gemm.hpp @@ -0,0 +1,177 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/device/device_base.hpp" +#include "ck/library/utility/host_tensor.hpp" + +namespace ck { +namespace tensor_operation { +namespace host { + +template +struct ReferencefpAintBGemm : public device::BaseOperator +{ + // Argument + struct Argument : public device::BaseArgument + { + Argument(const Tensor& a_m_k, + const Tensor& b_k_n, + const Tensor& scale_k_n, + Tensor& c_m_n, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + : a_m_k_{a_m_k}, + b_k_n_{b_k_n}, + scale_k_n_{scale_k_n}, + c_m_n_{c_m_n}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + c_element_op_{c_element_op} + { + } + + const Tensor& a_m_k_; + const Tensor& b_k_n_; + const Tensor& scale_k_n_; + Tensor& c_m_n_; + + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CElementwiseOperation c_element_op_; + }; + + // Invoker + struct Invoker : public device::BaseInvoker + { + using Argument = ReferencefpAintBGemm::Argument; + + float Run(const Argument& arg) + { + auto f_mk_kn_mn = [&](auto m, auto n) { + const int K = arg.a_m_k_.mDesc.GetLengths()[1]; + + AccDataType v_acc = 0; + + for(int k = 0; k < K; ++k) + { + ADataType v_a; + BDataType v_b; + ScaleDataType v_scale; + ADataType v_converted_b; + + // use PassThrough instead of ConvertBF16RTN for reference calculation + if constexpr(is_same_v) + { + ck::tensor_operation::element_wise::PassThrough{}(v_a, arg.a_m_k_(m, k)); + } + else + { + arg.a_element_op_(v_a, arg.a_m_k_(m, k)); + } + + // same for B matrix + if constexpr(is_same_v) + { + ck::tensor_operation::element_wise::PassThrough{}(v_b, arg.b_k_n_(k, n)); + } + else + { + arg.b_element_op_(v_b, arg.b_k_n_(k, n)); + } + + // same for scale matrix + if constexpr(is_same_v) + { + ck::tensor_operation::element_wise::PassThrough{}(v_scale, + arg.scale_k_n_(k, n)); + } + else + { + arg.b_element_op_(v_scale, arg.scale_k_n_(k, n)); + } + + v_converted_b = type_convert(v_b) * v_scale; + v_acc += ck::type_convert(v_a) * + ck::type_convert(v_converted_b); + } + + AccDataType v_c; + + arg.c_element_op_(v_c, v_acc); + + arg.c_m_n_(m, n) = ck::type_convert(v_c); + }; + + make_ParallelTensorFunctor( + f_mk_kn_mn, arg.c_m_n_.mDesc.GetLengths()[0], arg.c_m_n_.mDesc.GetLengths()[1])( + std::thread::hardware_concurrency()); + + return 0; + } + + float Run(const device::BaseArgument* p_arg, + const StreamConfig& /* stream_config */ = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg)); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + bool IsSupportedArgument(const device::BaseArgument*) override { return true; } + + static auto MakeArgument(const Tensor& a_m_k, + const Tensor& b_k_n, + const Tensor& scale_k_n, + Tensor& c_m_n, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + { + return Argument{a_m_k, b_k_n, scale_k_n, c_m_n, a_element_op, b_element_op, c_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + virtual std::unique_ptr MakeInvokerPointer() + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "ReferenceGemm" + << std::endl; + // clang-format on + + return str.str(); + } +}; + +} // namespace host +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/gemm.hpp b/library/include/ck/library/tensor_operation_instance/gpu/gemm.hpp index 31e5b72ea1..ee9d977096 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/gemm.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/gemm.hpp @@ -384,6 +384,26 @@ void add_device_gemm_xdl_c_shuffle_f16_f8_f16_mk_nk_mn_instances( instances); #endif +void add_device_gemm_wmma_f16_f16_f16_km_kn_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_wmma_f16_f16_f16_km_nk_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_wmma_f16_f16_f16_mk_kn_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_wmma_f16_f16_f16_mk_nk_mn_instances( + std::vector>>& + instances); + template && is_same_v && is_same_v) @@ -493,6 +514,7 @@ struct DeviceOperationInstanceFactory< add_device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instances(op_ptrs); add_device_gemm_xdl_c_shuffle_lds_direct_load_f16_f16_f16_mk_nk_mn_instances( op_ptrs); + add_device_gemm_wmma_f16_f16_f16_mk_nk_mn_instances(op_ptrs); } else if constexpr(is_same_v && is_same_v && is_same_v) @@ -505,6 +527,7 @@ struct DeviceOperationInstanceFactory< add_device_gemm_dpp_f16_f16_f16_km_kn_mn_irregular_instances(op_ptrs); #endif add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances(op_ptrs); + add_device_gemm_wmma_f16_f16_f16_km_kn_mn_instances(op_ptrs); } else if constexpr(is_same_v && is_same_v && is_same_v) @@ -517,6 +540,7 @@ struct DeviceOperationInstanceFactory< add_device_gemm_dpp_f16_f16_f16_km_nk_mn_irregular_instances(op_ptrs); #endif add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances(op_ptrs); + add_device_gemm_wmma_f16_f16_f16_km_nk_mn_instances(op_ptrs); } } #endif diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_instance.hpp index f925397832..4ea23ea1f9 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_instance.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_instance.hpp @@ -54,36 +54,36 @@ template using device_grouped_conv_fwd_wmma_f16_instances = std::tuple< // clang-format off - //########################################| NumDim| A| B| Ds| E| AData| BData| Ds| EData| AccData| CShuffle| A| B| CDE| ConvForward| GEMM| Block| MPer| NPer| KPer| K1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| DataType| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Size| Block| Block| Block| | WMMA| WMMA| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| - //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| - //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| Prefetch| Block| MPer| NPer| KPer| K1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Stage| Size| Block| Block| Block| | WMMA| WMMA| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // generic instance - DeviceGroupedConvFwdMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 1>, + DeviceGroupedConvFwdMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 1>, // blocksize=256 - DeviceGroupedConvFwdMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, // blocksize=128 - DeviceGroupedConvFwdMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, // blocksize=64 - DeviceGroupedConvFwdMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 2>, 8>, - DeviceGroupedConvFwdMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 2>, 8>, - DeviceGroupedConvFwdMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 2>, 8>, - DeviceGroupedConvFwdMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 2>, 8>, + DeviceGroupedConvFwdMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 2>, 8>, + DeviceGroupedConvFwdMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 2>, 8>, + DeviceGroupedConvFwdMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 2>, 8>, + DeviceGroupedConvFwdMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 2>, 8>, // blocksize=32 - DeviceGroupedConvFwdMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 2>, 8>, - DeviceGroupedConvFwdMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 2>, 8>, - DeviceGroupedConvFwdMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 2>, 8>, - DeviceGroupedConvFwdMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 2>, 8> + DeviceGroupedConvFwdMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 2>, 8>, + DeviceGroupedConvFwdMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 2>, 8>, + DeviceGroupedConvFwdMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 2>, 8>, + DeviceGroupedConvFwdMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 2>, 8> // clang-format on >; @@ -97,36 +97,36 @@ template using device_grouped_conv_fwd_wmma_i8_instances = std::tuple< // clang-format off - //########################################| NumDim| A| B| Ds| E| AData| BData| Ds| EData| AccData| CShuffle| A| B| CDE| ConvForward| GEMM| Block| MPer| NPer| KPer| K1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| DataType| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Size| Block| Block| Block| | WMMA| WMMA| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| - //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| - //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| Prefetch| Block| MPer| NPer| KPer| K1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Stage| Size| Block| Block| Block| | WMMA| WMMA| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | //generic instance - DeviceGroupedConvFwdMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 16, 1, 1, 1, S<1, 32, 1, 4>, 1>, + DeviceGroupedConvFwdMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 16, 1, 1, 1, S<1, 32, 1, 4>, 1>, // blocksize=256 - DeviceGroupedConvFwdMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 8>, 8>, // blocksize=128 - DeviceGroupedConvFwdMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 4>, 8>, // blocksize=64 - DeviceGroupedConvFwdMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 2>, 8>, - DeviceGroupedConvFwdMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 2>, 8>, - DeviceGroupedConvFwdMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 2>, 8>, - DeviceGroupedConvFwdMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 2>, 8>, + DeviceGroupedConvFwdMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 2>, 8>, + DeviceGroupedConvFwdMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 2>, 8>, + DeviceGroupedConvFwdMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 2>, 8>, + DeviceGroupedConvFwdMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 2>, 8>, // blocksize=32 - DeviceGroupedConvFwdMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 16, 1, 2>, 8>, - DeviceGroupedConvFwdMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 16, 1, 2>, 8>, - DeviceGroupedConvFwdMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 16, 1, 2>, 8>, - DeviceGroupedConvFwdMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 16, 1, 2>, 8> + DeviceGroupedConvFwdMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 16, 1, 2>, 8>, + DeviceGroupedConvFwdMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 16, 1, 2>, 8>, + DeviceGroupedConvFwdMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 16, 1, 2>, 8>, + DeviceGroupedConvFwdMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 16, 1, 2>, 8> // clang-format on >; diff --git a/library/src/tensor_operation_instance/gpu/gemm/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm/CMakeLists.txt index 3d243e3d56..e9cc1e854f 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/gemm/CMakeLists.txt @@ -111,6 +111,12 @@ list(APPEND GEMM_INSTANCES device_gemm_xdl_c_shuffle_fp8_fp8_fp8_km_kn_mn_instance.cpp device_gemm_xdl_c_shuffle_fp8_fp8_fp8_km_nk_mn_instance.cpp) +list(APPEND GEMM_INSTANCES + device_gemm_wmma_f16_f16_f16_mk_kn_mn_instance.cpp + device_gemm_wmma_f16_f16_f16_mk_nk_mn_instance.cpp + device_gemm_wmma_f16_f16_f16_km_kn_mn_instance.cpp + device_gemm_wmma_f16_f16_f16_km_nk_mn_instance.cpp) + add_instance_library(device_gemm_instance ${GEMM_INSTANCES}) set(ENABLE_PIPELINE_V2_OPT) diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_wmma_f16_f16_f16_km_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_wmma_f16_f16_f16_km_kn_mn_instance.cpp new file mode 100644 index 0000000000..f3665eb8d8 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_wmma_f16_f16_f16_km_kn_mn_instance.cpp @@ -0,0 +1,78 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +// static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +// Compilation parameters for a[k, m] * b[k, n] = c[m, n] +using device_gemm_wmma_f16_f16_f16_km_kn_mn_instances = std::tuple< + // clang-format off + //######################| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumPrefetch| Block| MPer| NPer| KPer| K1| MPer| NPer| M| N| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CShuffleBlockTransfer| + //######################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise|Specialization| | Size| Block| Block| Block| | WMMA| WMMA| Repeat| Repeat| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| MRepeat| ClusterLengths| ScalarPerVector| + //######################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerStore| PerStore| MBlock_MPerBlock| | + //######################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | + /* Prefetch 2, consume enormous vgpr resource*/ + // 8 Waves + DeviceGemmWmma_CShuffle< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 2, 256, 128, 128, 32, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 32, 1, 8>, 8>, + // 4 Waves + DeviceGemmWmma_CShuffle< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 2, 128, 128, 64, 64, 8, 16, 16, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 8>, + // 2 Waves + DeviceGemmWmma_CShuffle< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 2, 64, 64, 32, 32, 8, 16, 16, 4, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 16, 1, 4>, 8>, + // 1 Wave + DeviceGemmWmma_CShuffle< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 2, 32, 16, 16, 32, 8, 16, 16, 1, 1, S<2, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<2, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 16, 1, 2>, 8>, + /* Prefetch 1, prefer larger KPerBlock value for better latency hiding*/ + // 8 Waves + DeviceGemmWmma_CShuffle< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 256, 64, 8, 16, 16, 4, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmWmma_CShuffle< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 128, 64, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmWmma_CShuffle< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 160, 64, 8, 16, 16, 2, 5, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 64, 1, 4>, 8>, + // 4 Waves + DeviceGemmWmma_CShuffle< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGemmWmma_CShuffle< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 256, 64, 64, 8, 16, 16, 8, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGemmWmma_CShuffle< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 64, 256, 64, 8, 16, 16, 2, 8, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGemmWmma_CShuffle< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 64, 80, 64, 8, 16, 16, 1, 5, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 64, 1, 2>, 8>, + // 2 Waves + DeviceGemmWmma_CShuffle< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 16, 64, 64, 8, 16, 16, 1, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGemmWmma_CShuffle< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 64, 32, 64, 8, 16, 16, 4, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGemmWmma_CShuffle< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 32, 64, 64, 8, 16, 16, 2, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 16, 1, 4>, 8>, + // 1 Wave + DeviceGemmWmma_CShuffle< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 32, 16, 32, 64, 8, 16, 16, 1, 2, S<2, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<2, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 16, 1, 2>, 8>, + DeviceGemmWmma_CShuffle< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 32, 16, 16, 64, 8, 16, 16, 1, 1, S<2, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<2, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 16, 1, 2>, 8> + // clang-format on + >; + +void add_device_gemm_wmma_f16_f16_f16_km_kn_mn_instances( + std::vector>>& + instances) +{ + add_device_operation_instances(instances, device_gemm_wmma_f16_f16_f16_km_kn_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_wmma_f16_f16_f16_km_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_wmma_f16_f16_f16_km_nk_mn_instance.cpp new file mode 100644 index 0000000000..6726727e67 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_wmma_f16_f16_f16_km_nk_mn_instance.cpp @@ -0,0 +1,78 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +// static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +// Compilation parameters for a[k, m] * b[n, k] = c[m, n] +using device_gemm_wmma_f16_f16_f16_km_nk_mn_instances = std::tuple< + // clang-format off + //######################| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumPrefetch| Block| MPer| NPer| KPer| K1| MPer| NPer| M| N| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CShuffleBlockTransfer| + //######################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise|Specialization| | Size| Block| Block| Block| | WMMA| WMMA| Repeat| Repeat| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| MRepeat| ClusterLengths| ScalarPerVector| + //######################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerStore| PerStore| MBlock_MPerBlock| | + //######################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | + /* Prefetch 2, consume enormous vgpr resource*/ + // 8 Waves + DeviceGemmWmma_CShuffle< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 2, 256, 128, 128, 32, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8>, + // 4 Waves + DeviceGemmWmma_CShuffle< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 2, 128, 128, 64, 64, 8, 16, 16, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 8>, + // 2 Waves + DeviceGemmWmma_CShuffle< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 2, 64, 64, 32, 32, 8, 16, 16, 4, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8>, + // 1 Wave + DeviceGemmWmma_CShuffle< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 2, 32, 16, 16, 32, 8, 16, 16, 1, 1, S<2, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 2>, 8>, + /* Prefetch 1, prefer larger KPerBlock value for better latency hiding*/ + // 8 Waves + DeviceGemmWmma_CShuffle< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 256, 64, 8, 16, 16, 4, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmWmma_CShuffle< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 128, 64, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmWmma_CShuffle< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 160, 64, 8, 16, 16, 2, 5, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 64, 1, 4>, 8>, + // 4 Waves + DeviceGemmWmma_CShuffle< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGemmWmma_CShuffle< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 256, 64, 64, 8, 16, 16, 8, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGemmWmma_CShuffle< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 64, 256, 64, 8, 16, 16, 2, 8, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGemmWmma_CShuffle< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 64, 80, 64, 8, 16, 16, 1, 5, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 64, 1, 2>, 8>, + // 2 Waves + DeviceGemmWmma_CShuffle< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 16, 64, 64, 8, 16, 16, 1, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGemmWmma_CShuffle< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 64, 32, 64, 8, 16, 16, 4, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGemmWmma_CShuffle< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 32, 64, 64, 8, 16, 16, 2, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8>, + // 1 Wave + DeviceGemmWmma_CShuffle< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 32, 16, 32, 64, 8, 16, 16, 1, 2, S<2, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 2>, 8>, + DeviceGemmWmma_CShuffle< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 32, 16, 16, 64, 8, 16, 16, 1, 1, S<2, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 2>, 8> + // clang-format on + >; + +void add_device_gemm_wmma_f16_f16_f16_km_nk_mn_instances( + std::vector>>& + instances) +{ + add_device_operation_instances(instances, device_gemm_wmma_f16_f16_f16_km_nk_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_wmma_f16_f16_f16_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_wmma_f16_f16_f16_mk_kn_mn_instance.cpp new file mode 100644 index 0000000000..d526f17b56 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_wmma_f16_f16_f16_mk_kn_mn_instance.cpp @@ -0,0 +1,158 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +// static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +// Compilation parameters for a[m, k] * b[k, n] = c[m, n] +using device_gemm_wmma_f16_f16_f16_mk_kn_mn_instances = std::tuple< + // clang-format off + //######################| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumPrefetch| Block| MPer| NPer| KPer| K1| MPer| NPer| M| N| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CShuffleBlockTransfer| + //######################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise|Specialization| | Size| Block| Block| Block| | WMMA| WMMA| Repeat| Repeat| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| MRepeat| ClusterLengths| ScalarPerVector| + //######################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerStore| PerStore| MBlock_MPerBlock| | + //######################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | + /* Prefetch 2, consume enormous vgpr resource*/ + // 8 Waves + DeviceGemmWmma_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 2, 256, 128, 128, 32, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 32, 1, 8>, 8>, + // 4 Waves + DeviceGemmWmma_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 2, 128, 128, 64, 64, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 8>, + // 2 Waves + DeviceGemmWmma_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 2, 64, 64, 32, 32, 8, 16, 16, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 16, 1, 4>, 8>, + // 1 Wave + DeviceGemmWmma_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 2, 32, 16, 16, 32, 8, 16, 16, 1, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<2, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 16, 1, 2>, 8>, + /* Prefetch 1, prefer larger KPerBlock value for better latency hiding*/ + // 8 Waves + DeviceGemmWmma_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 256, 64, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmWmma_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 128, 64, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmWmma_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 160, 64, 8, 16, 16, 2, 5, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 64, 1, 4>, 8>, + // 4 Waves + DeviceGemmWmma_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGemmWmma_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 256, 64, 64, 8, 16, 16, 8, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGemmWmma_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 64, 256, 64, 8, 16, 16, 2, 8, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGemmWmma_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 64, 80, 64, 8, 16, 16, 1, 5, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 64, 1, 2>, 8>, + // 2 Waves + DeviceGemmWmma_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 16, 64, 64, 8, 16, 16, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGemmWmma_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 64, 32, 64, 8, 16, 16, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGemmWmma_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 32, 64, 64, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 16, 1, 4>, 8>, + // 1 Wave + DeviceGemmWmma_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 32, 16, 32, 64, 8, 16, 16, 1, 2, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<2, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 16, 1, 2>, 8>, + DeviceGemmWmma_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 32, 16, 16, 64, 8, 16, 16, 1, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<2, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 16, 1, 2>, 8> +#if 0 + /* Prefetch 2, consume enormous vgpr resource*/ + // 8 Waves + DeviceGemmWmma_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 2, 256, 128, 128, 32, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8>, + // 4 Waves + DeviceGemmWmma_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 2, 128, 128, 64, 64, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 4>, 8>, + // 2 Waves + DeviceGemmWmma_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 2, 64, 64, 32, 32, 8, 16, 16, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 16, 1, 4>, 8>, + // 1 Wave + DeviceGemmWmma_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 2, 32, 16, 16, 32, 8, 16, 16, 1, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<2, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 16, 1, 2>, 8>, + /* Prefetch 1, prefer larger KPerBlock value for better latency hiding*/ + // 8 Waves + DeviceGemmWmma_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 256, 64, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmWmma_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 128, 64, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmWmma_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 160, 64, 8, 16, 16, 2, 5, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 64, 1, 4>, 8>, + // 4 Waves + DeviceGemmWmma_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGemmWmma_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 256, 64, 64, 8, 16, 16, 8, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGemmWmma_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 64, 256, 64, 8, 16, 16, 2, 8, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGemmWmma_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 64, 80, 64, 8, 16, 16, 1, 5, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 64, 1, 2>, 8>, + // 2 Waves + DeviceGemmWmma_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 16, 64, 64, 8, 16, 16, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGemmWmma_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 64, 32, 64, 8, 16, 16, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGemmWmma_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 32, 64, 64, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 16, 1, 4>, 8>, + // 1 Wave + DeviceGemmWmma_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 32, 16, 32, 64, 8, 16, 16, 1, 2, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<2, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 16, 1, 2>, 8>, + DeviceGemmWmma_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 32, 16, 16, 64, 8, 16, 16, 1, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<2, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 16, 1, 2>, 8>, + /* Prefetch 2, consume enormous vgpr resource*/ + // 8 Waves + DeviceGemmWmma_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 2, 256, 128, 128, 32, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 32, 1, 8>, 8>, + // 4 Waves + DeviceGemmWmma_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 2, 128, 128, 64, 64, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 32, 1, 4>, 8>, + // 2 Waves + DeviceGemmWmma_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 2, 64, 64, 32, 32, 8, 16, 16, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 16, 1, 4>, 8>, + // 1 Wave + DeviceGemmWmma_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 2, 32, 16, 16, 32, 8, 16, 16, 1, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<2, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 16, 1, 2>, 8>, + /* Prefetch 1, prefer larger KPerBlock value for better latency hiding*/ + // 8 Waves + DeviceGemmWmma_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 256, 64, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmWmma_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 128, 64, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmWmma_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 160, 64, 8, 16, 16, 2, 5, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 64, 1, 4>, 8>, + // 4 Waves + DeviceGemmWmma_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGemmWmma_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 256, 64, 64, 8, 16, 16, 8, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGemmWmma_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 64, 256, 64, 8, 16, 16, 2, 8, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGemmWmma_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 64, 80, 64, 8, 16, 16, 1, 5, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 64, 1, 2>, 8>, + // 2 Waves + DeviceGemmWmma_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 16, 64, 64, 8, 16, 16, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGemmWmma_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 64, 32, 64, 8, 16, 16, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGemmWmma_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 32, 64, 64, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 16, 1, 4>, 8>, + // 1 Wave + DeviceGemmWmma_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 32, 16, 32, 64, 8, 16, 16, 1, 2, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<2, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 16, 1, 2>, 8>, + DeviceGemmWmma_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 32, 16, 16, 64, 8, 16, 16, 1, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<2, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 16, 1, 2>, 8>, + /* Prefetch 2, consume enormous vgpr resource*/ + // 8 Waves + DeviceGemmWmma_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 2, 256, 128, 128, 32, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8>, + // 4 Waves + DeviceGemmWmma_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 2, 128, 128, 64, 64, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 8>, + // 2 Waves + DeviceGemmWmma_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 2, 64, 64, 32, 32, 8, 16, 16, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8>, + // 1 Wave + DeviceGemmWmma_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 2, 32, 16, 16, 32, 8, 16, 16, 1, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<2, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, true, 1, 1, S<1, 16, 1, 2>, 8>, + /* Prefetch 1, prefer larger KPerBlock value for better latency hiding*/ + // 8 Waves + DeviceGemmWmma_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 256, 64, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmWmma_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 128, 64, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmWmma_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 160, 64, 8, 16, 16, 2, 5, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, true, 1, 1, S<1, 64, 1, 4>, 8>, + // 4 Waves + DeviceGemmWmma_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGemmWmma_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 256, 64, 64, 8, 16, 16, 8, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGemmWmma_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 64, 256, 64, 8, 16, 16, 2, 8, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGemmWmma_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 64, 80, 64, 8, 16, 16, 1, 5, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, true, 1, 1, S<1, 64, 1, 2>, 8>, + // 2 Waves + DeviceGemmWmma_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 16, 64, 64, 8, 16, 16, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGemmWmma_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 64, 32, 64, 8, 16, 16, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGemmWmma_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 32, 64, 64, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8>, + // 1 Wave + DeviceGemmWmma_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 32, 16, 32, 64, 8, 16, 16, 1, 2, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<2, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, true, 1, 1, S<1, 16, 1, 2>, 8>, + DeviceGemmWmma_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 32, 16, 16, 64, 8, 16, 16, 1, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<2, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, true, 1, 1, S<1, 16, 1, 2>, 8> +#endif + // clang-format on + >; + +void add_device_gemm_wmma_f16_f16_f16_mk_kn_mn_instances( + std::vector>>& + instances) +{ + add_device_operation_instances(instances, device_gemm_wmma_f16_f16_f16_mk_kn_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_wmma_f16_f16_f16_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_wmma_f16_f16_f16_mk_nk_mn_instance.cpp new file mode 100644 index 0000000000..eed856b6ca --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_wmma_f16_f16_f16_mk_nk_mn_instance.cpp @@ -0,0 +1,78 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +// static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +// Compilation parameters for a[m, k] * b[n, k] = c[m, n] +using device_gemm_wmma_f16_f16_f16_mk_nk_mn_instances = std::tuple< + // clang-format off + //######################| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumPrefetch| Block| MPer| NPer| KPer| K1| MPer| NPer| M| N| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CShuffleBlockTransfer| + //######################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Specialization| | Size| Block| Block| Block| | WMMA| WMMA| Repeat| Repeat| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| MRepeat| ClusterLengths| ScalarPerVector| + //######################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerStore| PerStore| MBlock_MPerBlock| | + //######################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | + /* Prefetch 2, consume enormous vgpr resource*/ + // 8 Waves + DeviceGemmWmma_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 2, 256, 128, 128, 32, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8>, + // 4 Waves + DeviceGemmWmma_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 2, 128, 128, 64, 64, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 8>, + // 2 Waves + DeviceGemmWmma_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 2, 64, 64, 32, 32, 8, 16, 16, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8>, + // 1 Wave + DeviceGemmWmma_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 2, 32, 16, 16, 32, 8, 16, 16, 1, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 2>, 8>, + /* Prefetch 1, prefer larger KPerBlock value for better latency hiding*/ + // 8 Waves + DeviceGemmWmma_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 256, 64, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmWmma_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 128, 64, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmWmma_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 160, 64, 8, 16, 16, 2, 5, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 64, 1, 4>, 8>, + // 4 Waves + DeviceGemmWmma_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGemmWmma_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 256, 64, 64, 8, 16, 16, 8, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGemmWmma_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 64, 256, 64, 8, 16, 16, 2, 8, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGemmWmma_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 64, 80, 64, 8, 16, 16, 1, 5, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 64, 1, 2>, 8>, + // 2 Waves + DeviceGemmWmma_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 16, 64, 64, 8, 16, 16, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGemmWmma_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 64, 32, 64, 8, 16, 16, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGemmWmma_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 32, 64, 64, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8>, + // 1 Wave + DeviceGemmWmma_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 32, 16, 32, 64, 8, 16, 16, 1, 2, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 2>, 8>, + DeviceGemmWmma_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 32, 16, 16, 64, 8, 16, 16, 1, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 2>, 8> + // clang-format on + >; + +void add_device_gemm_wmma_f16_f16_f16_mk_nk_mn_instances( + std::vector>>& + instances) +{ + add_device_operation_instances(instances, device_gemm_wmma_f16_f16_f16_mk_nk_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_bilinear/device_gemm_bilinear_wmma_c_shuffle_i8_i8_i8_i8_km_kn_mn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_bilinear/device_gemm_bilinear_wmma_c_shuffle_i8_i8_i8_i8_km_kn_mn_mn_instance.cpp index 73ea9cac07..dd055fabb8 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_bilinear/device_gemm_bilinear_wmma_c_shuffle_i8_i8_i8_i8_km_kn_mn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_bilinear/device_gemm_bilinear_wmma_c_shuffle_i8_i8_i8_i8_km_kn_mn_mn_instance.cpp @@ -36,32 +36,32 @@ static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecial // e[m, n] = bilinear(a[m, k] * b[k, n], d[m, n]) using device_gemm_bilinear_wmma_c_shuffle_i8_i8_i8_i8_km_kn_mn_mn_instances = std::tuple< // clang-format off - //################################| A| B| Ds| E| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //################################| Layout| Layout| Layout| Layout| Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | WMMA| WMMA| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| - //################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| - //################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGemmMultipleD_Wmma_CShuffle< Col, Row, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmDefault, 256, 128, 128, 4, 16, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, 1, 2, S<1, 32, 1, 8>, 16>, - DeviceGemmMultipleD_Wmma_CShuffle< Col, Row, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmDefault, 128, 64, 64, 4, 16, 16, 16, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, 1, 2, S<1, 32, 1, 4>, 16>, - DeviceGemmMultipleD_Wmma_CShuffle< Col, Row, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmDefault, 64, 32, 32, 4, 16, 16, 16, 1, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, 1, 2, S<1, 32, 1, 2>, 16>, - DeviceGemmMultipleD_Wmma_CShuffle< Col, Row, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmDefault, 32, 16, 16, 4, 16, 16, 16, 1, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, 1, 1, S<1, 16, 1, 2>, 8>, + //################################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| Prefetch| Block| MPer| NPer| K0Per| K1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //################################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Stage| Size| Block| Block| Block| | WMMA| WMMA| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmMultipleD_Wmma_CShuffle< Col, Row, Row_Tuple, Row, I8, I8, I32, I32, I8_Tuple, I8, PassThrough, PassThrough, Bilinear, GemmDefault, 1, 256, 128, 128, 64, 16, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, 1, 2, S<1, 32, 1, 8>, 16>, + DeviceGemmMultipleD_Wmma_CShuffle< Col, Row, Row_Tuple, Row, I8, I8, I32, I32, I8_Tuple, I8, PassThrough, PassThrough, Bilinear, GemmDefault, 1, 128, 64, 64, 64, 16, 16, 16, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, 1, 2, S<1, 32, 1, 4>, 16>, + DeviceGemmMultipleD_Wmma_CShuffle< Col, Row, Row_Tuple, Row, I8, I8, I32, I32, I8_Tuple, I8, PassThrough, PassThrough, Bilinear, GemmDefault, 1, 64, 32, 32, 64, 16, 16, 16, 1, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, 1, 2, S<1, 32, 1, 2>, 16>, + DeviceGemmMultipleD_Wmma_CShuffle< Col, Row, Row_Tuple, Row, I8, I8, I32, I32, I8_Tuple, I8, PassThrough, PassThrough, Bilinear, GemmDefault, 1, 32, 16, 16, 64, 16, 16, 16, 1, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, 1, 1, S<1, 16, 1, 2>, 8>, // M/N/K padding - //################################| A| B| Ds| E| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //################################| Layout| Layout| Layout| Layout| Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | WMMA| WMMA| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| - //################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| - //################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGemmMultipleD_Wmma_CShuffle< Col, Row, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 256, 128, 128, 4, 16, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, 1, 2, S<1, 32, 1, 8>, 16>, - DeviceGemmMultipleD_Wmma_CShuffle< Col, Row, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 128, 64, 64, 4, 16, 16, 16, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, 1, 2, S<1, 32, 1, 4>, 16>, - DeviceGemmMultipleD_Wmma_CShuffle< Col, Row, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 64, 32, 32, 4, 16, 16, 16, 1, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, 1, 2, S<1, 32, 1, 2>, 16>, - DeviceGemmMultipleD_Wmma_CShuffle< Col, Row, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 32, 16, 16, 4, 16, 16, 16, 1, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, 1, 1, S<1, 16, 1, 2>, 8>, - DeviceGemmMultipleD_Wmma_CShuffle< Col, Row, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 256, 128, 128, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 2, S<1, 32, 1, 8>, 8>, - DeviceGemmMultipleD_Wmma_CShuffle< Col, Row, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 128, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 2, S<1, 32, 1, 4>, 8>, - DeviceGemmMultipleD_Wmma_CShuffle< Col, Row, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 64, 32, 32, 8, 8, 16, 16, 1, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 2, S<1, 32, 1, 2>, 8>, - DeviceGemmMultipleD_Wmma_CShuffle< Col, Row, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 32, 16, 16, 8, 8, 16, 16, 1, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 16, 1, 2>, 8>, - DeviceGemmMultipleD_Wmma_CShuffle< Col, Row, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 256, 128, 128, 8, 4, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 2, S<1, 32, 1, 8>, 4>, - DeviceGemmMultipleD_Wmma_CShuffle< Col, Row, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 128, 64, 64, 8, 4, 16, 16, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 2, S<1, 32, 1, 4>, 4>, - DeviceGemmMultipleD_Wmma_CShuffle< Col, Row, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 64, 32, 32, 8, 4, 16, 16, 1, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 2, S<1, 32, 1, 2>, 4>, - DeviceGemmMultipleD_Wmma_CShuffle< Col, Row, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 32, 16, 16, 8, 4, 16, 16, 1, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 16, 1, 2>, 4> + //################################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| Prefetch| Block| MPer| NPer| K0Per| K1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //################################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Stage| Size| Block| Block| Block| | WMMA| WMMA| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmMultipleD_Wmma_CShuffle< Col, Row, Row_Tuple, Row, I8, I8, I32, I32, I8_Tuple, I8, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 128, 64, 16, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, 1, 2, S<1, 32, 1, 8>, 16>, + DeviceGemmMultipleD_Wmma_CShuffle< Col, Row, Row_Tuple, Row, I8, I8, I32, I32, I8_Tuple, I8, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 64, 64, 64, 16, 16, 16, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, 1, 2, S<1, 32, 1, 4>, 16>, + DeviceGemmMultipleD_Wmma_CShuffle< Col, Row, Row_Tuple, Row, I8, I8, I32, I32, I8_Tuple, I8, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 64, 32, 32, 64, 16, 16, 16, 1, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, 1, 2, S<1, 32, 1, 2>, 16>, + DeviceGemmMultipleD_Wmma_CShuffle< Col, Row, Row_Tuple, Row, I8, I8, I32, I32, I8_Tuple, I8, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 32, 16, 16, 64, 16, 16, 16, 1, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, 1, 1, S<1, 16, 1, 2>, 8>, + DeviceGemmMultipleD_Wmma_CShuffle< Col, Row, Row_Tuple, Row, I8, I8, I32, I32, I8_Tuple, I8, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 128, 64, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 2, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Wmma_CShuffle< Col, Row, Row_Tuple, Row, I8, I8, I32, I32, I8_Tuple, I8, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 64, 64, 64, 8, 16, 16, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 2, S<1, 32, 1, 4>, 8>, + DeviceGemmMultipleD_Wmma_CShuffle< Col, Row, Row_Tuple, Row, I8, I8, I32, I32, I8_Tuple, I8, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 64, 32, 32, 64, 8, 16, 16, 1, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 2, S<1, 32, 1, 2>, 8>, + DeviceGemmMultipleD_Wmma_CShuffle< Col, Row, Row_Tuple, Row, I8, I8, I32, I32, I8_Tuple, I8, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 32, 16, 16, 64, 8, 16, 16, 1, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 16, 1, 2>, 8>, + DeviceGemmMultipleD_Wmma_CShuffle< Col, Row, Row_Tuple, Row, I8, I8, I32, I32, I8_Tuple, I8, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 128, 32, 4, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 2, S<1, 32, 1, 8>, 4>, + DeviceGemmMultipleD_Wmma_CShuffle< Col, Row, Row_Tuple, Row, I8, I8, I32, I32, I8_Tuple, I8, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 64, 64, 32, 4, 16, 16, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 2, S<1, 32, 1, 4>, 4>, + DeviceGemmMultipleD_Wmma_CShuffle< Col, Row, Row_Tuple, Row, I8, I8, I32, I32, I8_Tuple, I8, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 64, 32, 32, 32, 4, 16, 16, 1, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 2, S<1, 32, 1, 2>, 4>, + DeviceGemmMultipleD_Wmma_CShuffle< Col, Row, Row_Tuple, Row, I8, I8, I32, I32, I8_Tuple, I8, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 32, 16, 16, 32, 4, 16, 16, 1, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 16, 1, 2>, 4> // clang-format on >; diff --git a/library/src/tensor_operation_instance/gpu/gemm_bilinear/device_gemm_bilinear_wmma_c_shuffle_i8_i8_i8_i8_km_nk_mn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_bilinear/device_gemm_bilinear_wmma_c_shuffle_i8_i8_i8_i8_km_nk_mn_mn_instance.cpp index 1f36113e62..f607484363 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_bilinear/device_gemm_bilinear_wmma_c_shuffle_i8_i8_i8_i8_km_nk_mn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_bilinear/device_gemm_bilinear_wmma_c_shuffle_i8_i8_i8_i8_km_nk_mn_mn_instance.cpp @@ -36,32 +36,32 @@ static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecial // e[m, n] = bilinear(a[m, k] * b[k, n], d[m, n]) using device_gemm_bilinear_wmma_c_shuffle_i8_i8_i8_i8_km_nk_mn_mn_instances = std::tuple< // clang-format off - //################################| A| B| Ds| E| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //################################| Layout| Layout| Layout| Layout| Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | WMMA| WMMA| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| - //################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| - //################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGemmMultipleD_Wmma_CShuffle< Col, Col, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmDefault, 256, 128, 128, 4, 16, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 2, S<1, 32, 1, 8>, 16>, - DeviceGemmMultipleD_Wmma_CShuffle< Col, Col, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmDefault, 128, 64, 64, 4, 16, 16, 16, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 2, S<1, 32, 1, 4>, 16>, - DeviceGemmMultipleD_Wmma_CShuffle< Col, Col, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmDefault, 64, 32, 32, 4, 16, 16, 16, 1, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 2, S<1, 32, 1, 2>, 16>, - DeviceGemmMultipleD_Wmma_CShuffle< Col, Col, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmDefault, 32, 16, 16, 4, 16, 16, 16, 1, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 16, 1, 2>, 8>, + //################################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| Prefetch| Block| MPer| NPer| K0Per| K1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //################################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Stage| Size| Block| Block| Block| | WMMA| WMMA| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmMultipleD_Wmma_CShuffle< Col, Col, Row_Tuple, Row, I8, I8, I32, I32, I8_Tuple, I8, PassThrough, PassThrough, Bilinear, GemmDefault, 1, 256, 128, 128, 64, 16, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 2, S<1, 32, 1, 8>, 16>, + DeviceGemmMultipleD_Wmma_CShuffle< Col, Col, Row_Tuple, Row, I8, I8, I32, I32, I8_Tuple, I8, PassThrough, PassThrough, Bilinear, GemmDefault, 1, 128, 64, 64, 64, 16, 16, 16, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 2, S<1, 32, 1, 4>, 16>, + DeviceGemmMultipleD_Wmma_CShuffle< Col, Col, Row_Tuple, Row, I8, I8, I32, I32, I8_Tuple, I8, PassThrough, PassThrough, Bilinear, GemmDefault, 1, 64, 32, 32, 64, 16, 16, 16, 1, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 2, S<1, 32, 1, 2>, 16>, + DeviceGemmMultipleD_Wmma_CShuffle< Col, Col, Row_Tuple, Row, I8, I8, I32, I32, I8_Tuple, I8, PassThrough, PassThrough, Bilinear, GemmDefault, 1, 32, 16, 16, 64, 16, 16, 16, 1, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 16, 1, 2>, 8>, // M/N/K padding - //################################| A| B| Ds| E| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //################################| Layout| Layout| Layout| Layout| Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | WMMA| WMMA| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| - //################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| - //################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGemmMultipleD_Wmma_CShuffle< Col, Col, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 256, 128, 128, 4, 16, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 2, S<1, 32, 1, 8>, 16>, - DeviceGemmMultipleD_Wmma_CShuffle< Col, Col, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 128, 64, 64, 4, 16, 16, 16, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 2, S<1, 32, 1, 4>, 16>, - DeviceGemmMultipleD_Wmma_CShuffle< Col, Col, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 64, 32, 32, 4, 16, 16, 16, 1, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 2, S<1, 32, 1, 2>, 16>, - DeviceGemmMultipleD_Wmma_CShuffle< Col, Col, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 32, 16, 16, 4, 16, 16, 16, 1, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 16, 1, 2>, 8>, - DeviceGemmMultipleD_Wmma_CShuffle< Col, Col, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 256, 128, 128, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 2, S<1, 32, 1, 8>, 8>, - DeviceGemmMultipleD_Wmma_CShuffle< Col, Col, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 128, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 2, S<1, 32, 1, 4>, 8>, - DeviceGemmMultipleD_Wmma_CShuffle< Col, Col, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 64, 32, 32, 8, 8, 16, 16, 1, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 2, S<1, 32, 1, 2>, 8>, - DeviceGemmMultipleD_Wmma_CShuffle< Col, Col, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 32, 16, 16, 8, 8, 16, 16, 1, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 2>, 8>, - DeviceGemmMultipleD_Wmma_CShuffle< Col, Col, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 256, 128, 128, 8, 4, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 2, S<1, 32, 1, 8>, 4>, - DeviceGemmMultipleD_Wmma_CShuffle< Col, Col, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 128, 64, 64, 8, 4, 16, 16, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 2, S<1, 32, 1, 4>, 4>, - DeviceGemmMultipleD_Wmma_CShuffle< Col, Col, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 64, 32, 32, 8, 4, 16, 16, 1, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 2, S<1, 32, 1, 2>, 4>, - DeviceGemmMultipleD_Wmma_CShuffle< Col, Col, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 32, 16, 16, 8, 4, 16, 16, 1, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 2>, 4> + //################################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| Prefetch| Block| MPer| NPer| K0Per| K1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //################################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Stage| Size| Block| Block| Block| | WMMA| WMMA| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmMultipleD_Wmma_CShuffle< Col, Col, Row_Tuple, Row, I8, I8, I32, I32, I8_Tuple, I8, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 128, 64, 16, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 2, S<1, 32, 1, 8>, 16>, + DeviceGemmMultipleD_Wmma_CShuffle< Col, Col, Row_Tuple, Row, I8, I8, I32, I32, I8_Tuple, I8, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 64, 64, 64, 16, 16, 16, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 2, S<1, 32, 1, 4>, 16>, + DeviceGemmMultipleD_Wmma_CShuffle< Col, Col, Row_Tuple, Row, I8, I8, I32, I32, I8_Tuple, I8, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 64, 32, 32, 64, 16, 16, 16, 1, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 2, S<1, 32, 1, 2>, 16>, + DeviceGemmMultipleD_Wmma_CShuffle< Col, Col, Row_Tuple, Row, I8, I8, I32, I32, I8_Tuple, I8, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 32, 16, 16, 64, 16, 16, 16, 1, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 16, 1, 2>, 8>, + DeviceGemmMultipleD_Wmma_CShuffle< Col, Col, Row_Tuple, Row, I8, I8, I32, I32, I8_Tuple, I8, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 128, 64, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 2, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Wmma_CShuffle< Col, Col, Row_Tuple, Row, I8, I8, I32, I32, I8_Tuple, I8, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 64, 64, 64, 8, 16, 16, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 2, S<1, 32, 1, 4>, 8>, + DeviceGemmMultipleD_Wmma_CShuffle< Col, Col, Row_Tuple, Row, I8, I8, I32, I32, I8_Tuple, I8, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 64, 32, 32, 64, 8, 16, 16, 1, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 2, S<1, 32, 1, 2>, 8>, + DeviceGemmMultipleD_Wmma_CShuffle< Col, Col, Row_Tuple, Row, I8, I8, I32, I32, I8_Tuple, I8, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 32, 16, 16, 64, 8, 16, 16, 1, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 2>, 8>, + DeviceGemmMultipleD_Wmma_CShuffle< Col, Col, Row_Tuple, Row, I8, I8, I32, I32, I8_Tuple, I8, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 128, 32, 4, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 2, S<1, 32, 1, 8>, 4>, + DeviceGemmMultipleD_Wmma_CShuffle< Col, Col, Row_Tuple, Row, I8, I8, I32, I32, I8_Tuple, I8, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 64, 64, 32, 4, 16, 16, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 2, S<1, 32, 1, 4>, 4>, + DeviceGemmMultipleD_Wmma_CShuffle< Col, Col, Row_Tuple, Row, I8, I8, I32, I32, I8_Tuple, I8, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 64, 32, 32, 32, 4, 16, 16, 1, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 2, S<1, 32, 1, 2>, 4>, + DeviceGemmMultipleD_Wmma_CShuffle< Col, Col, Row_Tuple, Row, I8, I8, I32, I32, I8_Tuple, I8, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 32, 16, 16, 32, 4, 16, 16, 1, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 2>, 4> // clang-format on >; diff --git a/library/src/tensor_operation_instance/gpu/gemm_bilinear/device_gemm_bilinear_wmma_c_shuffle_i8_i8_i8_i8_mk_kn_mn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_bilinear/device_gemm_bilinear_wmma_c_shuffle_i8_i8_i8_i8_mk_kn_mn_mn_instance.cpp index 688c463369..accb2f80b6 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_bilinear/device_gemm_bilinear_wmma_c_shuffle_i8_i8_i8_i8_mk_kn_mn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_bilinear/device_gemm_bilinear_wmma_c_shuffle_i8_i8_i8_i8_mk_kn_mn_mn_instance.cpp @@ -36,32 +36,32 @@ static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecial // e[m, n] = bilinear(a[m, k] * b[k, n], d[m, n]) using device_gemm_bilinear_wmma_c_shuffle_i8_i8_i8_i8_mk_kn_mn_mn_instances = std::tuple< // clang-format off - //################################| A| B| Ds| E| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //################################| Layout| Layout| Layout| Layout| Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | WMMA| WMMA| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| - //################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| - //################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGemmMultipleD_Wmma_CShuffle< Row, Row, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmDefault, 256, 128, 128, 4, 16, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, 1, 2, S<1, 32, 1, 8>, 16>, - DeviceGemmMultipleD_Wmma_CShuffle< Row, Row, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmDefault, 128, 64, 64, 4, 16, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, 1, 2, S<1, 32, 1, 4>, 16>, - DeviceGemmMultipleD_Wmma_CShuffle< Row, Row, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmDefault, 64, 32, 32, 4, 16, 16, 16, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, 1, 2, S<1, 32, 1, 2>, 16>, - DeviceGemmMultipleD_Wmma_CShuffle< Row, Row, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmDefault, 32, 16, 16, 4, 16, 16, 16, 1, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, 1, 1, S<1, 16, 1, 2>, 8>, + //################################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| Prefetch| Block| MPer| NPer| K0Per| K1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //################################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Stage| Size| Block| Block| Block| | WMMA| WMMA| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmMultipleD_Wmma_CShuffle< Row, Row, Row_Tuple, Row, I8, I8, I32, I32, I8_Tuple, I8, PassThrough, PassThrough, Bilinear, GemmDefault, 1, 256, 128, 128, 64, 16, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, 1, 2, S<1, 32, 1, 8>, 16>, + DeviceGemmMultipleD_Wmma_CShuffle< Row, Row, Row_Tuple, Row, I8, I8, I32, I32, I8_Tuple, I8, PassThrough, PassThrough, Bilinear, GemmDefault, 1, 128, 64, 64, 64, 16, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, 1, 2, S<1, 32, 1, 4>, 16>, + DeviceGemmMultipleD_Wmma_CShuffle< Row, Row, Row_Tuple, Row, I8, I8, I32, I32, I8_Tuple, I8, PassThrough, PassThrough, Bilinear, GemmDefault, 1, 64, 32, 32, 64, 16, 16, 16, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, 1, 2, S<1, 32, 1, 2>, 16>, + DeviceGemmMultipleD_Wmma_CShuffle< Row, Row, Row_Tuple, Row, I8, I8, I32, I32, I8_Tuple, I8, PassThrough, PassThrough, Bilinear, GemmDefault, 1, 32, 16, 16, 64, 16, 16, 16, 1, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, 1, 1, S<1, 16, 1, 2>, 8>, // M/N/K padding - //################################| A| B| Ds| E| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //################################| Layout| Layout| Layout| Layout| Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | WMMA| WMMA| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| - //################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| - //################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGemmMultipleD_Wmma_CShuffle< Row, Row, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 256, 128, 128, 4, 16, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, 1, 2, S<1, 32, 1, 8>, 16>, - DeviceGemmMultipleD_Wmma_CShuffle< Row, Row, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 128, 64, 64, 4, 16, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, 1, 2, S<1, 32, 1, 4>, 16>, - DeviceGemmMultipleD_Wmma_CShuffle< Row, Row, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 64, 32, 32, 4, 16, 16, 16, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, 1, 2, S<1, 32, 1, 2>, 16>, - DeviceGemmMultipleD_Wmma_CShuffle< Row, Row, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 32, 16, 16, 4, 16, 16, 16, 1, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, 1, 1, S<1, 16, 1, 2>, 8>, - DeviceGemmMultipleD_Wmma_CShuffle< Row, Row, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 256, 128, 128, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 2, S<1, 32, 1, 8>, 8>, - DeviceGemmMultipleD_Wmma_CShuffle< Row, Row, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 128, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 2, S<1, 32, 1, 4>, 8>, - DeviceGemmMultipleD_Wmma_CShuffle< Row, Row, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 64, 32, 32, 8, 8, 16, 16, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 2, S<1, 32, 1, 2>, 8>, - DeviceGemmMultipleD_Wmma_CShuffle< Row, Row, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 32, 16, 16, 8, 8, 16, 16, 1, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 16, 1, 2>, 8>, - DeviceGemmMultipleD_Wmma_CShuffle< Row, Row, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 256, 128, 128, 8, 4, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 2, S<1, 32, 1, 8>, 4>, - DeviceGemmMultipleD_Wmma_CShuffle< Row, Row, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 128, 64, 64, 8, 4, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 2, S<1, 32, 1, 4>, 4>, - DeviceGemmMultipleD_Wmma_CShuffle< Row, Row, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 64, 32, 32, 8, 4, 16, 16, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 2, S<1, 32, 1, 2>, 4>, - DeviceGemmMultipleD_Wmma_CShuffle< Row, Row, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 32, 16, 16, 8, 4, 16, 16, 1, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 16, 1, 2>, 4> + //################################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| Prefetch| Block| MPer| NPer| K0Per| K1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //################################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Stage| Size| Block| Block| Block| | WMMA| WMMA| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmMultipleD_Wmma_CShuffle< Row, Row, Row_Tuple, Row, I8, I8, I32, I32, I8_Tuple, I8, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 128, 64, 16, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, 1, 2, S<1, 32, 1, 8>, 16>, + DeviceGemmMultipleD_Wmma_CShuffle< Row, Row, Row_Tuple, Row, I8, I8, I32, I32, I8_Tuple, I8, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 64, 64, 64, 16, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, 1, 2, S<1, 32, 1, 4>, 16>, + DeviceGemmMultipleD_Wmma_CShuffle< Row, Row, Row_Tuple, Row, I8, I8, I32, I32, I8_Tuple, I8, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 64, 32, 32, 64, 16, 16, 16, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, 1, 2, S<1, 32, 1, 2>, 16>, + DeviceGemmMultipleD_Wmma_CShuffle< Row, Row, Row_Tuple, Row, I8, I8, I32, I32, I8_Tuple, I8, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 32, 16, 16, 64, 16, 16, 16, 1, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, 1, 1, S<1, 16, 1, 2>, 8>, + DeviceGemmMultipleD_Wmma_CShuffle< Row, Row, Row_Tuple, Row, I8, I8, I32, I32, I8_Tuple, I8, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 128, 64, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 2, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Wmma_CShuffle< Row, Row, Row_Tuple, Row, I8, I8, I32, I32, I8_Tuple, I8, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 64, 64, 64, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 2, S<1, 32, 1, 4>, 8>, + DeviceGemmMultipleD_Wmma_CShuffle< Row, Row, Row_Tuple, Row, I8, I8, I32, I32, I8_Tuple, I8, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 64, 32, 32, 64, 8, 16, 16, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 2, S<1, 32, 1, 2>, 8>, + DeviceGemmMultipleD_Wmma_CShuffle< Row, Row, Row_Tuple, Row, I8, I8, I32, I32, I8_Tuple, I8, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 32, 16, 16, 64, 8, 16, 16, 1, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 16, 1, 2>, 8>, + DeviceGemmMultipleD_Wmma_CShuffle< Row, Row, Row_Tuple, Row, I8, I8, I32, I32, I8_Tuple, I8, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 128, 32, 4, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 2, S<1, 32, 1, 8>, 4>, + DeviceGemmMultipleD_Wmma_CShuffle< Row, Row, Row_Tuple, Row, I8, I8, I32, I32, I8_Tuple, I8, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 64, 64, 32, 4, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 2, S<1, 32, 1, 4>, 4>, + DeviceGemmMultipleD_Wmma_CShuffle< Row, Row, Row_Tuple, Row, I8, I8, I32, I32, I8_Tuple, I8, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 64, 32, 32, 32, 4, 16, 16, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 2, S<1, 32, 1, 2>, 4>, + DeviceGemmMultipleD_Wmma_CShuffle< Row, Row, Row_Tuple, Row, I8, I8, I32, I32, I8_Tuple, I8, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 32, 16, 16, 32, 4, 16, 16, 1, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 16, 1, 2>, 4> // clang-format on >; diff --git a/library/src/tensor_operation_instance/gpu/gemm_bilinear/device_gemm_bilinear_wmma_c_shuffle_i8_i8_i8_i8_mk_nk_mn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_bilinear/device_gemm_bilinear_wmma_c_shuffle_i8_i8_i8_i8_mk_nk_mn_mn_instance.cpp index 5319bd8605..6a23b70321 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_bilinear/device_gemm_bilinear_wmma_c_shuffle_i8_i8_i8_i8_mk_nk_mn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_bilinear/device_gemm_bilinear_wmma_c_shuffle_i8_i8_i8_i8_mk_nk_mn_mn_instance.cpp @@ -38,56 +38,56 @@ using device_gemm_bilinear_wmma_c_shuffle_i8_i8_i8_i8_mk_nk_mn_mn_instances = st // clang-format off // no padding // N % 16 == 0 && K % 16 == 0 - //################################| A| B| Ds| E| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //################################| Layout| Layout| Layout| Layout| Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | WMMA| WMMA| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| - //################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| - //################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGemmMultipleD_Wmma_CShuffle< Row, Col, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmDefault, 256, 128, 128, 4, 16, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 2, S<1, 32, 1, 8>, 16>, - DeviceGemmMultipleD_Wmma_CShuffle< Row, Col, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmDefault, 128, 64, 64, 4, 16, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 2, S<1, 32, 1, 4>, 16>, - DeviceGemmMultipleD_Wmma_CShuffle< Row, Col, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmDefault, 64, 32, 32, 4, 16, 16, 16, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 2, S<1, 32, 1, 2>, 16>, - DeviceGemmMultipleD_Wmma_CShuffle< Row, Col, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmDefault, 32, 16, 16, 4, 16, 16, 16, 1, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 16, 1, 2>, 8>, + //################################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| Prefetch| Block| MPer| NPer| K0Per| K1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //################################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Stage| Size| Block| Block| Block| | WMMA| WMMA| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmMultipleD_Wmma_CShuffle< Row, Col, Row_Tuple, Row, I8, I8, I32, I32, I8_Tuple, I8, PassThrough, PassThrough, Bilinear, GemmDefault, 1, 256, 128, 128, 64, 16, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 2, S<1, 32, 1, 8>, 16>, + DeviceGemmMultipleD_Wmma_CShuffle< Row, Col, Row_Tuple, Row, I8, I8, I32, I32, I8_Tuple, I8, PassThrough, PassThrough, Bilinear, GemmDefault, 1, 128, 64, 64, 64, 16, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 2, S<1, 32, 1, 4>, 16>, + DeviceGemmMultipleD_Wmma_CShuffle< Row, Col, Row_Tuple, Row, I8, I8, I32, I32, I8_Tuple, I8, PassThrough, PassThrough, Bilinear, GemmDefault, 1, 64, 32, 32, 64, 16, 16, 16, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 2, S<1, 32, 1, 2>, 16>, + DeviceGemmMultipleD_Wmma_CShuffle< Row, Col, Row_Tuple, Row, I8, I8, I32, I32, I8_Tuple, I8, PassThrough, PassThrough, Bilinear, GemmDefault, 1, 32, 16, 16, 64, 16, 16, 16, 1, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 16, 1, 2>, 8>, // M/N/K padding // N % 16 == 0 && K % 16 == 0 - //################################| A| B| Ds| E| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //################################| Layout| Layout| Layout| Layout| Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | WMMA| WMMA| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| - //################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| - //################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGemmMultipleD_Wmma_CShuffle< Row, Col, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 256, 128, 128, 4, 16, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 2, S<1, 32, 1, 8>, 16>, - DeviceGemmMultipleD_Wmma_CShuffle< Row, Col, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 128, 64, 64, 4, 16, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 2, S<1, 32, 1, 4>, 16>, - DeviceGemmMultipleD_Wmma_CShuffle< Row, Col, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 64, 32, 32, 4, 16, 16, 16, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 2, S<1, 32, 1, 2>, 16>, - DeviceGemmMultipleD_Wmma_CShuffle< Row, Col, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 32, 16, 16, 4, 16, 16, 16, 1, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 16, 1, 2>, 8>, + //################################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| Prefetch| Block| MPer| NPer| K0Per| K1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //################################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Stage| Size| Block| Block| Block| | WMMA| WMMA| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmMultipleD_Wmma_CShuffle< Row, Col, Row_Tuple, Row, I8, I8, I32, I32, I8_Tuple, I8, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 128, 64, 16, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 2, S<1, 32, 1, 8>, 16>, + DeviceGemmMultipleD_Wmma_CShuffle< Row, Col, Row_Tuple, Row, I8, I8, I32, I32, I8_Tuple, I8, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 64, 64, 64, 16, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 2, S<1, 32, 1, 4>, 16>, + DeviceGemmMultipleD_Wmma_CShuffle< Row, Col, Row_Tuple, Row, I8, I8, I32, I32, I8_Tuple, I8, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 64, 32, 32, 64, 16, 16, 16, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 2, S<1, 32, 1, 2>, 16>, + DeviceGemmMultipleD_Wmma_CShuffle< Row, Col, Row_Tuple, Row, I8, I8, I32, I32, I8_Tuple, I8, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 32, 16, 16, 64, 16, 16, 16, 1, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 16, 1, 2>, 8>, // M/N/K padding // N % 8 == 0 && K % 8 == 0 - //################################| A| B| Ds| E| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //################################| Layout| Layout| Layout| Layout| Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | WMMA| WMMA| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| - //################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| - //################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGemmMultipleD_Wmma_CShuffle< Row, Col, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 256, 128, 128, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 2, S<1, 32, 1, 8>, 8>, - DeviceGemmMultipleD_Wmma_CShuffle< Row, Col, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 128, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 2, S<1, 32, 1, 4>, 8>, - DeviceGemmMultipleD_Wmma_CShuffle< Row, Col, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 64, 32, 32, 8, 8, 16, 16, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 2, S<1, 32, 1, 2>, 8>, - DeviceGemmMultipleD_Wmma_CShuffle< Row, Col, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 32, 16, 16, 8, 8, 16, 16, 1, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 2>, 8>, + //################################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| Prefetch| Block| MPer| NPer| K0Per| K1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //################################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Stage| Size| Block| Block| Block| | WMMA| WMMA| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmMultipleD_Wmma_CShuffle< Row, Col, Row_Tuple, Row, I8, I8, I32, I32, I8_Tuple, I8, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 128, 64, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 2, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Wmma_CShuffle< Row, Col, Row_Tuple, Row, I8, I8, I32, I32, I8_Tuple, I8, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 64, 64, 64, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 2, S<1, 32, 1, 4>, 8>, + DeviceGemmMultipleD_Wmma_CShuffle< Row, Col, Row_Tuple, Row, I8, I8, I32, I32, I8_Tuple, I8, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 64, 32, 32, 64, 8, 16, 16, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 2, S<1, 32, 1, 2>, 8>, + DeviceGemmMultipleD_Wmma_CShuffle< Row, Col, Row_Tuple, Row, I8, I8, I32, I32, I8_Tuple, I8, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 32, 16, 16, 64, 8, 16, 16, 1, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 2>, 8>, // M/N/K padding // N % 8 == 0 && K % 8 == 0 - //################################| A| B| Ds| E| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //################################| Layout| Layout| Layout| Layout| Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | WMMA| WMMA| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| - //################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| - //################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGemmMultipleD_Wmma_CShuffle< Row, Col, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 256, 128, 128, 8, 4, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 2, S<1, 32, 1, 8>, 4>, - DeviceGemmMultipleD_Wmma_CShuffle< Row, Col, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 128, 64, 64, 8, 4, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 2, S<1, 32, 1, 4>, 4>, - DeviceGemmMultipleD_Wmma_CShuffle< Row, Col, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 64, 32, 32, 8, 4, 16, 16, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 2, S<1, 32, 1, 2>, 4>, - DeviceGemmMultipleD_Wmma_CShuffle< Row, Col, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 32, 16, 16, 8, 4, 16, 16, 1, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 2>, 4>, + //################################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| Prefetch| Block| MPer| NPer| K0Per| K1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //################################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Stage| Size| Block| Block| Block| | WMMA| WMMA| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmMultipleD_Wmma_CShuffle< Row, Col, Row_Tuple, Row, I8, I8, I32, I32, I8_Tuple, I8, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 128, 32, 4, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 2, S<1, 32, 1, 8>, 4>, + DeviceGemmMultipleD_Wmma_CShuffle< Row, Col, Row_Tuple, Row, I8, I8, I32, I32, I8_Tuple, I8, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 64, 64, 32, 4, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 2, S<1, 32, 1, 4>, 4>, + DeviceGemmMultipleD_Wmma_CShuffle< Row, Col, Row_Tuple, Row, I8, I8, I32, I32, I8_Tuple, I8, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 64, 32, 32, 32, 4, 16, 16, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 2, S<1, 32, 1, 2>, 4>, + DeviceGemmMultipleD_Wmma_CShuffle< Row, Col, Row_Tuple, Row, I8, I8, I32, I32, I8_Tuple, I8, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 32, 16, 16, 32, 4, 16, 16, 1, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 2>, 4>, // M/N/K padding // N % 1 == 0 && K % 8 == 0 - //################################| A| B| Ds| E| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //################################| Layout| Layout| Layout| Layout| Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | WMMA| WMMA| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| - //################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| - //################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGemmMultipleD_Wmma_CShuffle< Row, Col, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 256, 128, 128, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 2, S<1, 32, 1, 8>, 1>, - DeviceGemmMultipleD_Wmma_CShuffle< Row, Col, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 128, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 2, S<1, 32, 1, 4>, 1>, - DeviceGemmMultipleD_Wmma_CShuffle< Row, Col, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 64, 32, 32, 8, 8, 16, 16, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 2, S<1, 32, 1, 2>, 1>, - DeviceGemmMultipleD_Wmma_CShuffle< Row, Col, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 32, 16, 16, 8, 8, 16, 16, 1, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 2>, 1> + //################################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| Prefetch| Block| MPer| NPer| K0Per| K1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //################################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Stage| Size| Block| Block| Block| | WMMA| WMMA| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmMultipleD_Wmma_CShuffle< Row, Col, Row_Tuple, Row, I8, I8, I32, I32, I8_Tuple, I8, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 128, 64, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 2, S<1, 32, 1, 8>, 1>, + DeviceGemmMultipleD_Wmma_CShuffle< Row, Col, Row_Tuple, Row, I8, I8, I32, I32, I8_Tuple, I8, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 64, 64, 64, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 2, S<1, 32, 1, 4>, 1>, + DeviceGemmMultipleD_Wmma_CShuffle< Row, Col, Row_Tuple, Row, I8, I8, I32, I32, I8_Tuple, I8, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 64, 32, 32, 64, 8, 16, 16, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 2, S<1, 32, 1, 2>, 1>, + DeviceGemmMultipleD_Wmma_CShuffle< Row, Col, Row_Tuple, Row, I8, I8, I32, I32, I8_Tuple, I8, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 32, 16, 16, 64, 8, 16, 16, 1, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 2>, 1> // clang-format on >; diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/CMakeLists.txt index d8bd0de692..93d5bd7422 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/CMakeLists.txt @@ -1,16 +1,18 @@ -add_instance_library(device_grouped_conv2d_bwd_data_instance - xdl/device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_f16_instance.cpp - xdl/device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_bf16_instance.cpp - xdl/device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_f32_instance.cpp - xdl/device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_f16_instance.cpp - xdl/device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_bf16_instance.cpp - xdl/device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_f32_instance.cpp +add_instance_library( + device_grouped_conv2d_bwd_data_instance + xdl/device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_f16_instance.cpp + xdl/device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_bf16_instance.cpp + xdl/device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_f32_instance.cpp + xdl/device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_f16_instance.cpp + xdl/device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_bf16_instance.cpp + xdl/device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_f32_instance.cpp - wmma/device_grouped_conv2d_bwd_data_wmma_gnhwc_gkyxc_gnhwk_f16_1x1s1p0_instance.cpp - wmma/device_grouped_conv2d_bwd_data_wmma_nhwgc_gkyxc_nhwgk_f16_1x1s1p0_instance.cpp - wmma/device_grouped_conv2d_bwd_data_wmma_gnhwc_gkyxc_gnhwk_i8_1x1s1p0_instance.cpp - wmma/device_grouped_conv2d_bwd_data_wmma_nhwgc_gkyxc_nhwgk_i8_1x1s1p0_instance.cpp - wmma/device_grouped_conv2d_bwd_data_wmma_gnhwc_gkyxc_gnhwk_f16_instance.cpp - wmma/device_grouped_conv2d_bwd_data_wmma_nhwgc_gkyxc_nhwgk_f16_instance.cpp - wmma/device_grouped_conv2d_bwd_data_wmma_gnhwc_gkyxc_gnhwk_i8_instance.cpp - wmma/device_grouped_conv2d_bwd_data_wmma_nhwgc_gkyxc_nhwgk_i8_instance.cpp) + wmma/device_grouped_conv2d_bwd_data_wmma_gnhwc_gkyxc_gnhwk_f16_1x1s1p0_instance.cpp + wmma/device_grouped_conv2d_bwd_data_wmma_nhwgc_gkyxc_nhwgk_f16_1x1s1p0_instance.cpp + wmma/device_grouped_conv2d_bwd_data_wmma_gnhwc_gkyxc_gnhwk_i8_1x1s1p0_instance.cpp + wmma/device_grouped_conv2d_bwd_data_wmma_nhwgc_gkyxc_nhwgk_i8_1x1s1p0_instance.cpp + wmma/device_grouped_conv2d_bwd_data_wmma_gnhwc_gkyxc_gnhwk_f16_instance.cpp + wmma/device_grouped_conv2d_bwd_data_wmma_nhwgc_gkyxc_nhwgk_f16_instance.cpp + wmma/device_grouped_conv2d_bwd_data_wmma_gnhwc_gkyxc_gnhwk_i8_instance.cpp + wmma/device_grouped_conv2d_bwd_data_wmma_nhwgc_gkyxc_nhwgk_i8_instance.cpp +) diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/CMakeLists.txt index 1542d611f7..2715a8cf21 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/CMakeLists.txt @@ -17,21 +17,21 @@ add_instance_library(device_grouped_conv2d_fwd_instance dl/device_grouped_conv2d_fwd_dl_nhwgc_gkyxc_nhwgk_f32_instance.cpp # WMMA # GNHWC, GKYXC, GNHWK - wmma/device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_instance.cpp - wmma/device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_instance.cpp - wmma/device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_1x1p0_instance.cpp - wmma/device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_1x1p0_instance.cpp - wmma/device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_1x1s1p0_instance.cpp - wmma/device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_1x1s1p0_instance.cpp - wmma/device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_oddc_instance.cpp - wmma/device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_oddc_instance.cpp - # NHWGC, GKYXC, NHWGK - wmma/device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_f16_instance.cpp - wmma/device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_i8_instance.cpp - wmma/device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_f16_1x1p0_instance.cpp - wmma/device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_i8_1x1p0_instance.cpp - wmma/device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_f16_1x1s1p0_instance.cpp - wmma/device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_i8_1x1s1p0_instance.cpp - wmma/device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_f16_oddc_instance.cpp - wmma/device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_i8_oddc_instance.cpp + wmma/device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_instance.cpp + wmma/device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_instance.cpp + wmma/device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_1x1p0_instance.cpp + wmma/device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_1x1p0_instance.cpp + wmma/device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_1x1s1p0_instance.cpp + wmma/device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_1x1s1p0_instance.cpp + wmma/device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_oddc_instance.cpp + wmma/device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_oddc_instance.cpp + ## NHWGC, GKYXC, NHWGK + wmma/device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_f16_instance.cpp + wmma/device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_i8_instance.cpp + wmma/device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_f16_1x1p0_instance.cpp + wmma/device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_i8_1x1p0_instance.cpp + wmma/device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_f16_1x1s1p0_instance.cpp + wmma/device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_i8_1x1s1p0_instance.cpp + wmma/device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_f16_oddc_instance.cpp + wmma/device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_i8_oddc_instance.cpp ) diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/CMakeLists.txt index bada661028..540ce3410b 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/CMakeLists.txt @@ -22,7 +22,8 @@ set(GROUPED_CONV3D_FWD wmma/device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_f16_oddc_instance.cpp wmma/device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_i8_oddc_instance.cpp wmma/device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_f16_oddc_instance.cpp - wmma/device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_i8_oddc_instance.cpp) + wmma/device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_i8_oddc_instance.cpp +) if((DTYPES MATCHES "fp8" AND DTYPES MATCHES "fp16") OR NOT DEFINED DTYPES) list(APPEND GROUPED_CONV3D_FWD diff --git a/test/grouped_convnd_bwd_data/CMakeLists.txt b/test/grouped_convnd_bwd_data/CMakeLists.txt index 9773e5a9c6..305c568ee9 100644 --- a/test/grouped_convnd_bwd_data/CMakeLists.txt +++ b/test/grouped_convnd_bwd_data/CMakeLists.txt @@ -1,5 +1,5 @@ list(APPEND gpu_list_xdl gfx908 gfx90a gfx940) -list(APPEND gpu_list_wmma gfx1100 gfx1101 gfx1102) +list(APPEND gpu_list_wmma gfx1100 gfx1101 gfx1102 gfx1103) set(target 0) foreach(gpu IN LISTS GPU_TARGETS) if(gpu IN_LIST gpu_list_xdl AND target EQUAL 0) diff --git a/test/grouped_convnd_bwd_weight/CMakeLists.txt b/test/grouped_convnd_bwd_weight/CMakeLists.txt index b167943c97..d7d6f8a3d6 100644 --- a/test/grouped_convnd_bwd_weight/CMakeLists.txt +++ b/test/grouped_convnd_bwd_weight/CMakeLists.txt @@ -1,5 +1,5 @@ list(APPEND gpu_list_xdl gfx908 gfx90a gfx940 gfx941 gfx942) -list(APPEND gpu_list_wmma gfx1100 gfx1101 gfx1102) +list(APPEND gpu_list_wmma gfx1100 gfx1101 gfx1102 gfx1103) set(target 0) foreach(gpu IN LISTS GPU_TARGETS)