From f69356b1d73ad616bee0f03eeefd7233b5b35ebb Mon Sep 17 00:00:00 2001 From: carlushuang Date: Wed, 28 Feb 2024 22:57:19 +0000 Subject: [PATCH] add code --- .gitignore | 2 + example/ck_tile/01_fmha/CMakeLists.txt | 41 + example/ck_tile/01_fmha/README.md | 94 + example/ck_tile/01_fmha/fmha_fwd.cpp | 521 +++++ example/ck_tile/01_fmha/fmha_fwd.hpp | 336 +++ example/ck_tile/01_fmha/generate.py | 500 ++++ example/ck_tile/01_fmha/mask.hpp | 108 + example/ck_tile/01_fmha/misc/gamc.png | Bin 0 -> 30073 bytes example/ck_tile/01_fmha/script/benchmark.sh | 21 + example/ck_tile/01_fmha/script/smoke_test.sh | 34 + example/ck_tile/01_fmha/utils.hpp | 91 + include/ck_tile/README.md | 40 + include/ck_tile/core.hpp | 56 + include/ck_tile/core/README.md | 18 + .../core/algorithm/cluster_descriptor.hpp | 38 + .../core/algorithm/coordinate_transform.hpp | 1664 +++++++++++++ .../core/algorithm/space_filling_curve.hpp | 165 ++ .../ck_tile/core/arch/amd_address_space.hpp | 20 + .../core/arch/amd_buffer_addressing.hpp | 2050 +++++++++++++++++ include/ck_tile/core/config.hpp | 56 + include/ck_tile/core/container/array.hpp | 201 ++ .../core/container/container_helper.hpp | 483 ++++ include/ck_tile/core/container/map.hpp | 164 ++ .../core/container/meta_data_buffer.hpp | 99 + .../ck_tile/core/container/multi_index.hpp | 99 + include/ck_tile/core/container/sequence.hpp | 1114 +++++++++ include/ck_tile/core/container/span.hpp | 78 + .../container/statically_indexed_array.hpp | 70 + include/ck_tile/core/container/tuple.hpp | 483 ++++ include/ck_tile/core/numeric/arithmetic.hpp | 116 + include/ck_tile/core/numeric/bfloat16.hpp | 263 +++ include/ck_tile/core/numeric/float8.hpp | 735 ++++++ include/ck_tile/core/numeric/half.hpp | 278 +++ include/ck_tile/core/numeric/integer.hpp | 13 + .../core/numeric/integral_constant.hpp | 82 + include/ck_tile/core/numeric/math.hpp | 309 +++ include/ck_tile/core/numeric/type_convert.hpp | 45 + include/ck_tile/core/numeric/vector_type.hpp | 304 +++ include/ck_tile/core/tensor/buffer_view.hpp | 1041 +++++++++ include/ck_tile/core/tensor/load_tile.hpp | 78 + include/ck_tile/core/tensor/null_tensor.hpp | 12 + .../ck_tile/core/tensor/null_tile_window.hpp | 87 + include/ck_tile/core/tensor/shuffle_tile.hpp | 171 ++ include/ck_tile/core/tensor/slice_tile.hpp | 94 + .../core/tensor/static_distributed_tensor.hpp | 180 ++ include/ck_tile/core/tensor/store_tile.hpp | 93 + include/ck_tile/core/tensor/sweep_tile.hpp | 30 + .../ck_tile/core/tensor/tensor_adaptor.hpp | 942 ++++++++ .../core/tensor/tensor_adaptor_coordinate.hpp | 257 +++ .../ck_tile/core/tensor/tensor_coordinate.hpp | 92 + .../ck_tile/core/tensor/tensor_descriptor.hpp | 472 ++++ include/ck_tile/core/tensor/tensor_view.hpp | 273 +++ .../ck_tile/core/tensor/tile_distribution.hpp | 754 ++++++ .../tensor/tile_distribution_encoding.hpp | 761 ++++++ .../ck_tile/core/tensor/tile_elementwise.hpp | 191 ++ include/ck_tile/core/tensor/tile_window.hpp | 735 ++++++ include/ck_tile/core/utility/bit_cast.hpp | 19 + include/ck_tile/core/utility/functional.hpp | 194 ++ include/ck_tile/core/utility/limits.hpp | 75 + include/ck_tile/core/utility/magic_div.hpp | 261 +++ include/ck_tile/core/utility/random.hpp | 64 + include/ck_tile/core/utility/to_sequence.hpp | 72 + include/ck_tile/core/utility/type_convert.hpp | 57 + include/ck_tile/core/utility/type_traits.hpp | 46 + include/ck_tile/host.hpp | 23 + include/ck_tile/host/arg_parser.hpp | 184 ++ include/ck_tile/host/check_err.hpp | 375 +++ include/ck_tile/host/device_memory.hpp | 112 + include/ck_tile/host/fill.hpp | 232 ++ include/ck_tile/host/hip_check_error.hpp | 35 + include/ck_tile/host/host_tensor.hpp | 495 ++++ include/ck_tile/host/kernel_launch.hpp | 166 ++ include/ck_tile/host/ranges.hpp | 71 + .../reference_batched_elementwise.hpp | 64 + .../host/reference/reference_batched_gemm.hpp | 50 + .../reference/reference_batched_masking.hpp | 32 + .../reference/reference_batched_softmax.hpp | 67 + .../ck_tile/host/reference/reference_gemm.hpp | 50 + .../host/reference/reference_im2col.hpp | 61 + .../host/reference/reference_reduce.hpp | 32 + .../host/reference/reference_softmax.hpp | 50 + include/ck_tile/host/stream_config.hpp | 17 + include/ck_tile/ops/common.hpp | 7 + include/ck_tile/ops/common/tensor_layout.hpp | 412 ++++ include/ck_tile/ops/epilogue.hpp | 7 + .../ops/epilogue/default_2d_epilogue.hpp | 50 + include/ck_tile/ops/fmha.hpp | 20 + .../ck_tile/ops/fmha/block/block_masking.hpp | 227 ++ .../ops/fmha/kernel/fmha_fwd_kernel.hpp | 698 ++++++ .../fmha/kernel/fmha_fwd_tile_partitioner.hpp | 54 + .../pipeline/block_fmha_pipeline_problem.hpp | 60 + .../pipeline/block_fmha_pipeline_qr_ks_vs.hpp | 581 +++++ .../block_fmha_pipeline_qr_ks_vs_async.hpp | 676 ++++++ ...pipeline_qr_ks_vs_async_default_policy.hpp | 18 + ..._fmha_pipeline_qr_ks_vs_default_policy.hpp | 18 + .../block_fmha_pipeline_qr_ks_vs_fp8.hpp | 507 ++++ .../pipeline/block_fmha_pipeline_qs_ks_vs.hpp | 573 +++++ ..._fmha_pipeline_qs_ks_vs_default_policy.hpp | 18 + ...k_fmha_pipeline_qx_ks_vs_custom_policy.hpp | 953 ++++++++ .../ops/fmha/pipeline/tile_fmha_shape.hpp | 47 + .../ops/fmha/pipeline/tile_fmha_traits.hpp | 28 + include/ck_tile/ops/gemm.hpp | 30 + .../block_gemm_areg_bgmem_creg_problem.hpp | 25 + .../block/block_gemm_areg_bgmem_creg_v1.hpp | 133 ++ ...gemm_areg_bgmem_creg_v1_default_policy.hpp | 110 + .../block_gemm_areg_bsmem_creg_problem.hpp | 26 + .../block/block_gemm_areg_bsmem_creg_v1.hpp | 337 +++ ..._gemm_areg_bsmem_creg_v1_custom_policy.hpp | 36 + ...gemm_areg_bsmem_creg_v1_default_policy.hpp | 55 + .../block/block_gemm_areg_bsmem_creg_v2.hpp | 225 ++ ..._gemm_areg_bsmem_creg_v2_custom_policy.hpp | 36 + ...gemm_areg_bsmem_creg_v2_default_policy.hpp | 45 + .../block_gemm_asmem_bsmem_creg_problem.hpp | 26 + .../block/block_gemm_asmem_bsmem_creg_v1.hpp | 212 ++ ...gemm_asmem_bsmem_creg_v1_custom_policy.hpp | 38 + ...emm_asmem_bsmem_creg_v1_default_policy.hpp | 54 + ...lock_gemm_pipeline_agmem_bgmem_creg_v1.hpp | 201 ++ ...ine_agmem_bgmem_creg_v1_default_policy.hpp | 251 ++ ...lock_gemm_pipeline_agmem_bgmem_creg_v2.hpp | 219 ++ ...ine_agmem_bgmem_creg_v2_default_policy.hpp | 18 + .../pipeline/block_gemm_pipeline_problem.hpp | 25 + .../ops/gemm/pipeline/tile_gemm_shape.hpp | 18 + include/ck_tile/ops/gemm/warp/warp_gemm.hpp | 103 + .../gemm/warp/warp_gemm_attribute_mfma.hpp | 455 ++++ .../warp/warp_gemm_attribute_mfma_impl.hpp | 247 ++ .../ops/gemm/warp/warp_gemm_dispatcher.hpp | 64 + .../ck_tile/ops/gemm/warp/warp_gemm_impl.hpp | 74 + include/ck_tile/ops/reduce.hpp | 6 + .../ck_tile/ops/reduce/block/block_reduce.hpp | 211 ++ include/ck_tile/remod.py | 76 + 130 files changed, 28268 insertions(+) create mode 100644 example/ck_tile/01_fmha/CMakeLists.txt create mode 100644 example/ck_tile/01_fmha/README.md create mode 100644 example/ck_tile/01_fmha/fmha_fwd.cpp create mode 100644 example/ck_tile/01_fmha/fmha_fwd.hpp create mode 100644 example/ck_tile/01_fmha/generate.py create mode 100644 example/ck_tile/01_fmha/mask.hpp create mode 100644 example/ck_tile/01_fmha/misc/gamc.png create mode 100644 example/ck_tile/01_fmha/script/benchmark.sh create mode 100644 example/ck_tile/01_fmha/script/smoke_test.sh create mode 100644 example/ck_tile/01_fmha/utils.hpp create mode 100644 include/ck_tile/README.md create mode 100644 include/ck_tile/core.hpp create mode 100644 include/ck_tile/core/README.md create mode 100644 include/ck_tile/core/algorithm/cluster_descriptor.hpp create mode 100644 include/ck_tile/core/algorithm/coordinate_transform.hpp create mode 100644 include/ck_tile/core/algorithm/space_filling_curve.hpp create mode 100644 include/ck_tile/core/arch/amd_address_space.hpp create mode 100644 include/ck_tile/core/arch/amd_buffer_addressing.hpp create mode 100644 include/ck_tile/core/config.hpp create mode 100644 include/ck_tile/core/container/array.hpp create mode 100644 include/ck_tile/core/container/container_helper.hpp create mode 100644 include/ck_tile/core/container/map.hpp create mode 100644 include/ck_tile/core/container/meta_data_buffer.hpp create mode 100644 include/ck_tile/core/container/multi_index.hpp create mode 100644 include/ck_tile/core/container/sequence.hpp create mode 100644 include/ck_tile/core/container/span.hpp create mode 100644 include/ck_tile/core/container/statically_indexed_array.hpp create mode 100644 include/ck_tile/core/container/tuple.hpp create mode 100644 include/ck_tile/core/numeric/arithmetic.hpp create mode 100644 include/ck_tile/core/numeric/bfloat16.hpp create mode 100644 include/ck_tile/core/numeric/float8.hpp create mode 100644 include/ck_tile/core/numeric/half.hpp create mode 100644 include/ck_tile/core/numeric/integer.hpp create mode 100644 include/ck_tile/core/numeric/integral_constant.hpp create mode 100644 include/ck_tile/core/numeric/math.hpp create mode 100644 include/ck_tile/core/numeric/type_convert.hpp create mode 100644 include/ck_tile/core/numeric/vector_type.hpp create mode 100644 include/ck_tile/core/tensor/buffer_view.hpp create mode 100644 include/ck_tile/core/tensor/load_tile.hpp create mode 100644 include/ck_tile/core/tensor/null_tensor.hpp create mode 100644 include/ck_tile/core/tensor/null_tile_window.hpp create mode 100644 include/ck_tile/core/tensor/shuffle_tile.hpp create mode 100644 include/ck_tile/core/tensor/slice_tile.hpp create mode 100644 include/ck_tile/core/tensor/static_distributed_tensor.hpp create mode 100644 include/ck_tile/core/tensor/store_tile.hpp create mode 100644 include/ck_tile/core/tensor/sweep_tile.hpp create mode 100644 include/ck_tile/core/tensor/tensor_adaptor.hpp create mode 100644 include/ck_tile/core/tensor/tensor_adaptor_coordinate.hpp create mode 100644 include/ck_tile/core/tensor/tensor_coordinate.hpp create mode 100644 include/ck_tile/core/tensor/tensor_descriptor.hpp create mode 100644 include/ck_tile/core/tensor/tensor_view.hpp create mode 100644 include/ck_tile/core/tensor/tile_distribution.hpp create mode 100644 include/ck_tile/core/tensor/tile_distribution_encoding.hpp create mode 100644 include/ck_tile/core/tensor/tile_elementwise.hpp create mode 100644 include/ck_tile/core/tensor/tile_window.hpp create mode 100644 include/ck_tile/core/utility/bit_cast.hpp create mode 100644 include/ck_tile/core/utility/functional.hpp create mode 100644 include/ck_tile/core/utility/limits.hpp create mode 100644 include/ck_tile/core/utility/magic_div.hpp create mode 100644 include/ck_tile/core/utility/random.hpp create mode 100644 include/ck_tile/core/utility/to_sequence.hpp create mode 100644 include/ck_tile/core/utility/type_convert.hpp create mode 100644 include/ck_tile/core/utility/type_traits.hpp create mode 100644 include/ck_tile/host.hpp create mode 100644 include/ck_tile/host/arg_parser.hpp create mode 100644 include/ck_tile/host/check_err.hpp create mode 100644 include/ck_tile/host/device_memory.hpp create mode 100644 include/ck_tile/host/fill.hpp create mode 100644 include/ck_tile/host/hip_check_error.hpp create mode 100644 include/ck_tile/host/host_tensor.hpp create mode 100644 include/ck_tile/host/kernel_launch.hpp create mode 100644 include/ck_tile/host/ranges.hpp create mode 100644 include/ck_tile/host/reference/reference_batched_elementwise.hpp create mode 100644 include/ck_tile/host/reference/reference_batched_gemm.hpp create mode 100644 include/ck_tile/host/reference/reference_batched_masking.hpp create mode 100644 include/ck_tile/host/reference/reference_batched_softmax.hpp create mode 100644 include/ck_tile/host/reference/reference_gemm.hpp create mode 100644 include/ck_tile/host/reference/reference_im2col.hpp create mode 100644 include/ck_tile/host/reference/reference_reduce.hpp create mode 100644 include/ck_tile/host/reference/reference_softmax.hpp create mode 100644 include/ck_tile/host/stream_config.hpp create mode 100644 include/ck_tile/ops/common.hpp create mode 100644 include/ck_tile/ops/common/tensor_layout.hpp create mode 100644 include/ck_tile/ops/epilogue.hpp create mode 100644 include/ck_tile/ops/epilogue/default_2d_epilogue.hpp create mode 100644 include/ck_tile/ops/fmha.hpp create mode 100644 include/ck_tile/ops/fmha/block/block_masking.hpp create mode 100644 include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp create mode 100644 include/ck_tile/ops/fmha/kernel/fmha_fwd_tile_partitioner.hpp create mode 100644 include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp create mode 100644 include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp create mode 100644 include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp create mode 100644 include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_default_policy.hpp create mode 100644 include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_default_policy.hpp create mode 100644 include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_fp8.hpp create mode 100644 include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs.hpp create mode 100644 include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs_default_policy.hpp create mode 100644 include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp create mode 100644 include/ck_tile/ops/fmha/pipeline/tile_fmha_shape.hpp create mode 100644 include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp create mode 100644 include/ck_tile/ops/gemm.hpp create mode 100644 include/ck_tile/ops/gemm/block/block_gemm_areg_bgmem_creg_problem.hpp create mode 100644 include/ck_tile/ops/gemm/block/block_gemm_areg_bgmem_creg_v1.hpp create mode 100644 include/ck_tile/ops/gemm/block/block_gemm_areg_bgmem_creg_v1_default_policy.hpp create mode 100644 include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_problem.hpp create mode 100644 include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1.hpp create mode 100644 include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_custom_policy.hpp create mode 100644 include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_default_policy.hpp create mode 100644 include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2.hpp create mode 100644 include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2_custom_policy.hpp create mode 100644 include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2_default_policy.hpp create mode 100644 include/ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_problem.hpp create mode 100644 include/ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1.hpp create mode 100644 include/ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_custom_policy.hpp create mode 100644 include/ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_default_policy.hpp create mode 100644 include/ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v1.hpp create mode 100644 include/ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp create mode 100644 include/ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v2.hpp create mode 100644 include/ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v2_default_policy.hpp create mode 100644 include/ck_tile/ops/gemm/pipeline/block_gemm_pipeline_problem.hpp create mode 100644 include/ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp create mode 100644 include/ck_tile/ops/gemm/warp/warp_gemm.hpp create mode 100644 include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp create mode 100644 include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp create mode 100644 include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp create mode 100644 include/ck_tile/ops/gemm/warp/warp_gemm_impl.hpp create mode 100644 include/ck_tile/ops/reduce.hpp create mode 100644 include/ck_tile/ops/reduce/block/block_reduce.hpp create mode 100644 include/ck_tile/remod.py diff --git a/.gitignore b/.gitignore index 090594a8df..f4d5ff7abd 100644 --- a/.gitignore +++ b/.gitignore @@ -64,3 +64,5 @@ build*/ # Python virtualenv .venv/ +# Python cache +__pycache__/ diff --git a/example/ck_tile/01_fmha/CMakeLists.txt b/example/ck_tile/01_fmha/CMakeLists.txt new file mode 100644 index 0000000000..d3b229daae --- /dev/null +++ b/example/ck_tile/01_fmha/CMakeLists.txt @@ -0,0 +1,41 @@ +# generate a list of kernels, but not actually emit files at config stage +execute_process( + COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py + --list_blobs ${CMAKE_CURRENT_BINARY_DIR}/blob_list.txt +) + +# NOTE: for cmake, the FMHA_FWD_GEN_BLOBS files must be in the same directory +# as current cmake list, otherwise will not figure out the dependency properly +file(STRINGS ${CMAKE_CURRENT_BINARY_DIR}/blob_list.txt FMHA_FWD_GEN_BLOBS) + +add_custom_command( + OUTPUT ${FMHA_FWD_GEN_BLOBS} + COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py + --output_dir ${CMAKE_CURRENT_BINARY_DIR} +) + +set(EXAMPLE_FMHA_FWD "example_fmha_fwd") +add_example_executable(${EXAMPLE_FMHA_FWD} fmha_fwd.cpp) +target_include_directories(${EXAMPLE_FMHA_FWD} PRIVATE ${CMAKE_CURRENT_LIST_DIR}) +target_sources(${EXAMPLE_FMHA_FWD} PRIVATE ${FMHA_FWD_GEN_BLOBS}) + +# NOTE: this is dangerous since will change the whole kernel to flush denormals +# WIP with compiler team for an exp2 intrinsic..., then remove this +if(NOT DEFINED FMHA_FWD_FAST_EXP2) + set(FMHA_FWD_FAST_EXP2 true) +endif() + +set(EXAMPLE_FMHA_FWD_COMPILE_OPTIONS) + +# NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations +# ... because they are auto-generated +if(FMHA_FWD_FAST_EXP2) + list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -Wno-undefined-func-template -DCK_FMHA_FWD_FAST_EXP2=1 -fgpu-flush-denormals-to-zero -v --save-temps -Wno-gnu-line-marker) +else() + list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -Wno-undefined-func-template -DCK_FMHA_FWD_FAST_EXP2=0 -v --save-temps -Wno-gnu-line-marker) +endif() + +# Allow comparing floating points directly in order to check sentinel values +list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -Wno-float-equal) + +target_compile_options(${EXAMPLE_FMHA_FWD} PRIVATE ${EXAMPLE_FMHA_FWD_COMPILE_OPTIONS}) diff --git a/example/ck_tile/01_fmha/README.md b/example/ck_tile/01_fmha/README.md new file mode 100644 index 0000000000..da3d412a09 --- /dev/null +++ b/example/ck_tile/01_fmha/README.md @@ -0,0 +1,94 @@ +# fused multi-head attention + +This folder contains example for fmha(fused multi-head attention) using ck_tile tile-programming implementation. It is a good example to demonstrate the usage of tile-programming API, as well as illustrate the new approach to construct a kernel template and instantiate it(them) while keeping compile time fast. + +## build +``` +# in the root of ck_tile +mkdir build && cd build +sh ../script/cmake-ck_tile-dev.sh ../ # you can replace this to gfx90a, gfx942... +make example_fmha_fwd -j +``` +This will result in an executable `build/bin/example_fmha_fwd` + +## kernel +The kernel template is `fmha_fwd_kernel.hpp`, this is the grid-wise op in old ck_tile's terminology. We put it here purposely, to demonstrate one can construct a kernel by using various internal component from ck_tile. We may still have an implementation under ck_tile's include path (in the future) for the kernel template. + +There are 3 template parameters for this kernel template. +* `TilePartitioner` is used to map the workgroup to corresponding tile, `fmha_fwd_tile_partitioner.hpp` in this folder served as this purpose. +* `FmhaPipeline` is one of the block_tile_pipeline(under `include/ck_tile/tile_program/block_tile_pipeline`) which is a performance critical component. Indeed, we did a lot of optimization and trials to optimize the pipeline and may still workout more performance pipeline and update into that folder. People only need to replace this pipeline type and would be able to enjoy the benefit of different performant implementations (stay tuned for updated pipeline(s)). +* `EpiloguePipeline` will modify and store out the result in the last phase. People usually will do lot of post-fusion at this stage, so we also abstract this concept. Currently we didn't do much thing at the epilogue stage but leave the room for future possible support. + +## codegen +To speed up compile time, we instantiate the kernels into separate file. In this way we can benefit from parallel building from CMake/Make system. This is achieved by `generate.py` script. Besides, you can look into this script to learn how to instantiate a kernel instance step by step, which is described in `FMHA_FWD_KERNEL_BODY` variable. + +## executable +`example_fmha_fwd` is the example executable, implemented in `fmha_fwd.cpp`. You can type `./bin/example_fmha_fwd -?` to list all supported args. Below is an example of the output (may subject to change) +``` +args: + -v weather do CPU validation or not (default:1) + -mode kernel mode. 0:batch, 1:group (default:0) + -b batch size (default:2) + -h num of head, for q (default:8) + -h_k num of head, for k/v, 0 means equal to h (default:0) + if not equal to h, then this is GQA/MQA case + -s seqlen_q (default:3328) + -s_k seqlen_k, 0 means equal to s (default:0) + -d head dim for q, k (default:128) + -d_v head dim for v, 0 means equal to d (default:0) + -scale scale factor. 0 means equal to 1/sqrt(hdim) (default:0) + -descale_q scale factor for fp8 quantization (default:1) + -descale_k scale factor for fp8 quantization (default:1) + -descale_v scale factor for fp8 quantization (default:1) + -iperm permute input (default:1) + if true, will be b*h*s*d, else b*s*h*d + -operm permute output (default:1) + -bias add bias or not (default:0) + -prec data type. fp16/bf16/fp8/bf8 (default:fp16) + -mask 0: no mask, 1: top-left, 2:bottom-right (default:0) + 't:l,r', top-left local-attn with left right size + 'b:l,r', bottom-r local-attn with left right size + 'g:y,x', generic attention mask coordinate with y/x size + + -vlayout r for row-major(seqlen*hdim), c for col-major(hdim*seqlen) (default:r) + -lse 0 not store lse, 1 store lse (default:0) + -kname if set to 1 will print kernel name (default:0) +``` +Example: `./bin/example_fmha_fwd -b=1 -h=16 -s=16384 -d=128` will run a fmha case with batch=1, nhead=16, sequence length=16384, hdim=128, fp16 case. + +## support features +Currently we are still in rapid development stage, so more features/optimizations will be coming soon. + +### hdim +Currently we support `32/64/128/256` hdim for `fp16`/`bf16`, within which `64`/`128` is better optimized. hdim should be multiple of 8, while seqlen_s can be arbitrary. For hdim be arbitrary number, it can be support through padding kernel of `qr` pipeline (we didn't generate this in generate.py by default) + +### group/batch mode +Currently we support both batch and group mode, by setting `-mode` = `0` or `1`, where in group mode we support each batch can have different seqlen + +### MQA/GQA +By setting `-h`(nhead for q) and `-h_k`(nhead for k/v) with different number, you can achieve MQA/GQA. Please pay attention that `h % h_K == 0` when you set different numbers. + +### input/output permute, and `b*s*3*h*d` +If you look at the kernel argument inside `fmha_fwd_kernel.hpp`, we support providing arbitrary stride for seqlen(stride_q/k/v), nhead, batch of q/k/v matrix, hence it is very flexible to support `b*h*s*d` or `b*s*h*d` input/output permute. The `-iperm=0/1`, `-operm=0/1` is a convenient way to achieve this through the executable. We didn't provide a command-line arg to test `b*s*3*h*d` layout which is by default used by torch/FA, but it's trivial to achieve this if one set the proper `stride_q/k/v` value as `3*h*d`. + +### attention bias +Attention bias is supported with the layout of `1*1*s*s`(similiar to input/output, different layout can be supported by changing the stride value for bias, or even extend to `b*h*s*s`) and bias value in float number. + +### lse +For training kernels, "log sum exp" need to store out in forward and used in backward. We support this by setting `-lse=1` + +### vlayout +We support v matrix in both row-major(`seqlen*hdim`) and col-major(`hdim*seqlen`). Since the accumulate(reduce) dimension for V is along `seqlen`, for current AMD's mfma layout which expect each thread to have contiguous register holding pixels along reduce dimension, it's easier to support col-major V layout. However, the performance of col-major is not necessarily faster than row-major, there are many factors that may affect the overall performance. We still provide the `-vlayout=r/c` here to switch/test between different layouts. + +### generic attention mask coordinate +We unify the mask expression into generic attention mask coordinate, providing an uniformed approach to describe causal top-left, causal bottom-right, local attention. +![](misc/gamc.png) + +(more description to be added) + +### dropout +TBD + +## FP8 experimental support +As described in [this blog](https://blog.hippoml.com/8bit-hippoattention-up-to-3x-faster-compared-to-flashattentionv2-8f9def90b482), we have an experimental support for fp8 fmha kernels, you can evaluate the performance by setting the arg `-prec=fp8` to the `example_fmha_fwd`, on a gfx940/941/942 machine and ROCm 6.0+. Currently if you not explicitly setting `-v=0`(which will disable CPU verification), it will printout an error as much as `0.05`. We are still WIP to tune the kernel performance as well as the precision, so stay tuned for the updated performance(pipeline) +Currently we only support `-vlayout=c` for fp8, which is `hdim*seqlen` for V matrix. row major for V matrix support will come later. diff --git a/example/ck_tile/01_fmha/fmha_fwd.cpp b/example/ck_tile/01_fmha/fmha_fwd.cpp new file mode 100644 index 0000000000..1b2183960c --- /dev/null +++ b/example/ck_tile/01_fmha/fmha_fwd.cpp @@ -0,0 +1,521 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "fmha_fwd.hpp" +#include "ck_tile/host.hpp" +#include "mask.hpp" +#include "utils.hpp" + +auto create_args(int argc, char* argv[]) +{ + ck_tile::ArgParser arg_parser; + arg_parser.insert("v", "1", "weather do CPU validation or not") + .insert("mode", "0", "kernel mode. 0:batch, 1:group") + .insert("b", "2", "batch size") + .insert("h", "8", "num of head, for q") + .insert("h_k", + "0", + "num of head, for k/v, 0 means equal to h\n" + "if not equal to h, then this is GQA/MQA case") + .insert("s", "3328", "seqlen_q") + .insert("s_k", "0", "seqlen_k, 0 means equal to s") + .insert("d", "128", "head dim for q, k") + .insert("d_v", "0", "head dim for v, 0 means equal to d") + .insert("scale", "0", "scale factor. 0 means equal to 1/sqrt(hdim)") + .insert("descale_q", "1", "scale factor for fp8 quantization") + .insert("descale_k", "1", "scale factor for fp8 quantization") + .insert("descale_v", "1", "scale factor for fp8 quantization") + .insert("iperm", + "1", + "permute input\n" + "if true, will be b*h*s*d, else b*s*h*d") + .insert("operm", "1", "permute output") + .insert("bias", "0", "add bias or not") + .insert("prec", "fp16", "data type. fp16/bf16/fp8/bf8") + .insert("mask", + "0", + "0: no mask, 1: top-left, 2:bottom-right\n" + "'t:l,r', top-left local-attn with left right size\n" + "'b:l,r', bottom-r local-attn with left right size\n" + "'g:y,x', generic attention mask coordinate with y/x size\n") + .insert("vlayout", "r", "r for row-major(seqlen*hdim), c for col-major(hdim*seqlen)") + .insert("lse", "0", "0 not store lse, 1 store lse") + .insert("kname", "0", "if set to 1 will print kernel name") + .insert("init", "1", "init method. 0:random int, 1:random float, 2:trig float") + .insert("seed", + "11939", + "random seed used for initializing input tensors. 0 to use " + "non-deterministic random number as seed") + .insert("warmup", "5", "number of iterations before benchmark the kernel") + .insert("repeat", "20", "number of iterations to benchmark the kernel"); + + bool result = arg_parser.parse(argc, argv); + return std::make_tuple(result, arg_parser); +} + +// different threshold for different dtype +template +auto get_elimit(int /*init_method*/) +{ + double rtol = 1e-3; + double atol = 1e-3; + return ck_tile::make_tuple(rtol, atol); +} + +template <> +auto get_elimit(int init_method) +{ + if(init_method == 0) + { + double rtol = 1e-2; + double atol = 1e-2; + return ck_tile::make_tuple(rtol, atol); + } + else + { + double rtol = 3e-3; + double atol = 3e-3; + return ck_tile::make_tuple(rtol, atol); + } +} + +template +bool run(const ck_tile::ArgParser& arg_parser) +{ + std::string data_type = arg_parser.get_str("prec"); + int do_validation = arg_parser.get_int("v"); + auto mode = static_cast(arg_parser.get_uint32("mode")); + ck_tile::index_t batch = arg_parser.get_int("b"); + ck_tile::index_t nhead = arg_parser.get_int("h"); + ck_tile::index_t nhead_k = arg_parser.get_int("h_k"); + if(nhead_k == 0) + nhead_k = nhead; + + if(nhead % nhead_k != 0) + { + std::cerr << "nhead:" << nhead << " must be multiple of nhead_k:" << nhead_k << std::endl; + return false; + } + + ck_tile::index_t seqlen_q = arg_parser.get_int("s"); + ck_tile::index_t seqlen_k = arg_parser.get_int("s_k"); + if(seqlen_k == 0) + seqlen_k = seqlen_q; + ck_tile::index_t hdim_q = arg_parser.get_int("d"); + ck_tile::index_t hdim_v = arg_parser.get_int("d_v"); + if(hdim_v == 0) + hdim_v = hdim_q; + + bool i_perm = arg_parser.get_bool("iperm"); // if true, will be batch * nhead * seqlen * hdim + bool o_perm = arg_parser.get_bool("operm"); // if false, will be batch * seqlen * nhead * hdim + + float scale = arg_parser.get_float("scale"); + if(scale == .0f) + scale = 1.0 / ck_tile::sqrt(static_cast(hdim_q)); // TODO: q ? v ? + + float descale_q = arg_parser.get_float("descale_q"); + float descale_k = arg_parser.get_float("descale_k"); + float descale_v = arg_parser.get_float("descale_v"); + + std::string vlayout = arg_parser.get_str("vlayout"); + bool use_bias = arg_parser.get_bool("bias"); + bool lse = arg_parser.get_bool("lse"); + + mask_info mask = mask_info::decode(arg_parser.get_str("mask"), seqlen_q, seqlen_k); + + int init_method = arg_parser.get_int("init"); + std::optional seed = arg_parser.get_uint32("seed"); + if(*seed == 0) + { + seed.reset(); + } + + int stream_warmup = arg_parser.get_int("warmup"); + int stream_repeat = arg_parser.get_int("repeat"); + bool kname = arg_parser.get_bool("kname"); + + stream_config stream_config{ + nullptr, true, /* log_level = */ (kname ? 1 : 0), stream_warmup, stream_repeat}; + + const auto seqstart_q_host = generate_seqstarts(mode, batch, seqlen_q); + const auto seqstart_k_host = generate_seqstarts(mode, batch, seqlen_k); + + using TypeConfig = FmhaFwdTypeConfig; + + using QDataType = typename TypeConfig::QDataType; + using KDataType = typename TypeConfig::KDataType; + using VDataType = typename TypeConfig::VDataType; + using BiasDataType = typename TypeConfig::BiasDataType; + using LSEDataType = typename TypeConfig::LSEDataType; + using SaccDataType = typename TypeConfig::SaccDataType; + using SMPLComputeDataType = typename TypeConfig::SMPLComputeDataType; + using PDataType = typename TypeConfig::PDataType; + using OaccDataType = typename TypeConfig::OaccDataType; + using ODataType = typename TypeConfig::ODataType; + + // accumulation numbers for performance evaluation + std::size_t flop = 0, num_byte = 0; + auto max_seqlen_q = + std::numeric_limits::min(); // we will use max seqlen to decide grid size + { + for(ck_tile::index_t wb = 0; wb < batch; ++wb) + { + const int32_t real_seqlen_q = seqstart_q_host[wb + 1] - seqstart_q_host[wb]; + const int32_t real_seqlen_k = seqstart_k_host[wb + 1] - seqstart_k_host[wb]; + + if(max_seqlen_q < real_seqlen_q) + { + max_seqlen_q = real_seqlen_q; + } + + flop += nhead * (static_cast(2) * real_seqlen_q * real_seqlen_k * hdim_q + + static_cast(2) * real_seqlen_q * hdim_v * real_seqlen_k); + + num_byte += nhead * (sizeof(QDataType) * real_seqlen_q * hdim_q + + sizeof(KDataType) * real_seqlen_k * hdim_q + + sizeof(VDataType) * hdim_v * real_seqlen_k + + sizeof(ODataType) * real_seqlen_q * hdim_v); + } + } + + auto get_lengths = [&](bool permute, + ck_tile::index_t b /*batch*/, + ck_tile::index_t h /*nhead*/, + ck_tile::index_t s /*seqlen*/, + ck_tile::index_t d /*hdim*/) { + if(permute) + return std::array{b, h, s, d}; + else + return std::array{b, s, h, d}; + }; + + bool is_v_rowmajor = vlayout == std::string("r"); + + // host memory for storing all the tensor elements + const ck_tile::index_t shape_batch = (mode == mode_enum::batch ? batch : 1); + const ck_tile::index_t shape_seqlen_q = + (mode == mode_enum::batch ? seqlen_q : seqstart_q_host.back()); + const ck_tile::index_t shape_seqlen_k = + (mode == mode_enum::batch ? seqlen_k : seqstart_k_host.back()); + + HostTensor q_host(get_lengths(i_perm, shape_batch, nhead, shape_seqlen_q, hdim_q)); + HostTensor k_host(get_lengths(i_perm, shape_batch, nhead_k, shape_seqlen_k, hdim_q)); + HostTensor v_host( + is_v_rowmajor ? get_lengths(i_perm, shape_batch, nhead_k, shape_seqlen_k, hdim_v) + : get_lengths(i_perm, shape_batch, nhead_k, hdim_v, shape_seqlen_k)); + // use bias shape = [1, 1, shape_seqlen_q, shape_seqlen_k]. if use_bias=false, the bias_host + // will not be used for verification at all (but will be copied to device anyway). + HostTensor bias_host( + use_bias ? get_lengths(i_perm, 1, 1, shape_seqlen_q, shape_seqlen_k) + : std::array{1, 1, 1, 1} /* dummy shape for simplifying code */); + // self define lse data layout as [shape_batch, nhead, shape_seqlen_q] + HostTensor lse_host( + lse ? std::array{shape_batch, nhead, shape_seqlen_q} + : std::array{1, 1, 1} /* dummy shape for simplifying code */); + + HostTensor o_host(get_lengths(o_perm, shape_batch, nhead, shape_seqlen_q, hdim_v)); + + if(init_method == 0) + { + ck_tile::utils::FillUniformDistributionIntegerValue{-2.f, 2.f, seed}(q_host); + ck_tile::utils::FillUniformDistributionIntegerValue{-2.f, 2.f, seed}(k_host); + ck_tile::utils::FillUniformDistributionIntegerValue{-2.f, 2.f, seed}(v_host); + ck_tile::utils::FillUniformDistributionIntegerValue{-2.f, 2.f, seed}(bias_host); + } + else if(init_method == 1) + { + ck_tile::utils::FillUniformDistribution{0.f, 1.f, seed}(q_host); + ck_tile::utils::FillUniformDistribution{0.f, 1.f, seed}(k_host); + ck_tile::utils::FillUniformDistribution{0.f, 1.f, seed}(v_host); + ck_tile::utils::FillUniformDistribution{0.f, 1.f, seed}(bias_host); + } + else if(init_method == 2) + { + ck_tile::utils::FillTrigValue{}(q_host); + ck_tile::utils::FillTrigValue{}(k_host); + ck_tile::utils::FillTrigValue{}(v_host); + ck_tile::utils::FillTrigValue{}(bias_host); + } + + DeviceMem q_buf(q_host.get_element_space_size_in_bytes()); + DeviceMem k_buf(k_host.get_element_space_size_in_bytes()); + DeviceMem v_buf(v_host.get_element_space_size_in_bytes()); + DeviceMem bias_buf(bias_host.get_element_space_size_in_bytes()); + DeviceMem lse_buf(lse_host.get_element_space_size_in_bytes()); + DeviceMem o_buf(o_host.get_element_space_size_in_bytes()); + DeviceMem seqstart_q(seqstart_q_host.size() * sizeof(int32_t)); + DeviceMem seqstart_k(seqstart_k_host.size() * sizeof(int32_t)); + + q_buf.ToDevice(q_host.data()); + k_buf.ToDevice(k_host.data()); + v_buf.ToDevice(v_host.data()); + bias_buf.ToDevice(bias_host.data()); + seqstart_q.ToDevice(seqstart_q_host.data()); + seqstart_k.ToDevice(seqstart_k_host.data()); + + // clang-format off + auto layout_str = [&](bool permute){ + if (permute) return std::string("bhsd"); + else return std::string("bshd"); + }; + auto io_layout = [&](bool iperm_, bool operm_) { + if (iperm_ == operm_) return layout_str(iperm_); + else return layout_str(iperm_) + std::string("-") + layout_str(operm_); + }; + // clang-format on + const std::string prec = arg_parser.get_str("prec"); + + std::cout << "[" << prec << "|" << mode << "|" << io_layout(i_perm, o_perm) << "] b:" << batch + << ", h:" << nhead << "/" << nhead_k << ", s:" << seqlen_q << "/" << seqlen_k + << ", d:" << hdim_q << "/" << hdim_v << ", scale:" << scale << ", bias:" << use_bias + << ", lse:" << lse << ", mask:" << mask << ", v:" << vlayout << std::flush; + + auto fmha_traits = fmha_fwd_traits{hdim_q, + hdim_v, + data_type, + mode == mode_enum::group, + is_v_rowmajor, + mask.type, + use_bias, + lse}; + auto fmha_args = fmha_fwd_args{q_buf.GetDeviceBuffer(), + k_buf.GetDeviceBuffer(), + v_buf.GetDeviceBuffer(), + bias_buf.GetDeviceBuffer(), + lse_buf.GetDeviceBuffer(), + o_buf.GetDeviceBuffer(), + seqstart_q.GetDeviceBuffer(), + seqstart_k.GetDeviceBuffer(), + nullptr, + batch, + nhead, + nhead_k, + shape_seqlen_q, + shape_seqlen_k, + hdim_q, + hdim_v, + max_seqlen_q, + scale, + descale_q * descale_k, + descale_v, + i_perm, + o_perm, + mask.y, + mask.x}; + + float ave_time = fmha_fwd(fmha_traits, fmha_args, stream_config); + + if(ave_time < 0) + { + std::cout << ", not supported yet" << std::flush << std::endl; + return false; + } + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_byte / 1.E6 / ave_time; + + std::cout << std::fixed << ", " << std::setprecision(3) << ave_time << " ms, " + << std::setprecision(2) << tflops << " TFlops, " << std::setprecision(2) << gb_per_sec + << " GB/s" << std::flush; + + if(!do_validation) + { + std::cout << std::flush << std::endl; + return true; + } + + o_buf.FromDevice(o_host.data()); + lse_buf.FromDevice(lse_host.data()); + + bool pass = true; + + for(ck_tile::index_t wb = 0; wb < batch; ++wb) + { + const ck_tile::index_t real_seqlen_q = seqstart_q_host[wb + 1] - seqstart_q_host[wb]; + const ck_tile::index_t real_seqlen_k = seqstart_k_host[wb + 1] - seqstart_k_host[wb]; + + // adjust matrix index according to the mode + const ck_tile::index_t b = (mode == mode_enum::batch ? wb : 0); + const ck_tile::index_t query_offset = (mode == mode_enum::batch ? 0 : seqstart_q_host[wb]); + const ck_tile::index_t key_offset = (mode == mode_enum::batch ? 0 : seqstart_k_host[wb]); + + const auto v_host_ref_lengths = std::array{nhead, hdim_v, real_seqlen_k}; + const auto v_host_ref_strides = + is_v_rowmajor ? std::array{hdim_v * real_seqlen_k, 1, hdim_v} + : std::array{hdim_v * real_seqlen_k, real_seqlen_k, 1}; + + HostTensor q_host_ref({nhead, real_seqlen_q, hdim_q}); + HostTensor k_host_ref({nhead, real_seqlen_k, hdim_q}); + HostTensor v_host_ref(v_host_ref_lengths, v_host_ref_strides); + HostTensor o_host_ref({nhead, real_seqlen_q, hdim_v}); + + HostTensor s_host_ref({nhead, real_seqlen_q, real_seqlen_k}); + HostTensor p_host_ref({nhead, real_seqlen_q, real_seqlen_k}); + HostTensor lse_host_ref({nhead, real_seqlen_q}); + + ck_tile::index_t nr = nhead / nhead_k; + + // clang-format off + // permute + if(i_perm) q_host_ref.ForEach([&](auto& self, auto i) { self(i) = q_host(b, i[0], i[1] + query_offset, i[2]); }); + else q_host_ref.ForEach([&](auto& self, auto i) { self(i) = q_host(b, i[1] + query_offset, i[0], i[2]); }); + + if(i_perm) k_host_ref.ForEach([&](auto& self, auto i) { self(i) = k_host(b, i[0] / nr, i[1] + key_offset, i[2]); }); + else k_host_ref.ForEach([&](auto& self, auto i) { self(i) = k_host(b, i[1] + key_offset, i[0] / nr, i[2]); }); + + if (is_v_rowmajor) { + // v_host_ref: [nhead, hdim, seq], v_host: [b, h_k, s, d] + if(i_perm) v_host_ref.ForEach([&](auto& self, auto i) { self(i) = v_host(b, i[0] / nr, i[2] + key_offset, i[1]); }); + // v_host_ref: [nhead, hdim, seq], v_host: [b, s, h_k, d] + else v_host_ref.ForEach([&](auto& self, auto i) { self(i) = v_host(b, i[2] + key_offset, i[0] / nr, i[1]); }); + } + else { + if(i_perm) v_host_ref.ForEach([&](auto& self, auto i) { self(i) = v_host(b, i[0] / nr, i[1], i[2] + key_offset); }); + else v_host_ref.ForEach([&](auto& self, auto i) { self(i) = v_host(b, i[1], i[0] / nr, i[2] + key_offset); }); + } + // clang-format on + + // reference + reference_batched_gemm( + q_host_ref, + k_host_ref, + s_host_ref, + ck_tile::identity{}, + ck_tile::identity{}, + [&](SaccDataType x) { return scale * x; }); + + if(use_bias) + { + HostTensor bias_host_ref({1, real_seqlen_q, real_seqlen_k}); + // clang-format off + if(i_perm) + bias_host_ref.ForEach([&](auto& self, auto i) { self(i) = bias_host(0, 0, i[1] + query_offset, i[2] + key_offset); }); + else + bias_host_ref.ForEach([&](auto& self, auto i) { self(i) = bias_host(0, i[1] + query_offset, 0, i[2] + key_offset); }); + // clang-format on + + // broadcast from [1, real_seqlen_q, real_seqlen_k] to [nhead, real_seqlen_q, + // real_seqlen_k] + reference_batched_elementwise( + s_host_ref, bias_host_ref, s_host_ref); + } + + if(mask.type == mask_enum::no_mask) + { + reference_batched_masking( + s_host_ref, FmhaMasks::NoMask{real_seqlen_q, real_seqlen_k}); + } + else if(mask.type == mask_enum::window_generic) + { + reference_batched_masking( + s_host_ref, FmhaMasks::GenericMask{mask.y, mask.x, real_seqlen_q, real_seqlen_k}); + } + else + { + reference_batched_masking( + s_host_ref, FmhaMasks::CausalMask{mask.y, mask.x, real_seqlen_q, real_seqlen_k}); + } + if(lse) + { + reference_batched_softmax( + s_host_ref, p_host_ref, lse_host_ref); + } + else + { + reference_batched_softmax( + s_host_ref, p_host_ref); + } + + reference_batched_gemm( + p_host_ref, v_host_ref, o_host_ref); + + HostTensor o_host_result({nhead, real_seqlen_q, hdim_v}); + // clang-format off + // permute + if(o_perm) o_host_result.ForEach([&](auto& self, auto idx) { self(idx) = o_host(b, idx[0], idx[1] + query_offset, idx[2]); }); + else o_host_result.ForEach([&](auto& self, auto idx) { self(idx) = o_host(b, idx[1] + query_offset, idx[0], idx[2]); }); + // clang-format on + + auto [rtol, atol] = get_elimit(init_method); + bool cur_pass = ck_tile::utils::check_err( + o_host_result, o_host_ref, std::string("OUT Error: Incorrect results!"), rtol, atol); + pass &= cur_pass; + if(!cur_pass) + { + std::cerr << "OUT mismatch found at batch: " << wb << std::endl + << "\tseqlen_q: " << real_seqlen_q << std::endl + << "\tseqlen_k: " << real_seqlen_k << std::endl + << "\tseqstart_q: " << seqstart_q_host << std::endl + << "\tseqstart_k: " << seqstart_k_host << std::endl; + + break; + } + + if(lse) + { + HostTensor lse_host_result({nhead, real_seqlen_q}); + lse_host_result.ForEach([&](auto& self, auto idx) { + self(idx) = lse_host(b, idx[0], idx[1] + query_offset); + }); + + bool lse_pass = ck_tile::utils::check_err(lse_host_result, + lse_host_ref, + "LSE Error: Incorrect results!", + rtol, + atol, + /* allow_infinity_ref = */ true); + + pass &= lse_pass; + if(!cur_pass) + { + std::cerr << "LSE mismatch found at batch: " << wb << std::endl + << "\tseqlen_q: " << real_seqlen_q << std::endl + << "\tseqlen_k: " << real_seqlen_k << std::endl + << "\tseqstart_q: " << seqstart_q_host << std::endl + << "\tseqstart_k: " << seqstart_k_host << std::endl; + + break; + } + } + } + + std::cout << ", valid:" << (pass ? "y" : "n") << std::flush << std::endl; + + return pass; +} + +int main(int argc, char* argv[]) +{ + auto [result, arg_parser] = create_args(argc, argv); + if(!result) + return -1; + + const std::string data_type = arg_parser.get_str("prec"); + if(data_type == "fp16") + { + return run(arg_parser) ? 0 : -2; + } + else if(data_type == "bf16") + { + return run(arg_parser) ? 0 : -2; + } + else if(data_type == "fp8") + { + return run(arg_parser) ? 0 : -2; + } + + return -3; +} diff --git a/example/ck_tile/01_fmha/fmha_fwd.hpp b/example/ck_tile/01_fmha/fmha_fwd.hpp new file mode 100644 index 0000000000..eb11efb2e2 --- /dev/null +++ b/example/ck_tile/01_fmha/fmha_fwd.hpp @@ -0,0 +1,336 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/host/kernel_launch.hpp" +#include "ck_tile/ops/fmha.hpp" +#include "ck_tile/ops/epilogue.hpp" +#include "ck_tile/ops/common.hpp" +#include "mask.hpp" + +template +struct FmhaFwdTypeConfig; + +template <> +struct FmhaFwdTypeConfig +{ + using QDataType = ck_tile::half_t; + using KDataType = ck_tile::half_t; + using VDataType = ck_tile::half_t; + using BiasDataType = ck_tile::half_t; + using LSEDataType = float; // data type for lse(logsumexp L_j = max_j + log(l_j)) + using SaccDataType = float; // data type for first gemm accumulation + using SMPLComputeDataType = float; // data type for reduction, softmax + using PDataType = ck_tile::half_t; // data type for A matrix of second gemm + using OaccDataType = float; // data type for second gemm accumulation + using ODataType = ck_tile::half_t; +}; + +template <> +struct FmhaFwdTypeConfig +{ + using QDataType = ck_tile::bhalf_t; + using KDataType = ck_tile::bhalf_t; + using VDataType = ck_tile::bhalf_t; + using BiasDataType = ck_tile::bhalf_t; + using LSEDataType = float; // data type for lse(logsumexp L_j = max_j + log(l_j)) + using SaccDataType = float; // data type for first gemm accumulation + using SMPLComputeDataType = float; // data type for reduction, softmax + using PDataType = ck_tile::bhalf_t; // data type for A matrix of second gemm + using OaccDataType = float; // data type for second gemm accumulation + using ODataType = ck_tile::bhalf_t; +}; + +template <> +struct FmhaFwdTypeConfig +{ + using QDataType = ck_tile::fp8_t; + using KDataType = ck_tile::fp8_t; + using VDataType = ck_tile::fp8_t; + using BiasDataType = float; // TODO: fix me + using LSEDataType = float; // data type for lse(logsumexp L_j = max_j + log(l_j)) + using SaccDataType = float; // data type for first gemm accumulation + using SMPLComputeDataType = float; // data type for reduction, softmax + using PDataType = ck_tile::fp8_t; // data type for A matrix of second gemm + using OaccDataType = float; // data type for second gemm accumulation + using ODataType = ck_tile::fp8_t; +}; + +template <> +struct FmhaFwdTypeConfig +{ + using QDataType = ck_tile::bf8_t; + using KDataType = ck_tile::bf8_t; + using VDataType = ck_tile::bf8_t; + using BiasDataType = ck_tile::bf8_t; + using LSEDataType = float; // data type for lse(logsumexp L_j = max_j + log(l_j)) + using SaccDataType = float; // data type for first gemm accumulation + using SMPLComputeDataType = float; // data type for reduction, softmax + using PDataType = ck_tile::bf8_t; // data type for A matrix of second gemm + using OaccDataType = float; // data type for second gemm accumulation + using ODataType = ck_tile::bf8_t; +}; + +struct FmhaMasks +{ + using NoMask = ck_tile::GenericAttentionMask; + using GenericMask = ck_tile::GenericAttentionMask; + using CausalMask = ck_tile::GenericAttentionMask; +}; + +// internal API, don't use this directly +template +auto fmha_fwd_create_kargs_and_grids(const void* q_ptr, + const void* k_ptr, + const void* v_ptr, + const void* bias_ptr, + void* lse_ptr, + void* o_ptr, + const void* seqstart_q_ptr, + const void* seqstart_k_ptr, + const void* seqlen_k_ptr, + ck_tile::index_t batch, + ck_tile::index_t nhead, + ck_tile::index_t nhead_k, + ck_tile::index_t seqlen_q, + ck_tile::index_t seqlen_k, + ck_tile::index_t hdim_q, + ck_tile::index_t hdim_v, + ck_tile::index_t max_seqlen_q, + float scale, + float descale_qk, + float descale_sv, + bool i_perm, + bool o_perm, + ck_tile::index_t mask_y, + ck_tile::index_t mask_x) +{ + constexpr bool is_v_rowmajor = + ck_tile::is_same_v; + + assert(nhead % nhead_k == 0); + /// NOTE: we broadcast bias from [1, 1, seqlen_q, seqlen_k] to [batch, nhead, seqlen_q, + /// seqlen_k] in this example, hence both the 'batch_stride_bias' & 'nhead_stride_bias' + /// are 0. + // setup stride_* arguments + const ck_tile::index_t stride_q = (i_perm ? hdim_q : nhead * hdim_q); + const ck_tile::index_t stride_k = (i_perm ? hdim_q : nhead_k * hdim_q); + const ck_tile::index_t stride_v = [&]() { + if constexpr(is_v_rowmajor) + return i_perm ? hdim_v : nhead_k * hdim_v; + else + return i_perm ? seqlen_k : nhead_k * seqlen_k; + }(); + const ck_tile::index_t stride_bias = (i_perm ? seqlen_k : 1 * seqlen_k); + const ck_tile::index_t stride_o = (o_perm ? hdim_v : nhead * hdim_v); + // setup nhead_stride_* arguments + const ck_tile::index_t nhead_stride_q = (i_perm ? seqlen_q * hdim_q : hdim_q); + const ck_tile::index_t nhead_stride_k = (i_perm ? seqlen_k * hdim_q : hdim_q); + const ck_tile::index_t nhead_stride_v = [&]() { + if constexpr(is_v_rowmajor) + return i_perm ? seqlen_k * hdim_v : hdim_v; + else + return i_perm ? hdim_v * seqlen_k : seqlen_k; + }(); + const ck_tile::index_t nhead_stride_bias = (i_perm ? 0 * seqlen_q * seqlen_k : 0 * seqlen_k); + const ck_tile::index_t nhead_stride_lse = (seqlen_q * 1); + const ck_tile::index_t nhead_stride_o = (o_perm ? seqlen_q * hdim_v : hdim_v); + // setup batch_stride_* arguments + const ck_tile::index_t batch_stride_q = (nhead * seqlen_q * hdim_q); + const ck_tile::index_t batch_stride_k = (nhead_k * seqlen_k * hdim_q); + const ck_tile::index_t batch_stride_v = (nhead_k * hdim_v * seqlen_k); + const ck_tile::index_t batch_stride_bias = (0 * nhead * seqlen_q * seqlen_k); + const ck_tile::index_t batch_stride_lse = (nhead * seqlen_q * 1); + const ck_tile::index_t batch_stride_o = (nhead * seqlen_q * hdim_v); + + auto kargs = [&] { + // create group mode kernel arguments + if constexpr(FmhaKernel::kIsGroupMode) + { + return FmhaKernel::MakeKargs(q_ptr, + k_ptr, + v_ptr, + bias_ptr, + lse_ptr, + o_ptr, + seqstart_q_ptr, + seqstart_k_ptr, + seqlen_k_ptr, + hdim_q, + hdim_v, + nhead / nhead_k, + scale, + stride_q, + stride_k, + stride_v, + stride_bias, + stride_o, + nhead_stride_q, + nhead_stride_k, + nhead_stride_v, + nhead_stride_bias, + nhead_stride_lse, + nhead_stride_o, + mask_y, + mask_x, + descale_qk, + descale_sv); + } + else + { // create batch mode kernel arguments + return FmhaKernel::MakeKargs(q_ptr, + k_ptr, + v_ptr, + bias_ptr, + lse_ptr, + o_ptr, + seqlen_q, + seqlen_k, + hdim_q, + hdim_v, + nhead / nhead_k, + scale, + stride_q, + stride_k, + stride_v, + stride_bias, + stride_o, + nhead_stride_q, + nhead_stride_k, + nhead_stride_v, + nhead_stride_bias, + nhead_stride_lse, + nhead_stride_o, + batch_stride_q, + batch_stride_k, + batch_stride_v, + batch_stride_bias, + batch_stride_lse, + batch_stride_o, + mask_y, + mask_x, + descale_qk, + descale_sv); + } + }(); + + dim3 grids = FmhaKernel::GridSize(batch, nhead, max_seqlen_q, hdim_v); + return ck_tile::make_tuple(kargs, grids); +} + +// This is the args from caller to underneath API, different from the kernel +struct fmha_fwd_args +{ + const void* q_ptr; + const void* k_ptr; + const void* v_ptr; + const void* bias_ptr; + void* lse_ptr; + void* o_ptr; + const void* seqstart_q_ptr; + const void* seqstart_k_ptr; + const void* seqlen_k_ptr; + ck_tile::index_t batch; + ck_tile::index_t nhead; + ck_tile::index_t nhead_k; + ck_tile::index_t seqlen_q; + ck_tile::index_t seqlen_k; + ck_tile::index_t hdim_q; + ck_tile::index_t hdim_v; + ck_tile::index_t max_seqlen_q; + float scale; + float descale_qk; + float descale_sv; + bool i_perm; + bool o_perm; + ck_tile::index_t mask_y; + ck_tile::index_t mask_x; +}; + +template +auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args) +{ + return fmha_fwd_create_kargs_and_grids(args.q_ptr, + args.k_ptr, + args.v_ptr, + args.bias_ptr, + args.lse_ptr, + args.o_ptr, + args.seqstart_q_ptr, + args.seqstart_k_ptr, + args.seqlen_k_ptr, + args.batch, + args.nhead, + args.nhead_k, + args.seqlen_q, + args.seqlen_k, + args.hdim_q, + args.hdim_v, + args.max_seqlen_q, + args.scale, + args.descale_qk, + args.descale_sv, + args.i_perm, + args.o_perm, + args.mask_y, + args.mask_x); +} + +// this is used to pattern-match internl kernel implementation, not to instantiate kernel +template +struct fmha_fwd_traits_ +{ + static constexpr ck_tile::index_t HDim = HDim_; + using DataType = ck_tile::remove_cvref_t; + static constexpr bool kIsGroupMode = kIsGroupMode_; + static constexpr ck_tile::index_t kM0 = kM0_; + static constexpr ck_tile::index_t kN0 = kN0_; + static constexpr ck_tile::index_t kK0 = kK0_; + static constexpr ck_tile::index_t kN1 = kN1_; + static constexpr ck_tile::index_t kK1 = kK1_; + static constexpr ck_tile::index_t kK0BlockLength = kK0BlockLength_; + static constexpr bool kIsVLayoutRowMajor = kIsVLayoutRowMajor_; + using FmhaMask = ck_tile::remove_cvref_t; + static constexpr bool kHasBias = kHasBias_; + static constexpr bool kStoreLse = kStoreLse_; + static constexpr bool kPadS = kPadS_; + static constexpr bool kPadSK = kPadSK_; + static constexpr bool kPadD = kPadD_; + static constexpr bool kPadDv = kPadDv_; +}; + +template +float fmha_fwd_(const stream_config&, fmha_fwd_args); + +// This is the public API, will be generated by script +struct fmha_fwd_traits +{ + int hdim_q; + int hdim_v; + std::string data_type; + bool is_group_mode; + bool is_v_rowmajor; + mask_enum mask_type; + bool has_bias; + bool has_lse; + // TODO: padding check is inside this api +}; +float fmha_fwd(fmha_fwd_traits, fmha_fwd_args, const stream_config&); diff --git a/example/ck_tile/01_fmha/generate.py b/example/ck_tile/01_fmha/generate.py new file mode 100644 index 0000000000..34d7421bbd --- /dev/null +++ b/example/ck_tile/01_fmha/generate.py @@ -0,0 +1,500 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +# generate kernel instances to speed up compilation + +import argparse +import itertools +from pathlib import Path +from typing import List, Optional, tuple +from dataclasses import dataclass +import copy + +DTYPE_MAP = { + "fp16": "ck_tile::half_t", + "bf16": "ck_tile::bhalf_t", + "fp8" : "ck_tile::fp8_t" +} + +DTYPE_BITS = { + "fp32": 32, + "fp16": 16, + "bf16": 16, + "fp8" : 8, + "bf8" : 8 +} + +MASK_MAP = { + "no" : "FmhaMasks::NoMask", + "causal" : "FmhaMasks::CausalMask", + "generic" : "FmhaMasks::GenericMask" +} + +MODE_MAP = { + "batch" : "false", + "group" : "true" +} + +LAYOUT_MAP = { + "row" : "true", + "col" : "false" +} + +PIPELINE_MAP = { + "qr" : "ck_tile::BlockFmhaPipelineQRKSVS", + "qr_fp8" : "ck_tile::BlockFmhaPipelineQRKSVSFp8", + "qr_async" : "ck_tile::BlockFmhaPipelineQRKSVSAsync", +} + +BOOL_MAP = { + "t" : "true", + "f" : "false" +} + +MASKS = ["no", "causal", "generic"] +DIRECTIONS = ["fwd"] +GEN_DIR = "" # in Cmake, have to generate files in same folder + +FMHA_FWD_KERNEL_HEADER = """// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.\n +// auto generated by generate.py +#include "fmha_fwd.hpp" +""" + +FMHA_FWD_KERNEL_BODY=""" +using fmha_dtype_{F_idx} = {F_dtype}; + +using fmha_block_tile_{F_idx} = ck_tile::sequence<{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0blen}>; +using fmha_block_warps_{F_idx} = ck_tile::sequence<{F_rm}, {F_rn}, {F_rk}>; +using fmha_warp_tile_{F_idx} = ck_tile::sequence<{F_wm}, {F_wn}, {F_wk}>; + +using fmha_shape_{F_idx} = ck_tile::TileFmhaShape; + +using fmha_trait_{F_idx} = ck_tile::TileFmhaTraits<{F_spad}, + {F_skpad}, + {F_dpad}, + {F_dvpad}, + {F_bias}, + {F_lse}, + {F_occupancy}>; +using fmha_mask_{F_idx} = {F_mask}; + +using fmha_pipeline_problem_{F_idx} = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_{F_idx}, + {F_mode}, + fmha_mask_{F_idx}, + fmha_trait_{F_idx}>; + +using fmha_pipeline_{F_idx} = {F_pipeline}< + fmha_pipeline_problem_{F_idx}>; + +using fmha_epilogue_{F_idx} = + ck_tile::FmhaFwdEpilogue::OaccDataType, + typename FmhaFwdTypeConfig<{F_dtype}>::ODataType, + {F_spad}, {F_dvpad}>>; + +using fmha_kernel_{F_idx} = + ck_tile::FmhaFwdKernel, + fmha_pipeline_{F_idx}, + fmha_epilogue_{F_idx}>; + +using trait_{F_idx} = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0blen}, {F_vlayout}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>; + +#include + +template<> +float fmha_fwd_(const stream_config& s, fmha_fwd_args a) +{{ + using k_ = fmha_kernel_{F_idx}; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; + return ck_tile::launch_kernel(s, k_{{}}, grids, blocks, 0, kargs); +}} +""" + +FMHA_FWD_API_FILENAME="fmha_fwd_api.cpp" +FMHA_FWD_API=""" +float fmha_fwd(fmha_fwd_traits t, fmha_fwd_args a, const stream_config& s){{ + float r = -1; +{F_dispatch} + return r; +}} +""" + +FMHA_FWD_API_PER_DTYPE=""" {F_if}(t.data_type.compare(\"{F_dtype}\") == 0){{ +{F_hdim_case} + }} +""" +FMHA_FWD_API_PER_HDIM_CASE=""" {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v <= {F_hdim}) {{ +{F_inner_dispatch} + }} +""" +MASK_CHECK_MAP = { + "no" : "t.mask_type == mask_enum::no_mask", + "causal" : "t.mask_type == mask_enum::causal_top_left || t.mask_type == mask_enum::causal_bottom_right", + "generic" : "t.mask_type == mask_enum::window_generic", +} + +FMHA_FWD_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && ({F_mask_check}) && (t.has_bias == {F_bias}) && (t.has_lse == {F_lse}) && + ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{ + using trait_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0blen}, {F_vlayout}, {F_mask}, {F_bias}, {F_lse}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>; + return fmha_fwd_(s, a); + }} +""" + +@dataclass +class FmhaFwdApiTrait: + pipeline_tag : str + # sync with fmha_fwd_traits<>, to generate fallback calls + hdim : str + dtype : str # data type + mode : str # value from MODE_MAP + bm0 : int # tile size along q seqlen (block size) + bn0 : int # tile size along qk seqlen + bk0 : int # tile size along qk gemm unroll + bn1 : int # tile size along v head_dim + bk1 : int # tile size along kv gemm unroll + bk0blen : int + vlayout : str + mask : str + bias : str # true/false + lse : str # + spad : str + skpad : str + dpad : str + dvpad : str + + @property + def name(self) -> str: + return f'{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn0}-{self.bk1}-{self.bk0blen}-'+\ + f'{self.vlayout}-{self.mask}-{self.bias}-{self.lse}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}' + + @property + def scheck(self) -> str: + if self.pipeline_tag == 'qr_async': + if self.spad == 't' : return 'true' # always support + else : return 'true' + elif self.pipeline_tag in ['qr', 'qr_fp8']: + if self.spad == 't' : return f'a.seqlen_q % {self.bm0} != 0' + else : return f'a.seqlen_q % {self.bm0} == 0' + else: assert False + + @property + def skcheck(self) -> str: + if self.skpad == 't' : return f'a.seqlen_k % {self.bn0} != 0' + else : return f'a.seqlen_k % {self.bn0} == 0' + + @property + def dcheck(self) -> str: + if self.pipeline_tag == 'qr_async': + vec = int((32 * 4) / DTYPE_BITS[self.dtype]) + if self.dpad == 't': return f'a.hdim_q % {vec} == 0' + else : assert False + elif self.pipeline_tag in ['qr', 'qr_fp8']: + if self.dpad == 't': return f'a.hdim_q % {self.bk0blen} != 0' + else : return f'a.hdim_q % {self.bk0blen} == 0' + else: assert False + + @property + def dvcheck(self) -> str: + if self.pipeline_tag == 'qr_async': + vec = int((32 * 4) / DTYPE_BITS[self.dtype]) + if self.dvpad == 't': return f'a.hdim_v % {vec} == 0' + else : assert False + elif self.pipeline_tag in ['qr', 'qr_fp8']: + if self.dvpad == 't': return f'a.hdim_v % {self.bk0blen} != 0' + else : return f'a.hdim_v % {self.bk0blen} == 0' + else: assert False + +@dataclass +class FmhaFwdPipeline: + tag : str + + F_vlayout : str # row/col + F_spad : str # true/false + F_skpad : str # + F_dpad : str # + F_dvpad : str # + F_bias : str # true/false + F_lse : str # + F_mask : str # value from MASK_MAP + + @property + def name(self) -> str: + def pad_name() -> str: + n = '' + if self.F_spad == 't': n += 's' + if self.F_skpad == 't' : n += 'sk' + if self.F_dpad == 't' : n += 'd' + if self.F_dvpad == 't' : n += 'dv' + if n != '' : n = 'p' + n + return n + pn = pad_name() + n = f'{self.tag}_v{self.F_vlayout[0]}' + if pn != '' : n += f'_{pn}' + if self.F_bias == 't' : n += '_bias' + if self.F_mask != 'no' : n += f'_m{self.F_mask[0]}' + if self.F_lse == 't' : n += '_lse' + return n + +class FmhaFwdApiPool: + def __init__(self): + self.pool = dict() + + def register_traits(self, trait : FmhaFwdApiTrait) -> None: + # TODO: do we need to check duplication? + if trait.dtype not in self.pool.keys(): + self.pool[trait.dtype] = dict() + if trait.hdim not in self.pool[trait.dtype].keys(): + self.pool[trait.dtype][trait.hdim] = list() + + self.pool[trait.dtype][trait.hdim].append(copy.copy(trait)) + + @property + def api(self) -> str: + per_dtypes=str() + for i, dtype in enumerate(self.pool.keys()): + per_hdim_case=str() + for j, hdim in enumerate(self.pool[dtype].keys()): + traits=self.pool[dtype][hdim] + inners=str() + for k, trait in enumerate(traits): + if_k = 'if' if k == 0 else 'else if' + inners = inners + FMHA_FWD_API_INNER_DISPATCH.format(F_if=if_k, F_mode=MODE_MAP[trait.mode], F_vlayout=LAYOUT_MAP[trait.vlayout], F_mask=MASK_MAP[trait.mask], + F_mask_check=MASK_CHECK_MAP[trait.mask], F_bias=BOOL_MAP[trait.bias], F_lse=BOOL_MAP[trait.lse], + F_scheck=trait.scheck, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, + F_spad=BOOL_MAP[trait.spad], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad], + F_bm0=trait.bm0, F_bn0=trait.bn0, F_bk0=trait.bk0, F_bn1=trait.bn1, F_bk1=trait.bk1, F_bk0blen=trait.bk0blen, + F_hdim=hdim, F_dtype=DTYPE_MAP[dtype]) + if_j = 'if' if j == 0 else 'else if' + per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_inner_dispatch=inners) + if_i = 'if' if i == 0 else 'else if' + per_dtypes = per_dtypes + FMHA_FWD_API_PER_DTYPE.format(F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case) + return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_API.format(F_dispatch = per_dtypes) + +@dataclass +class FmhaFwdTileSize: + F_bm0 : int # tile size along q seqlen (block size) + F_bn0 : int # tile size along qk seqlen + F_bk0 : int # tile size along qk gemm unroll + F_bn1 : int # tile size along v head_dim + F_bk1 : int # tile size along kv gemm unroll + F_bk0blen : int # total length of K0, used for pipeline that need load Q at once (or repeately load Q as a whole tile) + F_rm : int # number of warps along q seqlen (block warps) + F_rn : int # number of warps along k seqlen(not used) + F_rk : int # number of warps along gemm-k(not used) + F_wm : int # warp size along m (warp size) + F_wn : int # warp size along n + F_wk : int # warp size along k + F_occupancy : int # occupancy, -1 will let pipeline decide the occupancy, other value will overwrite occupancy + @property + def name(self) -> str: + return f"b{self.F_bm0}x{self.F_bn0}x{self.F_bk0}x{self.F_bn1}x{self.F_bk1}x{self.F_bk0blen}" +\ + f"_r{self.F_rm}x{self.F_rn}x{self.F_rk}_w{self.F_wm}x{self.F_wn}x{self.F_wk}" +\ + ("" if self.F_occupancy == -1 else f"_o{self.F_occupancy}") + +@dataclass +class FmhaFwdKernel: + direction : str + F_idx : int # this is not a tunable, but a counter to differentiate symbol + F_hdim : int # hdim + F_dtype : str # data type + F_mode : str # value from MODE_MAP + F_tile : FmhaFwdTileSize + F_pipeline : FmhaFwdPipeline + + @property + def template(self) -> str: + return FMHA_FWD_KERNEL_HEADER + \ + FMHA_FWD_KERNEL_BODY.format( + F_idx = self.F_idx, + F_hdim = self.F_hdim, + F_dtype = DTYPE_MAP[self.F_dtype], + F_bm0 = self.F_tile.F_bm0, + F_bn0 = self.F_tile.F_bn0, + F_bk0 = self.F_tile.F_bk0, + F_bn1 = self.F_tile.F_bn1, + F_bk1 = self.F_tile.F_bk1, + F_bk0blen = self.F_tile.F_bk0blen, + F_rm = self.F_tile.F_rm, + F_rn = self.F_tile.F_rn, + F_rk = self.F_tile.F_rk, + F_wm = self.F_tile.F_wm, + F_wn = self.F_tile.F_wn, + F_wk = self.F_tile.F_wk, + F_vlayout = LAYOUT_MAP[self.F_pipeline.F_vlayout], + F_spad = BOOL_MAP[self.F_pipeline.F_spad], + F_skpad = BOOL_MAP[self.F_pipeline.F_skpad], + F_dpad = BOOL_MAP[self.F_pipeline.F_dpad], + F_dvpad = BOOL_MAP[self.F_pipeline.F_dvpad], + F_bias = BOOL_MAP[self.F_pipeline.F_bias], + F_lse = BOOL_MAP[self.F_pipeline.F_lse], + F_occupancy = self.F_tile.F_occupancy , + F_mask = MASK_MAP[self.F_pipeline.F_mask], + F_mode = MODE_MAP[self.F_mode], + F_pipeline = PIPELINE_MAP[self.F_pipeline.tag]) + + @property + def name(self) -> str: + # TODO: we don't encode idx here + return f"fmha_{self.direction}_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_" +\ + self.F_tile.name + '_' + self.F_pipeline.name + + @property + def filename(self) -> str: + return self.name + ".cpp" + + def api_trait(self) -> FmhaFwdApiTrait: + return FmhaFwdApiTrait( + pipeline_tag=self.F_pipeline.tag, + hdim=str(self.F_hdim), + dtype=self.F_dtype, + mode=self.F_mode, + bm0=self.F_tile.F_bm0, + bn0=self.F_tile.F_bn0, + bk0=self.F_tile.F_bk0, + bn1=self.F_tile.F_bn1, + bk1=self.F_tile.F_bk1, + bk0blen=self.F_tile.F_bk0blen, + vlayout=self.F_pipeline.F_vlayout, + mask=self.F_pipeline.F_mask, + bias=self.F_pipeline.F_bias, + lse=self.F_pipeline.F_lse, + spad=self.F_pipeline.F_spad, + skpad=self.F_pipeline.F_skpad, + dpad=self.F_pipeline.F_dpad, + dvpad=self.F_pipeline.F_dvpad) + +# TODO: design a more practical way to do it +# this is current supported tile size per hdim +def get_fmha_fwd_tile_dict_from_dtype(direction : str, dtype : str) -> Optional[dict]: + if direction == 'fwd': + if dtype == 'fp16' or dtype == 'bf16': + return { + '32' : FmhaFwdTileSize(128, 64, 16, 32, 32, 32, 2, 1, 1, 32, 32, 16, -1), + '64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 32, 32, 16, -1), + '128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 32, 32, 16, -1), + '256' : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 32, 32, 16, -1), + } + elif dtype == 'fp8' or dtype == 'bf8': + return { + '128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 32, 32, 32, -1) + } + else: + return None + else: + return None + +def get_blobs() -> tuple[FmhaFwdApiPool, List[FmhaFwdKernel]]: + # TODO: we don't support tuning yet, so pick up one value for vlayout/pipeline/pad + # support this in future + def get_pipelines(dtype, hdim) -> List[FmhaFwdPipeline]: + # this function will populate a list possible pipelines + pipelines = [] + if dtype in ['fp16', 'bf16']: + for mask, bias, lse in itertools.product(MASK_MAP.keys(), ["t", "f"], ["t", "f"]): + #if hdim == 256: + if True: + pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', bias, lse, mask)) + pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', bias, lse, mask)) + + pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', bias, lse, mask)) + pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 't', 't', bias, lse, mask)) + #else: + # pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 'f', 't', 't', bias, lse, mask)) + # pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 't', 't', 't', bias, lse, mask)) + # pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 'f', 't', 't', bias, lse, mask)) + # pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 't', 't', 't', bias, lse, mask)) + elif dtype in ['fp8', 'bf8']: + # no need lse kernels + for mask, bias in itertools.product(MASK_MAP.keys(), ["t", "f"]): + pipelines.append(FmhaFwdPipeline('qr_fp8', 'col', 'f', 'f', 'f', 'f', bias, 'f', mask)) + else: + assert Fasle + return pipelines + + gen = list() + api_pool = FmhaFwdApiPool() + + for direction, dtype in itertools.product(DIRECTIONS, DTYPE_MAP.keys()): + d = get_fmha_fwd_tile_dict_from_dtype(direction, dtype) + if d == None: + continue + #for hdim_str, mode, mask, bias, lse in itertools.product(d.keys(), MODE_MAP.keys(), MASK_MAP.keys(), ["t", "f"], ["t", "f"]): + for hdim_str, mode in itertools.product(d.keys(), MODE_MAP.keys()): + tile = d[hdim_str] + hdim = int(hdim_str) + for pipeline in get_pipelines(dtype, hdim): + k = FmhaFwdKernel(direction=direction, F_idx=0, F_hdim=hdim, F_dtype=dtype, F_mode=mode, F_tile=tile, F_pipeline=pipeline) + api_pool.register_traits(k.api_trait()) + gen.append(k) + + return (api_pool, gen) + +def write_single_kernel(kernel: FmhaFwdKernel, autogen_dir: Path) -> None: + (autogen_dir / kernel.filename).write_text(kernel.template) + +def write_api(api_pool : FmhaFwdApiPool, autogen_dir: Path) -> None: + (autogen_dir / FMHA_FWD_API_FILENAME).write_text(api_pool.api) + +def write_blobs(output_dir: Optional[str]) -> None: + if output_dir is None: + output_dir = Path(__file__).parent + else: + output_dir = Path(output_dir) / GEN_DIR + + output_dir.mkdir(parents=True, exist_ok=True) + api_pool, kernels = get_blobs() + for kernel in kernels: + write_single_kernel(kernel, output_dir) + write_api(api_pool, output_dir) + +# list all the files that will be generated +def list_blobs(output_file: Optional[str]) -> None: + assert output_file is not None + file_path = Path(output_file) + with file_path.open('a') as f: + _, kernels = get_blobs() + for kernel in kernels: + f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n") + f.write(str(file_path.parent / GEN_DIR / FMHA_FWD_API_FILENAME) + "\n") + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + prog="generate", + description="gen api for CK fmha kernel", + ) + parser.add_argument( + "-o", + "--output_dir", + required=False, + help="write all the blobs into a directory" + ) + parser.add_argument( + "-l", + "--list_blobs", + required=False, + help="list all the kernels to a file" + ) + args = parser.parse_args() + if args.list_blobs is not None: + list_blobs(args.list_blobs) + else: + write_blobs(args.output_dir) diff --git a/example/ck_tile/01_fmha/mask.hpp b/example/ck_tile/01_fmha/mask.hpp new file mode 100644 index 0000000000..2e26fcb897 --- /dev/null +++ b/example/ck_tile/01_fmha/mask.hpp @@ -0,0 +1,108 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck_tile/core.hpp" +#include "ck_tile/fmha.hpp" + +enum class mask_enum +{ + no_mask = 0, + causal_top_left, + causal_bottom_right, + window_generic, +}; + +struct mask_info +{ + mask_enum type; + ck_tile::index_t y, x; + + void serialize(std::ostream& os) const + { + if(type == mask_enum::no_mask) + os << "n"; + else if(type == mask_enum::causal_top_left) + os << "tl"; + else if(type == mask_enum::causal_bottom_right) + os << "br"; + else + { + os << "g(" << y << "/" << x << ")"; + } + } + static mask_info decode(std::string str, ck_tile::index_t seqlen_q, ck_tile::index_t seqlen_k) + { + ck_tile::index_t x_total = seqlen_k; + ck_tile::index_t y_total = seqlen_q; + mask_info tmp; + auto found_0 = str.find(':'); + if(found_0 != std::string::npos) + { + std::string t = str.substr(0, found_0); + std::string v = str.substr(found_0 + 1); + auto found_1 = v.find(","); + if(found_1 == std::string::npos) + { + printf("not supported value %s, %s\n", v.c_str(), str.c_str()); + assert(0); + } + tmp.type = mask_enum::window_generic; + ck_tile::index_t v0 = atoi(v.substr(0, found_1).c_str()); + ck_tile::index_t v1 = atoi(v.substr(found_1 + 1).c_str()); + // TODO: some validation + if(t == "t") + { + auto r = ck_tile::make_generic_attention_mask_coordinates_from_lr_window( + v0, v1, y_total, x_total, true); + tmp.y = r.at(ck_tile::number<0>{}); + tmp.x = r.at(ck_tile::number<1>{}); + } + else if(t == "b") + { + auto r = ck_tile::make_generic_attention_mask_coordinates_from_lr_window( + v0, v1, y_total, x_total, false); + tmp.y = r.at(ck_tile::number<0>{}); + tmp.x = r.at(ck_tile::number<1>{}); + } + else if(t == "g") + { + tmp.y = v0; + tmp.x = v1; + } + else + { + printf("not supported type %s, %s\n", t.c_str(), str.c_str()); + assert(0); + } + } + else + { + // should be 0, 1, 2 + tmp.type = static_cast(atoi(str.c_str())); + if(tmp.type == mask_enum::causal_top_left) + { + tmp.y = seqlen_q; + tmp.x = 1; + } + else if(tmp.type == mask_enum::causal_bottom_right) + { + tmp.y = seqlen_q; + tmp.x = seqlen_k - seqlen_q + 1; + } + } + return tmp; + } + + friend std::ostream& operator<<(std::ostream& os, const mask_info& mi); +}; + +inline std::ostream& operator<<(std::ostream& os, const mask_info& mi) +{ + mi.serialize(os); + return os; +} diff --git a/example/ck_tile/01_fmha/misc/gamc.png b/example/ck_tile/01_fmha/misc/gamc.png new file mode 100644 index 0000000000000000000000000000000000000000..2c96951f30f99ff0345706c4f3031ab2a623ce1a GIT binary patch literal 30073 zcmeFZ2UOI_w=QY|6&g@UO;$vb0u2ZV2q;L--6TPB5+o=R8w3?;6c8jQ0hQQ9$r4%y zn;;-4&}3;uk~B#|gGhSSfHTfH^Z(y-*E;XMyVg6i)~s1W*RRsvwX5p;zFkjrw3MkR zm?;h%I6$SUat(gqz~Srz2M!4xI|Tl6lpw`@;DE;!)oWMvya#7fv~P+I-P?{}aWh#u zbqu4KOqcZeq0M8_Wbp~S5~_y3tU|(t!i=nzuWR1k{-t70JnBMOWE{i58yi}M`sBXB zBNPw%9L(U=oni%jHgR9nPd}647PuY#tFjrs zD)@qX2C~m_7US8WxArRkrU)Tt?4_ezPI7vFO#RDh1c6|{7Ubwg>2Dyx*y37#VsesK z$l+;jH-d25?Bk2{)EENZW^QQuikQ$?BOT&p-8LFIfbhZ_2sxy0Xx=3D6F%!TR6=c(YwEcNzo zes6J;@!t-;_lErmx-x-6fta(dB*dl4lY3-H&}(AIY2?!1hW{-KuXru_!%;-G<6r{K zhGx$n9MtYO{XP?mT&XF-7lXMvy*{t`sY;1thPQA6eruBu>lm)*%5bk*y&$rVJ+q?_ z*Qj=yf@ew53?l1!`HwoPrLDv#kgYf8gS3bBwLzwQTaxQJiVh}FNG>(|BtUvwv zXV6q=qu||*>#rwh#0z`B6qRPOKR!f^%z4~*;Svpb znhUUxF^-Wq>oo6Msm0HWYmC#_A)ys9V>zRr+S7T2rqGp4E7)=U6N_6EeR^GzD-@~f z4Hd3&PuXt~C-KMfHj>41Wn1As@Q*eFZxfp%sDAgin^|UZ@JB^1+RB89QmLwAhYT)V zuT}`8mH0!T@VbzoOoX`Ltfkn`{6x*-TG_XCo@xp!bCYWC+Ia3@ClpV;j+bUhkI@(S zBJ|0a`tDELA8FUJWWsRkcepxGzGo9gKN*gx3yu4Xl|<&k`Hz1*|M-J&g0)Aw&@F8T zBA2>rL&c!t&sx0ZJ?MO*Qruj!$D6(vJ2rcZgRB+V+KcXE%USMfQkTmk#@=Owu)eKU zZ9gCQl+yQ@a+=t{`5Vb;4x6Lc_n#cp2*Wp2D{w65U9Vqh!VMYPdlBI1mRLtYyezpj zCgrZ{1D0i~7bVG838W%T@Q^mzg-L`Xf?!(udgtmYjz#SOC6|4BgI-L*Y`!%H zEXo4ore4L{H)@TteB_4KpRU@K|13S^_mdNE%P}*YfDC-f-ldNhU}JuYGswGuc^S8r zCPbt@<;&f+TCfrT*=zIx=Yv-PG{KGI#?5y|sDrPv8;6Ola^T*=*ZK%6`fiiSpKXl# z6;9myId7Vi^@NW8x|^>4o7bnw%W;7L(xu;q&{_k<`ax}!MFtqX=J?1L&tK;%2BR2` z{SgA|Tbr(0>7D3yGb2C*Gtpup=8%=FAhIMVm^oGGI ztc^Exr19garuB)Xk_|!zncQUo^3`6US*w$&wHc>gUW-nwoBSTr{{4!@cgk#^&taGz ze(86n2s2;*vaZ2~^ql{*Wj%M4f*l(Q{rZScMbXurmfiSFMiE1#8`b6_kG1%uOTeWZxP7m-cIPmR zAjhmMqxLc&&|sA#E*UV5Eh z^zcs#-jk+MtT2M;)oI(sO!3} zy{A@z_+ZnsNrlc#8;LNv-XVAdzO=$*F4*u<0yErw4K^TBB-|eSO1;GeDu-F+kzl0S zc@qOgPN{onb7Pig0*AgM$J&~#gwU1lt!vx!BLO?zJ3c?k=k)wOmJK<+VO{VFo|gZv zJgb7qLgb2Tz|J>m!nQPNeQj5b+WC5z?J7R6cMs^VWuS566YGl;gLls*${5^LoA8Ck ziY=?n6(=G1#Xnq_5oe4jf){KDO3++pFvUFY_CL0ja~e96@?5rgO|AHZ20v}ZgpV@E9_gr{dsIUe2Bgr#(|6q%bK(BA;gHIVVgs z&b2pl?T#=L^<)mtz@yPV5SpDAYufJShHp_1$gkY+!{1<=q0+sTf(%$1Z7!1i@olbh zggW3ysrja3`;dg(PMw4+zSet@)A>OUO~B7ZA?p$U<%j@#&%uoDIK`+Gr?*AewVU)(QD>AGXXoR9NX)Z=0D_rU_aIp(PFOVn?H;0{TK~i zs9n|?Je0@bKeN$Rnpotw*!V%v+tI1{PNoxq5oHNBec~|?WZCA&4^?^*Vq;>JFqk`DdJc8fmL)a>zA&=66|mXZ=U>ym zs*aBs$|#98)nGzJ)^aq%OyD2HAQcU;a(!W5q80R(9={r9WOKfIZF}j(f1$Jrh72o9 z?ED;daaMoAB^g5{p z>itiXv6w~g{LCbpk@Ze|WnJpv<>otf0UqalmM5R~F0D3(I_l-#QGwV(l^FX)(kM_T zb2ta$pjj3C>rPKBbq16gzZB#B&?8yi^(q?ZJ*R16riUOLEpy)M=|Y92A^>Nkw(Es`D|NN8;utYQ)=dOQG;h4L(F@tWu z%WA(`IjiPJpZrx(jK2qj4awh)Lj`QAdIVLhl<8@Wspc;-uj@6&zs0pZDgU+RTI|&! zfg!y6)^)i*?tS$!UpcIWKg#)urDL8vq1+D9dtK2Ing>=O%V|lo6IN;h8YTNt<{7eJ zAjZRB&utdSRl*q`CBP}8^Y>jY!|ay0ZHhh}s@OOPS936GjOdeKhWEY5utQ4;TnF zI0Hq0LL)5arYCCWz zw&|&9V@45~4bd6H6#-|tnkmmKgiV84CM_zZ3^Oo&3}jfNEO=wlz$PNI?Dt|KpZB`k z_!yT5gYYNOL9Cf*bBz zavlDyfN0o@VzXo=T;iw>f1@n6-)7IfKqoj6icXd}{o>gAEHiG&a+hTU)mdQ=jPlV*0GU#pQ&emXES z{h}L@Ey1WvURlRQ4`0}1L!HccXj#EaYMg0)?4#D_`wHjizO z_b6gNuf}J{6R7YOP*r{?cK~TP^}s#R772rSpkumch~%!YFfMvXbRDT*7Erx1J($KE zSW=le&yax|SccxAmg6}V3C zidht>tBUGmSlkWUfDS z>i5Sr9ir7__Z(*YzGk~J8TQ4g^=G*ASA%eQWG%^E+`EU8k`8`J3&b+$K+ZK;E~1*w z8KUFWrK~d5-BU?^sTy@Z&QA{*Ad*HEa|+C>@v~Y!XK)??$C8=+sKer79AJ!U;|lPi zBcMVCE?2c=P^#CdRVwel+Ux~xaK8js(_te?d&A+-^ro++8#Rlw3ILy5M8968)J2z@{hHIKA$EimYAH)qB4jy8OMR zUHALsJft~4aR0mLVp3=W2Hkg&%*Vy}rmx2X3Oq+@q`vsv^mlw@Esh-owQCqeiu8g; zR1tqR@+}+;GhOw^&}64uVVgTnQjR@05wYrfNlCqN?~{Q>BcZ*jU%bkEXh|r9Hxvzg zj?UgH^mVqd3~yPw<}jH)DSX$e7mLRwAAwju4v8@@wcK4fdrL%%KPba&Ve?z3?{Jcd zwSWDUa$^Bv&LwT*(4| zR(P$XHCe*7EUrj|sCHWg5+#wTxp(yXEhxpx>!~r#i189x3ar|VipI5+6q7>w-|0l6 z|B72xqO1_dx=zrVo@35mf8LLUCSLS#A%3ylpedsNXGD`H6x2|l8Dp7uL*H&$#Vk_I zXNbL~LAX#gs=vza#p7qlnL@A|WTHqhcCYMSFH|rdhG⪼0mN|Vz2CejK)`kn5K(< z`5$`>cO3=+EeEIbS3KYC8U?y?{{OyVRAic+rnT}CzQPTId^&_WX#5B43gseBnmxkH zG|G+DzVF<9Jr!U;sFKKU1%$DWp)&j@*vhy1I%i10xU{SrE!)OVo;CO z@ZdR*o@(`+2C@p?p0GbFOE-!}dLT5!ZH#VRbQ!;ZeWq3M_*LagEA@EmigGe1=0q)Z zj7h9xQ`VzBUZzC%f-0t4cWTr7N55lvN@;hY+-os5DrBFR^~41^H-#(N+10~a_;uoJ za_LZVXt}k;;)gGV%iAuq@S@KiDEE}W`Suo@a>z>=1Fw)e+YOqd+NsKgF0PZ**UL3O zn6^L38ZfIi_pw84TZyiHzL6jC@}MEH!#>p9G(~nqZo%uUb^C*php);JwY1WFM_0MF z%-`b$_aI6&2Re1Q>uqntt>8Rb*J~dr#d5c94>@Zg+b4P~7vRN?DY_r+)06hxf>Pjj)A;#-DiuEVek6nbYfpKyS>njiygMqZ?`vzhzd6 z{xq|S8to%n(~zl}UaR%-60aoOSjZjnn->#mE`z3}qDCIJb(n+^k+&SmhwQVndD_#J&Z@@a zZc&^YGjg?HX)nefDUC$#kAMwvK^?Qz((Y)*-PWE|;tQ zK_OIxYt0D5H%e9h8Rby%#M$6pDSZMmTBcT(=*ZHQ@i%81hTZnQ?58)^__FSQ?j=|#qG#G*L)5#=@MOgM)3--Y z)R^wER57e{;7r#&N_gjkYF@hP1OOVixEr&pREjwB_MCSzOX@%c{ofYujCI^u14fsN zxy&aShOEW1V!JSucIV$4w_l1+|IRkywc8FI&Gpo47FSGO=y_-?rFjk4N$K~s@n|IU zG4RG_;=df9*809}F~?Lvs_H##p7%rdX1n*B>3T5|(%l6~|4CEHqbW9>!a>8RDt@NE z$tN$;n}}b^Fr(#7zdB`>g7@ZmKmY0TmcN9p_c`y=4CqRH%q(ir#d30_*lA0!F@uCN z--IeHJdn1#$Ppr_h=r*?Vt*-!24i_)BcXAz7KJZp=v%jKwYYpq{`s`b!qKU*#j`Bv zzGG^)S4D@uI2PNUo%-Y`Ir-|Yo=64>{gy&4#&LLbC`7O`^f2(!>y)rGam+ZQo7V$h zo1e6DROuYX+O*X+ky<-fW4XC9nD+U`#mzCe)wdL(9+||Ijv^Wt?95XkmQYLvzKd^U zqb6WyOM7Q(w66lODAVndH|ka^x0>rRTV`wE4qrt}2RzM`kX`gtaF5WRgH;dX}bu)v3)=gPr>1qwJxDW*rW4kNovc zgpuOlm~W=hF>ry>TA%b9E=`oXi{I`p+Mb+}+nD~$zsOV*Ml@DK55krkl>-Ml-i`iD zzZ1j8)P&7)5Z>(ofvR-XP=$O2G!x6L)SgYO#xEXz-JIaQT1fM*iaViWI=uIS{iP2@ zR}=%Rq(g42V(2BCnm0DmHkRdn_%V!drSJ?(7Br789aWycC-pAoODfy`K-DBoW>`sB zX=?{0T?L~A&lUf$^mTR8dT)GiY=b9XC~el{Y-4v%YoVQ_pskGyV=)~iRl4XW6HO9m{0Y__ zQ;aua^=y8!W!})PeCr`@rkyeTlA!E!#0+kuXopxAutBr486@O)W^>dY{}Y!~DdfM^ zC%gF_>xK8>AtC8?#I1>e4~B1lSlPu3)ormDJ3rBQGU$y5)WTD%bj!~}4rUzzi4C=Q zD?V}NCYjIA&8cOYohrgU%yGtn-%C2a734=dzv{o4@SdG}GjKqqY_ZupaQNAp<)s<& zHN%yJ^L^=#DagE+WjjC9?6NgQi~PoD0&cBV*8XRFnHD0dueNM|MD_gjluBj5ZFKF_ zE1M%{0^TnAPyc7UJByhV84Ov{&M+Ilz3Uyo?2F_g@iDVp)8^beMprE*PnUvi5kTEA zNX!x`6p5i)9vEliF>1DS2+a66xCSNDTmi`mwH90eOx$%E>!$#( zC@0|+yb$YhlHr~q1wIaNBh7GS9L3t%>~*h64B6@e6E@qf}2U>&rIQ(tJrIm*(ywYpG7?PYgF8^A@es@ zu~R$#0XugO?KN9)MhPH?zmV=<^(ke(!kyo%$%xgeSE0lbL``%-$W?S@*krm@@E&~8 zIvf!V5MGu2h5#~Q_2A$CQY=6`lz5^h^q|Oty-kXA$$y-UU-Yh>iHaG!R3msWL;j(H z)gQ2^-wQA)eKIQ2JWKd+n>QA4;!Y)iM+NTnCpzPlJZgauBK75e6Kw_{JfKqq{;^UO zj2j6H28?@{CMn82hLh0nee}ez5JWV+kb9`&mN;Mo-M^5U0aXGUmDJkCaB_OnEYbbb zG`c%E^laeJE;PGMjXefXKa<~~*_vG#kvciP1Ir^(sIHWP8M_m(M0jQCvY`%e&i`JJ z8tw|PGW-`RR3yOBml%J8d@&)bMFR|iiyxNz-wCUNLhI6yvnpg*C*9u=@Ol10FF_V6 z>IfkYdm-bw3Q<J=(1g)$1c>$keN&R5)g*@`<~e@1xUQTT6C^3&6yED7kS z8W_a-QLFeel2T#*L5|nRJwrpqNv01oeMRMQ5)S#5M=YQuN<32>0M2J$M?1q?4nnNK ziDWlI!HWbtnX#m^3Saqpjs^5Vic_>2wFQ(!j*}R=XGo7Hy-+0*ZH^{E;m2lelJ(pU zV4>jDvKzy2dFo|NFcQJ1<}3>|0~S0(8BvF!O{ZgDAxOY|K}S>oL)FU?De0`Y07E`z z5Sf4k%n4M<7!-VH>14=vk%*-Efp}6`*QRqlg$>aSfL5El$Lb`=8gjM(^tT=+ro#s# zR!2xORs5te)L#2&T0nb_710JtX5<~)EeRNvJ;EU9H*h<9IP~mCmQ?dWnzZNm?8{$W zt!+I#!~Pdg_NWxL=EC<&%Wdttguns`0|WbJqhxR#Y~hz?nj!Io=zYeeT|WdmqmS21 zSj9S5qH!604Z?uHvknZ2A%Wd$Nx#5uGexg_FDYitcO&8ve%Q+>Ij2!5-1Pv&`gN;d z6KO60MYB%{#lBnUYwI}(!l>|ATEN*yTS@IlQCG>ZDl=f#VPO;dyN(K`TA*Za(p#N= zO*)9WhNA;cB%>KcB9{;Rzhl&0y?`+}YqPGt?VFOJnMy^&iS}d($w4dSh=~Izw9>gW z;3{O!S2G{JgK5aTSR zHy%hNA>R8EqXxu+_hr0?YU`cnrtt^hXF${fA?Wu4kh50#jQi?O01n5%`7+G*wXfi> zs|Xd$zH}6*hi9{&?E{J=DRIB{8KV3?O#b3Zahg0KXkYbz<$;wP$OPC|96vbJ=nu$? zKide|qJsOH6x3{q{7h8$`Uf~8n=EB*r!9$tv07>o9QH0rGyF}s`8@FwL@djM$ zJnL<`_6tUea$LI|4WaZ(gE{>Pm8X;_oE#y6duAiT&OF`$x0$8*6T1OU#$+e;Z6$Tm zwFx_#Zd?<6l7uhLS=O%P{sY9(qtOButJOb;ns`KA48{9q9a_5R_%ycdZ2R=SIn$K)wn9!6&G`Pc^YW zr0O)7!j39G7R5|$eet{OZ*vw5!HVPDCDYWoGdL^56HM;9omZvSLb?`poFSs9Ul?97 zJ`U;UE_bV(!|Z<-;R7CSqx*KvYiD~SZTlm($T|D#g3X|_*bTGOZp|JFH!sR>L?}zt zuSjUl#ts0CSnpQCfZLQdnDyu^cy#P@8OT(#WRSDo49p-J2?74{kHhsgm$KjxcapDiN6*c?O_LTYHpl zhS^yFZrb_W@m;=g`@bYmRRsaUrcqmV^GDB74$Ri-&?rs7YN&aE&&K4`&e~*qGNX`L zhQRBequ?%NJ3HjMy}0n9qts)teKnr$xXL!l6IwgAR`M^w1<+bZHf~u5UE~Ug^hB2P;%k;&(c&9 z;S#C1t4-}FwU>@}-5hmo-;OghqJBRMq{V;-fE>=e+#3FOU@z#9BE!GYrk=D(==}2C z%-Iih^X7WT=$6CmbR;QA+vwWa_w0zQH&amjUCbKkv=^|V=z%oVc;=|S)tKNpatPus z(4_9leMCP1Z3fsD4t%TM-xzMkKm1eh17xd!#%@2DXdR+s1K*^b;n?}nY|JlOyD?u= z&0jr6v8y-R#;oDn`WTA&t(2{7n7PeONLMQK3jq;*62zzXi6InReRR0s+mj1BV?u9N zhaB3}@r~f_%Mc8`OLbP#%mO;haNT z=;Fh2xX~q~ss>8B>ZcEJ-U?UUJ14)j0qvsDW|{ps79Y4N`so8JiN?>QrmsFGHQCIy zcVIU!0b3YWujo}GQaP{#BZaB@&i?f>^9oQIwPia$Vy~MOl-O|u!PKt<=xGQ!o&6bb zt&Nw^YyYm3Kuv*cF{r{;IVRq@_lO*sJ-Wb&~ zOLD93-3_9^#6q)>^jl2S{`1b2J-4A9wUIB9pCehiJHOa2e|=bN#+&@_zKp%RjKpYc zK-tQkc6`%R#gK`UhYZu*^%5p2`IIUZds(ANvfh|MzZO97Ll79Xx5aqrY~m$9Cm|Fv zyLf<_4=r(d$g8dCN^!E)yr2>ZocUJqKFKh-eX`n!&%bgQ5^ITwd2}5nEy6!M=4pxpMr=_a94lR$ukFNMlDY(`v^$I_dGD zw{#${5^6#tNw5z%fKMiHk;JgLNp&D)wMro9^LvZ4EF)juM0^+~x!fa?tKHRzA_}*Eb!fpI=V0gU7IgIiM2+m<&Vz$BCcR^H@l(pz7shkjbM0B# z&>UHSYk<`i?%Qm%3xz;_c|!sJPohFGPxLM6vK8^W1z9)1H%B&&tRK8XrGp z^jt}J2lFJ2bOfmJ1zo~pwtZ&9ym_~EW|p28V;0X~mC z93@kH+yU}(lu}`_4*DjKBvoKI6@uF){^iMlC&Sm8TjVSHu4e3GzgC;ZYn6E!gC z`*B-p^XV5aKY7>SBxa|E`1$HuF}?sysp8a1!)l?sn+;`$cUZKP zSBe z-d^4!PQX5MS4wo?q5}k?P36W%%)}R4t5+rvItjgNR`&>| zvWEuPD%^$M^6+PA=RWQ>jB`IKCp3H#-poWpOKz;+giFT%p(6m|+)GpgF7MP3`}H00 zHyl3JY)PWDrlXgWsxON}MnRQzz)1xoJ6U?1yJ*$-cu2CpQALLG(N|N}NfMoxZhSIw^$S;UBWM<$ zsjX-6r`lXWpeo0i6lT|{(Wo$f<(gB8v#>oVLI&U<8p>(ASlpccl+p&2Azzg3+8dZTGX)k*p7*bQ`WX#(y=fex)-)n@kog%ybus`zw$< zLtd@7U_2{s*57GOf$g{5-6{%bgoE~31V%c@zyA0iL;Mk*MsW7EF*ouyam$&tiHA#I zV}@JwSD~oRb0K13Z2V9DfHhhL^H*DA4mjPn<>{5w!{1vQo~4|Up18O5(?iU;@{3{Jtp!Z5Rsj6(qSMk#{fFb)X{AAbbcUm(=2F z+!Xr1!c5Bf?b|b7HsE2e#6b`>6JS6UPbOR=w0CYwu_$GM7+T@f+67-gO{F*ozE8Pz zx7d*VVDY&uk~FQDTOwljbQ@q4L~v+D~)HP2HB(mQjJwV`T{i$EXZEYkN_t zx!V$1^SlyN^2zR+#Wd^VUFtww4lWbXS9uIC;}Qh3YH< z>>wuMwpRg!%e+#!(4;15Ur$^2^}6HRM`OWp-C7;xhoml zqBjd-x2{mi%tKCzKF!PPp4tFcM|H2)ME`-XKo_0+O6B#mpN$Jvtp?gLFFDhKP=2O_ z-+_mSFtm-p`;UspNv6n$R;@Rxcla#z+Tp0!bho!nGbYHqb}51OIs)s3iiJr zV#Bw?)G!weqzqlR1ce?SjSz7mUD37sn(MpW6&~feELZi_hEYHMc_AX>##MTMPnXTG3FF+_~#{#&WBqP~W0^LZff3sf7M&4V;~v^#fai zd$Q2SsfgR*<+gQagj1ygreR30Xfxow6g7K zMZ4f+HpkAA$w=Cxx}SN`LID$@i#3I>|K$i%M!$qFvmHGt_B`{I;S%Dj*|&BzD>J34 zbqc*p>>HtW2vwsID*;;9Pezr4E2=dun3D!Knx@}}4EKuBO2fqYIIHnJ2a?2HEuEdw z>bZ@HiY8Q8Rc}FpzJ#>to9l@SuO}ulNPO7Gf>s%x0{e!umY!FZm)>gY{0eg zL()gAJrLk{16G6v!Q&{o(C`_RW~QTecw;i-&OQq5Po&dByfW>@CF|OH+*7@l;%LZo zH-GM_BJHVYVu?rV#`jUrohnA=j~&mEh)IZjj=;T^%yB;>zf3vh6zte&y4-0q9ONjy zOr_T4JcC(M+e+iYEsyt1?Xqub70l>^E~8_bAYs#viY^f;n%mG`2Q^YOUu^ z|Mw8a3C+#p6#IKYdZgZG6pp2?Tt05FS3olUwsAyf3-5XBsrrnFiR_^r_dnz+wWR;n zVje8aKcYu%Z_V66dr3BWh^Ys0_B%yR=na&yGKI&?FXb&GY2mKou{{n4wjpL>i5hy{ zq4Cu{9rmU820MoQnWrkDap!ej_I6|w@jv9AExX1%OF$l;zIlL{OXna=9 zu)u<1VNxLWttJO<(d~z~%14Vlj&~+*MYFOBjA2_Qx!@FPRn}$x8fB-idObh zRDi^1yU_^v;n=``2)28aPsn}TjP{N1W)(jnE6g5$*4JGdaRKG+h!jcXZS;?1|JbHjk4Nd1M9OWBk4JYGg;E3KEPRrZNC|#%$J`3Q3 z#gBmc#Vz8FM(&lyuh#IcENshu%Hl707niiQRVe}#T&t!lo+t^dMm zQ5e0)F3h%BLxU{Goz>t%F$6v!6!fq2y*n=JYXff0xas!B&I+Kd;acnWbgqXb>RpiK zcF46z(BS8GYIOujHQe;H;b_OXxtz; zcaW)uIG{ehlVfcV#Xk}!0W$CK`s^t}!M^!^Itc|X)6~Z7L^eG2g%(o1x ze5t22>IToHKub)sK%#R&IR{EmXRg&?R=yG$HX}Eof=>+aEjq248rf%iZZAB&#O9II z;tM6YDI`7nUsoKX!9)TT=ZRcBj%%*2241ke_=f;q6IB6b?S)m-ZW}f_!q$jMI4Oyd zY>nx_IqV%)+de_=;8wtymI&RjjX3@}(*hdX2D5bqwsd3U_S_wUJ0)!mm{#FNSQt@$ zKt_>qHN|L3$7gNWt5oRuLc>quOrPW+e0ACvkf8ccEN#Nq#t)F3_$X8<(HaUoVHn{K zuDg*)kNcJ}lFb?K^Z^^n3rRl=N&g@WK{K|_^Hp0g0?DG(AF$TvBuAIL57tTNL6q(! z6osR-U86I79jcpQ#%2Mv?I5`+>mOhZd;^u-g#eJdgr#*?ed`nbA*D2LzERiutnEQh z;aqNUa?<*Q?RmE!{nYGXsmEun z^x8qN958E)v4)@FhuZl9lGL)yBE%9HCO5kw?eW-K4#!!W{>B_6NOD2k58-Rh-~=vp z5D=2soAtXj$EYUmrwRc!HNB&5x-_kv127r-#*^uaLcNojmQ3HGbc`{pbMsSXE$1;a^k>A%lk>_9dU zt7BL*>!&%7lZj6~0}z>waMh*&|Ll zADz;~s5YAhG9M-jLDMyp7?54^H4_f3{o>e*Uqo$8)Gobx1f((T=hw9Cvag4Df6=^| zHoN5P!9Ph^xCn{64LR^9XUmsZBEcT9Wmq~ke{{L;z2_&U>^hT7O4fHy%@fC6wz`f#AiN_35(gxOwQ9=Q8)`m<6@6fZL!mT9&#Gk#{=;>4St`80G*8<~Vrx?3>H#O00t-C~;p z9ho-EiAOCogKpBF<8&Zdz^~pMYcjNxW-7nrc}-~=r1AcHj2+Bqx+v1A*Br`oOIo5@tLyd_?O2a%j_xP zpCfO}QRKzolH;DAO+GjAQ(c~mYU0!SY>;<`X-`N#4__3|BO=62dU&rd?b&_{3)@?3 za$9M^=0jSxviD!0+Rle`PhZqrL|uPQJGheP!jK6{`g zPC1GpTxat>8An^+6xBR$LgW9!WVXBQk+RAl)!iUUoUiB!8C}el^4F@V}s%U2>4s#sfV+4yQTFQvTtS z;GXs|InNq%CRY{hbtf=L@j}W}OYjiwgxM|MmH9WH{Hmfvwj?=2b$oqArmGabsb(MX ze@*S5KHI@y@RD9JlVSZLvx2E`>?^V3J)&vKtaQR^Klt-*SjP{Y3aO`qyC%kv>~Ri7Gh|0r|TMtSTv$cjOCi@5HEiytlmr4>W)@Exli#yiXFlvfEU5x1pYDyGW? zit-b=!#;H!tzxWWzFYTs!G}ZesPa@@Re}d@J+tnO#M!U9=Z0&X-i9q145f{nJG0p+ zA{cC)LlK&JQ}Y$^hcLcQ+Biecw2I9|otGRabZp_<#JcS@*PYR{y6-_kkK{H#n52Or zG#2d)26OoT=#XC6W#B8{*0Bn$0~PWIYL#D)Gr3ne*a%kJ6N2YCOOr{jK>aO;B^DyS zQ|~msDjeEU;3!MY+)-V>MK3eiO5`ntyJj96Df@8j{buDO&Lnw)SX$TiOu$c&-6JN= zeD6Qyv?N$0^c<^jkk{{0GjQosCzKx(*|B`_7%b!V?3IA;o$>geX|A>A$iI?Q_TOXb zH8&3t9FeS*dnzLng}X&#vyHuSQu7QkNqZf%MHGq-$ba#g^%+Sb0pvWrQ0}8DnAIzj zauB-CUE5*?I;d0{d89fF!Rreo<&!X$gx2o#%)Z&E-xNFWLhbvAA^-l~Z|&33yr&zs zLII^*{HUerd)707?~IDr=mNMVNiQ?Sg$Uw7nu{O*!rP5P7Zdm38j$*8stF+8$86Pk zrIt%XBQXFiS6{mYK|H2F8y>N-BLkZ)f;V7!ZD*0ud}FR`f`rxze~6PYKmj>4Blg>B zf)RLl;24xn=tT@zxGu^d2Q2OFTKrLmHWFHDl_21o4*^eqkl+^KVtx-540tiGK=$2T zM*WKs%q_NmeK?v>#{LWSklzD&v+(G?d(^2?b7oFzrN-Kp7s_gvrM<|@?|RbHRn0eb zkT+8DW`gsGTCaw!buNItzB2u`hA|E4s#DV33ewYefnIBsT|F~lW%8Zx{PN9D zH6sgx?4OF0REIOCD1Tt9Bp7o+y3z?BiuXBpzsauqSQDO3eh}eT?KJH`KfYWdB^Jo? z=gj=W0GpoSSi=0O>$QXISTwQ`ncSGk3uIc=?E_X%mClNDSe}I1A20K|C+W#9Pg53} z(L&d7)DWalw_~->v;_8Ng!Uf^E$hFfQ+ppy0hGgRLDpSJfC%lH2i&TamuH-y2yhDz z2{54w_CK(~!+TBDc`bQ7e|oBoltO*UWLE!Oe2YvjNyZm7TvD-FyjU3DxmZ0t9tGh9pzjTTC+r{t zz(>0__VeJ)l;rev5AiTbbRWy8J9u@j*>i>9-_M)RsIKOF0 zB}(+J$L3qu7T!LucXi1b2MJ-+QG5BXd`F2ed}!dIGb+}m%F=Z`HowZWNOkd$8Xw5} z^`r||ow*A19ww_sOhAyk%S3?W>h5V3vCrsY@S;i8J)9msVGh%JJ9;?kJb*J6v=u7N zhANt07%7t0**|$x2TnTRP>`EFgHx)!2{(U!!wc_&nRGy;Elts!K|K9yHNsuPC;J9q zjTs>GV74xa%}Xw!j{Jb$@$HR@hJB)c{z}S#YS0&JICX4smI=Pr$6C{b;vkBT!9POZ z8t_XGfPj$h6&%Q##9%j1Q&7ZE(KUe;1eG}zt&o`{d`fzdLn1J2mrET$A{tFuSs0~+QG zI8k4^93W+a|MbVbrFr6yS};!f>y}eSk}}oPQHwxo@Sdp?xa0g8Z91)(NWkqL58yXZ>Zr#6xuf5(+kq9(nVS2*9`<1m&`iid@_Y8-@Q-( zZ(y^4gU`PjBu|l+*-wRvg5I(8WMs`p{Km~UR8Pm0zJ9+zTxAQmY$P=8((TW?d+qs0 z(_IiJnq8DHWix0(N@SUN#Vt6j*~mtEz5Osg`Iw1X$5G`sZ%%I%;oXxF03Pr=uIBj! zOJefIZQoaBsT@M&{(&EdXewR)pW?1O9?G_DSC(X}>_W1nkTtuaQi?{N#=b<1E&IMq zmO_?fuZ)mHG?o~$8?x_P3{A{f%FfuDknOt$PtWsy@B6&ppWpn&?>F~7=Y3u0bzbLn z9>;my8Fc3PmBW~7cy*0#R=y}=kmlXw-y)38W`cz zcJD$ui%L1F5r_`=&3sIAqd6M?UuQ&c0c+of{H$J1!9+>%^pAh?U4TNC{2xyCU)>vW z2dG$m36NGK_=^Y79eU5OwdN3h(2a_8Sv=XkSG$1w%x(m>~*f zEu~3dEQi@s67!8&@e-}OS&vOBc5gZMJMe>2X60galHZswp9CPkb@j4@5G=B_@3SS6 z1%oCs2T+XEKN{p3jXfAaj;phD0fA;RbxX<=ORCt^Ik%u9=3mfE=k*V+3Fy$JCGw~ zw~mEiskHVdk7U}S3{_nE0N(_|Kmkwc%R=2X7d~6mb)nif!!Vfx&sU9s*Qee`M2eq1 zpr3>w&kzB7W^?<)Zn8{(&5uc5kw}f-`bMkET5IW>l`h|lma@9AO=-G$PtfowX?Tav za?Wi}P%@6@;Lf;YG`}* zX;VC7j?>-gjzEqjQ!z2kA|2!&cSHWU>rS@hQZxSz5cXgFWv`$zQI!FqSSR=TNke=> z3SKMBWi&)&LIj)Yt|?ID?&UnEsf@R_DC0cO4YY_dnz5w8v8NaFc+%{n8@&!IDiRxE4ibU+Ic z5Kd_32u%PYBx%P<+9N3g9AfEZdB@N4fjv56r7t>QBX3c&&U72onozUSxK8ITCm={l z;kjTeWNxboItN zt&4f0-4bdb>wG2vId;HlFyoIT%1mleR$bLtBMVp#9onA46DR&S%l}E^sxBB0pqg70 zO@y!A&=)PMQmOO3BK@`|u1!qmmfJE;{o+9P*LxZjC?)L-6El!#o2Yw>3%IbUx4S4Bm9VOy zTPdtxjTJaPtq(!r??=_Mz|eTZalc!$FH}Kneoq`SUyx15pnD-a5P^(9b|#yj7reHq)e@ov{F z=W=w^YT}7(j4kqJvH!Y^nOU(_)Ti^D9xylhxP-1HvT|vvHY1^sua#^ZQ4kiBpY^X^ z5PaOi7zbwD^eDYxmQ1B!X}l%vd0R^xH6tUt#$Q)BdJx$h zhTbpX-53WKo##Cgb;ZD0#Jh6UHLT`vHFTCa2}&| z(fwO6g*jh>;YkCgI4!3io+lsV;i)B6o;@L*tU&Vh6zt|wX6hh?01%-$DT-#YWyYcK zjhsOjSf_{uoYj$d{f?XIy|;|N!~sHzVsV466Gb9V(>2ig2m$ruzg95-Qc)x{OpCUj zF}zlxg;s(Eq%oGzEkEvQh#z$5qStI@p8&{A2Q{!VrzmQ{vyFP zH}NYS!Yryk_q(`m7D5InGbCmiOSvl9q@ttThae;2z+A$w9&bI611 zqttCO?UTfKxV%8CHc!L!?;G+EzlP1cYHm6fv`PNJ<14x)ans#Mfkb?Kr$6s7#^&J5 zSzH;cZhWOk1}X*wpM`?i{;HT~m~ik&9h|Owb1d_Fnss-qqo#|cO57=MA6?EXh!AZD zRsYSbCQz!{S2r~AT?#$wB>H*f(;iF}`NGXpnp?$PxuS;SrV>rSY_bEVW53~-950jf z=44h=wc6MgDd7npWrs2KP200$6q={8&*9b)IBtGrM1#|V-LcVDyGG1HRRSxWk*+um zZt7yWZq9<=K=O4wWkhbl5mMk5!{sLz8Su23+7Db67Nup;MXzGl^z^d2*J8U&XV_~r z1+u!)&Sd8Atk{>%5@!jbasmdvkk3FiF%TdlC;bwsX~AhQvZ-v89d=Gk5uJ77+o#T> zZP8cb39X=_>E}Gxgd9~^CGi^ZRoebdE4s+9$4iSeA9l7AyX6K9Qo4>y3Su3Zqnr1svS3Q^e+SN8c3*b zE*a^KU@-37;3QMNosBI#M_96(&r>&}=(gXZ;>`Z?&SZfDNf$HJQX<^7Z_(%d+SCy} z+~jJNwLsd7N;<@vcW#OOu2_Z%cX(w4bEpfXa33S)2e7j_{2~Qg)AxQ33*^~h{s})U zeRI0F*HxV{W=D9-lD75O`P;=qwsOYC$PXo>tuw0|%a_V2eeCJId_8LPG)!?_g(DL% z_j1qK)3D;MGZ%kdE!bWXdVW#w`><@*+L%WO(?FMVm@WgHb%%GMm|Cx9OMKX%EJWZwU@95C(|9(T@7hM zPox&Vk{d8mt*xbc?>bmcRSXgnNdwVuoiceP`Eip49|#h0vp;I6nw0ry&XEs>7FAjf zn6j?}F|qdA4Y`~*{Yk?YGw=)`ujfaeFn>DJ$+QnE!CYt2e}Ka$N`5PKuN6gYPKY`$&Tg+I!z1_v&Cw){7}x~HoD02gO2j@l>I8Q zRW^2mX6((}Kmp7oN#4+XrAW~JH1vAbZO3O(Zf!@Z~(}gQzmRpEkE%0Fx`_e95_BAa3ObqysgF0 z5H{E)d|poJve%cVT_Jf@`2!>6!u8Q)eX;sZKl*y8T>cnx)phyqASo;|BS0XSFD#R@ z3@N2q*CA!aH+b!D+76LJ^9ctA0oe;a&=btmn}GZz-jsnqQ9`foVgtd@iCnsko|-Kw z6KWIk4*jf{t0nhzk^T2|86u$LI!Z?zvU9ovw826EFw-2KLh_H)C?LtfnUq z(=<>H@$<{)1$vzqNz3OqQ%He(*+)R$oTth#!Qm*H`2n8+=tiqX0kw{|Vp@;Qo~fbR zbLz)vu#NZSeL2u5MD#|A}RV zX#L~Z6C5qK&_}#jIpc-2cp5B1m-Y!E^kbyW2rBkouz^w%n*RwrW=5Mfl?0Z@cb|T- zcMl*idUi4z5mVVEDbyeI<%;Soa~XIV?L=xG0mWbv*+CrQNi_mC*ciA&@J~huVbduY zyQC<}_{m0I@n*bL2J7CH$w)eJzgIItu?vOu2L zgxJvImx+mm=X|-heq6);paK`2J%9x2u?Cg&lUOQo%SnnFQ5$99S zUfc8tR%6fi%~Gk%u}}+N;b7auA9b3iR)w7QS|N9SwI1&wSia5*%IVvlBFO?4=0+lA zW2YNUFac!>E?(zRrPbA5b4xjWS#t(w?gUW2BFzuZ)vI^r@{{UIUwjRdYM+rj1*y<* zpRFFOoBi2kOM+_-->pXVnMMV}M?TiOh(L@pESWW(<|N1d4Et=?rInL!WZ!$!PzcbR zspi$5B)v}yzR+c2LncB(wdfd$;2&>rF9%o`&@6gZ)Ve!v9Hk8X%%dtv3VJKiuUEJqrZlM?Qs zW*sD+)ZQvFx1aU6xed*vy}&!<-(5M3-`@hHuc_2cWSTf#%_6duSHD*Bdb3`*zX{Nu zOGHUR+l-Y5=DJ<H26-_*CX=& z?m{)er#GZCl)XkflE`{yHflJHg6q`8l?P~?R}~vnUEQ+E{o{Pb>)O!Vt>73CkzV|~ z)t1s&m*bR*NVR>QJA2NyA(1iQNQo=qq*SM{6~1G%>gU+cx+J$bUA%8OZ7yEv%0*&R zScFfH?AYVUMk8o`Ro!(cO9E`aG#W`c$bI8I)wT=6-(4HXgaiIMZ1JfdfOtddzwcVkhH&^{_#2CD%7usUhEoZr4q`2io@qB{+ z%^{8tS6w05kQ-KEwc_T#ofZ)bS7~vkIYDQG8@rDJg{arfi%PnJQcfL*;?9tQ3oRa*25yv(+&I*!q;LXO5w9CYXOUq7bxP6s z-{JP%Ba)8mzYq~jkdb^`t@^9WuT+Vc;#@Yq~v zZ4Vcax!IEzDI78_IvjH-@h%d2a1v<+xj9L_o1_7eVKRmo8S@(S(8hB9<5#dfzT!{Z zE$pH?-fIg2F$4j?uCZXg_FF>&nr|NiFRGRK5BFCDQq6DN7L`8^20Y3KQTbN;mA>9L zUVgH=;N3)Pbu||V{_X*?QNAoY{qly?}<$g_%UY+s@@T&!_ zXxmrYaHCat0zMg1btylfF@(xQ9WC8(|dQ?Tfc7;iI{cVuhik6_t80IVr`*t%_~Im4@#6(2At?^zl-=9xv#uG z9bM#n@=7b4$I6I~AsgBo!%3y4EtoYer8^LvS!Oe|{K=_|77>2c+t7$(>1UCak1DZW zv-|b0EPQ`(xwyhDSGT^ID2jd2?QnYg>r%lxlSr}Sju@Pev&-(-1tv6L6(P3n+Z%^oOR^`HQdZ(Q~ExLXPOo^6e+Nd221r;Yf8f zZ?5dQ@noj*Ck@~sCConZ7t=0%FT`#SVYu$zkWVy*WIasD|D1)HSXy|AnN|4p>(oXQ zL1wTwqVPgQ==jpgLR3%$8RdZCrtwEFvLE83gs#yCSZr=mjCDi@F_km#>S+^x$7;u_ zn|EvK`!qe2?ukT9 zH!P)xf*yg7&HD2{UdJ;4pPSo^wA>zOdr3S!=GNn%>(L6k3Y!qx2R1&qK2w-M}FF zry1F9aSagt@K}U|8NcAYxs>iHOwNSQ+SFum*bnv^KLa<^R{FKKR;QgNlf+|Tb+@1j zi~?tx|J#y?(!T;BP4s*=0X?1X(6Tq6*aH(9`FD++Ce?Pu* uCMMxplu>>d__WlV9uVdQpWT3+L#|9QWlatdsU#6!wHtS?f4Fu(@P7cv2xV;m literal 0 HcmV?d00001 diff --git a/example/ck_tile/01_fmha/script/benchmark.sh b/example/ck_tile/01_fmha/script/benchmark.sh new file mode 100644 index 0000000000..a8f3a8202c --- /dev/null +++ b/example/ck_tile/01_fmha/script/benchmark.sh @@ -0,0 +1,21 @@ +#!/bin/sh +# TODO: run this script from CK root +BUILD=build +EXE=$BUILD/bin/example_fmha_fwd +VALID=0 + +for prec in "fp16" "bf16" ; do +for perm in 0 1 ; do +for hdim in 64 128 256 ; do + +nhead=$((2048 / $hdim)) # follow fav2 setup +$EXE -prec=$prec -b=32 -h=$nhead -d=$hdim -s=512 -iperm=$perm -operm=$perm -v=$VALID ; sleep 3 +$EXE -prec=$prec -b=16 -h=$nhead -d=$hdim -s=1024 -iperm=$perm -operm=$perm -v=$VALID ; sleep 3 +$EXE -prec=$prec -b=8 -h=$nhead -d=$hdim -s=2048 -iperm=$perm -operm=$perm -v=$VALID ; sleep 3 +$EXE -prec=$prec -b=4 -h=$nhead -d=$hdim -s=4096 -iperm=$perm -operm=$perm -v=$VALID ; sleep 3 +$EXE -prec=$prec -b=2 -h=$nhead -d=$hdim -s=8192 -iperm=$perm -operm=$perm -v=$VALID ; sleep 3 +$EXE -prec=$prec -b=1 -h=$nhead -d=$hdim -s=16384 -iperm=$perm -operm=$perm -v=$VALID ; sleep 3 + +done +done +done diff --git a/example/ck_tile/01_fmha/script/smoke_test.sh b/example/ck_tile/01_fmha/script/smoke_test.sh new file mode 100644 index 0000000000..7275c9d1b6 --- /dev/null +++ b/example/ck_tile/01_fmha/script/smoke_test.sh @@ -0,0 +1,34 @@ +#!/bin/sh +# TODO: run this script from CK root +BUILD=build +EXE=$BUILD/bin/example_fmha_fwd +KNAME=1 + +export CK_WARMUP=0 +export CK_REPEAT=1 + +COMMON_ARGS='-v=1 -warmup=0 -repeat=1' +mode=0 + +for prec in "fp16" "bf16" ; do +# for mode in 1 0 ; do +for perm in 0 1 ; do +for vlayout in "r" "c" ; do +for hdim in 32 64 128 256 ; do +for lse in 0 1 ; do +for bias in 0 1 ; do + +$EXE -prec=$prec -mode=$mode -b=1 -h=1 -d=$hdim -s=1024 -bias=$bias -lse=$lse -iperm=$perm -operm=$perm -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS +$EXE -prec=$prec -mode=$mode -b=2 -h=2 -h_k=1 -d=16, -d_v=$hdim -s=55 -s_k=256 -bias=$bias -lse=$lse -iperm=$perm -operm=$perm -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS +$EXE -prec=$prec -mode=$mode -b=1 -h=3 -d=$hdim -s=100 -s_k=51 -bias=$bias -lse=$lse -iperm=$perm -operm=$perm -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS +$EXE -prec=$prec -mode=$mode -b=1 -h=1 -d=16 -d_v=$hdim -s=99 -s_k=256 -bias=$bias -lse=$lse -iperm=$perm -operm=$perm -mask=1 -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS +$EXE -prec=$prec -mode=$mode -b=1 -h=1 -d=$hdim -s=1024 -s_k=256 -bias=$bias -lse=$lse -iperm=$perm -operm=$perm -mask=2 -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS +$EXE -prec=$prec -mode=$mode -b=1 -h=1 -d=$hdim -s=256 -s_k=512 -bias=$bias -lse=$lse -iperm=$perm -operm=$perm -mask=g:128,32 -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS + +done +done +done +done +done +done +#done diff --git a/example/ck_tile/01_fmha/utils.hpp b/example/ck_tile/01_fmha/utils.hpp new file mode 100644 index 0000000000..14347a344c --- /dev/null +++ b/example/ck_tile/01_fmha/utils.hpp @@ -0,0 +1,91 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include +#include +#include +#include + +#include "ck_tile/core/container/span.hpp" + +enum class mode_enum +{ + batch = 0, + group +}; + +std::ostream& operator<<(std::ostream& stream, mode_enum mode) +{ + return stream << (mode == mode_enum::batch ? "batch" : "group"); +} + +std::vector to_seqstarts(ck_tile::span seqlens) +{ + std::vector seqstarts = {0}; + for(int32_t seqlen : seqlens) + { + seqstarts.push_back(seqstarts.back() + seqlen); + } + assert(seqstarts.size() == seqlens.size() + 1); + return seqstarts; +} + +std::vector generate_seqlens(mode_enum mode, + unsigned count, + int32_t seqlens_sum, + std::optional seed = std::nullopt) +{ + assert(0 < count); + + std::vector seqlens(count, seqlens_sum); + + if(mode == mode_enum::group && 1 < count) + { + using size_type = std::vector::size_type; + + std::mt19937 random_engine(seed.has_value() ? *seed : std::random_device{}()); + std::uniform_int_distribution idx_dist(0, count - 1); + auto next_idx = std::bind(idx_dist, std::ref(random_engine)); + + std::uniform_int_distribution step_dist(1, count - 1); + auto next_step = std::bind(step_dist, std::ref(random_engine)); + + for(unsigned repeat = seqlens_sum * (count / 2); 0 < repeat; --repeat) + { + const size_type to_decrease = next_idx(); + // make sure each elements of seqlens is always greater than 0 + if(seqlens[to_decrease] == 1) + { + continue; + } + + const size_type to_increase = (to_decrease + next_step()) % count; + + --seqlens[to_decrease]; + ++seqlens[to_increase]; + } + } + + return seqlens; +} + +std::vector generate_seqstarts(mode_enum mode, + unsigned count, + int32_t seqlens_sum, + std::optional seed = std::nullopt) +{ + return to_seqstarts(generate_seqlens(mode, count, seqlens_sum, seed)); +} + +int env_get_int(const char* var_name, int default_int) +{ + char* v = getenv(var_name); + int r = default_int; + if(v) + r = atoi(v); + return r; +} diff --git a/include/ck_tile/README.md b/include/ck_tile/README.md new file mode 100644 index 0000000000..a41b24c1af --- /dev/null +++ b/include/ck_tile/README.md @@ -0,0 +1,40 @@ +# ck_tile +`ck_tile` provides a programming model with templated abstractions to enable users to implement performance-critical kernels for machine learning workloads. introduces following basic concepts to help users building your own operator + - tensor coordinate transformation, the is the core concept of layout/index transform abstraction in both compiler time and run time. + - tile-based programming model, including tile-level api and the concept of distributed tensor. + +`ck_tile` is splitted into several componenets including `core`, `host`, `ops/gemm`, `ops/fmha`... each component you only need to include a single header (e.g `#include "ck_tile/core.hpp"`, `#include "ck_tile/ops/fmha.hpp"`) then you are able to use the function/structure inside (different from old `ck`) + +**[core]** +`ck_tile/core` contains all the basic data structure and function to build the kernel, you can only include this header and build your own operators that utilizing all the basic building blocks introduced in ck. + +`core/container` + - array, store runtime variables with fixed length (tensor index, register buffer, etc...) + - tuple, same as std::tuple, hold different type of data, and one of the solution to achieve multiple buffer. + - sequence, compile time integer sequence used to build various internal structures, or to describe tile size + - other convenient structure build on top of above 3 + +`core/numeric` + - gpu data type like `fp16_t`, `bf16_t`, `fp8_t`... and the conversion between each other + - constexpr integer similiar to std::integral_constant to be used as compile time integer. + +`core/algorithm` + - coordinate transformation system, used to build tensor transform and compile time indexing. This is the core idea introduced in old `ck` to describe how a tensor is build by several basic transform primitives like `merge`/`unmerge`/`embed` etc... and how we indexing into a ND tensor that finally mapped to 1D memory offset. + +`core/tensor` + - tensor descriptor, to describe how a ND tensor + - distributed tensor, describe the storage of this tensor, and the distribution of how a collection of threads collaborately work for this tensor. + - tile level API, including `load_tile`, `store_tile`, `shuffle_tile`, `slice_tile`, etc... + +**[host]** +`ck_tile/host` contains all the host side utilities to launch a kernel, create the device buffer, and some reference implementations. This can be used to create examples (like that under ck_tile example folder) and simple executable to invoke this kernel, so if you only need `ck_tile` to build your own device library then it's OK to not include this. Based on this, it is recommended to include the specific header you needed under this folder to avoid including unwanted headers (e.g, only include `ck_tile/host/kernel_launch.hpp`), unless you are writing a host executable. + +**[ops/gemm, ops/fmha, ops/reduce...]** +our implementation of different device operators. + - warp, warp tile level operator + - block, block tile level operator + - pipeline, pipeline that can achieve a customized tile level mainloop (or epilogue). By switching different pipeline to the kernel template you can have different kind of pipeline optimizations. + - kernel, template interface for users to instantiate a particular kernel + +**[ops/epilogue]** +epilogue part of our kernel. We may extend this epilogue part to let users to build their own cutomized epilogues. diff --git a/include/ck_tile/core.hpp b/include/ck_tile/core.hpp new file mode 100644 index 0000000000..0123163a65 --- /dev/null +++ b/include/ck_tile/core.hpp @@ -0,0 +1,56 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core/algorithm/cluster_descriptor.hpp" +#include "ck_tile/core/algorithm/coordinate_transform.hpp" +#include "ck_tile/core/algorithm/space_filling_curve.hpp" +#include "ck_tile/core/arch/amd_address_space.hpp" +#include "ck_tile/core/arch/amd_buffer_addressing.hpp" +#include "ck_tile/core/config.hpp" +#include "ck_tile/core/container/array.hpp" +#include "ck_tile/core/container/container_helper.hpp" +#include "ck_tile/core/container/map.hpp" +#include "ck_tile/core/container/meta_data_buffer.hpp" +#include "ck_tile/core/container/multi_index.hpp" +#include "ck_tile/core/container/sequence.hpp" +#include "ck_tile/core/container/span.hpp" +#include "ck_tile/core/container/statically_indexed_array.hpp" +#include "ck_tile/core/container/tuple.hpp" +#include "ck_tile/core/numeric/arithmetic.hpp" +#include "ck_tile/core/numeric/bfloat16.hpp" +#include "ck_tile/core/numeric/float8.hpp" +#include "ck_tile/core/numeric/half.hpp" +#include "ck_tile/core/numeric/integer.hpp" +#include "ck_tile/core/numeric/integral_constant.hpp" +#include "ck_tile/core/numeric/math.hpp" +#include "ck_tile/core/numeric/type_convert.hpp" +#include "ck_tile/core/numeric/vector_type.hpp" +#include "ck_tile/core/tensor/buffer_view.hpp" +#include "ck_tile/core/tensor/load_tile.hpp" +#include "ck_tile/core/tensor/null_tensor.hpp" +#include "ck_tile/core/tensor/null_tile_window.hpp" +#include "ck_tile/core/tensor/shuffle_tile.hpp" +#include "ck_tile/core/tensor/slice_tile.hpp" +#include "ck_tile/core/tensor/static_distributed_tensor.hpp" +#include "ck_tile/core/tensor/store_tile.hpp" +#include "ck_tile/core/tensor/sweep_tile.hpp" +#include "ck_tile/core/tensor/tensor_adaptor.hpp" +#include "ck_tile/core/tensor/tensor_adaptor_coordinate.hpp" +#include "ck_tile/core/tensor/tensor_coordinate.hpp" +#include "ck_tile/core/tensor/tensor_descriptor.hpp" +#include "ck_tile/core/tensor/tensor_view.hpp" +#include "ck_tile/core/tensor/tile_distribution.hpp" +#include "ck_tile/core/tensor/tile_distribution_encoding.hpp" +#include "ck_tile/core/tensor/tile_elementwise.hpp" +#include "ck_tile/core/tensor/tile_window.hpp" +#include "ck_tile/core/utility/bit_cast.hpp" +#include "ck_tile/core/utility/functional.hpp" +#include "ck_tile/core/utility/limits.hpp" +#include "ck_tile/core/utility/magic_div.hpp" +#include "ck_tile/core/utility/random.hpp" +#include "ck_tile/core/utility/to_sequence.hpp" +#include "ck_tile/core/utility/type_convert.hpp" +#include "ck_tile/core/utility/type_traits.hpp" + diff --git a/include/ck_tile/core/README.md b/include/ck_tile/core/README.md new file mode 100644 index 0000000000..d2ecfabae3 --- /dev/null +++ b/include/ck_tile/core/README.md @@ -0,0 +1,18 @@ +# ck_tile/core # + +`ck_tile/core` contains every basic functions and structures to create a GPU kernel using `ck_tile`. User should only include `ck_tile/core.hpp` this single header to use all the functionality. Everything is under `ck_tile` namespace. The coding style under this folder should be similar to `std` (`snake_case` for structure/function, Camel for template types...) + +``` +algorithm/ + coordinate transform and some other reusable algorithm +arch/ + contains some basic device building block like mma, buffer addressing, etc... +container/ + contains basic container data structure, array/sequence/tuple/... +numeric/ + data type, and data type related math +tensor/ + tensor descriptors and tile level API +utility/ + other utility function for both host/device +``` diff --git a/include/ck_tile/core/algorithm/cluster_descriptor.hpp b/include/ck_tile/core/algorithm/cluster_descriptor.hpp new file mode 100644 index 0000000000..c59a7c1fa1 --- /dev/null +++ b/include/ck_tile/core/algorithm/cluster_descriptor.hpp @@ -0,0 +1,38 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core/config.hpp" +#include "ck_tile/core/algorithm/coordinate_transform.hpp" +#include "ck_tile/core/tensor/tensor_adaptor.hpp" +#include "ck_tile/core/container/container_helper.hpp" +#include "ck_tile/core/utility/functional.hpp" +#include "ck_tile/core/utility/type_traits.hpp" + +namespace ck_tile { + +template ::type> +CK_TILE_HOST_DEVICE constexpr auto make_cluster_descriptor( + const Lengths& lengths, + ArrangeOrder order = typename arithmetic_sequence_gen<0, Lengths::size(), 1>::type{}) +{ + constexpr index_t ndim_low = Lengths::size(); + + const auto reordered_lengths = container_reorder_given_new2old(lengths, order); + + const auto low_lengths = generate_tuple( + [&](auto idim_low) { return reordered_lengths[idim_low]; }, number{}); + + const auto transform = make_merge_transform(low_lengths); + + constexpr auto low_dim_old_top_ids = ArrangeOrder{}; + + constexpr auto up_dim_new_top_ids = sequence<0>{}; + + return make_single_stage_tensor_adaptor( + make_tuple(transform), make_tuple(low_dim_old_top_ids), make_tuple(up_dim_new_top_ids)); +} + +} // namespace ck_tile diff --git a/include/ck_tile/core/algorithm/coordinate_transform.hpp b/include/ck_tile/core/algorithm/coordinate_transform.hpp new file mode 100644 index 0000000000..b8efe049c1 --- /dev/null +++ b/include/ck_tile/core/algorithm/coordinate_transform.hpp @@ -0,0 +1,1664 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core/config.hpp" +#include "ck_tile/core/container/multi_index.hpp" +#include "ck_tile/core/container/container_helper.hpp" +#include "ck_tile/core/utility/functional.hpp" +#include "ck_tile/core/utility/type_traits.hpp" +#include "ck_tile/core/utility/magic_div.hpp" + +namespace ck_tile { + +enum struct cood_transform_enum +{ + undefined, + pass_through, + pad, + embed, + merge, + unmerge, + replicate, + xor_t, + offset, +}; + +template +struct base_transform +{ + CK_TILE_HOST_DEVICE static constexpr auto get_type_enum() + { + return cood_transform_enum::undefined; + } + + CK_TILE_HOST_DEVICE static constexpr index_t get_num_of_lower_dimension() { return NDimLow; } + + CK_TILE_HOST_DEVICE static constexpr index_t get_num_of_upper_dimension() { return NDimUp; } + + // return safe value for vector length/stride, based on compile-time known only + // variables + // MUST be static function + template + CK_TILE_HOST_DEVICE static constexpr auto + calculate_upper_dimension_safe_vector_length_strides(const LowVectorLengths&, + const LowVectorStrides&) + { + if constexpr(NDimUp > 0) + { + array up_vector_lengths = make_array_with({-1}); + array up_vector_strides = make_array_with({-1}); + + return make_tuple(up_vector_lengths, up_vector_strides); + } + else + { + return make_tuple(array{}, array{}); + } + } +}; + +template +struct pass_through : public base_transform<1, 1> +{ + static constexpr auto type_enum = cood_transform_enum::pass_through; + + using LowerIndex = multi_index<1>; + using UpperIndex = multi_index<1>; + + using UpLengths = decltype(make_tuple(LowLength{})); + + UpLengths up_lengths_; + + CK_TILE_HOST_DEVICE constexpr pass_through() = default; + + CK_TILE_HOST_DEVICE constexpr pass_through(const LowLength& low_length) + : up_lengths_{make_tuple(low_length)} + { + } + + CK_TILE_HOST_DEVICE static constexpr auto get_type_enum() + { + return cood_transform_enum::pass_through; + } + + CK_TILE_HOST_DEVICE constexpr const auto& get_upper_lengths() const { return up_lengths_; } + + template + CK_TILE_HOST_DEVICE static constexpr void calculate_lower_index(LowIdx& idx_low, + const UpIdx& idx_up) + { + static_assert(LowIdx::size() == 1 && UpIdx::size() == 1, + "wrong! inconsistent # of dimension"); + + idx_low(number<0>{}) = idx_up[number<0>{}]; + } + + template + CK_TILE_HOST_DEVICE static void update_lower_index(LowIdxDiff& idx_diff_low, + const UpIdxDiff& idx_diff_up, + LowIdx& idx_low, + const UpIdx&) + { + static_assert(LowIdxDiff::size() == 1 && UpIdxDiff::size() == 1 && LowIdx::size() == 1 && + UpIdx::size() == 1, + "wrong! inconsistent # of dimension"); + + constexpr auto I0 = number<0>{}; + + idx_diff_low[I0] = idx_diff_up[I0]; + + idx_low += idx_diff_low; + } + + CK_TILE_HOST_DEVICE static constexpr bool + is_valid_upper_index_always_mapped_to_valid_lower_index() + { + return true; + } + + template + CK_TILE_HOST_DEVICE static constexpr bool + is_valid_upper_index_mapped_to_valid_lower_index(const UpIdx& /* idx_up */) + { + return true; + } + + CK_TILE_HOST_DEVICE static constexpr bool is_known_at_compile_time() + { + return ck_tile::is_known_at_compile_time::value; + } + + // MUST be static function + template + CK_TILE_HOST_DEVICE static constexpr auto + calculate_upper_dimension_safe_vector_length_strides(const LowVectorLengths& low_vector_lengths, + const LowVectorStrides& low_vector_strides) + { + return make_tuple(low_vector_lengths, low_vector_strides); + } + + CK_TILE_HOST_DEVICE void print() const + { + printf("pass_through{"); + + // + printf("up_lengths_:"); + print(up_lengths_); + + // + printf("}"); + } +}; + +template +struct pad : public base_transform<1, 1> +{ + using LowerIndex = multi_index<1>; + using UpperIndex = multi_index<1>; + + using UpLengths = decltype(make_tuple(LowLength{} + LeftPadLength{} + RightPadLength{})); + + UpLengths up_lengths_; + LeftPadLength left_pad_length_; + RightPadLength right_pad_length_; + + CK_TILE_HOST_DEVICE constexpr pad() : up_lengths_{}, left_pad_length_{}, right_pad_length_{} {} + + CK_TILE_HOST_DEVICE constexpr pad(const LowLength& low_length, + const LeftPadLength& left_pad_length, + const RightPadLength& right_pad_length) + : up_lengths_{make_tuple(low_length + left_pad_length + right_pad_length)}, + left_pad_length_{left_pad_length}, + right_pad_length_{right_pad_length} + { + } + + CK_TILE_HOST_DEVICE constexpr const auto& get_upper_lengths() const { return up_lengths_; } + + template + CK_TILE_HOST_DEVICE constexpr void calculate_lower_index(LowIdx& idx_low, + const UpIdx& idx_up) const + { + static_assert(LowIdx::size() == 1 && UpIdx::size() == 1, + "wrong! inconsistent # of dimension"); + + idx_low(number<0>{}) = idx_up[number<0>{}] - left_pad_length_; + } + + template + CK_TILE_HOST_DEVICE static void update_lower_index(LowIdxDiff& idx_diff_low, + const UpIdxDiff& idx_diff_up, + LowIdx& idx_low, + const UpIdx&) + { + static_assert(LowIdxDiff::size() == 1 && UpIdxDiff::size() == 1 && LowIdx::size() == 1 && + UpIdx::size() == 1, + "wrong! inconsistent # of dimension"); + + constexpr auto I0 = number<0>{}; + + idx_diff_low[I0] = idx_diff_up[I0]; + + idx_low += idx_diff_low; + } + + CK_TILE_HOST_DEVICE static constexpr bool + is_valid_upper_index_always_mapped_to_valid_lower_index() + { + return SkipIsValidCheck; + } + + template + CK_TILE_HOST_DEVICE constexpr bool + is_valid_upper_index_mapped_to_valid_lower_index(const UpIdx& idx_up) const + { + return SkipIsValidCheck || + ((idx_up[number<0>{}] >= left_pad_length_) && + (idx_up[number<0>{}] < up_lengths_[number<0>{}] - right_pad_length_)); + } + + CK_TILE_HOST_DEVICE static constexpr bool is_known_at_compile_time() + { + return ck_tile::ck_tile::is_known_at_compile_time::value && + ck_tile::is_known_at_compile_time::value && + ck_tile::is_known_at_compile_time::value; + } + + CK_TILE_HOST_DEVICE void print() const + { + printf("pad{"); + + // + printf("up_lengths_: "); + print(up_lengths_); + printf(", "); + + // + printf("left_pad_length_: "); + print(left_pad_length_); + printf(", "); + + // + printf("right_pad_length_: "); + print(right_pad_length_); + + printf("}"); + } +}; + +template +struct left_pad +{ + using LowerIndex = multi_index<1>; + using UpperIndex = multi_index<1>; + + using UpLengths = decltype(make_tuple(LowLength{} + LeftPadLength{})); + + UpLengths up_lengths_; + LeftPadLength left_pad_length_; + + CK_TILE_HOST_DEVICE constexpr left_pad() = default; + + CK_TILE_HOST_DEVICE constexpr left_pad(const LowLength& low_length, + const LeftPadLength& left_pad_length) + : up_lengths_{make_tuple(low_length + left_pad_length)}, left_pad_length_{left_pad_length} + { + } + + CK_TILE_HOST_DEVICE constexpr const auto& get_upper_lengths() const { return up_lengths_; } + + template + CK_TILE_HOST_DEVICE constexpr void calculate_lower_index(LowIdx& idx_low, + const UpIdx& idx_up) const + { + static_assert(LowIdx::size() == 1 && UpIdx::size() == 1, + "wrong! inconsistent # of dimension"); + + idx_low(number<0>{}) = idx_up[number<0>{}] - left_pad_length_; + } + + template + CK_TILE_HOST_DEVICE static void update_lower_index(LowIdxDiff& idx_diff_low, + const UpIdxDiff& idx_diff_up, + LowIdx& idx_low, + const UpIdx&) + { + static_assert(LowIdxDiff::size() == 1 && UpIdxDiff::size() == 1 && LowIdx::size() == 1 && + UpIdx::size() == 1, + "wrong! inconsistent # of dimension"); + + constexpr auto I0 = number<0>{}; + + idx_diff_low[I0] = idx_diff_up[I0]; + + idx_low += idx_diff_low; + } + + CK_TILE_HOST_DEVICE static constexpr bool + is_valid_upper_index_always_mapped_to_valid_lower_index() + { + return SkipIsValidCheck; + } + + template + CK_TILE_HOST_DEVICE constexpr bool + is_valid_upper_index_mapped_to_valid_lower_index(const UpIdx& idx_up) const + { + return SkipIsValidCheck || (idx_up[number<0>{}] >= left_pad_length_); + } + + CK_TILE_HOST_DEVICE static constexpr bool is_known_at_compile_time() + { + return ck_tile::is_known_at_compile_time::value && + ck_tile::is_known_at_compile_time::value; + } + + // MUST be static function + template + CK_TILE_HOST_DEVICE static constexpr auto + calculate_upper_dimension_safe_vector_length_strides(const LowVectorLengths& low_vector_lengths, + const LowVectorStrides& low_vector_strides) + { + // TODO: we allow pass through this vector length. If one need per-pixel check, + // should change the guaranteed vector length while creating the tensor view. + // It's up to runtime to check the padding length should be multiple of vector length + return make_tuple(low_vector_lengths, low_vector_strides); + } + + CK_TILE_HOST_DEVICE void print() const + { + printf("left_pad{"); + + // + printf("up_lengths_: "); + print(up_lengths_); + printf(", "); + + // + printf("left_pad_length_: "); + print(left_pad_length_); + + printf("}"); + } +}; + +template +struct right_pad : public base_transform<1, 1> +{ + using LowerIndex = multi_index<1>; + using UpperIndex = multi_index<1>; + + using UpLengths = decltype(make_tuple(LowLength{} + RightPadLength{})); + + UpLengths up_lengths_; + LowLength low_length_; + RightPadLength right_pad_length_; + + CK_TILE_HOST_DEVICE constexpr right_pad() = default; + + CK_TILE_HOST_DEVICE constexpr right_pad(const LowLength& low_length, + const RightPadLength& right_pad_length) + : up_lengths_{make_tuple(low_length + right_pad_length)}, + low_length_{low_length}, + right_pad_length_{right_pad_length} + { + } + + CK_TILE_HOST_DEVICE constexpr const auto& get_upper_lengths() const { return up_lengths_; } + + template + CK_TILE_HOST_DEVICE static constexpr void calculate_lower_index(LowIdx& idx_low, + const UpIdx& idx_up) + { + static_assert(LowIdx::size() == 1 && UpIdx::size() == 1, + "wrong! inconsistent # of dimension"); + + idx_low(number<0>{}) = idx_up[number<0>{}]; + } + + template + CK_TILE_HOST_DEVICE static void update_lower_index(LowIdxDiff& idx_diff_low, + const UpIdxDiff& idx_diff_up, + LowIdx& idx_low, + const UpIdx&) + { + static_assert(LowIdxDiff::size() == 1 && UpIdxDiff::size() == 1 && LowIdx::size() == 1 && + UpIdx::size() == 1, + "wrong! inconsistent # of dimension"); + + constexpr auto I0 = number<0>{}; + + idx_diff_low[I0] = idx_diff_up[I0]; + + idx_low += idx_diff_low; + } + + CK_TILE_HOST_DEVICE static constexpr bool + is_valid_upper_index_always_mapped_to_valid_lower_index() + { + return SkipIsValidCheck; + } + + template + CK_TILE_HOST_DEVICE constexpr bool + is_valid_upper_index_mapped_to_valid_lower_index(const UpIdx& idx_up) const + { + return SkipIsValidCheck || (idx_up[number<0>{}] < low_length_); + } + + CK_TILE_HOST_DEVICE static constexpr bool is_known_at_compile_time() + { + return ck_tile::is_known_at_compile_time::value && + ck_tile::is_known_at_compile_time::value && + ck_tile::is_known_at_compile_time::value; + } + + // MUST be static function + template + CK_TILE_HOST_DEVICE static constexpr auto + calculate_upper_dimension_safe_vector_length_strides(const LowVectorLengths& low_vector_lengths, + const LowVectorStrides& low_vector_strides) + { + // TODO: we allow pass through this vector length. If one need per-pixel check, + // should change the guaranteed vector length while creating the tensor view. + // It's up to runtime to check the padding length should be multiple of vector length + return make_tuple(low_vector_lengths, low_vector_strides); + } + + CK_TILE_HOST_DEVICE void print() const + { + printf("right_pad{"); + + // + printf("up_lengths_: "); + print(up_lengths_); + printf(", "); + + // + printf("right_pad_length_: "); + print(right_pad_length_); + + printf("}"); + } +}; + +// idx_low = coefficients[0, ...nDimUp-1] * idx_up[0, ...nDimUp-1] +// UpLengths and Coefficients can be either of the followings: +// 1) Tuple of index_t, which is known at run-time, or +// 2) Tuple of number, which is known at compile-time, or +// 3) Tuple of mixture of index_t and number, which is known partially at run-time and partially +// at compile-time +template ::type = false> +struct embed : public base_transform<1, UpLengths::size()> +{ + static constexpr index_t NDimUp = UpLengths::size(); + + using LowerIndex = multi_index<1>; + using UpperIndex = multi_index; + + UpLengths up_lengths_; + Coefficients coefficients_; + + CK_TILE_HOST_DEVICE constexpr embed() = default; + + CK_TILE_HOST_DEVICE constexpr embed(const UpLengths& up_lengths, + const Coefficients& coefficients) + : up_lengths_{up_lengths}, coefficients_{coefficients} + { + } + + CK_TILE_HOST_DEVICE static constexpr auto get_type_enum() { return cood_transform_enum::embed; } + + CK_TILE_HOST_DEVICE constexpr const auto& get_upper_lengths() const { return up_lengths_; } + + template + CK_TILE_HOST_DEVICE constexpr void calculate_lower_index(LowIdx& idx_low, + const UpIdx& idx_up) const + { + static_assert(LowIdx::size() == 1 && UpIdx::size() == NDimUp, + "wrong! inconsistent # of dimension"); + + idx_low(number<0>{}) = 0; + + static_for<0, NDimUp, 1>{}([&idx_low, &idx_up, this](auto i) { + idx_low(number<0>{}) += idx_up[i] * this->coefficients_[i]; + }); + } + + template + CK_TILE_HOST_DEVICE void update_lower_index(LowIdxDiff& idx_diff_low, + const UpIdxDiff& idx_diff_up, + LowIdx& idx_low, + const UpIdx&) const + { + static_assert(LowIdxDiff::size() == 1 && UpIdxDiff::size() == NDimUp && + LowIdx::size() == 1 && UpIdx::size() == NDimUp, + "wrong! inconsistent # of dimension"); + + idx_diff_low(number<0>{}) = 0; + + static_for<0, NDimUp, 1>{}( + [&](auto i) { idx_diff_low(number<0>{}) += idx_diff_up[i] * coefficients_[i]; }); + + idx_low += idx_diff_low; + } + + CK_TILE_HOST_DEVICE static constexpr bool + is_valid_upper_index_always_mapped_to_valid_lower_index() + { + return true; + } + + template + CK_TILE_HOST_DEVICE static constexpr bool + is_valid_upper_index_mapped_to_valid_lower_index(const UpIdx& /* idx_up */) + { + return true; + } + + CK_TILE_HOST_DEVICE static constexpr bool is_known_at_compile_time() + { + return ck_tile::is_known_at_compile_time::value && + ck_tile::is_known_at_compile_time::value; + } + + CK_TILE_HOST_DEVICE void print() const + { + printf("embed{"); + + // + printf("up_lengths_: "); + print(up_lengths_); + printf(", "); + + // + printf("coefficients_: "); + print(coefficients_); + + printf("}"); + } +}; + +template +struct lambda_merge_generate_MagicDivision_calculate_magic_divisor +{ + template + CK_TILE_HOST_DEVICE constexpr auto operator()(number i) const + { + return magic_division::calculate_magic_numbers(LowLengths{}[i]); + } +}; + +// Implementation of "merge" transformation primitive that uses magic-number-division to do lowering +// of both multi-index and delta of multi-index +// Caution: +// 1. The magic number division implementation being used would produce correct result if the +// dividended is uint32_t and its value is with in 31-bit value range of uint32_t. +// 2. The magic number division for int32_t dividened has not been implemented, the int32_t +// dividend would be bit-wise interpreted as uint32_t and magic number division implementation for +// uint32_t is then used. +// 3. For merge primitive, upper-index is the dividend. +// 4. When upper-index is uint32_t, its value need to be within 31-bit range. +// 5. When upper-index is int32_t type (when index_t is int32_t), its value need to be +// non-negative. +template +struct merge_v2_magic_division : public base_transform +{ + static constexpr index_t NDimLow = LowLengths::size(); + + using LowerIndex = multi_index; + using UpperIndex = multi_index<1>; + + using UpLengths = + decltype(make_tuple(container_reduce(LowLengths{}, math::multiplies{}, number<1>{}))); + + using LowLengthsMagicDivisor = decltype(generate_tuple( + lambda_merge_generate_MagicDivision_calculate_magic_divisor{}, + number{})); + + LowLengths low_lengths_; + LowLengthsMagicDivisor low_lengths_magic_divisor_; + UpLengths up_lengths_; + + static constexpr auto I0 = number<0>{}; + static constexpr auto I1 = number<1>{}; + + CK_TILE_HOST_DEVICE constexpr merge_v2_magic_division() = default; + + CK_TILE_HOST_DEVICE constexpr merge_v2_magic_division(const LowLengths& low_lengths) + : low_lengths_{low_lengths}, + low_lengths_magic_divisor_{generate_tuple( + [&](auto i) { return magic_division::calculate_magic_numbers(low_lengths[i]); }, + number{})}, + up_lengths_{make_tuple(container_reduce(low_lengths, math::multiplies{}, I1))} + { + static_assert(LowerIndex::size() == NDimLow, "wrong!"); + } + + CK_TILE_HOST_DEVICE static constexpr auto get_type_enum() { return cood_transform_enum::merge; } + + CK_TILE_HOST_DEVICE constexpr const auto& get_upper_lengths() const { return up_lengths_; } + + template + CK_TILE_HOST_DEVICE constexpr void calculate_lower_index(LowIdx& idx_low, + const UpIdx& idx_up) const + { + static_assert(LowIdx::size() == NDimLow && UpIdx::size() == 1, + "wrong! inconsistent # of dimension"); + + index_t tmp = idx_up[I0]; + + static_for{}([&, this](auto i) { + index_t tmp2 = + magic_division::do_magic_division(tmp, + this->low_lengths_magic_divisor_[i][I0], + this->low_lengths_magic_divisor_[i][I1]); + idx_low(i) = tmp - tmp2 * this->low_lengths_[i]; + tmp = tmp2; + }); + + idx_low(number<0>{}) = tmp; + } + + template + CK_TILE_HOST_DEVICE void update_lower_index(LowIdxDiff& idx_diff_low, + const UpIdxDiff&, + LowIdx& idx_low, + const UpIdx& idx_up_new) const + { + static_assert(LowIdxDiff::size() == NDimLow && UpIdxDiff::size() == 1 && + LowIdx::size() == NDimLow && UpIdx::size() == 1, + "wrong! inconsistent # of dimension"); + + index_t tmp = idx_up_new[number<0>{}]; + + static_for{}([&, this](auto i) { + index_t tmp2 = + magic_division::do_magic_division(tmp, + this->low_lengths_magic_divisor_[i][I0], + this->low_lengths_magic_divisor_[i][I1]); + + index_t idx_low_old = idx_low[i]; + + idx_low(i) = tmp - tmp2 * this->low_lengths_[i]; + tmp = tmp2; + + idx_diff_low(i) = idx_low[i] - idx_low_old; + }); + + idx_diff_low(number<0>{}) = tmp - idx_low(number<0>{}); + + idx_low(number<0>{}) = tmp; + } + + CK_TILE_HOST_DEVICE static constexpr bool + is_valid_upper_index_always_mapped_to_valid_lower_index() + { + return true; + } + + CK_TILE_HOST_DEVICE static constexpr bool is_known_at_compile_time() + { + return ck_tile::is_known_at_compile_time::value && + ck_tile::is_known_at_compile_time::value && + ck_tile::is_known_at_compile_time::value; + } + + template + CK_TILE_HOST_DEVICE static constexpr bool + is_valid_upper_index_mapped_to_valid_lower_index(const UpIdx& /* idx_up */) + { + return true; + } + + // MUST be static function + template + CK_TILE_HOST_DEVICE static constexpr auto + calculate_upper_dimension_safe_vector_length_strides(const LowVectorLengths& low_vector_lengths, + const LowVectorStrides& low_vector_strides) + { + array up_vector_lengths = make_array_with({-1}); + array up_vector_strides = make_array_with({-1}); + + up_vector_lengths[0] = low_vector_lengths[number{}]; + up_vector_strides[0] = low_vector_strides[number{}]; + + return make_tuple(up_vector_lengths, up_vector_strides); + } + + CK_TILE_HOST_DEVICE void print() const + { + printf("merge_v2_magic_division{"); + + // + printf("low_lengths_ "); + print(low_lengths_); + printf(", "); + + // + printf("up_lengths_ "); + print(up_lengths_); + + printf("}"); + } +}; + +// Implementation of "merge" transformation primitive that uses division and mod. It is supposed to +// be used for low_lengths that are known at compile time and are power of 2, otherwise performance +// will be very bad +template +struct merge_v3_division_mod : public base_transform +{ + static constexpr index_t NDimLow = LowLengths::size(); + + using LowerIndex = multi_index; + using UpperIndex = multi_index<1>; + + using LowLengthsScan = + decltype(container_reverse_exclusive_scan(LowLengths{}, math::multiplies{}, number<1>{})); + + using UpLengths = + decltype(make_tuple(container_reduce(LowLengths{}, math::multiplies{}, number<1>{}))); + + LowLengths low_lengths_; + LowLengthsScan low_lengths_scan_; + UpLengths up_lengths_; + + CK_TILE_HOST_DEVICE constexpr merge_v3_division_mod() = default; + + CK_TILE_HOST_DEVICE constexpr merge_v3_division_mod(const LowLengths& low_lengths) + : low_lengths_{low_lengths}, + low_lengths_scan_{ + container_reverse_exclusive_scan(low_lengths, math::multiplies{}, number<1>{})}, + up_lengths_{make_tuple(container_reduce(low_lengths, math::multiplies{}, number<1>{}))} + { + static_assert(LowerIndex::size() == NDimLow, "wrong!"); + } + + CK_TILE_HOST_DEVICE constexpr const auto& get_upper_lengths() const { return up_lengths_; } + + template + CK_TILE_HOST_DEVICE constexpr void calculate_lower_index(LowIdx& idx_low, + const UpIdx& idx_up) const + { + static_assert(LowIdx::size() == NDimLow && UpIdx::size() == 1, + "wrong! inconsistent # of dimension"); + + index_t tmp = idx_up[number<0>{}]; + + // division and mod + static_for<0, NDimLow - 1, 1>{}([&](auto i) { + idx_low(i) = tmp / this->low_lengths_scan_[i]; + tmp %= this->low_lengths_scan_[i]; + }); + + idx_low(number{}) = tmp; + } + + template + CK_TILE_HOST_DEVICE void update_lower_index(LowIdxDiff& idx_diff_low, + const UpIdxDiff&, + LowIdx& idx_low, + const UpIdx& idx_up_new) const + { + static_assert(LowIdxDiff::size() == NDimLow && UpIdxDiff::size() == 1 && + LowIdx::size() == NDimLow && UpIdx::size() == 1, + "wrong! inconsistent # of dimension"); + + constexpr auto I0 = number<0>{}; + constexpr auto INm1 = number{}; + + index_t tmp = idx_up_new[I0]; + + static_for<0, NDimLow - 1, 1>{}([&](auto i) { + const index_t tmp2 = idx_low[i]; + idx_low(i) = tmp / this->low_lengths_scan_[i]; + idx_diff_low(i) = idx_low[i] - tmp2; + tmp %= this->low_lengths_scan_[i]; + }); + + const index_t tmp2 = idx_low[INm1]; + idx_low(INm1) = tmp; + idx_diff_low(INm1) = idx_low[INm1] - tmp2; + } + + CK_TILE_HOST_DEVICE static constexpr bool + is_valid_upper_index_always_mapped_to_valid_lower_index() + { + return true; + } + + CK_TILE_HOST_DEVICE static constexpr bool is_known_at_compile_time() + { + return ck_tile::is_known_at_compile_time::value && + ck_tile::is_known_at_compile_time::value && + ck_tile::is_known_at_compile_time::value; + } + + template + CK_TILE_HOST_DEVICE static constexpr bool + is_valid_upper_index_mapped_to_valid_lower_index(const UpIdx& /* idx_up */) + { + return true; + } + + // MUST be static function + template + CK_TILE_HOST_DEVICE static constexpr auto + calculate_upper_dimension_safe_vector_length_strides(const LowVectorLengths& low_vector_lengths, + const LowVectorStrides& low_vector_strides) + { + array up_vector_lengths = make_array_with({-1}); + array up_vector_strides = make_array_with({-1}); + + up_vector_lengths[0] = low_vector_lengths[number{}]; + up_vector_strides[0] = low_vector_strides[number{}]; + + return make_tuple(up_vector_lengths, up_vector_strides); + } + + CK_TILE_HOST_DEVICE void print() const + { + printf("Merge_v3_direct_division_mod{"); + + // + printf("low_lengths_ "); + print(low_lengths_); + printf(", "); + + // + printf("low_lengths_scan_ "); + print(low_lengths_scan_); + printf(", "); + + // + printf("up_lengths_ "); + print(up_lengths_); + + printf("}"); + } +}; + +template +struct unmerge : public base_transform<1, UpLengths::size()> +{ + static constexpr index_t NDimUp = UpLengths::size(); + + using LowerIndex = multi_index<1>; + using UpperIndex = multi_index; + + using UpLengthsScan = + decltype(container_reverse_exclusive_scan(UpLengths{}, math::multiplies{}, number<1>{})); + + UpLengths up_lengths_; + UpLengthsScan up_lengths_scan_; + + CK_TILE_HOST_DEVICE constexpr unmerge() = default; + + CK_TILE_HOST_DEVICE constexpr unmerge(const UpLengths& up_lengths) + : up_lengths_{up_lengths}, + up_lengths_scan_{ + container_reverse_exclusive_scan(up_lengths, math::multiplies{}, number<1>{})} + { + } + + CK_TILE_HOST_DEVICE static constexpr auto get_type_enum() + { + return cood_transform_enum::unmerge; + } + + CK_TILE_HOST_DEVICE constexpr const auto& get_upper_lengths() const { return up_lengths_; } + + template + CK_TILE_HOST_DEVICE constexpr void calculate_lower_index(LowIdx& idx_low, + const UpIdx& idx_up) const + { + if constexpr(!Use24BitIntegerCalculation) + { + idx_low(number<0>{}) = idx_up[number{}]; + + static_for<0, NDimUp - 1, 1>{}( + [&](auto i) { idx_low(number<0>{}) += idx_up[i] * up_lengths_scan_[i]; }); + } + else + { + idx_low(number<0>{}) = idx_up[number{}]; + + static_for<0, NDimUp - 1, 1>{}([&](auto i) { + idx_low(number<0>{}) = + (0x00ffffff & idx_low[number<0>{}]) + + (0x00ffffff & idx_up[i]) * (0x00ffffff & up_lengths_scan_[i]); + }); + } + } + + template + CK_TILE_HOST_DEVICE void update_lower_index(LowIdxDiff& idx_diff_low, + const UpIdxDiff& idx_diff_up, + LowIdx& idx_low, + const UpIdx&) const + { + calculate_lower_index(idx_diff_low, idx_diff_up); + + idx_low += idx_diff_low; + } + + CK_TILE_HOST_DEVICE static constexpr bool + is_valid_upper_index_always_mapped_to_valid_lower_index() + { + return true; + } + + template + CK_TILE_HOST_DEVICE static constexpr bool + is_valid_upper_index_mapped_to_valid_lower_index(const UpIdx& /* idx_up */) + { + return true; + } + + CK_TILE_HOST_DEVICE static constexpr bool is_known_at_compile_time() + { + return ck_tile::is_known_at_compile_time::value && + ck_tile::is_known_at_compile_time::value; + } + + // MUST be static function + template + CK_TILE_HOST_DEVICE static constexpr auto + calculate_upper_dimension_safe_vector_length_strides(const LowVectorLengths& low_vector_lengths, + const LowVectorStrides& low_vector_strides) + { + array up_vector_lengths = make_array_with({-1}); + array up_vector_strides = make_array_with({-1}); + + constexpr auto up_length_last = UpLengths{}[number{}]; + + if constexpr(ck_tile::is_known_at_compile_time::value) + { + if(low_vector_lengths[0] != -1) + { + up_vector_lengths(NDimUp - 1) = math::gcd(low_vector_lengths[0], up_length_last); + } + } + + up_vector_strides(NDimUp - 1) = low_vector_strides[0]; + + return make_tuple(up_vector_lengths, up_vector_strides); + } + + CK_TILE_HOST_DEVICE void print() const + { + printf("unmerge{"); + + // + printf("up_lengths_"); + print(up_lengths_); + printf(", "); + + // + printf("up_lengths_scan_"); + print(up_lengths_scan_); + + printf("}"); + } +}; + +template +struct freeze : public base_transform<1, 0> +{ + LowerIndex low_idx_; + + CK_TILE_HOST_DEVICE constexpr freeze() = default; + + CK_TILE_HOST_DEVICE constexpr freeze(const LowerIndex& low_idx) : low_idx_{low_idx} {} + + CK_TILE_HOST_DEVICE static constexpr auto get_upper_lengths() { return Tuple<>{}; } + + template + CK_TILE_HOST_DEVICE constexpr void calculate_lower_index(LowIdx& idx_low, + const UpIdx& /* idx_up */) const + { + static_assert(LowIdx::size() == 1 && UpIdx::size() == 0, + "wrong! inconsistent # of dimension"); + + idx_low(number<0>{}) = low_idx_; + } + + template + CK_TILE_HOST_DEVICE static void update_lower_index(LowIdxDiff& idx_diff_low, + const UpIdxDiff& /* idx_diff_up */, + LowIdx& /* idx_low */, + const UpIdx& /* idx_up_new */) + { + idx_diff_low(number<0>{}) = 0; + } + + CK_TILE_HOST_DEVICE static constexpr bool + is_valid_upper_index_always_mapped_to_valid_lower_index() + { + return true; + } + + template + CK_TILE_HOST_DEVICE static constexpr bool + is_valid_upper_index_mapped_to_valid_lower_index(const UpIdx& /* idx_up */) + { + return true; + } + + CK_TILE_HOST_DEVICE static constexpr bool is_known_at_compile_time() + { + return ck_tile::is_known_at_compile_time::value; + } + + CK_TILE_HOST_DEVICE void print() const + { + printf("freeze{"); + + // + printf("low_idx_: "); + print(low_idx_); + + printf("}"); + } +}; + +// insert a dangling upper dimension without lower dimension +template +struct insert : public base_transform<0, 1> +{ + using UpLengths = decltype(make_tuple(UpperLength{})); + + UpLengths up_lengths_; + + CK_TILE_HOST_DEVICE constexpr insert() = default; + + CK_TILE_HOST_DEVICE constexpr insert(const UpperLength& up_length) + : up_lengths_{make_tuple(up_length)} + { + } + + CK_TILE_HOST_DEVICE static constexpr index_t get_num_of_lower_dimension() { return 0; } + + CK_TILE_HOST_DEVICE static constexpr index_t get_num_of_upper_dimension() { return 1; } + + CK_TILE_HOST_DEVICE constexpr auto get_upper_lengths() const { return up_lengths_; } + + template + CK_TILE_HOST_DEVICE constexpr void calculate_lower_index(LowIdx&, const UpIdx&) const + { + static_assert(LowIdx::size() == 0 && UpIdx::size() == 1, + "wrong! inconsistent # of dimension"); + } + + template + CK_TILE_HOST_DEVICE static void + update_lower_index(LowIdxDiff&, const UpIdxDiff&, LowIdx&, const UpIdx&) + { + static_assert(LowIdxDiff::size() == 0 && UpIdxDiff::size() == 1 && LowIdx::size() == 0 && + UpIdx::size() == 1, + "wrong! inconsistent # of dimension"); + } + + CK_TILE_HOST_DEVICE static constexpr bool IsLinearTransform() { return true; } + + CK_TILE_HOST_DEVICE static constexpr bool + is_valid_upper_index_always_mapped_to_valid_lower_index() + { + return true; + } + + template + CK_TILE_HOST_DEVICE static constexpr bool + is_valid_upper_index_mapped_to_valid_lower_index(const UpIdx& /* idx_up */) + { + return true; + } + + CK_TILE_HOST_DEVICE static constexpr bool is_known_at_compile_time() + { + return ck_tile::is_known_at_compile_time::value; + } + + CK_TILE_HOST_DEVICE void print() const + { + printf("insert{"); + + // + print(up_lengths_); + + printf("}"); + } +}; + +// replicate the original tensor and create a higher dimensional tensor +template +struct replicate : public base_transform<0, UpLengths::size()> +{ + static constexpr index_t NDimUp = UpLengths::size(); + + CK_TILE_HOST_DEVICE constexpr replicate() = default; + + CK_TILE_HOST_DEVICE constexpr replicate(const UpLengths& up_lengths) : up_lengths_{up_lengths} + { + } + + CK_TILE_HOST_DEVICE constexpr auto get_upper_lengths() const { return up_lengths_; } + + template + CK_TILE_HOST_DEVICE constexpr void calculate_lower_index(LowIdx&, const UpIdx&) const + { + static_assert(LowIdx::size() == 0 && UpIdx::size() == NDimUp, + "wrong! inconsistent # of dimension"); + } + + template + CK_TILE_HOST_DEVICE static void + update_lower_index(LowIdxDiff&, const UpIdxDiff&, LowIdx&, const UpIdx&) + { + static_assert(LowIdxDiff::size() == 0 && UpIdxDiff::size() == NDimUp && + LowIdx::size() == 0 && UpIdx::size() == NDimUp, + "wrong! inconsistent # of dimension"); + } + + CK_TILE_HOST_DEVICE static constexpr bool + is_valid_upper_index_always_mapped_to_valid_lower_index() + { + return true; + } + + template + CK_TILE_HOST_DEVICE static constexpr bool + is_valid_upper_index_mapped_to_valid_lower_index(const UpIdx& /* idx_up */) + { + return true; + } + + CK_TILE_HOST_DEVICE static constexpr bool is_known_at_compile_time() + { + return ck_tile::is_known_at_compile_time::value; + } + + CK_TILE_HOST_DEVICE void print() const + { + printf("replicate{"); + + // + printf("up_lengths_: "); + print(up_lengths_); + + printf("}"); + } + + // + UpLengths up_lengths_; +}; + +template +struct slice : public base_transform<1, 1> +{ + using LowerIndex = multi_index<1>; + using UpperIndex = multi_index<1>; + + using UpLengths = decltype(make_tuple(SliceEnd{} - SliceBegin{})); + + UpLengths up_lengths_; + SliceBegin slice_begin_; + SliceEnd slice_end_; + + CK_TILE_HOST_DEVICE constexpr slice() = default; + + CK_TILE_HOST_DEVICE constexpr slice(const LowLength&, + const SliceBegin& slice_begin, + const SliceEnd& slice_end) + : up_lengths_{make_tuple(slice_end - slice_begin)}, + slice_begin_{slice_begin}, + slice_end_{slice_end} + { + } + + CK_TILE_HOST_DEVICE constexpr const auto& get_upper_lengths() const { return up_lengths_; } + + template + CK_TILE_HOST_DEVICE constexpr void calculate_lower_index(LowIdx& idx_low, + const UpIdx& idx_up) const + { + static_assert(LowIdx::size() == 1 && UpIdx::size() == 1, + "wrong! inconsistent # of dimension"); + + idx_low(number<0>{}) = idx_up[number<0>{}] + slice_begin_; + } + + template + CK_TILE_HOST_DEVICE static void update_lower_index(LowIdxDiff& idx_diff_low, + const UpIdxDiff& idx_diff_up, + LowIdx& idx_low, + const UpIdx&) + { + static_assert(LowIdxDiff::size() == 1 && UpIdxDiff::size() == 1 && LowIdx::size() == 1 && + UpIdx::size() == 1, + "wrong! inconsistent # of dimension"); + + constexpr auto I0 = number<0>{}; + + idx_diff_low[I0] = idx_diff_up[I0]; + + idx_low += idx_diff_low; + } + + CK_TILE_HOST_DEVICE static constexpr bool + is_valid_upper_index_always_mapped_to_valid_lower_index() + { + return true; + } + + template + CK_TILE_HOST_DEVICE constexpr bool + is_valid_upper_index_mapped_to_valid_lower_index(const UpIdx&) const + { + return true; + } + + CK_TILE_HOST_DEVICE static constexpr bool is_known_at_compile_time() + { + return ck_tile::is_known_at_compile_time::value && + ck_tile::is_known_at_compile_time::value && + ck_tile::is_known_at_compile_time::value; + } + + CK_TILE_HOST_DEVICE void print() const + { + printf("slice{"); + + // + printf("up_lengths_: "); + print(up_lengths_); + printf(", "); + + // + printf("slice_begin_: "); + print(slice_begin_); + printf(", "); + + // + printf("slice_end_: "); + print(slice_end_); + + printf("}"); + } // namespace ck +}; // namespace ck + +/* + * \brief lower_idx = upper_idx % modulus. + * TODO: Need an improved implementation since the modulo operation is expensive. + */ +template +struct modulo : public base_transform<1, 1> +{ + using LowerIndex = multi_index<1>; + using UpperIndex = multi_index<1>; + using UpLengths = decltype(make_tuple(UpLength{})); + + Modulus modulus_; + UpLengths up_lengths_; + + CK_TILE_HOST_DEVICE constexpr modulo() = default; + + CK_TILE_HOST_DEVICE constexpr modulo(const Modulus& modulus, const UpLength& up_length) + : modulus_{modulus}, up_lengths_{make_tuple(up_length)} + { + } + + CK_TILE_HOST_DEVICE constexpr const auto& get_upper_lengths() const { return up_lengths_; } + + template + CK_TILE_HOST_DEVICE constexpr void calculate_lower_index(LowIdx& idx_low, + const UpIdx& idx_up) const + { + static_assert(LowIdx::size() == 1 && UpIdx::size() == 1, + "wrong! inconsistent # of dimension"); + + idx_low(number<0>{}) = idx_up[number<0>{}] % modulus_; + } + + template + CK_TILE_HOST_DEVICE void update_lower_index(LowIdxDiff& idx_diff_low, + const UpIdxDiff& idx_diff_up, + LowIdx& idx_low, + const UpIdx& up_idx) const + { + static_assert(LowIdxDiff::size() == 1 && UpIdxDiff::size() == 1 && LowIdx::size() == 1 && + UpIdx::size() == 1, + "wrong! inconsistent # of dimension"); + + constexpr auto I0 = number<0>{}; + + const auto idx_low_old = idx_low; + idx_low[I0] = (up_idx[I0] + idx_diff_up[I0]) % modulus_; + idx_diff_low[I0] = idx_low - idx_low_old; + } + + CK_TILE_HOST_DEVICE static constexpr bool + is_valid_upper_index_always_mapped_to_valid_lower_index() + { + return true; + } + + template + CK_TILE_HOST_DEVICE static constexpr bool + is_valid_upper_index_mapped_to_valid_lower_index(const UpIdx& /* idx_up */) + { + return true; + } + + CK_TILE_HOST_DEVICE static constexpr bool is_known_at_compile_time() + { + return ck_tile::is_known_at_compile_time::value; + } + + CK_TILE_HOST_DEVICE void print() const + { + printf("Modulus{"); + + // + printf("up_lengths_: "); + print(up_lengths_); + + printf("}"); + } +}; + +// 2D XOR, NOTE: "xor" is a keyword +template +struct xor_t : public base_transform<2, 2> +{ + static constexpr auto type_enum = cood_transform_enum::xor_t; + + using LowerIndex = multi_index<2>; + using UpperIndex = multi_index<2>; + + using UpLengths = LowLengths; + + UpLengths up_lengths_; + RightShift right_shift_; + + CK_TILE_HOST_DEVICE constexpr xor_t() : up_lengths_{}, right_shift_{} {} + + CK_TILE_HOST_DEVICE constexpr xor_t(const LowLengths& low_lengths, + const RightShift& right_shift) + : up_lengths_{low_lengths}, right_shift_{right_shift} + { + } + + CK_TILE_HOST_DEVICE static constexpr auto get_type_enum() { return cood_transform_enum::xor_t; } + + CK_TILE_HOST_DEVICE constexpr const auto& get_upper_lengths() const { return up_lengths_; } + + template + CK_TILE_HOST_DEVICE constexpr void calculate_lower_index(LowIdx& idx_low, + const UpIdx& idx_up) const + { + static_assert(LowIdx::size() == 2 && UpIdx::size() == 2, + "wrong! inconsistent # of dimension"); + + idx_low(number<0>{}) = idx_up[number<0>{}]; + + const auto idx_low_1_tmp = + (idx_up[number<1>{}] - idx_up[number<0>{}] * right_shift_) % up_lengths_[number<1>{}]; + + const auto idx_low_1 = + (idx_low_1_tmp >= 0) ? idx_low_1_tmp : up_lengths_[number<1>{}] + idx_low_1_tmp; + + idx_low(number<1>{}) = idx_low_1; + } + + template + CK_TILE_HOST_DEVICE void update_lower_index(LowIdxDiff& idx_diff_low, + const UpIdxDiff&, + LowIdx& idx_low, + const UpIdx& idx_up) const + { + static_assert(LowIdxDiff::size() == 2 && UpIdxDiff::size() == 2 && LowIdx::size() == 2 && + UpIdx::size() == 2, + "wrong! inconsistent # of dimension"); + + const auto idx_low_old = idx_low; + + calculate_lower_index(idx_low, idx_up); + + idx_diff_low = idx_low - idx_low_old; + } + + CK_TILE_HOST_DEVICE static constexpr bool + is_valid_upper_index_always_mapped_to_valid_lower_index() + { + return true; + } + + template + CK_TILE_HOST_DEVICE static constexpr bool + is_valid_upper_index_mapped_to_valid_lower_index(const UpIdx& /* idx_up */) + { + return true; + } + + CK_TILE_HOST_DEVICE static constexpr bool is_known_at_compile_time() + { + return ck_tile::is_known_at_compile_time::value && + ck_tile::is_known_at_compile_time::value; + } + + // MUST be static function + template + CK_TILE_HOST_DEVICE constexpr auto calculate_upper_dimension_safe_vector_length_strides( + const LowVectorLengths& low_vector_lengths, + const LowVectorStrides& low_vector_strides) const + { + array up_vector_lengths = low_vector_lengths; + array up_vector_strides = low_vector_strides; + + if constexpr(ck_tile::is_known_at_compile_time::value) + { + if(low_vector_lengths[1] != -1) + { + up_vector_lengths(1) = math::gcd(low_vector_lengths[1], math::abs(right_shift_)); + } + } + + return make_tuple(up_vector_lengths, up_vector_strides); + } + + CK_TILE_HOST_DEVICE void print() const + { + printf("xor_t{"); + + // + printf("up_lengths_: "); + print(up_lengths_); + printf(", "); + + // + printf("right_shift_: "); + print(right_shift_); + + printf("}"); + } +}; + +template +struct offset : public base_transform<1, 1> +{ + using LowerIndex = multi_index<1>; + using UpperIndex = multi_index<1>; + + using UpLengths = decltype(make_tuple(LowLength{})); + + UpLengths up_lengths_; + OffsetLength offset_length_; + + CK_TILE_HOST_DEVICE constexpr offset() = default; + + CK_TILE_HOST_DEVICE constexpr offset(const LowLength& low_length, + const OffsetLength& offset_length) + : up_lengths_{make_tuple(low_length)}, offset_length_{offset_length} + { + } + + CK_TILE_HOST_DEVICE static constexpr auto get_type_enum() + { + return cood_transform_enum::offset; + } + + CK_TILE_HOST_DEVICE constexpr const auto& get_upper_lengths() const { return up_lengths_; } + + template + CK_TILE_HOST_DEVICE constexpr void calculate_lower_index(LowIdx& idx_low, + const UpIdx& idx_up) const + { + static_assert(LowIdx::size() == 1 && UpIdx::size() == 1, + "wrong! inconsistent # of dimension"); + + idx_low(number<0>{}) = idx_up[number<0>{}] + offset_length_; + } + + template + CK_TILE_HOST_DEVICE static void update_lower_index(LowIdxDiff& idx_diff_low, + const UpIdxDiff& idx_diff_up, + LowIdx& idx_low, + const UpIdx&) + { + static_assert(LowIdxDiff::size() == 1 && UpIdxDiff::size() == 1 && LowIdx::size() == 1 && + UpIdx::size() == 1, + "wrong! inconsistent # of dimension"); + + constexpr auto I0 = number<0>{}; + + idx_diff_low[I0] = idx_diff_up[I0]; + + idx_low += idx_diff_low; + } + + CK_TILE_HOST_DEVICE static constexpr bool + is_valid_upper_index_always_mapped_to_valid_lower_index() + { + return true; + } + + template + CK_TILE_HOST_DEVICE constexpr bool + is_valid_upper_index_mapped_to_valid_lower_index(const UpIdx&) const + { + return true; + } + + CK_TILE_HOST_DEVICE static constexpr bool is_known_at_compile_time() + { + return ck_tile::is_known_at_compile_time::value && + ck_tile::is_known_at_compile_time::value; + } + + CK_TILE_HOST_DEVICE void print() const + { + printf("offset{"); + + // + printf("up_lengths_: "); + print(up_lengths_); + printf(", "); + + // + printf("offset_length_: "); + print(offset_length_); + + printf("}"); + } +}; + +//******************************************************************************************************* + +template +CK_TILE_HOST_DEVICE constexpr auto make_pass_through_transform(const LowLength& low_length) +{ + return PassThrough{low_length}; +} + +template +CK_TILE_HOST_DEVICE constexpr auto +make_pad_transform(const LowLength& low_length, + const left_pad& left_pad, + const right_pad& right_pad, + integral_constant = integral_constant{}) +{ + return pad{low_length, left_pad, right_pad}; +} + +template +CK_TILE_HOST_DEVICE constexpr auto make_left_pad_transform( + const LowLength& low_length, + const LeftPadLength& left_pad, + integral_constant = integral_constant{}) +{ + return left_pad{low_length, left_pad}; +} + +template +CK_TILE_HOST_DEVICE constexpr auto make_right_pad_transform( + const LowLength& low_length, + const RightPadLength& right_pad, + integral_constant = integral_constant{}) +{ + return right_pad{low_length, right_pad}; +} + +template ::type = false> +CK_TILE_HOST_DEVICE constexpr auto make_embed_transform(const UpLengths& up_lengths, + const Coefficients& coefficients) +{ + return embed{up_lengths, coefficients}; +} + +template +CK_TILE_HOST_DEVICE constexpr auto +make_merge_transform_v2_magic_division(const LowLengths& low_lengths) +{ + return merge_v2_magic_division{low_lengths}; +} + +template +CK_TILE_HOST_DEVICE constexpr auto +make_merge_transform_v3_division_mod(const LowLengths& low_lengths) +{ + return merge_v3_division_mod{low_lengths}; +} + +template +CK_TILE_HOST_DEVICE constexpr auto make_merge_transform(const LowLengths& low_lengths) +{ + return make_merge_transform_v2_magic_division(low_lengths); +} + +template +CK_TILE_HOST_DEVICE constexpr auto make_unmerge_transform( + const UpLengths& up_lengths, + integral_constant = integral_constant{}) +{ + return unmerge{up_lengths}; +} + +template +CK_TILE_HOST_DEVICE constexpr auto make_freeze_transform(const LowerIndex& low_idx) +{ + return freeze{low_idx}; +} + +template +CK_TILE_HOST_DEVICE constexpr auto make_insert_transform(const UpperIndex& up_idx) +{ + return insert{up_idx}; +} + +template +CK_TILE_HOST_DEVICE constexpr auto make_replicate_transform(const UpLengths& up_lengths) +{ + return replicate{up_lengths}; +} + +template +CK_TILE_HOST_DEVICE constexpr auto make_slice_transform(const LowLength& low_length, + const SliceBegin& slice_begin, + const SliceEnd& slice_end) +{ + return slice{low_length, slice_begin, slice_end}; +} + +template +CK_TILE_HOST_DEVICE constexpr auto make_modulo_transform(const Modulus& modulus, + const UpLength& up_length) +{ + return modulo{modulus, up_length}; +} + +template +CK_TILE_HOST_DEVICE constexpr auto make_xor_transform(const LowLengths& low_lengths, + const RightShift& right_shift) +{ + return xor_t{low_lengths, right_shift}; +} + +template +CK_TILE_HOST_DEVICE constexpr auto make_offset_transform(const LowLength& low_length, + const OffsetLength& offset_length) +{ + return offset{low_length, offset_length}; +} + +} // namespace ck_tile diff --git a/include/ck_tile/core/algorithm/space_filling_curve.hpp b/include/ck_tile/core/algorithm/space_filling_curve.hpp new file mode 100644 index 0000000000..d9850f9b91 --- /dev/null +++ b/include/ck_tile/core/algorithm/space_filling_curve.hpp @@ -0,0 +1,165 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core/config.hpp" +#include "ck_tile/core/container/multi_index.hpp" +#include "ck_tile/core/container/container_helper.hpp" +#include "ck_tile/core/utility/functional.hpp" +#include "ck_tile/core/utility/type_traits.hpp" + +namespace ck_tile { + +template // # of scalars per access in each dimension +struct space_filling_curve +{ + static constexpr index_t TensorSize = + reduce_on_sequence(TensorLengths{}, multiplies{}, number<1>{}); + static_assert(0 < TensorSize, + "space_filling_curve should be used to access a non-empty tensor"); + + static constexpr index_t nDim = TensorLengths::size(); + + using Index = multi_index; + + static constexpr index_t ScalarPerVector = + reduce_on_sequence(ScalarsPerAccess{}, multiplies{}, number<1>{}); + + static constexpr auto access_lengths = TensorLengths{} / ScalarsPerAccess{}; + static constexpr auto dim_access_order = DimAccessOrder{}; + static constexpr auto ordered_access_lengths = + container_reorder_given_new2old(access_lengths, dim_access_order); + + static constexpr auto to_index_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(ordered_access_lengths)), + make_tuple(typename arithmetic_sequence_gen<0, nDim, 1>::type{}), + make_tuple(sequence<0>{})); + + static constexpr auto I0 = number<0>{}; + static constexpr auto I1 = number<1>{}; + + CK_TILE_HOST_DEVICE static constexpr index_t get_num_of_access() + { + static_assert(TensorLengths::size() == ScalarsPerAccess::size()); + static_assert(TensorLengths{} % ScalarsPerAccess{} == + typename uniform_sequence_gen::type{}); + + return reduce_on_sequence(TensorLengths{}, multiplies{}, number<1>{}) / ScalarPerVector; + } + + template + static CK_TILE_HOST_DEVICE constexpr auto get_step_between(number, + number) + { + static_assert(AccessIdx1dHead >= 0 && AccessIdx1dHead < get_num_of_access(), + "1D index out of range"); + static_assert(AccessIdx1dTail >= 0 && AccessIdx1dTail < get_num_of_access(), + "1D index out of range"); + + constexpr auto idx_head = get_index(number{}); + constexpr auto idx_tail = get_index(number{}); + return idx_tail - idx_head; + } + + template + static CK_TILE_HOST_DEVICE constexpr auto get_forward_step(number) + { + static_assert(AccessIdx1d < get_num_of_access(), "1D index should be larger than 0"); + return get_step_between(number{}, number{}); + } + + template + static CK_TILE_HOST_DEVICE constexpr auto get_backward_step(number) + { + static_assert(AccessIdx1d > 0, "1D index should be larger than 0"); + + return get_step_between(number{}, number{}); + } + + template + static CK_TILE_HOST_DEVICE constexpr Index get_index(number) + { +#if 0 + /* + * \todo: tensor_adaptor::calculate_bottom_index does NOT return constexpr as expected. + */ + constexpr auto ordered_access_idx = to_index_adaptor.calculate_bottom_index(make_multi_index(number{})); +#else + + constexpr auto access_strides = + container_reverse_exclusive_scan(ordered_access_lengths, multiplies{}, number<1>{}); + + constexpr auto idx_1d = number{}; + // Given tensor strides \p access_lengths, and 1D index of space-filling-curve, compute the + // idim-th element of multidimensional index. + // All constexpr variables have to be captured by VALUE. + constexpr auto compute_index = [ idx_1d, access_strides ](auto idim) constexpr + { + constexpr auto compute_index_impl = [ idx_1d, access_strides ](auto jdim) constexpr + { + auto res = idx_1d.value; + auto id = 0; + + static_for<0, jdim.value + 1, 1>{}([&](auto kdim) { + id = res / access_strides[kdim].value; + res -= id * access_strides[kdim].value; + }); + + return id; + }; + + constexpr auto id = compute_index_impl(idim); + return number{}; + }; + + constexpr auto ordered_access_idx = generate_tuple(compute_index, number{}); +#endif + constexpr auto forward_sweep = [&]() { + StaticallyIndexedArray forward_sweep_; + + forward_sweep_(I0) = true; + + static_for<1, nDim, 1>{}([&](auto idim) { + index_t tmp = ordered_access_idx[I0]; + + static_for<1, idim, 1>{}( + [&](auto j) { tmp = tmp * ordered_access_lengths[j] + ordered_access_idx[j]; }); + + forward_sweep_(idim) = tmp % 2 == 0; + }); + + return forward_sweep_; + }(); + + // calculate multi-dim tensor index + auto idx_md = [&]() { + Index ordered_idx; + + static_for<0, nDim, 1>{}([&](auto idim) { + ordered_idx(idim) = + !SnakeCurved || forward_sweep[idim] + ? ordered_access_idx[idim] + : ordered_access_lengths[idim] - 1 - ordered_access_idx[idim]; + }); + + return container_reorder_given_old2new(ordered_idx, dim_access_order) * + ScalarsPerAccess{}; + }(); + return idx_md; + } + + // FIXME: rename this function + template + static CK_TILE_HOST_DEVICE constexpr auto get_index_tuple_of_number(number) + { + constexpr auto idx = get_index(number{}); + + return generate_tuple([&](auto i) { return number{}; }, number{}); + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/core/arch/amd_address_space.hpp b/include/ck_tile/core/arch/amd_address_space.hpp new file mode 100644 index 0000000000..19a9ded568 --- /dev/null +++ b/include/ck_tile/core/arch/amd_address_space.hpp @@ -0,0 +1,20 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +// Address Space for AMDGCN +// https://llvm.org/docs/AMDGPUUsage.html#address-space + +namespace ck_tile { + +enum struct address_space_enum +{ + generic, + global, + lds, + sgpr, + vgpr, +}; + +} // namespace ck_tile diff --git a/include/ck_tile/core/arch/amd_buffer_addressing.hpp b/include/ck_tile/core/arch/amd_buffer_addressing.hpp new file mode 100644 index 0000000000..6d922dc973 --- /dev/null +++ b/include/ck_tile/core/arch/amd_buffer_addressing.hpp @@ -0,0 +1,2050 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core/numeric/integer.hpp" +#include "ck_tile/core/numeric/integral_constant.hpp" +#include "ck_tile/core/container/container_helper.hpp" +#include "ck_tile/core/utility/type_traits.hpp" +#include "ck_tile/core/numeric/vector_type.hpp" + +namespace ck_tile { + +template +union buffer_resource +{ + CK_TILE_DEVICE constexpr buffer_resource() : content{} {} + + // 128 bit SGPRs to supply buffer resource in buffer instructions + // https://rocm-documentation.readthedocs.io/en/latest/GCN_ISA_Manuals/testdocbook.html#vector-memory-buffer-instructions + int32x4_t content; + statically_indexed_array address; + statically_indexed_array range; + statically_indexed_array config; +}; + +template +CK_TILE_DEVICE int32x4_t make_wave_buffer_resource(T* p_wave, index_t element_space_size) +{ + buffer_resource wave_buffer_resource; + + // wavewise base address (64 bit) + wave_buffer_resource.address(number<0>{}) = const_cast*>(p_wave); + // wavewise range (32 bit) + wave_buffer_resource.range(number<2>{}) = element_space_size * sizeof(T); + // wavewise setting (32 bit) + wave_buffer_resource.config(number<3>{}) = CK_BUFFER_RESOURCE_3RD_DWORD; + + return wave_buffer_resource.content; +} + +template +CK_TILE_DEVICE int32x4_t make_wave_buffer_resource_with_default_range(T* p_wave) +{ + buffer_resource wave_buffer_resource; + + // wavewise base address (64 bit) + wave_buffer_resource.address(number<0>{}) = const_cast*>(p_wave); + // wavewise range (32 bit) + wave_buffer_resource.range(number<2>{}) = 0xffffffff; // max possible range + // wavewise setting (32 bit) + wave_buffer_resource.config(number<3>{}) = CK_BUFFER_RESOURCE_3RD_DWORD; + + return wave_buffer_resource.content; +} + +// TODO: glc/slc/... +template +struct buffer_load; +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wundefined-reinterpret-cast" +// TODO: strict aliasing rule seems fail when reinterpret_cast between vector type +// (exp_vector_type(xxx)) +template <> +struct buffer_load<16> +{ + template + CK_TILE_DEVICE void operator()(T& value, + int32x4_t res /*buffer resource*/, + index_t v_offset, + index_t s_offset, + index_t i_offset /*max 0xFFF*/, + index_t /*flag*/ = 0) + { + static_assert(sizeof(T) == 16); + using mubuf_t = float __attribute__((ext_vector_type(4))); + asm volatile("buffer_load_dwordx4 %0, %1, %2, %3 offen offset:%4" + : "+v"(reinterpret_cast(value)) + : "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset) + : "memory"); + } +}; + +template <> +struct buffer_load<8> +{ + template + CK_TILE_DEVICE void operator()(T& value, + int32x4_t res /*buffer resource*/, + index_t v_offset, + index_t s_offset, + index_t i_offset /*max 0xFFF*/, + index_t /*flag*/ = 0) + { + static_assert(sizeof(T) == 8); + using mubuf_t = float __attribute__((ext_vector_type(2))); + asm volatile("buffer_load_dwordx2 %0, %1, %2, %3 offen offset:%4" + : "+v"(reinterpret_cast(value)) + : "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset) + : "memory"); + } +}; + +template <> +struct buffer_load<4> +{ + template + CK_TILE_DEVICE void operator()(T& value, + int32x4_t res /*buffer resource*/, + index_t v_offset, + index_t s_offset, + index_t i_offset /*max 0xFFF*/, + index_t /*flag*/ = 0) + { + static_assert(sizeof(T) == 4); + using mubuf_t = float; + asm volatile("buffer_load_dword %0, %1, %2, %3 offen offset:%4" + : "+v"(reinterpret_cast(value)) + : "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset) + : "memory"); + } +}; + +template <> +struct buffer_load<2> +{ + template + CK_TILE_DEVICE void operator()(T& value, + int32x4_t res /*buffer resource*/, + index_t v_offset, + index_t s_offset, + index_t i_offset /*max 0xFFF*/, + index_t /*flag*/ = 0) + { + static_assert(sizeof(T) == 4); // subdword is buggy, use dword buf and convert manually + using mubuf_t = float; + asm volatile("buffer_load_ushort %0, %1, %2, %3 offen offset:%4" + : "+v"(reinterpret_cast(value)) + : "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset) + : "memory"); + } +}; + +template <> +struct buffer_load<1> +{ + template + CK_TILE_DEVICE void operator()(T& value, + int32x4_t res /*buffer resource*/, + index_t v_offset, + index_t s_offset, + index_t i_offset /*max 0xFFF*/, + index_t /*flag*/ = 0) + { + static_assert(sizeof(T) == 4); + using mubuf_t = float; + asm volatile("buffer_load_ubyte %0, %1, %2, %3 offen offset:%4" + : "+v"(reinterpret_cast(value)) + : "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset) + : "memory"); + } +}; + +template +struct buffer_load_if; + +template <> +struct buffer_load_if<16> +{ + template + CK_TILE_DEVICE void operator()(T& value, + int32x4_t res /*buffer resource*/, + index_t v_offset, + index_t s_offset, + index_t i_offset /*max 0xFFF*/, + index_t flag = 0) + { + static_assert(sizeof(T) == 16); + auto saved_exec = __builtin_amdgcn_read_exec(); + using mubuf_t = float __attribute__((ext_vector_type(4))); + static_assert(sizeof(mubuf_t) == sizeof(T)); + asm volatile( + "v_cmpx_le_u32 exec, 1, %5\n" + "buffer_load_dwordx4 %0, %1, %2, %3 offen offset:%4\n" + "s_mov_b64 exec %6" + : "+v"(reinterpret_cast(value)) + : "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset), "v"(flag), "s"(saved_exec) + : "memory"); + } +}; + +template <> +struct buffer_load_if<8> +{ + template + CK_TILE_DEVICE void operator()(T& value, + int32x4_t res /*buffer resource*/, + index_t v_offset, + index_t s_offset, + index_t i_offset /*max 0xFFF*/, + index_t flag = 0) + { + static_assert(sizeof(T) == 8); + auto saved_exec = __builtin_amdgcn_read_exec(); + using mubuf_t = float __attribute__((ext_vector_type(2))); + asm volatile( + "v_cmpx_le_u32 exec, 1, %5\n" + "buffer_load_dwordx2 %0, %1, %2, %3 offen offset:%4\n" + "s_mov_b64 exec %6" + : "+v"(reinterpret_cast(value)) + : "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset), "v"(flag), "s"(saved_exec) + : "memory"); + } +}; + +template <> +struct buffer_load_if<4> +{ + template + CK_TILE_DEVICE void operator()(T& value, + int32x4_t res /*buffer resource*/, + index_t v_offset, + index_t s_offset, + index_t i_offset /*max 0xFFF*/, + index_t flag = 0) + { + static_assert(sizeof(T) == 4); + auto saved_exec = __builtin_amdgcn_read_exec(); + using mubuf_t = float; + asm volatile( + "v_cmpx_le_u32 exec, 1, %5\n" + "buffer_load_dword %0, %1, %2, %3 offen offset:%4\n" + "s_mov_b64 exec %6" + : "+v"(reinterpret_cast(value)) + : "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset), "v"(flag), "s"(saved_exec) + : "memory"); + } +}; + +template <> +struct buffer_load_if<2> +{ + template + CK_TILE_DEVICE void operator()(T& value, + int32x4_t res /*buffer resource*/, + index_t v_offset, + index_t s_offset, + index_t i_offset /*max 0xFFF*/, + index_t flag = 0) + { + static_assert(sizeof(T) == 4); + auto saved_exec = __builtin_amdgcn_read_exec(); + using mubuf_t = float; + asm volatile( + "v_cmpx_le_u32 exec, 1, %5\n" + "buffer_load_ushort %0, %1, %2, %3 offen offset:%4\n" + "s_mov_b64 exec %6" + : "+v"(reinterpret_cast(value)) + : "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset), "v"(flag), "s"(saved_exec) + : "memory"); + } +}; + +template <> +struct buffer_load_if<1> +{ + template + CK_TILE_DEVICE void operator()(T& value, + int32x4_t res /*buffer resource*/, + index_t v_offset, + index_t s_offset, + index_t i_offset /*max 0xFFF*/, + index_t flag = 0) + { + static_assert(sizeof(T) == 4); + auto saved_exec = __builtin_amdgcn_read_exec(); + using mubuf_t = float; + asm volatile( + "v_cmpx_le_u32 exec, 1, %5\n" + "buffer_load_ubyte %0, %1, %2, %3 offen offset:%4\n" + "s_mov_b64 exec %6" + : "+v"(reinterpret_cast(value)) + : "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset), "v"(flag), "s"(saved_exec) + : "memory"); + } +}; +#pragma clang diagnostic pop // "-Wundefined-reinterpret-cast" +template +struct buffer_store; + +template <> +struct buffer_store<16> +{ + template + CK_TILE_DEVICE void operator()(const T& value, + int32x4_t res /*buffer resource*/, + index_t v_offset, + index_t s_offset, + index_t i_offset /*max 0xFFF*/, + index_t /*flag*/ = 1) + { + static_assert(sizeof(T) == 16); + asm volatile("buffer_store_dwordx4 %0, %1, %2, %3 offen offset:%4" + : + : "v"(value), "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset) + : "memory"); + } +}; + +template <> +struct buffer_store<8> +{ + template + CK_TILE_DEVICE void operator()(const T& value, + int32x4_t res /*buffer resource*/, + index_t v_offset, + index_t s_offset, + index_t i_offset /*max 0xFFF*/, + index_t /*flag*/ = 1) + { + static_assert(sizeof(T) == 8); + asm volatile("buffer_store_dwordx2 %0, %1, %2, %3 offen offset:%4" + : + : "v"(value), "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset) + : "memory"); + } +}; + +template <> +struct buffer_store<4> +{ + template + CK_TILE_DEVICE void operator()(const T& value, + int32x4_t res /*buffer resource*/, + index_t v_offset, + index_t s_offset, + index_t i_offset /*max 0xFFF*/, + index_t /*flag*/ = 1) + { + static_assert(sizeof(T) == 4); + asm volatile("buffer_store_dword %0, %1, %2, %3 offen offset:%4" + : + : "v"(value), "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset) + : "memory"); + } +}; + +template <> +struct buffer_store<2> +{ + template + CK_TILE_DEVICE void operator()(const T& value, + int32x4_t res /*buffer resource*/, + index_t v_offset, + index_t s_offset, + index_t i_offset /*max 0xFFF*/, + index_t /*flag*/ = 1) + { + static_assert(sizeof(T) == 2); + asm volatile("buffer_store_short %0, %1, %2, %3 offen offset:%4" + : + : "v"(value), "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset) + : "memory"); + } +}; + +template <> +struct buffer_store<1> +{ + template + CK_TILE_DEVICE void operator()(const T& value, + int32x4_t res /*buffer resource*/, + index_t v_offset, + index_t s_offset, + index_t i_offset /*max 0xFFF*/, + index_t /*flag*/ = 1) + { + static_assert(sizeof(T) == 1); + asm volatile("buffer_store_byte %0, %1, %2, %3 offen offset:%4" + : + : "v"(value), "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset) + : "memory"); + } +}; + +template +struct buffer_store_if; + +template <> +struct buffer_store_if<16> +{ + template + CK_TILE_DEVICE void operator()(const T& value, + int32x4_t res /*buffer resource*/, + index_t v_offset, + index_t s_offset, + index_t i_offset /*max 0xFFF*/, + index_t flag = 1) + { + static_assert(sizeof(T) == 16); + auto save_exec = __builtin_amdgcn_read_exec(); + asm volatile("v_cmpx_le_u32 exec, 1, %5\n" + "buffer_store_dwordx4 %0, %1, %2, %3 offen offset:%4\n" + "s_mov_b64 exec %6" + : + : "v"(value), + "v"(v_offset), + "s"(res), + "s"(s_offset), + "n"(i_offset), + "v"(flag), + "s"(save_exec) + : "memory"); + } +}; + +template <> +struct buffer_store_if<8> +{ + template + CK_TILE_DEVICE void operator()(const T& value, + int32x4_t res /*buffer resource*/, + index_t v_offset, + index_t s_offset, + index_t i_offset /*max 0xFFF*/, + index_t flag = 1) + { + static_assert(sizeof(T) == 8); + auto save_exec = __builtin_amdgcn_read_exec(); + asm volatile("v_cmpx_le_u32 exec, 1, %5\n" + "buffer_store_dwordx2 %0, %1, %2, %3 offen offset:%4\n" + "s_mov_b64 exec %6" + : + : "v"(value), + "v"(v_offset), + "s"(res), + "s"(s_offset), + "n"(i_offset), + "v"(flag), + "s"(save_exec) + : "memory"); + } +}; + +template <> +struct buffer_store_if<4> +{ + template + CK_TILE_DEVICE void operator()(const T& value, + int32x4_t res /*buffer resource*/, + index_t v_offset, + index_t s_offset, + index_t i_offset /*max 0xFFF*/, + index_t flag = 1) + { + static_assert(sizeof(T) == 4); + auto save_exec = __builtin_amdgcn_read_exec(); + asm volatile("v_cmpx_le_u32 exec, 1, %5\n" + "buffer_store_dword %0, %1, %2, %3 offen offset:%4\n" + "s_mov_b64 exec %6" + : + : "v"(value), + "v"(v_offset), + "s"(res), + "s"(s_offset), + "n"(i_offset), + "v"(flag), + "s"(save_exec) + : "memory"); + } +}; + +template <> +struct buffer_store_if<2> +{ + template + CK_TILE_DEVICE void operator()(const T& value, + int32x4_t res /*buffer resource*/, + index_t v_offset, + index_t s_offset, + index_t i_offset /*max 0xFFF*/, + index_t flag = 1) + { + static_assert(sizeof(T) == 2); + auto save_exec = __builtin_amdgcn_read_exec(); + asm volatile("v_cmpx_le_u32 exec, 1, %5\n" + "buffer_store_short %0, %1, %2, %3 offen offset:%4\n" + "s_mov_b64 exec %6" + : + : "v"(value), + "v"(v_offset), + "s"(res), + "s"(s_offset), + "n"(i_offset), + "v"(flag), + "s"(save_exec) + : "memory"); + } +}; + +template <> +struct buffer_store_if<1> +{ + template + CK_TILE_DEVICE void operator()(const T& value, + int32x4_t res /*buffer resource*/, + index_t v_offset, + index_t s_offset, + index_t i_offset /*max 0xFFF*/, + index_t flag = 1) + { + static_assert(sizeof(T) == 1); + auto save_exec = __builtin_amdgcn_read_exec(); + asm volatile("v_cmpx_le_u32 exec, 1, %5\n" + "buffer_store_byte %0, %1, %2, %3 offen offset:%4\n" + "s_mov_b64 exec %6" + : + : "v"(value), + "v"(v_offset), + "s"(res), + "s"(s_offset), + "n"(i_offset), + "v"(flag), + "s"(save_exec) + : "memory"); + } +}; + +CK_TILE_DEVICE void buffer_load_fence(index_t cnt = 0) +{ + asm volatile("s_waitcnt vmcnt(%0)" : : "n"(cnt) : "memory"); +} + +// clang-format off +namespace impl{ + +// can't use "+v" since there could be potential extra move(read/write) +// use "v" can help remove such duplicated moves +// besides, fake this as "memory" operation to force later valu after this fence +// TODO: may have scratch (because this is memory?) +// need to reduce extra move inside compiler +template +CK_TILE_DEVICE void insert_dummy_dep_per_dword(static_buffer_c& b) +{ + for (auto i = 0; i < b.size(); i++) asm volatile(" " : : "v"(b.get(i)) : "memory"); +} + +template<> +CK_TILE_DEVICE void insert_dummy_dep_per_dword<2>(static_buffer_c& b) +{ + asm volatile(" " : : "v"(b.get(0)), "v"(b.get(1)) : "memory"); +} + +template<> +CK_TILE_DEVICE void insert_dummy_dep_per_dword<3>(static_buffer_c& b) +{ + asm volatile(" " : : "v"(b.get(0)), "v"(b.get(1)), "v"(b.get(2)) : "memory"); +} + +template<> +CK_TILE_DEVICE void insert_dummy_dep_per_dword<4>(static_buffer_c& b) +{ + asm volatile(" " : : "v"(b.get(0)), "v"(b.get(1)), "v"(b.get(2)), "v"(b.get(3)) : "memory"); +} + +template<> +CK_TILE_DEVICE void insert_dummy_dep_per_dword<8>(static_buffer_c& b) +{ + asm volatile(" " : : "v"(b.get(0)), "v"(b.get(1)), "v"(b.get(2)), "v"(b.get(3)), "v"(b.get(4)), "v"(b.get(5)), "v"(b.get(6)), "v"(b.get(7)) : "memory"); +} + +template<> +CK_TILE_DEVICE void insert_dummy_dep_per_dword<16>(static_buffer_c& b) +{ + asm volatile(" " : : "v"(b.get(0)), "v"(b.get(1)), "v"(b.get(2)), "v"(b.get(3)), "v"(b.get(4)), "v"(b.get(5)), "v"(b.get(6)), "v"(b.get(7)), + "v"(b.get(8)), "v"(b.get(9)), "v"(b.get(10)), "v"(b.get(11)), "v"(b.get(12)), "v"(b.get(13)), "v"(b.get(14)), "v"(b.get(15)) : "memory"); +} + +template<> +CK_TILE_DEVICE void insert_dummy_dep_per_dword<32>(static_buffer_c& b) +{ + asm volatile(" " : : "v"(b.get(0)), "v"(b.get(1)), "v"(b.get(2)), "v"(b.get(3)), "v"(b.get(4)), "v"(b.get(5)), "v"(b.get(6)), "v"(b.get(7)), + "v"(b.get(8)), "v"(b.get(9)), "v"(b.get(10)), "v"(b.get(11)), "v"(b.get(12)), "v"(b.get(13)), "v"(b.get(14)), "v"(b.get(15)), + "v"(b.get(16)), "v"(b.get(17)), "v"(b.get(18)), "v"(b.get(19)), "v"(b.get(20)), "v"(b.get(21)), "v"(b.get(22)), "v"(b.get(23)), + "v"(b.get(24)), "v"(b.get(25)), "v"(b.get(26)), "v"(b.get(27)), "v"(b.get(28)), "v"(b.get(29)), "v"(b.get(30)), "v"(b.get(31)) : "memory"); +} + +CK_TILE_DEVICE void insert_dummy_dep() {} + +template +CK_TILE_DEVICE void insert_dummy_dep(T & buffer) +{ + // TODO: indeed we expect T to be multiple of dword. subdword is always buggy + using da_type = static_buffer_c; + auto & dummy = reinterpret_cast(buffer); + insert_dummy_dep_per_dword(dummy); +} + +template +CK_TILE_DEVICE void insert_dummy_dep(Tx& bx, Ty&... by) +{ + insert_dummy_dep(bx); + insert_dummy_dep(by...); +} +} +// clang-format on +template +CK_TILE_DEVICE void buffer_load_fence(index_t cnt = 0, T&... o) +{ + asm volatile("s_waitcnt vmcnt(%0)" : : "n"(cnt) : "memory"); + impl::insert_dummy_dep(o...); +} + +CK_TILE_DEVICE void buffer_store_fence(index_t cnt = 0) +{ + asm volatile("s_waitcnt vmcnt(%0)" : : "n"(cnt) : "memory"); +} + +// buffer load i8 +CK_TILE_DEVICE int8_t +llvm_amdgcn_raw_buffer_load_i8(int32x4_t srsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.i8"); + +CK_TILE_DEVICE int8x2_t +llvm_amdgcn_raw_buffer_load_i8x2(int32x4_t srsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v2i8"); + +CK_TILE_DEVICE int8x4_t +llvm_amdgcn_raw_buffer_load_i8x4(int32x4_t srsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v4i8"); + +// buffer load i16 +CK_TILE_DEVICE bhalf_t +llvm_amdgcn_raw_buffer_load_i16(int32x4_t srsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.i16"); + +CK_TILE_DEVICE bhalf2_t +llvm_amdgcn_raw_buffer_load_i16x2(int32x4_t srsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v2i16"); + +CK_TILE_DEVICE bhalf4_t +llvm_amdgcn_raw_buffer_load_i16x4(int32x4_t srsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v4i16"); + +// buffer load i32 +CK_TILE_DEVICE int32_t +llvm_amdgcn_raw_buffer_load_i32(int32x4_t srsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.i32"); + +CK_TILE_DEVICE int32x2_t +llvm_amdgcn_raw_buffer_load_i32x2(int32x4_t srsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v2i32"); + +CK_TILE_DEVICE int32x4_t +llvm_amdgcn_raw_buffer_load_i32x4(int32x4_t srsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v4i32"); + +// buffer load fp16 +CK_TILE_DEVICE half_t +llvm_amdgcn_raw_buffer_load_fp16(int32x4_t srsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.f16"); + +CK_TILE_DEVICE half2_t +llvm_amdgcn_raw_buffer_load_fp16x2(int32x4_t srsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v2f16"); + +CK_TILE_DEVICE half4_t +llvm_amdgcn_raw_buffer_load_fp16x4(int32x4_t srsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v4f16"); + +// buffer load fp32 +CK_TILE_DEVICE float +llvm_amdgcn_raw_buffer_load_fp32(int32x4_t srsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.f32"); + +CK_TILE_DEVICE float2_t +llvm_amdgcn_raw_buffer_load_fp32x2(int32x4_t srsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v2f32"); + +CK_TILE_DEVICE float4_t +llvm_amdgcn_raw_buffer_load_fp32x4(int32x4_t srsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v4f32"); + +// buffer store i8 +CK_TILE_DEVICE void +llvm_amdgcn_raw_buffer_store_i8(int8_t vdata, + int32x4_t rsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.i8"); + +CK_TILE_DEVICE void +llvm_amdgcn_raw_buffer_store_i8x2(int8x2_t vdata, + int32x4_t rsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v2i8"); + +CK_TILE_DEVICE void +llvm_amdgcn_raw_buffer_store_i8x4(int8x4_t vdata, + int32x4_t rsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4i8"); + +// buffer store i16 +CK_TILE_DEVICE void +llvm_amdgcn_raw_buffer_store_i16(bhalf_t vdata, + int32x4_t rsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.i16"); + +CK_TILE_DEVICE void +llvm_amdgcn_raw_buffer_store_i16x2(bhalf2_t vdata, + int32x4_t rsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v2i16"); + +CK_TILE_DEVICE void +llvm_amdgcn_raw_buffer_store_i16x4(bhalf4_t vdata, + int32x4_t rsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4i16"); + +// buffer store i32 +CK_TILE_DEVICE void +llvm_amdgcn_raw_buffer_store_i32(int32_t vdata, + int32x4_t rsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.i32"); + +CK_TILE_DEVICE void +llvm_amdgcn_raw_buffer_store_i32x2(int32x2_t vdata, + int32x4_t rsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v2i32"); + +CK_TILE_DEVICE void +llvm_amdgcn_raw_buffer_store_i32x4(int32x4_t vdata, + int32x4_t rsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4i32"); + +// buffer store fp16 +CK_TILE_DEVICE void +llvm_amdgcn_raw_buffer_store_fp16(half_t vdata, + int32x4_t rsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.f16"); + +CK_TILE_DEVICE void +llvm_amdgcn_raw_buffer_store_fp16x2(half2_t vdata, + int32x4_t rsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v2f16"); + +CK_TILE_DEVICE void +llvm_amdgcn_raw_buffer_store_fp16x4(half4_t vdata, + int32x4_t rsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4f16"); + +// buffer store fp32 +CK_TILE_DEVICE void +llvm_amdgcn_raw_buffer_store_fp32(float vdata, + int32x4_t rsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.f32"); + +CK_TILE_DEVICE void +llvm_amdgcn_raw_buffer_store_fp32x2(float2_t vdata, + int32x4_t rsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v2f32"); + +CK_TILE_DEVICE void +llvm_amdgcn_raw_buffer_store_fp32x4(float4_t vdata, + int32x4_t rsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4f32"); + +// buffer atomic-add fp16 +CK_TILE_DEVICE half2_t llvm_amdgcn_raw_buffer_atomic_add_fp16x2( + half2_t vdata, + int32x4_t rsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.fadd.v2f16"); + +// buffer atomic-add i32 +CK_TILE_DEVICE int32_t llvm_amdgcn_raw_buffer_atomic_add_i32( + int32_t vdata, + int32x4_t rsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.add.i32"); + +// buffer atomic-add fp32 +CK_TILE_DEVICE float llvm_amdgcn_raw_buffer_atomic_add_fp32( + float vdata, + int32x4_t rsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.fadd.f32"); + +// buffer atomic-max fp64 +CK_TILE_DEVICE double +llvm_amdgcn_raw_buffer_atomic_max_fp64(double vdata, + int32x4_t rsrc, // dst_wave_buffer_resource + int voffset, // dst_thread_addr_offset + int soffset, // dst_wave_addr_offset + int glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.fmax.f64"); + +CK_TILE_DEVICE void async_buffer_load_dword(void* smem, + int32x4_t rsrc, + index_t voffset, + index_t soffset, + index_t ioffset /*max 0xFFF*/, + index_t /*flag*/ = 0) +{ + asm volatile("buffer_load_dword %1, %2, %3 offen offset:%4 lds" + : "=r"(smem) /*dummy dependency for smem*/ + : "v"(voffset), "s"(rsrc), "s"(soffset), "n"(ioffset) + : "memory"); +} + +CK_TILE_DEVICE void async_buffer_load_fence(index_t cnt = 0) +{ + asm volatile("s_waitcnt vmcnt(%0)" : : "n"(cnt) : "memory"); +} + +// memory coherency bit for buffer store/load instruction +// check ISA manual for each GFX target +// e.g. for +// https://www.amd.com/system/files/TechDocs/instinct-mi200-cdna2-instruction-set-architecture.pdf, +// page 67~68 +enum struct amd_buffer_coherence_enum +{ + coherence_default = 0, // default value + glc = 1, + slc = 2, + glc_slc = 3, +}; + +template +CK_TILE_DEVICE typename vector_type::type +amd_buffer_load_impl_raw(int32x4_t src_wave_buffer_resource, + index_t src_thread_addr_offset, + index_t src_wave_addr_offset) +{ + static_assert(N == 1 || N == 2 || N == 4 || N == 8 || N == 16 || N == 32 || N == 64, + "wrong! not implemented"); + + if constexpr(N == 1) + { + return llvm_amdgcn_raw_buffer_load_i8(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset, + static_cast(coherence)); + } + else if constexpr(N == 2) + { + + int16_t tmp = llvm_amdgcn_raw_buffer_load_i16(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset, + static_cast(coherence)); + + return bit_cast(tmp); + } + else if constexpr(N == 4) + { + int32_t tmp = llvm_amdgcn_raw_buffer_load_i32(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset, + static_cast(coherence)); + + return bit_cast(tmp); + } + else if constexpr(N == 8) + { + int32x2_t tmp = llvm_amdgcn_raw_buffer_load_i32x2(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset, + static_cast(coherence)); + + return bit_cast(tmp); + } + else if constexpr(N == 16) + { + int32x4_t tmp = llvm_amdgcn_raw_buffer_load_i32x4(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset, + static_cast(coherence)); + return bit_cast(tmp); + } + else if constexpr(N == 32) + { + int32x4_t tmp0 = llvm_amdgcn_raw_buffer_load_i32x4(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset, + static_cast(coherence)); + int32x4_t tmp1 = + llvm_amdgcn_raw_buffer_load_i32x4(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset + 4 * sizeof(int32_t), + static_cast(coherence)); + vector_type tmp; + + tmp.AsType()(number<0>{}) = tmp0; + tmp.AsType()(number<1>{}) = tmp1; + + return bit_cast(tmp); + } + else if constexpr(N == 64) + { + int32x4_t tmp0 = llvm_amdgcn_raw_buffer_load_i32x4(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset, + static_cast(coherence)); + int32x4_t tmp1 = + llvm_amdgcn_raw_buffer_load_i32x4(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset + 4 * sizeof(int32_t), + static_cast(coherence)); + int32x4_t tmp2 = + llvm_amdgcn_raw_buffer_load_i32x4(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset + 8 * sizeof(int32_t), + static_cast(coherence)); + int32x4_t tmp3 = + llvm_amdgcn_raw_buffer_load_i32x4(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset + 12 * sizeof(int32_t), + static_cast(coherence)); + + vector_type tmp; + + tmp.AsType()(number<0>{}) = tmp0; + tmp.AsType()(number<1>{}) = tmp1; + tmp.AsType()(number<2>{}) = tmp2; + tmp.AsType()(number<3>{}) = tmp3; + + return bit_cast(tmp); + } +} + +#ifndef BUFFER_LOAD_USE_INLINEASM +#define BUFFER_LOAD_USE_INLINEASM 0 +#endif + +template +CK_TILE_DEVICE typename vector_type::type +amd_buffer_load_impl(int32x4_t src_wave_buffer_resource, + index_t src_thread_addr_offset, + index_t src_wave_addr_offset) +{ + static_assert( + (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8)) || + (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || + (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || + (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || + (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || + (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || + (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || + (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)), + "wrong! not implemented"); + + if constexpr(is_same::value) // fp32 + { + if constexpr(N == 1) + { + return llvm_amdgcn_raw_buffer_load_fp32(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset, + static_cast(coherence)); + } + else if constexpr(N == 2) + { + return llvm_amdgcn_raw_buffer_load_fp32x2(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset, + static_cast(coherence)); + } + else if constexpr(N == 4) + { + return llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset, + static_cast(coherence)); + } + else if constexpr(N == 8) + { + vector_type tmp; + + tmp.AsType()(number<0>{}) = + llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset, + static_cast(coherence)); + + tmp.AsType()(number<1>{}) = + llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset + 4 * sizeof(float), + static_cast(coherence)); + + return tmp.AsType()(number<0>{}); + } + else if constexpr(N == 16) + { + vector_type tmp; + + tmp.AsType()(number<0>{}) = + llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset, + static_cast(coherence)); + + tmp.AsType()(number<1>{}) = + llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset + 4 * sizeof(float), + static_cast(coherence)); + + tmp.AsType()(number<2>{}) = + llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset + 8 * sizeof(float), + static_cast(coherence)); + + tmp.AsType()(number<3>{}) = + llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset + 12 * sizeof(float), + static_cast(coherence)); + + return tmp.AsType()(number<0>{}); + } + } + else if constexpr(is_same::value) // fp16 + { + if constexpr(N == 1) + { + return llvm_amdgcn_raw_buffer_load_fp16(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset, + static_cast(coherence)); + } + else if constexpr(N == 2) + { + return llvm_amdgcn_raw_buffer_load_fp16x2(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset, + static_cast(coherence)); + } + else if constexpr(N == 4) + { + return llvm_amdgcn_raw_buffer_load_fp16x4(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset, + static_cast(coherence)); + } + else if constexpr(N == 8) + { + // use fp32 load to mimic fp16 load + float4_t tmp = llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset, + static_cast(coherence)); + + return bit_cast(tmp); + } + } + else if constexpr(is_same::value) // bf16 + { + if constexpr(N == 1) + { + return llvm_amdgcn_raw_buffer_load_i16(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset, + static_cast(coherence)); + } + else if constexpr(N == 2) + { + return llvm_amdgcn_raw_buffer_load_i16x2(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset, + static_cast(coherence)); + } + else if constexpr(N == 4) + { + return llvm_amdgcn_raw_buffer_load_i16x4(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset, + static_cast(coherence)); + } + else if constexpr(N == 8) + { + int32x4_t tmp = llvm_amdgcn_raw_buffer_load_i32x4(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset, + static_cast(coherence)); + + return bit_cast(tmp); + } + } + else // other datatype + { + using r_t = typename vector_type::type; + + auto raw_data = amd_buffer_load_impl_raw( + src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset); + + return bit_cast(raw_data); + } +} + +template +CK_TILE_DEVICE void amd_buffer_load_raw_impl(typename vector_type::type& dst, + int32x4_t src_wave_buffer_resource, + index_t src_thread_addr_offset, + index_t src_wave_addr_offset, + index_t flag = 0) +{ + constexpr index_t bytes = sizeof(T) * N; + static_assert(bytes == 1 || bytes == 2 || bytes == 4 || bytes == 8 || bytes == 16, + "wrong! not supported by buffer_load instruction"); + + using type = typename vector_type::type; + if constexpr(oob_conditional_check) + { + buffer_load_if{}( + dst, src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0, flag); + } + else + { + buffer_load{}( + dst, src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0, flag); + } +} + +template +CK_TILE_DEVICE void amd_async_buffer_load_impl(T* smem, + int32x4_t src_wave_buffer_resource, + index_t src_thread_addr_offset, + index_t src_wave_addr_offset, + index_t src_immediate_addr_offset = 0) +{ + static_assert(sizeof(T) * N == 4, "wrong! not implemented vector size"); + + async_buffer_load_dword(smem, + src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset, + src_immediate_addr_offset); +} + +template +CK_TILE_DEVICE void +amd_buffer_store_impl_raw(const typename vector_type::type src_thread_data, + int32x4_t dst_wave_buffer_resource, + index_t dst_thread_addr_offset, + index_t dst_wave_addr_offset) +{ + static_assert(N == 1 || N == 2 || N == 4 || N == 8 || N == 16 || N == 32 || N == 64, + "wrong! not implemented"); + + if constexpr(N == 1) + { + llvm_amdgcn_raw_buffer_store_i8(src_thread_data, + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + static_cast(coherence)); + } + else if constexpr(N == 2) + { + + llvm_amdgcn_raw_buffer_store_i16(bit_cast(src_thread_data), + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + static_cast(coherence)); + } + else if constexpr(N == 4) + { + llvm_amdgcn_raw_buffer_store_i32(bit_cast(src_thread_data), + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + static_cast(coherence)); + } + else if constexpr(N == 8) + { + llvm_amdgcn_raw_buffer_store_i32x2(bit_cast(src_thread_data), + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + static_cast(coherence)); + } + else if constexpr(N == 16) + { + llvm_amdgcn_raw_buffer_store_i32x4(bit_cast(src_thread_data), + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + static_cast(coherence)); + } + else if constexpr(N == 32) + { + vector_type tmp{bit_cast(src_thread_data)}; + + llvm_amdgcn_raw_buffer_store_i32x4(tmp.template AsType()[number<0>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + static_cast(coherence)); + + llvm_amdgcn_raw_buffer_store_i32x4(tmp.template AsType()[number<1>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset + sizeof(int32_t) * 4, + static_cast(coherence)); + } + else if constexpr(N == 64) + { + vector_type tmp{bit_cast(src_thread_data)}; + + llvm_amdgcn_raw_buffer_store_i32x4(tmp.template AsType()[number<0>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + static_cast(coherence)); + + llvm_amdgcn_raw_buffer_store_i32x4(tmp.template AsType()[number<1>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset + sizeof(int32_t) * 4, + static_cast(coherence)); + + llvm_amdgcn_raw_buffer_store_i32x4(tmp.template AsType()[number<2>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset + sizeof(int32_t) * 8, + static_cast(coherence)); + + llvm_amdgcn_raw_buffer_store_i32x4(tmp.template AsType()[number<3>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset + sizeof(int32_t) * 12, + static_cast(coherence)); + } +} + +template +CK_TILE_DEVICE void amd_buffer_store_impl(const typename vector_type::type src_thread_data, + int32x4_t dst_wave_buffer_resource, + index_t dst_thread_addr_offset, + index_t dst_wave_addr_offset) +{ + static_assert( + (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8)) || + (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || + (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || + (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || + (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || + (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || + (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || + (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)), + "wrong! not implemented"); + + if constexpr(is_same::value) // fp32 + { + if constexpr(N == 1) + { + llvm_amdgcn_raw_buffer_store_fp32(src_thread_data, + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + static_cast(coherence)); + } + else if constexpr(N == 2) + { + llvm_amdgcn_raw_buffer_store_fp32x2(src_thread_data, + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + static_cast(coherence)); + } + else if constexpr(N == 4) + { + llvm_amdgcn_raw_buffer_store_fp32x4(src_thread_data, + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + static_cast(coherence)); + } + else if constexpr(N == 8) + { + vector_type tmp{src_thread_data}; + llvm_amdgcn_raw_buffer_store_fp32x4(tmp.AsType()[number<0>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + static_cast(coherence)); + llvm_amdgcn_raw_buffer_store_fp32x4(tmp.AsType()[number<1>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset + 4 * sizeof(float), + static_cast(coherence)); + } + } + else if constexpr(is_same::value) // fp16 + { + if constexpr(N == 1) + { + llvm_amdgcn_raw_buffer_store_fp16(src_thread_data, + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + static_cast(coherence)); + } + else if constexpr(N == 2) + { + llvm_amdgcn_raw_buffer_store_fp16x2(src_thread_data, + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + static_cast(coherence)); + } + else if constexpr(N == 4) + { + llvm_amdgcn_raw_buffer_store_fp16x4(src_thread_data, + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + static_cast(coherence)); + } + else if constexpr(N == 8) + { +#if 0 + vector_type tmp{src_thread_data}; + + llvm_amdgcn_raw_buffer_store_fp16x4(tmp.AsType()[number<0>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + static_cast(coherence)); + + llvm_amdgcn_raw_buffer_store_fp16x4(tmp.AsType()[number<1>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset + 4 * sizeof(half_t), + static_cast(coherence)); +#else + llvm_amdgcn_raw_buffer_store_fp32x4(bit_cast(src_thread_data), + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + static_cast(coherence)); +#endif + } + } + else if constexpr(is_same::value) // bf16 + { + if constexpr(N == 1) + { + llvm_amdgcn_raw_buffer_store_i16(src_thread_data, + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + static_cast(coherence)); + } + else if constexpr(N == 2) + { + llvm_amdgcn_raw_buffer_store_i16x2(src_thread_data, + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + static_cast(coherence)); + } + else if constexpr(N == 4) + { + llvm_amdgcn_raw_buffer_store_i16x4(src_thread_data, + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + static_cast(coherence)); + } + else if constexpr(N == 8) + { + vector_type tmp{src_thread_data}; + + llvm_amdgcn_raw_buffer_store_i16x4(tmp.AsType()[number<0>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + static_cast(coherence)); + + llvm_amdgcn_raw_buffer_store_i16x4(tmp.AsType()[number<1>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset + 4 * sizeof(bhalf_t), + static_cast(coherence)); + } + } + else + { + using r_t = typename vector_type::type; + + amd_buffer_store_impl_raw(bit_cast(src_thread_data), + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset); + } +} + +template +CK_TILE_DEVICE void +amd_buffer_store_raw_impl(const typename vector_type::type dst_thread_data, + int32x4_t dst_wave_buffer_resource, + index_t dst_thread_addr_offset, + index_t dst_wave_addr_offset, + index_t is_valid_element = 1) +{ + constexpr index_t bytes = sizeof(T) * N; + static_assert(bytes == 1 || bytes == 2 || bytes == 4 || bytes == 8 || bytes == 16, + "wrong! not supported by buffer_store instruction"); + + using type = typename vector_type::type; + if constexpr(oob_conditional_check) + { + buffer_store_if{}(dst_thread_data, + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + 0, + is_valid_element); + } + else + { + buffer_store{}(dst_thread_data, + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + 0); + } +} + +template +CK_TILE_DEVICE void +amd_buffer_atomic_add_impl(const typename vector_type::type src_thread_data, + int32x4_t dst_wave_buffer_resource, + index_t dst_thread_addr_offset, + index_t dst_wave_addr_offset) +{ + static_assert((is_same::value && (N == 1 || N == 2 || N == 4)) || + (is_same::value && (N == 2 || N == 4 || N == 8)) || + (is_same::value && (N == 1 || N == 2 || N == 4)), + "wrong! not implemented"); + + if constexpr(is_same::value) + { + if constexpr(N == 1) + { + llvm_amdgcn_raw_buffer_atomic_add_fp32(src_thread_data, + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + 0); + } + else if constexpr(N == 2) + { + vector_type tmp{src_thread_data}; + + llvm_amdgcn_raw_buffer_atomic_add_fp32(tmp.AsType()[number<0>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + 0); + + llvm_amdgcn_raw_buffer_atomic_add_fp32(tmp.AsType()[number<1>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset + sizeof(float), + 0); + } + else if constexpr(N == 4) + { + vector_type tmp{src_thread_data}; + + llvm_amdgcn_raw_buffer_atomic_add_fp32(tmp.AsType()[number<0>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + 0); + + llvm_amdgcn_raw_buffer_atomic_add_fp32(tmp.AsType()[number<1>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset + sizeof(float), + 0); + + llvm_amdgcn_raw_buffer_atomic_add_fp32(tmp.AsType()[number<2>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset + 2 * sizeof(float), + 0); + + llvm_amdgcn_raw_buffer_atomic_add_fp32(tmp.AsType()[number<3>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset + 3 * sizeof(float), + 0); + } + } + else if constexpr(is_same::value) + { + if constexpr(N == 2) + { + llvm_amdgcn_raw_buffer_atomic_add_fp16x2(src_thread_data, + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + 0); + } + else if constexpr(N == 4) + { + vector_type tmp{src_thread_data}; + + static_for<0, 2, 1>{}([&](auto i) { + llvm_amdgcn_raw_buffer_atomic_add_fp16x2(tmp.AsType()[i], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset + i * sizeof(half2_t), + 0); + }); + } + else if constexpr(N == 8) + { + vector_type tmp{src_thread_data}; + + static_for<0, 4, 1>{}([&](auto i) { + llvm_amdgcn_raw_buffer_atomic_add_fp16x2(tmp.AsType()[i], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset + i * sizeof(half2_t), + 0); + }); + } + } + else if constexpr(is_same::value) + { + if constexpr(N == 1) + { + llvm_amdgcn_raw_buffer_atomic_add_i32(src_thread_data, + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + 0); + } + else if constexpr(N == 2) + { + vector_type tmp{src_thread_data}; + + llvm_amdgcn_raw_buffer_atomic_add_i32(tmp.AsType()[number<0>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + 0); + + llvm_amdgcn_raw_buffer_atomic_add_i32(tmp.AsType()[number<1>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset + sizeof(int32_t), + 0); + } + else if constexpr(N == 4) + { + vector_type tmp{src_thread_data}; + + llvm_amdgcn_raw_buffer_atomic_add_i32(tmp.AsType()[number<0>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + 0); + + llvm_amdgcn_raw_buffer_atomic_add_i32(tmp.AsType()[number<1>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset + sizeof(int32_t), + 0); + + llvm_amdgcn_raw_buffer_atomic_add_i32(tmp.AsType()[number<2>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset + 2 * sizeof(int32_t), + 0); + + llvm_amdgcn_raw_buffer_atomic_add_i32(tmp.AsType()[number<3>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset + 3 * sizeof(int32_t), + 0); + } + } +} + +template +CK_TILE_DEVICE void +amd_buffer_atomic_max_impl(const typename vector_type::type src_thread_data, + int32x4_t dst_wave_buffer_resource, + index_t dst_thread_addr_offset, + index_t dst_wave_addr_offset) +{ + static_assert((is_same::value && (N == 1 || N == 2 || N == 4)), + "wrong! not implemented"); + if constexpr(is_same::value) + { + if constexpr(N == 1) + { + llvm_amdgcn_raw_buffer_atomic_max_fp64(src_thread_data, + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + 0); + } + else if constexpr(N == 2) + { + vector_type tmp{src_thread_data}; + + llvm_amdgcn_raw_buffer_atomic_max_fp64(tmp.AsType()[number<0>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + 0); + + llvm_amdgcn_raw_buffer_atomic_max_fp64(tmp.AsType()[number<1>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset + sizeof(double), + 0); + } + else if constexpr(N == 4) + { + vector_type tmp{src_thread_data}; + + llvm_amdgcn_raw_buffer_atomic_max_fp64(tmp.AsType()[number<0>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + 0); + + llvm_amdgcn_raw_buffer_atomic_max_fp64(tmp.AsType()[number<1>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset + sizeof(double), + 0); + + llvm_amdgcn_raw_buffer_atomic_max_fp64(tmp.AsType()[number<2>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset + 2 * sizeof(double), + 0); + + llvm_amdgcn_raw_buffer_atomic_max_fp64(tmp.AsType()[number<3>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset + 3 * sizeof(double), + 0); + } + } +} + +// buffer_load requires: +// 1) p_src_wave must point to global memory space +// 2) p_src_wave must be a wavewise pointer. +// It is user's responsibility to make sure that is true. +// oob_conditional_check : dynamic check if out-of-bound +template +CK_TILE_DEVICE typename vector_type_maker::type::type +amd_buffer_load_invalid_element_return_zero(const T* p_src_wave, + index_t src_thread_element_offset, + bool src_thread_element_valid, + index_t src_element_space_size) +{ + const int32x4_t src_wave_buffer_resource = + make_wave_buffer_resource(p_src_wave, src_element_space_size); + + index_t src_thread_addr_offset = src_thread_element_offset * sizeof(T); + + using vector_t = typename vector_type_maker::type::type; + using scalar_t = typename scalar_type::type; + + constexpr index_t vector_size = scalar_type::vector_size; + +#if CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK + uint32_t src_addr_shift = [&]() { + if constexpr(oob_conditional_check) + return src_thread_element_valid ? 0 : 0x80000000; + else + return 0; + }(); + return amd_buffer_load_impl( + src_wave_buffer_resource, src_addr_shift + src_thread_addr_offset, 0); +#else + vector_t tmp = amd_buffer_load_impl( + src_wave_buffer_resource, src_thread_addr_offset, 0); + if constexpr(oob_conditional_check) + return src_thread_element_valid ? tmp : vector_t(0); + else + return tmp; +#endif +} + +// buffer_load requires: +// 1) p_src_wave must point to global memory space +// 2) p_src_wave must be a wavewise pointer. +// It is user's responsibility to make sure that is true. +template +CK_TILE_DEVICE typename vector_type_maker::type::type +amd_buffer_load_invalid_element_return_customized_value(const T* p_src_wave, + index_t src_thread_element_offset, + bool src_thread_element_valid, + index_t src_element_space_size, + T customized_value) +{ + const int32x4_t src_wave_buffer_resource = + make_wave_buffer_resource(p_src_wave, src_element_space_size); + + index_t src_thread_addr_offset = src_thread_element_offset * sizeof(T); + + using vector_t = typename vector_type_maker::type::type; + using scalar_t = typename scalar_type::type; + + constexpr index_t vector_size = scalar_type::vector_size; + + vector_t tmp = amd_buffer_load_impl( + src_wave_buffer_resource, src_thread_addr_offset, 0); + + if constexpr(oob_conditional_check) + return src_thread_element_valid ? tmp : vector_t(customized_value); + else + return tmp; +} + +template +CK_TILE_DEVICE void amd_buffer_load_raw(typename vector_type_maker::type::type& dst, + const T* p_src_wave, + index_t src_thread_element_offset, + index_t src_element_space_size, + index_t is_valid_element = 0) +{ + const int32x4_t src_wave_buffer_resource = + make_wave_buffer_resource(p_src_wave, src_element_space_size); + + index_t src_thread_addr_offset = src_thread_element_offset * sizeof(T); + + using vector_t = typename vector_type_maker::type::type; + using scalar_t = typename scalar_type::type; + + constexpr index_t vector_size = scalar_type::vector_size; + + amd_buffer_load_raw_impl( + dst, src_wave_buffer_resource, src_thread_addr_offset, 0, is_valid_element); +} + +// unfortunately async copy can not make sure invalid data is zero inside LDS +// ... unless people manually write zero to LDS at the proper address. +// so not support invalid_element check for now. +// buffer_load OOB still working. +template +CK_TILE_DEVICE void amd_async_buffer_load_with_oob(T* smem, + const T* p_src_wave, + index_t src_thread_element_offset, + index_t src_element_space_size) +{ + const int32x4_t src_wave_buffer_resource = + make_wave_buffer_resource(p_src_wave, src_element_space_size); + + index_t src_thread_addr_offset = src_thread_element_offset * sizeof(T); + + amd_async_buffer_load_impl( + smem, src_wave_buffer_resource, src_thread_addr_offset, 0, 0); +} + +// buffer_store requires: +// 1) p_dst_wave must point to global memory +// 2) p_dst_wave must be a wavewise pointer. +// It is user's responsibility to make sure that is true. +template +CK_TILE_DEVICE void +amd_buffer_store(const typename vector_type_maker::type::type src_thread_data, + T* p_dst_wave, + const index_t dst_thread_element_offset, + const bool dst_thread_element_valid, + const index_t dst_element_space_size) +{ + const int32x4_t dst_wave_buffer_resource = + make_wave_buffer_resource(p_dst_wave, dst_element_space_size); + + index_t dst_thread_addr_offset = dst_thread_element_offset * sizeof(T); + + using vector_t = typename vector_type_maker::type::type; + using scalar_t = typename scalar_type::type; + constexpr index_t vector_size = scalar_type::vector_size; + +#if CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK + uint32_t dst_addr_shift = [&]() { + if constexpr(oob_conditional_check) + return dst_thread_element_valid ? 0 : 0x80000000; + else + return 0; + }(); + amd_buffer_store_impl( + src_thread_data, dst_wave_buffer_resource, dst_addr_shift + dst_thread_addr_offset, 0); +#else + if constexpr(oob_conditional_check) + { + if(dst_thread_element_valid) + { + amd_buffer_store_impl( + src_thread_data, dst_wave_buffer_resource, dst_thread_addr_offset, 0); + } + } + else + { + amd_buffer_store_impl( + src_thread_data, dst_wave_buffer_resource, dst_thread_addr_offset, 0); + } +#endif +} + +template +CK_TILE_DEVICE void +amd_buffer_store_raw(const typename vector_type_maker::type::type src_thread_data, + T* p_dst_wave, + const index_t dst_thread_element_offset, + const bool dst_thread_element_valid, + const index_t dst_element_space_size) +{ + const int32x4_t dst_wave_buffer_resource = + make_wave_buffer_resource(p_dst_wave, dst_element_space_size); + + index_t dst_thread_addr_offset = dst_thread_element_offset * sizeof(T); + + using vector_t = typename vector_type_maker::type::type; + using scalar_t = typename scalar_type::type; + constexpr index_t vector_size = scalar_type::vector_size; + + amd_buffer_store_raw_impl( + src_thread_data, + dst_wave_buffer_resource, + dst_thread_addr_offset, + 0, + dst_thread_element_valid); +} + +// buffer_atomic_add requires: +// 1) p_dst_wave must point to global memory +// 2) p_dst_wave must be a wavewise pointer. +// It is user's responsibility to make sure that is true. +template +CK_TILE_DEVICE void +amd_buffer_atomic_add(const typename vector_type_maker::type::type src_thread_data, + T* p_dst_wave, + const index_t dst_thread_element_offset, + const bool dst_thread_element_valid, + const index_t dst_element_space_size) +{ + const int32x4_t dst_wave_buffer_resource = + make_wave_buffer_resource(p_dst_wave, dst_element_space_size); + + index_t dst_thread_addr_offset = dst_thread_element_offset * sizeof(T); + + using vector_t = typename vector_type_maker::type::type; + using scalar_t = typename scalar_type::type; + constexpr index_t vector_size = scalar_type::vector_size; + +#if CK_EXPERIMENTAL_USE_BUFFER_ATOMIC_ADD_OOB_CHECK_OFFSET_TRICK + uint32_t dst_addr_shift = dst_thread_element_valid ? 0 : 0x80000000; + + amd_buffer_atomic_add_impl( + src_thread_data, dst_wave_buffer_resource, dst_addr_shift + dst_thread_addr_offset, 0); +#else + if(dst_thread_element_valid) + { + amd_buffer_atomic_add_impl( + src_thread_data, dst_wave_buffer_resource, dst_thread_addr_offset, 0); + } +#endif +} + +// buffer_atomic_max requires: +// 1) p_dst_wave must point to global memory +// 2) p_dst_wave must be a wavewise pointer. +// It is user's responsibility to make sure that is true. +template +CK_TILE_DEVICE void +amd_buffer_atomic_max(const typename vector_type_maker::type::type src_thread_data, + T* p_dst_wave, + const index_t dst_thread_element_offset, + const bool dst_thread_element_valid, + const index_t dst_element_space_size) +{ + const int32x4_t dst_wave_buffer_resource = + make_wave_buffer_resource(p_dst_wave, dst_element_space_size); + + index_t dst_thread_addr_offset = dst_thread_element_offset * sizeof(T); + + using vector_t = typename vector_type_maker::type::type; + using scalar_t = typename scalar_type::type; + constexpr index_t vector_size = scalar_type::vector_size; + +#if CK_EXPERIMENTAL_USE_BUFFER_ATOMIC_MAX_OOB_CHECK_OFFSET_TRICK + uint32_t dst_addr_shift = dst_thread_element_valid ? 0 : 0x80000000; + + amd_buffer_atomic_max_impl( + src_thread_data, dst_wave_buffer_resource, dst_addr_shift + dst_thread_addr_offset, 0); +#else + if(dst_thread_element_valid) + { + amd_buffer_atomic_max_impl( + src_thread_data, dst_wave_buffer_resource, dst_thread_addr_offset, 0); + } +#endif +} + +// Direct loads from global to LDS. +CK_TILE_DEVICE void +llvm_amdgcn_raw_buffer_load_lds(int32x4_t rsrc, + __attribute__((address_space(3))) uint32_t* lds_ptr, + index_t size, + index_t voffset, + index_t soffset, + index_t offset, + index_t aux) __asm("llvm.amdgcn.raw.buffer.load.lds"); + +template +CK_TILE_DEVICE void amd_direct_load_global_to_lds(const T* global_base_ptr, + const index_t global_offset, + T* lds_base_ptr, + const index_t lds_offset, + const bool is_valid, + const index_t src_element_space_size) +{ + // Direct loads require that each thread reads and writes exactly a single DWORD. + constexpr auto dword_bytes = 4; + constexpr auto bytes_per_thread = sizeof(T) * NumElemsPerThread; + static_assert(bytes_per_thread == dword_bytes); + + const uint32_t* global_ptr = + reinterpret_cast(reinterpret_cast(global_base_ptr)); + const int32x4_t src_resource = make_wave_buffer_resource(global_ptr, src_element_space_size); + const index_t global_offset_bytes = is_valid ? global_offset * sizeof(T) : 0x80000000; + +#if CK_USE_AMD_LDS_DIRECT_LOAD_INLINE_ASM + T* lds_ptr = lds_base_ptr + lds_offset; + auto const lds_ptr_sgpr = + __builtin_amdgcn_readfirstlane((reinterpret_cast(lds_ptr))); + asm volatile("s_mov_b32 m0, %0; \n\t" + "buffer_load_dword %1, %2, 0 offen lds;\n\t" ::"s"(lds_ptr_sgpr), + "v"(global_offset_bytes), + "s"(src_resource)); +#else + // LDS pointer must be attributed with the LDS address space. + __attribute__((address_space(3))) uint32_t* lds_ptr = + reinterpret_cast<__attribute__((address_space(3))) uint32_t*>( + reinterpret_cast(lds_base_ptr + lds_offset)); + + llvm_amdgcn_raw_buffer_load_lds( + src_resource, lds_ptr, sizeof(uint32_t), global_offset_bytes, 0, 0, 0); +#endif +} + +} // namespace ck_tile diff --git a/include/ck_tile/core/config.hpp b/include/ck_tile/core/config.hpp new file mode 100644 index 0000000000..d8e89b9190 --- /dev/null +++ b/include/ck_tile/core/config.hpp @@ -0,0 +1,56 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#ifdef __HIPCC__ +#define CK_TILE_HOST __host__ +#define CK_TILE_DEVICE __device__ +#define CK_TILE_HOST_DEVICE __host__ __device__ +#else +#define CK_TILE_HOST inline +#define CK_TILE_DEVICE inline +#define CK_TILE_HOST_DEVICE inline +#endif + +#define CK_TILE_FLOAT_TO_BFLOAT16_STANDARD 0 +#define CK_TILE_FLOAT_TO_BFLOAT16_TRUNCATE_WITH_NAN 1 +#define CK_TILE_FLOAT_TO_BFLOAT16_TRUNCATE 2 + +#ifndef CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT +#define CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT CK_TILE_FLOAT_TO_BFLOAT16_TRUNCATE +#endif + +#define CK_TILE_FLOAT_TO_FP8_STANDARD 0 +#define CK_TILE_FLOAT_TO_FP8_STOCHASTIC 1 + +#ifndef CK_TILE_FLOAT_TO_FP8_DEFAULT +#define CK_TILE_FLOAT_TO_FP8_DEFAULT CK_TILE_FLOAT_TO_FP8_STANDARD +#endif + +#ifndef STATIC_ASSERT +#ifndef NDEBUG +#define STATIC_ASSERT(...) static_assert(__VA_ARGS__) +#else +#define STATIC_ASSERT(...) +#endif +#endif // #ifndef STATIC_ASSERT + +// in the old rocm period, we have to use tuple array implementation to implement this +// so turn on the _USE_TUPLE if meet compiler error, otherwise _USE_ARRAY by default. +#define CK_TILE_STATICALLY_INDEXED_ARRAY_USE_ARRAY 0 +#define CK_TILE_STATICALLY_INDEXED_ARRAY_USE_TUPLE 1 +#ifndef CK_TILE_STATICALLY_INDEXED_ARRAY_DEFAULT +#define CK_TILE_STATICALLY_INDEXED_ARRAY_DEFAULT CK_TILE_STATICALLY_INDEXED_ARRAY_USE_ARRAY +#endif + +#ifndef CK_TILE_USE_LAUNCH_BOUNDS +#define CK_TILE_USE_LAUNCH_BOUNDS 1 +#endif + +#ifndef CK_TILE_TIME_KERNEL +#define CK_TILE_TIME_KERNEL 1 +#endif + +#define CK_TILE_MAX_THREAD_PER_BLOCK 256 +#define CK_TILE_MIN_BLOCK_PER_CU 2 diff --git a/include/ck_tile/core/container/array.hpp b/include/ck_tile/core/container/array.hpp new file mode 100644 index 0000000000..67d0379afc --- /dev/null +++ b/include/ck_tile/core/container/array.hpp @@ -0,0 +1,201 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core/config.hpp" +#include "ck_tile/core/numeric/integer.hpp" +#include "ck_tile/core/numeric/integral_constant.hpp" + +namespace ck_tile { + +// use aggregate initialization for this type +// e.g. array buf {0}; => {0, 0, 0, 0}, clean +// array buf {3, 2}; => {3, 2, 2, 2} (not {3,2,0,0}) +// use make_array_with({...}) to construct an array with compatible behavior as old ck +// TODO: manually added constructor same as old ck +template +struct array +{ + using value_type = T_; + static constexpr index_t N = N_; + value_type data[N]; + CK_TILE_HOST_DEVICE constexpr array() : data{} {} + // TODO: will initialize the data[] with the last value repeatedly + // behavior different from std + CK_TILE_HOST_DEVICE constexpr array(std::initializer_list ilist) + { + constexpr index_t list_size = std::initializer_list{}.size(); + static_assert(list_size <= N, "out of bound"); + + index_t i = 0; + value_type vlast = value_type{}; + + for(const value_type& val : ilist) + { + data[i] = val; + vlast = val; + ++i; + } + for(; i < N; ++i) + { + data[i] = vlast; + } + } + CK_TILE_HOST_DEVICE static constexpr auto size() { return N; } + CK_TILE_HOST_DEVICE static constexpr bool is_static() { return is_static_v; } + + // clang-format off + CK_TILE_HOST_DEVICE constexpr auto& get() { return data; } + CK_TILE_HOST_DEVICE constexpr const auto& get() const { return data; } + CK_TILE_HOST_DEVICE constexpr auto& get(index_t i) { return data[i]; } + CK_TILE_HOST_DEVICE constexpr const auto& get(index_t i) const { return data[i]; } + template CK_TILE_HOST_DEVICE constexpr auto& get() { return data[I]; } + template CK_TILE_HOST_DEVICE constexpr const auto& get() const { return data[I]; } + template CK_TILE_HOST_DEVICE constexpr auto& get(number) { return data[I]; } + template CK_TILE_HOST_DEVICE constexpr const auto& get(number) const { return data[I]; } + + CK_TILE_HOST_DEVICE constexpr auto& at(index_t i) { return data[i]; } + CK_TILE_HOST_DEVICE constexpr const auto& at(index_t i) const { return data[i]; } + template CK_TILE_HOST_DEVICE constexpr auto& at() { return data[I]; } + template CK_TILE_HOST_DEVICE constexpr const auto& at() const { return data[I]; } + template CK_TILE_HOST_DEVICE constexpr auto& at(number) { return data[I]; } + template CK_TILE_HOST_DEVICE constexpr const auto& at(number) const { return data[I]; } + + CK_TILE_HOST_DEVICE constexpr const value_type& operator[](index_t i) const { return data[i]; } + CK_TILE_HOST_DEVICE constexpr value_type& operator[](index_t i) { return data[i]; } + CK_TILE_HOST_DEVICE constexpr value_type& operator()(index_t i) { return data[i]; } // TODO: compatible + + template + CK_TILE_HOST_DEVICE constexpr auto operator=(const T& a) + { + static_assert(T::size() == size(), "wrong! size not the same"); + for(index_t i = 0; i < size(); ++i) + { + data[i] = a[i]; + } + return *this; + } + + // type punning (strict aliasing) member functions for read/write + // aliasing this array of type "T", "N" elements + // as array of type "Tx", sizeof(T)*N/sizeof(Tx) elements +#define AR_AS_COM_() \ + static_assert(sizeof(value_type) * N % sizeof(Tx) == 0); \ + constexpr int vx = sizeof(value_type) * N / sizeof(Tx) + + template CK_TILE_HOST_DEVICE constexpr auto& get_as() + { AR_AS_COM_(); return reinterpret_cast&>(data); } + template CK_TILE_HOST_DEVICE constexpr const auto& get_as() const + { AR_AS_COM_(); return reinterpret_cast&>(data); } + + // below index is for index *AFTER* type convert, not before + template CK_TILE_HOST_DEVICE constexpr auto& get_as(index_t i) + { AR_AS_COM_(); return reinterpret_cast&>(data).at(i); } + template CK_TILE_HOST_DEVICE constexpr const auto& get_as(index_t i) const + { AR_AS_COM_(); return reinterpret_cast&>(data).at(i); } + template CK_TILE_HOST_DEVICE constexpr auto& get_as(number) + { AR_AS_COM_(); return reinterpret_cast&>(data).at(number{}); } + template CK_TILE_HOST_DEVICE constexpr const auto& get_as(number) const + { AR_AS_COM_(); return reinterpret_cast&>(data).at(number{}); } + + template CK_TILE_HOST_DEVICE constexpr void set_as(index_t i, const Tx & x) + { AR_AS_COM_(); reinterpret_cast&>(data).at(i) = x; } + template CK_TILE_HOST_DEVICE constexpr void set_as(number, const Tx & x) + { AR_AS_COM_(); reinterpret_cast&>(data).at(number{}) = x; } +#undef AR_AS_COM_ + // clang-format on +}; + +// empty Array + +template +struct array +{ + using value_type = T; + + CK_TILE_HOST_DEVICE constexpr array() {} + CK_TILE_HOST_DEVICE static constexpr index_t size() { return 0; } + CK_TILE_HOST_DEVICE static constexpr bool is_static() { return is_static_v; }; + CK_TILE_HOST_DEVICE void print() const { printf("array{size: 0, data: []}"); } +}; + +template +CK_TILE_HOST_DEVICE constexpr auto make_array(T&& x, Ts&&... xs) +{ + using value_type = remove_cvref_t; + return array{std::forward(x), std::forward(xs)...}; +} + +// make empty array +template +CK_TILE_HOST_DEVICE constexpr auto make_array() +{ + return array{}; +} + +// compatible with old ck's initializer, make an array and fill it withe the last element from +// initializer_list +#include +template +CK_TILE_HOST_DEVICE constexpr auto make_array_with(std::initializer_list ilist) +{ + constexpr index_t list_size = std::initializer_list{}.size(); + + static_assert(list_size <= Size, "out of bound"); + + index_t i = 0; + T vlast = T{}; + array arr; + + for(const T& val : ilist) + { + arr.data[i] = val; + vlast = val; + ++i; + } + + for(; i < Size; ++i) + { + arr.data[i] = vlast; + } + + return arr; +} + +template +CK_TILE_HOST_DEVICE constexpr bool operator==(const array& a, const array& b) +{ + bool same = true; + + for(index_t i = 0; i < Size; ++i) + { + if(a[i] != b[i]) + { + same = false; + break; + } + } + + return same; +} + +template +CK_TILE_HOST_DEVICE constexpr bool operator!=(const array& a, const array& b) +{ + return !(a == b); +} + +template +CK_TILE_HOST_DEVICE constexpr auto to_array(const X& x) +{ + STATIC_ASSERT(N <= X::size(), ""); + + array arr; + + static_for<0, N, 1>{}([&x, &arr](auto i) { arr(i) = x[i]; }); + + return arr; +} + +} // namespace ck_tile diff --git a/include/ck_tile/core/container/container_helper.hpp b/include/ck_tile/core/container/container_helper.hpp new file mode 100644 index 0000000000..88405f6fcb --- /dev/null +++ b/include/ck_tile/core/container/container_helper.hpp @@ -0,0 +1,483 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core/config.hpp" +#include "ck_tile/core/container/array.hpp" +#include "ck_tile/core/container/map.hpp" +#include "ck_tile/core/container/sequence.hpp" +#include "ck_tile/core/container/tuple.hpp" +#include "ck_tile/core/utility/functional.hpp" + +namespace ck_tile { + +template +CK_TILE_HOST_DEVICE constexpr auto container_push_back(const array& a, const TData& x) +{ + array r; + static_for<0, NSize, 1>{}([&r, &a ](auto i) constexpr { r(i) = a[i]; }); + r[number{}] = x; + return r; +} + +template +CK_TILE_HOST_DEVICE constexpr auto container_push_front(const tuple& a, const T& x) +{ + return container_concat(make_tuple(x), a); +} + +template +CK_TILE_HOST_DEVICE constexpr auto container_push_back(const tuple& a, const T& x) +{ + return container_concat(a, make_tuple(x)); +} + +// reorder array +template +CK_TILE_HOST_DEVICE constexpr auto +container_reorder_given_new2old(const array& old_array, sequence /*new2old*/) +{ + static_assert(NSize == sizeof...(IRs), "wrong! size not consistent"); + static_assert(is_valid_sequence_map>{}, "wrong! invalid reorder map"); + return make_array>(old_array[IRs]...); +} + +template +CK_TILE_HOST_DEVICE constexpr auto +container_reorder_given_old2new(const array& old_array, sequence old2new) +{ + return container_reorder_given_new2old( + old_array, typename sequence_map_inverse::type{}); +} + +// reorder array +template +CK_TILE_HOST_DEVICE constexpr auto +container_reorder_given_new2old(const array& old_array, + const map& new2old) +{ + array new_array; + + for(const auto& [new_pos, old_pos] : new2old) + { + new_array(new_pos) = old_array[old_pos]; + } + + return new_array; +} + +template +CK_TILE_HOST_DEVICE constexpr auto +container_reorder_given_old2new(const array& old_array, + const map& old2new) +{ + array new_array; + + for(const auto& [old_pos, new_pos] : old2new) + { + new_array(new_pos) = old_array[old_pos]; + } + + return new_array; +} + +// reorder tuple +template +CK_TILE_HOST_DEVICE constexpr auto container_reorder_given_new2old(const tuple& old_tuple, + sequence /*new2old*/) +{ + static_assert(sizeof...(Ts) == sizeof...(IRs), "wrong! size not consistent"); + + static_assert(is_valid_sequence_map>{}, "wrong! invalid reorder map"); + + return make_tuple(old_tuple[number{}]...); +} + +template +CK_TILE_HOST_DEVICE constexpr auto container_reorder_given_old2new(const tuple& old_tuple, + sequence old2new) +{ + return container_reorder_given_new2old( + old_tuple, typename sequence_map_inverse::type{}); +} + +// reorder sequence +template +CK_TILE_HOST_DEVICE constexpr auto container_reorder_given_new2old(sequence /* old_seq */, + sequence /*new2old*/) +{ + static_assert(sizeof...(Is) == sizeof...(IRs), "wrong! size not consistent"); + + static_assert(is_valid_sequence_map>{}, "wrong! invalid reorder map"); + + return sequence::at(number{})...>{}; +} + +template +CK_TILE_HOST_DEVICE constexpr auto container_reorder_given_old2new(sequence old_seq, + sequence /* old2new */) +{ + static_assert(sizeof...(Is) == sizeof...(IRs), "wrong! size not consistent"); + + static_assert(is_valid_sequence_map>{}, "wrong! invalid reorder map"); + + constexpr auto new2old = typename sequence_map_inverse>::type{}; + + return container_reorder_given_new2old(old_seq, new2old); +} + +#if 0 +// rocm-4.1 compiler would crash for recursive lambda +template +CK_TILE_HOST_DEVICE constexpr auto container_reduce(const Container& x, + Reduce reduce, + Init init, + number = number<0>{}, + number = number{}, + number = number<1>{}) +{ + static_assert((IEnd - IBegin) % IStep == 0, "wrong!"); + + // f is recursive function, fs is a dummy of f + // i is index, y_old is current scan, r_old is current reduction + auto f = [&](auto fs, auto i, auto r_old) { + auto r_new = reduce(x[i], r_old); + + if constexpr(i.value < IEnd - IStep) + { + // recursively call f/fs + return fs(fs, i + number{}, r_new); + } + else + { + return r_new; + } + }; + + // start recursion + return f(f, number{}, init); +} +#else +// i is index, y_old is current scan, r_old is current reduction +template +CK_TILE_HOST_DEVICE constexpr auto container_reduce_impl( + const Container& x, Reduce reduce, ROld r_old, number i, number, number) +{ + auto r_new = reduce(x[i], r_old); + + if constexpr(i.value < IEnd - IStep) + { + return container_reduce_impl( + x, reduce, r_new, i + number{}, number{}, number{}); + } + else + { + return r_new; + } +} + +// rocm-4.1 compiler would crash for recursive lambda +// container reduce with initial value +template +CK_TILE_HOST_DEVICE constexpr auto container_reduce(const Container& x, + Reduce reduce, + Init init, + number = number<0>{}, + number = number{}, + number = number<1>{}) +{ + static_assert((IEnd - IBegin) % IStep == 0, "wrong!"); + + if constexpr(IEnd > IBegin) + { + return container_reduce_impl( + x, reduce, init, number{}, number{}, number{}); + } + else + { + return init; + } +} +#endif + +template +CK_TILE_HOST_DEVICE constexpr auto +container_reverse_inclusive_scan(const array& x, Reduce f, TData init) +{ + array y; + + TData r = init; + + static_for{}([&](auto i) { + r = f(r, x[i]); + y(i) = r; + }); + + r = f(r, x[number<0>{}]); + y(number<0>{}) = r; + + return y; +} + +template +CK_TILE_HOST_DEVICE constexpr auto +container_reverse_exclusive_scan(const array& x, Reduce f, Init init) +{ +#if 0 + array y; + + TData r = init; + + static_for{}([&](auto i) { + y(i) = r; + r = f(r, x[i]); + }); + + y(number<0>{}) = r; + + return y; +#else + array y; + + TData r = init; + + for(index_t i = NSize - 1; i > 0; --i) + { + y(i) = r; + r = f(r, x[i]); + } + + y(0) = r; + + return y; +#endif +} + +template +CK_TILE_HOST_DEVICE constexpr auto +container_reverse_exclusive_scan(const sequence& seq, Reduce f, number) +{ + return reverse_exclusive_scan_sequence(seq, f, number{}); +} + +#if 0 +// rocm4.1 compiler would crash with recursive lambda +template +CK_TILE_HOST_DEVICE constexpr auto +container_reverse_exclusive_scan(const tuple& x, Reduce reduce, Init init) +{ + constexpr index_t NSize = sizeof...(Xs); + + // f is recursive function, fs is a dummy of f + // i is index, y_old is current scan, r_old is current reduction + auto f = [&](auto fs, auto i, auto y_old, auto r_old) { + auto r_new = reduce(x[i], r_old); + + auto y_new = container_push_front(y_old, r_new); + + if constexpr(i.value > 1) + { + // recursively call f/fs + return fs(fs, i - number<1>{}, y_new, r_new); + } + else + { + return y_new; + } + }; + + // start recursion + return f(f, number{}, make_tuple(init), init); +} +#else +// i is index, y_old is current scan, r_old is current reduction +template +CK_TILE_HOST_DEVICE constexpr auto container_reverse_exclusive_scan_impl( + const tuple& x, Reduce reduce, number i, YOld y_old, ROld r_old) +{ + auto r_new = reduce(x[i], r_old); + + auto y_new = container_push_front(y_old, r_new); + + if constexpr(i.value > 1) + { + // recursively call f/fs + return container_reverse_exclusive_scan_impl(x, reduce, i - number<1>{}, y_new, r_new); + } + else + { + return y_new; + } +} + +template +CK_TILE_HOST_DEVICE constexpr auto +container_reverse_exclusive_scan(const tuple& x, Reduce reduce, Init init) +{ + constexpr index_t NSize = sizeof...(Xs); + + return container_reverse_exclusive_scan_impl( + x, reduce, number{}, make_tuple(init), init); +} +#endif + +// TODO: update to like container_reverse_exclusive_scan to deal with tuple of Numebr<> +template +CK_TILE_HOST_DEVICE constexpr auto +container_reverse_inclusive_scan(const tuple& x, Reduce f, TData init) +{ + constexpr index_t NSize = sizeof...(Xs); + + tuple y; + + TData r = init; + + static_for{}([&](auto i) { + r = f(r, x[i]); + y(i) = r; + }); + + r = f(r, x[number<0>{}]); + y(number<0>{}) = r; + + return y; +} + +template +CK_TILE_HOST_DEVICE constexpr auto container_concat(const X& x, const Ys&... ys) +{ + return container_concat(x, container_concat(ys...)); +} + +template +CK_TILE_HOST_DEVICE constexpr auto container_concat(const array& ax, const array& ay) +{ + return unpack2( + [&](auto&&... zs) { return make_array(std::forward(zs)...); }, ax, ay); +} + +template +CK_TILE_HOST_DEVICE constexpr auto container_concat(const tuple& tx, const tuple& ty) +{ + return unpack2( + [&](auto&&... zs) { return make_tuple(std::forward(zs)...); }, tx, ty); +} + +template +CK_TILE_HOST_DEVICE constexpr auto container_concat(const Container& x) +{ + return x; +} + +template +CK_TILE_HOST_DEVICE constexpr auto get_container_subset(const array& arr, sequence) +{ + static_assert(N >= sizeof...(Is), "wrong! size"); + + if constexpr(sizeof...(Is) > 0) + { + return make_array(arr[Is]...); + } + else + { + return array{}; + } +} + +template +CK_TILE_HOST_DEVICE constexpr auto get_container_subset(const tuple& tup, sequence) +{ + static_assert(sizeof...(Ts) >= sizeof...(Is), "wrong! size"); + + if constexpr(sizeof...(Is) > 0) + { + return make_tuple(tup[number{}]...); + } + else + { + return tuple<>{}; + } +} + +template +CK_TILE_HOST_DEVICE constexpr void +set_container_subset(array& y, sequence picks, const array& x) +{ + static_assert(N >= sizeof...(Is), "wrong! size"); + + if constexpr(sizeof...(Is) > 0) + { + for(index_t i = 0; i < picks.size(); ++i) + { + y(picks[i]) = x[i]; + } + } +} + +template +CK_TILE_HOST_DEVICE constexpr void set_container_subset(Y& y, sequence picks, const X& x) +{ + static_assert(Y::size() >= sizeof...(Is) && X::size() == sizeof...(Is), "wrong! size"); + + if constexpr(sizeof...(Is) > 0) + { + static_for<0, sizeof...(Is), 1>{}([&](auto i) { y(picks[i]) = x[i]; }); + } +} + +// return the index of first occurance in the sequence. +// return seq.size(), if not found +template +constexpr index_t container_find(sequence seq, index_t value) +{ + for(auto i = 0; i < seq.size(); i++) + { + if(seq[i] == value) + return i; + } + + return seq.size(); +} + +template +CK_TILE_HOST_DEVICE constexpr auto sequence_to_tuple_of_number(sequence) +{ + using Seq = sequence; + + return generate_tuple( + [&](auto i) { + constexpr index_t tmp = Seq::at(i); + return number{}; + }, + number{}); +} + +#define TO_TUPLE_OF_SEQUENCE(a_of_b_impl, a_size, bs_sizes) \ + [a_of_b_impl, a_size, bs_sizes] { \ + return ck_tile::generate_tuple( \ + [=](auto i) { \ + constexpr auto b_impl = a_of_b_impl[i]; \ + constexpr index_t b_size = bs_sizes[i]; \ + constexpr auto b = TO_SEQUENCE(b_impl, b_size); \ + return b; \ + }, \ + ck_tile::number{}); \ + }() + +} // namespace ck_tile diff --git a/include/ck_tile/core/container/map.hpp b/include/ck_tile/core/container/map.hpp new file mode 100644 index 0000000000..25e065c3c1 --- /dev/null +++ b/include/ck_tile/core/container/map.hpp @@ -0,0 +1,164 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core/config.hpp" +#include "ck_tile/core/container/array.hpp" +#include "ck_tile/core/container/sequence.hpp" +#include "ck_tile/core/container/tuple.hpp" + +namespace ck_tile { + +// naive map +template +struct map +{ + using pair_type = tuple; + using impl_type = array; + + impl_type impl_; + index_t size_; + + struct iterator + { + impl_type& impl_; + index_t pos_; + + CK_TILE_HOST_DEVICE constexpr iterator(impl_type& impl, index_t pos) + : impl_{impl}, pos_{pos} + { + } + + CK_TILE_HOST_DEVICE constexpr iterator& operator++() + { + pos_++; + return *this; + } + + CK_TILE_HOST_DEVICE constexpr bool operator!=(const iterator& other) const + { + return other.pos_ != pos_; + } + + CK_TILE_HOST_DEVICE constexpr pair_type& operator*() { return impl_.at(pos_); } + }; + + struct const_iterator + { + const impl_type& impl_; + index_t pos_; + + CK_TILE_HOST_DEVICE constexpr const_iterator(const impl_type& impl, index_t pos) + : impl_{impl}, pos_{pos} + { + } + + CK_TILE_HOST_DEVICE constexpr const_iterator& operator++() + { + pos_++; + + return *this; + } + + CK_TILE_HOST_DEVICE constexpr bool operator!=(const const_iterator& other) const + { + return other.pos_ != pos_; + } + + CK_TILE_HOST_DEVICE constexpr const pair_type& operator*() const { return impl_.at(pos_); } + }; + + CK_TILE_HOST_DEVICE constexpr map() : impl_{}, size_{0} {} + + CK_TILE_HOST_DEVICE constexpr index_t size() const { return size_; } + + CK_TILE_HOST_DEVICE void clear() { size_ = 0; } + + CK_TILE_HOST_DEVICE constexpr index_t find_position(const key& key) const + { + for(index_t i = 0; i < size(); i++) + { + if(impl_[i].template at<0>() == key) + { + return i; + } + } + + return size_; + } + + CK_TILE_HOST_DEVICE constexpr const_iterator find(const key& key) const + { + return const_iterator{impl_, find_position(key)}; + } + + CK_TILE_HOST_DEVICE constexpr iterator find(const key& key) + { + return iterator{impl_, find_position(key)}; + } + + CK_TILE_HOST_DEVICE constexpr const data& operator[](const key& key) const + { + const auto it = find(key); + + // FIXME + assert(it.pos_ < size()); + + return impl_[it.pos_].template at<1>(); + } + + CK_TILE_HOST_DEVICE constexpr data& operator()(const key& key) + { + auto it = find(key); + + // if entry not found + if(it.pos_ == size()) + { + impl_(it.pos_).template at<0>() = key; + size_++; + } + + // FIXME + assert(size_ <= max_size); + + return impl_(it.pos_).template at<1>(); + } + + // WARNING: needed by compiler for C++ range-based for loop only, don't use this function! + CK_TILE_HOST_DEVICE constexpr const_iterator begin() const { return const_iterator{impl_, 0}; } + + // WARNING: needed by compiler for C++ range-based for loop only, don't use this function! + CK_TILE_HOST_DEVICE constexpr const_iterator end() const + { + return const_iterator{impl_, size_}; + } + + // WARNING: needed by compiler for C++ range-based for loop only, don't use this function! + CK_TILE_HOST_DEVICE constexpr iterator begin() { return iterator{impl_, 0}; } + + // WARNING: needed by compiler for C++ range-based for loop only, don't use this function! + CK_TILE_HOST_DEVICE constexpr iterator end() { return iterator{impl_, size_}; } + + CK_TILE_HOST_DEVICE void print() const + { + printf("map{size_: %d, ", size_); + // + printf("impl_: ["); + // + for(const auto& [key, data] : *this) + { + printf("{key: "); + print(key); + printf(", data: "); + print(data); + printf("}, "); + } + // + printf("]"); + // + printf("}"); + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/core/container/meta_data_buffer.hpp b/include/ck_tile/core/container/meta_data_buffer.hpp new file mode 100644 index 0000000000..7493b93d80 --- /dev/null +++ b/include/ck_tile/core/container/meta_data_buffer.hpp @@ -0,0 +1,99 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core/config.hpp" +#include "ck_tile/core/container/array.hpp" +#include "ck_tile/core/utility/bit_cast.hpp" +#include + +namespace ck_tile { + +// TODO: this structure is not intented to be used by user +template +struct meta_data_buffer +{ + CK_TILE_HOST_DEVICE constexpr meta_data_buffer() : buffer_{}, size_{0} {} + + template + CK_TILE_HOST_DEVICE constexpr meta_data_buffer(const X& x, const Xs&... xs) + : buffer_{}, size_{0} + { + push(x, xs...); + } + + template + CK_TILE_HOST_DEVICE constexpr void push(const T& data) + { + if constexpr(!std::is_empty_v) + { + constexpr index_t size = sizeof(T); + + auto tmp = bit_cast>(data); + + for(int i = 0; i < size; i++) + { + buffer_(size_) = tmp[i]; + + size_++; + } + } + } + + template + CK_TILE_HOST_DEVICE constexpr void push(const X& x, const Xs&... xs) + { + push(x); + push(xs...); + } + + template + CK_TILE_HOST_DEVICE constexpr T pop(index_t& pos) const + { + T data; + + if constexpr(!std::is_empty_v) + { + constexpr index_t size = sizeof(T); + + array tmp; + + for(int i = 0; i < size; i++) + { + tmp(i) = buffer_[pos]; + + pos++; + } + + data = bit_cast(tmp); + } + + return data; + } + + template + CK_TILE_HOST_DEVICE constexpr T get(index_t pos) const + { + constexpr index_t size = sizeof(T); + + array tmp; + + for(int i = 0; i < size; i++) + { + tmp(i) = buffer_[pos]; + + pos++; + } + + auto data = bit_cast(tmp); + + return data; + } + + // + array buffer_; + index_t size_ = 0; +}; + +} // namespace ck_tile diff --git a/include/ck_tile/core/container/multi_index.hpp b/include/ck_tile/core/container/multi_index.hpp new file mode 100644 index 0000000000..b78c35a8a5 --- /dev/null +++ b/include/ck_tile/core/container/multi_index.hpp @@ -0,0 +1,99 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core/config.hpp" +#include "ck_tile/core/container/array.hpp" +#include "ck_tile/core/container/sequence.hpp" +#include "ck_tile/core/container/tuple.hpp" +#include "ck_tile/core/utility/functional.hpp" + +namespace ck_tile { + +// deprecated, always use array instead +template +using multi_index = array; + +template +CK_TILE_HOST_DEVICE constexpr auto make_multi_index(Xs&&... xs) +{ + return make_array(index_t{xs}...); +} + +template +CK_TILE_HOST_DEVICE constexpr auto make_zero_multi_index() +{ + return unpack([](auto... xs) { return make_multi_index(xs...); }, + typename uniform_sequence_gen::type{}); +} + +template +CK_TILE_HOST_DEVICE constexpr auto to_multi_index(const T& x) +{ + return unpack([](auto... ys) { return make_multi_index(ys...); }, x); +} + +template +CK_TILE_HOST_DEVICE constexpr auto operator+=(multi_index& y, const X& x) +{ + static_assert(X::size() == NSize, "wrong! size not the same"); + static_for<0, NSize, 1>{}([&](auto i) { y[i] += x[i]; }); + return y; +} + +template +CK_TILE_HOST_DEVICE constexpr auto operator-=(multi_index& y, const X& x) +{ + static_assert(X::size() == NSize, "wrong! size not the same"); + static_for<0, NSize, 1>{}([&](auto i) { y[i] -= x[i]; }); + return y; +} + +template +CK_TILE_HOST_DEVICE constexpr auto operator+(const multi_index& a, const T& b) +{ + using type = multi_index; + static_assert(T::size() == NSize, "wrong! size not the same"); + type r; + static_for<0, NSize, 1>{}([&](auto i) { r[i] = a[i] + b[i]; }); + return r; +} + +template +CK_TILE_HOST_DEVICE constexpr auto operator-(const multi_index& a, const T& b) +{ + using type = multi_index; + static_assert(T::size() == NSize, "wrong! size not the same"); + type r; + static_for<0, NSize, 1>{}([&](auto i) { r[i] = a[i] - b[i]; }); + return r; +} + +template +CK_TILE_HOST_DEVICE constexpr auto operator*(const multi_index& a, const T& b) +{ + using type = multi_index; + static_assert(T::size() == NSize, "wrong! size not the same"); + type r; + static_for<0, NSize, 1>{}([&](auto i) { r[i] = a[i] * b[i]; }); + return r; +} + +// multi_index = index_t * multi_index +template +CK_TILE_HOST_DEVICE constexpr auto operator*(index_t a, const multi_index& x) +{ + multi_index r; + static_for<0, NSize, 1>{}([&](auto i) { r[i] = a * x[i]; }); + return r; +} + +// multi_index = multi_index * index_t +template +CK_TILE_HOST_DEVICE constexpr auto operator*(const multi_index& x, index_t a) +{ + return a * x; +} + +} // namespace ck_tile diff --git a/include/ck_tile/core/container/sequence.hpp b/include/ck_tile/core/container/sequence.hpp new file mode 100644 index 0000000000..dc1971a330 --- /dev/null +++ b/include/ck_tile/core/container/sequence.hpp @@ -0,0 +1,1114 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core/config.hpp" +#include "ck_tile/core/container/array.hpp" +#include "ck_tile/core/numeric/integer.hpp" +#include "ck_tile/core/numeric/integral_constant.hpp" +#include "ck_tile/core/numeric/math.hpp" +#include "ck_tile/core/utility/to_sequence.hpp" + +namespace ck_tile { + +template +struct static_for; + +template +struct sequence; + +template +struct sequence_split; + +template +struct sequence_reverse; + +template +struct sequence_map_inverse; + +template +struct is_valid_sequence_map; + +template +CK_TILE_HOST_DEVICE constexpr auto sequence_pop_front(sequence); + +template +CK_TILE_HOST_DEVICE constexpr auto sequence_pop_back(Seq); + +namespace impl { +// static_assert(__has_builtin(__type_pack_element), "can't find __type_pack_element"); +template +using at_index_t = __type_pack_element; +} // namespace impl + +// we could implement as below, similiar to std. But let's reduce the symbol name... +// template< class T, T... Ints > +// class integer_sequence; + +template +struct sequence +{ + using type = sequence; + using value_type = index_t; + + CK_TILE_HOST_DEVICE static constexpr index_t size() { return sizeof...(Is); } + CK_TILE_HOST_DEVICE static constexpr bool is_static() { return true; }; + + CK_TILE_HOST_DEVICE static constexpr index_t at(index_t I) + { + // the last dummy element is to prevent compiler complain about empty array, when mSize = 0 + static_assert(I < size(), "wrong! I too large"); + const index_t mData[mSize + 1] = {Is..., 0}; + return mData[I]; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto get() + { + static_assert(I < size(), "wrong! I too large"); + return number...>{}>{}; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto get(number) + { + static_assert(I < size(), "wrong! I too large"); + return number{}; + } + + CK_TILE_HOST_DEVICE static constexpr index_t at(index_t I) + { + // the last dummy element is to prevent compiler complain about empty array, when mSize = 0 + static_assert(I < size(), "wrong! I too large"); + const index_t mData[mSize + 1] = {Is..., 0}; + return mData[I]; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto at() + { + static_assert(I < size(), "wrong! I too large"); + return number...>{}>{}; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto at(number) + { + static_assert(I < size(), "wrong! I too large"); + return number{}; + } + + template + CK_TILE_HOST_DEVICE constexpr auto operator[](I i) const + { + return at(i); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto reorder_new_to_old(sequence /*new2old*/) + { + static_assert(sizeof...(Is) == sizeof...(IRs), + "wrong! reorder map should have the same size as sequence to be rerodered"); + + static_assert(is_valid_sequence_map>::value, "wrong! invalid reorder map"); + + return sequence{})...>{}; + } + + // MapOld2New is sequence<...> + template + CK_TILE_HOST_DEVICE static constexpr auto reorder_old_to_new(MapOld2New) + { + static_assert(MapOld2New::size() == size(), + "wrong! reorder map should have the same size as sequence to be rerodered"); + + static_assert(is_valid_sequence_map::value, "wrong! invalid reorder map"); + + return reorder_new_to_old(typename sequence_map_inverse::type{}); + } + + CK_TILE_HOST_DEVICE static constexpr auto reverse() + { + return typename sequence_reverse::type{}; + } + + CK_TILE_HOST_DEVICE static constexpr auto front() + { + static_assert(size() > 0, "wrong!"); + return get(number<0>{}); + } + + CK_TILE_HOST_DEVICE static constexpr auto back() + { + static_assert(size() > 0, "wrong!"); + return get(number{}); + } + + CK_TILE_HOST_DEVICE static constexpr auto pop_front() { return sequence_pop_front(type{}); } + + CK_TILE_HOST_DEVICE static constexpr auto pop_back() { return sequence_pop_back(type{}); } + + template + CK_TILE_HOST_DEVICE static constexpr auto push_front(sequence) + { + return sequence{}; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto push_front(number...) + { + return sequence{}; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto push_back(sequence) + { + return sequence{}; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto push_back(number...) + { + return sequence{}; + } + + // pickup element at index + template + CK_TILE_HOST_DEVICE static constexpr auto extract(number...) + { + return sequence{})...>{}; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto extract(sequence) + { + return sequence{})...>{}; + } + + // modify element at index "I" with value "X" + template + CK_TILE_HOST_DEVICE static constexpr auto modify(number, number) + { + static_assert(I < size(), "wrong!"); + + using seq_split = sequence_split; + constexpr auto seq_left = typename seq_split::left_type{}; + constexpr auto seq_right = typename seq_split::right_type{}.pop_front(); + + return seq_left.push_back(number{}).push_back(seq_right); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto transform(F f) + { + return sequence{}; + } + + CK_TILE_HOST_DEVICE static void print() + { + printf("sequence{size: %d, data: [", size()); + ((printf("%d ", Is)), ...); + printf("]}"); + } +}; + +namespace impl { +template +struct __integer_sequence; + +template +struct __integer_sequence +{ + using seq_type = sequence; +}; +} // namespace impl + +// similiar +template +using make_index_sequence = + typename __make_integer_seq::seq_type; + +// merge sequence +template +struct sequence_merge +{ + using type = typename sequence_merge::type>::type; +}; + +template +struct sequence_merge, sequence> +{ + using type = sequence; +}; + +template +struct sequence_merge +{ + using type = Seq; +}; + +// generate sequence +template +struct sequence_gen +{ + template + struct sequence_gen_impl + { + static constexpr index_t NRemainLeft = NRemain / 2; + static constexpr index_t NRemainRight = NRemain - NRemainLeft; + static constexpr index_t IMiddle = IBegin + NRemainLeft; + + using type = typename sequence_merge< + typename sequence_gen_impl::type, + typename sequence_gen_impl::type>::type; + }; + + template + struct sequence_gen_impl + { + static constexpr index_t Is = G{}(number{}); + using type = sequence; + }; + + template + struct sequence_gen_impl + { + using type = sequence<>; + }; + + using type = typename sequence_gen_impl<0, NSize, F>::type; +}; + +// arithmetic sequence +template +struct arithmetic_sequence_gen +{ + struct F + { + CK_TILE_HOST_DEVICE constexpr index_t operator()(index_t i) const + { + return i * Increment + IBegin; + } + }; + + using type0 = typename sequence_gen<(IEnd - IBegin) / Increment, F>::type; + using type1 = sequence<>; + + static constexpr bool kHasContent = + (Increment > 0 && IBegin < IEnd) || (Increment < 0 && IBegin > IEnd); + + using type = typename conditional::type; +}; + +template +struct arithmetic_sequence_gen<0, IEnd, 1> +{ + using type = make_index_sequence; +}; + +// uniform sequence +template +struct uniform_sequence_gen +{ + struct F + { + CK_TILE_HOST_DEVICE constexpr index_t operator()(index_t) const { return I; } + }; + + using type = typename sequence_gen::type; +}; + +// reverse inclusive scan (with init) sequence +template +struct sequence_reverse_inclusive_scan; + +template +struct sequence_reverse_inclusive_scan, Reduce, Init> +{ + using old_scan = typename sequence_reverse_inclusive_scan, Reduce, Init>::type; + + static constexpr index_t new_reduce = Reduce{}(I, old_scan{}.front()); + + using type = typename sequence_merge, old_scan>::type; +}; + +template +struct sequence_reverse_inclusive_scan, Reduce, Init> +{ + using type = sequence; +}; + +template +struct sequence_reverse_inclusive_scan, Reduce, Init> +{ + using type = sequence<>; +}; + +// split sequence +template +struct sequence_split +{ + static constexpr index_t NSize = Seq{}.size(); + + using range0 = typename arithmetic_sequence_gen<0, I, 1>::type; + using range1 = typename arithmetic_sequence_gen::type; + + using left_type = decltype(Seq::extract(range0{})); + using right_type = decltype(Seq::extract(range1{})); +}; + +#if 0 +// reverse sequence +template +struct sequence_reverse +{ + static constexpr index_t NSize = Seq{}.size(); + + using seq_split = sequence_split; + using type = typename sequence_merge< + typename sequence_reverse::type, + typename sequence_reverse::type>::type; +}; + +template +struct sequence_reverse> +{ + using type = sequence; +}; + +template +struct sequence_reverse> +{ + using type = sequence; +}; +#endif + +namespace impl { +template +struct seq_reverse; + +template +struct seq_reverse, Ns...> +{ + template + using element = impl::at_index_t...>; + using type = sequence::value...>; +}; +} // namespace impl + +template +struct sequence_reverse> + : impl::seq_reverse, Ns...> +{ +}; + +template +using sequence_reverse_t = typename sequence_reverse::type; + +#if 1 +template +struct sequence_reduce +{ + using type = typename sequence_reduce::type>::type; +}; + +template +struct sequence_reduce, sequence> +{ + using type = sequence; +}; + +template +struct sequence_reduce +{ + using type = Seq; +}; +#endif + +template +struct sequence_sort_impl +{ + template + struct sorted_sequence_merge_impl + { + static constexpr bool choose_left = LeftValues::front() < RightValues::front(); + + static constexpr index_t chosen_value = + choose_left ? LeftValues::front() : RightValues::front(); + static constexpr index_t chosen_id = choose_left ? LeftIds::front() : RightIds::front(); + + using new_merged_values = decltype(MergedValues::push_back(number{})); + using new_merged_ids = decltype(MergedIds::push_back(number{})); + + using new_left_values = + typename conditional::type; + using new_left_ids = + typename conditional::type; + + using new_right_values = typename conditional::type; + using new_right_ids = + typename conditional::type; + + using merge = sorted_sequence_merge_impl; + // this is output + using merged_values = typename merge::merged_values; + using merged_ids = typename merge::merged_ids; + }; + + template + struct sorted_sequence_merge_impl, + sequence<>, + MergedValues, + MergedIds, + Comp> + { + using merged_values = typename sequence_merge::type; + using merged_ids = typename sequence_merge::type; + }; + + template + struct sorted_sequence_merge_impl, + sequence<>, + RightValues, + RightIds, + MergedValues, + MergedIds, + Comp> + { + using merged_values = typename sequence_merge::type; + using merged_ids = typename sequence_merge::type; + }; + + template + struct sorted_sequence_merge + { + using merge = sorted_sequence_merge_impl, + sequence<>, + Comp>; + + using merged_values = typename merge::merged_values; + using merged_ids = typename merge::merged_ids; + }; + + static constexpr index_t nsize = Values::size(); + + using split_unsorted_values = sequence_split; + using split_unsorted_ids = sequence_split; + + using left_unsorted_values = typename split_unsorted_values::left_type; + using left_unsorted_ids = typename split_unsorted_ids::left_type; + using left_sort = sequence_sort_impl; + using left_sorted_values = typename left_sort::sorted_values; + using left_sorted_ids = typename left_sort::sorted_ids; + + using right_unsorted_values = typename split_unsorted_values::right_type; + using right_unsorted_ids = typename split_unsorted_ids::right_type; + using right_sort = sequence_sort_impl; + using right_sorted_values = typename right_sort::sorted_values; + using right_sorted_ids = typename right_sort::sorted_ids; + + using merged_sorted = sorted_sequence_merge; + + using sorted_values = typename merged_sorted::merged_values; + using sorted_ids = typename merged_sorted::merged_ids; +}; + +template +struct sequence_sort_impl, sequence, Compare> +{ + static constexpr bool choose_x = Compare{}(ValueX, ValueY); + + using sorted_values = + typename conditional, sequence>::type; + using sorted_ids = typename conditional, sequence>::type; +}; + +template +struct sequence_sort_impl, sequence, Compare> +{ + using sorted_values = sequence; + using sorted_ids = sequence; +}; + +template +struct sequence_sort_impl, sequence<>, Compare> +{ + using sorted_values = sequence<>; + using sorted_ids = sequence<>; +}; + +template +struct sequence_sort +{ + using unsorted_ids = typename arithmetic_sequence_gen<0, Values::size(), 1>::type; + using sort = sequence_sort_impl; + + // this is output + using type = typename sort::sorted_values; + using sorted2unsorted_map = typename sort::sorted_ids; +}; + +template +struct sequence_unique_sort +{ + template + struct sorted_sequence_uniquify_impl + { + static constexpr index_t current_value = RemainValues::front(); + static constexpr index_t current_id = RemainIds::front(); + + static constexpr bool is_unique_value = (current_value != UniquifiedValues::back()); + + using new_remain_values = decltype(RemainValues::pop_front()); + using new_remain_ids = decltype(RemainIds::pop_front()); + + using new_uniquified_values = + typename conditional{})), + UniquifiedValues>::type; + + using new_uniquified_ids = + typename conditional{})), + UniquifiedIds>::type; + + using uniquify = sorted_sequence_uniquify_impl; + + // this is output + using uniquified_values = typename uniquify::uniquified_values; + using uniquified_ids = typename uniquify::uniquified_ids; + }; + + template + struct sorted_sequence_uniquify_impl, + sequence<>, + UniquifiedValues, + UniquifiedIds, + Eq> + { + using uniquified_values = UniquifiedValues; + using uniquified_ids = UniquifiedIds; + }; + + template + struct sorted_sequence_uniquify + { + using uniquify = sorted_sequence_uniquify_impl, + sequence, + Eq>; + + using uniquified_values = typename uniquify::uniquified_values; + using uniquified_ids = typename uniquify::uniquified_ids; + }; + + using sort = sequence_sort; + using sorted_values = typename sort::type; + using sorted_ids = typename sort::sorted2unsorted_map; + + using uniquify = sorted_sequence_uniquify; + + // this is output + using type = typename uniquify::uniquified_values; + using sorted2unsorted_map = typename uniquify::uniquified_ids; +}; + +template +struct is_valid_sequence_map : is_same::type, + typename sequence_sort>::type> +{ +}; + +template +struct sequence_map_inverse +{ + template + struct sequence_map_inverse_impl + { + static constexpr auto new_y2x = + WorkingY2X::modify(X2Y::get(number{}), number{}); + + using type = + typename sequence_map_inverse_impl:: + type; + }; + + template + struct sequence_map_inverse_impl + { + using type = WorkingY2X; + }; + + using type = + typename sequence_map_inverse_impl::type, + 0, + SeqMap::size()>::type; +}; + +template +CK_TILE_HOST_DEVICE constexpr bool operator==(sequence, sequence) +{ + return ((Xs == Ys) && ...); +} + +template +CK_TILE_HOST_DEVICE constexpr bool operator!=(sequence x, sequence y) +{ + return !(x == y); +} + +template +CK_TILE_HOST_DEVICE constexpr auto operator+(sequence, sequence) +{ + static_assert(sizeof...(Xs) == sizeof...(Ys), "wrong! inconsistent size"); + + return sequence<(Xs + Ys)...>{}; +} + +template +CK_TILE_HOST_DEVICE constexpr auto operator-(sequence, sequence) +{ + static_assert(sizeof...(Xs) == sizeof...(Ys), "wrong! inconsistent size"); + + return sequence<(Xs - Ys)...>{}; +} + +template +CK_TILE_HOST_DEVICE constexpr auto operator*(sequence, sequence) +{ + static_assert(sizeof...(Xs) == sizeof...(Ys), "wrong! inconsistent size"); + + return sequence<(Xs * Ys)...>{}; +} + +template +CK_TILE_HOST_DEVICE constexpr auto operator/(sequence, sequence) +{ + static_assert(sizeof...(Xs) == sizeof...(Ys), "wrong! inconsistent size"); + + return sequence<(Xs / Ys)...>{}; +} + +template +CK_TILE_HOST_DEVICE constexpr auto operator%(sequence, sequence) +{ + static_assert(sizeof...(Xs) == sizeof...(Ys), "wrong! inconsistent size"); + + return sequence<(Xs % Ys)...>{}; +} + +template +CK_TILE_HOST_DEVICE constexpr auto operator+(sequence, number) +{ + return sequence<(Xs + Y)...>{}; +} + +template +CK_TILE_HOST_DEVICE constexpr auto operator-(sequence, number) +{ + return sequence<(Xs - Y)...>{}; +} + +template +CK_TILE_HOST_DEVICE constexpr auto operator*(sequence, number) +{ + return sequence<(Xs * Y)...>{}; +} + +template +CK_TILE_HOST_DEVICE constexpr auto operator/(sequence, number) +{ + return sequence<(Xs / Y)...>{}; +} + +template +CK_TILE_HOST_DEVICE constexpr auto operator%(sequence, number) +{ + return sequence<(Xs % Y)...>{}; +} + +template +CK_TILE_HOST_DEVICE constexpr auto operator+(number, sequence) +{ + return sequence<(Y + Xs)...>{}; +} + +template +CK_TILE_HOST_DEVICE constexpr auto operator-(number, sequence) +{ + return sequence<(Y - Xs)...>{}; +} + +template +CK_TILE_HOST_DEVICE constexpr auto operator*(number, sequence) +{ + return sequence<(Y * Xs)...>{}; +} + +template +CK_TILE_HOST_DEVICE constexpr auto operator/(number, sequence) +{ + return sequence<(Y / Xs)...>{}; +} + +template +CK_TILE_HOST_DEVICE constexpr auto operator%(number, sequence) +{ + return sequence<(Y % Xs)...>{}; +} + +template +CK_TILE_HOST_DEVICE constexpr auto sequence_pop_front(sequence) +{ + return sequence{}; +} + +template +CK_TILE_HOST_DEVICE constexpr auto sequence_pop_back(Seq) +{ + static_assert(Seq::size() > 0, "wrong! cannot pop an empty sequence!"); + return sequence_pop_front(Seq::reverse()).reverse(); +} + +template +CK_TILE_HOST_DEVICE constexpr auto merge_sequences(Seqs...) +{ + return typename sequence_merge::type{}; +} + +template +CK_TILE_HOST_DEVICE constexpr auto transform_sequences(F f, sequence) +{ + return sequence{}; +} + +template +CK_TILE_HOST_DEVICE constexpr auto transform_sequences(F f, sequence, sequence) +{ + static_assert(sequence::size() == sequence::size(), "Dim not the same"); + + return sequence{}; +} + +template +CK_TILE_HOST_DEVICE constexpr auto +transform_sequences(F f, sequence, sequence, sequence) +{ + static_assert(sequence::size() == sequence::size() && + sequence::size() == sequence::size(), + "Dim not the same"); + + return sequence{}; +} + +template +CK_TILE_HOST_DEVICE constexpr auto reverse_inclusive_scan_sequence(Seq, Reduce, number) +{ + return typename sequence_reverse_inclusive_scan::type{}; +} + +template +CK_TILE_HOST_DEVICE constexpr auto reverse_exclusive_scan_sequence(Seq, Reduce, number) +{ + return reverse_inclusive_scan_sequence(Seq::pop_front(), Reduce{}, number{}) + .push_back(number{}); +} + +template +CK_TILE_HOST_DEVICE constexpr auto inclusive_scan_sequence(Seq, Reduce, number) +{ + return reverse_inclusive_scan_sequence(Seq{}.reverse(), Reduce{}, number{}).reverse(); +} + +// e.g. Seq<2, 3, 4> --> Seq<0, 2, 5>, Init=0, Reduce=Add +// ResultSeq TargetSeq Reduce +template +struct sequence_exclusive_scan; + +template +struct sequence_exclusive_scan, sequence, Reduce> +{ + using old_scan = typename sequence_merge, + sequence{}.back())>>::type; + using type = typename sequence_exclusive_scan, Reduce>::type; +}; + +template +struct sequence_exclusive_scan, sequence, Reduce> +{ + using type = sequence; +}; + +template +struct sequence_exclusive_scan, sequence<>, Reduce> +{ + using type = sequence; +}; + +template +constexpr auto exclusive_scan_sequence(Seq, Reduce, number) +{ + // TODO: c++20 and later can pass in Reduce with a lambda expression + return typename sequence_exclusive_scan, Seq, Reduce>::type{}; +} + +template +constexpr auto prefix_sum_sequence(Seq) +{ + return typename sequence_exclusive_scan, + typename sequence_merge>::type, + math::plus>::type{}; +} + +template +CK_TILE_HOST_DEVICE constexpr auto pick_sequence_elements_by_ids(Seq, sequence /* ids */) +{ + return sequence{})...>{}; +} + +#if 1 +namespace detail { +template +struct pick_sequence_elements_by_mask_impl +{ + using new_work_seq = typename conditional::type; + + using type = + typename pick_sequence_elements_by_mask_impl::type; +}; + +template +struct pick_sequence_elements_by_mask_impl, sequence<>> +{ + using type = WorkSeq; +}; + +} // namespace detail + +template +CK_TILE_HOST_DEVICE constexpr auto pick_sequence_elements_by_mask(Seq, Mask) +{ + static_assert(Seq::size() == Mask::size(), "wrong!"); + + return typename detail::pick_sequence_elements_by_mask_impl, Seq, Mask>::type{}; +} + +namespace detail { +template +struct modify_sequence_elements_by_ids_impl +{ + using new_work_seq = decltype(WorkSeq::modify(RemainIds::front(), RemainValues::front())); + + using type = + typename modify_sequence_elements_by_ids_impl::type; +}; + +template +struct modify_sequence_elements_by_ids_impl, sequence<>> +{ + using type = WorkSeq; +}; +} // namespace detail + +template +CK_TILE_HOST_DEVICE constexpr auto modify_sequence_elements_by_ids(Seq, Values, Ids) +{ + static_assert(Values::size() == Ids::size() && Seq::size() >= Values::size(), "wrong!"); + + return typename detail::modify_sequence_elements_by_ids_impl::type{}; +} +#endif + +template +CK_TILE_HOST_DEVICE constexpr index_t +reduce_on_sequence(Seq, Reduce f, number /*initial_value*/) +{ + index_t result = Init; + + for(index_t i = 0; i < Seq::size(); ++i) + { + result = f(result, Seq::get(i)); + } + + return result; +} + +// TODO: a generic any_of for any container +template +CK_TILE_HOST_DEVICE constexpr bool sequence_any_of(Seq, F f) +{ + bool flag = false; + + for(index_t i = 0; i < Seq::size(); ++i) + { + flag = flag || f(Seq::get(i)); + } + + return flag; +} + +// TODO: a generic all_of for any container +template +CK_TILE_HOST_DEVICE constexpr bool sequence_all_of(Seq, F f) +{ + bool flag = true; + + for(index_t i = 0; i < Seq::size(); ++i) + { + flag = flag && f(Seq::get(i)); + } + + return flag; +} + +template +using sequence_merge_t = typename sequence_merge::type; + +template +using uniform_sequence_gen_t = typename uniform_sequence_gen::type; + +template +CK_TILE_HOST_DEVICE constexpr auto make_sequence(number...) +{ + return sequence{}; +} + +// F() returns index_t +// F use default constructor, so F cannot be lambda function +template +CK_TILE_HOST_DEVICE constexpr auto generate_sequence(F, number) +{ + return typename sequence_gen::type{}; +} + +// F() returns number<> +// F could be lambda function +template +CK_TILE_HOST_DEVICE constexpr auto generate_sequence_v2(F&& f, number) +{ + return unpack([&f](auto&&... xs) { return make_sequence(f(xs)...); }, + typename arithmetic_sequence_gen<0, N, 1>::type{}); +} + +// template +// CK_TILE_HOST_DEVICE constexpr auto to_sequence(Tuple...>) +// { +// return sequence{}; +// } + +namespace detail { +template +struct sorted_sequence_histogram; + +template +struct sorted_sequence_histogram, sequence> +{ + template + constexpr auto operator()(Histogram& h) + { + if constexpr(x < r) + { + h.template at() += 1; + sorted_sequence_histogram, sequence>{}(h); + } + else + { + h.template at() = 1; + sorted_sequence_histogram, sequence>{}(h); + } + } +}; + +template +struct sorted_sequence_histogram, sequence> +{ + template + constexpr auto operator()(Histogram& h) + { + if constexpr(x < r) + { + h.template at() += 1; + } + } +}; +} // namespace detail + +// SeqSortedSamples: <0, 2, 3, 5, 7>, SeqRange: <0, 3, 6, 9> -> SeqHistogram : <2, 2, 1> +template +constexpr auto histogram_sorted_sequence(SeqSortedSamples, sequence) +{ + constexpr auto bins = sizeof...(rs); // or categories + constexpr auto histogram = [&]() { + array h{0}; // make sure this can clear all element to zero + detail::sorted_sequence_histogram<0, SeqSortedSamples, sequence>{}(h); + return h; + }(); + + return TO_SEQUENCE(histogram, bins); +} + +template +CK_TILE_HOST_DEVICE constexpr auto generate_array(F&& f, number) +{ + using T = remove_cvref_t{}))>; + + return unpack([&f](auto&&... is) { return array{f(is)...}; }, + typename arithmetic_sequence_gen<0, N, 1>::type{}); +} + +} // namespace ck_tile diff --git a/include/ck_tile/core/container/span.hpp b/include/ck_tile/core/container/span.hpp new file mode 100644 index 0000000000..eeb1f226a9 --- /dev/null +++ b/include/ck_tile/core/container/span.hpp @@ -0,0 +1,78 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core/config.hpp" +#include +#include +#include + +namespace ck_tile { + +// implement the c++20 std::span, lightweight, non-owning reference to a sequence +// weather it is dynamic or static range. Or can be seen as a view of a contiguous sequence +// TODO: do we need in device consider this is pointer? +template +class span +{ + public: + using element_type = T; + using value_type = std::remove_cv_t; + using size_type = std::size_t; + using difference_type = std::ptrdiff_t; + using pointer = element_type*; + using const_pointer = const element_type*; + using reference = element_type&; + using const_reference = const element_type&; + using iterator = pointer; + using const_iterator = pointer; + + CK_TILE_HOST_DEVICE constexpr span() : span(nullptr, size_type{0}) {} + + CK_TILE_HOST_DEVICE constexpr span(pointer first, size_type count) : ptr_(first), size_(count) + { + } + + CK_TILE_HOST_DEVICE constexpr span(pointer first, pointer last) : span(first, last - first) {} + + template + CK_TILE_HOST_DEVICE constexpr span(element_type (&arr)[N]) noexcept : span(arr, N) + { + } + + template + CK_TILE_HOST_DEVICE constexpr span(std::array& arr) noexcept + : span(arr.data(), N) + { + } + + template + CK_TILE_HOST_DEVICE constexpr span(const Container& container) + : span(container.data(), container.size()) + { + } + + CK_TILE_HOST_DEVICE constexpr iterator begin() const noexcept { return ptr_; } + CK_TILE_HOST_DEVICE constexpr const_iterator cbegin() const noexcept { return begin(); } + + CK_TILE_HOST_DEVICE constexpr iterator end() const noexcept { return begin() + size(); } + CK_TILE_HOST_DEVICE constexpr const_iterator cend() const noexcept { return end(); } + + CK_TILE_HOST_DEVICE constexpr reference front() const { return *begin(); } + CK_TILE_HOST_DEVICE constexpr reference back() const { return *(--end()); } + + CK_TILE_HOST_DEVICE constexpr reference operator[](size_type idx) const + { + return *(begin() + idx); + } + CK_TILE_HOST_DEVICE constexpr pointer data() const noexcept { return ptr_; } + + CK_TILE_HOST_DEVICE constexpr size_type size() const noexcept { return size_; } + + private: + pointer ptr_; + size_type size_; +}; + +} // namespace ck_tile diff --git a/include/ck_tile/core/container/statically_indexed_array.hpp b/include/ck_tile/core/container/statically_indexed_array.hpp new file mode 100644 index 0000000000..1542ad0768 --- /dev/null +++ b/include/ck_tile/core/container/statically_indexed_array.hpp @@ -0,0 +1,70 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core/config.hpp" +#include "ck_tile/core/container/array.hpp" +#include "ck_tile/core/numeric/integer.hpp" + +namespace ck_tile { + +#if CK_TILE_STATICALLY_INDEXED_ARRAY_DEFAULT == CK_TILE_STATICALLY_INDEXED_ARRAY_USE_TUPLE +namespace detail { +template +struct tuple_concat; + +template +struct tuple_concat, tuple> +{ + using type = tuple; +}; + +template +struct statically_indexed_array_impl +{ + using type = + typename tuple_concat::type, + typename statically_indexed_array_impl::type>::type; +}; + +template +struct statically_indexed_array_impl +{ + using type = tuple<>; +}; + +template +struct statically_indexed_array_impl +{ + using type = tuple; +}; +} // namespace detail + +template +using statically_indexed_array = typename detail::statically_indexed_array_impl::type; + +#else + +// consider mark this struct as deprecated +template +using statically_indexed_array = array; + +#endif + +// consider always use ck_tile::array for this purpose + +template +CK_TILE_HOST_DEVICE constexpr auto make_statically_indexed_array(const X& x, const Xs&... xs) +{ + return statically_indexed_array(x, static_cast(xs)...); +} + +// make empty statically_indexed_array +template +CK_TILE_HOST_DEVICE constexpr auto make_statically_indexed_array() +{ + return statically_indexed_array(); +} + +} // namespace ck_tile diff --git a/include/ck_tile/core/container/tuple.hpp b/include/ck_tile/core/container/tuple.hpp new file mode 100644 index 0000000000..a93ff0f42f --- /dev/null +++ b/include/ck_tile/core/container/tuple.hpp @@ -0,0 +1,483 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core/config.hpp" +#include "ck_tile/core/container/sequence.hpp" +#include "ck_tile/core/numeric/integer.hpp" +#include "ck_tile/core/numeric/integral_constant.hpp" +#include "ck_tile/core/numeric/math.hpp" +#include "ck_tile/core/utility/functional.hpp" +#include "ck_tile/core/utility/type_traits.hpp" +#include + +namespace ck_tile { + +namespace impl { + +template > +struct tuple_element +{ +}; + +template +struct tuple_element +{ + CK_TILE_HOST_DEVICE constexpr tuple_element() {} + CK_TILE_HOST_DEVICE constexpr tuple_element(const T&) {} +}; + +template +struct tuple_element +{ + CK_TILE_HOST_DEVICE constexpr tuple_element() {} + CK_TILE_HOST_DEVICE constexpr tuple_element(const T& e) : element(e) {} + T element; +}; + +template +CK_TILE_HOST_DEVICE constexpr T const& getv(tuple_element const& x) +{ + return x.element; +} + +template +CK_TILE_HOST_DEVICE constexpr T& getv(tuple_element& x) +{ + return x.element; +} + +template +CK_TILE_HOST_DEVICE constexpr T&& getv(tuple_element&& x) +{ + return static_cast(x.element); +} + +template +struct tuple_base; + +template +struct tuple_base, T...> : public tuple_element... +{ + CK_TILE_HOST_DEVICE constexpr tuple_base() {} + + template + CK_TILE_HOST_DEVICE constexpr explicit tuple_base(U const&... u) : tuple_element(u)... + { + } + + template + CK_TILE_HOST_DEVICE constexpr tuple_base(tuple_base, U...> const& u) + : tuple_element(getv(static_cast const&>(u)))... + { + } +}; +} // namespace impl + +template +struct tuple : impl::tuple_base, T...> +{ + CK_TILE_HOST_DEVICE + static constexpr auto size() { return sizeof...(T); } + using base = impl::tuple_base, T...>; + CK_TILE_HOST_DEVICE constexpr tuple() {} + + template + CK_TILE_HOST_DEVICE constexpr tuple(U const&... u) + : impl::tuple_base, T...>(u...) + { + } + + template + CK_TILE_HOST_DEVICE constexpr tuple(tuple const& u) + : impl::tuple_base, T...>( + static_cast const&>(u)) + { + } + + CK_TILE_HOST_DEVICE static constexpr bool is_static() + { + bool flag = true; + + static_for<0, sizeof...(Xs), 1>{}([&flag](auto i) { + flag &= is_static_v>>; + }); + + return flag; + } + +#define TP_COM_() static_assert(I < size(), "wrong! out of range") + // clang-format off + template CK_TILE_HOST_DEVICE constexpr const auto & get() const { TP_COM_(); return impl::getv(*this); } + template CK_TILE_HOST_DEVICE constexpr const auto & get(number) const { TP_COM_(); return get(); } + template CK_TILE_HOST_DEVICE constexpr auto & get() { TP_COM_(); return impl::getv(*this); } + template CK_TILE_HOST_DEVICE constexpr auto & get(number) { TP_COM_(); return get(); } + + template CK_TILE_HOST_DEVICE constexpr const auto & at() const { TP_COM_(); return impl::getv(*this); } + template CK_TILE_HOST_DEVICE constexpr const auto & at(number) const { TP_COM_(); return get(); } + template CK_TILE_HOST_DEVICE constexpr auto & at() { TP_COM_(); return impl::getv(*this); } + template CK_TILE_HOST_DEVICE constexpr auto & at(number) { TP_COM_(); return get(); } + + template CK_TILE_HOST_DEVICE constexpr auto & operator[](number) { TP_COM_(); return get(); } + template CK_TILE_HOST_DEVICE constexpr const auto & operator[](number) const { TP_COM_(); return get(); } + template CK_TILE_HOST_DEVICE constexpr auto & operator()(number) { TP_COM_(); return get(); } // TODO: compatible + // clang-format on +#undef TP_COM_ +}; + +// template +// CK_TILE_HOST_DEVICE constexpr +// tuple +// make_tuple(T const&... t) +// { +// return {t...}; +// } + +template +CK_TILE_HOST_DEVICE constexpr auto make_tuple(Xs&&... xs) +{ + return tuple...>(std::forward(xs)...); +} + +// https://en.cppreference.com/w/cpp/utility/tuple/tie +template +constexpr tuple tie(Args&... args) noexcept +{ + return {args...}; +} + +template +struct tuple_concat; + +template +struct tuple_concat, tuple> +{ + using type = tuple; +}; + +template +CK_TILE_HOST_DEVICE constexpr auto generate_tuple(F&& f, number) +{ + return unpack([&f](auto&&... is) { return make_tuple(f(is)...); }, + typename arithmetic_sequence_gen<0, N, 1>::type{}); +} + +template +CK_TILE_HOST_DEVICE constexpr auto generate_tie(F&& f, number) +{ + return unpack([&f](auto&&... is) { return tie(f(is)...); }, + typename arithmetic_sequence_gen<0, N, 1>::type{}); +} + +// tx and ty are tuple of references, return type of will tuple of referennce (not rvalue) +template +CK_TILE_HOST_DEVICE constexpr auto concat_tuple_of_reference(const tuple& tx, + const tuple& ty) +{ + return unpack2( + [&](auto&&... zs) { return tuple{std::forward(zs)...}; }, + tx, + ty); +} + +template +CK_TILE_HOST_DEVICE constexpr auto concat_tuple(const tuple& tx, const tuple& ty) +{ + return unpack2( + [&](auto... zs) { return tuple{std::forward(zs)...}; }, + tx, + ty); +} + +// Support any number of tuples to concat (also 1) +template +CK_TILE_HOST_DEVICE constexpr auto concat_tuple(const tuple& tx) +{ + return tx; +} + +template +CK_TILE_HOST_DEVICE constexpr auto concat_tuple(const tuple& tx, const Tuples&... tuples) +{ + return concat_tuple(tx, concat_tuple(tuples...)); +} + +namespace detail { + +template +CK_TILE_HOST_DEVICE constexpr auto transform_tuples_impl(F f, const X& x, sequence) +{ + return make_tuple(f(x.at(number{}))...); +} + +template +CK_TILE_HOST_DEVICE constexpr auto +transform_tuples_impl(F f, const X& x, const Y& y, sequence) +{ + return make_tuple(f(x.at(number{}), y.at(number{}))...); +} + +template +CK_TILE_HOST_DEVICE constexpr auto +transform_tuples_impl(F f, const X& x, const Y& y, const Z& z, sequence) +{ + return make_tuple(f(x.at(number{}), y.at(number{}), z.at(number{}))...); +} + +} // namespace detail + +template +CK_TILE_HOST_DEVICE constexpr auto transform_tuples(F f, const X& x) +{ + return detail::transform_tuples_impl( + f, x, typename arithmetic_sequence_gen<0, X::size()(), 1>::type{}); +} + +template +CK_TILE_HOST_DEVICE constexpr auto transform_tuples(F f, const X& x, const Y& y) +{ + return detail::transform_tuples_impl( + f, x, y, typename arithmetic_sequence_gen<0, X::size()(), 1>::type{}); +} + +template +CK_TILE_HOST_DEVICE constexpr auto transform_tuples(F f, const X& x, const Y& y, const Z& z) +{ + return detail::transform_tuples_impl( + f, x, y, z, typename arithmetic_sequence_gen<0, X::size()(), 1>::type{}); +} + +// By default unroll to the flatten +template +CK_TILE_HOST_DEVICE constexpr auto unroll_nested_tuple(const tuple<>& element) +{ + return element; +} + +template +CK_TILE_HOST_DEVICE constexpr auto unroll_nested_tuple(const T& element) +{ + return make_tuple(element); +} + +template +CK_TILE_HOST_DEVICE constexpr auto unroll_nested_tuple(const tuple& tuple) +{ + if constexpr(Depth == MaxDepth) + { + return tuple; + } + else + { + return unpack( + [&](auto&&... ts) { + return concat_tuple(unroll_nested_tuple(ts)...); + }, + tuple); + } +} + +template +CK_TILE_HOST_DEVICE constexpr auto tuple_reverse(const tuple& tuple) +{ + return generate_tuple( + [&](auto i) { + using Idx = number::size()() - i - 1>; + return tuple.at(Idx{}); + }, + number::size()()>{}); +} + +// Reduce tuple values in specific range using Function +template +CK_TILE_HOST_DEVICE constexpr auto tuple_reduce(F&& f, const tuple& tuple) +{ + static_assert(Idx < End, "Wrong parameters for tuple_reduce"); + if constexpr(Idx + 1 == End) + { + return tuple.at(number{}); + } + else + { + return f(tuple.at(number{}), tuple_reduce(f, tuple)); + } +} + +template +using is_tuple = decltype(std::declval().IsTuple()); + +template +CK_TILE_HOST_DEVICE constexpr auto is_nested_tuple(const tuple&) +{ + return (is_detected::value || ...); +} + +template +CK_TILE_HOST_DEVICE constexpr auto tuple_depth(const T&) +{ + return depth; +} + +template +CK_TILE_HOST_DEVICE constexpr auto tuple_depth(const tuple&) +{ + return math::max(tuple_depth(Ts{})...); +} + +template +CK_TILE_HOST_DEVICE constexpr auto to_array_of_array(tuple t_of_s) +{ + constexpr index_t n0 = sizeof...(Seqs); + + constexpr index_t max_n1 = [&] { + index_t max_n1_ = 0; + + static_for<0, n0, 1>{}([&](auto i0) { + constexpr index_t n1 = t_of_s[i0].size()(); + + max_n1_ = max_n1_ < n1 ? n1 : max_n1_; + }); + + return max_n1_; + }(); + + array, n0> a_of_a{{-1}}; + + static_for<0, n0, 1>{}([&](auto i0) { + constexpr index_t n1 = t_of_s[i0].size()(); + + static_for<0, n1, 1>{}([&](auto i1) { a_of_a(i0)(i1) = t_of_s[i0][i1]; }); + }); + + return a_of_a; +} + +// Here should use MultiIndex, instead of tuple, although the former +// is the alias of the latter. This is because compiler cannot infer the NSize if +// using MultiIndex +// TODO: how to fix this? +template ::value && !std::is_floating_point::value, bool> = + false> +CK_TILE_HOST_DEVICE constexpr auto operator+=(tuple& y, const X& x) +{ + static_assert(X::Size() == sizeof...(Ys), "wrong! size not the same"); + constexpr index_t NSize = sizeof...(Ys); + static_for<0, NSize, 1>{}([&](auto i) { y[i] += x[i]; }); + return y; +} + +template ::value && !std::is_floating_point::value, bool> = + false> +CK_TILE_HOST_DEVICE constexpr auto operator-=(tuple& y, const X& x) +{ + static_assert(X::Size() == sizeof...(Ys), "wrong! size not the same"); + constexpr index_t NSize = sizeof...(Ys); + static_for<0, NSize, 1>{}([&](auto i) { y[i] -= x[i]; }); + return y; +} + +template ::value && !std::is_floating_point::value, bool> = + false> +CK_TILE_HOST_DEVICE constexpr auto operator+(const tuple& x, const Y& y) +{ + static_assert(Y::Size() == sizeof...(Xs), "wrong! size not the same"); + constexpr index_t NSize = sizeof...(Xs); + + tuple r; + static_for<0, NSize, 1>{}([&](auto i) { r[i] = x[i] + y[i]; }); + return r; +} + +template ::value && !std::is_floating_point::value, bool> = + false> +CK_TILE_HOST_DEVICE constexpr auto operator-(const tuple& x, const Y& y) +{ + static_assert(Y::Size() == sizeof...(Xs), "wrong! size not the same"); + constexpr index_t NSize = sizeof...(Xs); + + tuple r; + static_for<0, NSize, 1>{}([&](auto i) { r[i] = x[i] - y[i]; }); + return r; +} + +template ::value && !std::is_floating_point::value, bool> = + false> +CK_TILE_HOST_DEVICE constexpr auto operator*(const tuple& x, const Y& y) +{ + static_assert(Y::Size() == sizeof...(Xs), "wrong! size not the same"); + constexpr index_t NSize = sizeof...(Xs); + + tuple r; + static_for<0, NSize, 1>{}([&](auto i) { r[i] = x[i] * y[i]; }); + return r; +} + +// MultiIndex = scalar * MultiIndex +template < + typename... Xs, + typename Y, + std::enable_if_t::value || std::is_floating_point::value, bool> = false> +CK_TILE_HOST_DEVICE constexpr auto operator*(Y a, const tuple& x) +{ + constexpr index_t NSize = sizeof...(Xs); + tuple r; + static_for<0, NSize, 1>{}([&](auto i) { r[i] = a * x[i]; }); + return r; +} + +// MultiIndex = MultiIndex * scalar +template < + typename... Xs, + typename Y, + std::enable_if_t::value || std::is_floating_point::value, bool> = false> +CK_TILE_HOST_DEVICE constexpr auto operator*(const tuple& x, Y a) +{ + return a * x; +} + +template +CK_TILE_HOST_DEVICE constexpr auto operator/(const tuple& x, const tuple& y) +{ + static_assert(sizeof...(Xs) == sizeof...(Ys), "wrong!"); + constexpr index_t NSize = sizeof...(Xs); + return generate_tuple([&](auto i) { return x[i] / y[i]; }, number{}); +} + +} // namespace ck_tile + +// WARNING: needed by compiler for C++ structured binding support only, don't use this +namespace std { + +template +struct tuple_size> : std::integral_constant +{ +}; + +template +struct tuple_element> : ck_tile::tuple_element> +{ +}; + +template +struct tuple_size> : std::integral_constant +{ +}; + +template +struct tuple_element> + : ck_tile::tuple_element> +{ +}; + +} // namespace std diff --git a/include/ck_tile/core/numeric/arithmetic.hpp b/include/ck_tile/core/numeric/arithmetic.hpp new file mode 100644 index 0000000000..970ea9ff61 --- /dev/null +++ b/include/ck_tile/core/numeric/arithmetic.hpp @@ -0,0 +1,116 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +#include + +#pragma once + +#define CK_TILE_ARITHMETIC_USING_FLOAT(type_) \ + CK_TILE_HOST_DEVICE \ + bool operator==(const type_& x, const type_& y) \ + { \ + return static_cast(x) == static_cast(y); \ + } \ + CK_TILE_HOST_DEVICE \ + bool operator!=(const type_& x, const type_& y) \ + { \ + return static_cast(x) != static_cast(y); \ + } \ + CK_TILE_HOST_DEVICE \ + bool operator<(const type_& x, const type_& y) \ + { \ + return static_cast(x) < static_cast(y); \ + } \ + CK_TILE_HOST_DEVICE \ + bool operator<=(const type_& x, const type_& y) \ + { \ + return static_cast(x) <= static_cast(y); \ + } \ + CK_TILE_HOST_DEVICE \ + bool operator>(const type_& x, const type_& y) \ + { \ + return static_cast(x) > static_cast(y); \ + } \ + CK_TILE_HOST_DEVICE \ + bool operator>=(const type_& x, const type_& y) \ + { \ + return static_cast(x) >= static_cast(y); \ + } \ + CK_TILE_HOST_DEVICE \ + type_ operator+(const type_& x, const type_& y) \ + { \ + return type_(static_cast(x) + static_cast(y)); \ + } \ + CK_TILE_HOST_DEVICE \ + type_ operator-(const type_& x) \ + { \ + constexpr uint32_t bits = sizeof(type_) * 8; \ + constexpr uint32_t mask = 1 << (bits - 1); \ + type_ y = x; \ + y.data ^= static_cast(mask); \ + return y; \ + } \ + CK_TILE_HOST_DEVICE \ + type_ operator-(const type_& x, const type_& y) \ + { \ + return type_(static_cast(x) - static_cast(y)); \ + } \ + CK_TILE_HOST_DEVICE \ + type_ operator*(const type_& x, const type_& y) \ + { \ + return type_(static_cast(x) * static_cast(y)); \ + } \ + CK_TILE_HOST_DEVICE \ + type_ operator/(const type_& x, const type_& y) \ + { \ + return type_(static_cast(x) / static_cast(y)); \ + } \ + CK_TILE_HOST_DEVICE \ + type_& operator+=(type_& x, const type_& y) \ + { \ + x = type_(static_cast(x) + static_cast(y)); \ + return x; \ + } \ + CK_TILE_HOST_DEVICE \ + type_& operator-=(type_& x, const type_& y) \ + { \ + x = type_(static_cast(x) - static_cast(y)); \ + return x; \ + } \ + CK_TILE_HOST_DEVICE \ + type_& operator*=(type_& x, const type_& y) \ + { \ + x = type_(static_cast(x) * static_cast(y)); \ + return x; \ + } \ + CK_TILE_HOST_DEVICE \ + type_& operator/=(type_& x, const type_& y) \ + { \ + x = type_(static_cast(x) / static_cast(y)); \ + return x; \ + } \ + CK_TILE_HOST_DEVICE \ + type_& operator++(type_& x) \ + { \ + x = type_(static_cast(x) + 1.f); \ + return x; \ + } \ + CK_TILE_HOST_DEVICE \ + type_& operator--(type_& x) \ + { \ + x = type_(static_cast(x) - 1.f); \ + return x; \ + } \ + CK_TILE_HOST_DEVICE \ + type_ operator++(type_& x, int) \ + { \ + type_ y(x); \ + x = type_(static_cast(x) + 1.f); \ + return y; \ + } \ + CK_TILE_HOST_DEVICE \ + type_ operator--(type_& x, int) \ + { \ + type_ y(x); \ + x = type_(static_cast(x) - 1.f); \ + return y; \ + } diff --git a/include/ck_tile/core/numeric/bfloat16.hpp b/include/ck_tile/core/numeric/bfloat16.hpp new file mode 100644 index 0000000000..d69024883e --- /dev/null +++ b/include/ck_tile/core/numeric/bfloat16.hpp @@ -0,0 +1,263 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck_tile/core/config.hpp" +#include "ck_tile/core/utility/bit_cast.hpp" +#include "ck_tile/core/numeric/arithmetic.hpp" +#include "ck_tile/core/numeric/half.hpp" +#include "ck_tile/core/numeric/integral_constant.hpp" +#include + +#pragma once + +namespace ck_tile { + +enum class bf16_rounding_mode +{ + standard = 0, // rtn + truncate_with_nan, + truncate, +}; + +template +CK_TILE_HOST_DEVICE uint16_t float_to_bf16_raw(float f, constant = {}); + +CK_TILE_HOST_DEVICE +float bf16_to_float_raw(uint16_t x); + +// HIP use __hip_bfloat16 as struct +struct alignas(2) bfloat16_t +{ + using raw_type = uint16_t; + raw_type data; + + CK_TILE_HOST_DEVICE + static bfloat16_t bit_cast(raw_type x) + { + bfloat16_t y; + y.data = x; + return y; + } + + // constructor + bfloat16_t() = default; + + // construct from float + CK_TILE_HOST_DEVICE + explicit bfloat16_t(const float& x) { data = float_to_bf16_raw(x); } + + // construct from int + CK_TILE_HOST_DEVICE + explicit bfloat16_t(const int& x) { data = float_to_bf16_raw(static_cast(x)); } + + // construct from unsigned int + CK_TILE_HOST_DEVICE + explicit bfloat16_t(const unsigned int& x) { data = float_to_bf16_raw(static_cast(x)); } + + // cast to float + CK_TILE_HOST_DEVICE + explicit operator float() const { return bf16_to_float_raw(data); } + + // cast to int + CK_TILE_HOST_DEVICE + explicit operator int() const { return static_cast(bf16_to_float_raw(data)); } + + // internal access + CK_TILE_HOST_DEVICE + raw_type& get() { return data; } + + CK_TILE_HOST_DEVICE + raw_type get() const { return data; } +}; + +// round to nearest +CK_TILE_HOST_DEVICE +uint16_t float_to_bf16_rtn_raw(float f) +{ + union + { + float fp32; + uint32_t int32; + } u = {f}; + if(~u.int32 & 0x7f800000) + { + // When the exponent bits are not all 1s, then the value is zero, normal, + // or subnormal. We round the bfloat16 mantissa up by adding 0x7FFF, plus + // 1 if the least significant bit of the bfloat16 mantissa is 1 (odd). + // This causes the bfloat16's mantissa to be incremented by 1 if the 16 + // least significant bits of the float mantissa are greater than 0x8000, + // or if they are equal to 0x8000 and the least significant bit of the + // bfloat16 mantissa is 1 (odd). This causes it to be rounded to even when + // the lower 16 bits are exactly 0x8000. If the bfloat16 mantissa already + // has the value 0x7f, then incrementing it causes it to become 0x00 and + // the exponent is incremented by one, which is the next higher FP value + // to the unrounded bfloat16 value. When the bfloat16 value is subnormal + // with an exponent of 0x00 and a mantissa of 0x7F, it may be rounded up + // to a normal value with an exponent of 0x01 and a mantissa of 0x00. + // When the bfloat16 value has an exponent of 0xFE and a mantissa of 0x7F, + // incrementing it causes it to become an exponent of 0xFF and a mantissa + // of 0x00, which is Inf, the next higher value to the unrounded value. + u.int32 += 0x7fff + ((u.int32 >> 16) & 1); // Round to nearest, round to even + } + else if(u.int32 & 0xffff) + { + // When all of the exponent bits are 1, the value is Inf or NaN. + // Inf is indicated by a zero mantissa. NaN is indicated by any nonzero + // mantissa bit. Quiet NaN is indicated by the most significant mantissa + // bit being 1. Signaling NaN is indicated by the most significant + // mantissa bit being 0 but some other bit(s) being 1. If any of the + // lower 16 bits of the mantissa are 1, we set the least significant bit + // of the bfloat16 mantissa, in order to preserve signaling NaN in case + // the bloat16's mantissa bits are all 0. + u.int32 |= 0x10000; // Preserve signaling NaN + } + return uint16_t(u.int32 >> 16); +} + +// Truncate instead of rounding, preserving SNaN +CK_TILE_HOST_DEVICE +uint16_t float_to_bf16_truc_nan_raw(float f) +{ + union + { + float fp32; + uint32_t int32; + } u = {f}; + return uint16_t(u.int32 >> 16) | (!(~u.int32 & 0x7f800000) && (u.int32 & 0xffff)); +} + +// Fast truncate instead of rounding, RTZ +CK_TILE_HOST_DEVICE +uint16_t float_to_bf16_truc_raw(float f) +{ + union + { + float fp32; + uint32_t int32; + } u = {f}; + return uint16_t(u.int32 >> 16); +} + +template +CK_TILE_HOST_DEVICE uint16_t float_to_bf16_raw(float f, constant = {}) +{ + if constexpr(rounding == bf16_rounding_mode::standard) + return float_to_bf16_rtn_raw(f); + else if constexpr(rounding == bf16_rounding_mode::truncate_with_nan) + return float_to_bf16_truc_nan_raw(f); + else + return float_to_bf16_truc_raw(f); +} + +CK_TILE_HOST_DEVICE +float bf16_to_float_raw(uint16_t x) +{ + union + { + uint32_t int32; + float fp32; + } u = {uint32_t(x) << 16}; + return u.fp32; +} + +template +CK_TILE_HOST_DEVICE bfloat16_t float_to_bf16(float f, constant = {}) +{ + return bfloat16_t::bit_cast(float_to_bf16_raw(f, constant{})); +} + +CK_TILE_HOST_DEVICE +float bf16_to_float(bfloat16_t x) { return static_cast(x); } + +template +CK_TILE_HOST_DEVICE bfloat16_t fp16_to_bf16(half_t f, constant = {}) +{ + return bfloat16_t::bit_cast(float_to_bf16_raw(static_cast(f), constant{})); +} + +CK_TILE_HOST_DEVICE +float bf16_to_fp16(bfloat16_t x) { return float_to_fp16(static_cast(x)); } + +template +struct numeric_limits; + +template <> +struct numeric_limits +{ + // minimum finite value, or minimum positive normalized value for float + CK_TILE_HOST_DEVICE static constexpr bfloat16_t min() { return bfloat16_t::bit_cast(0x0080); } + + // minumum finite value + CK_TILE_HOST_DEVICE static constexpr bfloat16_t lowest() + { + return bfloat16_t::bit_cast(0xff7f); + } + + // maximum finite value + CK_TILE_HOST_DEVICE static constexpr bfloat16_t max() { return bfloat16_t::bit_cast(0x7f7f); } + + // difference between 1.0 and next value representable by float + CK_TILE_HOST_DEVICE static constexpr bfloat16_t epsilon() + { + return bfloat16_t::bit_cast(0x1000); + } + + // maximum rounding error + CK_TILE_HOST_DEVICE static constexpr bfloat16_t round_error() { return bfloat16_t(0.5f); } + + // positive infinity value + CK_TILE_HOST_DEVICE static constexpr bfloat16_t infinity() + { + return bfloat16_t::bit_cast(0x7f80); + } + + // quiet NaN + CK_TILE_HOST_DEVICE static constexpr bfloat16_t quiet_NaN() + { + return bfloat16_t::bit_cast(0x7FFF); + } + + // signaling NaN + CK_TILE_HOST_DEVICE static constexpr bfloat16_t signaling_NaN() + { + return bfloat16_t::bit_cast(0x7FFF); + } + + // smallest positive subnormal value + CK_TILE_HOST_DEVICE static constexpr bfloat16_t denorm_min() + { + return bfloat16_t::bit_cast(0x0001); + } +}; + +CK_TILE_ARITHMETIC_USING_FLOAT(bfloat16_t) + +// math +CK_TILE_HOST_DEVICE +bfloat16_t abs(const bfloat16_t& x) { return bfloat16_t::bit_cast(x.get() & 0x7fff); } + +CK_TILE_HOST_DEVICE +bool isnan(const bfloat16_t& x) +{ + uint16_t xx = x.get(); + return (xx & 0x7FFF) > 0x7C00; +} + +CK_TILE_DEVICE +bfloat16_t sqrt(bfloat16_t x) +{ + return static_cast(__builtin_amdgcn_sqrtf(static_cast(x))); +}; + +CK_TILE_DEVICE +bfloat16_t exp(bfloat16_t x) { return static_cast(__expf(static_cast(x))); }; + +CK_TILE_DEVICE +bfloat16_t exp2(bfloat16_t x) { return static_cast(exp2f(static_cast(x))); }; + +CK_TILE_DEVICE +bfloat16_t log(bfloat16_t x) { return static_cast(__logf(static_cast(x))); }; + +using bf16_t = bfloat16_t; + +} // namespace ck_tile diff --git a/include/ck_tile/core/numeric/float8.hpp b/include/ck_tile/core/numeric/float8.hpp new file mode 100644 index 0000000000..c9af63dd43 --- /dev/null +++ b/include/ck_tile/core/numeric/float8.hpp @@ -0,0 +1,735 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck_tile/core/config.hpp" +#include "ck_tile/core/utility/bit_cast.hpp" +#include "ck_tile/core/utility/limits.hpp" +#include "ck_tile/core/utility/random.hpp" +#include "ck_tile/core/numeric/arithmetic.hpp" +#include "ck_tile/core/numeric/half.hpp" +#include "ck_tile/core/numeric/integral_constant.hpp" +#include +#include + +#pragma once + +namespace ck_tile { + +// fp8 rounding modes +// use standard for rounding to nearest, the faster one +// use stochastic for stochastic rounding, helps to avoid error accumulation +enum class fp8_rounding_mode +{ + standard = 0, + stochastic +}; + +/* + * ______________NANOO_________________ | ______________IEEE________________ + * e4m3 e5m2 | e4m3 e5m2 + * bias : 8 16 | 7 15 + * inf : 1.0000.000 1.00000.00 | N/A s.11111.00 + * Nan : 1.0000.000 1.00000.00 | s.1111.111 s.11111.{01, 10, 11} + * zero : 0.0000.000 0.00000.00 | s.0000.000 s.00000.00 + * Max(norm) : s.1111.111 (240) s.11111.11(57344) | s.1111.110(448) s.11110.11(57344) + * Max(snorm): s.0000.111 s.00000.11 | s.0000.111(448) s.00000.11(57344) + * 0.0068359375 2.288818e-05 | 0.013671875 4.57763671875e-05 + * Min(norm) : s.0001.000 s.00001.00 | s.0001.000 s.00001.00 + * 2^-7(0.00078125) 2^-15(3.05176e-05) | 2^-6(0.015625) 2^-14(6.10352e-05) + * Min(snorm): s.0000.001 s.00000.01 | s.0000.001 s.00000.01 + * 2^-10(0.00097656) 2^-17(7.629395e-06)| 2^-9(0.001953125) 2^-16(1.52588e-05) + */ + +template (CK_TILE_FLOAT_TO_FP8_DEFAULT)> +CK_TILE_HOST_DEVICE uint8_t float_to_fp8_raw(float, constant = {}); + +template (CK_TILE_FLOAT_TO_FP8_DEFAULT)> +CK_TILE_HOST_DEVICE uint8_t float_to_bf8_raw(float, constant = {}); + +CK_TILE_HOST_DEVICE float fp8_to_float_raw(uint8_t); +CK_TILE_HOST_DEVICE float bf8_to_float_raw(uint8_t); + +struct alignas(1) float8_e4m3_t +{ + static constexpr int exponent = 4; + static constexpr int mantissa = 3; +#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) + static constexpr int bias = 1 << (exponent - 1); // NANOO +#else + static constexpr int bias = (1 << (exponent - 1)) - 1; // IEEE +#endif + using raw_type = uint8_t; + raw_type data; + + CK_TILE_HOST_DEVICE + static float8_e4m3_t bit_cast(raw_type x) + { + float8_e4m3_t y; + y.data = x; + return y; + } + + // constructor + float8_e4m3_t() = default; + + // construct from float + CK_TILE_HOST_DEVICE + explicit float8_e4m3_t(const float& x) { data = float_to_fp8_raw(x); } + + // construct from int + CK_TILE_HOST_DEVICE + explicit float8_e4m3_t(const int& x) { data = float_to_fp8_raw(static_cast(x)); } + + // construct from unsigned int + CK_TILE_HOST_DEVICE + explicit float8_e4m3_t(const unsigned int& x) + { + data = float_to_fp8_raw(static_cast(x)); + } + + // cast to float + CK_TILE_HOST_DEVICE + explicit operator float() const { return fp8_to_float_raw(data); } + + // cast to int + CK_TILE_HOST_DEVICE + explicit operator int() const { return static_cast(fp8_to_float_raw(data)); } + + // internal access + CK_TILE_HOST_DEVICE + raw_type& get() { return data; } + + CK_TILE_HOST_DEVICE + raw_type get() const { return data; } +}; + +struct alignas(1) float8_e5m2_t +{ + static constexpr int exponent = 5; + static constexpr int mantissa = 2; +#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) + static constexpr int bias = 1 << (exponent - 1); // NANOO +#else + static constexpr int bias = (1 << (exponent - 1)) - 1; // IEEE +#endif + using raw_type = uint8_t; + raw_type data; + + CK_TILE_HOST_DEVICE + static float8_e5m2_t bit_cast(raw_type x) + { + float8_e5m2_t y; + y.data = x; + return y; + } + + // constructor + float8_e5m2_t() = default; + + // construct from float + CK_TILE_HOST_DEVICE + explicit float8_e5m2_t(const float& x) { data = float_to_bf8_raw(x); } + + // construct from int + CK_TILE_HOST_DEVICE + explicit float8_e5m2_t(const int& x) { data = float_to_bf8_raw(static_cast(x)); } + + // construct from unsigned int + CK_TILE_HOST_DEVICE + explicit float8_e5m2_t(const unsigned int& x) + { + data = float_to_bf8_raw(static_cast(x)); + } + + // cast to float + CK_TILE_HOST_DEVICE + explicit operator float() const { return bf8_to_float_raw(data); } + + // cast to int + CK_TILE_HOST_DEVICE + explicit operator int() const { return static_cast(bf8_to_float_raw(data)); } + + // internal access + CK_TILE_HOST_DEVICE + raw_type& get() { return data; } + + CK_TILE_HOST_DEVICE + raw_type get() const { return data; } +}; + +// below is sw fp8 conversion, not utilizing hw instruction +namespace impl { + +template +CK_TILE_HOST_DEVICE Y run_cast_to_f8(X x, uint32_t rng) +{ + // fp8/bf8 exponent/mantissa layout + constexpr int out_exp = numeric_utils::exp; + constexpr int out_mant = numeric_utils::mant; + + // original type exponent/mantissa layout + constexpr int in_exp = numeric_utils::exp; + constexpr int in_mant = numeric_utils::mant; + + int exponent, bias; + uint32_t head, mantissa, sign; + // nan code is same for float and half + constexpr Y nan_code = 0x80; + constexpr uint32_t nan_mask = numeric_utils::nan_mask; + + // convert to bitwise + using T_bitwise = typename numeric_utils::bitwise_type; + T_bitwise x_bitwise = *(reinterpret_cast(&x)); + + // unpack the input, depends on datatype + head = x_bitwise & numeric_utils::head_mask; + mantissa = x_bitwise & numeric_utils::mant_mask; + exponent = (head >> in_mant) & numeric_utils::exp_mask; + sign = head >> (in_exp + in_mant); + bias = numeric_utils::bias; + + uint32_t signed_inf = (sign << (in_exp + in_mant)) + (((1 << in_exp) - 1) << in_mant); + uint32_t drop_mask = (1 << (in_mant - out_mant)) - 1; + constexpr int max_exp = (1 << out_exp) - (negative_zero_nan ? 1 : 2); + + if constexpr(negative_zero_nan) + { + if((x_bitwise & nan_mask) == nan_mask) + return nan_code; + } + else + { + if((x_bitwise & nan_mask) == nan_mask) + return signed_inf + (mantissa != 0 ? 1 : 0); + } + + // check if x is 0.0 + if(x_bitwise == 0) + return 0; + + // First need to check if it is normal or denorm as there is a difference of implict 1 + // Then need to adjust the exponent to align with the F8 exponent, in the meanwhile, shift + // The mantissa. Then for stochastic rounding, add rng to mantissa and truncate. And for + // RNE, no need to add rng. Then probably need to check whether there is carry and adjust + // exponent and mantissa again3 + + // For IEEE bias mode, the bias is 2^(k-1)-1 where k is the width of exponent bits + const int out_bias = (1 << (out_exp - 1)) - 1 + (negative_zero_nan ? 1 : 0); + const int out_denormal_act_exponent = 1 - out_bias; // actual exponent of f8 denormal + // act_exponent is the actual exponent of fp32/fp16 (after subtracting bias) + // out_exponent is the converted f8 exponent with bias encoding + // exponent_diff is the diff between fp32/fp16 exponent and f8 exponent, + // the difference needs to be adjusted and mantissa shifted + int act_exponent, out_exponent, exponent_diff; + + if(exponent == 0) + { // fp32/fp16 is in denormal. + /* fp32 denormal is below 2^-127 so it is usually not a concern here, we mostly concern fp16 +here. In this case, f8 is usually in denormal. But there could be exceptions. fp16 denormal has +exponent bias 15 while bf8 with NANOO has exponent bias 16. It means that there are some numbers in +fp16 denormal but they are bf8 (NANOO) normals - smallest bf8 (NANOO) normal is 2^-15. fp16 numbers +where exponent==0 (actual exponent -14) and highest bit of mantissa is 1 are bf8 (NANOO) normal. +In this case, the fp16 mantissa should be shift left by 1 */ + act_exponent = exponent - bias + 1; + exponent_diff = out_denormal_act_exponent - + act_exponent; // actual exponent is exponent-bias+1 as it is denormal + } + else + { // fp32/fp16 is normal with implicit 1 + act_exponent = exponent - bias; + if(act_exponent <= out_denormal_act_exponent) + { + /* This is the case where fp32/fp16 is normal but it is in f8 denormal range. + For example fp8 nanoo mode, denormal exponent is -7, but if the fp32/fp16 + actual exponent is -7, it is actually larger due to the implict 1, + Therefore it needs to be adjust to -6 and mantissa shift right by 1. + So for fp32/fp16, exponent -8 is the cut point to convert to fp8 nanoo */ + exponent_diff = out_denormal_act_exponent - act_exponent; + } + else + { // both fp32/fp16 and f8 are in normal range + exponent_diff = + 0; // exponent_diff=0 does not mean there is no difference for this case, + // act_exponent could be larger. Just that it does not need shift mantissa + } + mantissa += (1 << in_mant); // Add the implicit 1 into mantissa + } + + bool midpoint = (mantissa & ((1 << (in_mant - out_mant + exponent_diff)) - 1)) == + (1 << (in_mant - out_mant + exponent_diff - 1)); + /* This part is a bit tricky. The judgment of whether it is a tie needs to be done before we + shift right as shift right could rip off some residual part and make something not midpoint look + like midpoint. For example, the fp16 number 0x1002 (0 00100 0000000010), it is larger than + midpoint, but after shift right by 4 bits, it would look like midpoint. */ + + if(exponent_diff > 0) + mantissa >>= exponent_diff; + else if(exponent_diff == -1) + mantissa <<= -exponent_diff; + bool implicit_one = mantissa & (1 << in_mant); + // if there is no implict 1, it means the f8 is denormal and need to adjust to denorm exponent + out_exponent = + (act_exponent + exponent_diff) /*actual f8 exponent*/ + out_bias - (implicit_one ? 0 : 1); + + // Now we have the exponent and mantissa adjusted + bool odd = + mantissa & + (1 << (in_mant - out_mant)); // if the least significant bit that is not truncated is 1 + mantissa += (stoch ? rng : (midpoint ? (odd ? mantissa : mantissa - 1) : mantissa)) & drop_mask; + + // Now we deal with overflow + if(out_exponent == 0) + { + if((1 << in_mant) & mantissa) + { + out_exponent = 1; // denormal overflow to become normal, promote exponent + // No need to make 1 implicit now as it will be addressed later + } + } + else + { + if((1 << (in_mant + 1)) & mantissa) + { + mantissa >>= 1; + out_exponent++; + // No need to make 1 implicit now as it will be addressed later + } + } + + mantissa >>= (in_mant - out_mant); + + if(out_exponent > max_exp) + { + if(clip) + { + mantissa = (1 << out_mant) - 1; + out_exponent = max_exp; + } + else + { + return signed_inf; + } + } + + // check if x is 0.0 or -0.0 + if(out_exponent == 0 && mantissa == 0) + return negative_zero_nan ? 0 : (sign << (out_exp + out_mant)); + mantissa &= (1 << out_mant) - 1; + return (sign << (out_exp + out_mant)) | (out_exponent << out_mant) | mantissa; +} + +template +CK_TILE_HOST_DEVICE Y run_cast_from_f8(X x) +{ + // fp8/bf8 exponent/mantissa layout + constexpr int in_exp = numeric_utils::exp; + constexpr int in_mant = numeric_utils::mant; + + // resulting type exponent/mantissa layout + constexpr int out_exp = numeric_utils::exp; + constexpr int out_mant = numeric_utils::mant; + + // prepare the codes + constexpr X nan_code = 0x80; + Y Inf, NegInf, NaN, Neg0; + using T_bitwise = typename numeric_utils::bitwise_type; + + constexpr T_bitwise Inf_bitwise = numeric_utils::Inf; + constexpr T_bitwise NegInf_bitwise = numeric_utils::NegInf; + constexpr T_bitwise NaN_bitwise = numeric_utils::NaN; + constexpr T_bitwise Neg0_bitwise = numeric_utils::Neg0; + + Inf = *(reinterpret_cast(&Inf_bitwise)); + NegInf = *(reinterpret_cast(&NegInf_bitwise)); + NaN = *(reinterpret_cast(&NaN_bitwise)); + Neg0 = *(reinterpret_cast(&Neg0_bitwise)); + + // check if x is 0.0 + if(x == 0) + return static_cast(0); + + // unpack the input + uint32_t sign = x >> (in_exp + in_mant); + uint32_t mantissa = x & ((1 << in_mant) - 1); + int exponent = (x & 0x7F) >> in_mant; + + constexpr int exp_low_cutoff = + (1 << (out_exp - 1)) - (1 << (in_exp - 1)) + 1 - (negative_zero_nan ? 1 : 0); + T_bitwise retval; + + if constexpr(negative_zero_nan) + { + if(x == nan_code) + return NaN; + } + else + { + if(x == nan_code) + return Neg0; + if(exponent == ((1 << in_exp) - 1)) + return (mantissa == 0) ? (sign ? NegInf : Inf) : NaN; + } + + if((numeric_utils::mant == 10) && (numeric_utils::mant == 2) && !negative_zero_nan) + { + retval = x; + retval <<= 8; + return *(reinterpret_cast(&retval)); + } + + // subnormal input + if(exponent == 0) + { + // guaranteed mantissa!=0 since cases 0x0 and 0x80 are handled above + int sh = 1 + clz(mantissa) - (32 - in_mant); + mantissa <<= sh; + exponent += 1 - sh; + mantissa &= ((1 << in_mant) - 1); + } + exponent += exp_low_cutoff - 1; + mantissa <<= out_mant - in_mant; + + // subnormal output (occurs when T=half, we=5, negative_zero_nan=true) + if(exponent <= 0) + { + mantissa |= 1 << out_mant; + mantissa >>= 1 - exponent; + exponent = 0; + } + + retval = (sign << (out_exp + out_mant)) | (exponent << out_mant) | mantissa; + return *(reinterpret_cast(&retval)); +} + +template +CK_TILE_HOST_DEVICE Y cast_to_f8(X x, uint32_t rng) +{ + // check datatypes + constexpr bool is_half = std::is_same::value; + constexpr bool is_float = std::is_same::value; + static_assert(is_half || is_float, "Only half and float can be casted."); + + return run_cast_to_f8(x, rng); +} + +template +CK_TILE_HOST_DEVICE Y cast_from_f8(X x) +{ + // check datatype + constexpr bool is_half = std::is_same::value; + constexpr bool is_float = std::is_same::value; + static_assert(is_half || is_float, "only half and float are supported."); + + return run_cast_from_f8(x); +} +} // namespace impl + +CK_TILE_HOST_DEVICE uint8_t float_to_fp8_sr_raw(float x) +{ + constexpr int seed = 42; + uint32_t rng = prand_generator{}(reinterpret_cast(&x), x); +#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) + float max_fp8 = 240.0f; + x = x > max_fp8 ? max_fp8 : (x < -max_fp8 ? -max_fp8 : x); + union + { + float fval; + uint32_t i32val; + uint8_t i8val[4]; // not endian independent + } val; + val.fval = x; + uint32_t ival = 0; + ival = __builtin_amdgcn_cvt_sr_fp8_f32(val.fval, rng, ival, 0); // 0 pos + val.i32val = ival; + return val.i8val[0]; // little endian +#else + constexpr bool negative_zero_nan = true; + constexpr bool clip = true; + constexpr fp8_rounding_mode rm = fp8_rounding_mode::stochastic; + return impl:: + cast_to_f8( + x, rng); +#endif +} + +CK_TILE_HOST_DEVICE uint8_t float_to_bf8_sr_raw(float x) +{ + constexpr int seed = 42; + uint32_t rng = prand_generator{}(reinterpret_cast(&x), x); +#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) + union + { + float fval; + uint32_t i32val; + uint8_t i8val[4]; // not endian independent + } val; + val.fval = x; + uint32_t ival = 0; + ival = __builtin_amdgcn_cvt_sr_bf8_f32(val.fval, rng, ival, 0); // 0 pos + val.i32val = ival; + return val.i8val[0]; // little endian +#else + constexpr bool negative_zero_nan = true; + constexpr bool clip = true; + constexpr fp8_rounding_mode rm = fp8_rounding_mode::stochastic; + return impl:: + cast_to_f8( + x, rng); +#endif +} + +CK_TILE_HOST_DEVICE uint8_t float_to_fp8_rtn_raw(float x) +{ +#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) + float max_fp8 = 240.0f; + x = x > max_fp8 ? max_fp8 : (x < -max_fp8 ? -max_fp8 : x); + union + { + float fval; + uint32_t i32val; + uint8_t i8val[4]; // not endian independent + } val; + val.fval = x; + uint32_t ival = 0; + ival = __builtin_amdgcn_cvt_pk_fp8_f32(val.fval, val.fval, ival, false); // false -> WORD0 + val.i32val = ival; + return val.i8val[0]; +#else + constexpr bool negative_zero_nan = true; + constexpr bool clip = true; + constexpr fp8_rounding_mode rm = fp8_rounding_mode::standard; + constexpr uint32_t rng = 0; + return impl:: + cast_to_f8( + x, rng); +#endif +} +CK_TILE_HOST_DEVICE uint8_t float_to_bf8_rtn_raw(float x) +{ +#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) + union + { + float fval; + uint32_t i32val; + uint8_t i8val[4]; // not endian independent + } val; + val.fval = x; + uint32_t ival = 0; + ival = __builtin_amdgcn_cvt_pk_bf8_f32(val.fval, val.fval, ival, false); // false -> WORD0 + val.i32val = ival; + return val.i8val[0]; +#else + constexpr bool negative_zero_nan = true; + constexpr bool clip = true; + constexpr fp8_rounding_mode rm = fp8_rounding_mode::standard; + constexpr uint32_t rng = 0; + return impl:: + cast_to_f8( + x, rng); +#endif +} + +// clang-format off +template(CK_TILE_FLOAT_TO_FP8_DEFAULT)> +CK_TILE_HOST_DEVICE uint8_t float_to_fp8_raw(float x, constant = {}) +{ + if constexpr (rounding == fp8_rounding_mode::standard) return float_to_fp8_rtn_raw(x); + else if constexpr (rounding == fp8_rounding_mode::stochastic) return float_to_fp8_sr_raw(x); + else return uint8_t{0}; +} + +template(CK_TILE_FLOAT_TO_FP8_DEFAULT)> +CK_TILE_HOST_DEVICE uint8_t float_to_bf8_raw(float x, constant = {}) +{ + if constexpr (rounding == fp8_rounding_mode::standard) return float_to_bf8_rtn_raw(x); + else if constexpr (rounding == fp8_rounding_mode::stochastic) return float_to_bf8_sr_raw(x); + else return uint8_t{0}; +} + +CK_TILE_HOST_DEVICE float fp8_to_float_raw(uint8_t x) +{ +#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) + float fval; + uint32_t i32val = static_cast(x); + fval = __builtin_amdgcn_cvt_f32_fp8(i32val, 0); + // asm volatile("v_cvt_f32_fp8 %0, %1 src0_sel:BYTE_0" : "=v"(fval) : "v"(i32val)); + return fval; +#else + constexpr bool negative_zero_nan = true; + return impl::cast_from_f8(x); +#endif +} + +CK_TILE_HOST_DEVICE float bf8_to_float_raw(uint8_t x) +{ +#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) + float fval; + uint32_t i32val = static_cast(x); + fval = __builtin_amdgcn_cvt_f32_bf8(i32val, 0); + // asm volatile("v_cvt_f32_bf8 %0, %1 src0_sel:BYTE_0" : "=v"(fval) : "v"(i32val)); + return fval; +#else + constexpr bool negative_zero_nan = true; + return impl::cast_from_f8(x); +#endif +} + +template(CK_TILE_FLOAT_TO_FP8_DEFAULT)> +CK_TILE_HOST_DEVICE float8_e4m3_t float_to_fp8(float x, constant = {}) +{ + return float8_e4m3_t::bit_cast(float_to_fp8_raw(x, constant{})); +} + +template(CK_TILE_FLOAT_TO_FP8_DEFAULT)> +CK_TILE_HOST_DEVICE float8_e5m2_t float_to_bf8(float x, constant = {}) +{ + return float8_e5m2_t::bit_cast(float_to_bf8_raw(x, constant{})); +} + +CK_TILE_HOST_DEVICE float fp8_to_float(float8_e4m3_t x) +{ + return fp8_to_float_raw(x.get()); +} + +CK_TILE_HOST_DEVICE float bf8_to_float(float8_e5m2_t x) +{ + return bf8_to_float_raw(x.get()); +} + +// clang-format on +using fp8_t = float8_e4m3_t; +using bf8_t = float8_e5m2_t; + +template +struct numeric_utils; + +template <> +struct numeric_utils +{ + static constexpr int exp = fp8_t::exponent; + static constexpr int mant = fp8_t::mantissa; + static constexpr int bias = fp8_t::bias; +}; + +template <> +struct numeric_utils +{ + static constexpr int exp = bf8_t::exponent; + static constexpr int mant = bf8_t::mantissa; + static constexpr int bias = bf8_t::bias; +}; + +template +struct numeric_limits; + +template <> +struct numeric_limits +{ + // minimum finite value, or minimum positive normalized value for float + CK_TILE_HOST_DEVICE static constexpr fp8_t min() { return fp8_t::bit_cast(0x08); } + + // minumum finite value + CK_TILE_HOST_DEVICE static constexpr fp8_t lowest() { return fp8_t::bit_cast(0xff); } + + // maximum finite value + CK_TILE_HOST_DEVICE static constexpr fp8_t max() { return fp8_t::bit_cast(0x7f); } + + // difference between 1.0 and next value representable by float + CK_TILE_HOST_DEVICE static constexpr fp8_t epsilon() { return fp8_t::bit_cast(0x20); } + + // maximum rounding error + CK_TILE_HOST_DEVICE static constexpr fp8_t round_error() { return fp8_t(0.5f); } + + // positive infinity value + CK_TILE_HOST_DEVICE static constexpr fp8_t infinity() { return fp8_t::bit_cast(0x80); } + + // quiet NaN + CK_TILE_HOST_DEVICE static constexpr fp8_t quiet_NaN() { return fp8_t::bit_cast(0x80); } + + // signaling NaN + CK_TILE_HOST_DEVICE static constexpr fp8_t signaling_NaN() { return fp8_t::bit_cast(0x80); } + + // smallest positive subnormal value + CK_TILE_HOST_DEVICE static constexpr fp8_t denorm_min() { return fp8_t::bit_cast(0x01); } +}; + +template <> +struct numeric_limits +{ + // minimum finite value, or minimum positive normalized value for float + CK_TILE_HOST_DEVICE static constexpr bf8_t min() { return bf8_t::bit_cast(0x04); } + + // minumum finite value + CK_TILE_HOST_DEVICE static constexpr bf8_t lowest() { return bf8_t::bit_cast(0xff); } + + // maximum finite value + CK_TILE_HOST_DEVICE static constexpr bf8_t max() { return bf8_t::bit_cast(0x7f); } + + // difference between 1.0 and next value representable by float + CK_TILE_HOST_DEVICE static constexpr bf8_t epsilon() { return bf8_t::bit_cast(0x34); } + + // maximum rounding error + CK_TILE_HOST_DEVICE static constexpr bf8_t round_error() { return bf8_t(0.5f); } + + // positive infinity value + CK_TILE_HOST_DEVICE static constexpr bf8_t infinity() { return bf8_t::bit_cast(0x80); } + + // quiet NaN + CK_TILE_HOST_DEVICE static constexpr bf8_t quiet_NaN() { return bf8_t::bit_cast(0x80); } + + // signaling NaN + CK_TILE_HOST_DEVICE static constexpr bf8_t signaling_NaN() { return bf8_t::bit_cast(0x80); } + + // smallest positive subnormal value + CK_TILE_HOST_DEVICE static constexpr bf8_t denorm_min() { return bf8_t::bit_cast(0x01); } +}; + +CK_TILE_ARITHMETIC_USING_FLOAT(fp8_t) +CK_TILE_ARITHMETIC_USING_FLOAT(bf8_t) + +// math +CK_TILE_HOST_DEVICE +fp8_t abs(const fp8_t& x) { return fp8_t::bit_cast(x.get() & 0x7f); } + +CK_TILE_HOST_DEVICE +bool isnan(const fp8_t& x) +{ + uint8_t xx = x.get(); + return xx == 0x80; // TODO: NANOO +} + +CK_TILE_DEVICE +fp8_t sqrt(fp8_t x) { return static_cast(__builtin_amdgcn_sqrtf(static_cast(x))); }; + +CK_TILE_DEVICE +fp8_t exp(fp8_t x) { return static_cast(__expf(static_cast(x))); }; + +CK_TILE_DEVICE +fp8_t exp2(fp8_t x) { return static_cast(exp2f(static_cast(x))); }; + +CK_TILE_DEVICE +fp8_t log(fp8_t x) { return static_cast(__logf(static_cast(x))); }; + +CK_TILE_HOST_DEVICE +bf8_t abs(const bf8_t& x) { return bf8_t::bit_cast(x.get() & 0x7f); } + +CK_TILE_HOST_DEVICE +bool isnan(const bf8_t& x) +{ + uint8_t xx = x.get(); + return xx == 0x80; // TODO: NANOO +} + +CK_TILE_DEVICE +bf8_t sqrt(bf8_t x) { return static_cast(__builtin_amdgcn_sqrtf(static_cast(x))); }; + +CK_TILE_DEVICE +bf8_t exp(bf8_t x) { return static_cast(__expf(static_cast(x))); }; + +CK_TILE_DEVICE +bf8_t exp2(bf8_t x) { return static_cast(exp2f(static_cast(x))); }; + +CK_TILE_DEVICE +bf8_t log(bf8_t x) { return static_cast(__logf(static_cast(x))); }; + +} // namespace ck_tile diff --git a/include/ck_tile/core/numeric/half.hpp b/include/ck_tile/core/numeric/half.hpp new file mode 100644 index 0000000000..4a6fc59c28 --- /dev/null +++ b/include/ck_tile/core/numeric/half.hpp @@ -0,0 +1,278 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck_tile/core/config.hpp" +#include "ck_tile/core/utility/bit_cast.hpp" +#include + +#pragma once + +namespace ck_tile { + +CK_TILE_HOST_DEVICE +float fp16_to_float_hip(const _Float16& x); + +CK_TILE_HOST_DEVICE +_Float16 float_to_fp16_hip(const float& x); + +// HIP use _Float16 as interchangable data type for float16 +struct alignas(2) half_t +{ + using raw_type = uint16_t; + raw_type data; + + CK_TILE_HOST_DEVICE + static half_t bit_cast(raw_type x) + { + half_t y; + y.data = x; + return y; + } + + CK_TILE_HOST_DEVICE + _Float16 to_fp16() const { return reinterpret_cast(data); } + + // constructor + half_t() = default; + + // construct from HIP half + CK_TILE_HOST_DEVICE + explicit half_t(const _Float16& x) : data(reinterpret_cast(x)) {} + + // construct from float + CK_TILE_HOST_DEVICE + explicit half_t(const float& x) : half_t(float_to_fp16_hip(x)) {} + + // construct from int + CK_TILE_HOST_DEVICE + explicit half_t(const int& x) : half_t(__int2half_rn(x)) {} + + // construct from unsigned int + CK_TILE_HOST_DEVICE + explicit half_t(const unsigned int& x) : half_t(__uint2half_rn(x)) {} + + // cast to float + CK_TILE_HOST_DEVICE + explicit operator float() const { return fp16_to_float_hip(to_fp16()); } + + // cast to int + CK_TILE_HOST_DEVICE + explicit operator int() const { return static_cast(fp16_to_float_hip(to_fp16())); } + + // internal access + CK_TILE_HOST_DEVICE + raw_type& get() { return data; } + + CK_TILE_HOST_DEVICE + raw_type get() const { return data; } +}; + +// conversions +CK_TILE_HOST_DEVICE +float fp16_to_float_hip(const _Float16& x) +{ + // return __half2float(x); + return static_cast(x); +} + +CK_TILE_HOST_DEVICE +_Float16 float_to_fp16_hip(const float& x) +{ + // return __float2half(x); + return static_cast<_Float16>(x); +} + +CK_TILE_HOST_DEVICE +float fp16_to_float(const half_t& x) { return static_cast(x); } + +CK_TILE_HOST_DEVICE +half_t float_to_fp16(const float& x) { return half_t{x}; } + +// limits +template +struct numeric_limits; + +template <> +struct numeric_limits +{ + // minimum finite value, or minimum positive normalized value for float + CK_TILE_HOST_DEVICE static constexpr half_t min() { return half_t::bit_cast(0x0400); } + + // minumum finite value + CK_TILE_HOST_DEVICE static constexpr half_t lowest() { return half_t::bit_cast(0xFBFF); } + + // maximum finite value + CK_TILE_HOST_DEVICE static constexpr half_t max() { return half_t::bit_cast(0x7BFF); } + + // difference between 1.0 and next value representable by float + CK_TILE_HOST_DEVICE static constexpr half_t epsilon() { return half_t::bit_cast(0x1800); } + + // maximum rounding error + CK_TILE_HOST_DEVICE static constexpr half_t round_error() { return half_t(0.5f); } + + // positive infinity value + CK_TILE_HOST_DEVICE static constexpr half_t infinity() { return half_t::bit_cast(0x7C00); } + + // quiet NaN + CK_TILE_HOST_DEVICE static constexpr half_t quiet_NaN() { return half_t::bit_cast(0x7FFF); } + + // signaling NaN + CK_TILE_HOST_DEVICE static constexpr half_t signaling_NaN() { return half_t::bit_cast(0x7FFF); } + + // smallest positive subnormal value + CK_TILE_HOST_DEVICE static constexpr half_t denorm_min() { return half_t::bit_cast(0x0001); } +}; + +template +struct numeric_utils; + +template <> +struct numeric_utils +{ + static constexpr int exp = 5; + static constexpr int mant = 10; + static constexpr int bias = 15; + static constexpr uint16_t nan_mask = 0x7C00; + static constexpr uint16_t head_mask = 0xFC00; + static constexpr uint16_t mant_mask = 0x3FF; + static constexpr uint16_t exp_mask = 0x1F; + static constexpr uint32_t Inf = 0x7C00; + static constexpr uint32_t NegInf = 0xFC00; + static constexpr uint32_t NaN = 0x7C01; + static constexpr uint32_t Neg0 = 0x8000; + using bitwise_type = uint16_t; +}; + +// arithmetic +CK_TILE_HOST_DEVICE +bool operator==(const half_t& x, const half_t& y) { return __heq(x.to_fp16(), y.to_fp16()); } + +CK_TILE_HOST_DEVICE +bool operator!=(const half_t& x, const half_t& y) { return __hne(x.to_fp16(), y.to_fp16()); } + +CK_TILE_HOST_DEVICE +bool operator<(const half_t& x, const half_t& y) { return __hlt(x.to_fp16(), y.to_fp16()); } + +CK_TILE_HOST_DEVICE +bool operator<=(const half_t& x, const half_t& y) { return __hle(x.to_fp16(), y.to_fp16()); } + +CK_TILE_HOST_DEVICE +bool operator>(const half_t& x, const half_t& y) { return __hgt(x.to_fp16(), y.to_fp16()); } + +CK_TILE_HOST_DEVICE +bool operator>=(const half_t& x, const half_t& y) { return __hge(x.to_fp16(), y.to_fp16()); } + +CK_TILE_HOST_DEVICE +half_t operator+(const half_t& x, const half_t& y) +{ + return half_t(__hadd(x.to_fp16(), y.to_fp16())); +} + +CK_TILE_HOST_DEVICE +half_t operator-(const half_t& x) { return half_t(__hneg(x.to_fp16())); } + +CK_TILE_HOST_DEVICE +half_t operator-(const half_t& x, const half_t& y) +{ + return half_t(__hsub(x.to_fp16(), y.to_fp16())); +} + +CK_TILE_HOST_DEVICE +half_t operator*(const half_t& x, const half_t& y) +{ + return half_t(__hmul(x.to_fp16(), y.to_fp16())); +} + +CK_TILE_HOST_DEVICE +half_t operator/(const half_t& x, const half_t& y) +{ + return half_t(__hdiv(x.to_fp16(), y.to_fp16())); +} + +CK_TILE_HOST_DEVICE +half_t& operator+=(half_t& x, const half_t& y) +{ + x = half_t(__hadd(x.to_fp16(), y.to_fp16())); + return x; +} + +CK_TILE_HOST_DEVICE +half_t& operator-=(half_t& x, const half_t& y) +{ + x = half_t(__hsub(x.to_fp16(), y.to_fp16())); + return x; +} + +CK_TILE_HOST_DEVICE +half_t& operator*=(half_t& x, const half_t& y) +{ + x = half_t(__hmul(x.to_fp16(), y.to_fp16())); + return x; +} + +CK_TILE_HOST_DEVICE +half_t& operator/=(half_t& x, const half_t& y) +{ + x = half_t(__hdiv(x.to_fp16(), y.to_fp16())); + return x; +} + +CK_TILE_HOST_DEVICE +half_t& operator++(half_t& x) +{ + x = half_t(__hadd(x.to_fp16(), half_t(1.0f).to_fp16())); + return x; +} + +CK_TILE_HOST_DEVICE +half_t& operator--(half_t& x) +{ + x = half_t(__hsub(x.to_fp16(), half_t(1.0f).to_fp16())); + return x; +} + +CK_TILE_HOST_DEVICE +half_t operator++(half_t& x, int) +{ + half_t y(x); + x = half_t(__hadd(x.to_fp16(), half_t(1.0f).to_fp16())); + return y; +} + +CK_TILE_HOST_DEVICE +half_t operator--(half_t& x, int) +{ + half_t y(x); + x = half_t(__hsub(x.to_fp16(), half_t(1.0f).to_fp16())); + return y; +} + +// math +CK_TILE_HOST_DEVICE +half_t abs(const half_t& x) { return half_t::bit_cast(x.get() & 0x7fff); } + +CK_TILE_HOST_DEVICE +bool isnan(const half_t& x) +{ + uint16_t xx = x.get(); + return (xx & 0x7FFF) > 0x7C00; +} + +CK_TILE_DEVICE +half_t sqrt(half_t x) +{ + return static_cast(__builtin_amdgcn_sqrtf(static_cast(x))); +}; + +CK_TILE_DEVICE +half_t exp(half_t x) { return static_cast(__expf(static_cast(x))); }; + +CK_TILE_DEVICE +half_t exp2(half_t x) { return static_cast(exp2f(static_cast(x))); }; + +CK_TILE_DEVICE +half_t log(half_t x) { return static_cast(__logf(static_cast(x))); }; + +using fp16_t = half_t; + +} // namespace ck_tile diff --git a/include/ck_tile/core/numeric/integer.hpp b/include/ck_tile/core/numeric/integer.hpp new file mode 100644 index 0000000000..3faf3020a6 --- /dev/null +++ b/include/ck_tile/core/numeric/integer.hpp @@ -0,0 +1,13 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once +#include + +namespace ck_tile { + +using index_t = int32_t; +using long_index_t = int64_t; +using int8_t = int8_t; + +} // namespace ck_tile diff --git a/include/ck_tile/core/numeric/integral_constant.hpp b/include/ck_tile/core/numeric/integral_constant.hpp new file mode 100644 index 0000000000..9021b30efd --- /dev/null +++ b/include/ck_tile/core/numeric/integral_constant.hpp @@ -0,0 +1,82 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core/config.hpp" +#include "ck_tile/core/numeric/integer.hpp" + +namespace ck_tile { + +template +struct constant +{ + using value_type = decltype(v); + using type = constant; // using injected-class-name + static constexpr value_type value = v; + constexpr CK_TILE_HOST_DEVICE operator value_type() const noexcept { return value; } + constexpr CK_TILE_HOST_DEVICE value_type operator()() const noexcept { return value; } +}; + +template +struct integral_constant : constant +{ + using value_type = T; + using type = integral_constant; // using injected-class-name + static constexpr T value = v; + // constexpr CK_TILE_HOST_DEVICE operator value_type() const noexcept { return value; } + // constexpr CK_TILE_HOST_DEVICE value_type operator()() const noexcept { return value; } // +}; + +template +using number = constant; + +template +using long_number = integral_constant; + +template +using bool_constant = constant; + +#define CK_TILE_LEFT_UNARY_OP(OP) \ + template \ + CK_TILE_HOST_DEVICE constexpr auto operator OP(constant) \ + { \ + return constant<(OP x)>{}; \ + } + +#define CK_TILE_BINARY_OP(OP) \ + template \ + CK_TILE_HOST_DEVICE constexpr auto operator OP(constant, constant) \ + { \ + return constant<(x OP y)>{}; \ + } + +CK_TILE_LEFT_UNARY_OP(+) +CK_TILE_LEFT_UNARY_OP(-) +CK_TILE_LEFT_UNARY_OP(~) +CK_TILE_LEFT_UNARY_OP(!) +CK_TILE_LEFT_UNARY_OP(*) + +CK_TILE_BINARY_OP(+) +CK_TILE_BINARY_OP(-) +CK_TILE_BINARY_OP(*) +CK_TILE_BINARY_OP(/) +CK_TILE_BINARY_OP(%) +CK_TILE_BINARY_OP(&) +CK_TILE_BINARY_OP(|) +CK_TILE_BINARY_OP(^) +CK_TILE_BINARY_OP(<<) +CK_TILE_BINARY_OP(>>) +CK_TILE_BINARY_OP(&&) +CK_TILE_BINARY_OP(||) +CK_TILE_BINARY_OP(==) +CK_TILE_BINARY_OP(!=) +CK_TILE_BINARY_OP(>) +CK_TILE_BINARY_OP(<) +CK_TILE_BINARY_OP(>=) +CK_TILE_BINARY_OP(<=) + +#undef CK_TILE_LEFT_UNARY_OP +#undef CK_TILE_BINARY_OP + +} // namespace ck_tile diff --git a/include/ck_tile/core/numeric/math.hpp b/include/ck_tile/core/numeric/math.hpp new file mode 100644 index 0000000000..98b3775285 --- /dev/null +++ b/include/ck_tile/core/numeric/math.hpp @@ -0,0 +1,309 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core/config.hpp" +#include "ck_tile/core/numeric/integer.hpp" +#include "ck_tile/core/numeric/integral_constant.hpp" +#include +#include + +namespace ck_tile { + +template +struct scales +{ + CK_TILE_HOST_DEVICE constexpr T operator()(T a) const { return s * a; } +}; + +template +struct plus +{ + CK_TILE_HOST_DEVICE constexpr T operator()(T a, T b) const { return a + b; } +}; + +template +struct minus +{ + CK_TILE_HOST_DEVICE constexpr T operator()(T a, T b) const { return a - b; } +}; + +struct multiplies +{ + template + CK_TILE_HOST_DEVICE constexpr auto operator()(const A& a, const B& b) const + { + return a * b; + } +}; + +template +struct maximize +{ + CK_TILE_HOST_DEVICE constexpr T operator()(T a, T b) const { return a >= b ? a : b; } +}; + +template +struct minimize +{ + CK_TILE_HOST_DEVICE constexpr T operator()(T a, T b) const { return a <= b ? a : b; } +}; + +template +struct integer_divide_ceiler +{ + CK_TILE_HOST_DEVICE constexpr T operator()(T a, T b) const + { + static_assert(std::is_same{} || std::is_same{}, "wrong type"); + return (a + b - number<1>{}) / b; + } +}; + +template +CK_TILE_HOST_DEVICE constexpr auto integer_divide_floor(X x, Y y) +{ + return x / y; +} + +template +CK_TILE_HOST_DEVICE constexpr auto integer_divide_ceil(X x, Y y) +{ + return (x + y - number<1>{}) / y; +} + +template +CK_TILE_HOST_DEVICE constexpr auto integer_least_multiple(X x, Y y) +{ + return y * integer_divide_ceil(x, y); +} + +template +CK_TILE_HOST_DEVICE constexpr T max(T x) +{ + return x; +} + +template +CK_TILE_HOST_DEVICE constexpr T max(T x, T y) +{ + return x > y ? x : y; +} + +template +CK_TILE_HOST_DEVICE constexpr index_t max(number, index_t y) +{ + return X > y ? X : y; +} + +template +CK_TILE_HOST_DEVICE constexpr index_t max(index_t x, number) +{ + return x > Y ? x : Y; +} + +template +CK_TILE_HOST_DEVICE constexpr auto max(X x, Ys... ys) +{ + static_assert(sizeof...(Ys) > 0, "not enough argument"); + return max(x, max(ys...)); +} + +template +CK_TILE_HOST_DEVICE constexpr T min(T x) +{ + return x; +} + +template +CK_TILE_HOST_DEVICE constexpr T min(T x, T y) +{ + return x < y ? x : y; +} + +template +CK_TILE_HOST_DEVICE constexpr index_t min(number, index_t y) +{ + return X < y ? X : y; +} + +template +CK_TILE_HOST_DEVICE constexpr index_t min(index_t x, number) +{ + return x < Y ? x : Y; +} + +template +CK_TILE_HOST_DEVICE constexpr auto min(X x, Ys... ys) +{ + static_assert(sizeof...(Ys) > 0, "not enough argument"); + return min(x, min(ys...)); +} + +template +CK_TILE_HOST_DEVICE constexpr T clamp(const T& x, const T& lowerbound, const T& upperbound) +{ + return min(max(x, lowerbound), upperbound); +} + +// greatest common divisor, aka highest common factor +CK_TILE_HOST_DEVICE constexpr index_t gcd(index_t x, index_t y) +{ + if(x < 0) + { + return gcd(-x, y); + } + else if(y < 0) + { + return gcd(x, -y); + } + else if(x == y || x == 0) + { + return y; + } + else if(y == 0) + { + return x; + } + else if(x > y) + { + return gcd(x % y, y); + } + else + { + return gcd(x, y % x); + } +} + +template +CK_TILE_HOST_DEVICE constexpr auto gcd(number, number) +{ + constexpr auto r = gcd(X, Y); + + return number{}; +} + +template = 2, bool>::type = false> +CK_TILE_HOST_DEVICE constexpr auto gcd(X x, Ys... ys) +{ + return gcd(x, gcd(ys...)); +} + +// least common multiple +template +CK_TILE_HOST_DEVICE constexpr auto lcm(X x, Y y) +{ + return (x * y) / gcd(x, y); +} + +template = 2, bool>::type = false> +CK_TILE_HOST_DEVICE constexpr auto lcm(X x, Ys... ys) +{ + return lcm(x, lcm(ys...)); +} + +template +struct equal +{ + CK_TILE_HOST_DEVICE constexpr bool operator()(T x, T y) const { return x == y; } +}; + +template +struct less +{ + CK_TILE_HOST_DEVICE constexpr bool operator()(T x, T y) const { return x < y; } +}; + +CK_TILE_HOST_DEVICE constexpr int32_t next_power_of_two(int32_t x) +{ + // TODO: x need to be 2 ~ 0x7fffffff. 0, 1, or larger than 0x7fffffff will compile fail + return 1 << (32 - __builtin_clz(x - 1)); +} + +template +CK_TILE_HOST_DEVICE constexpr auto next_power_of_two() +{ + constexpr index_t y = next_power_of_two(X); + return number{}; +} + +template +CK_TILE_HOST_DEVICE constexpr auto next_power_of_two(number) +{ + constexpr index_t y = next_power_of_two(X); + return number{}; +} + +CK_TILE_HOST_DEVICE constexpr int32_t integer_log2_floor(int32_t x) +{ + // TODO: x need to be 1 ~ 0x7fffffff + // __builtin_clz will produce unexpected result if x is 0; + return 31 - __builtin_clz(x); +} + +CK_TILE_HOST_DEVICE constexpr bool is_power_of_two_integer(int32_t x) +{ + // TODO: x need to be 1 ~ 0x7fffffff + return x == (1 << integer_log2_floor(x)); +} + +#ifndef C_LOG2E +#define C_LOG2E 1.44269504088896340736 // log2(e) +#endif + +template +struct log2e; + +template <> +struct log2e +{ + static constexpr double value = C_LOG2E; +}; + +template <> +struct log2e +{ + static constexpr float value = C_LOG2E; +}; + +template +inline constexpr T log2e_v = log2e::value; + +// math +CK_TILE_HOST_DEVICE +float abs(const float& x) +{ + union + { + float f32; + uint32_t u32; + } y; + y.f32 = x; + y.u32 = y.u32 & 0x7fffffff; + return y.f32; +} + +CK_TILE_HOST_DEVICE +bool isnan(const float& x) +{ + uint32_t xx = reinterpret_cast(x); + return (xx & 0x7fffffff) > 0x7F800000; +} + +CK_TILE_DEVICE +float sqrt(float x) { return __builtin_amdgcn_sqrtf(x); }; + +CK_TILE_DEVICE +float exp(float x) { return __expf(x); }; + +CK_TILE_DEVICE +float exp2(float x) { return exp2f(x); }; + +CK_TILE_DEVICE +float log(float x) { return __logf(x); }; + +} // namespace ck_tile diff --git a/include/ck_tile/core/numeric/type_convert.hpp b/include/ck_tile/core/numeric/type_convert.hpp new file mode 100644 index 0000000000..5eac399bf7 --- /dev/null +++ b/include/ck_tile/core/numeric/type_convert.hpp @@ -0,0 +1,45 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core/config.hpp" +#include + +namespace ck_tile { +#if 0 +// Convert X to Y, both X and Y are non-const data types. +template || std::is_const_v), bool> = false> +CK_TILE_HOST_DEVICE constexpr Y type_convert(X x) +{ + static_assert(!std::is_reference_v && !std::is_reference_v); + + return static_cast(x); +} + +// TODO: const version never called, we may never need +// Convert X to Y, either X or Y is a const data type. +template || std::is_const_v, bool> = false> +CK_TILE_HOST_DEVICE constexpr Y type_convert(X x) +{ + static_assert(!std::is_reference_v && !std::is_reference_v); + + using NonConstY = std::remove_const_t; + using NonConstX = std::remove_const_t; + return static_cast(type_convert(x)); +} +#else +// compatible way to call conversion operator and constructor of each custom data type +template +CK_TILE_HOST_DEVICE constexpr Y type_convert(X x) +{ + static_assert(!std::is_reference_v && !std::is_reference_v); + + return static_cast(x); +} +#endif +} // namespace ck_tile diff --git a/include/ck_tile/core/numeric/vector_type.hpp b/include/ck_tile/core/numeric/vector_type.hpp new file mode 100644 index 0000000000..56baba0567 --- /dev/null +++ b/include/ck_tile/core/numeric/vector_type.hpp @@ -0,0 +1,304 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core/config.hpp" +#include "ck_tile/core/container/array.hpp" +#include "ck_tile/core/numeric/integer.hpp" +#include "ck_tile/core/numeric/integral_constant.hpp" +#include "ck_tile/core/numeric/float8.hpp" +#include "ck_tile/core/numeric/half.hpp" +#include "ck_tile/core/numeric/bfloat16.hpp" + +namespace ck_tile { + +// TODO: the whole content of this file should consider deprecated! +template +struct vector_type +{ + static constexpr index_t N = N_; + using value_type = T_; + using type = value_type __attribute__((ext_vector_type(N))); // this is danguous + + CK_HOST_DEVICE constexpr vector_type() + { + for(auto i = 0; i < N; i++) + data[i] = static_cast(0); + } + CK_HOST_DEVICE constexpr vector_type(type v) + { + auto& r = reinterpret_cast&>(v); + for(auto i = 0; i < N; i++) + data[i] = r.get(i); + } + + value_type data[N]; + CK_HOST_DEVICE static constexpr auto size() { return N; } + CK_HOST_DEVICE auto& get() { return data; } + CK_HOST_DEVICE const auto& get() const { return data; } + CK_HOST_DEVICE auto& get(index_t i) { return data[i]; } + CK_HOST_DEVICE const auto& get(index_t i) const { return data[i]; } + + template + CK_HOST_DEVICE auto& operator[](number) + { + return data[I]; + } + template + CK_HOST_DEVICE const auto& operator[](number) const + { + return data[I]; + } + template + CK_HOST_DEVICE auto& operator()(number) + { + return data[I]; + } + + CK_HOST_DEVICE auto& at(index_t i) { return data[i]; } + CK_HOST_DEVICE const auto& at(index_t i) const { return data[i]; } + template + CK_HOST_DEVICE auto& at() + { + return data[I]; + } + template + CK_HOST_DEVICE const auto& at() const + { + return data[I]; + } + template + CK_HOST_DEVICE auto& at(number) + { + return data[I]; + } + template + CK_HOST_DEVICE const auto& at(number) const + { + return data[I]; + } + +#define _VT_COMMON_AS() \ + static_assert(sizeof(value_type) * N % sizeof(Tx) == 0); \ + constexpr int vx = sizeof(value_type) * N / sizeof(Tx) + + template + CK_HOST_DEVICE auto& get_as() + { + _VT_COMMON_AS(); + return reinterpret_cast&>(data); + } + template + CK_HOST_DEVICE const auto& get_as() const + { + _VT_COMMON_AS(); + return reinterpret_cast&>(data); + } + template + CK_HOST_DEVICE auto& get_as(index_t i) + { + _VT_COMMON_AS(); + return reinterpret_cast&>(data).get(i); + } + template + CK_HOST_DEVICE const auto& get_as(index_t i) const + { + _VT_COMMON_AS(); + return reinterpret_cast&>(data).get(i); + } +#undef _VT_COMMON_AS +}; + +template +struct vector_type_maker +{ + using type = vector_type; +}; + +template +struct vector_type_maker +{ + using type = vector_type; +}; + +template +struct vector_type_maker, N0> +{ + using type = vector_type; +}; + +template +using vector_type_maker_t = typename vector_type_maker::type; + +template +CK_HOST_DEVICE constexpr auto make_vector_type(number) +{ + return typename vector_type_maker::type{}; +} + +// scalar_type +template +struct scalar_type; + +// is_scalar_type +template +struct is_scalar_type +{ + static constexpr bool value = (scalar_type>::vector_size == 1); +}; + +// has_same_scalar_type +template +using has_same_scalar_type = is_same>::type, + typename scalar_type>::type>; + +template +struct scalar_type +{ + using type = T; + static constexpr index_t vector_size = N; +}; + +template +struct scalar_type> +{ + using type = T; + static constexpr index_t vector_size = N; +}; + +// +template <> +struct scalar_type +{ + using type = double; + static constexpr index_t vector_size = 1; +}; + +template <> +struct scalar_type +{ + using type = float; + static constexpr index_t vector_size = 1; +}; + +template <> +struct scalar_type +{ + using type = half_t; + static constexpr index_t vector_size = 1; +}; + +template <> +struct scalar_type +{ + using type = bhalf_t; + static constexpr index_t vector_size = 1; +}; + +template <> +struct scalar_type +{ + using type = int64_t; + static constexpr index_t vector_size = 1; +}; + +template <> +struct scalar_type +{ + using type = int32_t; + static constexpr index_t vector_size = 1; +}; + +template <> +struct scalar_type +{ + using type = int8_t; + static constexpr index_t vector_size = 1; +}; + +#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 +template <> +struct scalar_type +{ + using type = int4_t; + static constexpr index_t vector_size = 1; +}; +#endif + +template <> +struct scalar_type +{ + using type = fp8_t; + static constexpr index_t vector_size = 1; +}; + +template <> +struct scalar_type +{ + using type = bf8_t; + static constexpr index_t vector_size = 1; +}; + +// below are some pre-defines of ext_vector_type +// fp64 +using double2_t = typename vector_type::type; +using double4_t = typename vector_type::type; + +// fp32 +using float2_t = typename vector_type::type; +using float4_t = typename vector_type::type; +using float8_t = typename vector_type::type; +using float16_t = typename vector_type::type; +using float32_t = typename vector_type::type; +using float64_t = typename vector_type::type; + +// fp16 +using half2_t = typename vector_type::type; +using half4_t = typename vector_type::type; +using half8_t = typename vector_type::type; +using half16_t = typename vector_type::type; +using half32_t = typename vector_type::type; +using half64_t = typename vector_type::type; + +// bfp16 +using bhalf2_t = typename vector_type::type; +using bhalf4_t = typename vector_type::type; +using bhalf8_t = typename vector_type::type; +using bhalf16_t = typename vector_type::type; +using bhalf32_t = typename vector_type::type; +using bhalf64_t = typename vector_type::type; + +// i32 +using int32x2_t = typename vector_type::type; +using int32x4_t = typename vector_type::type; +using int32x8_t = typename vector_type::type; +using int32x16_t = typename vector_type::type; +using int32x32_t = typename vector_type::type; +using int32x64_t = typename vector_type::type; + +// i8 +using int8x2_t = typename vector_type::type; +using int8x4_t = typename vector_type::type; +using int8x8_t = typename vector_type::type; +using int8x16_t = typename vector_type::type; +using int8x32_t = typename vector_type::type; +using int8x64_t = typename vector_type::type; + +// f8 +using fp8x2_t = typename vector_type::type; +using fp8x4_t = typename vector_type::type; +using fp8x8_t = typename vector_type::type; +using fp8x16_t = typename vector_type::type; +using fp8x32_t = typename vector_type::type; +using fp8x64_t = typename vector_type::type; + +// bf8 +using bf8x2_t = typename vector_type::type; +using bf8x4_t = typename vector_type::type; +using bf8x8_t = typename vector_type::type; +using bf8x16_t = typename vector_type::type; +using bf8x32_t = typename vector_type::type; +using bf8x64_t = typename vector_type::type; + +} // namespace ck_tile diff --git a/include/ck_tile/core/tensor/buffer_view.hpp b/include/ck_tile/core/tensor/buffer_view.hpp new file mode 100644 index 0000000000..bf75f9bffc --- /dev/null +++ b/include/ck_tile/core/tensor/buffer_view.hpp @@ -0,0 +1,1041 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core/config.hpp" +#include "ck_tile/core/arch/amd_address_space.hpp" +#include "ck_tile/core/arch/amd_buffer_addressing.hpp" +#include "ck_tile/core/container/array.hpp" +#include "ck_tile/core/numeric/integer.hpp" +#include "ck_tile/core/numeric/integral_constant.hpp" +#include "ck_tile/core/numeric/float8.hpp" +#include "ck_tile/core/numeric/half.hpp" +#include "ck_tile/core/numeric/bfloat16.hpp" + +namespace ck_tile { + +// T may be scalar or vector +// X may be scalar or vector +// T and X have same scalar type +// X contains multiple T +// FIXME: InvalidElementUseNumericalZeroValue and invalid_element_value_ should be a property of +// transforms of tensor_view/Tensor +// FIXME: amd_buffer_coherence_enum is only meaningful for buffer addressing. Need to split +// BufferView definition for different memory address space (Global/GenericLds/Vgpr) +template +struct BufferView; + +// Address Space: generic +// T may be scalar or vector +// X may be scalar or vector +// T and X have same scalar type +// X contains multiple T +// FIXME: InvalidElementUseNumericalZeroValue and invalid_element_value_ should be a property of +// transforms of tensor_view/Tensor +template +struct buffer_view +{ + using type = T; + + T* p_data_ = nullptr; + BufferSizeType buffer_size_; + remove_cvref_t invalid_element_value_ = T{0}; + + CK_TILE_HOST_DEVICE constexpr buffer_view() + : p_data_{}, buffer_size_{}, invalid_element_value_{} + { + } + + CK_TILE_HOST_DEVICE constexpr buffer_view(T* p_data, BufferSizeType buffer_size) + : p_data_{p_data}, buffer_size_{buffer_size}, invalid_element_value_{0} + { + } + + CK_TILE_HOST_DEVICE constexpr buffer_view(T* p_data, + BufferSizeType buffer_size, + T invalid_element_value) + : p_data_{p_data}, buffer_size_{buffer_size}, invalid_element_value_{invalid_element_value} + { + } + + CK_TILE_DEVICE static constexpr address_space_enum get_address_space() + { + return address_space_enum::generic; + } + + // i is offset of T + // FIXME: doesn't do is_valid check + CK_TILE_DEVICE constexpr const T& operator[](index_t i) const { return p_data_[i]; } + + // i is offset of T + // FIXME: doesn't do is_valid check + CK_TILE_DEVICE constexpr T& operator()(index_t i) { return p_data_[i]; } + + // i is offset of T, not X. i should be aligned to X + template >::type, + typename scalar_type>::type>::value, + bool>::type = false> + CK_TILE_DEVICE constexpr auto + get(index_t i, bool is_valid_element, bool_constant = {}) const + { + // X contains multiple T + constexpr index_t scalar_per_t_vector = scalar_type>::vector_size; + + constexpr index_t scalar_per_x_vector = scalar_type>::vector_size; + + static_assert(scalar_per_x_vector % scalar_per_t_vector == 0, + "wrong! X should contain multiple T"); + + if(is_valid_element) + { +#if CK_EXPERIMENTAL_USE_MEMCPY_FOR_VECTOR_ACCESS + X tmp; + + __builtin_memcpy(&tmp, &(p_data_[i]), sizeof(X)); + + return tmp; +#else + return *c_style_pointer_cast(&p_data_[i]); +#endif + } + else + { + if constexpr(InvalidElementUseNumericalZeroValue) + { + return X{0}; + } + else + { + return X{invalid_element_value_}; + } + } + } + + // i is offset of T, not X. i should be aligned to X + template >::type, + typename scalar_type>::type>::value, + bool>::type = false> + CK_TILE_DEVICE void update(index_t i, bool is_valid_element, const X& x) + { + if constexpr(Op == InMemoryDataOperationEnum::set) + { + this->template set(i, is_valid_element, x); + } + // FIXME: remove InMemoryDataOperationEnum::Add + else if constexpr(Op == InMemoryDataOperationEnum::Add) + { + auto tmp = this->template get(i, is_valid_element); + this->template set(i, is_valid_element, x + tmp); + } + } + + // i is offset of T, not X. i should be aligned to X + template >::type, + typename scalar_type>::type>::value, + bool>::type = false> + CK_TILE_DEVICE void set(index_t i, bool is_valid_element, const X& x) + { + // X contains multiple T + constexpr index_t scalar_per_t_vector = scalar_type>::vector_size; + + constexpr index_t scalar_per_x_vector = scalar_type>::vector_size; + + static_assert(scalar_per_x_vector % scalar_per_t_vector == 0, + "wrong! X should contain multiple T"); + + if(is_valid_element) + { +#if CK_EXPERIMENTAL_USE_MEMCPY_FOR_VECTOR_ACCESS + X tmp = x; + + __builtin_memcpy(&(p_data_[i]), &tmp, sizeof(X)); +#else + *c_style_pointer_cast(&p_data_[i]) = x; +#endif + } + } + + // FIXME: remove + CK_TILE_DEVICE static constexpr bool is_static_buffer() { return false; } + + // FIXME: remove + CK_TILE_DEVICE static constexpr bool is_dynamic_buffer() { return true; } + + CK_TILE_HOST_DEVICE void print() const + { + printf("buffer_view{"); + + // AddressSpace + printf("AddressSpace: generic, "); + + // p_data_ + printf("p_data_: %p, ", static_cast(const_cast*>(p_data_))); + + // buffer_size_ + printf("buffer_size_: "); + print(buffer_size_); + printf(", "); + + // invalid_element_value_ + printf("invalid_element_value_: "); + print(invalid_element_value_); + + printf("}"); + } +}; + +// Address Space: Global +// T may be scalar or vector +// X may be scalar or vector +// T and X have same scalar type +// X contains multiple T +// FIXME: InvalidElementUseNumericalZeroValue and invalid_element_value_ should be a property of +// transforms of tensor_view/Tensor +template +struct buffer_view +{ + using type = T; + + T* p_data_ = nullptr; + BufferSizeType buffer_size_; + remove_cvref_t invalid_element_value_ = T{0}; + + CK_TILE_HOST_DEVICE constexpr buffer_view() + : p_data_{}, buffer_size_{}, invalid_element_value_{} + { + } + + CK_TILE_HOST_DEVICE constexpr buffer_view(T* p_data, BufferSizeType buffer_size) + : p_data_{p_data}, buffer_size_{buffer_size}, invalid_element_value_{0} + { + } + + CK_TILE_HOST_DEVICE constexpr buffer_view(T* p_data, + BufferSizeType buffer_size, + T invalid_element_value) + : p_data_{p_data}, buffer_size_{buffer_size}, invalid_element_value_{invalid_element_value} + { + } + + CK_TILE_DEVICE static constexpr address_space_enum get_address_space() + { + return address_space_enum::global; + } + + // i is offset of T + // FIXME: doesn't do is_valid check + CK_TILE_DEVICE constexpr const T& operator[](index_t i) const { return p_data_[i]; } + + // i is offset of T + // FIXME: doesn't do is_valid check + CK_TILE_DEVICE constexpr T& operator()(index_t i) { return p_data_[i]; } + + // i is offset of T, not X. i should be aligned to X + template >::type, + typename scalar_type>::type>::value, + bool>::type = false> + CK_TILE_DEVICE constexpr auto + get(index_t i, bool is_valid_element, bool_constant = {}) const + { + // X contains multiple T + constexpr index_t scalar_per_t_vector = scalar_type>::vector_size; + + constexpr index_t scalar_per_x_vector = scalar_type>::vector_size; + + static_assert(scalar_per_x_vector % scalar_per_t_vector == 0, + "wrong! X should contain multiple T"); + +#if CK_USE_AMD_BUFFER_LOAD + bool constexpr use_amd_buffer_addressing = true; +#else + bool constexpr use_amd_buffer_addressing = false; +#endif + + if constexpr(use_amd_buffer_addressing) + { + constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector; + + if constexpr(InvalidElementUseNumericalZeroValue) + { + return amd_buffer_load_invalid_element_return_zero, + t_per_x, + Coherence, + oob_conditional_check>( + p_data_, i, is_valid_element, buffer_size_); + } + else + { + return amd_buffer_load_invalid_element_return_customized_value< + remove_cvref_t, + t_per_x, + Coherence, + oob_conditional_check>( + p_data_, i, is_valid_element, buffer_size_, invalid_element_value_); + } + } + else + { + if(is_valid_element) + { +#if CK_EXPERIMENTAL_USE_MEMCPY_FOR_VECTOR_ACCESS + X tmp; + + __builtin_memcpy(&tmp, &(p_data_[i]), sizeof(X)); + + return tmp; +#else + return *c_style_pointer_cast(&p_data_[i]); +#endif + } + else + { + if constexpr(InvalidElementUseNumericalZeroValue) + { + return X{0}; + } + else + { + return X{invalid_element_value_}; + } + } + } + } + + // i is offset of T, not X. i should be aligned to X + template >::type, + typename scalar_type>::type>::value, + bool>::type = false> + CK_TILE_DEVICE constexpr auto + get_raw(remove_cvref_t& dst, index_t i, bool is_valid_element) const + { + constexpr index_t scalar_per_t_vector = scalar_type>::vector_size; + + constexpr index_t scalar_per_x_vector = scalar_type>::vector_size; + + static_assert(scalar_per_x_vector % scalar_per_t_vector == 0, + "wrong! X should contain multiple T"); + + constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector; + + amd_buffer_load_raw, t_per_x, Coherence, oob_conditional_check>( + dst, p_data_, i, buffer_size_, is_valid_element); + } + + // i is offset of T, not X. i should be aligned to X + template >::type, + typename scalar_type>::type>::value, + bool>::type = false> + CK_TILE_DEVICE constexpr auto + async_get(remove_cvref_t* smem, index_t i, bool /*is_valid_element*/) const + { + // X is vector of T + constexpr index_t scalar_per_t_vector = scalar_type>::vector_size; + constexpr index_t scalar_per_x_vector = scalar_type>::vector_size; + + static_assert(scalar_per_x_vector % scalar_per_t_vector == 0, + "wrong! X should contain multiple T"); + + constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector; + + amd_async_buffer_load_with_oob, t_per_x, Coherence>( + smem, p_data_, i, buffer_size_); + } + + // i is offset of T, not X. i should be aligned to X + template >::type, + typename scalar_type>::type>::value, + bool>::type = false> + CK_TILE_DEVICE void update(index_t i, bool is_valid_element, const X& x) + { + if constexpr(Op == InMemoryDataOperationEnum::set) + { + this->template set(i, is_valid_element, x); + } + else if constexpr(Op == InMemoryDataOperationEnum::atomic_add) + { + this->template atomic_add(i, is_valid_element, x); + } + else if constexpr(Op == InMemoryDataOperationEnum::atomic_max) + { + this->template atomic_max(i, is_valid_element, x); + } + // FIXME: remove InMemoryDataOperationEnum::Add + else if constexpr(Op == InMemoryDataOperationEnum::Add) + { + auto tmp = this->template get(i, is_valid_element); + this->template set(i, is_valid_element, x + tmp); + // tmp += x; + // this->template set(i, is_valid_element, tmp); + } + } + + // i is offset of T, not X. i should be aligned to X + template >::type, + typename scalar_type>::type>::value, + bool>::type = false> + CK_TILE_DEVICE void set(index_t i, bool is_valid_element, const X& x) + { + // X contains multiple T + constexpr index_t scalar_per_t_vector = scalar_type>::vector_size; + + constexpr index_t scalar_per_x_vector = scalar_type>::vector_size; + + static_assert(scalar_per_x_vector % scalar_per_t_vector == 0, + "wrong! X should contain multiple T"); + +#if CK_USE_AMD_BUFFER_STORE + bool constexpr use_amd_buffer_addressing = true; +#else + bool constexpr use_amd_buffer_addressing = false; +#endif + + if constexpr(use_amd_buffer_addressing) + { + constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector; + + amd_buffer_store, t_per_x, Coherence>( + x, p_data_, i, is_valid_element, buffer_size_); + } + else + { + if(is_valid_element) + { +#if CK_EXPERIMENTAL_USE_MEMCPY_FOR_VECTOR_ACCESS + X tmp = x; + + __builtin_memcpy(&(p_data_[i]), &tmp, sizeof(X)); +#else + *c_style_pointer_cast(&p_data_[i]) = x; +#endif + } + } + } + + // i is offset of T, not X. i should be aligned to X + template >::type, + typename scalar_type>::type>::value, + bool>::type = false> + CK_TILE_DEVICE void set_raw(index_t i, bool is_valid_element, const X& x) + { + // X contains multiple T + constexpr index_t scalar_per_t_vector = scalar_type>::vector_size; + + constexpr index_t scalar_per_x_vector = scalar_type>::vector_size; + + static_assert(scalar_per_x_vector % scalar_per_t_vector == 0, + "wrong! X should contain multiple T"); + + constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector; + amd_buffer_store_raw, t_per_x, Coherence, oob_conditional_check>( + x, p_data_, i, is_valid_element, buffer_size_); + } + + template >::type, + typename scalar_type>::type>::value, + bool>::type = false> + CK_TILE_DEVICE void atomic_add(index_t i, bool is_valid_element, const X& x) + { + using scalar_t = typename scalar_type>::type; + + // X contains multiple T + constexpr index_t scalar_per_t_vector = scalar_type>::vector_size; + + constexpr index_t scalar_per_x_vector = scalar_type>::vector_size; + + static_assert(scalar_per_x_vector % scalar_per_t_vector == 0, + "wrong! X should contain multiple T"); + + static_assert(get_address_space() == address_space_enum::global, "only support global mem"); + +#if CK_USE_AMD_BUFFER_ATOMIC_ADD_INTEGER && CK_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT + bool constexpr use_amd_buffer_addressing = + is_same_v, int32_t> || + is_same_v, float> || + (is_same_v, half_t> && scalar_per_x_vector % 2 == 0); +#elif CK_USE_AMD_BUFFER_ATOMIC_ADD_INTEGER && (!CK_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT) + bool constexpr use_amd_buffer_addressing = is_same_v, int32_t>; +#elif(!CK_USE_AMD_BUFFER_ATOMIC_ADD_INTEGER) && CK_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT + bool constexpr use_amd_buffer_addressing = + is_same_v, float> || + (is_same_v, half_t> && scalar_per_x_vector % 2 == 0); +#else + bool constexpr use_amd_buffer_addressing = false; +#endif + + if constexpr(use_amd_buffer_addressing) + { + constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector; + + amd_buffer_atomic_add, t_per_x>( + x, p_data_, i, is_valid_element, buffer_size_); + } + else + { + if(is_valid_element) + { + atomic_add(c_style_pointer_cast(&p_data_[i]), x); + } + } + } + + template >::type, + typename scalar_type>::type>::value, + bool>::type = false> + CK_TILE_DEVICE void atomic_max(index_t i, bool is_valid_element, const X& x) + { + // X contains multiple T + constexpr index_t scalar_per_t_vector = scalar_type>::vector_size; + + constexpr index_t scalar_per_x_vector = scalar_type>::vector_size; + + static_assert(scalar_per_x_vector % scalar_per_t_vector == 0, + "wrong! X should contain multiple T"); + + static_assert(get_address_space() == address_space_enum::global, "only support global mem"); + +#if CK_USE_AMD_BUFFER_ATOMIC_MAX_FLOAT64 + using scalar_t = typename scalar_type>::type; + bool constexpr use_amd_buffer_addressing = is_same_v, double>; +#else + bool constexpr use_amd_buffer_addressing = false; +#endif + + if constexpr(use_amd_buffer_addressing) + { + constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector; + + amd_buffer_atomic_max, t_per_x>( + x, p_data_, i, is_valid_element, buffer_size_); + } + else if(is_valid_element) + { + atomic_max(c_style_pointer_cast(&p_data_[i]), x); + } + } + + // FIXME: remove + CK_TILE_DEVICE static constexpr bool is_static_buffer() { return false; } + + // FIXME: remove + CK_TILE_DEVICE static constexpr bool is_dynamic_buffer() { return true; } + + CK_TILE_HOST_DEVICE void print() const + { + printf("buffer_view{"); + + // AddressSpace + printf("AddressSpace: Global, "); + + // p_data_ + printf("p_data_: %p, ", static_cast(const_cast*>(p_data_))); + + // buffer_size_ + printf("buffer_size_: "); + print(buffer_size_); + printf(", "); + + // invalid_element_value_ + printf("invalid_element_value_: "); + print(invalid_element_value_); + + printf("}"); + } +}; + +// Address Space: LDS +// T may be scalar or vector +// X may be scalar or vector +// T and X have same scalar type +// X contains multiple T +// FIXME: InvalidElementUseNumericalZeroValue and invalid_element_value_ should be a property of +// transforms of tensor_view/Tensor +template +struct buffer_view +{ + using type = T; + + T* p_data_ = nullptr; + BufferSizeType buffer_size_; + remove_cvref_t invalid_element_value_ = T{0}; + + CK_TILE_HOST_DEVICE constexpr buffer_view() + : p_data_{}, buffer_size_{}, invalid_element_value_{} + { + } + + CK_TILE_HOST_DEVICE constexpr buffer_view(T* p_data, BufferSizeType buffer_size) + : p_data_{p_data}, buffer_size_{buffer_size}, invalid_element_value_{0} + { + } + + CK_TILE_HOST_DEVICE constexpr buffer_view(T* p_data, + BufferSizeType buffer_size, + T invalid_element_value) + : p_data_{p_data}, buffer_size_{buffer_size}, invalid_element_value_{invalid_element_value} + { + } + + CK_TILE_DEVICE static constexpr address_space_enum get_address_space() + { + return address_space_enum::lds; + } + + // i is offset of T + // FIXME: doesn't do is_valid check + CK_TILE_DEVICE constexpr const T& operator[](index_t i) const { return p_data_[i]; } + + // i is offset of T + // FIXME: doesn't do is_valid check + CK_TILE_DEVICE constexpr T& operator()(index_t i) { return p_data_[i]; } + + // i is offset of T, not X. i should be aligned to X + template >::type, + typename scalar_type>::type>::value, + bool>::type = false> + CK_TILE_DEVICE constexpr auto + get(index_t i, bool is_valid_element, bool_constant = {}) const + { + // X contains multiple T + constexpr index_t scalar_per_t_vector = scalar_type>::vector_size; + + constexpr index_t scalar_per_x_vector = scalar_type>::vector_size; + + static_assert(scalar_per_x_vector % scalar_per_t_vector == 0, + "wrong! X should contain multiple T"); + + if(is_valid_element) + { +#if CK_EXPERIMENTAL_USE_MEMCPY_FOR_VECTOR_ACCESS + X tmp; + + __builtin_memcpy(&tmp, &(p_data_[i]), sizeof(X)); + + return tmp; +#else + return *c_style_pointer_cast(&p_data_[i]); +#endif + } + else + { + if constexpr(InvalidElementUseNumericalZeroValue) + { + return X{0}; + } + else + { + return X{invalid_element_value_}; + } + } + } + + // i is offset of T, not X. i should be aligned to X + template >::type, + typename scalar_type>::type>::value, + bool>::type = false> + CK_TILE_DEVICE void update(index_t i, bool is_valid_element, const X& x) + { + if constexpr(Op == InMemoryDataOperationEnum::set) + { + this->template set(i, is_valid_element, x); + } + // FIXME: remove InMemoryDataOperationEnum::Add + else if constexpr(Op == InMemoryDataOperationEnum::Add) + { + auto tmp = this->template get(i, is_valid_element); + this->template set(i, is_valid_element, x + tmp); + } + } + + // i is offset of T, not X. i should be aligned to X + template >::type, + typename scalar_type>::type>::value, + bool>::type = false> + CK_TILE_DEVICE void set(index_t i, bool is_valid_element, const X& x) + { + // X contains multiple T + constexpr index_t scalar_per_t_vector = scalar_type>::vector_size; + + constexpr index_t scalar_per_x_vector = scalar_type>::vector_size; + + static_assert(scalar_per_x_vector % scalar_per_t_vector == 0, + "wrong! X should contain multiple T"); + +#if CK_WORKAROUND_SWDEV_XXXXXX_INT8_DS_WRITE_ISSUE + bool constexpr workaround_int8_ds_write_issue = true; +#else + bool constexpr workaround_int8_ds_write_issue = false; +#endif + + if constexpr(is_same>::type, int8_t>::value && + workaround_int8_ds_write_issue) + { + if(is_valid_element) + { + // HACK: compiler would lower IR "store address_space(3)" into inefficient + // ISA, so I try to let compiler emit IR "store" which would be lower to + // ds_write_b128 + // TODO: remove this after compiler fix + static_assert((is_same, int8_t>::value && + is_same, int8_t>::value) || + (is_same, int8_t>::value && + is_same, int8x2_t>::value) || + (is_same, int8_t>::value && + is_same, int8x4_t>::value) || + (is_same, int8_t>::value && + is_same, int8x8_t>::value) || + (is_same, int8_t>::value && + is_same, int8x16_t>::value) || + (is_same, int8x4_t>::value && + is_same, int8x4_t>::value) || + (is_same, int8x8_t>::value && + is_same, int8x8_t>::value) || + (is_same, int8x16_t>::value && + is_same, int8x16_t>::value), + "wrong! not implemented for this combination, please add " + "implementation"); + + if constexpr(is_same, int8_t>::value && + is_same, int8_t>::value) + { + // HACK: cast pointer of x is bad + // TODO: remove this after compiler fix + *c_style_pointer_cast(&p_data_[i]) = + *c_style_pointer_cast(&x); + } + else if constexpr(is_same, int8_t>::value && + is_same, int8x2_t>::value) + { + // HACK: cast pointer of x is bad + // TODO: remove this after compiler fix + *c_style_pointer_cast(&p_data_[i]) = + *c_style_pointer_cast(&x); + } + else if constexpr(is_same, int8_t>::value && + is_same, int8x4_t>::value) + { + // HACK: cast pointer of x is bad + // TODO: remove this after compiler fix + *c_style_pointer_cast(&p_data_[i]) = + *c_style_pointer_cast(&x); + } + else if constexpr(is_same, int8_t>::value && + is_same, int8x8_t>::value) + { + // HACK: cast pointer of x is bad + // TODO: remove this after compiler fix + *c_style_pointer_cast(&p_data_[i]) = + *c_style_pointer_cast(&x); + } + else if constexpr(is_same, int8_t>::value && + is_same, int8x16_t>::value) + { + // HACK: cast pointer of x is bad + // TODO: remove this after compiler fix + *c_style_pointer_cast(&p_data_[i]) = + *c_style_pointer_cast(&x); + } + else if constexpr(is_same, int8x4_t>::value && + is_same, int8x4_t>::value) + { + // HACK: cast pointer of x is bad + // TODO: remove this after compiler fix + *c_style_pointer_cast(&p_data_[i]) = + *c_style_pointer_cast(&x); + } + else if constexpr(is_same, int8x8_t>::value && + is_same, int8x8_t>::value) + { + // HACK: cast pointer of x is bad + // TODO: remove this after compiler fix + *c_style_pointer_cast(&p_data_[i]) = + *c_style_pointer_cast(&x); + } + else if constexpr(is_same, int8x16_t>::value && + is_same, int8x16_t>::value) + { + // HACK: cast pointer of x is bad + // TODO: remove this after compiler fix + *c_style_pointer_cast(&p_data_[i]) = + *c_style_pointer_cast(&x); + } + } + } + else + { + if(is_valid_element) + { +#if CK_EXPERIMENTAL_USE_MEMCPY_FOR_VECTOR_ACCESS + X tmp = x; + + __builtin_memcpy(&(p_data_[i]), &tmp, sizeof(X)); +#else + *c_style_pointer_cast(&p_data_[i]) = x; +#endif + } + } + } + + // FIXME: remove + CK_TILE_DEVICE static constexpr bool is_static_buffer() { return false; } + + // FIXME: remove + CK_TILE_DEVICE static constexpr bool is_dynamic_buffer() { return true; } + + CK_TILE_HOST_DEVICE void print() const + { + printf("buffer_view{"); + + // AddressSpace + printf("AddressSpace: Lds, "); + + // p_data_ + printf("p_data_: %p, ", static_cast(const_cast*>(p_data_))); + + // buffer_size_ + printf("buffer_size_: "); + print(buffer_size_); + printf(", "); + + // invalid_element_value_ + printf("invalid_element_value_: "); + print(invalid_element_value_); + + printf("}"); + } +}; + +// Address Space: Vgpr +// T may be scalar or vector +// X may be scalar or vector +// T and X have same scalar type +// X contains multiple T +// FIXME: InvalidElementUseNumericalZeroValue and invalid_element_value_ should be a property of +// transforms of tensor_view/Tensor +template +struct buffer_view +{ + using type = T; + + T* p_data_ = nullptr; + BufferSizeType buffer_size_; + remove_cvref_t invalid_element_value_ = T{0}; + + CK_TILE_HOST_DEVICE constexpr buffer_view() + : p_data_{}, buffer_size_{}, invalid_element_value_{} + { + } + + CK_TILE_HOST_DEVICE constexpr buffer_view(T* p_data, BufferSizeType buffer_size) + : p_data_{p_data}, buffer_size_{buffer_size}, invalid_element_value_{0} + { + } + + CK_TILE_HOST_DEVICE constexpr buffer_view(T* p_data, + BufferSizeType buffer_size, + T invalid_element_value) + : p_data_{p_data}, buffer_size_{buffer_size}, invalid_element_value_{invalid_element_value} + { + } + + CK_TILE_DEVICE static constexpr address_space_enum get_address_space() + { + return address_space_enum::vgpr; + } + + // i is offset of T + // FIXME: doesn't do is_valid check + CK_TILE_DEVICE constexpr const T& operator[](index_t i) const { return p_data_[i]; } + + // i is offset of T + // FIXME: doesn't do is_valid check + CK_TILE_DEVICE constexpr T& operator()(index_t i) { return p_data_[i]; } + + // i is offset of T, not X. i should be aligned to X + template >::type, + typename scalar_type>::type>::value, + bool>::type = false> + CK_TILE_DEVICE constexpr auto + get(index_t i, bool is_valid_element, bool_constant = {}) const + { + // X contains multiple T + constexpr index_t scalar_per_t_vector = scalar_type>::vector_size; + + constexpr index_t scalar_per_x_vector = scalar_type>::vector_size; + + static_assert(scalar_per_x_vector % scalar_per_t_vector == 0, + "wrong! X should contain multiple T"); + + if(is_valid_element) + { +#if CK_EXPERIMENTAL_USE_MEMCPY_FOR_VECTOR_ACCESS + X tmp; + + __builtin_memcpy(&tmp, &(p_data_[i]), sizeof(X)); + + return tmp; +#else + return *c_style_pointer_cast(&p_data_[i]); +#endif + } + else + { + if constexpr(InvalidElementUseNumericalZeroValue) + { + return X{0}; + } + else + { + return X{invalid_element_value_}; + } + } + } + + // i is offset of T, not X. i should be aligned to X + template >::type, + typename scalar_type>::type>::value, + bool>::type = false> + CK_TILE_DEVICE void update(index_t i, bool is_valid_element, const X& x) + { + if constexpr(Op == InMemoryDataOperationEnum::set) + { + this->template set(i, is_valid_element, x); + } + // FIXME: remove InMemoryDataOperationEnum::Add + else if constexpr(Op == InMemoryDataOperationEnum::Add) + { + auto tmp = this->template get(i, is_valid_element); + this->template set(i, is_valid_element, x + tmp); + } + } + + // i is offset of T, not X. i should be aligned to X + template >::type, + typename scalar_type>::type>::value, + bool>::type = false> + CK_TILE_DEVICE void set(index_t i, bool is_valid_element, const X& x) + { + // X contains multiple T + constexpr index_t scalar_per_t_vector = scalar_type>::vector_size; + + constexpr index_t scalar_per_x_vector = scalar_type>::vector_size; + + static_assert(scalar_per_x_vector % scalar_per_t_vector == 0, + "wrong! X should contain multiple T"); + + if(is_valid_element) + { +#if CK_EXPERIMENTAL_USE_MEMCPY_FOR_VECTOR_ACCESS + X tmp = x; + + __builtin_memcpy(&(p_data_[i]), &tmp, sizeof(X)); +#else + *c_style_pointer_cast(&p_data_[i]) = x; +#endif + } + } + + // FIXME: remove + CK_TILE_DEVICE static constexpr bool is_static_buffer() { return false; } + + // FIXME: remove + CK_TILE_DEVICE static constexpr bool is_dynamic_buffer() { return true; } + + CK_TILE_HOST_DEVICE void print() const + { + printf("buffer_view{"); + + // AddressSpace + printf("AddressSpace: Vgpr, "); + + // p_data_ + printf("p_data_: %p, ", static_cast(const_cast*>(p_data_))); + + // buffer_size_ + printf("buffer_size_: "); + print(buffer_size_); + printf(", "); + + // invalid_element_value_ + printf("invalid_element_value_: "); + print(invalid_element_value_); + + printf("}"); + } +}; + +template +CK_TILE_HOST_DEVICE constexpr auto make_buffer_view(T* p, BufferSizeType buffer_size) +{ + return buffer_view{p, buffer_size}; +} + +template < + address_space_enum BufferAddressSpace, + amd_buffer_coherence_enum Coherence = amd_buffer_coherence_enum::coherence_default, + typename T, + typename BufferSizeType, + typename X, + typename enable_if, remove_cvref_t>::value, bool>::type = false> +CK_TILE_HOST_DEVICE constexpr auto +make_buffer_view(T* p, BufferSizeType buffer_size, X invalid_element_value) +{ + return buffer_view{ + p, buffer_size, invalid_element_value}; +} + +} // namespace ck_tile \ No newline at end of file diff --git a/include/ck_tile/core/tensor/load_tile.hpp b/include/ck_tile/core/tensor/load_tile.hpp new file mode 100644 index 0000000000..1d9c7c9c79 --- /dev/null +++ b/include/ck_tile/core/tensor/load_tile.hpp @@ -0,0 +1,78 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core/config.hpp" +#include "ck_tile/core/numeric/integer.hpp" +#include "ck_tile/core/numeric/integral_constant.hpp" +#include "ck_tile/core/algorithm/coordinate_transform.hpp" +#include "ck_tile/core/container/container_helper.hpp" +#include "ck_tile/core/numeric/math.hpp" +#include "ck_tile/core/tensor/tile_window.hpp" +#include "ck_tile/core/utility/type_traits.hpp" + +namespace ck_tile { + +template +CK_TILE_DEVICE auto load_tile(const tile_window_with_static_distribution& tile_window, + bool_constant = {}) +{ + return tile_window.load(bool_constant{}); +} + +template +CK_TILE_DEVICE auto load_tile_raw(T& tile, + const tile_window_with_static_distribution& tile_window, + bool_constant = {}) +{ + tile_window.load_raw(tile, bool_constant{}); +} + +template +CK_TILE_DEVICE auto +async_load_tile_raw(LdsTileWindow_&& lds_tile, + const tile_window_with_static_distribution& tile_window) +{ + return tile_window.async_load(lds_tile); +} + +CK_TILE_DEVICE auto async_load_fence(index_t cnt = 0) +{ + asm volatile("s_waitcnt vmcnt(%0)" : : "n"(cnt) : "memory"); +} + +template +CK_TILE_DEVICE auto load_tile(const NullTileWindow&) +{ + return NullTensor{}; +} + +template +CK_TILE_DEVICE auto load_tile_raw(T& /*null_tile*/, const NullTileWindow&) +{ +} + +} // namespace ck_tile diff --git a/include/ck_tile/core/tensor/null_tensor.hpp b/include/ck_tile/core/tensor/null_tensor.hpp new file mode 100644 index 0000000000..565ff87dff --- /dev/null +++ b/include/ck_tile/core/tensor/null_tensor.hpp @@ -0,0 +1,12 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +namespace ck_tile { + +struct null_tensor +{ +}; + +} // namespace ck_tile diff --git a/include/ck_tile/core/tensor/null_tile_window.hpp b/include/ck_tile/core/tensor/null_tile_window.hpp new file mode 100644 index 0000000000..ad7dd072dc --- /dev/null +++ b/include/ck_tile/core/tensor/null_tile_window.hpp @@ -0,0 +1,87 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core/config.hpp" +#include "ck_tile/core/numeric/integer.hpp" +#include "ck_tile/core/numeric/integral_constant.hpp" +#include "ck_tile/core/numeric/math.hpp" +#include "ck_tile/core/tensor/tile_window.hpp" +#include "ck_tile/core/utility/type_traits.hpp" + +namespace ck_tile { + +// placeholder type if we want to opt-out a tile window parameter +template +struct null_tile_window +{ + using BottomTensorView = null_tensor_view; + using WindowLengths = remove_cvref_t; + + using BottomTensorIndex = array; + + CK_TILE_DEVICE constexpr null_tile_window() = default; + + CK_TILE_DEVICE constexpr null_tile_window(const WindowLengths& window_lengths) + : window_lengths_{window_lengths} + { + } + + CK_TILE_DEVICE constexpr auto get_window_lengths() const { return window_lengths_; } + + CK_TILE_DEVICE constexpr auto get_bottom_tensor_view() const { return null_tensor_view{}; } + + CK_TILE_DEVICE constexpr auto get_window_origin() const { return BottomTensorIndex{}; } + + WindowLengths window_lengths_; +}; + +// utility to check if this is a Null Tile Window +namespace impl { +template +struct is_null_tile_window : public std::false_type +{ +}; + +template +struct is_null_tile_window> : public std::true_type +{ +}; +} // namespace impl + +template +CK_TILE_DEVICE constexpr auto is_null_tile_window(const T&) +{ + return impl::is_null_tile_window>::value; +} + +template +CK_TILE_DEVICE constexpr auto make_null_tile_window(const WindowLengths& window_lengths) +{ + static_assert(ck_tile::is_known_at_compile_time::value, + "wrong! lengths should be static"); + + return null_tile_window>{window_lengths}; +} + +template +CK_TILE_DEVICE constexpr auto make_tile_window(null_tensor_view, + const WindowLengths& window_lengths, + const multi_index& /*origin*/, + Ts&&...) +{ + static_assert(ck_tile::is_known_at_compile_time::value, + "wrong! lengths should be static"); + + return null_tile_window>{window_lengths}; +} + +template +CK_TILE_DEVICE void +move_tile_window(null_tile_window&, + const typename null_tile_window::BottomTensorIndex&) +{ +} + +} // namespace ck_tile diff --git a/include/ck_tile/core/tensor/shuffle_tile.hpp b/include/ck_tile/core/tensor/shuffle_tile.hpp new file mode 100644 index 0000000000..43a8a38e89 --- /dev/null +++ b/include/ck_tile/core/tensor/shuffle_tile.hpp @@ -0,0 +1,171 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core/config.hpp" +#include "ck_tile/core/numeric/integer.hpp" +#include "ck_tile/core/numeric/integral_constant.hpp" +#include "ck_tile/core/utility/functional.hpp" +#include "ck_tile/core/algorithm/coordinate_transform.hpp" +#include "ck_tile/core/container/container_helper.hpp" +#include "ck_tile/core/container/statically_indexed_array.hpp" +#include "ck_tile/core/numeric/math.hpp" +#include "ck_tile/core/utility/type_traits.hpp" +#include "ck_tile/core/tensor/tile_elementwise.hpp" + +namespace ck_tile { +namespace detail { + +template +CK_TILE_DEVICE void shuffle_tile_impl_in_thread(OutTensor& out_tensor, const InTensor& in_tensor) +{ + constexpr auto I0 = number<0>{}; + + using DataType = typename InTensor::DataType; + + constexpr auto y_in_desc = InTensor::get_tile_distribution().get_ys_to_d_descriptor(); + constexpr auto y_out_desc = OutTensor::get_tile_distribution().get_ys_to_d_descriptor(); + + // y_dim_out_to_in + constexpr auto get_rh_major_minor_to_y = [](auto dstr_tensor) { + using DstrEncode = typename decltype(dstr_tensor.get_tile_distribution())::DstrEncode; + + map, index_t> rh_major_minor_to_y_; + + static_for<0, DstrEncode::NDimY, 1>{}([&](auto i) { + constexpr index_t rh_major = DstrEncode::ys_to_rhs_major_[i]; + constexpr index_t rh_minor = DstrEncode::ys_to_rhs_minor_[i]; + + rh_major_minor_to_y_({rh_major, rh_minor}) = i; + }); + + return rh_major_minor_to_y_; + }; + + constexpr auto rh_major_minor_to_y_in = get_rh_major_minor_to_y(InTensor{}); + constexpr auto rh_major_minor_to_y_out = get_rh_major_minor_to_y(OutTensor{}); + + constexpr auto y_dim_out_to_in = [&] { + map y_dim_out_to_in_; + + for(const auto& [rh_major_minor, y_out] : rh_major_minor_to_y_out) + { + y_dim_out_to_in_(y_out) = rh_major_minor_to_y_in[rh_major_minor]; + } + + return y_dim_out_to_in_; + }(); + + // + constexpr index_t NDimY = InTensor::get_tile_distribution().GetNumOfDimensionY(); + + constexpr auto y_lengths = to_sequence(y_in_desc.get_lengths()); + + // input and output vector dim in the order of input Y dims + constexpr index_t y_dim_vec_in = NDimY - 1; + constexpr index_t y_dim_vec_out = y_dim_out_to_in[NDimY - 1]; + + // vector lengths + constexpr index_t vec_length_in = y_lengths[y_dim_vec_in]; + constexpr index_t vec_length_out = y_lengths[y_dim_vec_out]; + + // # of vectors + constexpr index_t num_vec_in = vec_length_out; + constexpr index_t num_vec_out = vec_length_in; + + using InVec = vector_type; + using OutVec = vector_type; + + using InVecType = typename InVec::type; + using OutVecType = typename OutVec::type; + + // SFC + constexpr auto scalars_per_access_arr = generate_array( + [&](auto i) { return (i == y_dim_vec_in or i == y_dim_vec_out) ? y_lengths[i] : 1; }, + number{}); + + constexpr auto scalars_per_access = TO_SEQUENCE(scalars_per_access_arr, NDimY); + + using SFC_Y = space_filling_curve::type, + decltype(scalars_per_access)>; + + constexpr index_t num_access = SFC_Y::get_num_of_access(); + + static_assert(num_access > 0, "wrong! num_access should be larger than 0"); + + // in/out vectors to be transposed + statically_indexed_array in_vectors; + statically_indexed_array out_vectors; + + // loop over SFC and do transpose + static_for<0, num_access, 1>{}([&](auto iAccess) { + // data index [y0, y1, ...] in the order of input tensor + constexpr auto idx_y_start = SFC_Y::get_index(iAccess); + + // get input vectors + static_for<0, num_vec_in, 1>{}([&](auto i) { + constexpr auto idx_y_in = generate_array( + [&](auto ii) { + return ii == y_dim_vec_out ? idx_y_start[ii] + i : idx_y_start[ii]; + }, + number{}); + + constexpr index_t in_offset = y_in_desc.calculate_offset(idx_y_in); + + in_vectors(i).template AsType()(I0) = + in_tensor.get_thread_buffer().template get_as(number{}); + }); + + // transpose + transpose_vectors{}(in_vectors, out_vectors); + + // set output vectors + static_for<0, num_vec_out, 1>{}([&](auto i) { + constexpr auto idx_y_out_tmp = generate_array( + [&](auto ii) { return ii == y_dim_vec_in ? idx_y_start[ii] + i : idx_y_start[ii]; }, + number{}); + + constexpr auto idx_y_out = + container_reorder_given_new2old(idx_y_out_tmp, y_dim_out_to_in); + + constexpr index_t out_offset = y_out_desc.calculate_offset(idx_y_out); + + out_tensor.get_thread_buffer().template set_as( + number{}, + out_vectors[i].template AsType()[I0]); + }); + }); +} + +} // namespace detail + +template +CK_TILE_DEVICE void shuffle_tile(OutTensor& out, const InTensor& in) +{ + using InDataType = typename InTensor::DataType; + using OutDataType = typename OutTensor::DataType; + + using InDstrEncode = typename InTensor::StaticTileDistribution::DstrEncode; + using OutDstrEncode = typename OutTensor::StaticTileDistribution::DstrEncode; + + // type convert + const auto in_tmp = tile_elementwise_in(type_convert, in); + + // shuffle + if constexpr(InDstrEncode::rs_lengths_ == OutDstrEncode::rs_lengths_ && + InDstrEncode::hs_lengthss_ == OutDstrEncode::hs_lengthss_ && + InDstrEncode::ps_to_rhss_major_ == OutDstrEncode::ps_to_rhss_major_ && + InDstrEncode::ps_to_rhss_minor_ == OutDstrEncode::ps_to_rhss_minor_ && + InDstrEncode::NDimY == OutDstrEncode::NDimY) + { + detail::shuffle_tile_impl_in_thread(out, in_tmp); + } + else + { + // NOT implemented + } +} + +} // namespace ck_tile diff --git a/include/ck_tile/core/tensor/slice_tile.hpp b/include/ck_tile/core/tensor/slice_tile.hpp new file mode 100644 index 0000000000..54d937a8d0 --- /dev/null +++ b/include/ck_tile/core/tensor/slice_tile.hpp @@ -0,0 +1,94 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core/config.hpp" +#include "ck_tile/core/numeric/integer.hpp" +#include "ck_tile/core/numeric/integral_constant.hpp" +#include "ck_tile/core/utility/functional.hpp" +#include "ck_tile/core/algorithm/coordinate_transform.hpp" +#include "ck_tile/core/container/container_helper.hpp" +#include "ck_tile/core/numeric/math.hpp" +#include "ck_tile/core/tensor/tile_window.hpp" +#include "ck_tile/core/utility/type_traits.hpp" + +namespace ck_tile { +namespace tile_program { + +template +CK_TILE_DEVICE constexpr auto +get_slice_tile(const tile_window_with_static_lengths& tile, + sequence slice_begins, + sequence slice_ends) +{ + using TileWindow = tile_window_with_static_lengths; + // NOTE: This API will override the origin of the tile window! + static_assert(sizeof...(SliceBegins) == sizeof...(SliceEnds)); + static_assert(sizeof...(SliceBegins) == TileWindow::get_num_of_dimension()); + + constexpr auto slice_lengths = slice_ends - slice_begins; + + return make_tile_window(tile.GetBottomTensorView(), + sequence_to_tuple_of_number(slice_lengths), + to_multi_index(slice_begins)); +} + +template +CK_TILE_DEVICE constexpr auto +get_slice_tile(const static_distributed_tensor& tile, + sequence slice_begins, + sequence slice_ends) +{ + using DataType = remove_cvref_t; + using Distribution = remove_cvref_t; + + constexpr auto sliced_dstr_yidx_ylen = + detail::slice_distribution_from_x(Distribution{}, slice_begins, slice_ends); + + constexpr auto sliced_dstr = sliced_dstr_yidx_ylen.template at<0>(); + constexpr auto sliced_y_origins = sliced_dstr_yidx_ylen.template at<1>(); + constexpr auto sliced_y_lengths = sliced_dstr_yidx_ylen.template at<2>(); + + auto sliced_tensor = make_static_distributed_tensor(sliced_dstr); + + sliced_tensor.get_thread_buffer() = + tile.get_y_sliced_thread_data(sliced_y_origins, sliced_y_lengths); + + return sliced_tensor; +} + +template +CK_TILE_DEVICE constexpr auto +set_slice_tile(static_distributed_tensor& dst_tile, + const static_distributed_tensor& src_tile, + sequence slice_begins, + sequence slice_ends) +{ + using DstDistribution = remove_cvref_t; + + constexpr auto sliced_dstr_yidx_ylen = + detail::slice_distribution_from_x(DstDistribution{}, slice_begins, slice_ends); + + constexpr auto sliced_dstr = sliced_dstr_yidx_ylen.template at<0>(); + constexpr auto sliced_y_origins = sliced_dstr_yidx_ylen.template at<1>(); + constexpr auto sliced_y_lengths = sliced_dstr_yidx_ylen.template at<2>(); + + static_assert(is_same_v, "wrong!"); + + dst_tile.SetSlicedThreadData(sliced_y_origins, sliced_y_lengths, src_tile.get_thread_buffer()); +} + +} // namespace tile_program +} // namespace ck_tile \ No newline at end of file diff --git a/include/ck_tile/core/tensor/static_distributed_tensor.hpp b/include/ck_tile/core/tensor/static_distributed_tensor.hpp new file mode 100644 index 0000000000..1f90ff2b95 --- /dev/null +++ b/include/ck_tile/core/tensor/static_distributed_tensor.hpp @@ -0,0 +1,180 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core/config.hpp" +#include "ck_tile/core/numeric/integer.hpp" +#include "ck_tile/core/numeric/integral_constant.hpp" +#include "ck_tile/core/algorithm/coordinate_transform.hpp" +#include "ck_tile/core/container/container_helper.hpp" +#include "ck_tile/core/numeric/math.hpp" +#include "ck_tile/core/utility/functional.hpp" +#include "ck_tile/core/utility/type_traits.hpp" + +namespace ck_tile { + +template +struct static_distributed_tensor +{ + using DataType = remove_cvref_t; + using StaticTileDistribution = remove_cvref_t; + + static_assert(StaticTileDistribution::is_static(), + "wrong! StaticTileDistribution should be known at compile tile"); + + using ThreadTensorDesc = + remove_cvref_t; + + static constexpr index_t kThreadElementSpaceSize = ThreadTensorDesc{}.get_element_space_size(); + + CK_TILE_HOST_DEVICE static constexpr auto get_num_of_dimension() + { + return StaticTileDistribution::get_num_of_dimension_x(); + } + + CK_TILE_HOST_DEVICE static constexpr auto get_lengths() + { + return StaticTileDistribution::get_lengths(); + } + + CK_TILE_HOST_DEVICE static constexpr auto get_tile_distribution() + { + return StaticTileDistribution{}; + } + + CK_TILE_HOST_DEVICE static constexpr auto get_distributed_spans() + { + return StaticTileDistribution::get_distributed_spans(); + } + + CK_TILE_HOST_DEVICE void initialize(const DataType& x) { thread_buf_.initialize(x); } + + CK_TILE_HOST_DEVICE constexpr const auto& get_thread_buffer() const { return thread_buf_; } + + CK_TILE_HOST_DEVICE constexpr auto& get_thread_buffer() { return thread_buf_; } + + CK_TILE_HOST_DEVICE static constexpr index_t get_thread_buffer_size() + { + return kThreadElementSpaceSize; + } + + template + CK_TILE_HOST_DEVICE auto get_y_sliced_thread_data(sequence, + sequence) const + { + static_assert(sizeof...(YSliceOrigins) == StaticTileDistribution::NDimY && + sizeof...(YSliceLengths) == StaticTileDistribution::NDimY, + "wrong!"); + + constexpr auto sliced_thread_tensor_desc = + make_naive_tensor_descriptor_packed(make_tuple(YSliceLengths...)); + + array sliced_thread_data; + + static_ford>{}([&](auto idx) { + constexpr auto idx_ys = idx + sequence{}; + + sliced_thread_data(number{}) = + thread_buf_[number{}]; + }); + + return sliced_thread_data; + } + + template + CK_TILE_HOST_DEVICE void + set_y_sliced_thread_data(sequence, + sequence, + const array& sliced_thread_data) + { + static_assert(sizeof...(YSliceOrigins) == StaticTileDistribution::NDimY && + sizeof...(YSliceLengths) == StaticTileDistribution::NDimY, + "wrong!"); + + constexpr auto sliced_thread_tensor_desc = + make_naive_tensor_descriptor_packed(make_tuple(YSliceLengths...)); + + static_ford>{}([&](auto idx) { + constexpr auto idx_ys = idx + sequence{}; + + thread_buf_(number{}) = + sliced_thread_data[number{}]; + }); + } + + template + CK_TILE_HOST_DEVICE constexpr const DataType& operator[](TileDistributedIndices) const + { + static_assert(is_static_v, + "wrong! Tile Distributed Indices should be static"); + + constexpr auto y_idx = get_tile_distribution().get_y_indices_from_distributed_indices( + TileDistributedIndices{}); + + return thread_buf_[number{}]; + } + + template + CK_TILE_HOST_DEVICE constexpr DataType& operator()(TileDistributedIndices) + { + static_assert(is_static_v, + "wrong! Tile Distributed Indices should be static"); + + constexpr auto y_idx = get_tile_distribution().get_y_indices_from_distributed_indices( + TileDistributedIndices{}); + + return thread_buf_(number{}); + } + + // + array thread_buf_; +}; + +template +CK_TILE_HOST_DEVICE constexpr auto make_static_distributed_tensor(const StaticTileDistribution&) +{ + return static_distributed_tensor, + remove_cvref_t>{}; +} + +// get X indices from tuple of tile_distributed_index<> +template +CK_TILE_HOST_DEVICE constexpr auto +get_x_indices_from_distributed_indices(StaticTileDistribution tile_distribution, + DistributedIndices distributed_indices) +{ + const auto partition_index = detail::get_partition_index(tile_distribution); + constexpr auto y_indices = + tile_distribution.get_y_indices_from_distributed_indices(distributed_indices); + + const auto x_coord = make_tensor_adaptor_coordinate( + tile_distribution.get_ps_ys_to_xs_adaptor(), + container_concat(partition_index, to_array(y_indices))); + + return x_coord.get_bottom_index(); +} + +template +CK_TILE_HOST_DEVICE void +set_tile_if(static_distributed_tensor& out_tensor, + DataType value, + XIndicesPredicate predicate) +{ + constexpr auto out_spans = + static_distributed_tensor::get_distributed_spans(); + sweep_tile_span(out_spans[number<0>{}], [&](auto idx0) { + sweep_tile_span(out_spans[number<1>{}], [&](auto idx1) { + constexpr auto distributed_indices = make_tuple(idx0, idx1); + const auto x_indices = get_x_indices_from_distributed_indices(StaticTileDistribution{}, + distributed_indices); + + if(predicate(x_indices)) + { + out_tensor(distributed_indices) = value; + } + }); + }); +} + +} // namespace ck_tile diff --git a/include/ck_tile/core/tensor/store_tile.hpp b/include/ck_tile/core/tensor/store_tile.hpp new file mode 100644 index 0000000000..6563f75a06 --- /dev/null +++ b/include/ck_tile/core/tensor/store_tile.hpp @@ -0,0 +1,93 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core/config.hpp" +#include "ck_tile/core/numeric/integer.hpp" +#include "ck_tile/core/numeric/integral_constant.hpp" +#include "ck_tile/core/algorithm/coordinate_transform.hpp" +#include "ck_tile/core/container/container_helper.hpp" +#include "ck_tile/core/numeric/math.hpp" +#include "ck_tile/core/tensor/tile_window.hpp" +#include "ck_tile/core/utility/type_traits.hpp" + +namespace ck_tile { + +template +CK_TILE_DEVICE void +store_tile(tile_window_with_static_lengths& tile_window_tmp, + const static_distributed_tensor& dstr_tensor) +{ + using DataType = remove_cvref_t; + using TileDstr = remove_cvref_t; + + static_assert(is_same_v, DataType>, "wrong!"); + + constexpr auto tile_dstr = TileDstr{}; + + auto tile_window = make_tile_window(tile_window_tmp.get_bottom_tensor_view(), + tile_window_tmp.get_window_lengths(), + tile_window_tmp.get_window_origin(), + tile_dstr); + + tile_window.store(dstr_tensor); +} + +template +CK_TILE_DEVICE void +store_tile_raw(tile_window_with_static_lengths& tile_window_tmp, + const static_distributed_tensor& dstr_tensor) +{ + using DataType = remove_cvref_t; + using TileDstr = remove_cvref_t; + + static_assert(is_same_v, DataType>, "wrong!"); + + constexpr auto tile_dstr = TileDstr{}; + + auto tile_window = make_tile_window(tile_window_tmp.get_bottom_tensor_view(), + tile_window_tmp.get_window_lengths(), + tile_window_tmp.get_window_origin(), + tile_dstr); + + tile_window.store_raw(dstr_tensor); +} + +template +CK_TILE_DEVICE void +store_tile(tile_window_with_static_distribution& tile_window, + const static_distributed_tensor& dstr_tensor) +{ + tile_window.store(dstr_tensor); +} + +template +CK_TILE_DEVICE void +store_tile_raw(tile_window_with_static_distribution& tile_window, + const static_distributed_tensor& dstr_tensor) +{ + tile_window.store_raw(dstr_tensor); +} + +} // namespace ck_tile diff --git a/include/ck_tile/core/tensor/sweep_tile.hpp b/include/ck_tile/core/tensor/sweep_tile.hpp new file mode 100644 index 0000000000..f1511f11d2 --- /dev/null +++ b/include/ck_tile/core/tensor/sweep_tile.hpp @@ -0,0 +1,30 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core/config.hpp" +#include "ck_tile/core/numeric/integer.hpp" +#include "ck_tile/core/numeric/integral_constant.hpp" +#include "ck_tile/core/tensor/tile_distribution.hpp" +#include "ck_tile/core/utility/functional.hpp" +#include "ck_tile/core/utility/type_traits.hpp" + +namespace ck_tile { + +// sweep over a span of a distribted tile and apply lambda function F +template + typename F // signature: F(tile_distributed_index<...>) + > +CK_TILE_DEVICE void sweep_tile_span(TileDistributedSpan_, const F& f) +{ + using DstrSpan = remove_cvref_t; + + static_ford{}([&](auto dstr_idx_impl) { + constexpr auto dstr_idx = detail::make_tile_distributed_index(dstr_idx_impl); + + f(dstr_idx); + }); +} + +} // namespace ck_tile diff --git a/include/ck_tile/core/tensor/tensor_adaptor.hpp b/include/ck_tile/core/tensor/tensor_adaptor.hpp new file mode 100644 index 0000000000..872bf6531e --- /dev/null +++ b/include/ck_tile/core/tensor/tensor_adaptor.hpp @@ -0,0 +1,942 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core/config.hpp" +#include "ck_tile/core/numeric/integer.hpp" +#include "ck_tile/core/numeric/integral_constant.hpp" +#include "ck_tile/core/algorithm/coordinate_transform.hpp" +#include "ck_tile/core/container/container_helper.hpp" +#include "ck_tile/core/numeric/math.hpp" +#include "ck_tile/core/utility/type_traits.hpp" + +namespace ck_tile { + +// Transforms: Tuple +// LowerDimensionHiddenIdss : Tuple, ...> +// UpperDimensionHiddenIdss : Tuple, ...> +// BottomDimensionHiddenIds : Sequence<...> +// TopDimensionHiddenIds : Sequence<...> +template +struct tensor_adaptor +{ + CK_TILE_HOST_DEVICE static constexpr index_t get_num_of_transform() + { + return Transforms::size(); + } + + CK_TILE_HOST_DEVICE constexpr const auto& get_transforms() const { return transforms_; } + + CK_TILE_HOST_DEVICE static constexpr auto get_lower_dimension_hidden_idss() + { + return LowerDimensionHiddenIdss{}; + } + + CK_TILE_HOST_DEVICE static constexpr auto get_upper_dimension_hidden_idss() + { + return UpperDimensionHiddenIdss{}; + } + + CK_TILE_HOST_DEVICE static constexpr auto get_bottom_dimension_hidden_ids() + { + return BottomDimensionHiddenIds{}; + } + + CK_TILE_HOST_DEVICE static constexpr auto get_top_dimension_hidden_ids() + { + return TopDimensionHiddenIds{}; + } + + CK_TILE_HOST_DEVICE static constexpr auto initialize_element_size(const Transforms& transforms) + { + const auto lengths = generate_tuple( + [&](auto idim_top) { + constexpr index_t idim_hidden = TopDimensionHiddenIds::at(idim_top); + + constexpr auto tmp = get_transform_and_its_upper_dimension(number{}); + + constexpr index_t itran = tmp[number<0>{}]; + constexpr index_t idim_up = tmp[number<1>{}]; + constexpr bool found = tmp[number<2>{}]; + + static_assert(found == true, + "wrong! not found matching transformation and upper-dimension"); + + const auto length = + transforms[number{}].get_upper_lengths()[number{}]; + + return length; + }, + number{}); + + // TODO: make container_reduce support tuple of number and index_t + return container_reduce(lengths, multiplies{}, number<1>{}); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto + get_transform_and_its_upper_dimension(number) + { + // FIXME: length of bottom dimension is not known, since info about lower dim length are not + // saved in transformation + static_assert(IDimHidden >= ndim_bottom_, "wrong! not implemented"); + + index_t itran_found = 0; + index_t idim_up_found = 0; + bool found = false; + + static_for<0, ntransform_, 1>{}([&](auto itran) { + constexpr auto up_dim_ids = UpperDimensionHiddenIdss{}[itran]; + + static_for<0, up_dim_ids.size(), 1>{}([&](auto idim_up) { + if constexpr(up_dim_ids[idim_up] == IDimHidden) + { + itran_found = itran; + idim_up_found = idim_up; + found = true; + } + }); + }); + + return make_tuple(itran_found, idim_up_found, found); + } + + CK_TILE_HOST_DEVICE static constexpr index_t get_num_of_bottom_dimension() + { + return BottomDimensionHiddenIds::size(); + } + + CK_TILE_HOST_DEVICE static constexpr index_t get_num_of_top_dimension() + { + return TopDimensionHiddenIds::size(); + } + + CK_TILE_HOST_DEVICE static constexpr index_t get_num_of_hidden_dimension() + { + constexpr auto all_low_dim_ids = unpack( + [](auto&&... xs) constexpr { return merge_sequences(xs...); }, + LowerDimensionHiddenIdss{}); + + constexpr auto all_up_dim_ids = unpack( + [](auto&&... xs) constexpr { return merge_sequences(xs...); }, + UpperDimensionHiddenIdss{}); + + constexpr auto all_dim_ids = merge_sequences(all_low_dim_ids, all_up_dim_ids); + + using unique_sort_all_dim_ids = typename sequence_unique_sort, + equal>::type; + + return unique_sort_all_dim_ids::size(); + } + + constexpr static index_t ntransform_ = get_num_of_transform(); + constexpr static index_t ndim_hidden_ = get_num_of_hidden_dimension(); + constexpr static index_t ndim_bottom_ = get_num_of_bottom_dimension(); + constexpr static index_t ndim_top_ = get_num_of_top_dimension(); + + using HiddenIndex = multi_index; + using BottomIndex = multi_index; + using TopIndex = multi_index; + + // may be index_t or number<> + using ElementSize = remove_cv_t; + + public: + CK_TILE_HOST_DEVICE constexpr tensor_adaptor() = default; + + CK_TILE_HOST_DEVICE constexpr tensor_adaptor(const Transforms& transforms) + : transforms_{transforms}, element_size_{initialize_element_size(transforms)} + { + static_assert(Transforms::size() == ntransform_ && + LowerDimensionHiddenIdss::size() == ntransform_ && + UpperDimensionHiddenIdss::size() == ntransform_, + "wrong! inconsistent # of transformations"); + + // TODO check dependency of dimensions is valid + } + + CK_TILE_HOST_DEVICE constexpr auto get_element_size() const { return element_size_; } + + // FIXME: this logic is wrong when getting bottome dimension lengths + template + CK_TILE_HOST_DEVICE constexpr auto get_hidden_dimension_length(number) const + { + static_assert(IDimHidden >= 0 && IDimHidden < ndim_hidden_, "wrong! out of range"); + + constexpr auto tmp = get_transform_and_its_upper_dimension(number{}); + + constexpr index_t itran = tmp[number<0>{}]; + constexpr index_t idim_up = tmp[number<1>{}]; + constexpr bool found = tmp[number<2>{}]; + + static_assert(found == true, + "wrong! not found matching transformation and upper-dimension"); + + return transforms_[number{}].get_upper_lengths()[number{}]; + } + + template + CK_TILE_HOST_DEVICE constexpr auto get_top_dimension_length(number idim_top) const + { + return get_hidden_dimension_length(TopDimensionHiddenIds::at(idim_top)); + } + +#if 0 + // FIXME: get_hidden_dimension_length is wrong when getting bottome dimension lengths + template + CK_TILE_HOST_DEVICE constexpr index_t + get_bottom_dimension_length(number idim_bottom) const + { + return get_hidden_dimension_length(TopDimensionHiddenIds::at(idim_bottom)); + } +#endif + + CK_TILE_HOST_DEVICE constexpr auto get_top_dimension_lengths() const + { + return generate_tuple([&](auto i) { return get_top_dimension_length(i); }, + number{}); + } + +#if 0 + // FIXME: get_hidden_dimension_length is wrong when getting bottome dimension lengths + CK_TILE_HOST_DEVICE constexpr auto GetBottomDimensionLengths() const + { + return generate_tuple([&](auto i) { return get_bottom_dimension_length(i); }, + number{}); + } +#endif + + template + CK_TILE_HOST_DEVICE constexpr auto calculate_bottom_index(const TopIdx& idx_top) const + { + static_assert(TopIdx::size() == TopDimensionHiddenIds::size(), + "wrong! # of dimension inconsistent"); + + constexpr index_t ntransform = get_num_of_transform(); + constexpr index_t ndim_hidden = get_num_of_hidden_dimension(); + + multi_index idx_hidden; + + // initialize uppest index + set_container_subset(idx_hidden, get_top_dimension_hidden_ids(), idx_top); + + // calculate hidden index + static_for{}([&](auto itran_p1) { + auto itran = itran_p1 - number<1>{}; + const auto& tran = get_transforms().at(itran); + constexpr auto dims_low = get_lower_dimension_hidden_idss().at(itran); + constexpr auto dims_up = get_upper_dimension_hidden_idss().at(itran); + + const auto idx_up = get_container_subset(idx_hidden, dims_up); + + multi_index idx_low; + + tran.calculate_lower_index(idx_low, idx_up); + + set_container_subset(idx_hidden, dims_low, idx_low); + }); + + return get_container_subset(idx_hidden, BottomDimensionHiddenIds{}); + } + + CK_TILE_HOST_DEVICE static constexpr bool is_static() + { + bool is_known = true; + + static_for<0, Transforms::size(), 1>{}([&](auto i) { + is_known &= remove_cvref_t::is_known_at_compile_time(); + }); + + return is_known && ck_tile::is_known_at_compile_time::value; + } + + CK_TILE_HOST_DEVICE static constexpr bool is_known_at_compile_time() { return is_static(); } + + CK_TILE_HOST_DEVICE static constexpr auto get_top_dimension_safe_vector_length_strides( + const array& guaranteed_vector_lengths, + const array& guaranteed_vector_strides) + { + auto vector_lengths = guaranteed_vector_lengths; + auto vector_strides = guaranteed_vector_strides; + + static_for<0, get_num_of_transform(), 1>{}([&](auto itran) { + constexpr auto low_dims = get_lower_dimension_hidden_idss().at(itran); + constexpr auto up_dims = get_upper_dimension_hidden_idss().at(itran); + + const auto up_guaranteed_vector_lengths = + get_container_subset(guaranteed_vector_lengths, up_dims); + const auto up_guaranteed_vector_strides = + get_container_subset(guaranteed_vector_strides, up_dims); + + // only need type of transform + auto [up_vector_lengths, up_vector_strides] = + Transforms{}.at(itran).calculate_upper_dimension_safe_vector_length_strides( + get_container_subset(vector_lengths, low_dims), + get_container_subset(vector_strides, low_dims)); + + if constexpr(up_dims.size() > 0) + { + for(index_t i = 0; i < up_dims.size(); ++i) + { + up_vector_lengths(i) = (up_guaranteed_vector_lengths[i] != -1) + ? up_guaranteed_vector_lengths[i] + : up_vector_lengths[i]; + + up_vector_strides(i) = (up_guaranteed_vector_strides[i] != -1) + ? up_guaranteed_vector_strides[i] + : up_vector_strides[i]; + } + } + + set_container_subset(vector_lengths, up_dims, up_vector_lengths); + set_container_subset(vector_strides, up_dims, up_vector_strides); + }); + + constexpr auto top_dims = TopDimensionHiddenIds{}; + + return make_tuple(get_container_subset(vector_lengths, top_dims), + get_container_subset(vector_strides, top_dims)); + } + + CK_TILE_HOST_DEVICE void print() const + { + printf("tensor_adaptor{"); + + // + printf("transforms: "); + print(transforms_); + printf(", "); + + // + printf("LowerDimensionHiddenIds: "); + print(LowerDimensionHiddenIdss{}); + printf(", "); + + // + printf("UpperDimensionHiddenIds: "); + print(UpperDimensionHiddenIdss{}); + printf(", "); + + // + printf("BottomDimensionHiddenIds: "); + print(BottomDimensionHiddenIds{}); + printf(", "); + + // + printf("TopDimensionHiddenIds: "); + print(TopDimensionHiddenIds{}); + + printf("}"); + } + + private: + Transforms transforms_; + ElementSize element_size_; +}; + +// Transforms: Tuple +// LowerDimensionOldTopIdss: Tuple, ...> +// UpperDimensionNewTopIdss: Tuple, ...> +template +CK_TILE_HOST_DEVICE constexpr auto make_single_stage_tensor_adaptor(const Transforms& transforms, + LowerDimensionOldTopIdss, + UpperDimensionNewTopIdss) +{ + constexpr index_t ntransform = Transforms::size(); + + static_assert(LowerDimensionOldTopIdss::size() == ntransform && + UpperDimensionNewTopIdss::size() == ntransform, + "wrong!"); + + // sanity check on LowerDimensionOldTopIdss and UpperDimensionNewTopIdss + constexpr auto all_low_dim_old_top_ids = unpack( + [](auto&&... xs) constexpr { return merge_sequences(xs...); }, LowerDimensionOldTopIdss{}); + + constexpr auto all_up_dim_new_top_ids = unpack( + [](auto&&... xs) constexpr { return merge_sequences(xs...); }, UpperDimensionNewTopIdss{}); + + static_assert(is_valid_sequence_map::value && + is_valid_sequence_map::value, + "wrong!"); + + constexpr index_t ndim_old_top = all_low_dim_old_top_ids.size(); + constexpr index_t ndim_new_top = all_up_dim_new_top_ids.size(); + + // low_dim_hidden_idss + constexpr auto low_dim_hidden_idss = LowerDimensionOldTopIdss{}; + + // up_dim_hidden_idss: shift UpperDimensionNewTopIdss by ndim_bottom + constexpr auto up_dim_hidden_idss = generate_tuple( + [](auto itran) { return UpperDimensionNewTopIdss{}[itran] + number{}; }, + number{}); + + // bottom_dim_hidden_ids + constexpr auto bottom_dim_hidden_ids = + typename arithmetic_sequence_gen<0, ndim_old_top, 1>::type{}; + + // top_dim_hidden_ids + constexpr auto top_dim_hidden_ids = + typename arithmetic_sequence_gen<0, ndim_new_top, 1>::type{} + number{}; + + return tensor_adaptor, + remove_cvref_t, + remove_cvref_t, + remove_cvref_t, + remove_cvref_t>{transforms}; +} + +// TODO: How to fix this? It uses an struct instead of lambda because lambda +// doesn't have constructor, and to put it outside the scope where it is used +// (transform_tensor_adaptor) because template cannot be defined inside a function +// template +template +struct lambda_get_up_dim_num +{ + template + CK_TILE_HOST_DEVICE constexpr auto operator()(I) const + { + using Tran = remove_reference_t; + return number{}; + } +}; + +template +CK_TILE_HOST_DEVICE constexpr auto +transform_tensor_adaptor(const OldTensorAdaptor& old_tensor_adaptor, + const NewTransforms& new_transforms, + NewLowerDimensionOldTopIdss, + NewUpperDimensionNewTopIdss) +{ + // sanity check + { + static_assert(NewTransforms::size() == NewLowerDimensionOldTopIdss::size() && + NewTransforms::size() == NewUpperDimensionNewTopIdss::size(), + "wrong! inconsitent number of transform"); + + constexpr auto all_old_top_ids = unpack([](auto... xs) { return merge_sequences(xs...); }, + NewLowerDimensionOldTopIdss{}); + + constexpr auto all_new_top_ids = unpack([](auto... xs) { return merge_sequences(xs...); }, + NewUpperDimensionNewTopIdss{}); + + static_assert(is_valid_sequence_map::value && + is_valid_sequence_map::value, + "wrong!"); + } + + // lower dimension's hidden idss + // convert lower dimension top idss (tuple of sequences) to hidden idss (tuple of + // sequences) + constexpr auto low_dim_hidden_idss = transform_tuples( + // convert lower dimension top ids (a sequence) to hidden ids (a sequence) + [](auto low_dim_top_ids) constexpr { + return transform_sequences( + // convert lower dimension top id to hidden id + [](auto low_dim_top_id) constexpr { + return OldTensorAdaptor::get_top_dimension_hidden_ids()[low_dim_top_id]; + }, + low_dim_top_ids); + }, + NewLowerDimensionOldTopIdss{}); + + constexpr index_t num_new_transform = NewTransforms::size(); + + // upper dimension's hidden idss + constexpr index_t old_hidden_dim_number = OldTensorAdaptor::get_num_of_hidden_dimension(); + + constexpr auto up_dim_numbers = + generate_sequence(lambda_get_up_dim_num{}, number{}); + + constexpr auto up_dim_numbers_scan = merge_sequences( + Sequence<0>{}, inclusive_scan_sequence(up_dim_numbers, plus{}, number<0>{})); + + constexpr auto up_dim_hidden_idss = generate_tuple( + [ old_hidden_dim_number, up_dim_numbers_scan ](auto i) constexpr { + return + typename arithmetic_sequence_gen::type{}; + }, + number{}); + + // new top dimension's hidden ids + constexpr auto unordered_new_top_dim_hidden_ids = unpack( + [](auto... xs) constexpr { return merge_sequences(xs...); }, up_dim_hidden_idss); + + constexpr auto new_top_dim_unordered2ordered = unpack( + [](auto... xs) constexpr { return merge_sequences(xs...); }, NewUpperDimensionNewTopIdss{}); + + constexpr auto new_top_dim_hidden_ids = + unordered_new_top_dim_hidden_ids.reorder_old_to_new(new_top_dim_unordered2ordered); + + // put everything together + const auto all_transforms = + container_concat(old_tensor_adaptor.get_transforms(), new_transforms); + + constexpr auto all_low_dim_hidden_idss = + container_concat(OldTensorAdaptor::get_lower_dimension_hidden_idss(), low_dim_hidden_idss); + + constexpr auto all_up_dim_hidden_idss = + container_concat(OldTensorAdaptor::get_upper_dimension_hidden_idss(), up_dim_hidden_idss); + + return tensor_adaptor< + remove_cvref_t, + remove_cvref_t, + remove_cvref_t, + remove_cvref_t, + remove_cvref_t>{all_transforms}; +} + +template +CK_TILE_HOST_DEVICE constexpr auto chain_tensor_adaptors(const TensorAdaptor0& adaptor0, + const TensorAdaptor1& adaptor1) +{ + static_assert(TensorAdaptor0::get_num_of_top_dimension() == + TensorAdaptor1::get_num_of_bottom_dimension(), + "wrong!"); + + // all_transforms = transform0 + transform1 + const auto all_transforms = + container_concat(adaptor0.get_transforms(), adaptor1.get_transforms()); + + // shift + constexpr index_t adaptor0_max_hidden_id = [&]() { + index_t adaptor0_max_hidden_id_ = NumericLimits::Min(); + + static_for<0, TensorAdaptor0::get_num_of_transform(), 1>{}([&](auto itran) { + constexpr index_t ndim_low = + TensorAdaptor0{}.get_transforms()[itran].get_num_of_lower_dimension(); + + static_for<0, ndim_low, 1>{}([&](auto idim_low) { + adaptor0_max_hidden_id_ = + max(adaptor0_max_hidden_id_, + TensorAdaptor0::get_lower_dimension_hidden_idss()[itran][idim_low].value); + }); + + constexpr index_t ndim_up = + TensorAdaptor0{}.get_transforms()[itran].get_num_of_upper_dimension(); + + static_for<0, ndim_up, 1>{}([&](auto idim_up) { + adaptor0_max_hidden_id_ = + max(adaptor0_max_hidden_id_, + TensorAdaptor0::get_upper_dimension_hidden_idss()[itran][idim_up].value); + }); + }); + + return adaptor0_max_hidden_id_; + }(); + + constexpr index_t adaptor1_min_hidden_id = [&]() { + index_t adaptor1_min_hidden_id_ = NumericLimits::Max(); + + static_for<0, TensorAdaptor1::get_num_of_transform(), 1>{}([&](auto itran) { + constexpr index_t ndim_low = + TensorAdaptor1{}.get_transforms()[itran].get_num_of_lower_dimension(); + + // get the min of all lower dimenions, but not bottom dimension (because their id will + // be matched with top id from adaptor0) + static_for<0, ndim_low, 1>{}([&](auto idim_low) { + constexpr index_t low_dim_hidden_id = + TensorAdaptor1::get_lower_dimension_hidden_idss()[itran][idim_low].value; + + bool is_bottom_dim = false; + static_for<0, TensorAdaptor1::get_num_of_bottom_dimension(), 1>{}([&](auto i) { + if constexpr(low_dim_hidden_id == + TensorAdaptor1::get_bottom_dimension_hidden_ids()[i]) + { + is_bottom_dim = true; + } + }); + + if(!is_bottom_dim) + { + adaptor1_min_hidden_id_ = min(adaptor1_min_hidden_id_, low_dim_hidden_id); + } + }); + + constexpr index_t ndim_up = + TensorAdaptor1{}.get_transforms()[itran].get_num_of_upper_dimension(); + + // get the min of all upper dimensions + static_for<0, ndim_up, 1>{}([&](auto idim_up) { + adaptor1_min_hidden_id_ = + min(adaptor1_min_hidden_id_, + TensorAdaptor1::get_upper_dimension_hidden_idss()[itran][idim_up].value); + }); + }); + + return adaptor1_min_hidden_id_; + }(); + + constexpr index_t adaptor1_hidden_id_shift = + adaptor0_max_hidden_id + 1 - adaptor1_min_hidden_id; + + constexpr index_t ndim_bottom_1 = TensorAdaptor1::get_num_of_bottom_dimension(); + + // all_low_dim_hidden_idss = + // low_dim_hidden_idss_0 + match_hidden_id_for_1(shift_hidden_id_for_1(low_dim_hiden_idss_1)) + constexpr auto low_dim_hidden_idss_1 = generate_tuple( + // generate sequence of ids for a transform + [&](auto itran) { + constexpr auto ndim_low_1 = + TensorAdaptor1::get_lower_dimension_hidden_idss()[itran].size(); + + constexpr auto low_dim_hidden_ids_1 = + TensorAdaptor1::get_lower_dimension_hidden_idss()[itran]; + + // sequence in, sequence out + constexpr auto low_dim_hidden_ids_1_mod = [&]() constexpr + { + auto low_dim_hidden_ids_1_mod_ = to_multi_index(low_dim_hidden_ids_1); + + // shift hidden id so every dim id is unique + static_for<0, ndim_low_1, 1>{}([&](auto idim_low_1) { + low_dim_hidden_ids_1_mod_(idim_low_1) += adaptor1_hidden_id_shift; + }); + + // match hidden id + static_for<0, ndim_low_1, 1>{}([&](auto idim_low_1) { + static_for<0, ndim_bottom_1, 1>{}([&](auto idim_bottom_1) { + // if this low dim is bottom dim, then do id matching + if constexpr(low_dim_hidden_ids_1[idim_low_1] == + TensorAdaptor1::get_bottom_dimension_hidden_ids() + [idim_bottom_1]) + { + low_dim_hidden_ids_1_mod_(idim_low_1) = + TensorAdaptor0::get_top_dimension_hidden_ids()[idim_bottom_1]; + } + }); + }); + + return low_dim_hidden_ids_1_mod_; + } + (); + + return generate_sequence_v2( + [&](auto i) constexpr { return number{}; }, + number{}); + }, + number{}); + + constexpr auto all_low_dim_hidden_idss = + container_concat(TensorAdaptor0::get_lower_dimension_hidden_idss(), low_dim_hidden_idss_1); + + // all_up_dim_hidden_idss = + // up_dim_hidden_idss_0 + shift_hidden_id_for_1(up_dim_hiden_idss_1) + constexpr auto up_dim_hidden_idss_1 = generate_tuple( + // generate sequence of ids for a transform + [&](auto itran) { + constexpr auto ndim_up_1 = + TensorAdaptor1::get_upper_dimension_hidden_idss()[itran].size(); + + constexpr auto up_dim_hidden_ids_1 = + TensorAdaptor1::get_upper_dimension_hidden_idss()[itran]; + + // sequence in, constexpr tuple out + constexpr auto up_dim_hidden_ids_1_mod = [&]() constexpr + { + auto up_dim_hidden_ids_1_mod_ = to_multi_index(up_dim_hidden_ids_1); + + // shift hidden id + static_for<0, ndim_up_1, 1>{}([&](auto idim_up_1) { + up_dim_hidden_ids_1_mod_(idim_up_1) += adaptor1_hidden_id_shift; + }); + + return up_dim_hidden_ids_1_mod_; + } + (); + + // constexpr tuple to sequence + return generate_sequence_v2( + [&](auto i) constexpr { return number{}; }, + number{}); + }, + number{}); + + constexpr auto all_up_dim_hidden_idss = + container_concat(TensorAdaptor0::get_upper_dimension_hidden_idss(), up_dim_hidden_idss_1); + + // bottom_dim_hidden_ids = bottom_dim_hidden_ids_0 + constexpr auto bottom_dim_hidden_ids = TensorAdaptor0::get_bottom_dimension_hidden_ids(); + + // top_dim_hidden_ids = shift_hidden_id(top_dim_hidden_ids_1) + constexpr auto top_dim_hidden_ids = + TensorAdaptor1::get_top_dimension_hidden_ids() + number{}; + + // put everything together + return tensor_adaptor, + remove_cvref_t, + remove_cvref_t, + remove_cvref_t, + remove_cvref_t>{all_transforms}; +} + +template = 2, bool>::type = false> +CK_TILE_HOST_DEVICE constexpr auto chain_tensor_adaptors(const X& x, const Xs&... xs) +{ + return chain_tensor_adaptors(x, chain_tensor_adaptors(xs...)); +} + +} // namespace ck_tile + +// Macro function +// construct constexpr tensor_adaptor from constexpr encoding +// encoded_tensor_adaptor are Tuple of following objects: +// 1. encoded transforms (array of fixed size). Each encoded transform is a Tuple of following: +// 1.1 name (cood_transform_enum) +// 1.2 meta data for constructor of the transform +// 1.3 num of lower dimension (index_t) +// 1.4 lower dimension Ids (array of fixed size) +// 1.5 num of up dimension (index_t) +// 1.6 upper dimension Ids (array of fixed size) +// 2. num of transforms (index_t) +// 3. encoded bottom dimension Ids (array of fixed size) +// 4. num of bottom dimension (index_t) +// 5. encoded top dimension Ids (array of fixed size) +// 6. num of top dimension (index_t) +#define CONSTRUCT_TENSOR_ADAPTOR_FROM_ENCODING(encoded_tensor_adaptor) \ + [encoded_tensor_adaptor]() { \ + using namespace ck_tile; \ + \ + constexpr auto encoded_transforms = encoded_tensor_adaptor.template at<0>(); \ + constexpr index_t num_transform = encoded_tensor_adaptor.template at<1>(); \ + constexpr auto encoded_bottom_dims = encoded_tensor_adaptor.template at<2>(); \ + constexpr index_t num_bottom_dim = encoded_tensor_adaptor.template at<3>(); \ + constexpr auto encoded_top_dims = encoded_tensor_adaptor.template at<4>(); \ + constexpr index_t num_top_dim = encoded_tensor_adaptor.template at<5>(); \ + \ + constexpr auto trans = [&encoded_transforms, &num_transform]() { \ + return generate_tuple( \ + [&encoded_transforms](auto i) constexpr { \ + constexpr auto name = encoded_transforms[i].template at<0>(); \ + constexpr auto meta_data = encoded_transforms[i].template at<1>(); \ + constexpr auto num_low_dim = encoded_transforms[i].template at<2>(); \ + constexpr auto num_up_dim = encoded_transforms[i].template at<4>(); \ + \ + STATIC_ASSERT(name == cood_transform_enum::PassThrough || \ + name == cood_transform_enum::pad || \ + name == cood_transform_enum::embed || \ + name == cood_transform_enum::merge || \ + name == cood_transform_enum::unmerge || \ + name == cood_transform_enum::replicate, \ + ""); \ + \ + if constexpr(name == cood_transform_enum::PassThrough) \ + { \ + index_t pos = 0; \ + auto low_len = meta_data.template pop(pos); \ + \ + return make_pass_through_transform(low_len); \ + } \ + else if constexpr(name == cood_transform_enum::pad) \ + { \ + index_t pos = 0; \ + auto low_len = meta_data.template pop(pos); \ + auto left_pad = meta_data.template pop(pos); \ + auto right_pad = meta_data.template pop(pos); \ + \ + return make_pad_transform(low_len, left_pad, right_pad); \ + } \ + else if constexpr(name == cood_transform_enum::embed) \ + { \ + index_t pos = 0; \ + auto up_lens = meta_data.template pop>(pos); \ + auto coefficients = \ + meta_data.template pop>(pos); \ + \ + return make_embed_transform(up_lens, coefficients); \ + } \ + else if constexpr(name == cood_transform_enum::merge) \ + { \ + index_t pos = 0; \ + auto low_lens = meta_data.template pop>(pos); \ + \ + return make_merge_transform(low_lens); \ + } \ + else if constexpr(name == cood_transform_enum::unmerge) \ + { \ + index_t pos = 0; \ + auto up_lens = meta_data.template pop>(pos); \ + \ + return make_unmerge_transform(up_lens); \ + } \ + else if constexpr(name == cood_transform_enum::replicate) \ + { \ + index_t pos = 0; \ + auto up_lens = meta_data.template pop>(pos); \ + \ + return make_replicate_transform(up_lens); \ + } \ + }, \ + number{}); \ + }(); \ + \ + constexpr auto low_dim_idss = [&encoded_transforms, &num_transform]() { \ + return generate_tuple( \ + [&encoded_transforms](auto i) { \ + constexpr auto num_low_dim = encoded_transforms[i].template at<2>(); \ + constexpr auto low_dims = encoded_transforms[i].template at<3>(); \ + \ + return TO_SEQUENCE(low_dims, num_low_dim); \ + }, \ + number()); \ + }(); \ + \ + constexpr auto up_dim_idss = [&encoded_transforms, &num_transform] { \ + return generate_tuple( \ + [&encoded_transforms](auto i) { \ + constexpr auto num_up_dim = encoded_transforms[i].template at<4>(); \ + constexpr auto up_dims = encoded_transforms[i].template at<5>(); \ + \ + return TO_SEQUENCE(up_dims, num_up_dim); \ + }, \ + number()); \ + }(); \ + \ + constexpr auto bottom_dim_ids = TO_SEQUENCE(encoded_bottom_dims, num_bottom_dim); \ + constexpr auto top_dim_ids = TO_SEQUENCE(encoded_top_dims, num_top_dim); \ + \ + return tensor_adaptor, \ + remove_cvref_t, \ + remove_cvref_t, \ + remove_cvref_t, \ + remove_cvref_t>{trans}; \ + }() + +// Macro function +// construct static tensor_adaptor from constexpr encoding +// encoded_tensor_adaptor are Tuple of following objects: +// 1. encoded transforms (array of fixed size). Each encoded transform is a Tuple of following: +// 1.1 name (cood_transform_enum) +// 1.2 meta data for constructor of the transform +// 1.3 num of lower dimension (index_t) +// 1.4 lower dimension Ids (array of fixed size) +// 1.5 num of up dimension (index_t) +// 1.6 upper dimension Ids (array of fixed size) +// 2. num of transforms (index_t) +// 3. encoded bottom dimension Ids (array of fixed size) +// 4. num of bottom dimension (index_t) +// 5. encoded top dimension Ids (array of fixed size) +// 6. num of top dimension (index_t) +#define CONSTRUCT_STATIC_TENSOR_ADAPTOR_FROM_ENCODING(encoded_tensor_adaptor) \ + [encoded_tensor_adaptor]() { \ + using namespace ck_tile; \ + \ + constexpr auto encoded_transforms = encoded_tensor_adaptor.template at<0>(); \ + constexpr index_t num_transform = encoded_tensor_adaptor.template at<1>(); \ + constexpr auto encoded_bottom_dims = encoded_tensor_adaptor.template at<2>(); \ + constexpr index_t num_bottom_dim = encoded_tensor_adaptor.template at<3>(); \ + constexpr auto encoded_top_dims = encoded_tensor_adaptor.template at<4>(); \ + constexpr index_t num_top_dim = encoded_tensor_adaptor.template at<5>(); \ + \ + constexpr auto trans = [&encoded_transforms, &num_transform]() { \ + return generate_tuple( \ + [&encoded_transforms](auto i) constexpr { \ + constexpr auto name = encoded_transforms[i].template at<0>(); \ + constexpr auto meta_data = encoded_transforms[i].template at<1>(); \ + constexpr auto num_low_dim = encoded_transforms[i].template at<2>(); \ + constexpr auto num_up_dim = encoded_transforms[i].template at<4>(); \ + \ + STATIC_ASSERT(name == cood_transform_enum::PassThrough || \ + name == cood_transform_enum::pad || \ + name == cood_transform_enum::embed || \ + name == cood_transform_enum::merge || \ + name == cood_transform_enum::unmerge || \ + name == cood_transform_enum::replicate, \ + ""); \ + \ + if constexpr(name == cood_transform_enum::PassThrough) \ + { \ + constexpr index_t low_len = meta_data.template get(0); \ + \ + return make_pass_through_transform(number{}); \ + } \ + else if constexpr(name == cood_transform_enum::pad) \ + { \ + constexpr index_t low_len = meta_data.template get(0); \ + \ + constexpr index_t left_pad = \ + meta_data.template get(sizeof(low_len)); \ + \ + constexpr index_t right_pad = \ + meta_data.template pop(sizeof(low_len) + sizeof(left_pad)); \ + \ + return make_pad_transform( \ + number{}, number{}, number{}); \ + } \ + else if constexpr(name == cood_transform_enum::embed) \ + { \ + constexpr auto up_lens = \ + meta_data.template get>(0); \ + \ + constexpr auto coefficients = \ + meta_data.template get>(sizeof(up_lens)); \ + \ + return make_embed_transform(TO_TUPLE_OF_NUMBER(up_lens, num_up_dim), \ + TO_TUPLE_OF_NUMBER(coefficients, num_up_dim)); \ + } \ + else if constexpr(name == cood_transform_enum::merge) \ + { \ + constexpr auto low_lens = \ + meta_data.template get>(0); \ + \ + return make_merge_transform(TO_TUPLE_OF_NUMBER(low_lens, num_low_dim)); \ + } \ + else if constexpr(name == cood_transform_enum::unmerge) \ + { \ + constexpr auto up_lens = \ + meta_data.template get>(0); \ + \ + return make_unmerge_transform(TO_TUPLE_OF_NUMBER(up_lens, num_up_dim)); \ + } \ + else if constexpr(name == cood_transform_enum::replicate) \ + { \ + constexpr auto up_lens = \ + meta_data.template get>(0); \ + \ + return make_replicate_transform(TO_TUPLE_OF_NUMBER(up_lens, num_up_dim)); \ + } \ + }, \ + number{}); \ + }(); \ + \ + constexpr auto low_dim_idss = [&encoded_transforms, &num_transform]() { \ + return generate_tuple( \ + [&encoded_transforms](auto i) { \ + constexpr auto num_low_dim = encoded_transforms[i].template at<2>(); \ + constexpr auto low_dims = encoded_transforms[i].template at<3>(); \ + \ + return TO_SEQUENCE(low_dims, num_low_dim); \ + }, \ + number()); \ + }(); \ + \ + constexpr auto up_dim_idss = [&encoded_transforms, &num_transform] { \ + return generate_tuple( \ + [&encoded_transforms](auto i) { \ + constexpr auto num_up_dim = encoded_transforms[i].template at<4>(); \ + constexpr auto up_dims = encoded_transforms[i].template at<5>(); \ + \ + return TO_SEQUENCE(up_dims, num_up_dim); \ + }, \ + number()); \ + }(); \ + \ + constexpr auto bottom_dim_ids = TO_SEQUENCE(encoded_bottom_dims, num_bottom_dim); \ + constexpr auto top_dim_ids = TO_SEQUENCE(encoded_top_dims, num_top_dim); \ + \ + return tensor_adaptor, \ + remove_cvref_t, \ + remove_cvref_t, \ + remove_cvref_t, \ + remove_cvref_t>{trans}; \ + }() diff --git a/include/ck_tile/core/tensor/tensor_adaptor_coordinate.hpp b/include/ck_tile/core/tensor/tensor_adaptor_coordinate.hpp new file mode 100644 index 0000000000..c4528fbc4b --- /dev/null +++ b/include/ck_tile/core/tensor/tensor_adaptor_coordinate.hpp @@ -0,0 +1,257 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core/config.hpp" +#include "ck_tile/core/numeric/integer.hpp" +#include "ck_tile/core/numeric/integral_constant.hpp" +#include "ck_tile/core/algorithm/coordinate_transform.hpp" +#include "ck_tile/core/tensor/tensor_adaptor.hpp" +#include "ck_tile/core/container/container_helper.hpp" +#include "ck_tile/core/container/multi_index.hpp" +#include "ck_tile/core/numeric/math.hpp" +#include "ck_tile/core/utility/type_traits.hpp" + +namespace ck_tile { + +template +struct tensor_adaptor_coordinate +{ + static constexpr index_t ndim_bottom_ = BottomDimensionHiddenIds::size(); + static constexpr index_t ndim_top_ = TopDimensionHiddenIds::size(); + + using HiddenIndex = multi_index; + using BottomIndex = multi_index; + using TopIndex = multi_index; + + public: + CK_TILE_HOST_DEVICE constexpr tensor_adaptor_coordinate() = default; + + CK_TILE_HOST_DEVICE constexpr tensor_adaptor_coordinate(const HiddenIndex& idx_hidden) + : idx_hidden_{idx_hidden} + { + } + + CK_TILE_HOST_DEVICE constexpr auto get_top_index() const + { + return get_container_subset(idx_hidden_, TopDimensionHiddenIds{}); + } + + CK_TILE_HOST_DEVICE constexpr auto get_bottom_index() const + { + return get_container_subset(idx_hidden_, BottomDimensionHiddenIds{}); + } + + CK_TILE_HOST_DEVICE constexpr const auto& get_hidden_index() const { return idx_hidden_; } + + CK_TILE_HOST_DEVICE constexpr auto& get_hidden_index() { return idx_hidden_; } + + // + HiddenIndex idx_hidden_; +}; + +template +CK_TILE_HOST_DEVICE constexpr auto make_tensor_adaptor_coordinate(const Adaptor& adaptor, + const TopIndex& idx_top) +{ + static_assert(Adaptor::get_num_of_top_dimension() == TopIndex::size(), + "wrong! # of dimension inconsistent"); + + constexpr index_t ntransform = Adaptor::get_num_of_transform(); + constexpr index_t ndim_hidden = Adaptor::get_num_of_hidden_dimension(); + constexpr auto bottom_dim_ids = Adaptor::get_bottom_dimension_hidden_ids(); + constexpr auto top_dim_ids = Adaptor::get_top_dimension_hidden_ids(); + + multi_index idx_hidden; + + // initialize visible index + set_container_subset(idx_hidden, top_dim_ids, idx_top); + + // calculate hidden index + static_for{}([&adaptor, &idx_hidden](auto itran_p1) { + auto itran = itran_p1 - number<1>{}; + const auto& tran = adaptor.get_transforms().at(itran); + constexpr auto dims_low = Adaptor::get_lower_dimension_hidden_idss().at(itran); + constexpr auto dims_up = Adaptor::get_upper_dimension_hidden_idss().at(itran); + + const auto idx_up = get_container_subset(idx_hidden, dims_up); + + multi_index idx_low; + + tran.calculate_lower_index(idx_low, idx_up); + + set_container_subset(idx_hidden, dims_low, idx_low); + }); + + return tensor_adaptor_coordinate, + remove_cvref_t>{idx_hidden}; +} + +template +CK_TILE_HOST_DEVICE constexpr void move_tensor_adaptor_coordinate(const Adaptor& adaptor, + AdaptorCoord& coord, + const TopIndex& idx_diff_top, + BottomIndex& idx_diff_bottom) +{ + constexpr index_t ndim_hidden = Adaptor::get_num_of_hidden_dimension(); + constexpr index_t ndim_top = Adaptor::get_num_of_top_dimension(); + // constexpr index_t ndim_bottom = Adaptor::get_num_of_bottom_dimension(); + constexpr index_t ntransform = Adaptor::get_num_of_transform(); + + // STATIC_ASSERT(TopIndex::size() == ndim_top && BottomIndex::size() == ndim_bottom, ""); + + // judge whether calculation of lower diff is needed for each transform + // use index_t for boolean type + auto do_transforms = make_zero_multi_index(); + + if constexpr(JudgeDoTransforms) + { + auto is_non_zero_diff = make_zero_multi_index(); + + // decide do_transform by checkout non-zero index diff components + multi_index non_zero_diff_pick_top; + + static_for<0, ndim_top, 1>{}( + [&](auto i) { non_zero_diff_pick_top(i) = (idx_diff_top[i] != 0); }); + + set_container_subset( + is_non_zero_diff, Adaptor::get_top_dimension_hidden_ids(), non_zero_diff_pick_top); + + static_for{}([&](auto itran) { + constexpr auto dims_low = Adaptor::get_lower_dimension_hidden_idss().at(itran); + constexpr auto dims_up = Adaptor::get_upper_dimension_hidden_idss().at(itran); + + const auto non_zero_diff_pick_up = get_container_subset(is_non_zero_diff, dims_up); + + multi_index non_zero_diff_pick_low; + + // if any of upper index diff components is non-zero, then + // 1) Need to do this transform + // 2) all components of lower index diff will assume to be non-zero and need to be + // computed + const bool idx_diff_up_has_non_zero = container_reduce( + non_zero_diff_pick_up, [](auto a, auto b) constexpr { return a or b; }, false); + + do_transforms(itran) = idx_diff_up_has_non_zero; + + static_for<0, dims_low.size(), 1>{}( + [&](auto i) { non_zero_diff_pick_low(i) = idx_diff_up_has_non_zero; }); + + set_container_subset(is_non_zero_diff, dims_low, non_zero_diff_pick_low); + }); + } + else + { + static_for{}([&](auto itran) { do_transforms(itran) = 1; }); + } + + // this is what needs to be calculated + auto idx_diff_hidden = make_zero_multi_index(); + + // initialize top index diff + set_container_subset(idx_diff_hidden, Adaptor::get_top_dimension_hidden_ids(), idx_diff_top); + + // this is what needs to be updated + auto& idx_hidden = coord.get_hidden_index(); + + // update top index + auto idx_hidden_pick_top = + get_container_subset(idx_hidden, Adaptor::get_top_dimension_hidden_ids()); + + idx_hidden_pick_top += idx_diff_top; + + set_container_subset(idx_hidden, Adaptor::get_top_dimension_hidden_ids(), idx_hidden_pick_top); + + // update rest of hidden index + static_for{}([&](auto itran) { + if(do_transforms[itran]) + { + const auto& tran = adaptor.get_transforms().at(itran); + constexpr auto dims_low = Adaptor::get_lower_dimension_hidden_idss().at(itran); + constexpr auto dims_up = Adaptor::get_upper_dimension_hidden_idss().at(itran); + + const auto idx_up_new = get_container_subset(idx_hidden, dims_up); + auto idx_low = get_container_subset(idx_hidden, dims_low); + const auto idx_diff_up = get_container_subset(idx_diff_hidden, dims_up); + + multi_index idx_diff_low; + + tran.update_lower_index(idx_diff_low, idx_diff_up, idx_low, idx_up_new); + + set_container_subset(idx_diff_hidden, dims_low, idx_diff_low); + set_container_subset(idx_hidden, dims_low, idx_low); + } + }); + + // set bottom index diff + idx_diff_bottom = + get_container_subset(idx_diff_hidden, Adaptor::get_bottom_dimension_hidden_ids()); +} + +template +CK_TILE_HOST_DEVICE constexpr void move_tensor_adaptor_coordinate(const Adaptor& adaptor, + AdaptorCoord& coord, + const TopIndex& idx_diff_top) +{ + constexpr index_t ndim_bottom = Adaptor::get_num_of_bottom_dimension(); + + multi_index tmp; + + move_tensor_adaptor_coordinate(adaptor, coord, idx_diff_top, tmp); +} + +template +CK_TILE_HOST_DEVICE constexpr bool +adaptor_coordinate_is_valid_assuming_top_index_is_valid(const Adaptor& adaptor, + const AdaptorCoord& coord) +{ + bool valid = true; + + constexpr index_t ntransform = Adaptor::get_num_of_transform(); + + const auto& idx_hidden = coord.get_hidden_index(); + + static_for{}([&adaptor, &idx_hidden, &valid](auto itran) { + const auto tran = adaptor.get_transforms().at(itran); + + // check validity, only if current transformation does not always has a valid mapping + if constexpr(!decltype(tran)::is_valid_upper_index_always_mapped_to_valid_lower_index()) + { + const auto idx_up = get_container_subset( + idx_hidden, Adaptor::get_upper_dimension_hidden_idss().at(itran)); + + // Comment: using valid = valid && .. will result in weird control flow in ISA + valid &= tran.is_valid_upper_index_mapped_to_valid_lower_index(idx_up); + } + }); + + return valid; +} + +template +CK_TILE_HOST_DEVICE constexpr bool adaptor_coordinate_is_valid(const Adaptor& adaptor, + const AdpatorCoord& coord) +{ + // check top index + const auto& idx_top = coord.get_top_index(); + + bool is_top_index_valid = true; + + static_for<0, Adaptor::get_num_of_dimension(), 1>{}( + [&is_top_index_valid, &idx_top, &adaptor](auto i) { + is_top_index_valid = + is_top_index_valid && (idx_top[i] >= 0 && idx_top[i] < adaptor.get_length(i)); + }); + + // check other hidden index + return is_top_index_valid && + adaptor_coordinate_is_valid_assuming_top_index_is_valid(adaptor, coord); +} + +} // namespace ck_tile diff --git a/include/ck_tile/core/tensor/tensor_coordinate.hpp b/include/ck_tile/core/tensor/tensor_coordinate.hpp new file mode 100644 index 0000000000..9b8fe731fd --- /dev/null +++ b/include/ck_tile/core/tensor/tensor_coordinate.hpp @@ -0,0 +1,92 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core/config.hpp" +#include "ck_tile/core/numeric/integer.hpp" +#include "ck_tile/core/numeric/integral_constant.hpp" +#include "ck_tile/core/algorithm/coordinate_transform.hpp" +#include "ck_tile/core/tensor/tensor_adaptor.hpp" +#include "ck_tile/core/tensor/tensor_adaptor_coordinate.hpp" +#include "ck_tile/core/container/container_helper.hpp" +#include "ck_tile/core/container/multi_index.hpp" +#include "ck_tile/core/numeric/math.hpp" +#include "ck_tile/core/utility/type_traits.hpp" + +namespace ck_tile { + +template +struct tensor_coordinate + : public tensor_adaptor_coordinate, TopDimensionHiddenIds> +{ + using Base = tensor_adaptor_coordinate, TopDimensionHiddenIds>; + + // TODO make these private + static constexpr index_t ndim_top_ = TopDimensionHiddenIds::size(); + + using HiddenIndex = multi_index; + using TopIndex = multi_index; + + public: + CK_TILE_HOST_DEVICE constexpr tensor_coordinate() = default; + + CK_TILE_HOST_DEVICE constexpr tensor_coordinate(const HiddenIndex& idx_hidden) + : Base{idx_hidden} + { + } + + // construct from TensorAdaptorCoordinte base class + CK_TILE_HOST_DEVICE constexpr tensor_coordinate(const Base& adaptor_coord) : Base{adaptor_coord} + { + } + + CK_TILE_HOST_DEVICE constexpr auto get_index() const { return Base::get_top_index(); } + + CK_TILE_HOST_DEVICE constexpr index_t get_offset() const + { + return Base::get_bottom_index()[number<0>{}]; + } + + CK_TILE_HOST_DEVICE constexpr const auto& get_hidden_index() const + { + return Base::get_hidden_index(); + } + + CK_TILE_HOST_DEVICE auto& get_hidden_index() { return Base::get_hidden_index(); } +}; + +template +CK_TILE_HOST_DEVICE constexpr auto make_tensor_coordinate(const TensorDesc& tensor_desc, + const TopIndex& idx_top) +{ + const auto adaptor_coord = make_tensor_adaptor_coordinate(tensor_desc, idx_top); + + return tensor_coordinate>{ + adaptor_coord}; +} + +template +CK_TILE_HOST_DEVICE constexpr void +move_tensor_coordinate(const TensorDesc& tensor_desc, TensorCoord& coord, const Index& coord_step) +{ + move_tensor_adaptor_coordinate(tensor_desc, coord, coord_step); +} + +template +CK_TILE_HOST_DEVICE constexpr bool +coordinate_has_valid_offset_assuming_top_index_is_valid(const TensorDesc& tensor_desc, + const TensorCoord& coord) +{ + return adaptor_coordinate_is_valid_assuming_top_index_is_valid(tensor_desc, coord); +} + +template +CK_TILE_HOST_DEVICE constexpr bool coordinate_has_valid_offset(const TensorDesc& tensor_desc, + const TensorCoord& coord) +{ + return adaptor_coordinate_is_valid(tensor_desc, coord); +} + +} // namespace ck_tile diff --git a/include/ck_tile/core/tensor/tensor_descriptor.hpp b/include/ck_tile/core/tensor/tensor_descriptor.hpp new file mode 100644 index 0000000000..697988de10 --- /dev/null +++ b/include/ck_tile/core/tensor/tensor_descriptor.hpp @@ -0,0 +1,472 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core/config.hpp" +#include "ck_tile/core/numeric/integer.hpp" +#include "ck_tile/core/numeric/integral_constant.hpp" +#include "ck_tile/core/algorithm/coordinate_transform.hpp" +#include "ck_tile/core/tensor/tensor_adaptor.hpp" +#include "ck_tile/core/container/container_helper.hpp" +#include "ck_tile/core/container/multi_index.hpp" +#include "ck_tile/core/numeric/math.hpp" +#include "ck_tile/core/utility/type_traits.hpp" + +namespace ck_tile { + +// Transforms: Tuple +// LowerDimensionHiddenIdss : Tuple, ...> +// UpperDimensionHiddenIdss : Tuple, ...> +// TopDimensionHiddenIds> : sequence<...> +template +struct tensor_descriptor : public tensor_adaptor, + TopDimensionHiddenIds> +{ + using Base = tensor_adaptor, + TopDimensionHiddenIds>; + + using ElementSpaceSizeType = ElementSpaceSize; + + constexpr static index_t ntransform_ = Base::get_num_of_transform(); + constexpr static index_t ndim_hidden_ = Base::get_num_of_hidden_dimension(); + constexpr static index_t ndim_top_ = Base::get_num_of_top_dimension(); + + using GuaranteedVectorLengths = GuaranteedVectorLengths_; + using GuaranteedVectorStrides = GuaranteedVectorSrides_; + + static_assert(GuaranteedVectorLengths::size() == ndim_hidden_ && + GuaranteedVectorStrides::size() == ndim_hidden_, + "wrong! inconsistent # of hidden dimensions"); + + using TopIndex = multi_index; + using HiddenIndex = multi_index; + + public: + CK_TILE_HOST_DEVICE constexpr tensor_descriptor() = default; + + CK_TILE_HOST_DEVICE constexpr tensor_descriptor(const Transforms& transforms, + ElementSpaceSize element_space_size) + : Base{transforms}, element_space_size_{element_space_size} + + { + static_assert(Transforms::size() == ntransform_ && + LowerDimensionHiddenIdss::size() == ntransform_ && + UpperDimensionHiddenIdss::size() == ntransform_, + "wrong! inconsistent # of transformations"); + + // TODO check dependency of dimensions is valid + } + + // construct from tensor_adaptor base class + CK_TILE_HOST_DEVICE constexpr tensor_descriptor(const Base& adaptor, + ElementSpaceSize element_space_size) + : Base{adaptor}, element_space_size_{element_space_size} + { + } + + CK_TILE_HOST_DEVICE static constexpr index_t get_num_of_dimension() + { + return Base::get_num_of_top_dimension(); + } + + template + CK_TILE_HOST_DEVICE constexpr auto get_length(number idim) const + { + return Base::get_top_dimension_length(idim); + } + + CK_TILE_HOST_DEVICE constexpr auto get_lengths() const + { + return Base::get_top_dimension_length(); + } + + CK_TILE_HOST_DEVICE constexpr auto get_element_space_size() const + { + return element_space_size_; + } + + template + CK_TILE_HOST_DEVICE constexpr index_t calculate_offset(const Idx& idx) const + { + return Base::calculate_bottom_index(idx)[number<0>{}]; + } + + // TODO make these private + CK_TILE_HOST_DEVICE constexpr const auto& get_transforms() const + { + return Base::get_transforms(); + } + + CK_TILE_HOST_DEVICE static constexpr auto get_lower_dimension_hidden_idss() + { + return Base::get_lower_dimension_hidden_idss(); + } + + CK_TILE_HOST_DEVICE static constexpr auto get_upper_dimension_hidden_idss() + { + return Base::get_upper_dimension_hidden_idss(); + } + + CK_TILE_HOST_DEVICE static constexpr auto get_top_dimension_hidden_ids() + { + return Base::get_top_dimension_hidden_ids(); + } + + CK_TILE_HOST_DEVICE static constexpr bool is_static() + { + return Base::is_known_at_compile_time() && + ck_tile::is_known_at_compile_time::value; + } + + CK_TILE_HOST_DEVICE static constexpr bool is_known_at_compile_time() { return is_static(); } + + CK_TILE_HOST_DEVICE static constexpr auto get_top_dimension_safe_vector_length_strides() + { + return Base::get_top_dimension_safe_vector_length_strides( + to_array(GuaranteedVectorLengths{}), + to_array(GuaranteedVectorStrides{})); + } + + CK_TILE_HOST_DEVICE void print() const + { + printf("tensor_descriptor{"); + + // tensor_adaptor + Base::print(); + printf(", "); + + // element_space_size_ + printf("element_space_size_: "); + print(element_space_size_); + + printf("}"); + } + + // TODO make these private + ElementSpaceSize element_space_size_; +}; + +template +CK_TILE_HOST_DEVICE constexpr auto +make_tensor_descriptor_from_adaptor(const Adaptor& adaptor, + const ElementSpaceSize& element_space_size) +{ + constexpr index_t NDimHidden = Adaptor::get_num_of_hidden_dimension(); + + return tensor_descriptor, + remove_cvref_t, + remove_cvref_t, + remove_cvref_t, + remove_cvref_t, + typename uniform_sequence_gen::type, + typename uniform_sequence_gen::type>{ + adaptor, element_space_size}; +} + +template +CK_TILE_HOST_DEVICE constexpr auto +transform_tensor_descriptor(const OldTensorDescriptor& old_tensor_desc, + const NewTransforms& new_transforms, + NewLowerDimensionOldTopIdss, + NewUpperDimensionNewTopIdss) +{ + const auto element_space_size = old_tensor_desc.get_element_space_size(); + + const auto new_tensor_adaptor = transform_tensor_adaptor(old_tensor_desc, + new_transforms, + NewLowerDimensionOldTopIdss{}, + NewUpperDimensionNewTopIdss{}); + + constexpr index_t NDimHiddenOld = OldTensorDescriptor::get_num_of_hidden_dimension(); + constexpr index_t NDimHiddenNew = decltype(new_tensor_adaptor)::get_num_of_hidden_dimension(); + + using NewGuaranteedVectorLengths = typename sequence_merge< + typename OldTensorDescriptor::GuaranteedVectorLengths, + typename uniform_sequence_gen::type>::type; + + using NewGuaranteedVectorStrides = typename sequence_merge< + typename OldTensorDescriptor::GuaranteedVectorStrides, + typename uniform_sequence_gen::type>::type; + + return tensor_descriptor< + remove_cvref_t, + remove_cvref_t, + remove_cvref_t, + remove_cvref_t, + remove_cvref_t, + NewGuaranteedVectorLengths, + NewGuaranteedVectorStrides>{new_tensor_adaptor, element_space_size}; +} + +namespace detail { + +template +CK_TILE_HOST_DEVICE constexpr auto calculate_element_space_size_impl(const Lengths& lengths, + const Strides& strides, + number i, + AccOld acc_old) +{ + auto acc_new = acc_old + (lengths[i] - number<1>{}) * strides[i]; + + if constexpr(i.value < Lengths::size() - 1) + { + return calculate_element_space_size_impl(lengths, strides, i + number<1>{}, acc_new); + } + else + { + return acc_new; + } +} + +} // namespace detail + +/* + * These functions create naive tensor descriptor + */ + +// Lengths..., Strides... could be: +// 1) index_t, which is known at run-time, or +// 2) number<>, which is known at compile-time +// element_space_size could be: +// 1) long_index_t, or +// 2) long_number<> +template ::type = false> +CK_TILE_HOST_DEVICE constexpr auto +make_naive_tensor_descriptor(const tuple& lengths, + const tuple& strides, + number = number<-1>{}, + number = number<-1>{}) +{ + constexpr index_t N = sizeof...(Lengths); + + const auto transforms = make_tuple(make_embed_transform(lengths, strides)); + + constexpr auto low_dim_hidden_idss = make_tuple(sequence<0>{}); + + constexpr auto up_dim_hidden_idss = + make_tuple(typename arithmetic_sequence_gen<1, N + 1, 1>::type{}); + + constexpr auto visible_dim_hidden_ids = typename arithmetic_sequence_gen<1, N + 1, 1>::type{}; + + const auto element_space_size = + detail::calculate_element_space_size_impl(lengths, strides, number<0>{}, long_number<1>{}); + + using GuaranteedVectorLengths = + typename sequence_merge::type, + sequence>::type; + + using GuaranteedVectorStrides = + typename sequence_merge::type, + sequence>::type; + + return tensor_descriptor, + remove_cv_t, + remove_cv_t, + remove_cv_t, + remove_cv_t, + GuaranteedVectorLengths, + GuaranteedVectorStrides>{transforms, element_space_size}; +} + +// tensor descriptor with offset, the offset will not be added into element space size +// only have an information of the starting offset, and will impact on offset calculation +template ::type = false> +CK_TILE_HOST_DEVICE constexpr auto +make_naive_tensor_descriptor_with_offset(const tuple& lengths, + const tuple& strides, + const offset& offset, + number = number<-1>{}, + number = number<-1>{}) +{ + const auto desc_0 = [&]() { + const auto element_space_size = detail::calculate_element_space_size_impl( + lengths, strides, number<0>{}, long_number<1>{}); + + const auto transforms = make_tuple(make_offset_transform(element_space_size, offset)); + + constexpr auto low_dim_hidden_idss = make_tuple(sequence<0>{}); + + constexpr auto up_dim_hidden_idss = make_tuple(sequence<1>{}); + + constexpr auto visible_dim_hidden_ids = sequence<1>{}; + + using GuaranteedVectorLengths = + typename sequence_merge::type, + sequence>::type; + + using GuaranteedVectorStrides = + typename sequence_merge::type, + sequence>::type; + + return tensor_descriptor, + remove_cv_t, + remove_cv_t, + remove_cv_t, + remove_cv_t, + GuaranteedVectorLengths, + GuaranteedVectorStrides>{transforms, element_space_size}; + }(); + + constexpr index_t N = sizeof...(Lengths); + + return transform_tensor_descriptor( + desc_0, + make_tuple(make_embed_transform(lengths, strides)), + make_tuple(sequence<0>{}), + make_tuple(typename arithmetic_sequence_gen<0, N, 1>::type{})); +} + +// Lengths... could be: +// 1) index_t, which is known at run-time, or +// 2) number<>, which is known at compile-time +// element_space_size could be: +// 1) long_index_t, or +// 2) long_number<> +template +CK_TILE_HOST_DEVICE constexpr auto +make_naive_tensor_descriptor_packed(const tuple& lengths, + number = number<-1>{}) +{ + constexpr index_t N = sizeof...(Lengths); + + const auto transforms = make_tuple(make_unmerge_transform(lengths)); + + constexpr auto low_dim_hidden_idss = make_tuple(sequence<0>{}); + + constexpr auto up_dim_hidden_idss = + make_tuple(typename arithmetic_sequence_gen<1, N + 1, 1>::type{}); + + constexpr auto visible_dim_hidden_ids = typename arithmetic_sequence_gen<1, N + 1, 1>::type{}; + + const auto element_space_size = container_reduce(lengths, math::multiplies{}, long_number<1>{}); + + using GuaranteedVectorLengths = + typename sequence_merge::type, + sequence>::type; + + using GuaranteedVectorStrides = + typename sequence_merge::type, sequence<1>>::type; + + return tensor_descriptor, + remove_cv_t, + remove_cv_t, + remove_cv_t, + remove_cv_t, + GuaranteedVectorLengths, + GuaranteedVectorStrides>{transforms, element_space_size}; +} + +template ::type = false> +CK_TILE_HOST_DEVICE constexpr auto make_naive_tensor_descriptor_packed_with_offset( + const tuple& lengths, + const offset& offset, + number = number<-1>{}) +{ + const auto desc_0 = [&]() { + const auto element_space_size = + container_reduce(lengths, math::multiplies{}, long_number<1>{}); + + const auto transforms = make_tuple(make_offset_transform(element_space_size, offset)); + + constexpr auto low_dim_hidden_idss = make_tuple(sequence<0>{}); + + constexpr auto up_dim_hidden_idss = make_tuple(sequence<1>{}); + + constexpr auto visible_dim_hidden_ids = sequence<1>{}; + + using GuaranteedVectorLengths = + typename sequence_merge::type, + sequence>::type; + + using GuaranteedVectorStrides = + typename sequence_merge::type, sequence<1>>::type; + + return tensor_descriptor, + remove_cv_t, + remove_cv_t, + remove_cv_t, + remove_cv_t, + GuaranteedVectorLengths, + GuaranteedVectorStrides>{transforms, element_space_size}; + }(); + + constexpr index_t N = sizeof...(Lengths); + + return transform_tensor_descriptor( + desc_0, + make_tuple(make_unmerge_transform(lengths)), + make_tuple(sequence<0>{}), + make_tuple(typename arithmetic_sequence_gen<0, N, 1>::type{})); +} + +// Lengths... could be: +// 1) index_t, which is known at run-time, or +// 2) number<>, which is known at compile-time +// align could be: +// 1) index_t, or +// 2) number<> +template +CK_TILE_HOST_DEVICE constexpr auto +make_naive_tensor_descriptor_aligned(const tuple& lengths, Align align) +{ + constexpr auto I1 = number<1>{}; + + constexpr index_t N = sizeof...(Lengths); + + const auto stride_n_minus_2 = math::integer_least_multiple(lengths[number{}], align); + + auto strides = generate_tuple( + [&](auto i) { + if constexpr(i.value == N - 1) + { + return I1; + } + else if constexpr(i.value == N - 2) + { + return number{}; + } + else + { + return container_reduce(lengths, + math::multiplies{}, + number{}, + i + I1, + number{}, + I1); + } + }, + number{}); + + return make_naive_tensor_descriptor(lengths, strides); +} + +} // namespace ck_tile diff --git a/include/ck_tile/core/tensor/tensor_view.hpp b/include/ck_tile/core/tensor/tensor_view.hpp new file mode 100644 index 0000000000..3309b4b442 --- /dev/null +++ b/include/ck_tile/core/tensor/tensor_view.hpp @@ -0,0 +1,273 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core/config.hpp" +#include "ck_tile/core/numeric/integer.hpp" +#include "ck_tile/core/numeric/integral_constant.hpp" +#include "ck_tile/core/algorithm/coordinate_transform.hpp" +#include "ck_tile/core/container/container_helper.hpp" +#include "ck_tile/core/numeric/math.hpp" +#include "ck_tile/core/utility/type_traits.hpp" + +namespace ck_tile { + +template +struct tensor_view +{ + using BufferView = remove_reference_t; + using DataType = typename BufferView::type; + using TensorDesc = remove_cvref_t; + using TensorIndex = array; + using TensorCoord = decltype(make_tensor_coordinate(TensorDesc{}, TensorIndex{})); + + CK_TILE_HOST_DEVICE constexpr tensor_view() = default; + + CK_TILE_HOST_DEVICE constexpr tensor_view(const BufferView& buffer_view, const TensorDesc& desc) + : buf_{buffer_view}, desc_{desc} + { + } + + CK_TILE_HOST_DEVICE constexpr auto& get_tensor_descriptor() const { return desc_; } + + CK_TILE_HOST_DEVICE static constexpr index_t get_num_of_dimension() + { + return TensorDesc::get_num_of_top_dimension(); + } + + CK_TILE_HOST_DEVICE constexpr const auto& get_buffer_view() const { return buf_; } + + CK_TILE_HOST_DEVICE constexpr auto& get_buffer_view() { return buf_; } + +#if 0 + CK_TILE_HOST_DEVICE constexpr DataType get_element(const TensorCoord& coord) const + { + return buf_.template get( + coord.get_offset(), + coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord)); + } + + CK_TILE_HOST_DEVICE constexpr void set_element(const TensorCoord& coord, const DataType& x) + { + buf_.template set( + coord.get_offset(), + coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord), + x); + } +#endif + // X is vector of DataType. + // "coord" is coordinate of DataType, not X. "coord" should be aligned to X + template >::type, + typename scalar_type>::type>, + bool>::type = false> + CK_TILE_HOST_DEVICE constexpr remove_cvref_t + get_vectorized_elements(const TensorCoord& coord, + bool_constant = {}) const + { + return buf_.template get( + coord.get_offset(), + coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord), + bool_constant{}); + } + + // X is vector of DataType. + // "coord" is coordinate of DataType, not X. "coord" should be aligned to X + template >::type, + typename scalar_type>::type>, + bool>::type = false> + CK_TILE_HOST_DEVICE void + get_vectorized_elements_raw(remove_cvref_t& dst, + const TensorCoord& coord, + bool_constant = {}) const + { + return buf_.template get_raw( + dst, + coord.get_offset(), + coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord)); + } + + template >::type, + typename scalar_type>::type>, + bool>::type = false> + CK_TILE_HOST_DEVICE constexpr void async_get_vectorized_elements(remove_cvref_t* smem, + const TensorCoord& coord) const + { + return buf_.template async_get(smem, coord.get_offset(), true /*not used*/); + } + + // X is vector of DataType. + // "coord" is coordinate of DataType, not X. "coord" should be aligned to X + template >::type, + typename scalar_type>::type>, + bool>::type = false> + CK_TILE_HOST_DEVICE constexpr void set_vectorized_elements( + const TensorCoord& coord, const X& x, bool_constant = {}) + { + buf_.template set( + coord.get_offset(), + coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord), + x); + } + + template >::type, + typename scalar_type>::type>, + bool>::type = false> + CK_TILE_HOST_DEVICE constexpr void set_vectorized_elements_raw( + const TensorCoord& coord, const X& x, bool_constant = {}) + { + buf_.template set_raw( + coord.get_offset(), + coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord), + x); + } + + CK_TILE_HOST_DEVICE void print() const + { + printf("tensor_view{"); + + // buf_ + printf("buf_: "); + print(buf_); + printf(", "); + + // desc_ + printf("desc_: "); + print(desc_); + + printf("}"); + } + + // member + BufferView buf_; + TensorDesc desc_; +}; + +// placeholder type if we want to opt-out a tile view parameter +struct null_tensor_view +{ +}; + +template +CK_TILE_HOST_DEVICE constexpr auto make_tensor_view(DataType* p, + const tensor_descriptor& desc) +{ + auto buffer_view = make_buffer_view(p, desc.get_element_space_size()); + + return tensor_view{buffer_view, desc}; +} + +template ::type = false> +CK_TILE_HOST_DEVICE constexpr auto +make_naive_tensor_view(DataType* p, + const tuple& lengths, + const tuple& strides, + number = number<-1>{}, + number = number<-1>{}) +{ + auto desc = make_naive_tensor_descriptor(lengths, + strides, + number{}, + number{}); + + auto buffer_view = make_buffer_view(p, desc.get_element_space_size()); + + return tensor_view{buffer_view, desc}; +} + +template +CK_TILE_HOST_DEVICE constexpr auto +make_naive_tensor_view_packed(DataType* p, + const tuple& lengths, + number = number<-1>{}) +{ + auto desc = + make_naive_tensor_descriptor_packed(lengths, number{}); + + auto buffer_view = make_buffer_view(p, desc.get_element_space_size()); + + return tensor_view{buffer_view, desc}; +} + +template +CK_TILE_HOST_DEVICE constexpr auto transform_tensor_view(const OldTensorView& old_tensor_view, + const NewTransforms& new_transforms, + NewLowerDimensionOldVisibleIdss, + NewUpperDimensionNewVisibleIdss) +{ + auto new_desc = transform_tensor_descriptor(old_tensor_view.desc_, + new_transforms, + NewLowerDimensionOldVisibleIdss{}, + NewUpperDimensionNewVisibleIdss{}); + + return tensor_view>{ + old_tensor_view.buf_, new_desc}; +} + +template + typename DoPads> // sequence +CK_TILE_HOST_DEVICE constexpr auto +pad_tensor_view(const tensor_view& tensor_view, const TileLengths& tile_lengths, DoPads) +{ + constexpr index_t num_dim = DoPads::size(); + + static_assert(num_dim == TileLengths::size() && num_dim == tensor_view::get_num_of_dimension(), + "wrong! inconsistent # of dimensions"); + + // transforms + const auto transforms = generate_tuple( + [&](auto idim) { + const auto old_length = tensor_view.get_tensor_descriptor().get_length(idim); + + const auto tile_length = tile_lengths[idim]; + + const auto new_length = + math::integer_divide_ceil(old_length, tile_length) * tile_length; + + const auto pad_length = new_length - old_length; + + constexpr bool DoPad = DoPads::at(idim); + + const auto transform = + conditional_expr(make_right_pad_transform(old_length, pad_length), + make_pass_through_transform(old_length)); + + return transform; + }, + number{}); + + // lower dimension Id + const auto lower_dimss = + generate_tuple([&](auto idim) { return sequence{}; }, number{}); + + // upper dimension Id + const auto upper_dimss = lower_dimss; + + return transform_tensor_view(tensor_view, transforms, lower_dimss, upper_dimss); +} + +} // namespace ck_tile diff --git a/include/ck_tile/core/tensor/tile_distribution.hpp b/include/ck_tile/core/tensor/tile_distribution.hpp new file mode 100644 index 0000000000..c891e9a608 --- /dev/null +++ b/include/ck_tile/core/tensor/tile_distribution.hpp @@ -0,0 +1,754 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core/config.hpp" +#include "ck_tile/core/container/array.hpp" +#include "ck_tile/core/container/sequence.hpp" +#include "ck_tile/core/container/tuple.hpp" +#include "ck_tile/core/container/container_helper.hpp" +#include "ck_tile/core/tensor/tensor_adaptor.hpp" +#include "ck_tile/core/utility/functional.hpp" +#include "ck_tile/core/utility/type_traits.hpp" + +namespace ck_tile { + +// distributed span +template +struct tile_distributed_span +{ + using Impl = sequence; + + static constexpr auto impl_ = Impl{}; + + CK_TILE_HOST_DEVICE static constexpr bool is_static() { return true; } +}; + +// distributed index +template +struct tile_distributed_index +{ + using Impl = sequence; + + static constexpr auto impl_ = Impl{}; + + CK_TILE_HOST_DEVICE static constexpr bool is_static() { return true; } +}; + +namespace detail { + +template +CK_TILE_HOST_DEVICE constexpr auto make_tile_distributed_span(sequence) +{ + return tile_distributed_span{}; +} + +template +CK_TILE_HOST_DEVICE constexpr auto make_tile_distributed_index(sequence) +{ + return tile_distributed_index{}; +} + +} // namespace detail + +template // FIXME: this is for hold ad-hoc but useful info, + // should be more elegnat +struct tile_distribution +{ + using PsYs2XsAdaptor = remove_cvref_t; + using Ys2DDescriptor = remove_cvref_t; + using DstrEncode = remove_cvref_t; + using DstrDetail = remove_cvref_t; + + static_assert(PsYs2XsAdaptor::is_static() && Ys2DDescriptor::is_static(), + "wrong! should be static"); + + static constexpr index_t NDimX = PsYs2XsAdaptor::get_num_of_bottom_dimension(); + static constexpr index_t NDimY = Ys2DDescriptor::get_num_of_top_dimension(); + static constexpr index_t NDimP = PsYs2XsAdaptor::get_num_of_top_dimension() - NDimY; + static constexpr index_t NDimR = StaticTileDistributionEncoding_::NDimR; + + PsYs2XsAdaptor ps_ys_to_xs_; + Ys2DDescriptor ys_to_d_; + + CK_TILE_HOST_DEVICE static constexpr index_t get_num_of_dimension_x() { return NDimX; } + CK_TILE_HOST_DEVICE static constexpr index_t GetNumOfDimensionY() { return NDimY; } + CK_TILE_HOST_DEVICE static constexpr index_t GetNumOfDimensionP() { return NDimP; } + CK_TILE_HOST_DEVICE static constexpr index_t GetNumOfDimensionR() { return NDimR; } + + CK_TILE_HOST_DEVICE static constexpr auto get_lengths() + { +#if 0 + // FIXME: tensor_adaptor::GetBottomDimensionLengths is wrong. re-enable this after it's fixed + ps_ys_to_xs_.GetBottomDimensionLengths(); +#else + return generate_tuple( + [&](auto i) { + constexpr index_t x_length = + container_reduce(typename DstrEncode::HsLengthss{}[i], math::multiplies{}, 1); + + return number{}; + }, + number{}); +#endif + } + + CK_TILE_HOST_DEVICE constexpr const auto& get_ps_ys_to_xs_adaptor() const + { + return ps_ys_to_xs_; + } + + CK_TILE_HOST_DEVICE constexpr const auto& get_ys_to_d_descriptor() const { return ys_to_d_; } + + CK_TILE_HOST_DEVICE static constexpr auto get_static_tile_distribution_encoding() + { + return DstrEncode{}; + } + +#if 1 + // Calculate Replication index [R0, R1, ...] based on Partion index + // FIXME: very nasty implementation + template + CK_TILE_HOST_DEVICE auto calculate_rs_index_from_ps_index(const PartitionIndex& ps_idx) const + { + static_assert(PartitionIndex::size() == NDimP, "wrong!"); + + const auto ps_ys_idx = container_concat(ps_idx, array{0}); + + const auto dummy_adaptor_coord = make_tensor_adaptor_coordinate(ps_ys_to_xs_, ps_ys_idx); + + array rs_idx; + + static_for<0, NDimP, 1>{}([&](auto idim_p) { + constexpr index_t ndim_low = DstrEncode::ps_to_rhss_major_[idim_p].size(); + + static_for<0, ndim_low, 1>{}([&](auto i) { + constexpr index_t rh_major = DstrEncode::ps_to_rhss_major_[idim_p][i]; + constexpr index_t rh_minor = DstrEncode::ps_to_rhss_minor_[idim_p][i]; + + // 0-th rh_major is the replicate dimension + if constexpr(rh_major == 0) + { + constexpr index_t adaptor_hidden_id = + DstrDetail::rh_major_minor_to_adaptor_hidden_idss_[rh_major][rh_minor]; + + // fill in + rs_idx(rh_minor) = dummy_adaptor_coord.get_hidden_index()[adaptor_hidden_id]; + } + }); + }); + + return rs_idx; + } +#endif + + CK_TILE_HOST_DEVICE static constexpr auto get_distributed_spans() + { + constexpr auto distributed_spans_impl = DstrEncode::detail::distributed_spans_lengthss_; + constexpr auto ndims_spans_minor = DstrEncode::detail::ndims_distributed_spans_minor_; + + return generate_tuple( + [&](auto i) { + constexpr auto span_impl = distributed_spans_impl[i]; + constexpr index_t ndim_span_minor = ndims_spans_minor[i]; + + constexpr auto span = TO_SEQUENCE(span_impl, ndim_span_minor); + + return detail::make_tile_distributed_span(span); + }, + number{}); + } + + // FIXME: it's hacky to get Y index from Distributed-Index + template + CK_TILE_HOST_DEVICE static constexpr auto + get_y_indices_from_distributed_indices(DistributedIndices) + { + constexpr auto ys_idx_arr = [] { + array ys_idx; + + static_for<0, NDimY, 1>{}([&](auto i) { + constexpr index_t span_major = DstrEncode::detail::ys_to_span_major_[i]; + constexpr index_t span_minor = DstrEncode::detail::ys_to_span_minor_[i]; + + constexpr auto dstr_index = DistributedIndices{}[number{}]; + + ys_idx(i) = dstr_index.impl_[span_minor]; + }); + + return ys_idx; + }(); + + constexpr index_t ndim_y = NDimY; + + return TO_SEQUENCE(ys_idx_arr, ndim_y); + } + + CK_TILE_HOST_DEVICE static constexpr bool is_static() + { + return PsYs2XsAdaptor::is_static() && Ys2DDescriptor::is_static(); + } + + CK_TILE_HOST_DEVICE void print() const + { + printf("tile_distribution{"); + // + printf("tile_distribution_encoding: "); + print(DstrEncode{}); + printf(", "); + // + printf("ps_ys_to_xs_: "); + print(ps_ys_to_xs_); + printf(", "); + // + printf("ys_to_d_: "); + print(ys_to_d_); + // + printf("}"); + } +}; + +namespace detail { + +template +CK_TILE_HOST_DEVICE constexpr auto make_sequential_index(index_t ibegin, index_t iend) +{ + array arr{0}; + + for(index_t i = 0; i < iend - ibegin; ++i) + { + arr(i) = ibegin + i; + } + + return arr; +} + +// this returns a constexpr encoding of tile_distribution +template +CK_TILE_HOST_DEVICE constexpr auto + make_adaptor_encoding_for_tile_distribution(StaticTileDistributionEncoding_) +{ + using RsLengths = typename StaticTileDistributionEncoding_::RsLengths; + using HsLengthss = typename StaticTileDistributionEncoding_::HsLengthss; + using Ps2RHssMajor = typename StaticTileDistributionEncoding_::Ps2RHssMajor; + using Ps2RHssMinor = typename StaticTileDistributionEncoding_::Ps2RHssMinor; + using Ys2RHsMajor = typename StaticTileDistributionEncoding_::Ys2RHsMajor; + using Ys2RHsMinor = typename StaticTileDistributionEncoding_::Ys2RHsMinor; + + // FIXME: increase max value if fail + constexpr index_t kMaxNumTransforms = 20; + constexpr index_t kMaxMetaDataSize = 128; + constexpr index_t kMaxNumDim = 10; + + using Name = cood_transform_enum; + using MetaData = meta_data_buffer; + using NumDim = index_t; + using Dims = array; + using Lengths = array; + + // Tile Adaptor + // bottom dims [x0, x1, x2, ...] + // top dims [p0, p1, ..., y0, y1, ...] + constexpr index_t ndim_x = HsLengthss::size(); + + // Dim Ids: [idim_x_major, idim_x_minor] to [idim_hidden] + array, ndim_x + 1> rh_major_minor_to_hidden_ids; + array, ndim_x + 1> rh_major_minor_to_hidden_lengths; + + auto trans = array, kMaxNumTransforms>{}; + + index_t num_tran = 0; + index_t hidden_dim_cnt = ndim_x; + + // this is replicate transform + { + constexpr index_t ndim_r_minor = RsLengths::size(); + + constexpr auto r_minor_lengths = RsLengths{}; + + trans(num_tran++) = { + cood_transform_enum::replicate, + MetaData{to_array(r_minor_lengths)}, + NumDim{0}, + Dims{}, + NumDim{ndim_r_minor}, + make_sequential_index(hidden_dim_cnt, hidden_dim_cnt + ndim_r_minor)}; + + for(index_t i = 0; i < ndim_r_minor; ++i) + { + rh_major_minor_to_hidden_ids(0)(i) = hidden_dim_cnt; + rh_major_minor_to_hidden_lengths(0)(i) = r_minor_lengths[i]; + + hidden_dim_cnt++; + } + }; + + // these are Unmerge transforms for X dimesions + static_for<0, ndim_x, 1>{}([&trans, + &num_tran, + &hidden_dim_cnt, + &rh_major_minor_to_hidden_ids, + &rh_major_minor_to_hidden_lengths](auto idim_x) { + constexpr auto h_minor_lengths = tuple_element_t{}; + + constexpr index_t ndim_h_minor = h_minor_lengths.size(); + + trans(num_tran++) = { + cood_transform_enum::unmerge, + MetaData{to_array(h_minor_lengths)}, + NumDim{1}, + Dims{idim_x}, + NumDim{ndim_h_minor}, + make_sequential_index(hidden_dim_cnt, hidden_dim_cnt + ndim_h_minor)}; + + for(index_t i = 0; i < ndim_h_minor; ++i) + { + rh_major_minor_to_hidden_ids(idim_x + 1)(i) = hidden_dim_cnt; + rh_major_minor_to_hidden_lengths(idim_x + 1)(i) = h_minor_lengths[i]; + + hidden_dim_cnt++; + } + }); + + // transform: P dimensions + constexpr index_t ndim_p = Ps2RHssMajor::size(); + + Dims hidden_dim_id_ps; + + static_for<0, ndim_p, 1>{}([&](auto iDimP) { + // + index_t hidden_dim_id_p = hidden_dim_cnt++; + + hidden_dim_id_ps(iDimP) = hidden_dim_id_p; + + constexpr auto p2RHsMajor = Ps2RHssMajor{}[iDimP]; + constexpr auto p2RHsMinor = Ps2RHssMinor{}[iDimP]; + + static_assert(p2RHsMajor.size() == p2RHsMinor.size(), "wrong!"); + + constexpr index_t ndim_low = p2RHsMajor.size(); + + Dims low_dims; + Lengths low_lengths; + + for(index_t i = 0; i < ndim_low; ++i) + { + index_t rh_major = p2RHsMajor[i]; + index_t rh_minor = p2RHsMinor[i]; + low_dims(i) = rh_major_minor_to_hidden_ids[rh_major][rh_minor]; + low_lengths(i) = rh_major_minor_to_hidden_lengths[rh_major][rh_minor]; + } + + trans(num_tran++) = {cood_transform_enum::merge, + MetaData{to_array(low_lengths)}, + NumDim{ndim_low}, + low_dims, + NumDim{1}, + Dims{hidden_dim_id_p}}; + }); + + constexpr index_t ndim_bottom = ndim_x; + + constexpr auto bottom_dim_ids = make_sequential_index(0, ndim_bottom); + + constexpr auto ys_to_rhs_major = Ys2RHsMajor{}; + constexpr auto ys_to_rhs_minor = Ys2RHsMinor{}; + + constexpr index_t ndim_y = Ys2RHsMajor::size(); + constexpr index_t ndim_top = ndim_p + ndim_y; + + auto top_dim_ids = hidden_dim_id_ps; + + { + for(index_t i = 0; i < ndim_y; ++i) + { + index_t rh_major = ys_to_rhs_major[i]; + index_t rh_minor = ys_to_rhs_minor[i]; + top_dim_ids(ndim_p + i) = rh_major_minor_to_hidden_ids[rh_major][rh_minor]; + } + } + + // + const auto ps_ys_to_xs_adaptor_encoding = + make_tuple(trans, num_tran, bottom_dim_ids, ndim_bottom, top_dim_ids, ndim_top); + + // descriptor: [y0, y1, ...] to [d] + Lengths y_lengths; + index_t d_length = 1; + + for(index_t i = 0; i < ndim_y; ++i) + { + index_t rh_major = ys_to_rhs_major[i]; + index_t rh_minor = ys_to_rhs_minor[i]; + index_t y_length = rh_major_minor_to_hidden_lengths[rh_major][rh_minor]; + y_lengths(i) = y_length; + d_length *= y_length; + } + + auto tran = make_tuple(cood_transform_enum::unmerge, + MetaData{to_array(y_lengths)}, + NumDim{1}, + Dims{0}, + NumDim{ndim_y}, + make_sequential_index(1, ndim_y + 1)); + + const auto ys_to_d_adaptor_encoding = make_tuple( + make_tuple(tran), 1, Dims{0}, 1, make_sequential_index(1, ndim_y + 1), ndim_y); + + return make_tuple(ps_ys_to_xs_adaptor_encoding, + ys_to_d_adaptor_encoding, + d_length, + rh_major_minor_to_hidden_ids); +} + +// FIXME: this is nasty. Move it inside TileDistributionEncoding::detail +template // tuple, ...> +struct tile_distribution_detail +{ + static constexpr auto rh_major_minor_to_adaptor_hidden_idss_ = + to_array_of_array(RhMajorMinor2AdaptorHiddenIdss{}); +}; + +} // namespace detail + +// this returns a constexpr tile_distribution +template +CK_TILE_HOST_DEVICE constexpr auto make_tile_distribution(StaticTileDistributionEncoding_) +{ + using DstrEncode = remove_cvref_t; + + constexpr auto adaptor_impl = + detail::make_adaptor_encoding_for_tile_distribution(StaticTileDistributionEncoding_{}); + + constexpr auto ps_ys_to_xs_adaptor_impl = adaptor_impl.template at<0>(); + constexpr auto ys_to_d_adaptor_impl = adaptor_impl.template at<1>(); + constexpr index_t d_length = adaptor_impl.template at<2>(); + constexpr auto rh_major_minor_to_hidden_ids_impl = adaptor_impl.template at<3>(); + + constexpr auto ps_ys_to_xs_adaptor = + CONSTRUCT_TENSOR_ADAPTOR_FROM_ENCODING(ps_ys_to_xs_adaptor_impl); + + constexpr auto ys_to_d_adaptor = CONSTRUCT_TENSOR_ADAPTOR_FROM_ENCODING(ys_to_d_adaptor_impl); + + constexpr auto ys_to_d_descriptor = + make_tensor_descriptor_from_adaptor(ys_to_d_adaptor, d_length); + + // + constexpr index_t ndim_rh_major = DstrEncode::detail::ndim_rh_major_; + constexpr auto ndims_rhs_minor = DstrEncode::detail::ndims_rhs_minor_; + + constexpr auto rh_major_minor_to_hidden_ids = + TO_TUPLE_OF_SEQUENCE(rh_major_minor_to_hidden_ids_impl, ndim_rh_major, ndims_rhs_minor); + + return tile_distribution< + remove_cvref_t, + remove_cvref_t, + remove_cvref_t, + detail::tile_distribution_detail>>{ + ps_ys_to_xs_adaptor, ys_to_d_descriptor}; +} + +// this returns a static tile_distribution +template +CK_TILE_HOST_DEVICE constexpr auto make_static_tile_distribution(StaticTileDistributionEncoding_) +{ + using DstrEncode = remove_cvref_t; + + constexpr auto adaptor_impl = + detail::make_adaptor_encoding_for_tile_distribution(StaticTileDistributionEncoding_{}); + + constexpr auto ps_ys_to_xs_adaptor_impl = adaptor_impl.template at<0>(); + constexpr auto ys_to_d_adaptor_impl = adaptor_impl.template at<1>(); + constexpr index_t d_length = adaptor_impl.template at<2>(); + constexpr auto rh_major_minor_to_hidden_ids_impl = adaptor_impl.template at<3>(); + + constexpr auto ps_ys_to_xs_adaptor = + CONSTRUCT_STATIC_TENSOR_ADAPTOR_FROM_ENCODING(ps_ys_to_xs_adaptor_impl); + + constexpr auto ys_to_d_adaptor = + CONSTRUCT_STATIC_TENSOR_ADAPTOR_FROM_ENCODING(ys_to_d_adaptor_impl); + + constexpr auto ys_to_d_descriptor = + make_tensor_descriptor_from_adaptor(ys_to_d_adaptor, number{}); + + // + constexpr index_t ndim_rh_major = DstrEncode::detail::ndim_rh_major_; + constexpr auto ndims_rhs_minor = DstrEncode::detail::ndims_rhs_minor_; + + constexpr auto rh_major_minor_to_hidden_ids = + TO_TUPLE_OF_SEQUENCE(rh_major_minor_to_hidden_ids_impl, ndim_rh_major, ndims_rhs_minor); + + return tile_distribution< + remove_cvref_t, + remove_cvref_t, + remove_cvref_t, + detail::tile_distribution_detail>>{ + ps_ys_to_xs_adaptor, ys_to_d_descriptor}; +} + +//*********************************************************************************** + +namespace detail { + +template +CK_TILE_HOST_DEVICE auto get_partition_index(Distribution) +{ + // only support warp-tile and block-tile + static_assert(Distribution::NDimP == 1 or Distribution::NDimP == 2, "wrong!"); + + if constexpr(Distribution::NDimP == 1) + { + return array{get_lane_id()}; + } + else if constexpr(Distribution::NDimP == 2) + { + return array{get_warp_id(), get_lane_id()}; + } +} + +template +struct reverse_slice_sequence_impl; + +template +struct reverse_slice_sequence_impl, + sequence, + sequence, + SliceSize> +{ + using old_scan = + reverse_slice_sequence_impl, sequence, sequence, SliceSize>; + + static constexpr auto slice_size = old_scan::remaining_slice_sizes::Front().value; + static constexpr auto slice_length = + std::conditional_t, number>::value; + + using dim_lengths = + typename sequence_merge, typename old_scan::dim_lengths>::type; + using dim_slices = + typename sequence_merge, typename old_scan::dim_slices>::type; + using remaining_slice_sizes = typename sequence_merge< + std::conditional_t, sequence>, + typename old_scan::remaining_slice_sizes>::type; + + // the first idx that sliced length not equal to original length + static constexpr index_t _flag = + slice_length != x && remaining_slice_sizes{}.Front().value == 1; + static constexpr index_t _split_flag = std::conditional_t, number<0>>::value; + static constexpr index_t _split_idx = + std::conditional_t<_split_flag, number, number<0>>::value; + + static constexpr index_t split_flag = _split_flag || old_scan::split_flag; + static constexpr index_t split_idx = std:: + conditional_t, number<_split_idx>>::value; +}; + +template +struct reverse_slice_sequence_impl, sequence, sequence, SliceSize> +{ + static constexpr auto slice_size = SliceSize; + static constexpr auto slice_length = + std::conditional_t, number>::value; + + using dim_lengths = sequence; + using dim_slices = sequence; + using remaining_slice_sizes = + std::conditional_t, sequence>; + + // the first idx that sliced length not equal to original length + static constexpr index_t _flag = + slice_length != x && remaining_slice_sizes{}.Front().value == 1; + static constexpr index_t split_flag = std::conditional_t, number<0>>::value; + static constexpr index_t split_idx = + std::conditional_t, number<0>>::value; +}; + +// clang-format off +// input a sequence(with optional mask), and the SliceSize : size per slice +// output the sequence each slice, and number of slices +// +// e.g. <2, 1, 4, 2>, 8 -> lengths:<1, 1, 4, 2> , nums: <2, 1, 1, 1> : 2 slices , slice_idx: 0 +// <4, 2, 4, 1, 2>, 4 -> lengths:<1, 1, 2, 1, 2> , nums: <4, 2, 2, 1, 1> : 16 slices , slice_idx: 2 +// <4, 2, 4, 1, 6>, 4 -> lengths:<1, 1, 2, 1, 2> , nums: <4, 2, 2, 1, 3> : 48 slices , slice_idx: 2 +// <4, 2, 5, 1, 2>, 10 -> lengths:<1, 1, 5, 1, 2> , nums: <4, 2, 1, 1, 1> : 8 slices , slice_idx: 1 +// +// <4, 2, 8>, 64 -> lengths:<4, 2, 8> , nums: <1, 1, 1> : 1 slices , slice_idx: 0 +// <4, 2, 8>, 32 -> lengths:<2, 2, 8> , nums: <2, 1, 1> : 2 slices , slice_idx: 0 +// <4, 2, 8>, 16 -> lengths:<1, 2, 8> , nums: <4, 1, 1> : 4 slices , slice_idx: 0 +// <4, 2, 8>, 8 -> lengths:<1, 1, 8> , nums: <4, 2, 1> : 8 slices , slice_idx: 1 +// <4, 2, 8>, 4 -> lengths:<1, 1, 4> , nums: <4, 2, 2> : 16 slices , slice_idx: 2 +// <4, 2, 8>, 2 -> lengths:<1, 1, 2> , nums: <4, 2, 4> : 32 slices , slice_idx: 2 +// <4, 2, 8>, 1 -> lengths:<1, 1, 1> , nums: <4, 2, 8> : 64 slices , slice_idx: 2 +// +// <4, 2, 1, 4, 2> / 4 -> +// mask:<1, 1, 1, 0, 1>, -> lengths:<1, 2, 1, 4, 2> , nums: <4, 1, 1, 1, 1> : 8 slices , slice_idx: 0 +// +// return tuple, slice_index is at which index will start +// have split slices (right -> left) +// or the first index that sliced length is different from the original length +// clang-format on +template ::type> +constexpr auto reverse_slice_sequence(Seq, + number, + Mask = typename uniform_sequence_gen::type{}) +{ + static_assert(Seq::size() == Mask::size()); + using sliced_type = + reverse_slice_sequence_impl::type, + SliceSize>; + static_assert(sliced_type::remaining_slice_sizes::Front().value == 1, + "can not evenly divide this sequence, please check"); + return make_tuple(typename sliced_type::dim_lengths{}, + typename sliced_type::dim_slices{}, + number{}); +} + +// +// slice tensor from x_dim, result in split in y_dim, not p_dim. +// We don't support slice cross p_dim (aka, slice different threads) +// also, sliced along y_dim need be the first dim of current dim. +// Multiply Y dim before sliced dim does not make sense +// +// e.g +// X0 X1 +// <1, 4, 32> - <4, 1, 4, 2, 4> | slice origin:<0, 0>, len:<0, 32>, (0 means all length) +// Y P P Y P Y P Y +// => <1, 4, 32> - <1, 1, 4, 2, 4> -> OK +// |--> slice along this Y dim, is the first dim of X1, totally 4 slices +// +// X0 X1 +// <1, 4, 32> - <4, 1, 4, 2, 4> | slice origin:<0, 0>, len:<0, 8>, (0 means all length) +// Y P P Y P Y P Y +// => <1, 4, 32> - <1, 1, 1, 2, 4> -> OK +// |--> slice along this Y dim, the P dim is 1 in the left, so is OK +// totally 16 slices +// +// X0 X1 +// <1, 4, 32> - <4, 1, 4, 2, 4> | slice origin:<0, 0>, len:<0, 4>, (0 means all length) +// Y P P Y P Y P Y +// => <1, 4, 32> - <1, 1, 1, 1, 4> -> Fail +// |--> slice along this P dim, will split threads, not supported +// +// X0 X1 +// <1, 4, 32> - <4, 1, 4, 2, 4> | slice origin:<0, 0>, len:<0, 16>, (0 means all length) +// Y P P Y P Y P Y +// => <1, 4, 32> - <1, 1, 2, 2, 4> -> OK +// |--> slice along this Y dim, but this Y sim need to split into 2 +// subdime +// the P dim in the left is 1, means actually not crossing P +// +template +CK_TILE_HOST_DEVICE constexpr auto slice_distribution_from_x( + Distribution, sequence x_slice_begins, sequence x_slice_ends) +{ + // NOTE: this function need to be called under constexpr context, + // due to https://wg21.link/p2280r0 we have to use non-reference type for distribution + using Encoding = decltype(Distribution::get_static_tile_distribution_encoding()); + + static_assert(sizeof...(XSliceBegins) == sizeof...(XSliceEnds)); + + constexpr auto x_slice_lengths = x_slice_ends - x_slice_begins; + + constexpr auto src_h_prefix_sum = Encoding::detail::get_h_dim_lengths_prefix_sum(); + constexpr auto src_y_info = Encoding::detail::get_sorted_y_info(); + constexpr auto src_y_dims = src_y_info[number<0>{}]; + constexpr auto src_y_maps = src_y_info[number<1>{}]; + constexpr auto src_y_prefix_sum = src_y_info[number<2>{}]; + + constexpr auto sliced_hlen_yidx_ylen = [&]() constexpr + { + auto y_slice_sorted_origins = make_zero_multi_index(); + auto y_slice_lengths = Encoding::detail::ys_lengths_; + + // This lambda will modify some value outside, so c++ will not treat return value as + // constexpr + // TODO: ugly + auto new_h_lengths = transform_tuples( + [&](auto h_len, auto id) { + constexpr auto sliced_h = + reverse_slice_sequence(h_len, number{}); + + constexpr auto sliced_h_lens = sliced_h[number<0>{}]; + constexpr auto sliced_h_index = sliced_h[number<2>{}]; + + // update y_slice_lengths + constexpr auto uniformed_h_index = sliced_h_index + number{}; + constexpr auto found_y_index = container_find(src_y_dims, uniformed_h_index); + + static_assert(found_y_index >= 0 && found_y_index < src_y_dims.size(), + "not sliced at y dim, please check"); + + static_for<0, sliced_h_index + 1, 1>{}([&](auto i) { + y_slice_lengths(src_y_maps[found_y_index - i]) = + sliced_h_lens[sliced_h_index - i]; + }); + // TODO: add validations not across p dim + + // NOTE: this y_origin is for all dims, not only current dim + // will later use pick to select target dim + constexpr auto y_origin = [&]() { + constexpr auto h_trans = make_merge_transform_v3_division_mod(h_len); + auto h_origin_ = make_zero_multi_index(); + h_trans.calculate_lower_index(h_origin_, sequence{}); + + auto y_origin_ = make_zero_multi_index(); + static_for<0, sliced_h_index + 1, 1>{}([&](auto i) { + y_origin_(found_y_index - i) = h_origin_[sliced_h_index - i]; + }); + return y_origin_; + }(); + + constexpr auto y_picks = typename arithmetic_sequence_gen::type{}; + + set_container_subset( + y_slice_sorted_origins, y_picks, get_container_subset(y_origin, y_picks)); + return sliced_h_lens; + }, + typename Encoding::HsLengthss{}, + typename arithmetic_sequence_gen<0, Encoding::HsLengthss::size(), 1>::type{}); + + auto y_slice_origins = container_reorder_given_old2new(y_slice_sorted_origins, src_y_maps); + + return make_tuple(new_h_lengths, y_slice_origins, y_slice_lengths); + } + (); + + constexpr auto sliced_h_lengths = sliced_hlen_yidx_ylen[number<0>{}]; + constexpr auto sliced_y_origins_array = sliced_hlen_yidx_ylen[number<1>{}]; + constexpr auto sliced_y_origins_size = sliced_y_origins_array.size(); + constexpr auto sliced_y_lengths_array = sliced_hlen_yidx_ylen[number<2>{}]; + constexpr auto sliced_y_lengths_size = sliced_y_lengths_array.size(); + + constexpr auto sliced_y_origins = TO_SEQUENCE(sliced_y_origins_array, sliced_y_origins_size); + constexpr auto sliced_y_lengths = TO_SEQUENCE(sliced_y_lengths_array, sliced_y_lengths_size); + + return make_tuple( + make_static_tile_distribution( + tile_distribution_encoding{}), + sliced_y_origins, + sliced_y_lengths); +} + +} // namespace detail +} // namespace ck_tile diff --git a/include/ck_tile/core/tensor/tile_distribution_encoding.hpp b/include/ck_tile/core/tensor/tile_distribution_encoding.hpp new file mode 100644 index 0000000000..f2f3707e6c --- /dev/null +++ b/include/ck_tile/core/tensor/tile_distribution_encoding.hpp @@ -0,0 +1,761 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core/config.hpp" +#include "ck_tile/core/numeric/integer.hpp" +#include "ck_tile/core/numeric/integral_constant.hpp" +#include "ck_tile/core/algorithm/coordinate_transform.hpp" +#include "ck_tile/core/tensor/tensor_adaptor.hpp" +#include "ck_tile/core/tensor/tensor_adaptor_coordinate.hpp" +#include "ck_tile/core/container/container_helper.hpp" +#include "ck_tile/core/container/multi_index.hpp" +#include "ck_tile/core/numeric/math.hpp" +#include "ck_tile/core/utility/type_traits.hpp" + +namespace ck_tile { + +template + typename HsLengthss_, // tuple, ...> + typename Ps2RHssMajor_, // tuple, ...> + typename Ps2RHssMinor_, // tuple, ...> + typename Ys2RHsMajor_, // sequence<...> + typename Ys2RHsMinor_> // sequence<...> +struct tile_distribution_encoding +{ + using RsLengths = remove_cvref_t; + using HsLengthss = remove_cvref_t; + using Ps2RHssMajor = remove_cvref_t; + using Ps2RHssMinor = remove_cvref_t; + using Ys2RHsMajor = remove_cvref_t; + using Ys2RHsMinor = remove_cvref_t; + + static_assert(Ps2RHssMajor::size() == Ps2RHssMinor::size(), "wrong!"); + static_assert(Ys2RHsMajor::size() == Ys2RHsMinor::size(), "wrong!"); + + static constexpr index_t NDimX = HsLengthss::size(); + static constexpr index_t NDimP = Ps2RHssMajor::size(); + static constexpr index_t NDimY = Ys2RHsMajor::size(); + static constexpr index_t NDimR = RsLengths::size(); + + // FIXME: move into detail + static constexpr auto rs_lengths_ = RsLengths{}; + static constexpr auto hs_lengthss_ = HsLengthss{}; + static constexpr auto ps_to_rhss_major_ = Ps2RHssMajor{}; + static constexpr auto ps_to_rhss_minor_ = Ps2RHssMinor{}; + static constexpr auto ys_to_rhs_major_ = Ys2RHsMajor{}; + static constexpr auto ys_to_rhs_minor_ = Ys2RHsMinor{}; + + // redundant but useful info + // TODO: really bad code, should be over-hauled + struct detail + { + // ndim_rh_major_, ndim_span_mainor_ + static constexpr index_t ndim_rh_major_ = NDimX + 1; + static constexpr index_t ndim_span_major_ = NDimX; + + // ndims_rhs_minor_[ndim_rh_major_] + static constexpr auto ndims_rhs_minor_ = generate_array( + [](auto i) { + if constexpr(i.value == 0) + { + return rs_lengths_.size(); + } + else + { + return hs_lengthss_[i - number<1>{}].size(); + } + }, + number{}); + + // max_ndim_rh_minor_ + static constexpr index_t max_ndim_rh_minor_ = + container_reduce(ndims_rhs_minor_, math::maximize{}, 0); + + // rhs_lengthss_[ndim_rh_major_][max_ndim_rh_minor_] + static constexpr auto rhs_lengthss_ = + to_array_of_array(container_concat(make_tuple(rs_lengths_), hs_lengthss_)); + + // ys_lengths_ + static constexpr auto ys_lengths_ = [] { + array ys_lengths_tmp{-1}; + + for(index_t i = 0; i < NDimY; i++) + { + index_t rh_major = ys_to_rhs_major_[i]; + index_t rh_minor = ys_to_rhs_minor_[i]; + + ys_lengths_tmp(i) = rhs_lengthss_[rh_major][rh_minor]; + } + + return ys_lengths_tmp; + }(); + + // rhs_major_minor_to_ys_[ndim_rh_majpr_][max_ndim_rh_minor_] + static constexpr auto rhs_major_minor_to_ys_ = [] { + array, NDimX + 1> rhs_major_minor_to_ys_tmp{{-1}}; + + static_for<0, NDimY, 1>{}([&](auto i) { + constexpr index_t rh_major = ys_to_rhs_major_[i]; + constexpr index_t rh_minor = ys_to_rhs_minor_[i]; + + rhs_major_minor_to_ys_tmp(rh_major)(rh_minor) = i; + }); + + return rhs_major_minor_to_ys_tmp; + }(); + + // ndims_span_minor_[NDimY] + static constexpr auto ndims_span_minor_ = [] { + array ndims_span_minor{0}; + + for(index_t i = 0; i < NDimY; i++) + { + const index_t span_major = ys_to_rhs_major_[i] - 1; + + ndims_span_minor(span_major)++; + } + + return ndims_span_minor; + }(); + + // max_ndim_span_minor_ + static constexpr index_t max_ndim_span_minor_ = + container_reduce(ndims_span_minor_, math::maximize{}, 0); + + // rhs_major_minor_to_span_minor_ [ndim_rh_major_][max_ndim_rh_minor_] + static constexpr auto rhs_major_minor_to_span_minor_ = [] { + array, ndim_rh_major_> rhs_major_minor_to_span_minor{ + {-1}}; + + static_for<0, ndim_rh_major_, 1>{}([&](auto rh_major) { + constexpr index_t ndim_rh_minor = ndims_rhs_minor_[rh_major]; + + index_t cnt_ndim_span_minor = 0; + + static_for<0, ndim_rh_minor, 1>{}([&](auto rh_minor) { + constexpr index_t idim_y = rhs_major_minor_to_ys_[rh_major][rh_minor]; + + if(idim_y >= 0) + { + rhs_major_minor_to_span_minor(rh_major)(rh_minor) = cnt_ndim_span_minor; + + cnt_ndim_span_minor++; + } + }); + }); + + return rhs_major_minor_to_span_minor; + }(); + + // ys_to_span_major_[NDimY] + static constexpr auto ys_to_span_major_ = + generate_array([](auto i) { return ys_to_rhs_major_[i] - 1; }, number{}); + + // ys_to_span_minor_[NDimY] + static constexpr auto ys_to_span_minor_ = generate_array( + [](auto i) { + return rhs_major_minor_to_span_minor_[ys_to_rhs_major_[i]][ys_to_rhs_minor_[i]]; + }, + number{}); + + // distributed_spans_lengthss_[ndim_span_major_][max_ndim_span_minor_] + static constexpr auto distributed_spans_lengthss_ = [] { + array, ndim_span_major_> + distributed_spans_lengthss{{-1}}; + + static_for<0, NDimY, 1>{}([&](auto i) { + const index_t rh_major = ys_to_rhs_major_[i]; + const index_t rh_minor = ys_to_rhs_minor_[i]; + + const index_t h_length = hs_lengthss_[number{}][rh_minor]; + + const index_t span_major = rh_major - 1; + const index_t span_minor = rhs_major_minor_to_span_minor_[rh_major][rh_minor]; + + distributed_spans_lengthss(span_major)(span_minor) = h_length; + }); + + return distributed_spans_lengthss; + }(); + + // ndims_distributed_spans_minor_[ndim_span_major_] + static constexpr auto ndims_distributed_spans_minor_ = [] { + array ndims_distributed_spans_minor{0}; + + static_for<0, NDimY, 1>{}([&](auto i) { + const index_t span_major = ys_to_rhs_major_[i] - 1; + + ndims_distributed_spans_minor(span_major)++; + }); + + return ndims_distributed_spans_minor; + }(); + + // does_p_own_r_[NDimP][NDimR] + static constexpr auto does_p_own_r_ = [] { + if constexpr(NDimR > 0) + { + array, NDimP> does_p_own_r{{false}}; + + static_for<0, NDimP, 1>{}([&](auto idim_p) { + constexpr index_t ndim_low = ps_to_rhss_major_[idim_p].size(); + + static_for<0, ndim_low, 1>{}([&](auto idim_low) { + constexpr index_t rh_major = ps_to_rhss_major_[idim_p][idim_low]; + constexpr index_t rh_minor = ps_to_rhss_minor_[idim_p][idim_low]; + + if constexpr(rh_major == 0) + { + does_p_own_r(idim_p)(rh_minor) = true; + } + }); + }); + + return does_p_own_r; + } + else + { + return array, NDimP>{}; + } + }(); + + // ps_over_rs_derivative_[NDimP][NDimR] + static constexpr auto ps_over_rs_derivative_ = [] { + if constexpr(NDimR > 0) + { + array, NDimP> ps_over_rs_derivative{{0}}; + + static_for<0, NDimP, 1>{}([&](auto idim_p) { + constexpr index_t ndim_low = ps_to_rhss_major_[idim_p].size(); + + index_t p_over_rh_derivative = 1; + + static_for{}([&](auto idim_low) { + constexpr index_t rh_major = ps_to_rhss_major_[idim_p][idim_low]; + constexpr index_t rh_minor = ps_to_rhss_minor_[idim_p][idim_low]; + + constexpr index_t rh_length = rhs_lengthss_[rh_major][rh_minor]; + + if constexpr(rh_major == 0) + { + ps_over_rs_derivative(idim_p)(rh_minor) = p_over_rh_derivative; + } + + p_over_rh_derivative *= rh_length; + }); + }); + + return ps_over_rs_derivative; + } + else + { + return array, NDimP>{}; + } + }(); + + // e.g. tuple, seq<4, 1, 4, 2, 4>> --> seq<3, 5> --> seq<0, 3, 8> + CK_TILE_HOST_DEVICE static constexpr auto get_h_dim_lengths_prefix_sum() + { + // + // e.g. tuple, seq<4, 1, 4, 2, 4>> --> seq<3, 5> + constexpr auto uniformed_h_dim_lengths = generate_sequence_v2( + [&](auto i) { + constexpr index_t size = HsLengthss{}[i].size(); + return number{}; + }, + number{}); + + // <0, len_d0, len_d0+len_d1, ...> + // e.g. seq<3, 5> --> seq<0, 3, 8> + constexpr auto h_dim_prefix_sum = prefix_sum_sequence(uniformed_h_dim_lengths); + + return h_dim_prefix_sum; + } + + CK_TILE_HOST_DEVICE static constexpr auto get_uniformed_idx_y_to_h() + { + constexpr auto all_ys_2_rhss = transform_sequences( + [](auto major, auto minor) constexpr { + // <0, 0, len_d0, len_d0+len_d1, ...> + constexpr auto x_dim_prefix_sum = merge_sequences( + sequence<0>{} /*for R dims*/, get_h_dim_lengths_prefix_sum()); + return x_dim_prefix_sum.at(major) + minor; + }, + Ys2RHsMajor{}, + Ys2RHsMinor{}); + + return all_ys_2_rhss; + } + + // return tuple + template + CK_TILE_HOST_DEVICE static constexpr auto get_sorted_info(IdxSeq, PrefixSumSeq) + { + using sorted_idx = + sequence_unique_sort, math::equal>; + + constexpr auto sorted_dims = typename sorted_idx::type{}; + constexpr auto sorted_maps = typename sorted_idx::sorted2unsorted_map{}; + + constexpr auto sorted_histogram = + histogram_sorted_sequence(sorted_dims, PrefixSumSeq{}); + constexpr auto sorted_prefix_sum = prefix_sum_sequence(sorted_histogram); + + return make_tuple(sorted_dims, sorted_maps, sorted_prefix_sum); + } + + CK_TILE_HOST_DEVICE static constexpr auto get_sorted_y_info() + { + return get_sorted_info(get_uniformed_idx_y_to_h(), get_h_dim_lengths_prefix_sum()); + } + + CK_TILE_HOST_DEVICE void print() const + { + printf("tile_distribution_encoding::detail{"); + // + printf("ndim_rh_major_: "); + print(ndim_rh_major_); + printf(", "); + // + printf("ndim_span_major_: "); + print(ndim_span_major_); + printf(", "); + // + printf("ndims_rhs_minor_: "); + print(ndims_rhs_minor_); + printf(", "); + // + printf("ndim_rh_major_: "); + print(ndim_rh_major_); + printf(", "); + // + printf("max_ndim_rh_minor_: "); + print(max_ndim_rh_minor_); + printf(", "); + // + printf("rhs_lengthss_: "); + print(rhs_lengthss_); + printf(", "); + // + printf("ys_lengths_: "); + print(ys_lengths_); + printf(", "); + // + printf("rhs_major_minor_to_ys_: "); + print(rhs_major_minor_to_ys_); + printf(", "); + // + printf("ndims_span_minor_: "); + print(ndims_span_minor_); + printf(", "); + // + printf("max_ndim_span_minor_: "); + print(max_ndim_span_minor_); + printf(", "); + // + printf("ys_to_span_major_: "); + print(ys_to_span_major_); + printf(", "); + // + printf("ys_to_span_minor_: "); + print(ys_to_span_minor_); + printf(", "); + // + printf("distributed_spans_lengthss_: "); + print(distributed_spans_lengthss_); + printf(", "); + // + printf("ndims_distributed_spans_minor_: "); + print(ndims_distributed_spans_minor_); + printf(", "); + // + printf("ps_over_rs_derivative_: "); + print(ps_over_rs_derivative_); + // + printf("}"); + } + }; + + CK_TILE_HOST_DEVICE void print() const + { + printf("tile_distribution_encoding{"); + // + printf("NDimX: %d, NDimP: %d, NDimY: %d, ", NDimX, NDimP, NDimY); + // + printf("rs_lengths_: "); + print(rs_lengths_); + printf(", "); + // + printf("hs_lengthss_: "); + print(hs_lengthss_); + printf(", "); + // + printf("ps_to_rhss_major_: "); + print(ps_to_rhss_major_); + printf(", "); + // + printf("ps_to_rhss_minor_: "); + print(ps_to_rhss_minor_); + printf(", "); + // + printf("ys_to_rhs_major_: "); + print(ys_to_rhs_major_); + printf(", "); + // + printf("ys_to_rhs_minor_: "); + print(ys_to_rhs_minor_); + printf(", "); + // + printf("detail: "); + print(detail{}); + // + printf("}"); + } +}; + +namespace detail { + +template +CK_TILE_HOST_DEVICE constexpr auto make_embed_tile_distribution_encoding(OuterDstr, InnerDstr) +{ + static_assert(OuterDstr::NDimX == InnerDstr::NDimX, "wrong!"); + + constexpr index_t NDimHMajor = OuterDstr::NDimX; + + using RsLengths = + sequence_merge_t; + + constexpr auto hs_lengthss = generate_tuple( + [&](auto i) { + return merge_sequences(typename OuterDstr::HsLengthss{}[i], + typename InnerDstr::HsLengthss{}[i]); + }, + number{}); + + // + constexpr auto rhs_major_2_ndim_outer_rhs_minor = [&]() { + array rhs_major_2_ndim_outer_rhs_minor_; + + // R dimension + rhs_major_2_ndim_outer_rhs_minor_(0) = OuterDstr::RsLengths::size(); + + // Hs dimensions + static_for<0, NDimHMajor, 1>{}([&](auto i) { + rhs_major_2_ndim_outer_rhs_minor_(i + 1) = typename OuterDstr::HsLengthss{}[i].size(); + }); + + return rhs_major_2_ndim_outer_rhs_minor_; + }(); + + // Ps2RHssMinor + constexpr auto updated_inner_ps_2_rhss_minor = generate_tuple( + [&](auto p) { + constexpr auto inner_p_2_rhss_major = typename InnerDstr::Ps2RHssMajor{}[p]; + constexpr auto inner_p_2_rhss_minor = typename InnerDstr::Ps2RHssMinor{}[p]; + + constexpr index_t ndim_tmp = inner_p_2_rhss_minor.size(); + + constexpr auto updated_inner_p_2_rhss_minor = [&]() { + array updated_inner_p_2_rhss_minor_; + + for(index_t i = 0; i < ndim_tmp; i++) + { + index_t rh_major = inner_p_2_rhss_major[i]; + + index_t ndim_outer_h_minor = rhs_major_2_ndim_outer_rhs_minor[rh_major]; + + updated_inner_p_2_rhss_minor_(i) = inner_p_2_rhss_minor[i] + ndim_outer_h_minor; + } + + return updated_inner_p_2_rhss_minor_; + }(); + + return TO_SEQUENCE(updated_inner_p_2_rhss_minor, ndim_tmp); + }, + number{}); + + // Ys2RHsMinor + constexpr auto updated_inner_ys_2_rhs_minor = [&]() { + constexpr auto inner_ys_2_rhs_major = typename InnerDstr::Ys2RHsMajor{}; + constexpr auto inner_ys_2_rhs_minor = typename InnerDstr::Ys2RHsMinor{}; + + constexpr index_t ndim_tmp = inner_ys_2_rhs_minor.size(); + + constexpr auto updated_inner_ys_2_rhs_minor_ = [&]() { + array updated_inner_ys_2_rhs_minor__; + + for(index_t i = 0; i < ndim_tmp; i++) + { + index_t rh_major = inner_ys_2_rhs_major[i]; + + index_t ndim_outer_h_minor = rhs_major_2_ndim_outer_rhs_minor[rh_major]; + + updated_inner_ys_2_rhs_minor__(i) = inner_ys_2_rhs_minor[i] + ndim_outer_h_minor; + } + + return updated_inner_ys_2_rhs_minor__; + }(); + + return TO_SEQUENCE(updated_inner_ys_2_rhs_minor_, ndim_tmp); + }(); + + // + constexpr auto ps_2_rhss_major = + container_concat(typename OuterDstr::Ps2RHssMajor{}, typename InnerDstr::Ps2RHssMajor{}); + + constexpr auto ps_2_rhss_minor = + container_concat(typename OuterDstr::Ps2RHssMinor{}, updated_inner_ps_2_rhss_minor); + + // + constexpr auto ys_2_rhs_major = + merge_sequences(typename OuterDstr::Ys2RHsMajor{}, typename InnerDstr::Ys2RHsMajor{}); + + constexpr auto ys_2_rhs_minor = + merge_sequences(typename OuterDstr::Ys2RHsMinor{}, updated_inner_ys_2_rhs_minor); + + return tile_distribution_encoding, + remove_cvref_t, + remove_cvref_t, + remove_cvref_t, + remove_cvref_t>{}; +} + +template +CK_TILE_HOST_DEVICE constexpr auto +make_reduce_tile_distribution_encoding_impl(InDstr, sequence reduce_dim_xs_in) +{ + constexpr auto I1 = number<1>{}; + + // FIXME: increase if fail + constexpr index_t max_ndim_r_out = 20; + constexpr index_t max_ndim_y_out = 20; + + // + constexpr index_t ndim_p = InDstr::NDimP; + constexpr index_t ndim_x_in = InDstr::NDimX; + constexpr index_t ndim_y_in = InDstr::NDimY; + constexpr index_t ndim_rh_major_in = InDstr::NDimX + 1; + constexpr index_t ndim_x_out = ndim_x_in - sizeof...(InReduceDimXs); + constexpr index_t max_ndim_rh_minor_in = InDstr::detail::max_ndim_rh_minor_; + + // ndims_ps_low + constexpr auto ndims_ps_low = generate_array( + [&](auto i) { return InDstr::ps_to_rhss_major_[i].size(); }, number{}); + + // is_rh_major_in_for_reduce + array is_rh_major_in_for_reduce{false}; + + for(index_t i = 0; i < reduce_dim_xs_in.size(); i++) + { + index_t rh_major = reduce_dim_xs_in[i] + 1; + + is_rh_major_in_for_reduce(rh_major) = true; + } + + // is_y_in_for_reduce + array is_y_in_for_reduce{false}; + + for(index_t i = 0; i < ndim_y_in; i++) + { + index_t rh_major = InDstr::ys_to_rhs_major_[i]; + + if(is_rh_major_in_for_reduce[rh_major]) + { + is_y_in_for_reduce(i) = true; + } + } + + // is_rh_minor_in_for_y_reduce + array, ndim_rh_major_in> is_rh_minor_in_for_y_reduce{{false}}; + + static_for<0, ndim_y_in, 1>{}([&](auto i) { + index_t rh_major = InDstr::ys_to_rhs_major_[i]; + index_t rh_minor = InDstr::ys_to_rhs_minor_[i]; + + if(is_y_in_for_reduce[i]) + { + is_rh_minor_in_for_y_reduce(rh_major)(rh_minor) = true; + } + }); + + // in2out_rh_major + array in2out_rh_major{-1}; + index_t cnt_ndim_rh_major_out = 0; + + for(index_t i = 0; i < ndim_rh_major_in; i++) + { + if(is_rh_major_in_for_reduce[i]) + { + in2out_rh_major(i) = 0; + } + else + { + in2out_rh_major(i) = cnt_ndim_rh_major_out; + + cnt_ndim_rh_major_out++; + } + } + + // rs_lengths_out, in2out_rh_minor + array rs_lengths_out{-1}; + array, ndim_rh_major_in> in2out_rh_minor{{-1}}; + + // loop over input R dim + for(index_t i = 0; i < InDstr::rs_lengths_.size(); i++) + { + // rs_lengths_out + rs_lengths_out(i) = InDstr::rs_lengths_[i]; + + // in2out_rh_minor + in2out_rh_minor(0)(i) = i; + } + + // loop over input H Dim + index_t cnt_ndim_r_out = InDstr::rs_lengths_.size(); + + static_for<1, ndim_rh_major_in, 1>{}([&](auto rh_major_in) { + constexpr auto h_major_in = rh_major_in - I1; + + constexpr index_t ndim_rh_minor_in = InDstr::hs_lengthss_[h_major_in].size(); + + if(is_rh_major_in_for_reduce[rh_major_in]) + { + for(index_t rh_minor_in = 0; rh_minor_in < ndim_rh_minor_in; rh_minor_in++) + { + if(not is_rh_minor_in_for_y_reduce[rh_major_in][rh_minor_in]) + { + // rs_lengths_out + rs_lengths_out(cnt_ndim_r_out) = InDstr::hs_lengthss_[h_major_in][rh_minor_in]; + + // in2out_rh_minor + in2out_rh_minor(rh_major_in)(rh_minor_in) = cnt_ndim_r_out; + + cnt_ndim_r_out++; + } + } + } + else + { + for(index_t rh_minor_in = 0; rh_minor_in < ndim_rh_minor_in; rh_minor_in++) + { + // in2out_rh_minor + in2out_rh_minor(rh_major_in)(rh_minor_in) = rh_minor_in; + } + } + }); + + // ndim_r_out + const index_t ndim_r_out = cnt_ndim_r_out; + + // ndims_hs_minor_out, hs_lengthss_out + array ndims_hs_minor_out{-1}; + array, ndim_x_out> hs_lengthss_out{{-1}}; + + index_t cnt_ndim_x_out = 0; + + static_for<0, ndim_x_in, 1>{}([&](auto i) { + if(not is_rh_major_in_for_reduce[i + I1]) + { + // ndims_hs_minor_out + ndims_hs_minor_out(cnt_ndim_x_out) = InDstr::hs_lengthss_[i].size(); + + // hs_lengthss_out + static_for<0, InDstr::hs_lengthss_[i].size(), 1>{}( + [&](auto j) { hs_lengthss_out(cnt_ndim_x_out)(j) = InDstr::hs_lengthss_[i][j]; }); + + cnt_ndim_x_out++; + } + }); + + // ps_to_rhss_major_out, ps_to_rhss_minor_out + array, ndim_p> ps_to_rhss_major_out{{-1}}; + array, ndim_p> ps_to_rhss_minor_out{{-1}}; + + static_for<0, ndim_p, 1>{}([&](auto idim_p) { + static_for<0, InDstr::ps_to_rhss_major_[idim_p].size(), 1>{}([&](auto idim_low) { + index_t rh_major_in = InDstr::ps_to_rhss_major_[idim_p][idim_low]; + index_t rh_minor_in = InDstr::ps_to_rhss_minor_[idim_p][idim_low]; + + ps_to_rhss_major_out(idim_p)(idim_low) = in2out_rh_major[rh_major_in]; + ps_to_rhss_minor_out(idim_p)(idim_low) = in2out_rh_minor[rh_major_in][rh_minor_in]; + }); + }); + + // ys_to_rhs_major_out, ys_to_rhs_minor_out + array ys_to_rhs_major_out{-1}; + array ys_to_rhs_minor_out{-1}; + + index_t cnt_ndim_y_out = 0; + + static_for<0, ndim_y_in, 1>{}([&](auto i) { + if(not is_y_in_for_reduce[i]) + { + index_t rh_major_in = InDstr::ys_to_rhs_major_[i]; + index_t rh_minor_in = InDstr::ys_to_rhs_minor_[i]; + + ys_to_rhs_major_out(cnt_ndim_y_out) = in2out_rh_major[rh_major_in]; + ys_to_rhs_minor_out(cnt_ndim_y_out) = in2out_rh_minor[rh_major_in][rh_minor_in]; + + cnt_ndim_y_out++; + } + }); + + // ndim_y_out + const index_t ndim_y_out = cnt_ndim_y_out; + + // + return make_tuple(ndim_x_out, + ndim_p, + ndim_y_out, + ndim_r_out, + ndims_hs_minor_out, + ndims_ps_low, + rs_lengths_out, + hs_lengthss_out, + ps_to_rhss_major_out, + ps_to_rhss_minor_out, + ys_to_rhs_major_out, + ys_to_rhs_minor_out); +} + +template +CK_TILE_HOST_DEVICE constexpr auto +make_reduce_tile_distribution_encoding(InDstr, sequence reduce_dim_xs_in) +{ + constexpr auto impl = make_reduce_tile_distribution_encoding_impl(InDstr{}, reduce_dim_xs_in); + + constexpr index_t ndim_x = impl.template at<0>(); + constexpr index_t ndim_p = impl.template at<1>(); + constexpr index_t ndim_y = impl.template at<2>(); + constexpr index_t ndim_r = impl.template at<3>(); + constexpr auto ndims_hs_minor = impl.template at<4>(); + constexpr auto ndims_ps_low = impl.template at<5>(); + constexpr auto rs_lengths_impl = impl.template at<6>(); + constexpr auto hs_lengthss_impl = impl.template at<7>(); + constexpr auto ps_to_rhss_major_impl = impl.template at<8>(); + constexpr auto ps_to_rhss_minor_impl = impl.template at<9>(); + constexpr auto ys_to_rhs_major_impl = impl.template at<10>(); + constexpr auto ys_to_rhs_minor_impl = impl.template at<11>(); + + constexpr auto rs_lengths = TO_SEQUENCE(rs_lengths_impl, ndim_r); + constexpr auto hs_lengthss = TO_TUPLE_OF_SEQUENCE(hs_lengthss_impl, ndim_x, ndims_hs_minor); + constexpr auto ps_to_rhss_major = + TO_TUPLE_OF_SEQUENCE(ps_to_rhss_major_impl, ndim_p, ndims_ps_low); + constexpr auto ps_to_rhss_minor = + TO_TUPLE_OF_SEQUENCE(ps_to_rhss_minor_impl, ndim_p, ndims_ps_low); + constexpr auto ys_to_rhs_major = TO_SEQUENCE(ys_to_rhs_major_impl, ndim_y); + constexpr auto ys_to_rhs_minor = TO_SEQUENCE(ys_to_rhs_minor_impl, ndim_y); + + return tile_distribution_encoding, + remove_cvref_t, + remove_cvref_t, + remove_cvref_t, + remove_cvref_t, + remove_cvref_t>{}; +} + +} // namespace detail +} // namespace ck_tile diff --git a/include/ck_tile/core/tensor/tile_elementwise.hpp b/include/ck_tile/core/tensor/tile_elementwise.hpp new file mode 100644 index 0000000000..e2b1f0c385 --- /dev/null +++ b/include/ck_tile/core/tensor/tile_elementwise.hpp @@ -0,0 +1,191 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core/config.hpp" +#include "ck_tile/core/container/array.hpp" +#include "ck_tile/core/container/sequence.hpp" +#include "ck_tile/core/container/tuple.hpp" +#include "ck_tile/core/container/container_helper.hpp" +#include "ck_tile/core/tensor/tensor_adaptor.hpp" +#include "ck_tile/core/utility/functional.hpp" +#include "ck_tile/core/utility/type_traits.hpp" + +namespace ck_tile { + +// TODO: support tensors with different distribution +template , NullTensor>>...>>> +CK_TILE_DEVICE void tile_elementwise_inout(const InOutElementFunc& inout_element_func, + InOutDstrTensors&... inout_dstr_tensors) +{ + // TODO: make sure all distributed tensors have same lengths and distribution + // static_assert(xxx); + + constexpr index_t thread_buffer_size = + type_pack_element<0, InOutDstrTensors...>::get_thread_buffer_size(); + + static_for<0, thread_buffer_size, 1>{}( + [&](auto i) { inout_element_func(inout_dstr_tensors.get_thread_buffer().at(i)...); }); +} + +template >...>>> +CK_TILE_DEVICE auto tile_elementwise_in(const InElementFunc& in_element_func, + const InDstrTensors&... in_dstr_tensors) +{ + using OutDataType = decltype(in_element_func(typename InDstrTensors::DataType{}...)); + + // TODO: make sure all distributed tensors have same lengths and distribution + // static_assert(xxx); + constexpr auto in_tile_dstr = type_pack_element<0, InDstrTensors...>::get_tile_distribution(); + + constexpr index_t thread_buffer_size = + type_pack_element<0, InDstrTensors...>::get_thread_buffer_size(); + + auto out_dstr_tensor = make_static_distributed_tensor(in_tile_dstr); + + static_for<0, thread_buffer_size, 1>{}([&](auto i) { + out_dstr_tensor.get_thread_buffer()(i) = + in_element_func(in_dstr_tensors.get_thread_buffer()[i]...); + }); + + return out_dstr_tensor; +} + +template +CK_TILE_DEVICE void set_tile(DstrTensors& dstr_tensor, const T& value) +{ + tile_elementwise_inout( + [&value](auto& x) { + x = type_convert>(value); + }, + dstr_tensor); +} + +template +CK_TILE_DEVICE void set_tile(NullTensor&, const T&) +{ +} + +// TODO: prefer to use per-dword value to set a tensor, in case compiler not doing well with +// sub-dword tensor... +template +CK_TILE_DEVICE void set_tile(DstrTensors& dstr_tensor, number) +{ + constexpr index_t tensor_bytes = + DstrTensors::get_thread_buffer_size() * sizeof(typename DstrTensors::DataType); + if constexpr(v == 0 && tensor_bytes % 4 == 0) + { + using dvec_t = static_buffer_c; + auto& tensor = reinterpret_cast(dstr_tensor.get_thread_buffer()); + for(auto i = 0; i < tensor.size(); i++) + tensor.get(i) = v; + } + else + { + tile_elementwise_inout( + [](auto& x) { x = type_convert(v); }, + dstr_tensor); + } +} + +template +CK_TILE_DEVICE void set_tile(NullTensor&, number) +{ +} + +template +CK_TILE_DEVICE void clear_tile(DstrTensors& dstr_tensor) +{ + set_tile(dstr_tensor, 0); +} + +// TODO: this is ugly +template +CK_TILE_DEVICE auto cast_tile_pk_fp8x4(const InDstrTensors& in_dstr_tensors) +{ +#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) + // This API is designed to use the _pk_ serious of function + constexpr auto in_tile_dstr = InDstrTensors::get_tile_distribution(); + + constexpr index_t thread_buffer_size = InDstrTensors::get_thread_buffer_size(); + static_assert(thread_buffer_size % 4 == 0); + constexpr index_t thread_buffer_size_pk = thread_buffer_size / 4; + + auto out_dstr_tensor = make_static_distributed_tensor(in_tile_dstr); +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wuninitialized" + // __builtin_amdgcn_cvt_pk_fp8_f32() this builtin require the old value, and + // will generate a v_mov_b32 vxxx [old] before cvt, which result in unwanted ISA + // so we prepare an uninitialized variable purposely, and turn off the warning + int dummy_old; + static_for<0, thread_buffer_size_pk, 1>{}([&](auto i) { + uint32_t x = __builtin_amdgcn_cvt_pk_fp8_f32( + in_dstr_tensors.get_thread_buffer()[number<4 * i + 0>{}], + in_dstr_tensors.get_thread_buffer()[number<4 * i + 1>{}], + dummy_old, + false); // false -> WORD0 + + uint32_t y = __builtin_amdgcn_cvt_pk_fp8_f32( + in_dstr_tensors.get_thread_buffer()[number<4 * i + 2>{}], + in_dstr_tensors.get_thread_buffer()[number<4 * i + 3>{}], + dummy_old, + false); // false -> WORD0 + + constexpr int32_t m0 = 0x05040100; + using vec_t = typename vector_type::type; + + vec_t d = bit_cast(__builtin_amdgcn_perm(y, x, m0)); + out_dstr_tensor.get_thread_buffer().template set_as(number{}, d); + }); +#pragma clang diagnostic pop + + return out_dstr_tensor; +#else + // fallback + return tile_elementwise_in(type_convert, + in_dstr_tensors); +#endif +} + +template +CK_TILE_DEVICE auto cast_tile(const SrcDstrTensors& src_tensor) +{ + if constexpr((ck_tile::is_same_v || + ck_tile::is_same_v)&&ck_tile:: + is_same_v && + (SrcDstrTensors::get_thread_buffer_size() % 4 == 0)) + { + return cast_tile_pk_fp8x4(src_tensor); + } + else + return tile_elementwise_in(type_convert, + src_tensor); +} + +// no-op function for NullTensor arguments +template , NullTensor>...>>> +CK_TILE_DEVICE void tile_elementwise_inout(const InOutElementFunc&, MaybeNullTensor&&...) +{ +} + +// no-op function for NullTensor arguments +template , NullTensor>...>>> +CK_TILE_DEVICE auto tile_elementwise_in(const InElementFunc&, MaybeNullTensor&&...) +{ + return NullTensor{}; +} + +} // namespace ck_tile diff --git a/include/ck_tile/core/tensor/tile_window.hpp b/include/ck_tile/core/tensor/tile_window.hpp new file mode 100644 index 0000000000..80e483e6ad --- /dev/null +++ b/include/ck_tile/core/tensor/tile_window.hpp @@ -0,0 +1,735 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core/config.hpp" +#include "ck_tile/core/container/array.hpp" +#include "ck_tile/core/container/sequence.hpp" +#include "ck_tile/core/container/tuple.hpp" +#include "ck_tile/core/container/container_helper.hpp" +#include "ck_tile/core/tensor/tensor_adaptor.hpp" +#include "ck_tile/core/utility/functional.hpp" +#include "ck_tile/core/utility/type_traits.hpp" + +namespace ck_tile { + +template +struct tile_window_with_static_distribution +{ + using BottomTensorView = remove_reference_t; + using WindowLengths = remove_cvref_t; + using TileDstr = remove_cvref_t; + + using WindowAdaptor = typename TileDstr::PsYs2XsAdaptor; + using BottomTensorDesc = typename BottomTensorView::TensorDesc; + + using DataType = remove_cvref_t; + + static constexpr index_t NDimWindowAdaptorTop = WindowAdaptor::get_num_of_top_dimension(); + static constexpr index_t NDimBottomTensor = BottomTensorDesc::get_num_of_dimension(); + + static constexpr index_t NDimP = TileDstr::GetNumOfDimensionP(); + static constexpr index_t NDimY = TileDstr::GetNumOfDimensionY(); + + static constexpr auto I0 = number<0>{}; + static constexpr auto I1 = number<1>{}; + + // TODO: check WindowLengths and StaticTileDistribution are consistent + + static_assert(ck_tile::is_known_at_compile_time::value, + "wrong! lengths should be static"); + static_assert(TileDstr::is_static(), "wrong!"); + + static_assert(NDimBottomTensor == WindowAdaptor::get_num_of_bottom_dimension(), + "wrong! inconsistent # of diemsnions"); + + using AdaptorTopIndex = array; + using BottomTensorIndex = array; + + using WindowAdaptorCoord = + decltype(make_tensor_adaptor_coordinate(WindowAdaptor{}, AdaptorTopIndex{})); + + using BottomTensorCoord = + decltype(make_tensor_coordinate(BottomTensorDesc{}, BottomTensorIndex{})); + + struct load_store_traits + { + private: + static constexpr auto get_vector_dim_y_scalar_per_vector() + { + const auto [ys_vector_lengths, ys_vector_strides] = + tile_window_with_static_distribution:: + get_window_adaptor_ys_safe_vector_length_strides(); + + index_t VectorDimY_ = 0; + index_t ScalarPerVector_ = 1; + + for(index_t i = 0; i < NDimY; ++i) + { + if(ys_vector_strides[i] == 1 && ys_vector_lengths[i] > ScalarPerVector_) + { + ScalarPerVector_ = ys_vector_lengths[i]; + VectorDimY_ = i; + } + } + + return make_tuple(VectorDimY_, ScalarPerVector_); + } + + public: + static constexpr index_t VectorDimY = get_vector_dim_y_scalar_per_vector().template at<0>(); + static constexpr index_t ScalarPerVector = + get_vector_dim_y_scalar_per_vector().template at<1>(); + + using vector_type_t = vector_type_maker_t; + using vector_t = typename vector_type_t::type; + + private: + static constexpr auto scalars_per_access_ = [] { + constexpr auto scalars_per_access_arr = generate_array( + [&](auto i) { return (i == VectorDimY) ? ScalarPerVector : 1; }, number{}); + + /// TODO: add non-automatic storage argument support to macro TO_SEQUENCE() + constexpr auto NDimY_ = NDimY; + + return TO_SEQUENCE(scalars_per_access_arr, NDimY_); + }(); + + static constexpr auto get_space_filling_curve() + { + constexpr auto tile_dstr = TileDstr{}; + + constexpr auto thread_tensor_lengths_ys = + to_sequence(tile_dstr.get_ys_to_d_descriptor().get_lengths()); + + // FIXME: need logic to judge dim access order + using DimAccessOrder = typename arithmetic_sequence_gen<0, NDimY, 1>::type; + + return space_filling_curve{}; + } + + public: + using SFC_Ys = decltype(get_space_filling_curve()); + + static constexpr index_t NumAccess = SFC_Ys::get_num_of_access(); + + static_assert(0 < NumAccess, "Wrong! NumAccess should be larger than 0"); + static_assert(NumAccess % NumCoord == 0, "wrong! # of access is not divisible by NumCoord"); + }; + + static constexpr index_t NumAccessPerCoord = load_store_traits::NumAccess / NumCoord; + + CK_TILE_DEVICE constexpr tile_window_with_static_distribution() = default; + + CK_TILE_DEVICE constexpr tile_window_with_static_distribution( + const BottomTensorView& bottom_tensor_view, + const WindowLengths& window_lengths, + const BottomTensorIndex& window_origin, + const TileDstr& tile_distribution) + : bottom_tensor_view_{bottom_tensor_view}, + window_lengths_{window_lengths}, + window_origin_{window_origin}, + tile_dstr_{tile_distribution}, + pre_computed_coords_{} + { +#if 0 // debug + // TODO: this use more register for FA, but less register for GEMM + // need investigation + // only support warp-tile and block-tile + static_assert(NDimP == 1 or NDimP == 2, "wrong!"); + + WindowAdaptorCoord window_adaptor_thread_coord_tmp; + + if constexpr(NDimP == 1) + { + window_adaptor_thread_coord_tmp = make_tensor_adaptor_coordinate( + tile_distribution.get_ps_ys_to_xs_adaptor(), AdaptorTopIndex{get_lane_id(), 0}); + } + else if constexpr(NDimP == 2) + { + window_adaptor_thread_coord_tmp = + make_tensor_adaptor_coordinate(tile_distribution.get_ps_ys_to_xs_adaptor(), + AdaptorTopIndex{get_warp_id(), get_lane_id(), 0}); + } +#else + // TODO: this use less register for FA, but more register for GEMM + // need investigation + const auto window_adaptor_thread_coord_tmp = make_tensor_adaptor_coordinate( + tile_distribution.get_ps_ys_to_xs_adaptor(), + container_concat(detail::get_partition_index(tile_distribution), + array{0})); +#endif + + BottomTensorIndex bottom_tensor_thread_origin_idx_tmp = + window_origin + window_adaptor_thread_coord_tmp.get_bottom_index(); + + const auto bottom_tensor_thread_coord_tmp = make_tensor_coordinate( + bottom_tensor_view_.get_tensor_descriptor(), bottom_tensor_thread_origin_idx_tmp); + + // pre-compute NumCoord (WindowAdaptorCoord, BottomTensorCoord) bundles to speed up + // future load/store() calls (might allocate more registers) + using Traits = load_store_traits; + using SFC_Ys = typename Traits::SFC_Ys; + + static_for<0, NumCoord, 1>{}([&](auto iCoord) { + auto window_adaptor_thread_coord = window_adaptor_thread_coord_tmp; + auto bottom_tensor_thread_coord = bottom_tensor_thread_coord_tmp; + + constexpr auto idx_diff_ys = + SFC_Ys::get_step_between(number<0>{}, number{}); + + constexpr auto idx_diff_ps_ys = container_concat(array{0}, idx_diff_ys); + + move_window_adaptor_and_bottom_tensor_thread_coordinate( + window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys); + + pre_computed_coords_(iCoord) = + make_tuple(window_adaptor_thread_coord, bottom_tensor_thread_coord); + }); + } + + CK_TILE_DEVICE static constexpr index_t get_num_of_dimension() { return NDimBottomTensor; } + + CK_TILE_DEVICE static constexpr bool has_static_tile_distribution() + { + return TileDstr::is_static(); + } + + CK_TILE_DEVICE constexpr auto get_window_lengths() const { return window_lengths_; } + + CK_TILE_DEVICE constexpr auto get_tile_distribution() const { return tile_dstr_; } + + CK_TILE_DEVICE constexpr auto get_bottom_tensor_view() const { return bottom_tensor_view_; } + + CK_TILE_DEVICE constexpr auto get_window_origin() const { return window_origin_; } + + // move thread's window adaptor coordinate and bottom tensor coordinate + // [p0, p1, ..., y0, y1, ...] ==> [x0, x1, ...] ==> [x0', x1', ...] ==> [offset] + CK_TILE_DEVICE void move_window_adaptor_and_bottom_tensor_thread_coordinate( + WindowAdaptorCoord& window_adaptor_thread_coord, + BottomTensorCoord& bottom_tensor_thread_coord, + const AdaptorTopIndex& idx_diff_adaptor_top) const + { + array idx_diff_adaptor_bottom; + + move_tensor_adaptor_coordinate(tile_dstr_.get_ps_ys_to_xs_adaptor(), + window_adaptor_thread_coord, + idx_diff_adaptor_top, + idx_diff_adaptor_bottom); + + move_tensor_coordinate(bottom_tensor_view_.get_tensor_descriptor(), + bottom_tensor_thread_coord, + idx_diff_adaptor_bottom); + } + + // return vector dimension among [y0, y1, ...] + CK_TILE_DEVICE static constexpr auto get_window_adaptor_ys_safe_vector_length_strides() + { + // bottom tensor top dimension vector lengths and strides + const auto [bottom_tensor_top_dim_vector_lengths, bottom_tensor_top_dim_vector_strides] = + BottomTensorDesc::get_top_dimension_safe_vector_length_strides(); + + // window vector lengths/strides + const auto window_adaptor_bottom_dim_vector_lengths = bottom_tensor_top_dim_vector_lengths; + const auto window_adaptor_bottom_dim_vector_strides = bottom_tensor_top_dim_vector_strides; + + // window adaptor [p0, p1, ..., y0, y1, ...] + array window_adaptor_vector_lengths{ + -1}; + array window_adaptor_vector_strides{ + -1}; + + constexpr auto window_adaptor_bottom_dims = + WindowAdaptor::get_bottom_dimension_hidden_ids(); + + set_container_subset(window_adaptor_vector_lengths, + window_adaptor_bottom_dims, + window_adaptor_bottom_dim_vector_lengths); + set_container_subset(window_adaptor_vector_strides, + window_adaptor_bottom_dims, + window_adaptor_bottom_dim_vector_strides); + + const auto [window_adaptor_ps_ys_vector_lengths, window_adaptor_ps_ys_vector_strides] = + WindowAdaptor{}.get_top_dimension_safe_vector_length_strides( + window_adaptor_vector_lengths, window_adaptor_vector_strides); + + // [y0, y1, ...] + constexpr auto y_dims = typename arithmetic_sequence_gen::type{}; + + return make_tuple(get_container_subset(window_adaptor_ps_ys_vector_lengths, y_dims), + get_container_subset(window_adaptor_ps_ys_vector_strides, y_dims)); + } + + CK_TILE_DEVICE constexpr auto get_num_access() const { return load_store_traits::NumAccess; } + + template + CK_TILE_DEVICE auto load(bool_constant = {}) const + { + using Traits = load_store_traits; + + using vector_type_t = typename Traits::vector_type_t; + using vector_t = typename vector_type_t::type; + using SFC_Ys = typename Traits::SFC_Ys; + + constexpr auto tile_dstr = TileDstr{}; + + auto dst_tensor = make_static_distributed_tensor(tile_dstr); + + // loop over thread tensor space [y0, y1, ...] + static_for<0, NumCoord, 1>{}([&](auto iCoord) { + /// TODO: use structure binding (to be captured later) if compiled in C++20 + auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0]; + auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1]; + + static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) { + constexpr auto iAccess = number{}; + + // data index [y0, y1, ...] + constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess); + + // read from bottom tensor + const vector_t vec_value = + get_bottom_tensor_view().template get_vectorized_elements( + bottom_tensor_thread_coord, bool_constant{}); + + const vector_type_t vec{vec_value}; + + // write into distributed tensor + static_for<0, Traits::ScalarPerVector, 1>{}([&](auto j) { + constexpr auto idx_ys = generate_array( + [&](auto jj) { + return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j) + : idx_ys_start[jj]; + }, + number{}); + + constexpr index_t d = + tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys); + + dst_tensor.get_thread_buffer().template at() = + vec.template AsType()[j]; + }); + + // move thread coordinate + if constexpr(iCoordAccess != (NumAccessPerCoord - 1)) + { + constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess); + + constexpr auto idx_diff_ps_ys = + container_concat(array{0}, idx_diff_ys); + + move_window_adaptor_and_bottom_tensor_thread_coordinate( + window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys); + } + }); + }); + + return dst_tensor; + } + + template + CK_TILE_DEVICE void load_raw(DstTile& dst_tensor, + bool_constant = {}) const + { + using Traits = load_store_traits; + + using vector_type_t = typename Traits::vector_type_t; + using vector_t = typename vector_type_t::type; + using SFC_Ys = typename Traits::SFC_Ys; + static constexpr index_t YElementSize = + TileDstr{}.get_ys_to_d_descriptor().get_element_space_size(); + static_assert(YElementSize % Traits::ScalarPerVector == 0); + using vectorized_tbuf = StaticBuffer; + + constexpr auto tile_dstr = TileDstr{}; + + auto& dst_vec_tbuf = reinterpret_cast(dst_tensor.get_thread_buffer()); + + // loop over thread tensor space [y0, y1, ...] + static_for<0, NumCoord, 1>{}([&](auto iCoord) { + /// TODO: use structure binding (to be captured later) if compiled in C++20 + auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0]; + auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1]; + + static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) { + constexpr auto iAccess = number{}; + + // data index [y0, y1, ...] + constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess); + constexpr index_t d = + tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys_start); + static_assert(d % Traits::ScalarPerVector == 0); + + get_bottom_tensor_view().template get_vectorized_elements_raw( + dst_vec_tbuf.template at(), + bottom_tensor_thread_coord, + bool_constant{}); + + // move thread coordinate + if constexpr(iCoordAccess != (NumAccessPerCoord - 1)) + { + constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess); + + constexpr auto idx_diff_ps_ys = + container_concat(array{0}, idx_diff_ys); + + move_window_adaptor_and_bottom_tensor_thread_coordinate( + window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys); + } + }); + }); + } + + // TODO: currently async load only implemented in inline asm + template + CK_TILE_DEVICE auto async_load(LdsTileWindow_&& lds_tile, + bool_constant = {}) const + { + using LdsTileWindow = remove_cvref_t; + // using LdsTensorView = typename LdsTileWindow::BottomTensorView; + using LdsDataType = typename LdsTileWindow::DataType; + // using LdsDescriptor = typename LdsTileWindow::BottomTensorDesc; + + // issues * warps * lanes + static_assert(LdsTileWindow::get_num_of_dimension() == 3); // TODO: hard coded + + const index_t size_per_buf = + lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset( + make_tuple(number<0>{}, number<0>{}, number<0>{})) * + sizeof(LdsDataType); + + const index_t size_per_wave = + lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset( + make_tuple(number<0>{}, number<1>{}, number<0>{})) * + sizeof(LdsDataType) - + size_per_buf; + + const index_t size_per_issue = + lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset( + make_tuple(number<1>{}, number<0>{}, number<0>{})) * + sizeof(LdsDataType) - + size_per_buf; + + const index_t m0_init_value = size_per_buf + size_per_wave * get_warp_id(); + m0_set_with_memory(m0_init_value); // This should be wave independent + + using Traits = load_store_traits; + + using vector_type_t = typename Traits::vector_type_t; + using vector_t = typename vector_type_t::type; + using SFC_Ys = typename Traits::SFC_Ys; + + LdsDataType* smem = lds_tile.get_bottom_tensor_view().get_buffer_view().p_data_; + + // loop over thread tensor space [y0, y1, ...] + static_for<0, NumCoord, 1>{}([&](auto iCoord) { + // TODO: use structure binding (to be captured later) if compiled in C++20 + auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0]; + auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1]; + + static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) { + constexpr auto iAccess = number{}; + + // read from bottom tensor + get_bottom_tensor_view().template async_get_vectorized_elements( + smem, bottom_tensor_thread_coord); + + // move thread coordinate + if constexpr(iCoordAccess != (NumAccessPerCoord - 1)) + { + constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess); + + constexpr auto idx_diff_ps_ys = + container_concat(array{0}, idx_diff_ys); + + move_window_adaptor_and_bottom_tensor_thread_coordinate( + window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys); + + m0_inc_with_memory(size_per_issue); + } + }); + }); + } + + template + CK_TILE_DEVICE void store(const static_distributed_tensor& dstr_tensor, + bool_constant = {}) const + { + using Traits = load_store_traits; + + using vector_type_t = typename Traits::vector_type_t; + using vector_t = typename vector_type_t::type; + using SFC_Ys = typename Traits::SFC_Ys; + + constexpr auto tile_dstr = TileDstr{}; + + // loop over thread tensor space [y0, y1, ...] + static_for<0, NumCoord, 1>{}([&](auto iCoord) { + /// TODO: use structure binding (to be captured later) if compiled in C++20 + auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0]; + auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1]; + + static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) { + constexpr auto iAccess = number{}; + + // data index [y0, y1, ...] + constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess); + + // read from distributed tensor + vector_type_t vec; + + static_for<0, Traits::ScalarPerVector, 1>{}([&](auto j) { + constexpr auto idx_ys = generate_array( + [&](auto jj) { + return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j) + : idx_ys_start[jj]; + }, + number{}); + + constexpr index_t d = + tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys); + + vec.template AsType()(j) = + dstr_tensor.get_thread_buffer().template at(); + }); + + const vector_t vec_value = vec.template AsType().template at<0>(); + + // write into bottom tensor + get_bottom_tensor_view().template set_vectorized_elements( + bottom_tensor_thread_coord, vec_value, bool_constant{}); + + // move thread coordinate + if constexpr(iCoordAccess != (NumAccessPerCoord - 1)) + { + constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess); + + constexpr auto idx_diff_ps_ys = + container_concat(array{0}, idx_diff_ys); + + move_window_adaptor_and_bottom_tensor_thread_coordinate( + window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys); + } + }); + }); + } + + CK_TILE_DEVICE void + store_raw(const static_distributed_tensor& dstr_tensor) const + { + using Traits = load_store_traits; + + using vector_type_t = typename Traits::vector_type_t; + using vector_t = typename vector_type_t::type; + using SFC_Ys = typename Traits::SFC_Ys; + + constexpr auto tile_dstr = TileDstr{}; + static constexpr bool oob_conditional_check = true; + + // loop over thread tensor space [y0, y1, ...] + static_for<0, NumCoord, 1>{}([&](auto iCoord) { + /// TODO: use structure binding (to be captured later) if compiled in C++20 + auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0]; + auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1]; + + static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) { + constexpr auto iAccess = number{}; + + // data index [y0, y1, ...] + constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess); + + // read from distributed tensor + vector_type_t vec; + + static_for<0, Traits::ScalarPerVector, 1>{}([&](auto j) { + constexpr auto idx_ys = generate_array( + [&](auto jj) { + return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j) + : idx_ys_start[jj]; + }, + number{}); + + constexpr index_t d = + tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys); + + vec.template AsType()(j) = + dstr_tensor.get_thread_buffer().template at(); + }); + + const vector_t vec_value = vec.template AsType().template at<0>(); + + // write into bottom tensor + get_bottom_tensor_view() + .template set_vectorized_elements_raw( + bottom_tensor_thread_coord, vec_value); + + // move thread coordinate + if constexpr(iCoordAccess != (NumAccessPerCoord - 1)) + { + constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess); + + constexpr auto idx_diff_ps_ys = + container_concat(array{0}, idx_diff_ys); + + move_window_adaptor_and_bottom_tensor_thread_coordinate( + window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys); + } + }); + }); + } + + // move thread's botom tensor coordiante + // [x0', x1', ... ] ==> [offset] + // also move window-origin + CK_TILE_DEVICE void move(const BottomTensorIndex& step) + { + window_origin_ += step; + + static_for<0, NumCoord, 1>{}([&](auto iCoord) { + move_tensor_coordinate(bottom_tensor_view_.get_tensor_descriptor(), + pre_computed_coords_(iCoord)(I1), + step); + }); + } + + // this is the bottom tensor view + // [x0', x1', ...] ==> [offset] + BottomTensorView bottom_tensor_view_; + + // + WindowLengths window_lengths_; + + // origin ([x0', x1', ...]) of window on bottom tensor + BottomTensorIndex window_origin_; + + // Tile tensor distribution, which contains: + // 1. adaptor for window: [p0, p1, ..., y0, y1, ...] ==> [x0, x1, ...] + // 2. thread descriptor for thread tensor in register: [y0, y1, ...] ==> [d] + TileDstr tile_dstr_; + + // this contains: + // per-thread coordinate for window adaptor + // per-thread coordinate for bottom tensor + array, NumCoord> pre_computed_coords_; +}; + +// TODO: use strategy +template +CK_TILE_DEVICE constexpr auto +make_tile_window(const TensorView_& tensor_view, + const WindowLengths_& window_lengths, + const multi_index& origin, + const StaticTileDistribution_& tile_distribution, + number = {}) +{ + return tile_window_with_static_distribution, + remove_cvref_t, + remove_cvref_t, + NumCoord>{ + tensor_view, window_lengths, origin, tile_distribution}; +} + +template +CK_TILE_DEVICE void move_tile_window( + tile_window_with_static_distribution& window, + const typename tile_window_with_static_distribution::BottomTensorIndex& step) +{ + window.move(step); +} + +template +struct tile_window_with_static_lengths +{ + using BottomTensorView = remove_reference_t; + using WindowLengths = remove_cvref_t; + using BottomTensorDesc = typename BottomTensorView::TensorDesc; + using DataType = typename BottomTensorView::DataType; + + static constexpr index_t NDimBottomTensor = BottomTensorDesc::get_num_of_dimension(); + + static_assert(ck_tile::is_known_at_compile_time::value, + "wrong! lengths should be static"); + + using BottomTensorIndex = array; + + CK_TILE_DEVICE constexpr tile_window_with_static_lengths() = default; + + CK_TILE_DEVICE constexpr tile_window_with_static_lengths( + const BottomTensorView& bottom_tensor_view, + const WindowLengths& window_lengths, + const BottomTensorIndex& window_origin) + : bottom_tensor_view_{bottom_tensor_view}, + window_lengths_{window_lengths}, + window_origin_{window_origin} + { + } + + CK_TILE_DEVICE static constexpr index_t get_num_of_dimension() { return NDimBottomTensor; } + + CK_TILE_DEVICE constexpr auto get_window_lengths() const { return window_lengths_; } + + CK_TILE_DEVICE constexpr auto get_bottom_tensor_view() const { return bottom_tensor_view_; } + + CK_TILE_DEVICE constexpr auto get_window_origin() const { return window_origin_; } + + // move window-origin + CK_TILE_DEVICE void move(const BottomTensorIndex& step) { window_origin_ += step; } + + // this is the bottom tensor view + // [x0', x1', ...] ==> [offset] + BottomTensorView bottom_tensor_view_; + + // + WindowLengths window_lengths_; + + // origin ([x0', x1', ...]) of window on bottom tensor + BottomTensorIndex window_origin_; +}; + +template +CK_TILE_DEVICE constexpr auto +make_tile_window(const TensorView_& tensor_view, + const WindowLengths_& window_lengths, + const multi_index& origin) +{ + static_assert(ck_tile::is_known_at_compile_time::value, + "wrong! lengths should be static"); + + return tile_window_with_static_lengths, + remove_cvref_t>{ + tensor_view, window_lengths, origin}; +} + +template +CK_TILE_DEVICE void move_tile_window( + tile_window_with_static_lengths& window, + const typename tile_window_with_static_lengths::BottomTensorIndex& + step) +{ + window.move(step); +} + +} // namespace ck_tile diff --git a/include/ck_tile/core/utility/bit_cast.hpp b/include/ck_tile/core/utility/bit_cast.hpp new file mode 100644 index 0000000000..2cb91b7d47 --- /dev/null +++ b/include/ck_tile/core/utility/bit_cast.hpp @@ -0,0 +1,19 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core/config.hpp" + +namespace ck_tile { + +template +CK_TILE_HOST_DEVICE constexpr Y bit_cast(const X& x) +{ + static_assert(__has_builtin(__builtin_bit_cast), ""); + static_assert(sizeof(X) == sizeof(Y), "Do not support cast between different size of type"); + + return __builtin_bit_cast(Y, x); +} + +} // namespace ck_tile diff --git a/include/ck_tile/core/utility/functional.hpp b/include/ck_tile/core/utility/functional.hpp new file mode 100644 index 0000000000..7bbc61cef1 --- /dev/null +++ b/include/ck_tile/core/utility/functional.hpp @@ -0,0 +1,194 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core/config.hpp" +#include "ck_tile/core/container/sequence.hpp" +#include "ck_tile/core/numeric/integer.hpp" +#include "ck_tile/core/numeric/integral_constant.hpp" +#include +#include + +namespace ck_tile { + +namespace detail { + +struct swallow +{ + template + CK_TILE_HOST_DEVICE constexpr swallow(Ts&&...) + { + } +}; + +template +struct static_for_impl; + +template +struct static_for_impl> +{ + template + CK_TILE_HOST_DEVICE constexpr void operator()(F f) const + { + swallow{(f(number{}), 0)...}; + } +}; + +} // namespace detail + +// F signature: F(number) +template +struct static_for +{ + CK_TILE_HOST_DEVICE constexpr static_for() + { + static_assert(Increment != 0 && (NEnd - NBegin) % Increment == 0, + "Wrong! should satisfy (NEnd - NBegin) % Increment == 0"); + static_assert((Increment > 0 && NBegin <= NEnd) || (Increment < 0 && NBegin >= NEnd), + "wrongs! should (Increment > 0 && NBegin <= NEnd) || (Increment < 0 && " + "NBegin >= NEnd)"); + } + + template + CK_TILE_HOST_DEVICE constexpr void operator()(F f) const + { + detail::static_for_impl::type>{}( + f); + } +}; + +struct identity +{ + template + CK_TILE_HOST_DEVICE constexpr T&& operator()(T&& arg) const noexcept + { + return std::forward(arg); + } +}; + +namespace detail { + +// RemainLengths: sequence<...> +// Orders: sequence<...> +template +struct static_ford_impl +{ + CK_TILE_HOST_DEVICE constexpr static_ford_impl() + { + static_assert(RemainLengths::size() > 0, "wrong! should not get here"); + } + + // F signature: F(sequence<...>) + // CurrentOrderedId: sequence<...> + template + CK_TILE_HOST_DEVICE constexpr void operator()(F f, CurrentOrderedId) const + { + static_for<0, RemainLengths::front(), 1>{}([=](auto I) { + static_ford_impl{}( + f, CurrentOrderedId::push_back(I)); + }); + } +}; + +template +struct static_ford_impl, Orders> +{ + // F signature: F(sequence<...>) + // OrderedId: sequence<...> + template + CK_TILE_HOST_DEVICE constexpr void operator()(F f, OrderedId) const + { + // retrive unordered Id + f(OrderedId::reorder_old_to_new(Orders{})); + } +}; + +} // namespace detail + +// Lengths is sequence<...>, it is the length of each dimension for +// N-dimensional loop +// Orders is sequence<...>, it is the order of dimension in which static_ford +// will loop over each +// dimension +template ::type> +struct static_ford +{ + CK_TILE_HOST_DEVICE constexpr static_ford() + { + static_assert(Lengths::size() > 0, "wrong! Lengths is empty"); + static_assert(Lengths::size() == Orders::size(), "wrong! inconsistent size"); + } + + // F signature: F(sequence<...> multi_id) + // multi_id is the unordered multi-index + template + CK_TILE_HOST_DEVICE constexpr void operator()(F f) const + { + constexpr auto ordered_lengths = Lengths::reorder_new_to_old(Orders{}); + detail::static_ford_impl{}(f, sequence<>{}); + } +}; + +namespace detail { + +template +struct unpack_impl; + +template +struct unpack_impl> +{ + template + CK_TILE_HOST_DEVICE constexpr auto operator()(F&& f, X&& x) const + { +#if 0 + return std::forward(f)(std::forward(x).at(number{})...); +#else + return std::forward(f)(std::forward(x).template at()...); +#endif + } +}; + +template +struct unpack2_impl; + +// TODO: remove this, after properly implementing unpack that takes any number of containers +template +struct unpack2_impl, sequence> +{ + template + CK_TILE_HOST_DEVICE constexpr auto operator()(F&& f, X&& x, Y&& y) const + { +#if 0 + return std::forward(f)(std::forward(x).at(number{})..., + std::forward(y).at(number{})...); +#else + return std::forward(f)(std::forward(x).template at()..., + std::forward(y).template at()...); +#endif + } +}; + +} // namespace detail + +template +CK_TILE_HOST_DEVICE constexpr auto unpack(F&& f, X&& x) +{ + using X_ = remove_reference_t; + return detail::unpack_impl::type>{}( + std::forward(f), std::forward(x)); +} + +// TODO: properly implement unpack that takes any number of containers +template +CK_TILE_HOST_DEVICE constexpr auto unpack2(F&& f, X&& x, Y&& y) +{ + using X_ = remove_reference_t; + using Y_ = remove_reference_t; + return detail::unpack2_impl::type, + typename arithmetic_sequence_gen<0, Y_::size(), 1>::type>{}( + std::forward(f), std::forward(x), std::forward(y)); +} + +} // namespace ck_tile diff --git a/include/ck_tile/core/utility/limits.hpp b/include/ck_tile/core/utility/limits.hpp new file mode 100644 index 0000000000..9a3987c177 --- /dev/null +++ b/include/ck_tile/core/utility/limits.hpp @@ -0,0 +1,75 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core/config.hpp" +#include +#include + +namespace ck_tile { + +template +struct numeric_limits +{ + // minimum finite value, or minimum positive normalized value for float + CK_TILE_HOST_DEVICE static constexpr T min() { return std::numeric_limits::min(); } + + // minumum finite value + CK_TILE_HOST_DEVICE static constexpr T lowest() { return std::numeric_limits::lowest(); } + + // maximum finite value + CK_TILE_HOST_DEVICE static constexpr T max() { return std::numeric_limits::max(); } + + // difference between 1.0 and next value representable by float + CK_TILE_HOST_DEVICE static constexpr T epsilon() { return std::numeric_limits::epsilon(); } + + // maximum rounding error + CK_TILE_HOST_DEVICE static constexpr T round_error() + { + return std::numeric_limits::round_error(); + } + + // positive infinity value + CK_TILE_HOST_DEVICE static constexpr T infinity() { return std::numeric_limits::infinity(); } + + // quiet NaN + CK_TILE_HOST_DEVICE static constexpr T quiet_NaN() + { + return std::numeric_limits::quiet_NaN(); + } + + // signaling NaN + CK_TILE_HOST_DEVICE static constexpr T signaling_NaN() + { + return std::numeric_limits::signaling_NaN(); + } + + // smallest positive subnormal value + CK_TILE_HOST_DEVICE static constexpr T denorm_min() + { + return std::numeric_limits::denorm_min(); + } +}; + +template +struct numeric_utils; + +template <> +struct numeric_utils +{ + static constexpr int exp = 8; + static constexpr int mant = 23; + static constexpr int bias = 127; + static constexpr uint32_t nan_mask = 0x7F800000; + static constexpr uint32_t head_mask = 0xFF800000; + static constexpr uint32_t mant_mask = 0x7FFFFF; + static constexpr uint32_t exp_mask = 0xFF; + static constexpr uint32_t Inf = 0x7F800000; + static constexpr uint32_t NegInf = 0xFF800000; + static constexpr uint32_t NaN = 0x7F800001; + static constexpr uint32_t Neg0 = 0x80000000; + using bitwise_type = uint32_t; +}; + +} // namespace ck_tile diff --git a/include/ck_tile/core/utility/magic_div.hpp b/include/ck_tile/core/utility/magic_div.hpp new file mode 100644 index 0000000000..1b7eb9c036 --- /dev/null +++ b/include/ck_tile/core/utility/magic_div.hpp @@ -0,0 +1,261 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core/config.hpp" +#include "ck_tile/core/container/tuple.hpp" +#include "ck_tile/core/numeric/integral_constant.hpp" +#include "ck_tile/core/utility/bit_cast.hpp" +#include "ck_tile/core/utility/type_traits.hpp" +#include + +namespace ck_tile { + +// magic number division +// Caution: +// 1. For uint32_t as dividend: magic number division implementation being used would produce +// correct result if the dividend is uint32_t and its value is within 31-bit value range. +// 2. For int32_t as dividendd: magic number division for int32_t dividened has not been +// implemented, the int32_t dividend would be bit-wise interpreted as uint32_t and magic number +// division implementation for uint32_t is then used. Therefore, dividend value need to be +// non-negative. +// TODO: +// 1. Implement magic number divison for int32_t +// 2. Implement magic number divison for unit32_t with 32-bit value range +struct magic_division32_bit_range +{ + // uint32_t + CK_TILE_HOST_DEVICE static constexpr auto calculate_magic_numbers(uint32_t divisor) + { + // WARNING: magic division is only valid for division inside this range. + // assert(divisor >= 1 && divisor <= INT32_MAX) + + uint32_t shift_u32 = 0; + + while((1U << shift_u32) < divisor) + { + shift_u32++; + }; + + uint64_t tmp_u64 = ((1UL << shift_u32) - divisor) << 32; + uint32_t multiplier_u32 = tmp_u64 / divisor + 1; + + return make_tuple(multiplier_u32, shift_u32); + } + + // integral_constant + template > + CK_TILE_HOST_DEVICE static constexpr auto + calculate_magic_numbers(integral_constant) + { + constexpr auto tmp = calculate_magic_numbers(uint32_t{Divisor}); + + constexpr uint32_t multiplier = tmp[number<0>{}]; + constexpr uint32_t shift = tmp[number<1>{}]; + + return make_tuple(integral_constant{}, + integral_constant{}); + } + + // integral_constant + template + CK_TILE_HOST_DEVICE static constexpr auto + calculate_magic_numbers(integral_constant) + { + return calculate_magic_numbers(integral_constant{}); + } + + // magic division for uint32_t + CK_TILE_DEVICE static constexpr uint32_t + do_magic_division(uint32_t dividend, uint32_t multiplier, uint32_t shift) + { + uint32_t tmp = __umulhi(dividend, multiplier); + return (tmp + dividend) >> shift; + } + + CK_TILE_HOST static constexpr uint32_t + do_magic_division(uint32_t dividend, uint32_t multiplier, uint32_t shift) + { + uint32_t tmp = (static_cast(dividend) * multiplier) >> 32; + return (tmp + dividend) >> shift; + } + + // magic division for int32_t + // HACK: use dividend_i32 as if it's uint32_t, dividend_i32 need to be + // non-negative for result to be correct + // TODO: figure out how to do magic number divison for int32_t as dividended + CK_TILE_DEVICE static constexpr int32_t + do_magic_division(int32_t dividend_i32, uint32_t multiplier, uint32_t shift) + { + uint32_t dividend_u32 = bit_cast(dividend_i32); + uint32_t tmp = __umulhi(dividend_u32, multiplier); + return (tmp + dividend_u32) >> shift; + } + + CK_TILE_HOST static constexpr int32_t + do_magic_division(int32_t dividend_i32, uint32_t multiplier, uint32_t shift) + { + uint32_t dividend_u32 = bit_cast(dividend_i32); + uint32_t tmp = (static_cast(dividend_u32) * multiplier) >> 32; + return (tmp + dividend_u32) >> shift; + } +}; + +// magic number division +// This version on works for divisor and dividended between [0, 1 << 16] +struct magic_division16_bit_range +{ + // uint32_t + CK_TILE_HOST_DEVICE static constexpr auto calculate_magic_numbers(uint32_t divisor) + { + // WARNING: magic division is only valid for division inside this range. + // assert(divisor >= 1 && divisor <= (1U << 16)); + + uint32_t shift_u32 = 0; + + while((1U << shift_u32) < divisor) + { + shift_u32++; + }; + + uint32_t one = 1; + uint32_t multiplier_u32 = ((one << 16) * ((one << shift_u32) - divisor)) / divisor + 1; + + return make_tuple(multiplier_u32, shift_u32); + } + + // integral_constant + template + CK_TILE_HOST_DEVICE static constexpr auto + calculate_magic_numbers(integral_constant) + { + constexpr auto tmp = calculate_magic_numbers(uint32_t{Divisor}); + + constexpr uint32_t multiplier = tmp[number<0>{}]; + constexpr uint32_t shift = tmp[number<1>{}]; + + return make_tuple(integral_constant{}, + integral_constant{}); + } + + // integral_constant + template + CK_TILE_HOST_DEVICE static constexpr auto + calculate_magic_numbers(integral_constant) + { + return calculate_magic_numbers(integral_constant{}); + } + + // magic division for uint32_t + CK_TILE_DEVICE static constexpr uint32_t + do_magic_division(uint32_t dividend, uint32_t multiplier, uint32_t shift) + { + uint32_t tmp = (dividend * multiplier) >> 16; + return (tmp + dividend) >> shift; + } + + CK_TILE_HOST static constexpr uint32_t + do_magic_division(uint32_t dividend, uint32_t multiplier, uint32_t shift) + { + uint32_t tmp = (dividend * multiplier) >> 16; + return (tmp + dividend) >> shift; + } + + // magic division for int32_t + // HACK: use dividend_i32 as if it's uint32_t, dividend_i32 need to be + // non-negative for result to be correct + // TODO: figure out how to do magic number divison for int32_t as dividended + CK_TILE_DEVICE static constexpr int32_t + do_magic_division(int32_t dividend_i32, uint32_t multiplier, uint32_t shift) + { + uint32_t dividend_u32 = bit_cast(dividend_i32); + uint32_t tmp = (dividend_u32 * multiplier) >> 16; + return (tmp + dividend_u32) >> shift; + } + + CK_TILE_HOST static constexpr int32_t + do_magic_division(int32_t dividend_i32, uint32_t multiplier, uint32_t shift) + { + uint32_t dividend_u32 = bit_cast(dividend_i32); + uint32_t tmp = (dividend_u32 * multiplier) >> 16; + return (tmp + dividend_u32) >> shift; + } +}; + +// use 32bit version +using magic_division = magic_division32_bit_range; + +struct mdiv +{ + // 1 dword -> 3 dword storage + uint32_t divisor; + uint32_t multiplier; + uint32_t shift; // TODO: 8 bit is enough + + // prefer construct on host + CK_TILE_HOST_DEVICE mdiv(uint32_t divisor_) : divisor(divisor_) + { + auto tmp = magic_division::calculate_magic_numbers(divisor_); + + multiplier = tmp[number<0>{}]; + shift = tmp[number<1>{}]; + } + + CK_TILE_HOST_DEVICE mdiv() : divisor(0), multiplier(0), shift(0) {} + + CK_TILE_HOST_DEVICE void update(uint32_t divisor_) + { + divisor = divisor_; + auto tmp = magic_division::calculate_magic_numbers(divisor_); + + multiplier = tmp[number<0>{}]; + shift = tmp[number<1>{}]; + } + + CK_TILE_HOST_DEVICE uint32_t div(uint32_t dividend_) const + { + return magic_division::do_magic_division(dividend_, multiplier, shift); + } + + CK_TILE_HOST_DEVICE void + divmod(uint32_t dividend_, uint32_t& quotient_, uint32_t& remainder_) const + { + quotient_ = div(dividend_); + remainder_ = dividend_ - (quotient_ * divisor); + } + + CK_TILE_HOST_DEVICE uint32_t get() const { return divisor; } +}; + +struct mdiv2 +{ + // 1 dword -> 2 dword storage, divisor need compute from runtime + uint32_t multiplier; + uint32_t shift; // TODO: 8 bit is enough + + // prefer construct on host + CK_TILE_HOST_DEVICE mdiv2(uint32_t divisor_) + { + auto tmp = magic_division::calculate_magic_numbers(divisor_); + + multiplier = tmp[number<0>{}]; + shift = tmp[number<1>{}]; + } + + CK_TILE_HOST_DEVICE mdiv2() : multiplier(0), shift(0) {} + + CK_TILE_HOST_DEVICE uint32_t div(uint32_t dividend_) const + { + return magic_division::do_magic_division(dividend_, multiplier, shift); + } + + CK_TILE_HOST_DEVICE void + divmod(uint32_t dividend_, uint32_t divisor_, uint32_t& quotient_, uint32_t& remainder_) const + { + quotient_ = div(dividend_); + remainder_ = dividend_ - (quotient_ * divisor_); + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/core/utility/random.hpp b/include/ck_tile/core/utility/random.hpp new file mode 100644 index 0000000000..e7dc34e6d5 --- /dev/null +++ b/include/ck_tile/core/utility/random.hpp @@ -0,0 +1,64 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core/config.hpp" +#include "ck_tile/core/numeric/half.hpp" +#include +#include +#include + +namespace ck_tile { + +// return 0 if data is not fp16 or fp32 +template +struct prand_generator_t +{ + CK_TILE_HOST_DEVICE uint32_t operator()(int id, T val, uint32_t seed = seed_) + { + std::ignore = id; + std::ignore = val; + std::ignore = seed; + return 0; + } +}; + +// version for fp32 +template +struct prand_generator_t +{ + CK_TILE_HOST_DEVICE uint32_t operator()(int id, float val, uint32_t seed = seed_) + { + uint32_t x = *(reinterpret_cast(&val)); + uint32_t drop_bits = uint32_t(x) & 0xFFFFu; + drop_bits ^= x >> 16; + drop_bits = ((drop_bits & 31) << 11) | (drop_bits >> 5); + drop_bits *= 0x7000149; + // NOTE: If id is in 64 bit, we are only using lower 32 bit. + // So, it can have an effect of using same id for multiple elements when the id is + // very large! + uint32_t rng = (drop_bits ^ 0x13371337 ^ (id * 229791) ^ seed); + return rng; + } +}; + +// version for fp16 +template +struct prand_generator_t +{ + CK_TILE_HOST_DEVICE uint32_t operator()(int id, half_t val, uint32_t seed = seed_) + { + uint16_t x = *(reinterpret_cast(&val)); + uint32_t drop_bits = uint32_t(x) & 0xFFFFu; + drop_bits = ((drop_bits & 31) << 11) | (drop_bits >> 5); + drop_bits *= 0x7000149; + // NOTE: If id is in 64 bit, we are only using lower 32 bit. + // So, it can have an effect of using same id for multiple elements when the id is + // very large! + uint32_t rng = (drop_bits ^ 0x13371337 ^ (id * 229791) ^ seed); + return rng; + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/core/utility/to_sequence.hpp b/include/ck_tile/core/utility/to_sequence.hpp new file mode 100644 index 0000000000..1d2c73073d --- /dev/null +++ b/include/ck_tile/core/utility/to_sequence.hpp @@ -0,0 +1,72 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once +#include "ck_tile/core/container/sequence.hpp" +// TODO: use c++20 nontype template with struct to implement this + +#if 1 +// clang happen to support this feature (__cpp_generic_lambdas >= 201707) in c++17 mode +#define TO_SEQUENCE(a, n) \ + _Pragma("clang diagnostic push") \ + _Pragma("clang diagnostic ignored \"-Wc++20-extensions\"")[a]( \ + ck_tile::sequence) \ + { \ + return ck_tile::sequence{})...>{}; \ + } \ + (make_index_sequence{}) _Pragma("clang diagnostic pop") + +#else +// Macro function +// convert constexpr array to sequence, both a/n need to be constexpr (can't be a rvalue like 2) +#define TO_SEQUENCE(a, n) \ + [a, n] { \ + static_assert(a.size() >= n, "wrong! out of bound"); \ + static_assert(n <= 10, "not implemented"); \ + if constexpr(n == 0) \ + { \ + return ck_tile::sequence<>{}; \ + } \ + else if constexpr(n == 1) \ + { \ + return ck_tile::sequence{}; \ + } \ + else if constexpr(n == 2) \ + { \ + return ck_tile::sequence{}; \ + } \ + else if constexpr(n == 3) \ + { \ + return ck_tile::sequence{}; \ + } \ + else if constexpr(n == 4) \ + { \ + return ck_tile::sequence{}; \ + } \ + else if constexpr(n == 5) \ + { \ + return ck_tile::sequence{}; \ + } \ + else if constexpr(n == 6) \ + { \ + return ck_tile::sequence{}; \ + } \ + else if constexpr(n == 7) \ + { \ + return ck_tile::sequence{}; \ + } \ + else if constexpr(n == 8) \ + { \ + return ck_tile::sequence{}; \ + } \ + else if constexpr(n == 9) \ + { \ + return ck_tile::sequence{}; \ + } \ + else if constexpr(n == 10) \ + { \ + return ck_tile:: \ + sequence{}; \ + } \ + }() +#endif diff --git a/include/ck_tile/core/utility/type_convert.hpp b/include/ck_tile/core/utility/type_convert.hpp new file mode 100644 index 0000000000..4bc3393fd9 --- /dev/null +++ b/include/ck_tile/core/utility/type_convert.hpp @@ -0,0 +1,57 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include +#include "ck_tile/core/config.hpp" +#include "ck_tile/core/numeric/half.hpp" +#include "ck_tile/core/numeric/bfloat16.hpp" +#include "ck_tile/core/numeric/float8.hpp" + +namespace ck_tile { + +// Convert X to Y, both X and Y are non-const data types. +template || std::is_const_v), bool> = false> +CK_TILE_HOST_DEVICE constexpr Y type_convert(X x) +{ + static_assert(!std::is_reference_v && !std::is_reference_v); + + return static_cast(x); +} + +// Convert X to Y, either X or Y is a const data type. +template || std::is_const_v, bool> = false> +CK_TILE_HOST_DEVICE constexpr Y type_convert(X x) +{ + static_assert(!std::is_reference_v && !std::is_reference_v); + + using non_const_y = std::remove_const_t; + using non_const_x = std::remove_const_t; + return static_cast(type_convert(x)); +} + +#define CK_TILE_TYPE_CONVERT(dtype_, stype_) \ + template <> \ + inline CK_TILE_HOST_DEVICE constexpr dtype_ type_convert(stype_ x) \ + { \ + return stype_##_to_##dtype_(x); \ + } + +CK_TILE_TYPE_CONVERT(float, fp16_t) +CK_TILE_TYPE_CONVERT(float, bf16_t) +CK_TILE_TYPE_CONVERT(float, fp8_t) +CK_TILE_TYPE_CONVERT(float, bf8_t) + +CK_TILE_TYPE_CONVERT(fp16_t, float) +CK_TILE_TYPE_CONVERT(bf16_t, float) +CK_TILE_TYPE_CONVERT(fp8_t, float) +CK_TILE_TYPE_CONVERT(bf8_t, float) + +} // namespace ck_tile diff --git a/include/ck_tile/core/utility/type_traits.hpp b/include/ck_tile/core/utility/type_traits.hpp new file mode 100644 index 0000000000..9e1a7aa4c9 --- /dev/null +++ b/include/ck_tile/core/utility/type_traits.hpp @@ -0,0 +1,46 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core/config.hpp" +#include +#include + +namespace ck_tile { + +// remove_cvref_t +template +using remove_reference_t = typename std::remove_reference::type; + +template +using remove_cv_t = typename std::remove_cv::type; + +template +using remove_cvref_t = remove_cv_t>; + +template +using remove_pointer_t = typename std::remove_pointer::type; + +namespace impl { +template +struct is_static_impl +{ + static constexpr bool value = std::is_arithmetic::v ? false : T::is_static(); +}; +} // namespace impl + +template +using is_static = impl::is_static_impl>; + +template +inline constexpr bool is_static_v = is_static::value; + +// TODO: deprecate this +template +using is_known_at_compile_time = is_static; +// TODO: if evaluating a rvalue, e.g. a const integer +// , this helper will also return false, which is not good(?) +// do we need something like is_constexpr()? + +} // namespace ck_tile diff --git a/include/ck_tile/host.hpp b/include/ck_tile/host.hpp new file mode 100644 index 0000000000..1bbb4b9539 --- /dev/null +++ b/include/ck_tile/host.hpp @@ -0,0 +1,23 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/host/arg_parser.hpp" +#include "ck_tile/host/check_err.hpp" +#include "ck_tile/host/device_memory.hpp" +#include "ck_tile/host/fill.hpp" +#include "ck_tile/host/hip_check_error.hpp" +#include "ck_tile/host/host_tensor.hpp" +#include "ck_tile/host/kernel_launch.hpp" +#include "ck_tile/host/ranges.hpp" +#include "ck_tile/host/reference/reference_batched_elementwise.hpp" +#include "ck_tile/host/reference/reference_batched_gemm.hpp" +#include "ck_tile/host/reference/reference_batched_masking.hpp" +#include "ck_tile/host/reference/reference_batched_softmax.hpp" +#include "ck_tile/host/reference/reference_gemm.hpp" +#include "ck_tile/host/reference/reference_im2col.hpp" +#include "ck_tile/host/reference/reference_reduce.hpp" +#include "ck_tile/host/reference/reference_softmax.hpp" +#include "ck_tile/host/stream_config.hpp" + diff --git a/include/ck_tile/host/arg_parser.hpp b/include/ck_tile/host/arg_parser.hpp new file mode 100644 index 0000000000..5f8a78b4c9 --- /dev/null +++ b/include/ck_tile/host/arg_parser.hpp @@ -0,0 +1,184 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include +#include +#include +#include +#include +#include +#include + +namespace ck_tile { +/* + * a host side utility, arg parser for + * -[key0]=[value0] -[key1]=[value1] ... + */ +class ArgParser +{ + public: + class Arg + { + public: + std::string name; + std::string value; + std::string help_text; + }; + + ArgParser() {} + ArgParser& insert(const std::string& _name, + const std::string& _default_value, + const std::string& _help_text) + { + Arg in; + in.name = _name; + in.value = _default_value; + in.help_text = _help_text; + + if(input_map.count(_name) != 0) + { + printf("arg:%s already exist\n", _name.c_str()); + } + else + { + input_map[_name] = in; + keys.push_back(_name); + } + return *this; + } + void print() + { + printf("args:\n"); + for(auto& key : keys) + { + auto value = input_map[key]; + std::vector help_text_lines; + size_t pos = 0; + for(size_t next_pos = value.help_text.find('\n', pos); next_pos != std::string::npos;) + { + help_text_lines.push_back(std::string(value.help_text.begin() + pos, + value.help_text.begin() + next_pos++)); + pos = next_pos; + next_pos = value.help_text.find('\n', pos); + } + help_text_lines.push_back( + std::string(value.help_text.begin() + pos, value.help_text.end())); + + std::string default_value = std::string("(default:") + value.value + std::string(")"); + + std::cout << std::setw(2) << std::setw(12 - value.name.length()) << "-" << key + << std::setw(4) << " " << help_text_lines[0] << " " << default_value + << std::endl; + + for(auto help_next_line = std::next(help_text_lines.begin()); + help_next_line != help_text_lines.end(); + ++help_next_line) + { + std::cout << std::setw(17) << " " << *help_next_line << std::endl; + } + } + } + bool parse(int argc, char* argv[], int start_index = 1) + { + if(argc < start_index) + { + printf("not enough args\n"); + return false; + } + for(int i = start_index; i < argc; i++) + { + char* cur_arg = argv[i]; + if(cur_arg[0] != '-') + { + printf("illegal input\n"); + print(); + return false; + } + else + { + std::string text(cur_arg + 1); + if(text == "?") + { + print(); + return false; + } + auto pos = text.find('='); + if(pos == std::string::npos) + { + printf("arg should be [key]=[value] pair, here:%s\n", text.c_str()); + return false; + } + if(pos >= (text.size() - 1)) + { + printf("cant find value after \"=\", here:%s\n", text.c_str()); + return false; + } + auto key = text.substr(0, pos); + auto value = text.substr(pos + 1); + if(input_map.count(key) == 0) + { + printf("no such arg:%s\n", key.c_str()); + return false; + } + input_map[key].value = value; + } + } + return true; + } + + std::string get_str(const std::string& name) const + { + std::string value = input_map.at(name).value; + return value; + } + + int get_int(const std::string& name) const + { + int value = atoi(input_map.at(name).value.c_str()); + return value; + } + + uint32_t get_uint32(const std::string& name) const + { + uint32_t value = strtoul(input_map.at(name).value.c_str(), nullptr, 10); + return value; + } + + uint64_t get_uint64(const std::string& name) const + { + uint64_t value = strtoull(input_map.at(name).value.c_str(), nullptr, 10); + return value; + } + + bool get_bool(const std::string& name) const + { + auto v = input_map.at(name).value; + if(v.compare("t") == 0 || v.compare("true") == 0) + return true; + if(v.compare("f") == 0 || v.compare("false") == 0) + return false; + int value = atoi(v.c_str()); + return value == 0 ? false : true; + } + + float get_float(const std::string& name) const + { + double value = atof(input_map.at(name).value.c_str()); + return static_cast(value); + } + + double get_double(const std::string& name) const + { + double value = atof(input_map.at(name).value.c_str()); + return value; + } + + private: + std::unordered_map input_map; + std::vector keys; +}; +} // namespace ck_tile diff --git a/include/ck_tile/host/check_err.hpp b/include/ck_tile/host/check_err.hpp new file mode 100644 index 0000000000..5d39548602 --- /dev/null +++ b/include/ck_tile/host/check_err.hpp @@ -0,0 +1,375 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "ck_tile/core.hpp" +#include "ck_tile/host/ranges.hpp" + +namespace ck_tile { + +template +std::ostream& operator<<(std::ostream& os, const std::vector& v) +{ + using size_type = typename std::vector::size_type; + + os << "["; + for(size_type idx = 0; idx < v.size(); ++idx) + { + if(0 < idx) + { + os << ", "; + } + os << v[idx]; + } + return os << "]"; +} + +template +typename std::enable_if< + std::is_same_v, ranges::range_value_t> && + std::is_floating_point_v> && + !std::is_same_v, half_t>, + bool>::type +check_err(const Range& out, + const RefRange& ref, + const std::string& msg = "Error: Incorrect results!", + double rtol = 1e-5, + double atol = 3e-6, + bool allow_infinity_ref = false) +{ + if(out.size() != ref.size()) + { + std::cerr << msg << " out.size() != ref.size(), :" << out.size() << " != " << ref.size() + << std::endl; + return false; + } + + const auto is_infinity_error = [=](auto o, auto r) { + const bool either_not_finite = !std::isfinite(o) || !std::isfinite(r); + const bool both_infinite_and_same = std::isinf(o) && std::isinf(r) && (o == r); + + return either_not_finite && !(allow_infinity_ref && both_infinite_and_same); + }; + + bool res{true}; + int err_count = 0; + double err = 0; + double max_err = std::numeric_limits::min(); + for(std::size_t i = 0; i < ref.size(); ++i) + { + const double o = *std::next(std::begin(out), i); + const double r = *std::next(std::begin(ref), i); + err = std::abs(o - r); + if(err > atol + rtol * std::abs(r) || is_infinity_error(o, r)) + { + max_err = err > max_err ? err : max_err; + err_count++; + if(err_count < 5) + { + std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i + << "] != ref[" << i << "]: " << o << " != " << r << std::endl; + } + res = false; + } + } + if(!res) + { + const float error_percent = + static_cast(err_count) / static_cast(out.size()) * 100.f; + std::cerr << "max err: " << max_err; + std::cerr << ", number of errors: " << err_count; + std::cerr << ", " << error_percent << "% wrong values" << std::endl; + } + return res; +} + +template +typename std::enable_if< + std::is_same_v, ranges::range_value_t> && + std::is_same_v, bhalf_t>, + bool>::type +check_err(const Range& out, + const RefRange& ref, + const std::string& msg = "Error: Incorrect results!", + double rtol = 1e-3, + double atol = 1e-3, + bool allow_infinity_ref = false) +{ + if(out.size() != ref.size()) + { + std::cerr << msg << " out.size() != ref.size(), :" << out.size() << " != " << ref.size() + << std::endl; + return false; + } + + const auto is_infinity_error = [=](auto o, auto r) { + const bool either_not_finite = !std::isfinite(o) || !std::isfinite(r); + const bool both_infinite_and_same = std::isinf(o) && std::isinf(r) && (o == r); + + return either_not_finite && !(allow_infinity_ref && both_infinite_and_same); + }; + + bool res{true}; + int err_count = 0; + double err = 0; + // TODO: This is a hack. We should have proper specialization for bhalf_t data type. + double max_err = std::numeric_limits::min(); + for(std::size_t i = 0; i < ref.size(); ++i) + { + const double o = type_convert(*std::next(std::begin(out), i)); + const double r = type_convert(*std::next(std::begin(ref), i)); + err = std::abs(o - r); + if(err > atol + rtol * std::abs(r) || is_infinity_error(o, r)) + { + max_err = err > max_err ? err : max_err; + err_count++; + if(err_count < 5) + { + std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i + << "] != ref[" << i << "]: " << o << " != " << r << std::endl; + } + res = false; + } + } + if(!res) + { + const float error_percent = + static_cast(err_count) / static_cast(out.size()) * 100.f; + std::cerr << "max err: " << max_err; + std::cerr << ", number of errors: " << err_count; + std::cerr << ", " << error_percent << "% wrong values" << std::endl; + } + return res; +} + +template +typename std::enable_if< + std::is_same_v, ranges::range_value_t> && + std::is_same_v, half_t>, + bool>::type +check_err(const Range& out, + const RefRange& ref, + const std::string& msg = "Error: Incorrect results!", + double rtol = 1e-3, + double atol = 1e-3, + bool allow_infinity_ref = false) +{ + if(out.size() != ref.size()) + { + std::cerr << msg << " out.size() != ref.size(), :" << out.size() << " != " << ref.size() + << std::endl; + return false; + } + + const auto is_infinity_error = [=](auto o, auto r) { + const bool either_not_finite = !std::isfinite(o) || !std::isfinite(r); + const bool both_infinite_and_same = std::isinf(o) && std::isinf(r) && (o == r); + + return either_not_finite && !(allow_infinity_ref && both_infinite_and_same); + }; + + bool res{true}; + int err_count = 0; + double err = 0; + double max_err = std::numeric_limits>::min(); + for(std::size_t i = 0; i < ref.size(); ++i) + { + const double o = type_convert(*std::next(std::begin(out), i)); + const double r = type_convert(*std::next(std::begin(ref), i)); + err = std::abs(o - r); + if(err > atol + rtol * std::abs(r) || is_infinity_error(o, r)) + { + max_err = err > max_err ? err : max_err; + err_count++; + if(err_count < 5) + { + std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i + << "] != ref[" << i << "]: " << o << " != " << r << std::endl; + } + res = false; + } + } + if(!res) + { + const float error_percent = + static_cast(err_count) / static_cast(out.size()) * 100.f; + std::cerr << "max err: " << max_err; + std::cerr << ", number of errors: " << err_count; + std::cerr << ", " << error_percent << "% wrong values" << std::endl; + } + return res; +} + +template +std::enable_if_t<(std::is_same_v, ranges::range_value_t> && + std::is_integral_v> && + !std::is_same_v, bhalf_t>) +#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 + || std::is_same_v, int4_t> +#endif + , + bool> +check_err(const Range& out, + const RefRange& ref, + const std::string& msg = "Error: Incorrect results!", + double = 0, + double atol = 0) +{ + if(out.size() != ref.size()) + { + std::cerr << msg << " out.size() != ref.size(), :" << out.size() << " != " << ref.size() + << std::endl; + return false; + } + + bool res{true}; + int err_count = 0; + int64_t err = 0; + int64_t max_err = std::numeric_limits::min(); + for(std::size_t i = 0; i < ref.size(); ++i) + { + const int64_t o = *std::next(std::begin(out), i); + const int64_t r = *std::next(std::begin(ref), i); + err = std::abs(o - r); + + if(err > atol) + { + max_err = err > max_err ? err : max_err; + err_count++; + if(err_count < 5) + { + std::cerr << msg << " out[" << i << "] != ref[" << i << "]: " << o << " != " << r + << std::endl; + } + res = false; + } + } + if(!res) + { + const float error_percent = + static_cast(err_count) / static_cast(out.size()) * 100.f; + std::cerr << "max err: " << max_err; + std::cerr << ", number of errors: " << err_count; + std::cerr << ", " << error_percent << "% wrong values" << std::endl; + } + return res; +} + +template +std::enable_if_t<(std::is_same_v, ranges::range_value_t> && + std::is_same_v, fp8_t>), + bool> +check_err(const Range& out, + const RefRange& ref, + const std::string& msg = "Error: Incorrect results!", + double rtol = 1e-3, + double atol = 1e-3, + bool allow_infinity_ref = false) +{ + if(out.size() != ref.size()) + { + std::cerr << msg << " out.size() != ref.size(), :" << out.size() << " != " << ref.size() + << std::endl; + return false; + } + + const auto is_infinity_error = [=](auto o, auto r) { + const bool either_not_finite = !std::isfinite(o) || !std::isfinite(r); + const bool both_infinite_and_same = std::isinf(o) && std::isinf(r) && (o == r); + + return either_not_finite && !(allow_infinity_ref && both_infinite_and_same); + }; + + bool res{true}; + int err_count = 0; + double err = 0; + double max_err = std::numeric_limits::min(); + for(std::size_t i = 0; i < ref.size(); ++i) + { + const double o = type_convert(*std::next(std::begin(out), i)); + const double r = type_convert(*std::next(std::begin(ref), i)); + err = std::abs(o - r); + if(err > atol + rtol * std::abs(r) || is_infinity_error(o, r)) + { + max_err = err > max_err ? err : max_err; + err_count++; + if(err_count < 5) + { + std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i + << "] != ref[" << i << "]: " << o << " != " << r << std::endl; + } + res = false; + } + } + if(!res) + { + std::cerr << std::setw(12) << std::setprecision(7) << "max err: " << max_err << std::endl; + } + return res; +} + +template +std::enable_if_t<(std::is_same_v, ranges::range_value_t> && + std::is_same_v, bf8_t>), + bool> +check_err(const Range& out, + const RefRange& ref, + const std::string& msg = "Error: Incorrect results!", + double rtol = 1e-3, + double atol = 1e-3, + bool allow_infinity_ref = false) +{ + if(out.size() != ref.size()) + { + std::cerr << msg << " out.size() != ref.size(), :" << out.size() << " != " << ref.size() + << std::endl; + return false; + } + + const auto is_infinity_error = [=](auto o, auto r) { + const bool either_not_finite = !std::isfinite(o) || !std::isfinite(r); + const bool both_infinite_and_same = std::isinf(o) && std::isinf(r) && (o == r); + + return either_not_finite && !(allow_infinity_ref && both_infinite_and_same); + }; + + bool res{true}; + int err_count = 0; + double err = 0; + double max_err = std::numeric_limits::min(); + for(std::size_t i = 0; i < ref.size(); ++i) + { + const double o = type_convert(*std::next(std::begin(out), i)); + const double r = type_convert(*std::next(std::begin(ref), i)); + err = std::abs(o - r); + if(err > atol + rtol * std::abs(r) || is_infinity_error(o, r)) + { + max_err = err > max_err ? err : max_err; + err_count++; + if(err_count < 5) + { + std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i + << "] != ref[" << i << "]: " << o << " != " << r << std::endl; + } + res = false; + } + } + if(!res) + { + std::cerr << std::setw(12) << std::setprecision(7) << "max err: " << max_err << std::endl; + } + return res; +} + +} // namespace ck_tile diff --git a/include/ck_tile/host/device_memory.hpp b/include/ck_tile/host/device_memory.hpp new file mode 100644 index 0000000000..91463a06a9 --- /dev/null +++ b/include/ck_tile/host/device_memory.hpp @@ -0,0 +1,112 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include +#include "ck_tile/host/hip_check_error.hpp" + +namespace ck_tile { +template +__global__ void set_buffer_value(T* p, T x, uint64_t buffer_element_size) +{ + for(uint64_t i = threadIdx.x; i < buffer_element_size; i += blockDim.x) + { + p[i] = x; + } +} + +/** + * @brief Container for storing data in GPU device memory + * + */ +struct DeviceMem +{ + DeviceMem() : mpDeviceBuf(nullptr), mMemSize(0) {} + DeviceMem(std::size_t mem_size) : mMemSize(mem_size) + { + HIP_CHECK_ERROR(hipMalloc(static_cast(&mpDeviceBuf), mMemSize)); + } + void Realloc(std::size_t mem_size) + { + if(mpDeviceBuf) + { + HIP_CHECK_ERROR(hipFree(mpDeviceBuf)); + } + mMemSize = mem_size; + HIP_CHECK_ERROR(hipMalloc(static_cast(&mpDeviceBuf), mMemSize)); + } + void* GetDeviceBuffer() const { return mpDeviceBuf; } + std::size_t GetBufferSize() const { return mMemSize; } + void ToDevice(const void* p) const + { + if(mpDeviceBuf) + { + HIP_CHECK_ERROR( + hipMemcpy(mpDeviceBuf, const_cast(p), mMemSize, hipMemcpyHostToDevice)); + } + else + { + throw std::runtime_error("ToDevice with an empty pointer"); + } + } + void ToDevice(const void* p, const std::size_t cpySize) const + { + HIP_CHECK_ERROR( + hipMemcpy(mpDeviceBuf, const_cast(p), cpySize, hipMemcpyHostToDevice)); + } + void FromDevice(void* p) const + { + if(mpDeviceBuf) + { + HIP_CHECK_ERROR(hipMemcpy(p, mpDeviceBuf, mMemSize, hipMemcpyDeviceToHost)); + } + else + { + throw std::runtime_error("FromDevice with an empty pointer"); + } + } + void FromDevice(void* p, const std::size_t cpySize) const + { + HIP_CHECK_ERROR(hipMemcpy(p, mpDeviceBuf, cpySize, hipMemcpyDeviceToHost)); + } + void SetZero() const + { + if(mpDeviceBuf) + { + HIP_CHECK_ERROR(hipMemset(mpDeviceBuf, 0, mMemSize)); + } + } + template + void SetValue(T x) const + { + if(mMemSize % sizeof(T) != 0) + { + throw std::runtime_error("wrong! not entire DeviceMem will be set"); + } + + // TODO: call a gpu kernel to set the value (?) + set_buffer_value<<<1, 1024>>>(static_cast(mpDeviceBuf), x, mMemSize / sizeof(T)); + } + ~DeviceMem() + { + if(mpDeviceBuf) + { + try + { + HIP_CHECK_ERROR(hipFree(mpDeviceBuf)); + } + catch(std::runtime_error& re) + { + std::cerr << re.what() << std::endl; + } + } + } + + void* mpDeviceBuf; + std::size_t mMemSize; +}; + +} // namespace ck_tile diff --git a/include/ck_tile/host/fill.hpp b/include/ck_tile/host/fill.hpp new file mode 100644 index 0000000000..f490bbdeba --- /dev/null +++ b/include/ck_tile/host/fill.hpp @@ -0,0 +1,232 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +#include "ck_tile/core.hpp" + +namespace ck_tile { + +template +struct FillUniformDistribution +{ + float a_{-5.f}; + float b_{5.f}; + std::optional seed_{11939}; + + template + void operator()(ForwardIter first, ForwardIter last) const + { + std::mt19937 gen(seed_.has_value() ? *seed_ : std::random_device{}()); + std::uniform_real_distribution dis(a_, b_); + std::generate(first, last, [&dis, &gen]() { return ck_tile::type_convert(dis(gen)); }); + } + + template + auto operator()(ForwardRange&& range) const + -> std::void_t()( + std::begin(std::forward(range)), + std::end(std::forward(range))))> + { + (*this)(std::begin(std::forward(range)), + std::end(std::forward(range))); + } +}; + +template +struct FillNormalDistribution +{ + float mean_{0.f}; + float variance_{1.f}; + std::optional seed_{11939}; + + template + void operator()(ForwardIter first, ForwardIter last) const + { + std::mt19937 gen(seed_.has_value() ? *seed_ : std::random_device{}()); + std::normal_distribution dis(mean_, std::sqrt(variance_)); + std::generate(first, last, [&dis, &gen]() { return ck_tile::type_convert(dis(gen)); }); + } + + template + auto operator()(ForwardRange&& range) const + -> std::void_t()( + std::begin(std::forward(range)), + std::end(std::forward(range))))> + { + (*this)(std::begin(std::forward(range)), + std::end(std::forward(range))); + } +}; + +// Normally FillUniformDistributionIntegerValue should use std::uniform_int_distribution as below. +// However this produces segfaults in std::mt19937 which look like inifite loop. +// template +// struct FillUniformDistributionIntegerValue +// { +// int a_{-5}; +// int b_{5}; +// +// template +// void operator()(ForwardIter first, ForwardIter last) const +// { +// std::mt19937 gen(11939); +// std::uniform_int_distribution dis(a_, b_); +// std::generate( +// first, last, [&dis, &gen]() { return ck_tile::type_convert(dis(gen)); }); +// } +// }; + +// Workaround for uniform_int_distribution not working as expected. See note above.< +template +struct FillUniformDistributionIntegerValue +{ + float a_{-5.f}; + float b_{5.f}; + std::optional seed_{11939}; + + template + void operator()(ForwardIter first, ForwardIter last) const + { + std::mt19937 gen(seed_.has_value() ? *seed_ : std::random_device{}()); + std::uniform_real_distribution dis(a_, b_); + std::generate( + first, last, [&dis, &gen]() { return ck_tile::type_convert(std::round(dis(gen))); }); + } + + template + auto operator()(ForwardRange&& range) const + -> std::void_t()( + std::begin(std::forward(range)), + std::end(std::forward(range))))> + { + (*this)(std::begin(std::forward(range)), + std::end(std::forward(range))); + } +}; + +template +struct FillNormalDistributionIntegerValue +{ + float mean_{0.f}; + float variance_{1.f}; + std::optional seed_{11939}; + + template + void operator()(ForwardIter first, ForwardIter last) const + { + std::mt19937 gen(seed_.has_value() ? *seed_ : std::random_device{}()); + std::normal_distribution dis(mean_, std::sqrt(variance_)); + std::generate( + first, last, [&dis, &gen]() { return ck_tile::type_convert(std::round(dis(gen))); }); + } + + template + auto operator()(ForwardRange&& range) const + -> std::void_t()( + std::begin(std::forward(range)), + std::end(std::forward(range))))> + { + (*this)(std::begin(std::forward(range)), + std::end(std::forward(range))); + } +}; + +template +struct FillMonotonicSeq +{ + T init_value_{0}; + T step_{1}; + + template + void operator()(ForwardIter first, ForwardIter last) const + { + std::generate(first, last, [=, n = init_value_]() mutable { + auto tmp = n; + n += step_; + return tmp; + }); + } + + template + auto operator()(ForwardRange&& range) const + -> std::void_t()( + std::begin(std::forward(range)), + std::end(std::forward(range))))> + { + (*this)(std::begin(std::forward(range)), + std::end(std::forward(range))); + } +}; + +template +struct FillConstant +{ + T value_{0}; + + template + void operator()(ForwardIter first, ForwardIter last) const + { + std::fill(first, last, value_); + } + + template + auto operator()(ForwardRange&& range) const -> std::void_t< + decltype(std::declval()(std::begin(std::forward(range)), + std::end(std::forward(range))))> + { + (*this)(std::begin(std::forward(range)), + std::end(std::forward(range))); + } +}; + +template +struct FillTrigValue +{ + template + struct LinearTrigGen + { + int i{0}; + auto operator()() + { + float v = 0; + if constexpr(UseCos_) + { + v = cos(i); + } + else + { + v = sin(i); + } + if constexpr(UseAbs_) + v = abs(v); + i++; + return ck_tile::type_convert(v); + } + }; + template + void operator()(ForwardIter first, ForwardIter last) const + { + LinearTrigGen gen; + std::generate(first, last, gen); + } + + template + auto operator()(ForwardRange&& range) const -> std::void_t< + decltype(std::declval()(std::begin(std::forward(range)), + std::end(std::forward(range))))> + { + (*this)(std::begin(std::forward(range)), + std::end(std::forward(range))); + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/host/hip_check_error.hpp b/include/ck_tile/host/hip_check_error.hpp new file mode 100644 index 0000000000..d19b2e3cb2 --- /dev/null +++ b/include/ck_tile/host/hip_check_error.hpp @@ -0,0 +1,35 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include + +namespace ck_tile { +// To be removed, which really does not tell the location of failed HIP functional call +inline void hip_check_error(hipError_t x) +{ + if(x != hipSuccess) + { + std::ostringstream ss; + ss << "HIP runtime error: " << hipGetErrorString(x) << ". " << __FILE__ << ": " << __LINE__ + << "in function: " << __func__; + throw std::runtime_error(ss.str()); + } +} +} // namespace ck_tile + +#define HIP_CHECK_ERROR(retval_or_funcall) \ + do \ + { \ + hipError_t _tmpVal = retval_or_funcall; \ + if(_tmpVal != hipSuccess) \ + { \ + std::ostringstream ostr; \ + ostr << "HIP Function Failed (" << __FILE__ << "," << __LINE__ << ") " \ + << hipGetErrorString(_tmpVal); \ + throw std::runtime_error(ostr.str()); \ + } \ + } while(0) diff --git a/include/ck_tile/host/host_tensor.hpp b/include/ck_tile/host/host_tensor.hpp new file mode 100644 index 0000000000..81bf3fb515 --- /dev/null +++ b/include/ck_tile/host/host_tensor.hpp @@ -0,0 +1,495 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "ck_tile/core.hpp" +#include "ck_tile/host/ranges.hpp" + +namespace ck_tile { + +template +std::ostream& LogRange(std::ostream& os, + Range&& range, + std::string delim, + int precision = std::cout.precision(), + int width = 0) +{ + bool first = true; + for(auto&& v : range) + { + if(first) + first = false; + else + os << delim; + os << std::setw(width) << std::setprecision(precision) << v; + } + return os; +} + +template +std::ostream& LogRangeAsType(std::ostream& os, + Range&& range, + std::string delim, + int precision = std::cout.precision(), + int width = 0) +{ + bool first = true; + for(auto&& v : range) + { + if(first) + first = false; + else + os << delim; + os << std::setw(width) << std::setprecision(precision) << static_cast(v); + } + return os; +} + +template +auto call_f_unpack_args_impl(F f, T args, std::index_sequence) +{ + return f(std::get(args)...); +} + +template +auto call_f_unpack_args(F f, T args) +{ + constexpr std::size_t N = std::tuple_size{}; + + return call_f_unpack_args_impl(f, args, std::make_index_sequence{}); +} + +template +auto construct_f_unpack_args_impl(T args, std::index_sequence) +{ + return F(std::get(args)...); +} + +template +auto construct_f_unpack_args(F, T args) +{ + constexpr std::size_t N = std::tuple_size{}; + + return construct_f_unpack_args_impl(args, std::make_index_sequence{}); +} + +struct HostTensorDescriptor +{ + HostTensorDescriptor() = default; + + void CalculateStrides(); + + template >> + HostTensorDescriptor(const std::initializer_list& lens) : mLens(lens.begin(), lens.end()) + { + this->CalculateStrides(); + } + + template , std::size_t>>> + HostTensorDescriptor(const Lengths& lens) : mLens(lens.begin(), lens.end()) + { + this->CalculateStrides(); + } + + template && + std::is_convertible_v>> + HostTensorDescriptor(const std::initializer_list& lens, + const std::initializer_list& strides) + : mLens(lens.begin(), lens.end()), mStrides(strides.begin(), strides.end()) + { + } + + template , std::size_t> && + std::is_convertible_v, std::size_t>>> + HostTensorDescriptor(const Lengths& lens, const Strides& strides) + : mLens(lens.begin(), lens.end()), mStrides(strides.begin(), strides.end()) + { + } + + std::size_t get_num_of_dimension() const; + std::size_t get_element_size() const; + std::size_t get_element_space_size() const; + + const std::vector& get_lengths() const; + const std::vector& GetStrides() const; + + template + std::size_t GetOffsetFromMultiIndex(Is... is) const + { + assert(sizeof...(Is) == this->get_num_of_dimension()); + std::initializer_list iss{static_cast(is)...}; + return std::inner_product(iss.begin(), iss.end(), mStrides.begin(), std::size_t{0}); + } + + std::size_t GetOffsetFromMultiIndex(std::vector iss) const + { + return std::inner_product(iss.begin(), iss.end(), mStrides.begin(), std::size_t{0}); + } + + friend std::ostream& operator<<(std::ostream& os, const HostTensorDescriptor& desc); + + private: + std::vector mLens; + std::vector mStrides; +}; + +template +HostTensorDescriptor transpose_host_tensor_descriptor_given_new2old(const HostTensorDescriptor& a, + const New2Old& new2old) +{ + std::vector new_lengths(a.get_num_of_dimension()); + std::vector new_strides(a.get_num_of_dimension()); + + for(std::size_t i = 0; i < a.get_num_of_dimension(); i++) + { + new_lengths[i] = a.get_lengths()[new2old[i]]; + new_strides[i] = a.GetStrides()[new2old[i]]; + } + + return HostTensorDescriptor(new_lengths, new_strides); +} + +struct joinable_thread : std::thread +{ + template + joinable_thread(Xs&&... xs) : std::thread(std::forward(xs)...) + { + } + + joinable_thread(joinable_thread&&) = default; + joinable_thread& operator=(joinable_thread&&) = default; + + ~joinable_thread() + { + if(this->joinable()) + this->join(); + } +}; + +template +struct ParallelTensorFunctor +{ + F mF; + static constexpr std::size_t NDIM = sizeof...(Xs); + std::array mLens; + std::array mStrides; + std::size_t mN1d; + + ParallelTensorFunctor(F f, Xs... xs) : mF(f), mLens({static_cast(xs)...}) + { + mStrides.back() = 1; + std::partial_sum(mLens.rbegin(), + mLens.rend() - 1, + mStrides.rbegin() + 1, + std::multiplies()); + mN1d = mStrides[0] * mLens[0]; + } + + std::array GetNdIndices(std::size_t i) const + { + std::array indices; + + for(std::size_t idim = 0; idim < NDIM; ++idim) + { + indices[idim] = i / mStrides[idim]; + i -= indices[idim] * mStrides[idim]; + } + + return indices; + } + + void operator()(std::size_t num_thread = 1) const + { + std::size_t work_per_thread = (mN1d + num_thread - 1) / num_thread; + + std::vector threads(num_thread); + + for(std::size_t it = 0; it < num_thread; ++it) + { + std::size_t iw_begin = it * work_per_thread; + std::size_t iw_end = std::min((it + 1) * work_per_thread, mN1d); + + auto f = [this, iw_begin, iw_end] { + for(std::size_t iw = iw_begin; iw < iw_end; ++iw) + { + call_f_unpack_args(this->mF, this->GetNdIndices(iw)); + } + }; + threads[it] = joinable_thread(f); + } + } +}; + +template +auto make_ParallelTensorFunctor(F f, Xs... xs) +{ + return ParallelTensorFunctor(f, xs...); +} + +template +struct HostTensor +{ + using Descriptor = HostTensorDescriptor; + using Data = std::vector; + + template + HostTensor(std::initializer_list lens) : mDesc(lens), mData(mDesc.get_element_space_size()) + { + } + + template + HostTensor(std::initializer_list lens, std::initializer_list strides) + : mDesc(lens, strides), mData(mDesc.get_element_space_size()) + { + } + + template + HostTensor(const Lengths& lens) : mDesc(lens), mData(mDesc.get_element_space_size()) + { + } + + template + HostTensor(const Lengths& lens, const Strides& strides) + : mDesc(lens, strides), mData(get_element_space_size()) + { + } + + HostTensor(const Descriptor& desc) : mDesc(desc), mData(mDesc.get_element_space_size()) {} + + template + HostTensor CopyAsType() const + { + HostTensor ret(mDesc); + std::transform(mData.cbegin(), mData.cend(), ret.mData.begin(), [](auto value) { + return ck_tile::type_convert(value); + }); + return ret; + } + + HostTensor() = delete; + HostTensor(const HostTensor&) = default; + HostTensor(HostTensor&&) = default; + + ~HostTensor() = default; + + HostTensor& operator=(const HostTensor&) = default; + HostTensor& operator=(HostTensor&&) = default; + + template + explicit HostTensor(const HostTensor& other) : HostTensor(other.template CopyAsType()) + { + } + + decltype(auto) get_lengths() const { return mDesc.get_lengths(); } + + decltype(auto) GetStrides() const { return mDesc.GetStrides(); } + + std::size_t get_num_of_dimension() const { return mDesc.get_num_of_dimension(); } + + std::size_t get_element_size() const { return mDesc.get_element_size(); } + + std::size_t get_element_space_size() const { return mDesc.get_element_space_size(); } + + std::size_t get_element_space_size_in_bytes() const + { + return sizeof(T) * get_element_space_size(); + } + + // void SetZero() { ck_tile::ranges::fill(mData, 0); } + void SetZero() { std::fill(mData.begin(), mData.end(), 0); } + + template + void ForEach_impl(F&& f, std::vector& idx, size_t rank) + { + if(rank == mDesc.get_num_of_dimension()) + { + f(*this, idx); + return; + } + // else + for(size_t i = 0; i < mDesc.get_lengths()[rank]; i++) + { + idx[rank] = i; + ForEach_impl(std::forward(f), idx, rank + 1); + } + } + + template + void ForEach(F&& f) + { + std::vector idx(mDesc.get_num_of_dimension(), 0); + ForEach_impl(std::forward(f), idx, size_t(0)); + } + + template + void ForEach_impl(const F&& f, std::vector& idx, size_t rank) const + { + if(rank == mDesc.get_num_of_dimension()) + { + f(*this, idx); + return; + } + // else + for(size_t i = 0; i < mDesc.get_lengths()[rank]; i++) + { + idx[rank] = i; + ForEach_impl(std::forward(f), idx, rank + 1); + } + } + + template + void ForEach(const F&& f) const + { + std::vector idx(mDesc.get_num_of_dimension(), 0); + ForEach_impl(std::forward(f), idx, size_t(0)); + } + + template + void GenerateTensorValue(G g, std::size_t num_thread = 1) + { + switch(mDesc.get_num_of_dimension()) + { + case 1: { + auto f = [&](auto i) { (*this)(i) = g(i); }; + make_ParallelTensorFunctor(f, mDesc.get_lengths()[0])(num_thread); + break; + } + case 2: { + auto f = [&](auto i0, auto i1) { (*this)(i0, i1) = g(i0, i1); }; + make_ParallelTensorFunctor(f, mDesc.get_lengths()[0], mDesc.get_lengths()[1])( + num_thread); + break; + } + case 3: { + auto f = [&](auto i0, auto i1, auto i2) { (*this)(i0, i1, i2) = g(i0, i1, i2); }; + make_ParallelTensorFunctor(f, + mDesc.get_lengths()[0], + mDesc.get_lengths()[1], + mDesc.get_lengths()[2])(num_thread); + break; + } + case 4: { + auto f = [&](auto i0, auto i1, auto i2, auto i3) { + (*this)(i0, i1, i2, i3) = g(i0, i1, i2, i3); + }; + make_ParallelTensorFunctor(f, + mDesc.get_lengths()[0], + mDesc.get_lengths()[1], + mDesc.get_lengths()[2], + mDesc.get_lengths()[3])(num_thread); + break; + } + case 5: { + auto f = [&](auto i0, auto i1, auto i2, auto i3, auto i4) { + (*this)(i0, i1, i2, i3, i4) = g(i0, i1, i2, i3, i4); + }; + make_ParallelTensorFunctor(f, + mDesc.get_lengths()[0], + mDesc.get_lengths()[1], + mDesc.get_lengths()[2], + mDesc.get_lengths()[3], + mDesc.get_lengths()[4])(num_thread); + break; + } + case 6: { + auto f = [&](auto i0, auto i1, auto i2, auto i3, auto i4, auto i5) { + (*this)(i0, i1, i2, i3, i4, i5) = g(i0, i1, i2, i3, i4, i5); + }; + make_ParallelTensorFunctor(f, + mDesc.get_lengths()[0], + mDesc.get_lengths()[1], + mDesc.get_lengths()[2], + mDesc.get_lengths()[3], + mDesc.get_lengths()[4], + mDesc.get_lengths()[5])(num_thread); + break; + } + default: throw std::runtime_error("unspported dimension"); + } + } + + template + std::size_t GetOffsetFromMultiIndex(Is... is) const + { + return mDesc.GetOffsetFromMultiIndex(is...); + } + + template + T& operator()(Is... is) + { + return mData[mDesc.GetOffsetFromMultiIndex(is...)]; + } + + template + const T& operator()(Is... is) const + { + return mData[mDesc.GetOffsetFromMultiIndex(is...)]; + } + + T& operator()(std::vector idx) + { + return mData[mDesc.GetOffsetFromMultiIndex(idx)]; + } + + const T& operator()(std::vector idx) const + { + return mData[mDesc.GetOffsetFromMultiIndex(idx)]; + } + + typename Data::iterator begin() { return mData.begin(); } + + typename Data::iterator end() { return mData.end(); } + + typename Data::pointer data() { return mData.data(); } + + typename Data::const_iterator begin() const { return mData.begin(); } + + typename Data::const_iterator end() const { return mData.end(); } + + typename Data::const_pointer data() const { return mData.data(); } + + typename Data::size_type size() const { return mData.size(); } + + template + auto AsSpan() const + { + constexpr std::size_t FromSize = sizeof(T); + constexpr std::size_t ToSize = sizeof(U); + + using Element = std::add_const_t>; + return ck_tile::span{reinterpret_cast(data()), + size() * FromSize / ToSize}; + } + + template + auto AsSpan() + { + constexpr std::size_t FromSize = sizeof(T); + constexpr std::size_t ToSize = sizeof(U); + + using Element = std::remove_reference_t; + return ck_tile::span{reinterpret_cast(data()), + size() * FromSize / ToSize}; + } + + Descriptor mDesc; + Data mData; +}; +} // namespace ck_tile diff --git a/include/ck_tile/host/kernel_launch.hpp b/include/ck_tile/host/kernel_launch.hpp new file mode 100644 index 0000000000..30c49e3c1b --- /dev/null +++ b/include/ck_tile/host/kernel_launch.hpp @@ -0,0 +1,166 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core/config.hpp" +#include "ck_tile/host/stream_config.hpp" +#include "ck_tile/host/hip_check_error.hpp" +#include +#include + +namespace ck_tile { +template +#if CK_TILE_USE_LAUNCH_BOUNDS +__launch_bounds__(MaxThreadPerBlock, MinBlockPerCu) +#endif + __global__ void kentry(Kernel f, Args... args) +{ + f(args...); +} + +template +float launch_and_time_kernel(const stream_config& s, + F kernel, + dim3 grid_dim, + dim3 block_dim, + std::size_t lds_byte, + Args... args) +{ +#if CK_TILE_TIME_KERNEL + if(s.time_kernel_) + { + // warm up + for(int i = 0; i < s.cold_niters_; ++i) + { + kernel<<>>(args...); + hip_check_error(hipGetLastError()); + } + + const int nrepeat = s.nrepeat_; + hipEvent_t start, stop; + + HIP_CHECK_ERROR(hipEventCreate(&start)); + HIP_CHECK_ERROR(hipEventCreate(&stop)); + + HIP_CHECK_ERROR(hipDeviceSynchronize()); + HIP_CHECK_ERROR(hipEventRecord(start, s.stream_id_)); + + for(int i = 0; i < nrepeat; ++i) + { + kernel<<>>(args...); + hip_check_error(hipGetLastError()); + } + + HIP_CHECK_ERROR(hipEventRecord(stop, s.stream_id_)); + HIP_CHECK_ERROR(hipEventSynchronize(stop)); + + float total_time = 0; + + HIP_CHECK_ERROR(hipEventElapsedTime(&total_time, start, stop)); + + return total_time / nrepeat; + } + else + { + kernel<<>>(args...); + hip_check_error(hipGetLastError()); + return 0; + } +#else + kernel<<>>(args...); + hip_check_error(hipGetLastError()); + return 0; +#endif +} + +template +float launch_and_time_kernel_with_preprocess(const stream_config& s, + PreProcessFunc preprocess, + F kernel, + dim3 grid_dim, + dim3 block_dim, + std::size_t lds_byte, + Args... args) +{ +#if CK_TILE_TIME_KERNEL + if(s.time_kernel_) + { +#if DEBUG_LOG + printf("%s: grid_dim {%d, %d, %d}, block_dim {%d, %d, %d} \n", + __func__, + grid_dim.x, + grid_dim.y, + grid_dim.z, + block_dim.x, + block_dim.y, + block_dim.z); + + printf("Warm up 1 time\n"); +#endif + // warm up + preprocess(); + kernel<<>>(args...); + hip_check_error(hipGetLastError()); + + const int nrepeat = 10; +#if DEBUG_LOG + printf("Start running %d times...\n", nrepeat); +#endif + hipEvent_t start, stop; + + HIP_CHECK_ERROR(hipEventCreate(&start)); + HIP_CHECK_ERROR(hipEventCreate(&stop)); + + HIP_CHECK_ERROR(hipDeviceSynchronize()); + HIP_CHECK_ERROR(hipEventRecord(start, s.stream_id_)); + + for(int i = 0; i < nrepeat; ++i) + { + preprocess(); + kernel<<>>(args...); + hip_check_error(hipGetLastError()); + } + + HIP_CHECK_ERROR(hipEventRecord(stop, s.stream_id_)); + HIP_CHECK_ERROR(hipEventSynchronize(stop)); + + float total_time = 0; + + HIP_CHECK_ERROR(hipEventElapsedTime(&total_time, start, stop)); + + return total_time / nrepeat; + } + else + { + preprocess(); + kernel<<>>(args...); + hip_check_error(hipGetLastError()); + + return 0; + } +#else + kernel<<>>(args...); + hip_check_error(hipGetLastError()); + + return 0; +#endif +} + +template +float launch_kernel(const stream_config& s, + KernelImpl kernel_impl, + dim3 grid_dim, + dim3 block_dim, + std::size_t dynamic_smem_byte, + Args... args) +{ + const auto kernel = kentry; + + return launch_and_time_kernel( + s, kernel, grid_dim, block_dim, dynamic_smem_byte, kernel_impl, args...); +} +} // namespace ck_tile diff --git a/include/ck_tile/host/ranges.hpp b/include/ck_tile/host/ranges.hpp new file mode 100644 index 0000000000..b1b8197044 --- /dev/null +++ b/include/ck_tile/host/ranges.hpp @@ -0,0 +1,71 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include + +// ranges implementation are not intented to be used by user +// TODO: do we need this? +namespace ck_tile { +namespace ranges { + +template +using iter_value_t = typename std::iterator_traits>::value_type; + +template +using iter_reference_t = decltype(*std::declval()); + +template +using iter_difference_t = typename std::iterator_traits>::difference_type; + +//......................... + +template +using iterator_t = decltype(std::begin(std::declval())); + +template +using sentinel_t = decltype(std::end(std::declval())); + +template +using range_size_t = decltype(std::size(std::declval())); + +template +using range_difference_t = ck_tile::iter_difference_t>; + +template +using range_value_t = iter_value_t>; + +template +using range_reference_t = iter_reference_t>; + +template +struct is_range : std::false_type +{ +}; + +template +struct is_range< + T, + std::void_t())), decltype(std::end(std::declval()))>> + : std::true_type +{ +}; + +template +inline constexpr bool is_range_v = is_range::value; + +template +struct is_sized_range : std::false_type +{ +}; + +template +struct is_sized_range()))>> + : std::bool_constant> +{ +}; +} // namespace ranges +} // namespace ck_tile diff --git a/include/ck_tile/host/reference/reference_batched_elementwise.hpp b/include/ck_tile/host/reference/reference_batched_elementwise.hpp new file mode 100644 index 0000000000..9b69f940ec --- /dev/null +++ b/include/ck_tile/host/reference/reference_batched_elementwise.hpp @@ -0,0 +1,64 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/host/host_tensor.hpp" +#include + +namespace ck_tile { + +template > +void reference_batched_elementwise(const HostTensor& a_b_m_n, + const HostTensor& b_b_m_n, + HostTensor& c_b_m_n, + const AElementOp& a_element_op = {}, + const BElementOp& b_element_op = {}, + const BinaryElementOp& binary_element_op = {}) +{ + const ck_tile::index_t N = c_b_m_n.mDesc.get_lengths()[2]; + + const bool broadcast_a_dim_b = (a_b_m_n.get_lengths()[0] == 1); + const bool broadcast_a_dim_m = (a_b_m_n.get_lengths()[1] == 1); + const bool broadcast_a_dim_n = (a_b_m_n.get_lengths()[2] == 1); + + const bool broadcast_b_dim_b = (b_b_m_n.get_lengths()[0] == 1); + const bool broadcast_b_dim_m = (b_b_m_n.get_lengths()[1] == 1); + const bool broadcast_b_dim_n = (b_b_m_n.get_lengths()[2] == 1); + + auto f = [&](auto batch, auto m) { + for(ck_tile::index_t n = 0; n < N; ++n) + { + AccDataType v_a{}; + { + ck_tile::index_t i_b = (broadcast_a_dim_b ? 0 : batch); + ck_tile::index_t i_m = (broadcast_a_dim_m ? 0 : m); + ck_tile::index_t i_n = (broadcast_a_dim_n ? 0 : n); + + v_a = ck_tile::type_convert(a_element_op(a_b_m_n(i_b, i_m, i_n))); + } + + AccDataType v_b{}; + { + ck_tile::index_t i_b = (broadcast_b_dim_b ? 0 : batch); + ck_tile::index_t i_m = (broadcast_b_dim_m ? 0 : m); + ck_tile::index_t i_n = (broadcast_b_dim_n ? 0 : n); + + v_b = ck_tile::type_convert(b_element_op(b_b_m_n(i_b, i_m, i_n))); + } + + c_b_m_n(batch, m, n) = ck_tile::type_convert(binary_element_op(v_a, v_b)); + } + }; + + make_ParallelTensorFunctor(f, c_b_m_n.mDesc.get_lengths()[0], c_b_m_n.mDesc.get_lengths()[1])( + std::thread::hardware_concurrency()); +} +} // namespace ck_tile diff --git a/include/ck_tile/host/reference/reference_batched_gemm.hpp b/include/ck_tile/host/reference/reference_batched_gemm.hpp new file mode 100644 index 0000000000..8e8f713537 --- /dev/null +++ b/include/ck_tile/host/reference/reference_batched_gemm.hpp @@ -0,0 +1,50 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/host/host_tensor.hpp" +#include + +namespace ck_tile { + +template +void reference_batched_gemm(const HostTensor& a_b_m_k, + const HostTensor& b_b_n_k, + HostTensor& c_b_m_n, + const AElementOp& a_element_op = {}, + const BElementOp& b_element_op = {}, + const ACCElementOp& acc_element_op = {}) +{ + const int N = b_b_n_k.mDesc.get_lengths()[1]; + const int K = b_b_n_k.mDesc.get_lengths()[2]; + + auto f = [&](auto batch, auto m) { + for(int n = 0; n < N; ++n) + { + AccDataType v_acc = 0; + + for(int k = 0; k < K; ++k) + { + ADataType v_a = a_element_op(a_b_m_k(batch, m, k)); + BDataType v_b = b_element_op(b_b_n_k(batch, n, k)); + + v_acc += ck_tile::type_convert(v_a) * + ck_tile::type_convert(v_b); + } + + c_b_m_n(batch, m, n) = ck_tile::type_convert(acc_element_op(v_acc)); + } + }; + + make_ParallelTensorFunctor(f, c_b_m_n.mDesc.get_lengths()[0], c_b_m_n.mDesc.get_lengths()[1])( + std::thread::hardware_concurrency()); +} +} // namespace ck_tile diff --git a/include/ck_tile/host/reference/reference_batched_masking.hpp b/include/ck_tile/host/reference/reference_batched_masking.hpp new file mode 100644 index 0000000000..c8457273fb --- /dev/null +++ b/include/ck_tile/host/reference/reference_batched_masking.hpp @@ -0,0 +1,32 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/host/host_tensor.hpp" +#include + +namespace ck_tile { + +template +void reference_batched_masking(HostTensor& c_b_m_n, const MaskingType& mask) +{ + const int M = c_b_m_n.mDesc.get_lengths()[1]; + const int N = c_b_m_n.mDesc.get_lengths()[2]; + + auto f = [&](auto batch) { + for(int n = 0; n < N; ++n) + { + for(int m = 0; m < M; ++m) + { + if(mask.IsOutOfBound(m, n)) + c_b_m_n(batch, m, n) = -ck_tile::numeric_limits::infinity(); + } + } + }; + + make_ParallelTensorFunctor(f, + c_b_m_n.mDesc.get_lengths()[0])(std::thread::hardware_concurrency()); +} +} // namespace ck_tile diff --git a/include/ck_tile/host/reference/reference_batched_softmax.hpp b/include/ck_tile/host/reference/reference_batched_softmax.hpp new file mode 100644 index 0000000000..55eeb8d479 --- /dev/null +++ b/include/ck_tile/host/reference/reference_batched_softmax.hpp @@ -0,0 +1,67 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/host/host_tensor.hpp" +#include + +namespace ck_tile { + +template +void reference_batched_softmax( + const HostTensor& a_b_m_n, + HostTensor& b_b_m_n, + std::optional>> lse_b_m = std::nullopt) +{ + const int N = a_b_m_n.mDesc.get_lengths()[2]; + + auto f = [&](auto batch, auto m) { + CompDataType v_max = -ck_tile::numeric_limits::infinity(); + + // max + for(int n = 0; n < N; ++n) + { + const CompDataType v_a = ck_tile::type_convert(a_b_m_n(batch, m, n)); + + v_max = v_max < v_a ? v_a : v_max; + } + + CompDataType v_exp_sum = 0; + // validate v_max if all the elements within a row are -INF + if(std::isinf(v_max) && v_max < 0) + { + v_max = ck_tile::type_convert(0.f); + } + + // sum + for(int n = 0; n < N; ++n) + { + const CompDataType v_a = ck_tile::type_convert(a_b_m_n(batch, m, n)); + + v_exp_sum += ck_tile::exp(v_a - v_max); + } + + // if sum is zero(masked), or nan/inf(other computation error), don't do divide + CompDataType inv_sum = (v_exp_sum == 0.f ? 1.f : 1.f / v_exp_sum); + + // elementwise + for(int n = 0; n < N; ++n) + { + const CompDataType v_a = ck_tile::type_convert(a_b_m_n(batch, m, n)); + + b_b_m_n(batch, m, n) = + ck_tile::type_convert(ck_tile::exp(v_a - v_max) * inv_sum); + } + // lse + if(lse_b_m) + { + lse_b_m->get()(batch, m) = v_max + ck_tile::log(v_exp_sum); + } + }; + + make_ParallelTensorFunctor(f, b_b_m_n.mDesc.get_lengths()[0], b_b_m_n.mDesc.get_lengths()[1])( + std::thread::hardware_concurrency()); +} +} // namespace ck_tile diff --git a/include/ck_tile/host/reference/reference_gemm.hpp b/include/ck_tile/host/reference/reference_gemm.hpp new file mode 100644 index 0000000000..8afd0391f9 --- /dev/null +++ b/include/ck_tile/host/reference/reference_gemm.hpp @@ -0,0 +1,50 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/host/host_tensor.hpp" +#include + +namespace ck_tile { + +template +void reference_gemm(const HostTensor& a_m_k, + const HostTensor& b_n_k, + HostTensor& c_m_n, + const AElementOp& a_element_op = {}, + const BElementOp& b_element_op = {}, + const ACCElementOp& acc_element_op = {}) +{ + const int N = b_n_k.mDesc.get_lengths()[0]; + const int K = b_n_k.mDesc.get_lengths()[1]; + + auto f = [&](auto m) { + for(int n = 0; n < N; ++n) + { + AccDataType v_acc = 0; + + for(int k = 0; k < K; ++k) + { + ADataType v_a = a_element_op(a_m_k(m, k)); + BDataType v_b = b_element_op(b_n_k(n, k)); + + v_acc += ck_tile::type_convert(v_a) * + ck_tile::type_convert(v_b); + } + + c_m_n(m, n) = ck_tile::type_convert(acc_element_op(v_acc)); + } + }; + + make_ParallelTensorFunctor(f, + c_m_n.mDesc.get_lengths()[0])(std::thread::hardware_concurrency()); +} +} // namespace ck_tile diff --git a/include/ck_tile/host/reference/reference_im2col.hpp b/include/ck_tile/host/reference/reference_im2col.hpp new file mode 100644 index 0000000000..a0ba2135be --- /dev/null +++ b/include/ck_tile/host/reference/reference_im2col.hpp @@ -0,0 +1,61 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/host/host_tensor.hpp" +#include + +namespace ck_tile { + +template +void reference_im2col(HostTensor& in_mtx_host_ref, + const HostTensor& in_host, + int /*N*/, + int /*K*/, + int C, + int /*Y*/, + int X, + int Hi, + int Wi, + int Ho, + int Wo, + int ConvStrideH, + int ConvStrideW, + int ConvDilationH, + int ConvDilationW, + int InLeftPadH, + int InLeftPadW, + int /*InRightPadH*/, + int /*InRightPadW*/) +{ + int GemmM = in_mtx_host_ref.get_lengths()[0]; + int GemmK = in_mtx_host_ref.get_lengths()[1]; + + for(int gemm_m = 0; gemm_m < GemmM; ++gemm_m) + { + int mtmp = gemm_m; + int n = mtmp / (Ho * Wo); + mtmp -= n * Ho * Wo; + int ho = mtmp / Wo; + int wo = mtmp - ho * Wo; + + for(int gemm_k = 0; gemm_k < GemmK; ++gemm_k) + { + int ktmp = gemm_k; + int y = ktmp / (X * C); + ktmp -= y * X * C; + int x = ktmp / C; + int c = ktmp - x * C; + + int hi = y * ConvDilationH + ho * ConvStrideH - InLeftPadH; + int wi = x * ConvDilationW + wo * ConvStrideW - InLeftPadW; + + bool inbound = (hi >= 0 && hi < Hi && wi >= 0 && wi < Wi); + + in_mtx_host_ref(gemm_m, gemm_k) = inbound ? in_host(n, hi, wi, c) : 0; + } + } +} +} // namespace ck_tile diff --git a/include/ck_tile/host/reference/reference_reduce.hpp b/include/ck_tile/host/reference/reference_reduce.hpp new file mode 100644 index 0000000000..90bb2b7c33 --- /dev/null +++ b/include/ck_tile/host/reference/reference_reduce.hpp @@ -0,0 +1,32 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/host/host_tensor.hpp" +#include + +namespace ck_tile { + +template +void reference_reduce(const HostTensor& a_m_n, HostTensor& b_m) +{ + auto f = [&](auto m) { + const int N = a_m_n.mDesc.get_lengths()[1]; + + AccDataType v_acc = 0; + + for(int n = 0; n < N; ++n) + { + const ADataType v_a = a_m_n(m, n); + + v_acc += v_a; + } + + b_m(m) = ck_tile::type_convert(v_acc); + }; + + make_ParallelTensorFunctor(f, b_m.mDesc.get_lengths()[0])(std::thread::hardware_concurrency()); +} +} // namespace ck_tile diff --git a/include/ck_tile/host/reference/reference_softmax.hpp b/include/ck_tile/host/reference/reference_softmax.hpp new file mode 100644 index 0000000000..356b2587b7 --- /dev/null +++ b/include/ck_tile/host/reference/reference_softmax.hpp @@ -0,0 +1,50 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/host/host_tensor.hpp" +#include + +namespace ck_tile { + +template +void reference_softmax(const HostTensor& a_m_n, HostTensor& b_m_n) +{ + auto f = [&](auto m) { + const int N = a_m_n.mDesc.get_lengths()[1]; + + AccDataType v_max = ck_tile::NumericLimits::Lowest(); + + // max + for(int n = 0; n < N; ++n) + { + const ADataType v_a = a_m_n(m, n); + + v_max = v_max < v_a ? v_a : v_max; + } + + AccDataType v_exp_sum = 0; + + // sum + for(int n = 0; n < N; ++n) + { + const ADataType v_a = a_m_n(m, n); + + v_exp_sum += ck_tile::exp(v_a - v_max); + } + + // elementwise + for(int n = 0; n < N; ++n) + { + const ADataType v_a = a_m_n(m, n); + + b_m_n(m, n) = ck_tile::exp(v_a - v_max) / v_exp_sum; + } + }; + + make_ParallelTensorFunctor(f, + b_m_n.mDesc.get_lengths()[0])(std::thread::hardware_concurrency()); +} +} // namespace ck_tile diff --git a/include/ck_tile/host/stream_config.hpp b/include/ck_tile/host/stream_config.hpp new file mode 100644 index 0000000000..d29c6f0fa1 --- /dev/null +++ b/include/ck_tile/host/stream_config.hpp @@ -0,0 +1,17 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +namespace ck_tile { +struct stream_config +{ + hipStream_t stream_id_ = nullptr; + bool time_kernel_ = false; + int log_level_ = 0; + int cold_niters_ = 3; + int nrepeat_ = 10; +}; +} // namespace ck_tile diff --git a/include/ck_tile/ops/common.hpp b/include/ck_tile/ops/common.hpp new file mode 100644 index 0000000000..9fc1c0d0c1 --- /dev/null +++ b/include/ck_tile/ops/common.hpp @@ -0,0 +1,7 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/ops/common/tensor_layout.hpp" + diff --git a/include/ck_tile/ops/common/tensor_layout.hpp b/include/ck_tile/ops/common/tensor_layout.hpp new file mode 100644 index 0000000000..bb905e6ab9 --- /dev/null +++ b/include/ck_tile/ops/common/tensor_layout.hpp @@ -0,0 +1,412 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +// TODO: this folder does not match the single namespace rule. need to refactor in the future +namespace ck_tile { +namespace tensor_layout { + +struct BaseTensorLayout +{ +}; + +namespace gemm { + +struct RowMajor : public BaseTensorLayout +{ + static constexpr const char* name = "RowMajor"; +}; + +struct ColumnMajor : public BaseTensorLayout +{ + static constexpr const char* name = "ColumnMajor"; +}; +} // namespace gemm + +namespace convolution { + +// input tensor +// packed NCW/NCHW/NCDHW +struct NCW : public BaseTensorLayout +{ + static constexpr const char* name = "NCW"; +}; + +struct NCHW : public BaseTensorLayout +{ + static constexpr const char* name = "NCHW"; +}; + +struct NCDHW : public BaseTensorLayout +{ + static constexpr const char* name = "NCDHW"; +}; + +// packed GNCW/GNCHW/GNCDHW +struct GNCW : public BaseTensorLayout +{ + static constexpr const char* name = "GNCW"; +}; + +struct GNCHW : public BaseTensorLayout +{ + static constexpr const char* name = "GNCHW"; +}; + +struct GNCDHW : public BaseTensorLayout +{ + static constexpr const char* name = "GNCDHW"; +}; + +// input tensor +// packed NWC/NHWC/NDHWC +struct NWC : public BaseTensorLayout +{ + static constexpr const char* name = "NWC"; +}; + +struct NHWC : public BaseTensorLayout +{ + static constexpr const char* name = "NHWC"; +}; + +struct NDHWC : public BaseTensorLayout +{ + static constexpr const char* name = "NDHWC"; +}; + +// input tensor +// packed GNWC/GNHWC/GNDHWC +struct GNWC : public BaseTensorLayout +{ + static constexpr const char* name = "GNWC"; +}; + +struct GNHWC : public BaseTensorLayout +{ + static constexpr const char* name = "GNHWC"; +}; + +struct GNDHWC : public BaseTensorLayout +{ + static constexpr const char* name = "GNDHWC"; +}; + +// for input bias +struct GC : public BaseTensorLayout +{ + static constexpr const char* name = "GC"; +}; + +// input tensor +// packed NWGC/NHWGC/NDHWGC +struct NWGC : public BaseTensorLayout +{ + static constexpr const char* name = "NWGC"; +}; + +struct NHWGC : public BaseTensorLayout +{ + static constexpr const char* name = "NHWGC"; +}; + +struct NDHWGC : public BaseTensorLayout +{ + static constexpr const char* name = "NDHWGC"; +}; + +// input tensor +// strided layout +struct G_NW_C : public BaseTensorLayout +{ + static constexpr const char* name = "G_NW_C"; +}; + +struct G_NHW_C : public BaseTensorLayout +{ + static constexpr const char* name = "G_NHW_C"; +}; + +struct G_NDHW_C : public BaseTensorLayout +{ + static constexpr const char* name = "G_NDHW_C"; +}; + +// for input bias +struct G_C : public BaseTensorLayout +{ + static constexpr const char* name = "G_C"; +}; + +// weight tensor +// packed KCX/KCYX/KCZYX +struct KCX : public BaseTensorLayout +{ + static constexpr const char* name = "KCX"; +}; + +struct KCYX : public BaseTensorLayout +{ + static constexpr const char* name = "KCYX"; +}; + +struct KCZYX : public BaseTensorLayout +{ + static constexpr const char* name = "KCZYX"; +}; + +// weight tensor +// packed KCX/KCYX/KCZYX +struct GKCX : public BaseTensorLayout +{ + static constexpr const char* name = "GKCX"; +}; + +struct GKCYX : public BaseTensorLayout +{ + static constexpr const char* name = "GKCYX"; +}; + +struct GKCZYX : public BaseTensorLayout +{ + static constexpr const char* name = "GKCZYX"; +}; + +// weight tensor +// packed KXC/KYXC/KZYXC +struct KXC : public BaseTensorLayout +{ + static constexpr const char* name = "KXC"; +}; + +struct KYXC : public BaseTensorLayout +{ + static constexpr const char* name = "KYXC"; +}; + +struct KZYXC : public BaseTensorLayout +{ + static constexpr const char* name = "KZYXC"; +}; + +// weight tensor +// packed GKXC/GKYXC/GKZYXC +struct GKXC : public BaseTensorLayout +{ + static constexpr const char* name = "GKXC"; +}; + +struct GKYXC : public BaseTensorLayout +{ + static constexpr const char* name = "GKYXC"; +}; + +struct GKZYXC : public BaseTensorLayout +{ + static constexpr const char* name = "GKZYXC"; +}; + +// weight tensor +// packed KXGC/KYXGC/KZYXGC +struct KXGC : public BaseTensorLayout +{ + static constexpr const char* name = "KXGC"; +}; + +struct KYXGC : public BaseTensorLayout +{ + static constexpr const char* name = "KYXGC"; +}; + +struct KZYXGC : public BaseTensorLayout +{ + static constexpr const char* name = "KZYXGC"; +}; + +// weight tensor +// strided +struct G_K_X_C : public BaseTensorLayout +{ + static constexpr const char* name = "G_K_X_C"; +}; + +struct G_K_YX_C : public BaseTensorLayout +{ + static constexpr const char* name = "G_K_YX_C"; +}; + +struct G_K_ZYX_C : public BaseTensorLayout +{ + static constexpr const char* name = "G_K_ZYX_C"; +}; + +// output tensor +// packed NKW/NKHW/NKDHW +struct NKW : public BaseTensorLayout +{ + static constexpr const char* name = "NKW"; +}; + +struct NKHW : public BaseTensorLayout +{ + static constexpr const char* name = "NKHW"; +}; + +struct NKDHW : public BaseTensorLayout +{ + static constexpr const char* name = "NKDHW"; +}; + +// output tensor +// packed GNKW/GNKHW/GNKDHW +struct GNKW : public BaseTensorLayout +{ + static constexpr const char* name = "GNKW"; +}; + +struct GNKHW : public BaseTensorLayout +{ + static constexpr const char* name = "GNKHW"; +}; + +struct GNKDHW : public BaseTensorLayout +{ + static constexpr const char* name = "GNKDHW"; +}; + +// output tensor +// packed NWK/NHWK/NDHWK +struct NWK : public BaseTensorLayout +{ + static constexpr const char* name = "NWK"; +}; + +struct NHWK : public BaseTensorLayout +{ + static constexpr const char* name = "NHWK"; +}; + +struct NDHWK : public BaseTensorLayout +{ + static constexpr const char* name = "NDHWK"; +}; + +// output tensor +// packed GNWK/GNHWK/GNDHWK +struct GNWK : public BaseTensorLayout +{ + static constexpr const char* name = "GNWK"; +}; + +struct GNHWK : public BaseTensorLayout +{ + static constexpr const char* name = "GNHWK"; +}; + +struct GNDHWK : public BaseTensorLayout +{ + static constexpr const char* name = "GNDHWK"; +}; + +// output tensor +// packed NWGK/NHWGK/NDHWGK +struct NWGK : public BaseTensorLayout +{ + static constexpr const char* name = "NWGK"; +}; + +struct NHWGK : public BaseTensorLayout +{ + static constexpr const char* name = "NHWGK"; +}; + +struct NDHWGK : public BaseTensorLayout +{ + static constexpr const char* name = "NDHWGK"; +}; + +// output tensor +// strided layout +struct G_NW_K : public BaseTensorLayout +{ + static constexpr const char* name = "G_NW_K"; +}; + +struct G_NHW_K : public BaseTensorLayout +{ + static constexpr const char* name = "G_NHW_K"; +}; + +struct G_NDHW_K : public BaseTensorLayout +{ + static constexpr const char* name = "G_NDHW_K"; +}; + +// for output bias +struct G_K : public BaseTensorLayout +{ + static constexpr const char* name = "G_K"; +}; + +// K-reduced output tensor (packed) +struct GNW : public BaseTensorLayout +{ + static constexpr const char* name = "GNW"; +}; + +struct GNHW : public BaseTensorLayout +{ + static constexpr const char* name = "GNHW"; +}; + +struct GNDHW : public BaseTensorLayout +{ + static constexpr const char* name = "GNDHW"; +}; + +// K-reduced output tensor (packed) +struct NWG : public BaseTensorLayout +{ + static constexpr const char* name = "NWG"; +}; + +struct NHWG : public BaseTensorLayout +{ + static constexpr const char* name = "NHWG"; +}; + +struct NDHWG : public BaseTensorLayout +{ + static constexpr const char* name = "NDHWG"; +}; + +// K-reduced output tensor (strided) +struct G_NW : public BaseTensorLayout +{ + static constexpr const char* name = "G_NW"; +}; + +struct G_NHW : public BaseTensorLayout +{ + static constexpr const char* name = "G_NHW"; +}; + +struct G_NDHW : public BaseTensorLayout +{ + static constexpr const char* name = "G_NDHW"; +}; + +} // namespace convolution + +template < + typename Layout, + typename std::enable_if::value, bool>::type = false> +std::ostream& operator<<(std::ostream& os, const Layout&) +{ + os << Layout::name; + return os; +} + +} // namespace tensor_layout +} // namespace ck_tile diff --git a/include/ck_tile/ops/epilogue.hpp b/include/ck_tile/ops/epilogue.hpp new file mode 100644 index 0000000000..497f6d1504 --- /dev/null +++ b/include/ck_tile/ops/epilogue.hpp @@ -0,0 +1,7 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/ops/epilogue/default_2d_epilogue.hpp" + diff --git a/include/ck_tile/ops/epilogue/default_2d_epilogue.hpp b/include/ck_tile/ops/epilogue/default_2d_epilogue.hpp new file mode 100644 index 0000000000..5dc49c3b0e --- /dev/null +++ b/include/ck_tile/ops/epilogue/default_2d_epilogue.hpp @@ -0,0 +1,50 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" + +namespace ck_tile { + +// this epilogue just store out a M*N matrix, row major + +template +struct Default2DEpilogueProblem +{ + using AccDataType = remove_cvref_t; + using ODataType = remove_cvref_t; + static constexpr bool kPadM = kPadM_; + static constexpr bool kPadN = kPadN_; +}; + +template +struct Default2DEpilogue +{ + using Problem = remove_cvref_t; + using AccDataType = remove_cvref_t; + using ODataType = remove_cvref_t; + static constexpr bool kPadM = Problem::kPadM; + static constexpr bool kPadN = Problem::kPadN; + + CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { return 0; } + + // TODO: this function assume store out vector size is the same as OAccTile last dimension size + // how do we fix this ? + template + CK_TILE_DEVICE auto operator()(ODramWindowTmp& o_dram_window_tmp, const OAccTile& o_acc_tile) + { + + // TODO: this is ugly + if constexpr(kPadM || kPadN) + { + store_tile_raw(o_dram_window_tmp, cast_tile(o_acc_tile)); + buffer_store_fence(); + } + else + { + store_tile(o_dram_window_tmp, cast_tile(o_acc_tile)); + } + } +}; +} // namespace ck_tile diff --git a/include/ck_tile/ops/fmha.hpp b/include/ck_tile/ops/fmha.hpp new file mode 100644 index 0000000000..6813a7a971 --- /dev/null +++ b/include/ck_tile/ops/fmha.hpp @@ -0,0 +1,20 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/ops/fmha/block/block_masking.hpp" +#include "ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp" +#include "ck_tile/ops/fmha/kernel/fmha_fwd_tile_partitioner.hpp" +#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp" +#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp" +#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp" +#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_default_policy.hpp" +#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_default_policy.hpp" +#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_fp8.hpp" +#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs.hpp" +#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs_default_policy.hpp" +#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp" +#include "ck_tile/ops/fmha/pipeline/tile_fmha_shape.hpp" +#include "ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp" + diff --git a/include/ck_tile/ops/fmha/block/block_masking.hpp b/include/ck_tile/ops/fmha/block/block_masking.hpp new file mode 100644 index 0000000000..db880daadb --- /dev/null +++ b/include/ck_tile/ops/fmha/block/block_masking.hpp @@ -0,0 +1,227 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" + +namespace ck_tile { + +// clang-format off +/* generic Attention Mask Coordinate + use x(horizontal axis), y(vertical axis) to describe mask. + top-left corner is origin + + x=1/y=5(top-left) x=4/y=5(botm-r) x=6/y=5 x=8/y=5(no mask) + 1 * * * * * * * 1 1 1 1 * * * * 1 1 1 1 1 1 * * 1 1 1 1 1 1 1 1 + 1 1 * * * * * * 1 1 1 1 1 * * * 1 1 1 1 1 1 1 * 1 1 1 1 1 1 1 1 + 1 1 1 * * * * * 1 1 1 1 1 1 * * 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 + 1 1 1 1 * * * * 1 1 1 1 1 1 1 * 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 + 1 1 1 1 1 * * * 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 + l=7,-1/r=0(tl) l=7,-1/r=0(br) + + x=1/y=2 x=4/y=2 x=6/y=2 x=8/y=2 + 1 * * * * * * * 1 1 1 1 * * * * 1 1 1 1 1 1 * * 1 1 1 1 1 1 1 1 + 1 1 * * * * * * 1 1 1 1 1 * * * 1 1 1 1 1 1 1 * 1 1 1 1 1 1 1 1 + * 1 1 * * * * * * 1 1 1 1 1 * * * 1 1 1 1 1 1 1 * 1 1 1 1 1 1 1 + * * 1 1 * * * * * * 1 1 1 1 1 * * * 1 1 1 1 1 1 * * 1 1 1 1 1 1 + * * * 1 1 * * * * * * 1 1 1 1 1 * * * 1 1 1 1 1 * * * 1 1 1 1 1 + l=1/r=0(tl) l=1/r=3(tl) l=1/r=5(tl) l=1/r=7(tl) + l=4/r=0(br) l=4/r=2(br) l=4/r=4(br) + + x=4/y=-1 x=6/y=-1 x=8/y=-1 + * * 1 1 * * * * * * 1 1 1 1 * * * * 1 1 1 1 1 1 + * * * 1 1 * * * * * * 1 1 1 1 * * * * 1 1 1 1 1 + * * * * 1 1 * * * * * * 1 1 1 1 * * * * 1 1 1 1 + * * * * * 1 1 * * * * * * 1 1 1 * * * * * 1 1 1 + * * * * * * 1 1 * * * * * * 1 1 * * * * * * 1 1 + + x=-2/y=5 x=1/y=5(top-left) x=0/y=5(botm-r) + * * * * * * * * 1 * * * * * * * + * * * * * * * * 1 1 * * 1 * * * + * * * * * * * * 1 1 1 * 1 1 * * + 1 * * * * * * * 1 1 1 1 1 1 1 * + 1 1 * * * * * * 1 1 1 1 1 1 1 1 + + Validations: + x + y > 1 (x + y >= 2) + + Note: + y = seq_q, x = 1 -> top-left + y = seq_q, x = seq_k - seq_q + 1 -> bottom-right + y < seq_q, x < seq_k -> local-attn + y = seq_q, x = seq_k -> no mask + +*/ +namespace impl { + template struct MaskName; + template<> struct MaskName { static constexpr const char * name = "mn"; }; + template<> struct MaskName { static constexpr const char * name = "mn"; }; + template<> struct MaskName { static constexpr const char * name = "mc"; }; + template<> struct MaskName { static constexpr const char * name = "mg"; }; +} +// clang-format on + +template +struct GenericAttentionMask +{ + static constexpr bool IsMasking = IsMasking_; // false will disable masking + static constexpr bool IsLocal = IsLocal_; // if true, upper/lower area could have mask, + // else only upper-right could have mask + + static constexpr const char* name = impl::MaskName::name; + + CK_TILE_HOST_DEVICE GenericAttentionMask(index_t y_total_, index_t x_total_) + : GenericAttentionMask(0, 0, y_total_, x_total_) + { + } + + CK_TILE_HOST_DEVICE + GenericAttentionMask(index_t y_, index_t x_, index_t y_total_, index_t x_total_) + : y(y_), x(x_), y_total(y_total_), x_total(x_total_) + { + } + template + CK_TILE_HOST_DEVICE GenericAttentionMask(const MaskCoordinates& mask_coord) + : y(mask_coord.at(number<0>{})), + x(mask_coord.at(number<1>{})), + y_total(mask_coord.at(number<2>{})), + x_total(mask_coord.at(number<3>{})) + { + } + + // to get the loop length along X axis, return index:[start, end), end-start=length + // use this if need loop over X axis tile by tile (like k-seqlen loopover) + // TODO: x_end still could be negative, so end-start could be negative(need check) + template + CK_TILE_HOST_DEVICE constexpr auto + GetTileRangeAlongX(index_t i_y, number, number) const + { + if constexpr(!IsMasking) + { + return ck_tile::make_tuple(0, x_total); + } + else + { + // get the tile start/end range assum we loop over along X tile by tile + index_t x_start = [&]() { + if constexpr(IsLocal) + { + index_t tmp = math::max(-y + i_y + 1, 0); + return (tmp / XTile) * XTile; // round to tile aligned + } + else + { + return 0; + } + }(); + + // TODO: end could be negative, we ignore clamp here, and let caller to check + // ... in which case end-start is negative + index_t x_end = [&]() { + index_t tmp = math::min(i_y + YTile - 1 + x, x_total); + return ((tmp + XTile - 1) / XTile) * XTile; + }(); + + return ck_tile::make_tuple(x_start, x_end); + } + } + + // per-pixel check if out-of-bound, if true, need mask a value(like -INF) + CK_TILE_HOST_DEVICE constexpr auto IsOutOfBound(index_t i_y, index_t i_x) const + { + if constexpr(!IsMasking) + { + return i_x >= x_total; + } + else + { + // no need to do min/max here, since i_x will never be < 0 or >= x_total + index_t x_start = -y + i_y + 1; + index_t x_end = math::min(i_y + x, x_total); + + if constexpr(IsLocal) + { + return i_x < x_start || i_x >= x_end; + } + else + { + return i_x >= x_end; + } + } + } + + // if current tile is at the edge, means need per-pixel mask check. + // otherwise no need to check per-pixel + // Attention! assume the idex passed in this function is with in range of GetTileRangeAlongX() + // can be used as a fast-path to decide if do per-pixel check or not + template + CK_TILE_HOST_DEVICE constexpr auto + IsEdgeTile(index_t i_tile_top, index_t i_tile_left, number, number) const + { + if constexpr(IsLocal) + { + // check top-right corner > x or left-borrom corner < x + index_t i_tile_right = i_tile_left + TileWidth; + index_t i_tile_bottom = i_tile_top + TileHeight; + index_t x_end = math::min(i_tile_top + x, x_total); + + bool top_right_edge = i_tile_right > (i_tile_top + x); + bool bottom_left_edge = i_tile_bottom > (i_tile_left + y); + bool is_partial_out_of_bound = i_tile_right > x_end; // only consider right-pad for now + + return top_right_edge || bottom_left_edge || is_partial_out_of_bound; + } + else + { + // only need to check top-right corner > x + index_t i_tile_right = i_tile_left + TileWidth; + index_t x_end = math::min(i_tile_top + x, x_total); + + bool top_right_edge = i_tile_right > x_end; + return top_right_edge; + } + } + + private: + index_t y, x; + index_t y_total, x_total; +}; + +// TODO: prefer use this function in host code +// can convert from the FA style left/right to our generic coordinate +// if left_size < 0 && right_size = 0, it is normal causal mask +// local is left_size >=0 or right_size >=0 +CK_TILE_HOST_DEVICE constexpr auto +make_generic_attention_mask_coordinates_from_lr_window(index_t left_size, + index_t right_size, + index_t y_total, + index_t x_total, + bool is_top_left = true) +{ + index_t x = 0, y = 0; + + if(is_top_left) + { + if(left_size < 0) + left_size = y_total - 1; + if(right_size < 0) + right_size = x_total - 1; + + x = 1 + right_size; + y = left_size + 1; + } + else + { + if(left_size < 0) + left_size = x_total - 1; + if(right_size < 0) + right_size = y_total - 1; + + x = x_total - y_total + 1 + right_size; + y = y_total - x_total + 1 + left_size; + } + + return ck_tile::make_tuple(y, x, y_total, x_total); +} +} // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp new file mode 100644 index 0000000000..9685bd2da5 --- /dev/null +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp @@ -0,0 +1,698 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include +#include + +// S[seqlen_q, seqlen_k] = Q[seqlen_q, hdim_q] * K[seqlen_k, hdim_q] +// S'[seqlen_q, seqlen_k] = S[seqlen_q, seqlen_k] * Scale[1] +// S''[seqlen_q, seqlen_k] = S'[seqlen_q, seqlen_k] + Bias[seqlen_q, seqlen_k] +// P[seqlen_q, seqlen_k] = Softmax(S[seqlen_q, seqlen_k]) +// O[seqlen_q, hdim_v] = P[seqlen_q, seqlen_k] * V[hdim_v, seqlen_k] + +namespace ck_tile { + +template +struct FmhaFwdKernel +{ + using TilePartitioner = ck_tile::remove_cvref_t; + using FmhaPipeline = ck_tile::remove_cvref_t; + using EpiloguePipeline = ck_tile::remove_cvref_t; + static constexpr ck_tile::index_t kBlockSize = FmhaPipeline::kBlockSize; + static constexpr ck_tile::index_t kBlockPerCu = FmhaPipeline::kBlockPerCu; + static_assert(kBlockPerCu > 0); + static constexpr ck_tile::index_t kBlockPerCuInput = FmhaPipeline::Problem::kBlockPerCu; + + using QDataType = ck_tile::remove_cvref_t; + using KDataType = ck_tile::remove_cvref_t; + using VDataType = ck_tile::remove_cvref_t; + using BiasDataType = ck_tile::remove_cvref_t; + using LSEDataType = ck_tile::remove_cvref_t; + using ODataType = ck_tile::remove_cvref_t; + static constexpr bool kIsFp8 = FmhaPipeline::kIsFp8; + + using VLayout = ck_tile::remove_cvref_t; + + static constexpr bool kIsGroupMode = FmhaPipeline::kIsGroupMode; + static constexpr bool kPadSeqLenQ = FmhaPipeline::kPadSeqLenQ; + static constexpr bool kPadSeqLenK = FmhaPipeline::kPadSeqLenK; + static constexpr bool kPadHeadDimQ = FmhaPipeline::kPadHeadDimQ; + static constexpr bool kPadHeadDimV = FmhaPipeline::kPadHeadDimV; + static constexpr bool kHasBias = FmhaPipeline::kHasBias; + static constexpr bool kStoreLSE = FmhaPipeline::kStoreLSE; + using FmhaMask = ck_tile::remove_cvref_t; + static constexpr bool kHasMask = FmhaMask::IsMasking; + + // clang-format off + template struct t2s; + template <> struct t2s { static constexpr const char * name = "fp32"; }; + template <> struct t2s { static constexpr const char * name = "fp16"; }; + template <> struct t2s { static constexpr const char * name = "bf16"; }; + template <> struct t2s { static constexpr const char * name = "fp8"; }; + template <> struct t2s { static constexpr const char * name = "bf8"; }; + // clang-format on + + __host__ static std::string GetName() + { + // sync with generate.py + // clang-format off + using bfs = typename FmhaPipeline::BlockFmhaShape; + using gbr = typename bfs::Gemm0BlockWarps; + using gwt = typename bfs::Gemm0WarpTile; + #define _SS_ std::string + #define _TS_ std::to_string + auto pn = [&] () { + std::string n; + if (kPadSeqLenQ) n += "s"; + if (kPadSeqLenK) n += "sk"; + if (kPadHeadDimQ) n += "d"; + if (kPadHeadDimV) n += "dv"; + return n.empty() ? n : std::string("p") + n; }(); + return + _SS_("fmha_fwd_d") + _TS_(bfs::kK0BlockLength) + "_" + _SS_(t2s::name) + + "_" + (kIsGroupMode ? "group" : "batch") + "_" + + "b" + _TS_(bfs::kM0) + "x" + _TS_(bfs::kN0) + "x" + _TS_(bfs::kK0) + "x" + + _TS_(bfs::kN1) + "x" + _TS_(bfs::kK1) + "x" + _TS_(bfs::kK0BlockLength) + "_" + + "r" + _TS_(gbr::at(ck_tile::number<0>{})) + "x" + _TS_(gbr::at(ck_tile::number<1>{})) + "x" + _TS_(gbr::at(ck_tile::number<2>{})) + "_" + + "w" + _TS_(gwt::at(ck_tile::number<0>{})) + "x" + _TS_(gwt::at(ck_tile::number<1>{})) + "x" + _TS_(gwt::at(ck_tile::number<2>{})) + "_" + + (kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) + _SS_(FmhaPipeline::name) + "_" + + "v" + (ck_tile::is_same_v ? "r" : "c") + (pn.empty() ? "" : "_" + pn) + + (kHasBias ? "_bias" : "") + (kHasMask ? "_" + _SS_(FmhaMask::name) : "") + (kStoreLSE ? "_lse" : "" ); + #undef _SS_ + #undef _TS_ + // clang-format on + } + + template // to avoid duplicated base class prblem, introduce an template + // arg + struct FmhaFwdEmptyKargs + { + }; + + // kargs use aggregate initializer, so no constructor will provided + // use inheritance to minimize karg size + // user need to use MakeKargs() function to create kargs. + struct FmhaFwdCommonKargs + { + const void* q_ptr; + const void* k_ptr; + const void* v_ptr; + void* o_ptr; + + ck_tile::index_t seqlen_q; + ck_tile::index_t seqlen_k; + ck_tile::index_t hdim_q; + ck_tile::index_t hdim_v; + + // for MQA/GQA, nhead could be different. This parameter is nhead_q / nhead_k + // if this param is larger than 1, indicate MQA/GQA case + ck_tile::index_t nhead_ratio_qk; + float scale; + + ck_tile::index_t stride_q; + ck_tile::index_t stride_k; + ck_tile::index_t stride_v; + ck_tile::index_t stride_o; + + ck_tile::index_t nhead_stride_q; + ck_tile::index_t nhead_stride_k; + ck_tile::index_t nhead_stride_v; + ck_tile::index_t nhead_stride_o; + }; + + struct FmhaFwdCommonBiasKargs + { + const void* bias_ptr = nullptr; + ck_tile::index_t stride_bias = 0; + ck_tile::index_t nhead_stride_bias = 0; + }; + + struct FmhaFwdBatchModeBiasKargs : FmhaFwdCommonBiasKargs + { + ck_tile::index_t batch_stride_bias = 0; + }; + + struct FmhaFwdMaskKargs + { + ck_tile::index_t mask_y, mask_x; + }; + + struct FmhaFwdFP8Kargs + { + float descale_qk; // q*k + float descale_sv; // s*v + // float * o_amax_ptr; + }; + + struct FmhaFwdCommonLSEKargs + { + void* lse_ptr = nullptr; + ck_tile::index_t nhead_stride_lse = 0; + }; + + struct FmhaFwdBatchModeLSEKargs : FmhaFwdCommonLSEKargs + { + ck_tile::index_t batch_stride_lse = 0; + }; + + struct FmhaFwdBatchModeKargs + : FmhaFwdCommonKargs, + std::conditional_t>, + std::conditional_t>, + std::conditional_t>, + std::conditional_t> + { + ck_tile::index_t batch_stride_q; + ck_tile::index_t batch_stride_k; + ck_tile::index_t batch_stride_v; + ck_tile::index_t batch_stride_o; + }; + + struct FmhaFwdGroupModeKargs + : FmhaFwdCommonKargs, + std::conditional_t>, + std::conditional_t>, + std::conditional_t>, + std::conditional_t> + { + const int32_t* seqstart_q_ptr; + const int32_t* seqstart_k_ptr; + const int32_t* seqlen_k_ptr; + }; + + using Kargs = std::conditional_t; + + template + __host__ static constexpr std::enable_if_t + MakeKargs(const void* q_ptr, + const void* k_ptr, + const void* v_ptr, + const void* bias_ptr, + void* lse_ptr, + void* o_ptr, + ck_tile::index_t seqlen_q, + ck_tile::index_t seqlen_k, + ck_tile::index_t hdim_q, + ck_tile::index_t hdim_v, + ck_tile::index_t nhead_ratio_qk, + float scale, + ck_tile::index_t stride_q, + ck_tile::index_t stride_k, + ck_tile::index_t stride_v, + ck_tile::index_t stride_bias, + ck_tile::index_t stride_o, + ck_tile::index_t nhead_stride_q, + ck_tile::index_t nhead_stride_k, + ck_tile::index_t nhead_stride_v, + ck_tile::index_t nhead_stride_bias, + ck_tile::index_t nhead_stride_lse, + ck_tile::index_t nhead_stride_o, + ck_tile::index_t batch_stride_q, + ck_tile::index_t batch_stride_k, + ck_tile::index_t batch_stride_v, + ck_tile::index_t batch_stride_bias, + ck_tile::index_t batch_stride_lse, + ck_tile::index_t batch_stride_o, + ck_tile::index_t mask_y, + ck_tile::index_t mask_x, + float descale_qk, + float descale_sv) + { + Kargs kargs{{q_ptr, + k_ptr, + v_ptr, + o_ptr, + seqlen_q, + seqlen_k, + hdim_q, + hdim_v, + nhead_ratio_qk, +#if CK_FMHA_FWD_FAST_EXP2 + static_cast(scale * ck_tile::log2e_v<>), +#else + scale, +#endif + stride_q, + stride_k, + stride_v, + stride_o, + nhead_stride_q, + nhead_stride_k, + nhead_stride_v, + nhead_stride_o}, // args for common karg + {}, // placeholder for bias + {}, // placeholder for mask + {}, // placeholder for lse + {}, // placeholder for fp8 args + batch_stride_q, + batch_stride_k, + batch_stride_v, + batch_stride_o}; + + if constexpr(kHasBias) + { + kargs.bias_ptr = bias_ptr; + kargs.stride_bias = stride_bias; + kargs.nhead_stride_bias = nhead_stride_bias; + kargs.batch_stride_bias = batch_stride_bias; + } + if constexpr(kHasMask) + { + kargs.mask_y = mask_y; + kargs.mask_x = mask_x; + } + if constexpr(kStoreLSE) + { + kargs.lse_ptr = lse_ptr; + kargs.nhead_stride_lse = nhead_stride_lse; + kargs.batch_stride_lse = batch_stride_lse; + } + if constexpr(kIsFp8) + { + kargs.descale_qk = descale_qk; + kargs.descale_sv = descale_sv; + } + + return kargs; + } + + template + __host__ static constexpr std::enable_if_t + MakeKargs(const void* q_ptr, + const void* k_ptr, + const void* v_ptr, + const void* bias_ptr, + void* lse_ptr, + void* o_ptr, + const void* seqstart_q_ptr, + const void* seqstart_k_ptr, + const void* seqlen_k_ptr, + ck_tile::index_t hdim_q, + ck_tile::index_t hdim_v, + ck_tile::index_t nhead_ratio_qk, + float scale, + ck_tile::index_t stride_q, + ck_tile::index_t stride_k, + ck_tile::index_t stride_v, + ck_tile::index_t stride_bias, + ck_tile::index_t stride_o, + ck_tile::index_t nhead_stride_q, + ck_tile::index_t nhead_stride_k, + ck_tile::index_t nhead_stride_v, + ck_tile::index_t nhead_stride_bias, + ck_tile::index_t nhead_stride_lse, + ck_tile::index_t nhead_stride_o, + ck_tile::index_t mask_y, + ck_tile::index_t mask_x, + float descale_qk, + float descale_sv) + { + Kargs kargs{{q_ptr, + k_ptr, + v_ptr, + o_ptr, + -1, // seqlen will be updated by another pointer + -1, // + hdim_q, + hdim_v, + nhead_ratio_qk, +#if CK_FMHA_FWD_FAST_EXP2 + static_cast(scale * ck_tile::log2e_v<>), +#else + scale, +#endif + stride_q, + stride_k, + stride_v, + stride_o, + nhead_stride_q, + nhead_stride_k, + nhead_stride_v, + nhead_stride_o}, // args for common karg + {}, // placeholder for bias + {}, // placeholder for mask + {}, // placeholder for lse + {}, // placeholder for fp8 args + reinterpret_cast(seqstart_q_ptr), + reinterpret_cast(seqstart_k_ptr), + reinterpret_cast(seqlen_k_ptr)}; + + if constexpr(kHasBias) + { + kargs.bias_ptr = bias_ptr; + kargs.stride_bias = stride_bias; + kargs.nhead_stride_bias = nhead_stride_bias; + } + if constexpr(kHasMask) + { + kargs.mask_y = mask_y; + kargs.mask_x = mask_x; + } + if constexpr(kStoreLSE) + { + kargs.lse_ptr = lse_ptr; + kargs.nhead_stride_lse = nhead_stride_lse; + } + if constexpr(kIsFp8) + { + kargs.descale_qk = descale_qk; + kargs.descale_sv = descale_sv; + } + + return kargs; + } + + __host__ static constexpr auto GridSize(ck_tile::index_t batch_size_, + ck_tile::index_t nhead_, + ck_tile::index_t seqlen_q_, + ck_tile::index_t hdim_v_) + { + return TilePartitioner::GridSize(batch_size_, nhead_, seqlen_q_, hdim_v_); + } + + __host__ static constexpr auto BlockSize() { return dim3(kBlockSize); } + + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() + { + return ck_tile::max(FmhaPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize()); + } + + CK_TILE_DEVICE void operator()(Kargs kargs) const + { + // allocate LDS + __shared__ char smem_ptr[GetSmemSize()]; + + // divide problem + const auto [i_tile_m, i_tile_n, i_nhead, i_batch] = + TilePartitioner{}(kargs.seqlen_q, kargs.hdim_v); + + const index_t i_m0 = __builtin_amdgcn_readfirstlane(i_tile_m * FmhaPipeline::kM0); + const index_t i_n1 = __builtin_amdgcn_readfirstlane(i_tile_n * FmhaPipeline::kN1); + + long_index_t batch_offset_q = 0; + long_index_t batch_offset_k = 0; + long_index_t batch_offset_v = 0; + long_index_t batch_offset_bias = 0; + long_index_t batch_offset_lse = 0; + long_index_t batch_offset_o = 0; + + if constexpr(kIsGroupMode) + { + // get starting offset for each batch + const long_index_t query_start = kargs.seqstart_q_ptr[i_batch]; + const long_index_t key_start = kargs.seqstart_k_ptr[i_batch]; + + batch_offset_q = query_start * kargs.stride_q; + batch_offset_k = key_start * kargs.stride_k; + if constexpr(ck_tile::is_same_v) + { + batch_offset_v = key_start * kargs.stride_v; + } + else + { + batch_offset_v = key_start; + } + if constexpr(kHasBias) + { + batch_offset_bias = query_start * kargs.stride_bias + key_start; + } + else + { + batch_offset_bias = key_start; + } + if constexpr(kStoreLSE) + { + batch_offset_lse = query_start; + } + batch_offset_o = query_start * kargs.stride_o; + + // get real # queries & # keys under group mode + const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch; + kargs.seqlen_q = adjusted_seqstart_q_ptr[1] - adjusted_seqstart_q_ptr[0]; + + // # of required blocks is different in each groups, terminate unnecessary blocks + // earlier + if(kargs.seqlen_q <= i_m0) + { + return; + } + + if(kargs.seqlen_k_ptr != nullptr) + { + kargs.seqlen_k = kargs.seqlen_k_ptr[i_batch]; + } + else + { + const auto adjusted_seqstart_k_ptr = kargs.seqstart_k_ptr + i_batch; + kargs.seqlen_k = adjusted_seqstart_k_ptr[1] - adjusted_seqstart_k_ptr[0]; + } + } + else + { + batch_offset_q = static_cast(i_batch) * kargs.batch_stride_q; + batch_offset_k = static_cast(i_batch) * kargs.batch_stride_k; + batch_offset_v = static_cast(i_batch) * kargs.batch_stride_v; + if constexpr(kHasBias) + { + batch_offset_bias = static_cast(i_batch) * kargs.batch_stride_bias; + } + if constexpr(kStoreLSE) + { + batch_offset_lse = static_cast(i_batch) * kargs.batch_stride_lse; + } + batch_offset_o = static_cast(i_batch) * kargs.batch_stride_o; + } + + // for simplicity, batch stride we just modify the pointer + const QDataType* q_ptr = reinterpret_cast(kargs.q_ptr) + + static_cast(i_nhead) * kargs.nhead_stride_q + + batch_offset_q; + const KDataType* k_ptr = + reinterpret_cast(kargs.k_ptr) + + static_cast(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_k + + batch_offset_k; + const VDataType* v_ptr = + reinterpret_cast(kargs.v_ptr) + + static_cast(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_v + + batch_offset_v; + ODataType* o_ptr = reinterpret_cast(kargs.o_ptr) + + static_cast(i_nhead) * kargs.nhead_stride_o + + batch_offset_o; + + // Q/K/V DRAM and DRAM window + const auto q_dram = [&]() { + const auto q_dram_naive = make_naive_tensor_view( + q_ptr, + make_tuple(kargs.seqlen_q, kargs.hdim_q), + make_tuple(kargs.stride_q, 1), + number{}, + number<1>{}); + if constexpr(FmhaPipeline::kQLoadOnce) + { + return pad_tensor_view( + q_dram_naive, + make_tuple(number{}, number{}), + sequence{}); + } + else + { + return pad_tensor_view( + q_dram_naive, + make_tuple(number{}, number{}), + sequence{}); + } + }(); + const auto k_dram = [&]() { + const auto k_dram_naive = make_naive_tensor_view( + k_ptr, + make_tuple(kargs.seqlen_k, kargs.hdim_q), + make_tuple(kargs.stride_k, 1), + number{}, + number<1>{}); + + return pad_tensor_view( + k_dram_naive, + make_tuple(number{}, number{}), + sequence{}); + }(); + const auto v_dram = [&]() { + if constexpr(ck_tile::is_same_v) + { + const auto v_dram_naive = make_naive_tensor_view( + v_ptr, + make_tuple(kargs.seqlen_k, kargs.hdim_v), + make_tuple(kargs.stride_v, 1), + number{}, + number<1>{}); + + const auto v_dram_transposed = + transform_tensor_view(v_dram_naive, + make_tuple(make_pass_through_transform(kargs.hdim_v), + make_pass_through_transform(kargs.seqlen_k)), + make_tuple(sequence<1>{}, sequence<0>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + + return pad_tensor_view( + v_dram_transposed, + make_tuple(number{}, number{}), + sequence{}); + } + else + { + const auto v_dram_naive = make_naive_tensor_view( + v_ptr, + make_tuple(kargs.hdim_v, kargs.seqlen_k), + make_tuple(kargs.stride_v, 1), + number{}, + number<1>{}); + + return pad_tensor_view( + v_dram_naive, + make_tuple(number{}, number{}), + sequence{}); + } + }(); + + auto q_dram_window = make_tile_window( + q_dram, + [&]() { + if constexpr(FmhaPipeline::kQLoadOnce) + return make_tuple(number{}, + number{}); + else + return make_tuple(number{}, number{}); + }(), + {i_m0, 0}); + + auto k_dram_window = make_tile_window( + k_dram, make_tuple(number{}, number{}), {0, 0}); + + auto v_dram_window = + make_tile_window(v_dram, + make_tuple(number{}, number{}), + {i_n1, 0}); + /// FIXME: Before C++20, capturing structured binding variables is not supported. Remove + /// following copy capture of the 'i_nhead' + /// if compiled in C++20 + const auto bias_dram_window = [&, i_nhead_ = i_nhead]() { + constexpr auto bias_dram_window_lengths = + make_tuple(number{}, number{}); + if constexpr(kHasBias) + { + const BiasDataType* bias_ptr = + reinterpret_cast(kargs.bias_ptr) + + static_cast(i_nhead_) * kargs.nhead_stride_bias + + batch_offset_bias; + + const auto bias_dram = [&]() { + const auto bias_dram_naive = make_naive_tensor_view( + bias_ptr, + make_tuple(kargs.seqlen_q, kargs.seqlen_k), + make_tuple(kargs.stride_bias, 1), + number{}, + number<1>{}); + + return pad_tensor_view(bias_dram_naive, + bias_dram_window_lengths, + sequence{}); + }(); + + return make_tile_window(bias_dram, bias_dram_window_lengths, {i_m0, 0}); + } + else + { + return make_null_tile_window(bias_dram_window_lengths); + } + }(); + + // lse + auto lse_dram_window = [&, i_nhead_ = i_nhead]() { + constexpr auto lse_dram_window_lengths = make_tuple(number{}); + if constexpr(kStoreLSE) + { + LSEDataType* lse_ptr = + reinterpret_cast(kargs.lse_ptr) + + static_cast(i_nhead_) * kargs.nhead_stride_lse + batch_offset_lse; + + const auto lse_dram = [&]() { + const auto lse_dram_naive = make_naive_tensor_view( + lse_ptr, + make_tuple(kargs.seqlen_q), + make_tuple(1), + number<1>{}, + number<1>{}); + + return pad_tensor_view( + lse_dram_naive, lse_dram_window_lengths, sequence{}); + }(); + + return make_tile_window(lse_dram, lse_dram_window_lengths, {i_m0}); + } + else + { + return make_null_tile_window(lse_dram_window_lengths); + } + }(); + + FmhaMask mask = [&]() { + if constexpr(kHasMask) + return FmhaMask{kargs.mask_y, kargs.mask_x, kargs.seqlen_q, kargs.seqlen_k}; + else + return FmhaMask{kargs.seqlen_q, kargs.seqlen_k}; + }(); + + auto o_acc_tile = [&]() { + if constexpr(kIsFp8) + { + return FmhaPipeline{}(q_dram_window, + k_dram_window, + v_dram_window, + bias_dram_window, + lse_dram_window, + mask, + kargs.scale, + kargs.descale_qk, + kargs.descale_sv, + smem_ptr); + } + else + { + return FmhaPipeline{}(q_dram_window, + k_dram_window, + v_dram_window, + bias_dram_window, + lse_dram_window, + mask, + kargs.scale, + smem_ptr); + } + }(); + + // O DRAM and O DRAM window + auto o_dram = [&]() { + const auto o_dram_naive = make_naive_tensor_view( + o_ptr, + make_tuple(kargs.seqlen_q, kargs.hdim_v), + make_tuple(kargs.stride_o, 1), + number{}, + number<1>{}); + + return pad_tensor_view( + o_dram_naive, + make_tuple(number{}, number{}), + sequence{}); + }(); + + auto o_dram_window = + make_tile_window(o_dram, + make_tuple(number{}, number{}), + {i_m0, i_n1}); + + EpiloguePipeline{}(o_dram_window, o_acc_tile); + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_tile_partitioner.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_tile_partitioner.hpp new file mode 100644 index 0000000000..52f458c72e --- /dev/null +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_tile_partitioner.hpp @@ -0,0 +1,54 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" + +namespace ck_tile { + +template +struct FmhaFwdTilePartitioner +{ + using BlockFmhaShape = ck_tile::remove_cvref_t; + + static constexpr ck_tile::index_t kM0 = BlockFmhaShape::kM0; + static constexpr ck_tile::index_t kN0 = BlockFmhaShape::kN0; + static constexpr ck_tile::index_t kK0 = BlockFmhaShape::kK0; + static constexpr ck_tile::index_t kN1 = BlockFmhaShape::kN1; + static constexpr ck_tile::index_t kK1 = BlockFmhaShape::kK1; + + __host__ static constexpr auto GridSize(ck_tile::index_t batch_size_, + ck_tile::index_t nhead_, + ck_tile::index_t seqlen_q_, + ck_tile::index_t hdim_v_) + { + // TODO: this may need tuning + return dim3(ck_tile::integer_divide_ceil(seqlen_q_, kM0) * + ck_tile::integer_divide_ceil(hdim_v_, kN1), + nhead_, + batch_size_); + } + + CK_TILE_DEVICE auto operator()(ck_tile::index_t /*seqlen_q*/, ck_tile::index_t hdim_v) + { + // const index_t num_tile_m0 = seqlen_q / kM0; + const index_t num_tile_n1 = ck_tile::integer_divide_ceil(hdim_v, kN1); + + const index_t i_block = blockIdx.x; + const index_t i_nhead = blockIdx.y; + const index_t i_batch = blockIdx.z; + + const auto f = [](index_t dividend, index_t divisor) { + index_t quotient = dividend / divisor; + index_t modulus = dividend - quotient * divisor; + return ck_tile::make_tuple(quotient, modulus); + }; + + const auto [i_tile_m, i_tile_n] = f(i_block, num_tile_n1); + + return ck_tile::make_tuple(i_tile_m, i_tile_n, i_nhead, i_batch); + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp new file mode 100644 index 0000000000..5a1f1b0520 --- /dev/null +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp @@ -0,0 +1,60 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" + +namespace ck_tile { + +template +struct BlockFmhaPipelineProblem +{ + using QDataType = remove_cvref_t; + using KDataType = remove_cvref_t; + using VDataType = remove_cvref_t; + using SaccDataType = remove_cvref_t; + using SMPLComputeDataType = remove_cvref_t; + using BiasDataType = remove_cvref_t; + using LSEDataType = remove_cvref_t; + using PDataType = remove_cvref_t; + using OaccDataType = remove_cvref_t; + using ODataType = remove_cvref_t; + using BlockFmhaShape = remove_cvref_t; + using FmhaMask = remove_cvref_t; + using Traits = remove_cvref_t; + + static constexpr index_t kBlockSize = BlockFmhaShape::NumWarps * get_warp_size(); + static constexpr bool kIsGroupMode = kIsGroupMode_; + + // attributes from traits + static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ; + static constexpr bool kPadSeqLenK = Traits::kPadSeqLenK; + static constexpr bool kPadHeadDimQ = Traits::kPadHeadDimQ; + static constexpr bool kPadHeadDimV = Traits::kPadHeadDimV; + static constexpr bool kHasBias = Traits::kHasBias; + static constexpr bool kStoreLSE = Traits::kStoreLSE; + static constexpr index_t kBlockPerCu = Traits::kBlockPerCu; + static constexpr bool kIsFp8 = + (is_same_v || is_same_v)&&( + is_same_v || + is_same_v)&&(is_same_v || + is_same_v)&&is_same_v && + is_same_v; +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp new file mode 100644 index 0000000000..8b8ac4a35b --- /dev/null +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp @@ -0,0 +1,581 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" + +namespace ck_tile { + +// This pipeline is qkv all located in LDS +template +struct BlockFmhaPipelineQRKSVS +{ + using Problem = remove_cvref_t; + using Policy = remove_cvref_t; + using QDataType = remove_cvref_t; + using KDataType = remove_cvref_t; + using VDataType = remove_cvref_t; + using SaccDataType = remove_cvref_t; + using SMPLComputeDataType = remove_cvref_t; + using BiasDataType = remove_cvref_t; + using LSEDataType = remove_cvref_t; + using PDataType = remove_cvref_t; + using OaccDataType = remove_cvref_t; + using ODataType = remove_cvref_t; + using FmhaMask = remove_cvref_t; + + using BlockFmhaShape = remove_cvref_t; + using VLayout = remove_cvref_t; + static constexpr bool kQLoadOnce = true; // if q_tile load whole block length (hdim) at once + static_assert(kQLoadOnce == Policy::QLoadOnce); + static constexpr bool kIsFp8 = Problem::kIsFp8; + + static constexpr index_t kBlockSize = Problem::kBlockSize; + + static constexpr index_t kM0 = BlockFmhaShape::kM0; + static constexpr index_t kN0 = BlockFmhaShape::kN0; + static constexpr index_t kK0 = BlockFmhaShape::kK0; + static constexpr index_t kN1 = BlockFmhaShape::kN1; + static constexpr index_t kK1 = BlockFmhaShape::kK1; + static constexpr index_t kK0BlockLength = BlockFmhaShape::kK0BlockLength; + + static constexpr bool kIsGroupMode = Problem::kIsGroupMode; + static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ; + static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK; + static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ; + static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV; + static constexpr bool kHasBias = Problem::kHasBias; + static constexpr bool kStoreLSE = Problem::kStoreLSE; + + // last dimension vector length used to create tensor view(and decide buffer_load vector length) + // ... together with tensor distribution. tensor dist should able to overwrite this + static constexpr index_t kAlignmentQ = + kPadHeadDimQ ? 1 : Policy::template GetAlignmentQ(); + static constexpr index_t kAlignmentK = + kPadHeadDimQ ? 1 : Policy::template GetAlignmentK(); + static constexpr index_t kAlignmentV = []() { + if constexpr(ck_tile::is_same_v) + return kPadHeadDimV ? 1 : Policy::template GetAlignmentV(); + else + return kPadSeqLenK ? 1 : Policy::template GetAlignmentV(); + }(); + + static constexpr index_t kAlignmentO = + kPadHeadDimV ? 1 : Policy::template GetAlignmentO(); + static constexpr index_t kAlignmentBias = + kPadSeqLenK ? 1 : Policy::template GetAlignmentBias(); + + static constexpr index_t kBlockPerCu = []() { + if constexpr(Problem::kBlockPerCu != -1) + return Problem::kBlockPerCu; + else + { + if constexpr(kK0BlockLength <= 32) + { + return 2; + } + else if constexpr(kK0BlockLength <= 64) + { + return 3; + } + else if constexpr(kK0BlockLength <= 128) + { + if constexpr(kHasBias) + return 1; + else + return 2; + } + else if constexpr(kK0BlockLength <= 256) + { + return 1; + } + } + }(); + + static constexpr const char* name = "qr"; + + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() + { + return Policy::template GetSmemSize(); + } + + template + CK_TILE_HOST_DEVICE auto + operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile + const QElementFunction& q_element_func, + const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile + const KElementFunction& k_element_func, + const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile + const VElementFunction& v_element_func, + const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile + const BiasElementFunction& bias_element_func, + LSEDramBlockWindowTmp& lse_dram_window_tmp, // M0*1 tile + const LSEElementFunction& lse_element_func, + FmhaMask mask, + float scale, + void* smem_ptr) const + { + static_assert( + is_same_v> && + is_same_v> && + is_same_v>, + "wrong!"); + + static_assert(kM0 == QDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kN0 == KDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kK0 == KDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] && + kN1 == VDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kK1 == VDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] && + kM0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kN0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<1>{}], + "wrong!"); + + // K tile in LDS + KDataType* k_lds_ptr = static_cast(static_cast( + static_cast(smem_ptr) + Policy::template GetSmemSizeQ())); + auto k_lds = make_tensor_view( + k_lds_ptr, Policy::template MakeKLdsBlockDescriptor()); + auto k_lds_window = + make_tile_window(k_lds, make_tuple(number{}, number{}), {0, 0}); + + // V tile in LDS + auto v_lds = make_tensor_view( + reinterpret_cast(smem_ptr), + Policy::template MakeVLdsBlockDescriptor()); + auto v_lds_window = make_tile_window( + v_lds, Policy::template MakeVLdsBlockDescriptor().get_lengths(), {0, 0}); + + // Block GEMM + constexpr auto gemm_0 = Policy::template GetQKBlockGemm(); + constexpr auto gemm_1 = Policy::template GetKVBlockGemm(); + + auto q_dram_window = make_tile_window( + q_dram_block_window_tmp.get_bottom_tensor_view(), + q_dram_block_window_tmp.get_window_lengths(), + q_dram_block_window_tmp.get_window_origin(), + Policy::template MakeQDramTileDistribution()); + + auto q = load_tile(q_dram_window); + + using SaccBlockTileType = decltype(gemm_0.MakeCBlockTile()); + auto s_acc = SaccBlockTileType{}; + + // reduction function for softmax + const auto f_max = [](auto e0, auto e1) { return max(e0, e1); }; + const auto f_sum = [](auto e0, auto e1) { return e0 + e1; }; + + // infer Sacc, S, P, M, L, Oacc type + using SBlockTileType = decltype(cast_tile(s_acc)); + + using MLBlockTileType = decltype(block_tile_reduce( + SBlockTileType{}, sequence<1>{}, f_max, SMPLComputeDataType{0})); + + using OaccBlockTileType = decltype(gemm_1.MakeCBlockTile()); + + // init Oacc, M, L + auto o_acc = OaccBlockTileType{}; + auto m = MLBlockTileType{}; + auto l = MLBlockTileType{}; + + clear_tile(o_acc); + set_tile(m, -NumericLimits::Infinity()); + clear_tile(l); + + const auto q_origin = q_dram_window.get_window_origin(); + const auto [seqlen_k_start, seqlen_k_end] = + mask.GetTileRangeAlongX(q_origin.at(number<0>{}), number{}, number{}); + + const auto num_total_loop = math::integer_divide_ceil(seqlen_k_end - seqlen_k_start, kN0); + + // check early exit if masked and no work to do. + if constexpr(FmhaMask::IsMasking) + { + if(num_total_loop <= 0) + { + if constexpr(kStoreLSE) + { + auto lse = + make_static_distributed_tensor(m.get_tile_distribution()); + + set_tile(lse, -NumericLimits::Infinity()); + + store_tile(lse_dram_window_tmp, tile_elementwise_in(lse_element_func, lse)); + } + + // Note: here occ are all cleard, return it + // Note: q loaded but no fence, ignore it. + return o_acc; + } + } + + auto k_dram_block_window = + make_tile_window(k_dram_block_window_tmp.get_bottom_tensor_view(), + k_dram_block_window_tmp.get_window_lengths(), + {seqlen_k_start, 0}); + + const auto bias_origin = bias_dram_block_window_tmp.get_window_origin(); + auto bias_dram_window = make_tile_window( + bias_dram_block_window_tmp.get_bottom_tensor_view(), + bias_dram_block_window_tmp.get_window_lengths(), + {bias_origin.at(number<0>{}), seqlen_k_start}, // M/N + Policy::template MakeBiasDramTileDistribution()); + + auto v_dram_window = + make_tile_window(v_dram_block_window_tmp.get_bottom_tensor_view(), + v_dram_block_window_tmp.get_window_lengths(), + {0, seqlen_k_start}, // TODO: hdim split? + Policy::template MakeVDramTileDistribution()); + + auto q_tile = tile_elementwise_in(q_element_func, q); + + // prefetch K tile + index_t i_total_loops = 0; + constexpr index_t k0_loops = kK0BlockLength / kK0; + constexpr index_t k1_loops = kN0 / kK1; + + static_assert(2 <= k0_loops); + static_assert(1 <= k1_loops); + do + { + // STAGE 1, QK gemm + auto k_dram_window = make_tile_window( + k_dram_block_window.get_bottom_tensor_view(), + k_dram_block_window.get_window_lengths(), + k_dram_block_window.get_window_origin(), + Policy::template MakeKDramTileDistribution()); // K DRAM tile window for + // load + + auto k_block_tile = load_tile(k_dram_window); + { + move_tile_window(k_dram_window, {0, kK0}); + clear_tile(s_acc); // initialize C + store_tile(k_lds_window, tile_elementwise_in(k_element_func, k_block_tile)); + k_block_tile = load_tile(k_dram_window); + } + + if constexpr(kHasBias) + { + __builtin_amdgcn_sched_barrier( + 0); // prevent from messing up the order of global loads + } + const auto bias_tile = load_tile(bias_dram_window); // load bias tile + if constexpr(kHasBias) + { + __builtin_amdgcn_sched_barrier( + 0); // prevent from messing up the order of global loads + } + + if constexpr(k0_loops > 2) + { + static_for<0, k0_loops - 2, 1>{}([&](auto i_k0) { + block_sync_lds(); + gemm_0(s_acc, + get_slice_tile(q_tile, + sequence<0, i_k0 * kK0>{}, + sequence{}), + k_lds_window); + block_sync_lds(); + move_tile_window(k_dram_window, {0, kK0}); + + store_tile( + k_lds_window, + tile_elementwise_in(k_element_func, k_block_tile)); // LDS write i + 1 + k_block_tile = load_tile(k_dram_window); // global read i + 2 + }); + } + + const auto v_prefetch = load_tile(v_dram_window); // prefetch load v tile + { // tail + block_sync_lds(); + gemm_0(s_acc, + get_slice_tile(q_tile, + sequence<0, (k0_loops - 2) * kK0>{}, + sequence{}), + k_lds_window); + block_sync_lds(); + + store_tile(k_lds_window, tile_elementwise_in(k_element_func, k_block_tile)); + block_sync_lds(); + + gemm_0(s_acc, + get_slice_tile(q_tile, + sequence<0, (k0_loops - 1) * kK0>{}, + sequence{}), + k_lds_window); + } + + // STAGE 2, scale, add bias, mask, softmax + if constexpr(kHasBias) + { + tile_elementwise_inout( + [&](auto& x, const auto& y) { +#if !CK_FMHA_FWD_FAST_EXP2 + x = scale * x + type_convert(bias_element_func(y)); +#else + x = scale * x + math::log2e_v * + type_convert(bias_element_func(y)); +#endif + }, + s_acc, + bias_tile); + } + else + { +#if !CK_FMHA_FWD_FAST_EXP2 + tile_elementwise_inout([&scale](auto& x) { x = x * scale; }, s_acc); +#endif + } + move_tile_window(bias_dram_window, {0, kN0}); + if constexpr(kPadSeqLenK || FmhaMask::IsMasking) + { + const auto k_origin = k_dram_block_window.get_window_origin(); + bool need_perpixel_check = mask.IsEdgeTile(q_origin.at(number<0>{}), + k_origin.at(number<0>{}), + number{}, + number{}); + if(need_perpixel_check) + { + set_tile_if( + s_acc, -NumericLimits::Infinity(), [&](auto tile_idx) { + const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); + const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{}); + return mask.IsOutOfBound(row, col); + }); + } + } + + const auto s = cast_tile(s_acc); // S{j} + auto m_local = block_tile_reduce( + s, + sequence<1>{}, + f_max, + -NumericLimits::Infinity()); // m_local = rowmax(S{j}) + block_tile_reduce_sync(m_local, f_max, bool_constant{}); + + const auto m_old = m; // m{j-1} + tile_elementwise_inout( + [](auto& e0, auto e1, auto e2) { e0 = max(e1, e2); }, m, m_old, m_local); // m{j} + + auto p_compute = make_static_distributed_tensor( + s.get_tile_distribution()); // Pcompute{j} + + static const auto get_validated_m = [](SMPLComputeDataType raw_m) { + /// NOTICE: bias might be materialized mask including -inf values, need + /// consideration + if constexpr(kHasBias || FmhaMask::IsMasking) + { + return raw_m == -NumericLimits::Infinity() + ? type_convert(0.f) + : raw_m; + } + else + { + return raw_m; + } + }; + + constexpr auto p_spans = decltype(p_compute)::get_distributed_spans(); + sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); +#if CK_FMHA_FWD_FAST_EXP2 + auto row_max = scale * get_validated_m(m[i_idx]); +#endif + sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); +#if CK_FMHA_FWD_FAST_EXP2 + if constexpr(kHasBias) + { + p_compute(i_j_idx) = math::exp2(s[i_j_idx] - get_validated_m(m[i_idx])); + } + else + { + p_compute(i_j_idx) = math::exp2(scale * s[i_j_idx] - row_max); + } +#else + p_compute(i_j_idx) = math::exp(s[i_j_idx] - get_validated_m(m[i_idx])); +#endif + }); + }); + + auto rowsum_p = block_tile_reduce( + p_compute, sequence<1>{}, f_sum, SMPLComputeDataType{0}); // rowsum(Pcompute{j}) + + block_tile_reduce_sync(rowsum_p, f_sum, bool_constant{}); + // l{j}, Oacc{j} + constexpr auto o_spans = decltype(o_acc)::get_distributed_spans(); + sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); +#if CK_FMHA_FWD_FAST_EXP2 + const auto tmp = [&]() { + if constexpr(kHasBias) + { + return math::exp2(m_old[i_idx] - get_validated_m(m[i_idx])); + } + else + { + auto row_max = scale * get_validated_m(m[i_idx]); + return math::exp2(scale * m_old[i_idx] - row_max); + } + }(); +#else + const auto tmp = math::exp(m_old[i_idx] - get_validated_m(m[i_idx])); +#endif + l(i_idx) = tmp * l[i_idx] + rowsum_p[i_idx]; + sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); + // FIXME: this use different equation from FA v2 paper, + // but produce correc result. + // Is the equation wrong? + o_acc(i_j_idx) *= tmp; + }); + }); + + block_sync_lds(); + if constexpr(ck_tile::is_same_v) + { + auto v_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledVRegBlockDescriptor()); + shuffle_tile(v_shuffle_tmp, v_prefetch); + store_tile( + v_lds_window, + tile_elementwise_in(v_element_func, v_shuffle_tmp)); // store the prefetch + } + else + { + store_tile(v_lds_window, + tile_elementwise_in(v_element_func, v_prefetch)); // store the prefetch + } + move_tile_window(v_dram_window, {0, kK1}); + + const auto p = cast_tile(p_compute); + + // STAGE 3, KV gemm + if constexpr(k1_loops > 1) + { + static_for<0, k1_loops - 1, 1>{}([&](auto i_k1) { + const auto v = load_tile(v_dram_window); // load next v + block_sync_lds(); + gemm_1(o_acc, + get_slice_tile( + p, sequence<0, i_k1 * kK1>{}, sequence{}), + v_lds_window); + block_sync_lds(); + if constexpr(ck_tile::is_same_v) + { + auto v_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledVRegBlockDescriptor()); + shuffle_tile(v_shuffle_tmp, v); + store_tile(v_lds_window, + tile_elementwise_in(v_element_func, + v_shuffle_tmp)); // store the prefetch + } + else + { + store_tile(v_lds_window, + tile_elementwise_in(v_element_func, v)); // store next v + } + move_tile_window(v_dram_window, {0, kK1}); + }); + } + // move K tile windows + move_tile_window(k_dram_block_window, {kN0, 0}); + // tail + { + block_sync_lds(); + gemm_1(o_acc, + get_slice_tile(p, sequence<0, (k1_loops - 1) * kK1>{}, sequence{}), + v_lds_window); + block_sync_lds(); + } + } while(++i_total_loops < num_total_loop); + + // store lse + if constexpr(kStoreLSE) + { + auto lse = make_static_distributed_tensor(m.get_tile_distribution()); + + constexpr auto lse_spans = decltype(lse)::get_distributed_spans(); + sweep_tile_span(lse_spans[number<0>{}], [&, m_ = m, l_ = l](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); +#if CK_FMHA_FWD_FAST_EXP2 + if constexpr(kHasBias) + { + lse(i_idx) = m_[i_idx] / C_LOG2E + math::log(l_[i_idx]); + } + else + { + lse(i_idx) = m_[i_idx] * scale / C_LOG2E + math::log(l_[i_idx]); + } +#else + lse(i_idx) = m_[i_idx] + math::log(l_[i_idx]); +#endif + }); + + store_tile(lse_dram_window_tmp, tile_elementwise_in(lse_element_func, lse)); + } + + // finally, O + constexpr auto o_spans = decltype(o_acc)::get_distributed_spans(); + + sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); + const auto tmp = [&]() { + if constexpr(FmhaMask::IsMasking) + { + return l[i_idx] == 0.f ? 0.f : 1 / l[i_idx]; + } + else + return 1 / l[i_idx]; + }(); + sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); + o_acc(i_j_idx) *= tmp; + }); + }); + + return o_acc; + } + + template + CK_TILE_HOST_DEVICE auto + operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile + const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile + const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile + const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile + LSEDramBlockWindowTmp& lse_dram_block_window_tmp, // M0*1 tile + FmhaMask mask, + float scale, + void* smem_ptr) const + { + return operator()(q_dram_block_window_tmp, + identity{}, + k_dram_block_window_tmp, + identity{}, + v_dram_block_window_tmp, + identity{}, + bias_dram_block_window_tmp, + identity{}, + lse_dram_block_window_tmp, + identity{}, + mask, + scale, + smem_ptr); + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp new file mode 100644 index 0000000000..5aed576784 --- /dev/null +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp @@ -0,0 +1,676 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/common.hpp" + +namespace ck_tile { + +// a variation of qr/ks/vs, where we use async copy to load k (potentially v in the future) +template +struct BlockFmhaPipelineQRKSVSAsync +{ + using Problem = remove_cvref_t; + using Policy = remove_cvref_t; + using QDataType = remove_cvref_t; + using KDataType = remove_cvref_t; + using VDataType = remove_cvref_t; + using SaccDataType = remove_cvref_t; + using SMPLComputeDataType = remove_cvref_t; + using BiasDataType = remove_cvref_t; + using LSEDataType = remove_cvref_t; + using PDataType = remove_cvref_t; + using OaccDataType = remove_cvref_t; + using ODataType = remove_cvref_t; + using FmhaMask = remove_cvref_t; + + using BlockFmhaShape = remove_cvref_t; + using VLayout = remove_cvref_t; + static constexpr bool kQLoadOnce = true; // if q_tile load whole block length (hdim) at once + static_assert(kQLoadOnce == Policy::QLoadOnce); + static constexpr bool kIsFp8 = Problem::kIsFp8; + + static constexpr index_t kBlockSize = Problem::kBlockSize; + + static constexpr index_t kM0 = BlockFmhaShape::kM0; + static constexpr index_t kN0 = BlockFmhaShape::kN0; + static constexpr index_t kK0 = BlockFmhaShape::kK0; + static constexpr index_t kN1 = BlockFmhaShape::kN1; + static constexpr index_t kK1 = BlockFmhaShape::kK1; + static constexpr index_t kK0BlockLength = BlockFmhaShape::kK0BlockLength; + + static constexpr bool kIsGroupMode = Problem::kIsGroupMode; + // TODO: seq_q always support padding, hdim_q/v support multiple of vector(like 8x) + // only need special care about seq_k padding (oob need set -INF of p instead of zero) + static_assert(Problem::kPadSeqLenQ == true && Problem::kPadHeadDimQ == true && + Problem::kPadHeadDimV == true); + static constexpr bool kPadSeqLenQ = true; + static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK; + static constexpr bool kPadHeadDimQ = true; // support multiple of vector(like 8x) + static constexpr bool kPadHeadDimV = true; // support multiple of vector(like 8x) + static constexpr bool kHasBias = Problem::kHasBias; + static constexpr bool kStoreLSE = Problem::kStoreLSE; + + // last dimension vector length used to create tensor view(and decide buffer_load vector length) + // ... together with tensor distribution. tensor dist should able to overwrite this + static constexpr index_t kAlignmentQ = Policy::template GetAlignmentQ(); + static constexpr index_t kAlignmentK = Policy::template GetAlignmentK(); + static constexpr index_t kAlignmentV = []() { + if constexpr(ck_tile::is_same_v) + return Policy::template GetAlignmentV(); + else + return kPadSeqLenK ? 1 : Policy::template GetAlignmentV(); + }(); + static constexpr index_t kAlignmentO = Policy::template GetAlignmentO(); + static constexpr index_t kAlignmentBias = + kPadSeqLenK ? 1 : Policy::template GetAlignmentBias(); + +#if CK_FMHA_FWD_FAST_EXP2 + static constexpr auto R_LOG2E = 1.0 / math::log2e_v; +#endif + + static constexpr index_t kBlockPerCu = []() { + if constexpr(Problem::kBlockPerCu != -1) + return Problem::kBlockPerCu; + else + { + if constexpr(kK0BlockLength <= 32) + { + if constexpr(kPadSeqLenK && kHasBias && FmhaMask::IsMasking) + return 1; + else + return 2; + } + else if constexpr(kK0BlockLength <= 64) + { + if constexpr(kPadSeqLenK && kHasBias) + return 2; + else + return 3; + } + else if constexpr(kK0BlockLength <= 128) + { + if constexpr(kPadSeqLenK && kHasBias) + return 1; + else + return 2; + } + else if constexpr(kK0BlockLength <= 256) + { + return 1; + } + } + }(); + + static constexpr const char* name = "qr_async"; + + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() + { + return Policy::template GetSmemSize(); + } + + template + CK_TILE_HOST_DEVICE auto + operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile + const QElementFunction& q_element_func, + const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile + const KElementFunction& /*k_element_func*/, + const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile + const VElementFunction& v_element_func, + const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile + const BiasElementFunction& bias_element_func, + LSEDramBlockWindowTmp& lse_dram_window_tmp, // M0*1 tile + const LSEElementFunction& lse_element_func, + FmhaMask mask, + float scale, + void* smem_ptr) const + { + static_assert( + is_same_v> && + is_same_v> && + is_same_v>, + "wrong!"); + + static_assert(kM0 == QDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kN0 == KDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kK0 == KDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] && + kN1 == VDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kK1 == VDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] && + kM0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kN0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<1>{}], + "wrong!"); + + constexpr auto LdsSeq = Policy::template GetLdsBufferSequence(); + + // K tile in LDS + auto k_lds_ptr = reinterpret_cast(smem_ptr); + auto k_lds_store = generate_tuple( + [&](auto i_buf) { + return make_tile_window( + make_tensor_view( + k_lds_ptr, Policy::template MakeKLdsStoreBlockDescriptor(i_buf)), + Policy::template MakeKLdsStoreBlockDescriptor(i_buf).get_lengths(), + {0, 0, 0}); + }, + number{}); + +#if K_LDS_LOAD_USE_OFFSET_TRANSFORM + auto k_lds_load = generate_tuple( + [&](auto i_buf) { + return make_tile_window( + make_tensor_view( + k_lds_ptr, Policy::template MakeKLdsLoadBlockDescriptor(i_buf)), + Policy::template MakeKLdsLoadBlockDescriptor(i_buf).get_lengths(), + {0, 0}); + }, + number{}); +#else + auto k_lds_Load_view = make_tensor_view( + k_lds_ptr, Policy::template MakeKLdsLoadBlockDescriptor()); + + auto k_lds_load = + make_tile_window(k_lds_Load_view, + Policy::template MakeKLdsLoadBlockDescriptor().get_lengths(), + {0, 0}); +#endif + + // V tile in LDS + auto v_lds = make_tensor_view( + reinterpret_cast(smem_ptr), + Policy::template MakeVLdsBlockDescriptor()); + auto v_lds_window = make_tile_window( + v_lds, Policy::template MakeVLdsBlockDescriptor().get_lengths(), {0, 0}); + + // Block GEMM + constexpr auto gemm_0 = Policy::template GetQKBlockGemm(); + constexpr auto gemm_1 = Policy::template GetKVBlockGemm(); + + auto q_dram_window = make_tile_window( + q_dram_block_window_tmp.get_bottom_tensor_view(), + q_dram_block_window_tmp.get_window_lengths(), + q_dram_block_window_tmp.get_window_origin(), + Policy::template MakeQDramTileDistribution()); + + // TODO: we use async Copy for K, which is inline asm + // a side effect is we have to use inline asm for q as well + auto q = decltype(load_tile(q_dram_window)){}; + set_tile(q, number<0>{}); // use per-dword clear to avoid scratch + load_tile_raw(q, q_dram_window); + __builtin_amdgcn_sched_barrier(0); + + using SaccBlockTileType = decltype(gemm_0.MakeCBlockTile()); + auto s_acc = SaccBlockTileType{}; + + // reduction function for softmax + const auto f_max = [](auto e0, auto e1) { return max(e0, e1); }; + const auto f_sum = [](auto e0, auto e1) { return e0 + e1; }; + + // infer Sacc, S, P, M, L, Oacc type + using SBlockTileType = decltype(cast_tile(s_acc)); + + using MLBlockTileType = decltype(block_tile_reduce( + SBlockTileType{}, sequence<1>{}, f_max, SMPLComputeDataType{0})); + + using OaccBlockTileType = decltype(gemm_1.MakeCBlockTile()); + + // init Oacc, M, L + auto o_acc = OaccBlockTileType{}; + auto m = MLBlockTileType{}; + auto l = MLBlockTileType{}; + + clear_tile(o_acc); + set_tile(m, -NumericLimits::Infinity()); + clear_tile(l); + + __builtin_amdgcn_sched_barrier(0); + const auto q_origin = q_dram_window.get_window_origin(); + const auto [seqlen_k_start, seqlen_k_end] = + mask.GetTileRangeAlongX(q_origin.at(number<0>{}), number{}, number{}); + + const auto num_total_loop = math::integer_divide_ceil(seqlen_k_end - seqlen_k_start, kN0); + + // check early exit if masked and no work to do. + if constexpr(FmhaMask::IsMasking) + { + if(num_total_loop <= 0) + { + if constexpr(kStoreLSE) + { + auto lse = + make_static_distributed_tensor(m.get_tile_distribution()); + + set_tile(lse, -NumericLimits::Infinity()); + + store_tile(lse_dram_window_tmp, tile_elementwise_in(lse_element_func, lse)); + } + + // Note: here occ are all cleard, return it + // Note: q loaded but no fence, ignore it. + return o_acc; + } + } + + auto k_dram_block_window = + make_tile_window(k_dram_block_window_tmp.get_bottom_tensor_view(), + k_dram_block_window_tmp.get_window_lengths(), + {seqlen_k_start, 0}); + + auto k_dram_window = make_tile_window( + k_dram_block_window.get_bottom_tensor_view(), + k_dram_block_window.get_window_lengths(), + k_dram_block_window.get_window_origin(), + Policy::template MakeKDramTileDistribution()); // K DRAM tile window for + // load + const auto bias_origin = bias_dram_block_window_tmp.get_window_origin(); + auto bias_dram_window = make_tile_window( + bias_dram_block_window_tmp.get_bottom_tensor_view(), + bias_dram_block_window_tmp.get_window_lengths(), + {bias_origin.at(number<0>{}), seqlen_k_start}, // M/N + Policy::template MakeBiasDramTileDistribution()); + + auto v_dram_window = + make_tile_window(v_dram_block_window_tmp.get_bottom_tensor_view(), + v_dram_block_window_tmp.get_window_lengths(), + {0, seqlen_k_start}, // TODO: hdim split? + Policy::template MakeVDramTileDistribution()); + + // prefetch K tile + async_load_tile_raw(k_lds_store(LdsSeq.at(number<0>{})), k_dram_window); + move_tile_window(k_dram_window, {0, kK0}); + __builtin_amdgcn_sched_barrier(0); + + buffer_load_fence(k_dram_window.get_num_access(), q.get_thread_buffer()); + (void)q_element_func; // ??? rocm-6.x if use q element func will have scratch on hdim=64/32 + // auto q_tile = q; // tile_elementwise_in(q_element_func, q); + + index_t i_total_loops = 0; + constexpr index_t k0_loops = kK0BlockLength / kK0; + constexpr index_t k1_loops = kN0 / kK1; + + static_assert(1 <= k0_loops); + static_assert(1 <= k1_loops); + // main loop + do + { + // STAGE 1, QK gemm + clear_tile(s_acc); // initialize C + if constexpr(k0_loops > 1) + { + static_for<0, k0_loops - 1, 1>{}([&](auto i_k0) { + async_load_tile_raw(k_lds_store(number{})>{}), + k_dram_window); + if constexpr(i_k0 < k0_loops - 1) + move_tile_window(k_dram_window, {0, kK0}); + + async_load_fence(k_dram_window.get_num_access()); + __builtin_amdgcn_s_barrier(); + __builtin_amdgcn_sched_barrier(0); + gemm_0(s_acc, + get_slice_tile( + q, sequence<0, i_k0 * kK0>{}, sequence{}), +#if K_LDS_LOAD_USE_OFFSET_TRANSFORM + k_lds_load[number{})>{}]); + +#else + get_slice_tile(k_lds_load, + sequence<(LdsSeq.at(number{})) * kN0, 0>{}, + sequence<(LdsSeq.at(number{}) + 1) * kN0, kK0>{})); +#endif + }); + } + + // TODO: this to fix a bug when loop smaller than 2, + // the following fence/barrier will be scheduled inside 1st loop + if constexpr(k0_loops <= 2) + __builtin_amdgcn_sched_barrier(0); + + async_load_fence(); + __builtin_amdgcn_s_barrier(); + + const auto bias_tile = load_tile(bias_dram_window); // load bias tile + auto v_buf = load_tile(v_dram_window, bool_constant{}); + __builtin_amdgcn_sched_barrier(0); + { // tail + gemm_0(s_acc, + get_slice_tile( + q, sequence<0, (k0_loops - 1) * kK0>{}, sequence{}), +#if K_LDS_LOAD_USE_OFFSET_TRANSFORM + k_lds_load[number{})>{}]); + +#else + get_slice_tile( + k_lds_load, + sequence<(LdsSeq.at(number{})) * kN0, 0>{}, + sequence<(LdsSeq.at(number{}) + 1) * kN0, kK0>{})); +#endif + } + __builtin_amdgcn_sched_barrier(1); + + // STAGE 2, scale, add bias, mask, softmax + if constexpr(kHasBias) + { + tile_elementwise_inout( + [&](auto& x, const auto& y) { +#if !CK_FMHA_FWD_FAST_EXP2 + x = scale * x + type_convert(bias_element_func(y)); +#else + x = scale * x + math::log2e_v * + type_convert(bias_element_func(y)); +#endif + }, + s_acc, + bias_tile); + } + else + { +#if !CK_FMHA_FWD_FAST_EXP2 + tile_elementwise_inout([&scale](auto& x) { x = x * scale; }, s_acc); +#endif + } + move_tile_window(bias_dram_window, {0, kN0}); + if constexpr(kPadSeqLenK || FmhaMask::IsMasking) + { + const auto k_origin = k_dram_block_window.get_window_origin(); + bool need_perpixel_check = mask.IsEdgeTile(q_origin.at(number<0>{}), + k_origin.at(number<0>{}), + number{}, + number{}); + if(need_perpixel_check) + { + set_tile_if( + s_acc, -NumericLimits::Infinity(), [&](auto tile_idx) { + const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); + const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{}); + return mask.IsOutOfBound(row, col); + }); + } + } + + const auto s = cast_tile(s_acc); // S{j} + auto m_local = block_tile_reduce( + s, + sequence<1>{}, + f_max, + -NumericLimits::Infinity()); // m_local = rowmax(S{j}) + block_tile_reduce_sync(m_local, f_max, bool_constant{}); + + const auto m_old = m; // m{j-1} + tile_elementwise_inout( + [](auto& e0, auto e1, auto e2) { e0 = max(e1, e2); }, m, m_old, m_local); // m{j} + + auto p_compute = make_static_distributed_tensor( + s.get_tile_distribution()); // Pcompute{j} + + __builtin_amdgcn_sched_barrier(0x7F); + // store & prefetch next v, after the max reduction + if constexpr(ck_tile::is_same_v) + { + auto v_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledVRegBlockDescriptor()); + shuffle_tile(v_shuffle_tmp, v_buf); + + auto v_lds_window_tmp = + get_slice_tile(v_lds_window, + sequence<(LdsSeq.at(number{})) * kN1, 0>{}, + sequence<(LdsSeq.at(number{}) + 1) * kN1, kK1>{}); + + store_tile( + v_lds_window_tmp, + tile_elementwise_in(v_element_func, v_shuffle_tmp)); // store the prefetch + } + else + { + auto v_lds_window_tmp = + get_slice_tile(v_lds_window, + sequence<(LdsSeq.at(number{})) * kN1, 0>{}, + sequence<(LdsSeq.at(number{}) + 1) * kN1, kK1>{}); + store_tile(v_lds_window_tmp, + tile_elementwise_in(v_element_func, v_buf)); // store the prefetch + } + + if constexpr(k1_loops > 1) + { + move_tile_window( + v_dram_window, + {0, kK1}); // will have scratch if move this right after load_tile(v_dram)... + v_buf = load_tile(v_dram_window, bool_constant{}); // load next v_buf + } + __builtin_amdgcn_sched_barrier(0); + + static const auto get_validated_m = [](SMPLComputeDataType raw_m) { + /// NOTICE: bias might be materialized mask including -inf values, need + /// consideration + if constexpr(kHasBias || FmhaMask::IsMasking) + { + return raw_m == -NumericLimits::Infinity() + ? type_convert(0.f) + : raw_m; + } + else + { + return raw_m; + } + }; + + constexpr auto p_spans = decltype(p_compute)::get_distributed_spans(); + sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); +#if CK_FMHA_FWD_FAST_EXP2 + auto row_max = scale * get_validated_m(m[i_idx]); +#endif + sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); +#if CK_FMHA_FWD_FAST_EXP2 + if constexpr(kHasBias) + { + p_compute(i_j_idx) = math::exp2(s[i_j_idx] - get_validated_m(m[i_idx])); + } + else + { + p_compute(i_j_idx) = math::exp2(scale * s[i_j_idx] - row_max); + } +#else + p_compute(i_j_idx) = math::exp(s[i_j_idx] - get_validated_m(m[i_idx])); +#endif + }); + }); + + auto rowsum_p = block_tile_reduce( + p_compute, sequence<1>{}, f_sum, SMPLComputeDataType{0}); // rowsum(Pcompute{j}) + + block_tile_reduce_sync(rowsum_p, f_sum, bool_constant{}); + // l{j}, Oacc{j} + constexpr auto o_spans = decltype(o_acc)::get_distributed_spans(); + sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); +#if CK_FMHA_FWD_FAST_EXP2 + const auto tmp = [&]() { + if constexpr(kHasBias) + { + return math::exp2(m_old[i_idx] - get_validated_m(m[i_idx])); + } + else + { + auto row_max = scale * get_validated_m(m[i_idx]); + return math::exp2(scale * m_old[i_idx] - row_max); + } + }(); +#else + const auto tmp = math::exp(m_old[i_idx] - get_validated_m(m[i_idx])); +#endif + l(i_idx) = tmp * l[i_idx] + rowsum_p[i_idx]; + sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); + // FIXME: this use different equation from FA v2 paper, + // but produce correc result. + // Is the equation wrong? + o_acc(i_j_idx) *= tmp; + }); + }); + + const auto p = cast_tile(p_compute); + + // STAGE 3, KV gemm + if constexpr(k1_loops > 1) + { + static_for<0, k1_loops - 1, 1>{}([&](auto i_k1) { + if constexpr(i_k1 != 0 && i_k1 < k1_loops - 1) + { + v_buf = load_tile(v_dram_window, bool_constant{}); // load next v_buf + } + block_sync_lds(); + gemm_1(o_acc, + get_slice_tile( + p, sequence<0, i_k1 * kK1>{}, sequence{}), + get_slice_tile( + v_lds_window, + sequence<(LdsSeq.at(number{})) * kN1, 0>{}, + sequence<(LdsSeq.at(number{}) + 1) * kN1, kK1>{})); + + if constexpr(ck_tile::is_same_v) + { + auto v_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledVRegBlockDescriptor()); + shuffle_tile(v_shuffle_tmp, v_buf); + auto v_lds_window_tmp = get_slice_tile( + v_lds_window, + sequence<(LdsSeq.at(number{})) * kN1, 0>{}, + sequence<(LdsSeq.at(number{}) + 1) * kN1, kK1>{}); + store_tile(v_lds_window_tmp, + tile_elementwise_in(v_element_func, + v_shuffle_tmp)); // store the prefetch + } + else + { + auto v_lds_window_tmp = get_slice_tile( + v_lds_window, + sequence<(LdsSeq.at(number{})) * kN1, 0>{}, + sequence<(LdsSeq.at(number{}) + 1) * kN1, kK1>{}); + store_tile(v_lds_window_tmp, + tile_elementwise_in(v_element_func, v_buf)); // store next v_buf + } + if constexpr(i_k1 < k1_loops - 1) + move_tile_window(v_dram_window, {0, kK1}); + }); + } + i_total_loops++; + if(i_total_loops < num_total_loop) + { + // move K tile windows + move_tile_window(k_dram_block_window, {kN0, 0}); + k_dram_window = + make_tile_window(k_dram_block_window.get_bottom_tensor_view(), + k_dram_block_window.get_window_lengths(), + k_dram_block_window.get_window_origin(), + Policy::template MakeKDramTileDistribution()); + + if constexpr(k1_loops >= 2 && + LdsSeq.at(number<0>{}) == LdsSeq.at(number{})) + __builtin_amdgcn_s_barrier(); + async_load_tile_raw(k_lds_store(LdsSeq.at(number<0>{})), k_dram_window); + move_tile_window(k_dram_window, {0, kK0}); + } + // tail + { + block_sync_lds(); + gemm_1( + o_acc, + get_slice_tile(p, sequence<0, (k1_loops - 1) * kK1>{}, sequence{}), + get_slice_tile( + v_lds_window, + sequence<(LdsSeq.at(number{})) * kN1, 0>{}, + sequence<(LdsSeq.at(number{}) + 1) * kN1, kK1>{})); + } + } while(i_total_loops < num_total_loop); + + // store lse + if constexpr(kStoreLSE) + { + auto lse = make_static_distributed_tensor(m.get_tile_distribution()); + + constexpr auto lse_spans = decltype(lse)::get_distributed_spans(); + sweep_tile_span(lse_spans[number<0>{}], [&, m_ = m, l_ = l](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); +#if CK_FMHA_FWD_FAST_EXP2 + if constexpr(kHasBias) + { + lse(i_idx) = m_[i_idx] * R_LOG2E + math::log(l_[i_idx]); + } + else + { + lse(i_idx) = m_[i_idx] * scale * R_LOG2E + math::log(l_[i_idx]); + } +#else + lse(i_idx) = m_[i_idx] + math::log(l_[i_idx]); +#endif + }); + + store_tile(lse_dram_window_tmp, tile_elementwise_in(lse_element_func, lse)); + } + + // finally, O + constexpr auto o_spans = decltype(o_acc)::get_distributed_spans(); + + sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); + const auto tmp = [&]() { + if constexpr(FmhaMask::IsMasking) + { + return l[i_idx] == 0.f ? 0.f : 1 / l[i_idx]; + } + else + return 1 / l[i_idx]; + }(); + sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); + o_acc(i_j_idx) *= tmp; + }); + }); + + return o_acc; + } + + template + CK_TILE_HOST_DEVICE auto + operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile + const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile + const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile + const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile + LSEDramBlockWindowTmp& lse_dram_block_window_tmp, // M0*1 tile + FmhaMask mask, + float scale, + void* smem_ptr) const + { + return operator()(q_dram_block_window_tmp, + identity{}, + k_dram_block_window_tmp, + identity{}, + v_dram_block_window_tmp, + identity{}, + bias_dram_block_window_tmp, + identity{}, + lse_dram_block_window_tmp, + identity{}, + mask, + scale, + smem_ptr); + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_default_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_default_policy.hpp new file mode 100644 index 0000000000..dda6e1dc2f --- /dev/null +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_default_policy.hpp @@ -0,0 +1,18 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" + +namespace ck_tile { + +// This pipeline is qkv all located in LDS +using BlockFmhaPipelineQRKSVSAsyncDefaultPolicy = + BlockFmhaPipelineQXKSVSCustomPolicy; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_default_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_default_policy.hpp new file mode 100644 index 0000000000..c3ae38dae6 --- /dev/null +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_default_policy.hpp @@ -0,0 +1,18 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" + +namespace ck_tile { + +// This pipeline is qkv all located in LDS +using BlockFmhaPipelineQRKSVSDefaultPolicy = + BlockFmhaPipelineQXKSVSCustomPolicy; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_fp8.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_fp8.hpp new file mode 100644 index 0000000000..20eb15adc0 --- /dev/null +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_fp8.hpp @@ -0,0 +1,507 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" + +namespace ck_tile { + +// This pipeline is qkv all located in LDS +template +struct BlockFmhaPipelineQRKSVSFp8 +{ + using Problem = remove_cvref_t; + using Policy = remove_cvref_t; + using QDataType = remove_cvref_t; + using KDataType = remove_cvref_t; + using VDataType = remove_cvref_t; + using SaccDataType = remove_cvref_t; + using SMPLComputeDataType = remove_cvref_t; + using BiasDataType = remove_cvref_t; + using LSEDataType = remove_cvref_t; + using PDataType = remove_cvref_t; + using OaccDataType = remove_cvref_t; + using ODataType = remove_cvref_t; + using FmhaMask = remove_cvref_t; + + using BlockFmhaShape = remove_cvref_t; + using VLayout = remove_cvref_t; + static constexpr bool kQLoadOnce = true; // if q_tile load whole block length (hdim) at once + static_assert(kQLoadOnce == Policy::QLoadOnce); + static constexpr bool kIsFp8 = Problem::kIsFp8; + + static constexpr index_t kBlockSize = Problem::kBlockSize; + + static constexpr index_t kM0 = BlockFmhaShape::kM0; + static constexpr index_t kN0 = BlockFmhaShape::kN0; + static constexpr index_t kK0 = BlockFmhaShape::kK0; + static constexpr index_t kN1 = BlockFmhaShape::kN1; + static constexpr index_t kK1 = BlockFmhaShape::kK1; + static constexpr index_t kK0BlockLength = BlockFmhaShape::kK0BlockLength; + + static constexpr bool kIsGroupMode = Problem::kIsGroupMode; + static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ; + static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK; + static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ; + static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV; + static constexpr bool kHasBias = Problem::kHasBias; + static constexpr bool kStoreLSE = Problem::kStoreLSE; + + // last dimension vector length used to create tensor view(and decide buffer_load vector length) + // ... together with tensor distribution. tensor dist should able to overwrite this + static constexpr index_t kAlignmentQ = + kPadHeadDimQ ? 1 : Policy::template GetAlignmentQ(); + static constexpr index_t kAlignmentK = + kPadHeadDimQ ? 1 : Policy::template GetAlignmentK(); + static constexpr index_t kAlignmentV = []() { + if constexpr(ck_tile::is_same_v) + return kPadHeadDimV ? 1 : Policy::template GetAlignmentV(); + else + return kPadSeqLenK ? 1 : Policy::template GetAlignmentV(); + }(); + + static constexpr index_t kAlignmentO = + kPadHeadDimV ? 1 : Policy::template GetAlignmentO(); + static constexpr index_t kAlignmentBias = + kPadSeqLenK ? 1 : Policy::template GetAlignmentBias(); + + static constexpr index_t kBlockPerCu = []() { + if constexpr(Problem::kBlockPerCu != -1) + return Problem::kBlockPerCu; + else + { + if constexpr(kK0BlockLength <= 32) + { + return 2; + } + else if constexpr(kK0BlockLength <= 64) + { + return 3; + } + else if constexpr(kK0BlockLength <= 128) + { + if constexpr(kHasBias) + return 1; + else + return 2; + } + else if constexpr(kK0BlockLength <= 256) + { + return 1; + } + } + }(); + + static constexpr const char* name = "qr_fp8"; + + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() + { + return Policy::template GetSmemSize(); + } + + template + CK_TILE_HOST_DEVICE auto + operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile + const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile + const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile + const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile + LSEDramBlockWindowTmp& /*lse_dram_window_tmp*/, // not supported + FmhaMask mask, + float scale, + float descale_qk, + float descale_sv, + void* smem_ptr) const + { + static_assert( + is_same_v> && + is_same_v> && + is_same_v>, + "wrong!"); + + static_assert(kM0 == QDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kN0 == KDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kK0 == KDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] && + kN1 == VDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kK1 == VDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] && + kM0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kN0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<1>{}], + "wrong!"); + + // K tile in LDS + KDataType* k_lds_ptr = static_cast(static_cast( + static_cast(smem_ptr) + Policy::template GetSmemSizeQ())); + auto k_lds = make_tensor_view( + k_lds_ptr, Policy::template MakeKLdsBlockDescriptor()); + auto k_lds_window = + make_tile_window(k_lds, make_tuple(number{}, number{}), {0, 0}); + + // V tile in LDS + auto v_lds = make_tensor_view( + reinterpret_cast(smem_ptr), + Policy::template MakeVLdsBlockDescriptor()); + auto v_lds_window = make_tile_window( + v_lds, Policy::template MakeVLdsBlockDescriptor().get_lengths(), {0, 0}); + + // Block GEMM + constexpr auto gemm_0 = Policy::template GetQKBlockGemm(); + constexpr auto gemm_1 = Policy::template GetKVBlockGemm(); + + auto q_dram_window = make_tile_window( + q_dram_block_window_tmp.get_bottom_tensor_view(), + q_dram_block_window_tmp.get_window_lengths(), + q_dram_block_window_tmp.get_window_origin(), + Policy::template MakeQDramTileDistribution()); + + auto q = load_tile(q_dram_window); + + using SaccBlockTileType = decltype(gemm_0.MakeCBlockTile()); + auto s_acc = SaccBlockTileType{}; + + // reduction function for softmax + const auto f_max = [](auto e0, auto e1) { return max(e0, e1); }; + const auto f_sum = [](auto e0, auto e1) { return e0 + e1; }; + + // infer Sacc, S, P, M, L, Oacc type + using SBlockTileType = decltype(cast_tile(s_acc)); + + using MLBlockTileType = decltype(block_tile_reduce( + SBlockTileType{}, sequence<1>{}, f_max, SMPLComputeDataType{0})); + + using OaccBlockTileType = decltype(gemm_1.MakeCBlockTile()); + + // init Oacc, M, L + auto o_acc = OaccBlockTileType{}; + auto m = MLBlockTileType{}; + auto l = MLBlockTileType{}; + + clear_tile(o_acc); + set_tile(m, -NumericLimits::Infinity()); + clear_tile(l); + + const auto q_origin = q_dram_window.get_window_origin(); + const auto [seqlen_k_start, seqlen_k_end] = + mask.GetTileRangeAlongX(q_origin.at(number<0>{}), number{}, number{}); + + const auto num_total_loop = math::integer_divide_ceil(seqlen_k_end - seqlen_k_start, kN0); + + // check early exit if masked and no work to do. + if constexpr(FmhaMask::IsMasking) + { + if(num_total_loop <= 0) + { + // Note: here occ are all cleard, return it + // Note: q loaded but no fence, ignore it. + return o_acc; + } + } + + auto k_dram_block_window = + make_tile_window(k_dram_block_window_tmp.get_bottom_tensor_view(), + k_dram_block_window_tmp.get_window_lengths(), + {seqlen_k_start, 0}); + + const auto bias_origin = bias_dram_block_window_tmp.get_window_origin(); + auto bias_dram_window = make_tile_window( + bias_dram_block_window_tmp.get_bottom_tensor_view(), + bias_dram_block_window_tmp.get_window_lengths(), + {bias_origin.at(number<0>{}), seqlen_k_start}, // M/N + Policy::template MakeBiasDramTileDistribution()); + + auto v_dram_window = + make_tile_window(v_dram_block_window_tmp.get_bottom_tensor_view(), + v_dram_block_window_tmp.get_window_lengths(), + {0, seqlen_k_start}, // TODO: hdim split? + Policy::template MakeVDramTileDistribution()); + + // auto q_tile = tile_elementwise_in(q_element_func, q); + auto q_tile = q; + + // prefetch K tile + index_t i_total_loops = 0; + constexpr index_t k0_loops = kK0BlockLength / kK0; + constexpr index_t k1_loops = kN0 / kK1; + + static_assert(2 <= k0_loops); + static_assert(1 <= k1_loops); + + scale = scale * descale_qk; + do + { + // STAGE 1, QK gemm + auto k_dram_window = make_tile_window( + k_dram_block_window.get_bottom_tensor_view(), + k_dram_block_window.get_window_lengths(), + k_dram_block_window.get_window_origin(), + Policy::template MakeKDramTileDistribution()); // K DRAM tile window for + // load + + auto k_block_tile = load_tile(k_dram_window); + { + move_tile_window(k_dram_window, {0, kK0}); + clear_tile(s_acc); // initialize C + store_tile(k_lds_window, k_block_tile); + k_block_tile = load_tile(k_dram_window); + } + + if constexpr(kHasBias) + { + __builtin_amdgcn_sched_barrier( + 0); // prevent from messing up the order of global loads + } + const auto bias_tile = load_tile(bias_dram_window); // load bias tile + if constexpr(kHasBias) + { + __builtin_amdgcn_sched_barrier( + 0); // prevent from messing up the order of global loads + } + + if constexpr(k0_loops > 2) + { + static_for<0, k0_loops - 2, 1>{}([&](auto i_k0) { + block_sync_lds(); + gemm_0(s_acc, + get_slice_tile(q_tile, + sequence<0, i_k0 * kK0>{}, + sequence{}), + k_lds_window); + block_sync_lds(); + move_tile_window(k_dram_window, {0, kK0}); + + store_tile(k_lds_window, + k_block_tile); // LDS write i + 1 + k_block_tile = load_tile(k_dram_window); // global read i + 2 + }); + } + + const auto v_prefetch = load_tile(v_dram_window); // prefetch load v tile + { // tail + block_sync_lds(); + gemm_0(s_acc, + get_slice_tile(q_tile, + sequence<0, (k0_loops - 2) * kK0>{}, + sequence{}), + k_lds_window); + block_sync_lds(); + + store_tile(k_lds_window, k_block_tile); + block_sync_lds(); + + gemm_0(s_acc, + get_slice_tile(q_tile, + sequence<0, (k0_loops - 1) * kK0>{}, + sequence{}), + k_lds_window); + } + + // STAGE 2, scale, add bias, mask, softmax + if constexpr(kHasBias) + { + tile_elementwise_inout( + [&](auto& x, const auto& y) { +#if !CK_FMHA_FWD_FAST_EXP2 + x = scale * x + type_convert((y)); +#else + x = scale * x + + math::log2e_v * type_convert((y)); +#endif + }, + s_acc, + bias_tile); + } + else + { +#if !CK_FMHA_FWD_FAST_EXP2 + tile_elementwise_inout([&scale](auto& x) { x = x * scale; }, s_acc); +#endif + } + move_tile_window(bias_dram_window, {0, kN0}); + if constexpr(kPadSeqLenK || FmhaMask::IsMasking) + { + const auto k_origin = k_dram_block_window.get_window_origin(); + bool need_perpixel_check = mask.IsEdgeTile(q_origin.at(number<0>{}), + k_origin.at(number<0>{}), + number{}, + number{}); + if(need_perpixel_check) + { + set_tile_if( + s_acc, -NumericLimits::Infinity(), [&](auto tile_idx) { + const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); + const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{}); + return mask.IsOutOfBound(row, col); + }); + } + } + + const auto s = cast_tile(s_acc); // S{j} + auto m_local = block_tile_reduce( + s, + sequence<1>{}, + f_max, + -NumericLimits::Infinity()); // m_local = rowmax(S{j}) + block_tile_reduce_sync(m_local, f_max, bool_constant{}); + + const auto m_old = m; // m{j-1} + tile_elementwise_inout( + [](auto& e0, auto e1, auto e2) { e0 = max(e1, e2); }, m, m_old, m_local); // m{j} + + auto p_compute = make_static_distributed_tensor( + s.get_tile_distribution()); // Pcompute{j} + + static const auto get_validated_m = [](SMPLComputeDataType raw_m) { + /// NOTICE: bias might be materialized mask including -inf values, need + /// consideration + if constexpr(kHasBias || FmhaMask::IsMasking) + { + return raw_m == -NumericLimits::Infinity() + ? type_convert(0.f) + : raw_m; + } + else + { + return raw_m; + } + }; + + constexpr auto p_spans = decltype(p_compute)::get_distributed_spans(); + sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); +#if CK_FMHA_FWD_FAST_EXP2 + auto row_max = scale * get_validated_m(m[i_idx]); +#endif + sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); +#if CK_FMHA_FWD_FAST_EXP2 + if constexpr(kHasBias) + { + p_compute(i_j_idx) = math::exp2(s[i_j_idx] - get_validated_m(m[i_idx])); + } + else + { + p_compute(i_j_idx) = math::exp2(scale * s[i_j_idx] - row_max); + } +#else + p_compute(i_j_idx) = math::exp(s[i_j_idx] - get_validated_m(m[i_idx])); +#endif + }); + }); + + auto rowsum_p = block_tile_reduce( + p_compute, sequence<1>{}, f_sum, SMPLComputeDataType{0}); // rowsum(Pcompute{j}) + + block_tile_reduce_sync(rowsum_p, f_sum, bool_constant{}); + // l{j}, Oacc{j} + constexpr auto o_spans = decltype(o_acc)::get_distributed_spans(); + sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); +#if CK_FMHA_FWD_FAST_EXP2 + const auto tmp = [&]() { + if constexpr(kHasBias) + { + return math::exp2(m_old[i_idx] - get_validated_m(m[i_idx])); + } + else + { + auto row_max = scale * get_validated_m(m[i_idx]); + return math::exp2(scale * m_old[i_idx] - row_max); + } + }(); +#else + const auto tmp = math::exp(m_old[i_idx] - get_validated_m(m[i_idx])); +#endif + l(i_idx) = tmp * l[i_idx] + rowsum_p[i_idx]; + sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); + // FIXME: this use different equation from FA v2 paper, + // but produce correc result. + // Is the equation wrong? + o_acc(i_j_idx) *= tmp; + }); + }); + + block_sync_lds(); + if constexpr(ck_tile::is_same_v) + { + auto v_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledVRegBlockDescriptor()); + shuffle_tile(v_shuffle_tmp, v_prefetch); + store_tile(v_lds_window, + v_shuffle_tmp); // store the prefetch + } + else + { + store_tile(v_lds_window, + v_prefetch); // store the prefetch + } + move_tile_window(v_dram_window, {0, kK1}); + + const auto p = cast_tile(p_compute); + + // STAGE 3, KV gemm + if constexpr(k1_loops > 1) + { + static_for<0, k1_loops - 1, 1>{}([&](auto i_k1) { + const auto v = load_tile(v_dram_window); // load next v + block_sync_lds(); + gemm_1(o_acc, + get_slice_tile( + p, sequence<0, i_k1 * kK1>{}, sequence{}), + v_lds_window); + block_sync_lds(); + if constexpr(ck_tile::is_same_v) + { + auto v_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledVRegBlockDescriptor()); + shuffle_tile(v_shuffle_tmp, v); + store_tile(v_lds_window, v_shuffle_tmp); + } + else + { + store_tile(v_lds_window, v); + } + move_tile_window(v_dram_window, {0, kK1}); + }); + } + // move K tile windows + move_tile_window(k_dram_block_window, {kN0, 0}); + // tail + { + block_sync_lds(); + gemm_1(o_acc, + get_slice_tile(p, sequence<0, (k1_loops - 1) * kK1>{}, sequence{}), + v_lds_window); + block_sync_lds(); + } + } while(++i_total_loops < num_total_loop); + + // finally, O + constexpr auto o_spans = decltype(o_acc)::get_distributed_spans(); + + sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); + auto tmp = [&]() { + if constexpr(FmhaMask::IsMasking) + { + return l[i_idx] == 0.f ? 0.f : 1 / l[i_idx]; + } + else + return 1 / l[i_idx]; + }(); + tmp = tmp * descale_sv; + sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); + o_acc(i_j_idx) *= tmp; + }); + }); + + return o_acc; + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs.hpp new file mode 100644 index 0000000000..9a9a24e266 --- /dev/null +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs.hpp @@ -0,0 +1,573 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" + +namespace ck_tile { + +// This pipeline is qkv all located in LDS +template +struct BlockFmhaPipelineQSKSVS +{ + using Problem = remove_cvref_t; + using Policy = remove_cvref_t; + using QDataType = remove_cvref_t; + using KDataType = remove_cvref_t; + using VDataType = remove_cvref_t; + using SaccDataType = remove_cvref_t; + using SMPLComputeDataType = remove_cvref_t; + using BiasDataType = remove_cvref_t; + using LSEDataType = remove_cvref_t; + using PDataType = remove_cvref_t; + using OaccDataType = remove_cvref_t; + using ODataType = remove_cvref_t; + using FmhaMask = remove_cvref_t; + + using BlockFmhaShape = remove_cvref_t; + using VLayout = remove_cvref_t; + static constexpr bool kQLoadOnce = false; + static_assert(kQLoadOnce == Policy::QLoadOnce); + static constexpr bool kIsFp8 = Problem::kIsFp8; + + static constexpr index_t kBlockSize = Problem::kBlockSize; + + static constexpr index_t kM0 = BlockFmhaShape::kM0; + static constexpr index_t kN0 = BlockFmhaShape::kN0; + static constexpr index_t kK0 = BlockFmhaShape::kK0; + static constexpr index_t kN1 = BlockFmhaShape::kN1; + static constexpr index_t kK1 = BlockFmhaShape::kK1; + static constexpr index_t kK0BlockLength = BlockFmhaShape::kK0BlockLength; + + static constexpr bool kIsGroupMode = Problem::kIsGroupMode; + static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ; + static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK; + static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ; + static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV; + static constexpr bool kHasBias = Problem::kHasBias; + static constexpr bool kStoreLSE = Problem::kStoreLSE; + + static constexpr index_t kBlockPerCu = []() { + if constexpr(Problem::kBlockPerCu != -1) + return Problem::kBlockPerCu; + else + { + if constexpr(kK0BlockLength <= 32) + { + return 2; + } + else if constexpr(kK0BlockLength <= 64) + { + return 3; + } + else if constexpr(kK0BlockLength <= 128) + { + if constexpr(kHasBias) + return 1; + else + return 2; + } + else if constexpr(kK0BlockLength <= 256) + { + return 1; + } + } + }(); + + static constexpr const char* name = "qs"; + + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() + { + return Policy::template GetSmemSize(); + } + + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeQ() + { + return Policy::template GetSmemSizeQ(); + } + + template + CK_TILE_HOST_DEVICE auto + operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile + const QElementFunction& q_element_func, + const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile + const KElementFunction& k_element_func, + const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile + const VElementFunction& v_element_func, + const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile + const BiasElementFunction& bias_element_func, + LSEDramBlockWindowTmp& lse_dram_window_tmp, // M0*1 tile + const LSEElementFunction& lse_element_func, + FmhaMask mask, + float scale, + void* smem_ptr) const + { + static_assert( + is_same_v> && + is_same_v> && + is_same_v>, + "wrong!"); + + static_assert(kM0 == QDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kN0 == KDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kK0 == KDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] && + kN1 == VDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kK1 == VDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] && + kM0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kN0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<1>{}], + "wrong!"); + + // Q tile in LDS + auto q_lds = make_tensor_view( + reinterpret_cast(smem_ptr), + Policy::template MakeQLdsBlockDescriptor()); + auto q_lds_window = + make_tile_window(q_lds, make_tuple(number{}, number{}), {0, 0}); + + // K tile in LDS + KDataType* k_lds_ptr = static_cast(static_cast( + static_cast(smem_ptr) + Policy::template GetSmemSizeQ())); + auto k_lds = make_tensor_view( + k_lds_ptr, Policy::template MakeKLdsBlockDescriptor()); + auto k_lds_window = + make_tile_window(k_lds, make_tuple(number{}, number{}), {0, 0}); + + // V tile in LDS + auto v_lds = make_tensor_view( + reinterpret_cast(smem_ptr), + Policy::template MakeVLdsBlockDescriptor()); + auto v_lds_window = make_tile_window( + v_lds, Policy::template MakeVLdsBlockDescriptor().get_lengths(), {0, 0}); + + // Block GEMM + constexpr auto gemm_0 = Policy::template GetQKBlockGemm(); + constexpr auto gemm_1 = Policy::template GetKVBlockGemm(); + + using SaccBlockTileType = decltype(gemm_0.MakeCBlockTile()); + auto s_acc = SaccBlockTileType{}; + + // reduction function for softmax + const auto f_max = [](auto e0, auto e1) { return max(e0, e1); }; + const auto f_sum = [](auto e0, auto e1) { return e0 + e1; }; + + // infer Sacc, S, P, M, L, Oacc type + using SBlockTileType = decltype(cast_tile(s_acc)); + + using MLBlockTileType = decltype(block_tile_reduce( + SBlockTileType{}, sequence<1>{}, f_max, SMPLComputeDataType{0})); + + using OaccBlockTileType = decltype(gemm_1.MakeCBlockTile()); + + // init Oacc, M, L + auto o_acc = OaccBlockTileType{}; + auto m = MLBlockTileType{}; + auto l = MLBlockTileType{}; + + clear_tile(o_acc); + set_tile(m, -NumericLimits::Infinity()); + clear_tile(l); + + const auto q_origin = q_dram_block_window_tmp.get_window_origin(); + const auto [seqlen_k_start, seqlen_k_end] = + mask.GetTileRangeAlongX(q_origin.at(number<0>{}), number{}, number{}); + + const auto num_total_loop = math::integer_divide_ceil(seqlen_k_end - seqlen_k_start, kN0); + + // check early exit if masked and no work to do. + if constexpr(FmhaMask::IsMasking) + { + if(num_total_loop <= 0) + { + if constexpr(kStoreLSE) + { + auto lse = + make_static_distributed_tensor(m.get_tile_distribution()); + + set_tile(lse, -NumericLimits::Infinity()); + + store_tile(lse_dram_window_tmp, tile_elementwise_in(lse_element_func, lse)); + } + + // Note: here occ are all cleard, return it + // Note: q loaded but no fence, ignore it. + return o_acc; + } + } + + auto k_dram_block_window = + make_tile_window(k_dram_block_window_tmp.get_bottom_tensor_view(), + k_dram_block_window_tmp.get_window_lengths(), + {seqlen_k_start, 0}); + + const auto bias_origin = bias_dram_block_window_tmp.get_window_origin(); + auto bias_dram_window = make_tile_window( + bias_dram_block_window_tmp.get_bottom_tensor_view(), + bias_dram_block_window_tmp.get_window_lengths(), + {bias_origin.at(number<0>{}), seqlen_k_start}, // M/N + Policy::template MakeBiasDramTileDistribution()); + + auto v_dram_window = + make_tile_window(v_dram_block_window_tmp.get_bottom_tensor_view(), + v_dram_block_window_tmp.get_window_lengths(), + {0, seqlen_k_start}, // TODO: hdim split? + Policy::template MakeVDramTileDistribution()); + + // prefetch K tile + index_t i_total_loops = 0; + constexpr index_t k0_loops = kK0BlockLength / kK0; + constexpr index_t k1_loops = kN0 / kK1; + + static_assert(2 <= k0_loops); + static_assert(1 <= k1_loops); + do + { + // STAGE 1, QK gemm + auto q_dram_window = + make_tile_window(q_dram_block_window_tmp.get_bottom_tensor_view(), + q_dram_block_window_tmp.get_window_lengths(), + q_dram_block_window_tmp.get_window_origin(), + Policy::template MakeQDramTileDistribution()); + + auto k_dram_window = + make_tile_window(k_dram_block_window.get_bottom_tensor_view(), + k_dram_block_window.get_window_lengths(), + k_dram_block_window.get_window_origin(), + Policy::template MakeKDramTileDistribution()); + + auto q_block_tile = load_tile(q_dram_window); + auto k_block_tile = load_tile(k_dram_window); + { + move_tile_window(q_dram_window, {0, kK0}); + move_tile_window(k_dram_window, {0, kK0}); + + clear_tile(s_acc); // initialize C + + store_tile(q_lds_window, tile_elementwise_in(q_element_func, q_block_tile)); + q_block_tile = load_tile(q_dram_window); + + store_tile(k_lds_window, tile_elementwise_in(k_element_func, k_block_tile)); + k_block_tile = load_tile(k_dram_window); + } + + if constexpr(kHasBias) + { + __builtin_amdgcn_sched_barrier( + 0); // prevent from messing up the order of global loads + } + const auto bias_tile = load_tile(bias_dram_window); // load bias tile + if constexpr(kHasBias) + { + __builtin_amdgcn_sched_barrier( + 0); // prevent from messing up the order of global loads + } + + if constexpr(k0_loops > 2) + { + static_for<0, k0_loops - 2, 1>{}([&](auto) { + block_sync_lds(); + gemm_0(s_acc, q_lds_window, k_lds_window); + block_sync_lds(); + + move_tile_window(q_dram_window, {0, kK0}); + move_tile_window(k_dram_window, {0, kK0}); + + store_tile( + q_lds_window, + tile_elementwise_in(q_element_func, q_block_tile)); // LDS write i + 1 + q_block_tile = load_tile(q_dram_window); // global read i + 2 + + store_tile( + k_lds_window, + tile_elementwise_in(k_element_func, k_block_tile)); // LDS write i + 1 + k_block_tile = load_tile(k_dram_window); // global read i + 2 + }); + } + + const auto v_prefetch = load_tile(v_dram_window); // prefetch load v tile + { // tail + block_sync_lds(); + gemm_0(s_acc, q_lds_window, k_lds_window); + block_sync_lds(); + + store_tile(q_lds_window, tile_elementwise_in(q_element_func, q_block_tile)); + store_tile(k_lds_window, tile_elementwise_in(k_element_func, k_block_tile)); + block_sync_lds(); + + gemm_0(s_acc, q_lds_window, k_lds_window); + } + + // STAGE 2, scale, add bias, mask, softmax + if constexpr(kHasBias) + { + tile_elementwise_inout( + [&](auto& x, const auto& y) { +#if !CK_FMHA_FWD_FAST_EXP2 + x = scale * x + type_convert(bias_element_func(y)); +#else + x = scale * x + math::log2e_v * + type_convert(bias_element_func(y)); +#endif + }, + s_acc, + bias_tile); + } + else + { +#if !CK_FMHA_FWD_FAST_EXP2 + tile_elementwise_inout([&scale](auto& x) { x = x * scale; }, s_acc); +#endif + } + move_tile_window(bias_dram_window, {0, kN0}); + if constexpr(kPadSeqLenK || FmhaMask::IsMasking) + { + const auto k_origin = k_dram_block_window.get_window_origin(); + bool need_perpixel_check = mask.IsEdgeTile(q_origin.at(number<0>{}), + k_origin.at(number<0>{}), + number{}, + number{}); + if(need_perpixel_check) + { + set_tile_if( + s_acc, -NumericLimits::Infinity(), [&](auto tile_idx) { + const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); + const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{}); + return mask.IsOutOfBound(row, col); + }); + } + } + + const auto s = cast_tile(s_acc); // S{j} + auto m_local = block_tile_reduce( + s, + sequence<1>{}, + f_max, + -NumericLimits::Infinity()); // m_local = rowmax(S{j}) + block_tile_reduce_sync(m_local, f_max, bool_constant{}); + + const auto m_old = m; // m{j-1} + tile_elementwise_inout( + [](auto& e0, auto e1, auto e2) { e0 = max(e1, e2); }, m, m_old, m_local); // m{j} + + auto p_compute = make_static_distributed_tensor( + s.get_tile_distribution()); // Pcompute{j} + + static const auto get_validated_m = [](SMPLComputeDataType raw_m) { + /// NOTICE: bias might be materialized mask including -inf values, need + /// consideration + if constexpr(kHasBias || FmhaMask::IsMasking) + { + return raw_m == -NumericLimits::Infinity() + ? type_convert(0.f) + : raw_m; + } + else + { + return raw_m; + } + }; + + constexpr auto p_spans = decltype(p_compute)::get_distributed_spans(); + sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); +#if CK_FMHA_FWD_FAST_EXP2 + auto row_max = scale * get_validated_m(m[i_idx]); +#endif + sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); +#if CK_FMHA_FWD_FAST_EXP2 + if constexpr(kHasBias) + { + p_compute(i_j_idx) = math::exp2(s[i_j_idx] - get_validated_m(m[i_idx])); + } + else + { + p_compute(i_j_idx) = math::exp2(scale * s[i_j_idx] - row_max); + } +#else + p_compute(i_j_idx) = math::exp(s[i_j_idx] - get_validated_m(m[i_idx])); +#endif + }); + }); + + auto rowsum_p = block_tile_reduce( + p_compute, sequence<1>{}, f_sum, SMPLComputeDataType{0}); // rowsum(Pcompute{j}) + + block_tile_reduce_sync(rowsum_p, f_sum, bool_constant{}); + // l{j}, Oacc{j} + constexpr auto o_spans = decltype(o_acc)::get_distributed_spans(); + sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); +#if CK_FMHA_FWD_FAST_EXP2 + const auto tmp = [&]() { + if constexpr(kHasBias) + { + return math::exp2(m_old[i_idx] - get_validated_m(m[i_idx])); + } + else + { + auto row_max = scale * get_validated_m(m[i_idx]); + return math::exp2(scale * m_old[i_idx] - row_max); + } + }(); +#else + const auto tmp = math::exp(m_old[i_idx] - get_validated_m(m[i_idx])); +#endif + l(i_idx) = tmp * l[i_idx] + rowsum_p[i_idx]; + sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); + // FIXME: this use different equation from FA v2 paper, + // but produce correc result. + // Is the equation wrong? + o_acc(i_j_idx) *= tmp; + }); + }); + + block_sync_lds(); + if constexpr(ck_tile::is_same_v) + { + auto v_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledVRegBlockDescriptor()); + shuffle_tile(v_shuffle_tmp, v_prefetch); + store_tile( + v_lds_window, + tile_elementwise_in(v_element_func, v_shuffle_tmp)); // store the prefetch + } + else + { + store_tile(v_lds_window, + tile_elementwise_in(v_element_func, v_prefetch)); // store the prefetch + } + move_tile_window(v_dram_window, {0, kK1}); + + const auto p = cast_tile(p_compute); + + // STAGE 3, KV gemm + if constexpr(k1_loops > 1) + { + static_for<0, k1_loops - 1, 1>{}([&](auto i_k1) { + const auto v = load_tile(v_dram_window); // load next v + block_sync_lds(); + gemm_1(o_acc, + get_slice_tile( + p, sequence<0, i_k1 * kK1>{}, sequence{}), + v_lds_window); + block_sync_lds(); + if constexpr(ck_tile::is_same_v) + { + auto v_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledVRegBlockDescriptor()); + shuffle_tile(v_shuffle_tmp, v); + store_tile(v_lds_window, + tile_elementwise_in(v_element_func, + v_shuffle_tmp)); // store the prefetch + } + else + { + store_tile(v_lds_window, + tile_elementwise_in(v_element_func, v)); // store next v + } + move_tile_window(v_dram_window, {0, kK1}); + }); + } + // move K tile windows + move_tile_window(k_dram_block_window, {kN0, 0}); + // tail + { + block_sync_lds(); + gemm_1(o_acc, + get_slice_tile(p, sequence<0, (k1_loops - 1) * kK1>{}, sequence{}), + v_lds_window); + block_sync_lds(); + } + } while(++i_total_loops < num_total_loop); + + // store lse + if constexpr(kStoreLSE) + { + auto lse = make_static_distributed_tensor(m.get_tile_distribution()); + + constexpr auto lse_spans = decltype(lse)::get_distributed_spans(); + sweep_tile_span(lse_spans[number<0>{}], [&, m_ = m, l_ = l](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); +#if CK_FMHA_FWD_FAST_EXP2 + if constexpr(kHasBias) + { + lse(i_idx) = m_[i_idx] / C_LOG2E + math::log(l_[i_idx]); + } + else + { + lse(i_idx) = m_[i_idx] * scale / C_LOG2E + math::log(l_[i_idx]); + } +#else + lse(i_idx) = m_[i_idx] + math::log(l_[i_idx]); +#endif + }); + + store_tile(lse_dram_window_tmp, tile_elementwise_in(lse_element_func, lse)); + } + + // finally, O + constexpr auto o_spans = decltype(o_acc)::get_distributed_spans(); + + sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); + const auto tmp = [&]() { + if constexpr(FmhaMask::IsMasking) + { + return l[i_idx] == 0.f ? 0.f : 1 / l[i_idx]; + } + else + return 1 / l[i_idx]; + }(); + sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); + o_acc(i_j_idx) *= tmp; + }); + }); + + return o_acc; + } + + template + CK_TILE_HOST_DEVICE auto + operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile + const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile + const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile + const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile + LSEDramBlockWindowTmp& lse_dram_block_window_tmp, // M0*1 tile + FmhaMask mask, + float scale, + void* smem_ptr) const + { + return operator()(q_dram_block_window_tmp, + identity{}, + k_dram_block_window_tmp, + identity{}, + v_dram_block_window_tmp, + identity{}, + bias_dram_block_window_tmp, + identity{}, + lse_dram_block_window_tmp, + identity{}, + mask, + scale, + smem_ptr); + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs_default_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs_default_policy.hpp new file mode 100644 index 0000000000..e7bbe1fac1 --- /dev/null +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs_default_policy.hpp @@ -0,0 +1,18 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" + +namespace ck_tile { + +// This pipeline is qkv all located in LDS +using BlockFmhaPipelineQSKSVSDefaultPolicy = + BlockFmhaPipelineQXKSVSCustomPolicy; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp new file mode 100644 index 0000000000..fd88fee078 --- /dev/null +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp @@ -0,0 +1,953 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/common.hpp" + +// TODO: remove this +#define K_LDS_LOAD_USE_OFFSET_TRANSFORM 0 + +namespace ck_tile { + +template +struct BlockFmhaPipelineQXCustomPolicy; + +template <> +struct BlockFmhaPipelineQXCustomPolicy +{ + static constexpr bool QLoadOnce = true; + + template + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeQ() + { + return 0; + } + + // TODO: GetAlignment*() currently didn't consider if need padding or not + // so in pipeline still need check padding requirement + template + CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentQ() + { + using BlockGemm = remove_cvref_t())>; + constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); + using WG = remove_cvref_t())>; + return WG::kK / WG::WarpGemmAttribute::Impl::kABKLane; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeQDramTileDistribution() + { + constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); + using WG = remove_cvref_t())>; + constexpr index_t MWarp = config.template at<1>(); + + constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0BlockLength; + + constexpr index_t K2 = WG::kK / WG::WarpGemmAttribute::Impl::kABKLane; + constexpr index_t K1 = WG::WarpGemmAttribute::Impl::kABKLane; + constexpr index_t K0 = kKPerBlock / (K1 * K2); + + constexpr index_t M2 = WG::WarpGemmAttribute::Impl::kAMLane; + constexpr index_t M1 = MWarp; + constexpr index_t M0 = kMPerBlock / (M2 * M1); + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<2, 1>>, + tuple, sequence<1, 2>>, + sequence<1, 2, 2>, + sequence<0, 0, 2>>{}); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetQKBlockGemm() + { + using BlockGemmProblem = + BlockGemmPipelineProblem>; + + constexpr auto warp_gemm = []() { + if constexpr(is_same_v && + is_same_v && + is_same_v) + { + return warp::WarpGemmMfmaF16F16F32M16N16K32SwizzleBTransposedCDistribution{}; + } + else if constexpr(is_same_v && + is_same_v && + is_same_v) + { + return warp::WarpGemmMfmaBf16Bf16F32M16N16K32SwizzleBTransposedCDistribution{}; + } + else if constexpr(Problem::kIsFp8) + { + constexpr index_t swizzle_factor = 4; // TODO: hard coded here + return warp::WarpGemmImpl< + warp::WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB< + warp::WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base< + typename Problem::QDataType, + typename Problem::KDataType>, + 2, + swizzle_factor>>{}; + } + }(); + + using BlockGemmPolicy = + BlockGemmARegBSmemCRegV2CustomPolicy; + + return BlockGemmARegBSmemCRegV2{}; + } +}; + +template <> +struct BlockFmhaPipelineQXCustomPolicy +{ + static constexpr bool QLoadOnce = false; + + template + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeQ() + { + constexpr index_t lds_alignment = 16; // optional + constexpr index_t q_smem_size = + ck_tile::integer_divide_ceil( + sizeof(typename Problem::QDataType) * + MakeQLdsBlockDescriptor().get_element_space_size(), + lds_alignment) * + lds_alignment; + return q_smem_size; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentQ() + { + using QDataType = remove_cvref_t; + return 16 / sizeof(QDataType); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeQDramTileDistribution() + { + using QDataType = remove_cvref_t; + + constexpr index_t kBlockSize = Problem::kBlockSize; + + constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0; + + constexpr index_t K1 = 16 / sizeof(QDataType); // use dwordx4. TODO: change this + constexpr index_t K0 = kKPerBlock / K1; + constexpr index_t M2 = get_warp_size() / K0; + constexpr index_t M1 = kBlockSize / get_warp_size(); + constexpr index_t M0 = kMPerBlock / (M2 * M1); + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<2, 0>>, + sequence<1, 2>, + sequence<0, 1>>{}); + } + + // 3d + padding + template + CK_TILE_HOST_DEVICE static constexpr auto MakeQLdsBlockDescriptor() + { + using QDataType = remove_cvref_t; + + constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0; + constexpr index_t kKPack = 16 / sizeof(QDataType); + + constexpr auto q_lds_block_desc_0 = make_naive_tensor_descriptor( + make_tuple(number{}, number{}, number{}), + make_tuple(number<(kMPerBlock + 1) * kKPack>{}, number{}, number<1>{}), + number<8>{}, + number<1>{}); + + constexpr auto q_lds_block_desc = transform_tensor_descriptor( + q_lds_block_desc_0, + make_tuple(make_pass_through_transform(kMPerBlock), + make_merge_transform(make_tuple(kKPerBlock / kKPack, kKPack))), + make_tuple(sequence<1>{}, sequence<0, 2>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + + return q_lds_block_desc; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetQKBlockGemm() + { + using BlockGemmProblem = + BlockGemmPipelineProblem>; + + constexpr auto warp_gemm = []() { + if constexpr(is_same_v && + is_same_v && + is_same_v) + { + return warp::WarpGemmMfmaF16F16F32M16N16K32SwizzleBTransposedCDistribution{}; + } + else if constexpr(is_same_v && + is_same_v && + is_same_v) + { + return warp::WarpGemmMfmaBf16Bf16F32M16N16K32SwizzleBTransposedCDistribution{}; + } + else if constexpr(Problem::kIsFp8) + { + constexpr index_t swizzle_factor = 4; // TODO: hard coded here + return warp::WarpGemmImpl< + warp::WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB< + warp::WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base< + typename Problem::QDataType, + typename Problem::KDataType>, + 2, + swizzle_factor>>{}; + } + }(); + + using BlockGemmPolicy = + BlockGemmASmemBSmemCRegV1CustomPolicy; + + return BlockGemmASmemBSmemCRegV1{}; + } +}; + +// This pipeline is qkv all located in LDS +template +struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy +{ + static constexpr bool AsyncCopyK = AsyncCopyK_; + static constexpr bool AsyncCopyV = AsyncCopyV_; // TODO: this not supported yet + + static constexpr index_t NumPrefetchK = NumPrefetchK_; + static constexpr index_t NumPrefetchV = NumPrefetchK_; + + using QXPolicy = BlockFmhaPipelineQXCustomPolicy; + + template + struct LdsBufferSequence + { + static constexpr auto Make() + { + return transform_sequences( + [&](auto i) { + if(i < k_loops_) + return i % k_prefetches_; + return (i - k_loops_) % v_prefetches_; + }, + typename arithmetic_sequence_gen<0, k_loops_ + v_loops_, 1>::type{}); + }; + + using type = remove_cvref_t; + }; + // clang-format off + template<> struct + LdsBufferSequence<3, 3, 4, 4> { using type = sequence<1, 2, 0, 1, 0, 1, 2, 0>; }; + + template<> struct + LdsBufferSequence<3, 3, 4, 2> { using type = sequence<1, 2, 0, 1, 2, 0>; }; + + template<> struct + LdsBufferSequence<3, 3, 2, 4> { using type = sequence<1, 2, 0, 1, 2, 0>; }; + + template<> struct + LdsBufferSequence<3, 3, 3, 3> { using type = sequence<1, 2, 0, 1, 2, 0>; }; + + template<> struct + LdsBufferSequence<3, 3, 2, 2> { using type = sequence<1, 2, 1, 0>;}; + // clang-format on + + template + CK_TILE_HOST_DEVICE static constexpr auto GetLdsBufferSequence() + { + using BlockFmhaShape = remove_cvref_t; + + constexpr index_t kN0 = BlockFmhaShape::kN0; + constexpr index_t kK0 = BlockFmhaShape::kK0; + constexpr index_t kK1 = BlockFmhaShape::kK1; + constexpr index_t kK0BlockLength = BlockFmhaShape::kK0BlockLength; + + constexpr index_t k0_loops = kK0BlockLength / kK0; + constexpr index_t k1_loops = kN0 / kK1; + + return typename LdsBufferSequence::type{}; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPackK() + { + // TODO: this is for 3d layout + using KDataType = remove_cvref_t; + return 16 / sizeof(KDataType); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentK() + { + using KDataType = remove_cvref_t; + if constexpr(AsyncCopyK) + { + return 4 / sizeof(KDataType); + } + else + { + return 16 / sizeof(KDataType); + } + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPackV() + { + // TODO: this is for 3d layout + using VDataType = remove_cvref_t; + return 16 / sizeof(VDataType); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentV() + { + using VLayout = remove_cvref_t; + using VDataType = remove_cvref_t; + if constexpr(ck_tile::is_same_v) + { + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1; + constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize; + + // TODO: not correct! + if constexpr(total_pixels > 4) + return 4; + else + return 2; + } + else + { + return 16 / sizeof(VDataType); + } + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentBias() + { + using BlockGemm = remove_cvref_t())>; + constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); + using WG = remove_cvref_t())>; + using CWarpDstr = typename WG::CWarpDstr; + constexpr auto vec = + CWarpDstr{}.get_ys_to_d_descriptor().get_lengths().at(number{}); + return vec; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentO() + { + using BlockGemm = remove_cvref_t())>; + constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); + using WG = remove_cvref_t())>; + using CWarpDstr = typename WG::CWarpDstr; + constexpr auto vec = + CWarpDstr{}.get_ys_to_d_descriptor().get_lengths().at(number{}); + return vec; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetSingleSmemElementSpaceSize() + { + // this function assume K/V can share smem + constexpr index_t SingleKSize = [&]() { + if constexpr(!AsyncCopyK) + { + return MakeKLdsBlockDescriptor().get_element_space_size(); + } + else + { + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1; + constexpr index_t NumWarps = Problem::BlockFmhaShape::NumWarps; + constexpr index_t warpSize = ck_tile::get_warp_size(); + + constexpr index_t KPack = GetSmemKPackK(); // this is for lds + constexpr index_t KVector = GetAlignmentK(); // this is for global load + constexpr index_t kPad = KPack; + + static_assert(warpSize * KVector >= kKPerBlock && + warpSize * KVector % kKPerBlock == 0); + constexpr index_t LanesPerK = kKPerBlock / KVector; + constexpr index_t LaneGroups = warpSize / LanesPerK; + constexpr index_t NumIssues = kNPerBlock / (LaneGroups * NumWarps); + + return NumIssues * NumWarps * (warpSize * KVector + kPad); + } + }(); + + constexpr index_t SingleVSize = [&]() { + using VDataType = remove_cvref_t; + constexpr index_t Banks = 32; // TODO: need change based on arch + constexpr index_t PixelsPerRow = Banks * 4 / sizeof(VDataType); + constexpr index_t kKPack = GetSmemKPackK(); + static_assert(PixelsPerRow % kKPack == 0); + constexpr index_t NPerRow = PixelsPerRow / kKPack; + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1; + static_assert(kNPerBlock % NPerRow == 0); + static_assert(kKPerBlock % kKPack == 0); + + return (kKPerBlock / kKPack) * (kNPerBlock / NPerRow) * (PixelsPerRow + kKPack); + }(); + + return math::max(SingleKSize, SingleVSize); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeQRegBlockDescriptor() + { + constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0BlockLength; + + constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); + + using WG = remove_cvref_t())>; + + constexpr index_t MWarp = config.template at<1>(); + constexpr index_t NWarp = config.template at<2>(); + + constexpr index_t MIterPerWarp = kMPerBlock / (MWarp * WG::kM); + constexpr index_t KIterPerWarp = kKPerBlock / WG::kK; + + constexpr auto q_block_outer_dstr_encoding = + tile_distribution_encoding, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + + constexpr auto q_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + q_block_outer_dstr_encoding, typename WG::AWarpDstrEncoding{}); + + constexpr auto q_block_dstr = make_static_tile_distribution(q_block_dstr_encode); + + return q_block_dstr; + } + + // TODO: this is used for non async copy desc. unify in the future + template + CK_TILE_HOST_DEVICE static constexpr auto MakeKLdsBlockDescriptor() + { + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1; + constexpr index_t kKPack = GetSmemKPackK(); + + constexpr auto k_lds_block_desc_0 = make_naive_tensor_descriptor( + make_tuple(number{}, number{}, number{}), + make_tuple(number<(kNPerBlock + 1) * kKPack>{}, number{}, number<1>{}), + number<8>{}, + number<1>{}); + + constexpr auto k_lds_block_desc = transform_tensor_descriptor( + k_lds_block_desc_0, + make_tuple(make_pass_through_transform(kNPerBlock), + make_merge_transform(make_tuple(kKPerBlock / kKPack, kKPack))), + make_tuple(sequence<1>{}, sequence<0, 2>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + + return k_lds_block_desc; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto + MakeKLdsStoreBlockDescriptor(number = number<0>{}) + { + // K is always k-major, we use async-copy to load into LDS + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1; + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t NumWarps = Problem::BlockFmhaShape::NumWarps; + constexpr index_t warpSize = ck_tile::get_warp_size(); + + constexpr index_t KPack = GetSmemKPackK(); // this is for lds + constexpr index_t KVector = GetAlignmentK(); // this is for global load + constexpr index_t kPad = + KPack; // for async-copy, this pad is between warps. Optimize this for lds_read speed + + static_assert(warpSize * KVector >= kKPerBlock && warpSize * KVector % kKPerBlock == 0); + constexpr index_t LanesPerK = + kKPerBlock / KVector; // how many lane (within a wave) to load K + constexpr index_t LaneGroups = + warpSize / + LanesPerK; // how many groups (within a wave), they may load different N, but same K + constexpr index_t NumIssues = kNPerBlock / (LaneGroups * NumWarps); + static_assert(NumIssues == kNPerBlock * kKPerBlock / (kBlockSize * KVector)); + + constexpr auto k_lds_block_desc_0 = make_naive_tensor_descriptor_with_offset( + make_tuple(number{}, // n0 + number{}, // n1 + number{}, // n2 + number{}, // k0 + number{}), // k1 + make_tuple(number{}, + number{}, + number{}, + number{}, + number<1>{}), + number()>{}, + number{}, + number<1>{}); + + // TODO this layout is hard coded, and will be used in async copy buffer view load + // in LDS the real layout is (bufs, N0, N2, N1*K0*K1) + constexpr auto k_lds_block_desc_issues_warps_lanes = transform_tensor_descriptor( + k_lds_block_desc_0, + make_tuple(make_pass_through_transform(number{}), + make_pass_through_transform(number{}), + make_merge_transform(make_tuple( + number{}, number{}, number{}))), + make_tuple(sequence<0>{}, sequence<2>{}, sequence<1, 3, 4>{}), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{})); + + return k_lds_block_desc_issues_warps_lanes; + } + +#if K_LDS_LOAD_USE_OFFSET_TRANSFORM + template + CK_TILE_HOST_DEVICE static constexpr auto + MakeKLdsLoadBlockDescriptor(number = number<0>{}) + { + // K is always k-major, we use async-copy to load into LDS + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1; + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t NumWarps = Problem::BlockFmhaShape::NumWarps; + constexpr index_t warpSize = ck_tile::get_warp_size(); + + constexpr index_t KPack = GetSmemKPackK(); // this is for lds + constexpr index_t KVector = GetAlignmentK(); // this is for global load + constexpr index_t kPad = KPack; // for async-copy, this pad is between warps + + static_assert(warpSize * KVector >= kKPerBlock && warpSize * KVector % kKPerBlock == 0); + constexpr index_t LanesPerK = kKPerBlock / KVector; // within a wave + constexpr index_t LaneGroups = warpSize / LanesPerK; // within a wave + constexpr index_t NumIssues = kNPerBlock / (LaneGroups * NumWarps); + static_assert(NumIssues == kNPerBlock * kKPerBlock / (kBlockSize * KVector)); + + constexpr auto k_lds_block_desc_0 = make_naive_tensor_descriptor_with_offset( + make_tuple(number{}, // n0 + number{}, // n2 + number{}, // n1 + number{}, // k0 + number{}), // k1 + make_tuple(number{}, + number{}, + number{}, + number{}, + number<1>{}), + number()>{}, + number{}, + number<1>{}); + + constexpr auto k_lds_block_desc = transform_tensor_descriptor( + k_lds_block_desc_0, + make_tuple( + make_merge_transform( + make_tuple(number{}, number{}, number{})), + make_merge_transform(make_tuple(number{}, number{}))), + make_tuple(sequence<0, 2, 1>{}, sequence<3, 4>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + + return k_lds_block_desc; + } +#else + template + CK_TILE_HOST_DEVICE static constexpr auto MakeKLdsLoadBlockDescriptor() + { + // K is always k-major, we use async-copy to load into LDS + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1; + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t NumWarps = Problem::BlockFmhaShape::NumWarps; + constexpr index_t warpSize = ck_tile::get_warp_size(); + + constexpr index_t KPack = GetSmemKPackK(); // this is for lds + constexpr index_t KVector = GetAlignmentK(); // this is for global load + constexpr index_t kPad = KPack; // for async-copy, this pad is between warps + + static_assert(warpSize * KVector >= kKPerBlock && warpSize * KVector % kKPerBlock == 0); + constexpr index_t LanesPerK = kKPerBlock / KVector; // within a wave + constexpr index_t LaneGroups = warpSize / LanesPerK; // within a wave + constexpr index_t NumIssues = kNPerBlock / (LaneGroups * NumWarps); + static_assert(NumIssues == kNPerBlock * kKPerBlock / (kBlockSize * KVector)); + // constexpr index_t SingleKSize = NumIssues * NumWarps * (warpSize * KVector + kPad); + // constexpr index_t SingleVSize = + // MakeVLdsBlockDescriptor().get_element_space_size(); + constexpr index_t BufferSize = + GetSingleSmemElementSpaceSize(); // math::max(SingleKSize, SingleVSize); + + constexpr auto k_lds_block_desc_0 = + make_naive_tensor_descriptor(make_tuple(number{}, // num_buffers + number{}, // n0 + number{}, // n2 + number{}, // n1 + number{}, // k0 + number{}), // k1 + make_tuple(number{}, + number{}, + number{}, + number{}, + number{}, + number<1>{}), + number{}, + number<1>{}); + + constexpr auto k_lds_block_desc = transform_tensor_descriptor( + k_lds_block_desc_0, + make_tuple( + make_merge_transform(make_tuple(number{}, + number{}, + number{}, + number{})), + make_merge_transform(make_tuple(number{}, number{}))), + make_tuple(sequence<0, 1, 3, 2>{}, sequence<4, 5>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + + return k_lds_block_desc; + } +#endif + + // 3d + padding + template + CK_TILE_HOST_DEVICE static constexpr auto MakeVLdsBlockDescriptor() + { + using VDataType = remove_cvref_t; + constexpr index_t Banks = 32; // TODO: need change based on arch + constexpr index_t PixelsPerRow = Banks * 4 / sizeof(VDataType); + constexpr index_t kKPack = GetSmemKPackV(); + static_assert(PixelsPerRow % kKPack == 0); + constexpr index_t NPerRow = PixelsPerRow / kKPack; + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1; + static_assert(kNPerBlock % NPerRow == 0); + static_assert(kKPerBlock % kKPack == 0); + + constexpr auto v_lds_block_desc_0 = make_naive_tensor_descriptor( + make_tuple(number{}, + number{}, + number{}, + number{}, + number{}), + make_tuple(number()>{}, + number<(kNPerBlock / NPerRow) * (PixelsPerRow + kKPack)>{}, + number{}, + number{}, + number<1>{}), + number{}, + number<1>{}); + + constexpr auto v_lds_block_desc = transform_tensor_descriptor( + v_lds_block_desc_0, + make_tuple( + make_merge_transform(make_tuple( + number{}, number{}, number{})), + make_merge_transform(make_tuple(number{}, number{}))), + make_tuple(sequence<0, 2, 3>{}, sequence<1, 4>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + + return v_lds_block_desc; + } + + template + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() + { + // TODO: assume Q is in register + // TODO: assume K/V has same data type + constexpr index_t single_smem_size = + GetSingleSmemElementSpaceSize() * sizeof(typename Problem::KDataType); + + return QXPolicy::template GetSmemSizeQ() + + single_smem_size * math::max(NumPrefetchK, NumPrefetchV); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeKDramTileDistribution() + { + if constexpr(!AsyncCopyK) + { + using KDataType = remove_cvref_t; + + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0; + + constexpr index_t K1 = 16 / sizeof(KDataType); + constexpr index_t K0 = kKPerBlock / K1; + constexpr index_t N2 = get_warp_size() / K0; + constexpr index_t N1 = kBlockSize / get_warp_size(); + constexpr index_t N0 = kNPerBlock / (N2 * N1); + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<2, 0>>, + sequence<1, 2>, + sequence<0, 1>>{}); + } + else + { + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1; + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t NumWarps = Problem::BlockFmhaShape::NumWarps; + constexpr index_t warpSize = ck_tile::get_warp_size(); + + constexpr index_t KVector = GetAlignmentK(); // this is for global load + + static_assert(warpSize * KVector >= kKPerBlock && warpSize * KVector % kKPerBlock == 0); + constexpr index_t LanesPerK = kKPerBlock / KVector; // within a wave + constexpr index_t LaneGroups = warpSize / LanesPerK; // within a wave + constexpr index_t NumIssues = kNPerBlock / (LaneGroups * NumWarps); + static_assert(NumIssues == kNPerBlock * kKPerBlock / (kBlockSize * KVector)); + + constexpr index_t N0 = NumIssues; + constexpr index_t N1 = LaneGroups; + constexpr index_t N2 = NumWarps; + constexpr index_t K0 = LanesPerK; + constexpr index_t K1 = KVector; + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<1, 0>>, + sequence<1, 2>, + sequence<0, 1>>{}); + } + } + + template + CK_TILE_DEVICE static constexpr auto MakeVDramTileDistribution() + { + using VLayout = remove_cvref_t; + + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1; + + if constexpr(ck_tile::is_same_v) + { + constexpr index_t N1 = GetAlignmentV(); + constexpr index_t N0 = kNPerBlock / N1; // P + + constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize; + static_assert(total_pixels % N1 == 0); // TODO: this is not always true? + constexpr index_t K3 = total_pixels / N1; + constexpr index_t kKPack = GetSmemKPackV(); + static_assert(kKPack % K3 == 0); + constexpr index_t K2 = kKPack / K3; // TODO: this dimention could be outside single wave + if constexpr(get_warp_size() % (K2 * N0) == 0) + { + constexpr index_t K1 = get_warp_size() / (K2 * N0); + constexpr index_t K0 = kBlockSize / get_warp_size(); + static_assert(kKPerBlock == K0 * K1 * K2 * K3); + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<2, 1, 2>>, + tuple, sequence<1, 0, 2>>, + sequence<2, 1>, + sequence<3, 1>>{}); + } + else + { + constexpr index_t K1 = (K2 * N0) / get_warp_size(); + constexpr index_t K2_m = K2 / K1; + constexpr index_t K0 = kBlockSize / get_warp_size() / K1; + static_assert(kKPerBlock == K0 * K1 * K2_m * K3); + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<0, 2>>, + sequence<2, 1>, + sequence<3, 1>>{}); + } + } + else + { + constexpr index_t K1 = GetAlignmentV(); + constexpr index_t K0 = kKPerBlock / K1; + constexpr index_t N2 = get_warp_size() / K0; + constexpr index_t N1 = kBlockSize / get_warp_size(); + constexpr index_t N0 = kNPerBlock / (N2 * N1); + static_assert(N0 != 0); + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<2, 0>>, + sequence<1, 2>, + sequence<0, 1>>{}); + } + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeBiasDramTileDistribution() + { + constexpr index_t MPerBlock = Problem::BlockFmhaShape::kM0; + constexpr index_t NPerBlock = Problem::BlockFmhaShape::kN0; + + constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); + using WG = remove_cvref_t())>; + + constexpr index_t MWarp = config.template at<1>(); + constexpr index_t NWarp = config.template at<2>(); + + constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM); + constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN); + + // Construct C-Block-HostTensor + constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding< + sequence<>, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + + constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{}); + + constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode); + + return c_block_dstr; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledVRegBlockDescriptor() + { + // This descriptor only used when V layout is seqlen * hdim + using VLayout = remove_cvref_t; + static_assert(ck_tile::is_same_v); + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1; + + constexpr index_t N1 = GetAlignmentV(); + constexpr index_t N0 = kNPerBlock / N1; + constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize; + static_assert(total_pixels % N1 == 0); // TODO: this is not always true? + constexpr index_t K3 = total_pixels / N1; + constexpr index_t kKPack = GetSmemKPackV(); + static_assert(kKPack % K3 == 0); + constexpr index_t K2 = kKPack / K3; // TODO: this dimention could be outside single wave + if constexpr(get_warp_size() % (K2 * N0) == 0) + { + constexpr index_t K1 = get_warp_size() / (K2 * N0); + constexpr index_t K0 = kBlockSize / get_warp_size(); + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<2, 1, 2>>, + tuple, sequence<1, 0, 2>>, + sequence<1, 2>, + sequence<1, 3>>{}); + } + else + { + constexpr index_t K1 = (K2 * N0) / get_warp_size(); + constexpr index_t K2_m = K2 / K1; + constexpr index_t K0 = kBlockSize / get_warp_size() / K1; + static_assert(kKPerBlock == K0 * K1 * K2_m * K3); + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<0, 2>>, + sequence<1, 2>, + sequence<1, 3>>{}); + } + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetKVBlockGemm() + { + using BlockGemmProblem = + BlockGemmPipelineProblem>; + + auto warp_gemm = [&]() { + if constexpr(Problem::kIsFp8) + { + return warp::WarpGemmImpl< + warp::WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution< + warp::WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base< + typename Problem::PDataType, + typename Problem::VDataType>, + 2>>{}; + // return + // warp::WarpGemmImpl>>{}; + } + else + { + return WarpGemmMfmaDispatcher< + typename Problem::PDataType, + typename Problem::VDataType, + typename Problem::OaccDataType, + Problem::BlockFmhaShape::Gemm1WarpTile::at(number<0>{}), + Problem::BlockFmhaShape::Gemm1WarpTile::at(number<1>{}), + Problem::BlockFmhaShape::Gemm1WarpTile::at(number<2>{}), + true>{}; + } + }(); + + using WarpGemm = remove_cvref_t; + + using BlockGemmPolicy = + BlockGemmARegBSmemCRegV2CustomPolicy; + return BlockGemmARegBSmemCRegV2{}; + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/pipeline/tile_fmha_shape.hpp b/include/ck_tile/ops/fmha/pipeline/tile_fmha_shape.hpp new file mode 100644 index 0000000000..c70af47dd0 --- /dev/null +++ b/include/ck_tile/ops/fmha/pipeline/tile_fmha_shape.hpp @@ -0,0 +1,47 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" + +namespace ck_tile { + +template +struct TileFmhaShape +{ + using BlockTile = remove_cvref_t; + using Gemm0BlockWarps = remove_cvref_t; + using Gemm0WarpTile = remove_cvref_t; + using Gemm1BlockWarps = remove_cvref_t; + using Gemm1WarpTile = remove_cvref_t; + + static constexpr index_t NumWarps = + reduce_on_sequence(Gemm0BlockWarps{}, math::multiplies{}, number<1>{}); + + static_assert(NumWarps == + reduce_on_sequence(Gemm1BlockWarps{}, math::multiplies{}, number<1>{})); + + static constexpr index_t kM0 = BlockTile::at(number<0>{}); // tile size along q seqlen + static constexpr index_t kN0 = BlockTile::at(number<1>{}); // tile size along k seqlen + static constexpr index_t kK0 = BlockTile::at(number<2>{}); // tile size along qk gemm unroll + static constexpr index_t kN1 = BlockTile::at(number<3>{}); // tile size along v head_dim + static constexpr index_t kK1 = BlockTile::at(number<4>{}); // tile size along kv gemm unroll + static constexpr index_t kK0BlockLength = + BlockTile::at(number<5>{}); // total length of K0, used for pipeline that need load Q at + // once (or repeately load Q as a whole tile) + static_assert(kK0BlockLength % kK0 == 0, "kK0BlockLength should be divisible by kK0"); + + // v, rowmajor : seqlen*hdim, colmajor : hdim*seqlen + static constexpr bool IsVLayoutRowMajor = IsVLayoutRowMajor_; + using VLayout = std::conditional_t; +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp b/include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp new file mode 100644 index 0000000000..1b14d0cae0 --- /dev/null +++ b/include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp @@ -0,0 +1,28 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" + +namespace ck_tile { + +template +struct TileFmhaTraits +{ + static constexpr bool kPadSeqLenQ = kPadSeqLenQ_; + static constexpr bool kPadSeqLenK = kPadSeqLenK_; + static constexpr bool kPadHeadDimQ = kPadHeadDimQ_; + static constexpr bool kPadHeadDimV = kPadHeadDimV_; + static constexpr bool kHasBias = kHasBias_; + static constexpr bool kStoreLSE = kStoreLSE_; + static constexpr index_t kBlockPerCu = kBlockPerCu_; +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/gemm.hpp b/include/ck_tile/ops/gemm.hpp new file mode 100644 index 0000000000..857305964f --- /dev/null +++ b/include/ck_tile/ops/gemm.hpp @@ -0,0 +1,30 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/ops/gemm/block/block_gemm_areg_bgmem_creg_problem.hpp" +#include "ck_tile/ops/gemm/block/block_gemm_areg_bgmem_creg_v1.hpp" +#include "ck_tile/ops/gemm/block/block_gemm_areg_bgmem_creg_v1_default_policy.hpp" +#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_problem.hpp" +#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1.hpp" +#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_custom_policy.hpp" +#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_default_policy.hpp" +#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2.hpp" +#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2_custom_policy.hpp" +#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2_default_policy.hpp" +#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_problem.hpp" +#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1.hpp" +#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_custom_policy.hpp" +#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_default_policy.hpp" +#include "ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v1.hpp" +#include "ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp" +#include "ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v2.hpp" +#include "ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v2_default_policy.hpp" +#include "ck_tile/ops/gemm/pipeline/block_gemm_pipeline_problem.hpp" +#include "ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp" +#include "ck_tile/ops/gemm/warp/warp_gemm.hpp" +#include "ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp" +#include "ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp" +#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp" +#include "ck_tile/ops/gemm/warp/warp_gemm_impl.hpp" diff --git a/include/ck_tile/ops/gemm/block/block_gemm_areg_bgmem_creg_problem.hpp b/include/ck_tile/ops/gemm/block/block_gemm_areg_bgmem_creg_problem.hpp new file mode 100644 index 0000000000..1053c751ad --- /dev/null +++ b/include/ck_tile/ops/gemm/block/block_gemm_areg_bgmem_creg_problem.hpp @@ -0,0 +1,25 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" + +namespace ck_tile { +// Problem Description for BlockGemmARegBGmemCReg +template +struct BlockGemmARegBGmemCRegProblem +{ + using ADataType = remove_cvref_t; + using BDataType = remove_cvref_t; + using CDataType = remove_cvref_t; + using BlockGemmShape = remove_cvref_t; + + static constexpr index_t kBlockSize = kBlockSize_; +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/gemm/block/block_gemm_areg_bgmem_creg_v1.hpp b/include/ck_tile/ops/gemm/block/block_gemm_areg_bgmem_creg_v1.hpp new file mode 100644 index 0000000000..5d900d6f7c --- /dev/null +++ b/include/ck_tile/ops/gemm/block/block_gemm_areg_bgmem_creg_v1.hpp @@ -0,0 +1,133 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" + +namespace ck_tile { + +// A is block distributed tensor +// B is block window on global memory +// C is block distributed tensor +// This will: +// 1. load B from global memory into shared memory and then +// 2. Call BlockGemmARegSGmemCRegV1 +template +struct BlockGemmARegBGmemCRegV1 +{ + using Problem = remove_cvref_t; + using Policy = remove_cvref_t; + using ADataType = remove_cvref_t; + using BDataType = remove_cvref_t; + using CDataType = remove_cvref_t; + using BlockGemmShape = remove_cvref_t; + + static constexpr index_t kBlockSize = Problem::kBlockSize; + + // use BlockGemmARegBSmemCRegV1 as the underlying block-GEMM implementation + using BlockGemmARegBSmemCRegImpl = BlockGemmARegBSmemCRegV1< + BlockGemmARegBSmemCRegProblem, + BlockGemmARegBSmemCRegV1DefaultPolicy>; + + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetStaticLdsSize() + { + return sizeof(BDataType) * + Policy::template MakeBSmemBlockDescriptor().get_element_space_size(); + } + + // C += A * B + template + CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor, + const ABlockTensor& a_block_tensor, + const BBlockGmemWindowTmp& b_block_gmem_window_tmp, + void* smem_ptr) const + { + static_assert( + is_same_v> && + is_same_v> && + is_same_v>, + "wrong!"); + + constexpr index_t MPerBlock = ABlockTensor{}.get_lengths()[number<0>{}]; + constexpr index_t NPerBlock = BBlockGmemWindowTmp{}.get_window_lengths()[number<0>{}]; + constexpr index_t KPerBlock = ABlockTensor{}.get_lengths()[number<1>{}]; + + static_assert(MPerBlock == BlockGemmShape::kM && NPerBlock == BlockGemmShape::kN && + KPerBlock == BlockGemmShape::kK, + "wrong!"); + + const auto b_block_gmem_window = + make_tile_window(b_block_gmem_window_tmp.get_bottom_tensor_view(), + make_tuple(number{}, number{}), + b_block_gmem_window_tmp.get_window_origin(), + Policy::template MakeBGmemTileDistribution()); + + // B LDS and LDS window + auto b_block_smem = make_tensor_view( + reinterpret_cast(smem_ptr), + Policy::template MakeBSmemBlockDescriptor()); + + auto b_block_smem_window = make_tile_window( + b_block_smem, make_tuple(number{}, number{}), {0, 0}); + + // load B tile from global mem + const auto b_block_tile = load_tile(b_block_gmem_window); + + // store B tile into shared mem + store_tile(b_block_smem_window, b_block_tile); + + // wait for store_tile to finish + block_sync_lds(); + + // block GEMM + BlockGemmARegBSmemCRegImpl{}(c_block_tensor, a_block_tensor, b_block_smem_window); + } + + // C = A * B + template + CK_TILE_DEVICE auto operator()(const ABlockTensor& a_block_tensor, + const BBlockGmemWindowTmp& b_block_gmem_window_tmp, + void* smem_ptr) const + { + static_assert(is_same_v> && + is_same_v>, + "wrong!"); + + constexpr index_t MPerBlock = ABlockTensor{}.get_lengths()[number<0>{}]; + constexpr index_t NPerBlock = BBlockGmemWindowTmp{}.get_window_lengths()[number<0>{}]; + constexpr index_t KPerBlock = ABlockTensor{}.get_lengths()[number<1>{}]; + + static_assert(MPerBlock == BlockGemmShape::kM && NPerBlock == BlockGemmShape::kN && + KPerBlock == BlockGemmShape::kK, + "wrong!"); + + const auto b_block_gmem_window = + make_tile_window(b_block_gmem_window_tmp.get_bottom_tensor_view(), + make_tuple(number{}, number{}), + b_block_gmem_window_tmp.get_window_origin(), + Policy::template MakeBGmemTileDistribution()); + + // B LDS and LDS window + auto b_block_smem = make_tensor_view( + reinterpret_cast(smem_ptr), + Policy::template MakeBSmemBlockDescriptor()); + + auto b_block_smem_window = make_tile_window( + b_block_smem, make_tuple(number{}, number{}), {0, 0}); + + // load B tile from global mem + const auto b_block_tile = load_tile(b_block_gmem_window); + + // store B tile into shared mem + store_tile(b_block_smem_window, b_block_tile); + + // wait for store_tile to finish + block_sync_lds(); + + // block GEMM + return BlockGemmARegBSmemCRegImpl{}(a_block_tensor, b_block_smem_window); + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/gemm/block/block_gemm_areg_bgmem_creg_v1_default_policy.hpp b/include/ck_tile/ops/gemm/block/block_gemm_areg_bgmem_creg_v1_default_policy.hpp new file mode 100644 index 0000000000..4156398bd3 --- /dev/null +++ b/include/ck_tile/ops/gemm/block/block_gemm_areg_bgmem_creg_v1_default_policy.hpp @@ -0,0 +1,110 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" + +namespace ck_tile { + +// Default policy for BlockGemmARegBGmemCRegV1 +// Default policy class should not be templated, put template on member functions instead +struct BlockGemmARegBGmemCRegV1DefaultPolicy +{ + template + CK_TILE_HOST_DEVICE static constexpr auto MakeBGmemTileDistribution() + { + using BDataType = remove_cvref_t; + + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN; + constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK; + + constexpr index_t K1 = 16 / sizeof(BDataType); + constexpr index_t K0 = kKPerBlock / K1; + constexpr index_t N2 = get_warp_size() / K0; + constexpr index_t N1 = kBlockSize / get_warp_size(); + constexpr index_t N0 = kNPerBlock / (N2 * N1); + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<2, 0>>, + sequence<1, 2>, + sequence<0, 1>>{}); + } + +#if 0 + // 2d + template + CK_TILE_HOST_DEVICE static constexpr auto MakeBLdsBlockDescriptor() + { + constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN; + constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK; + + constexpr auto b_lds_block_desc = + make_naive_tensor_descriptor_packed(make_tuple(kNPerBlock, kKPerBlock), number<32>{}); + + return b_lds_block_desc; + } +#elif 0 + // 3d + padding + template + CK_TILE_HOST_DEVICE static constexpr auto MakeBSmemBlockDescriptor() + { + constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN; + constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK; + + constexpr auto b_lds_block_desc_0 = make_naive_tensor_descriptor( + make_tuple(number{}, number{}, number<8>{}), + make_tuple(number<(kNPerBlock + 1) * 8>{}, number<8>{}, number<1>{}), + number<8>{}, + number<1>{}); + + constexpr auto b_lds_block_desc = transform_tensor_descriptor( + b_lds_block_desc_0, + make_tuple(make_pass_through_transform(kNPerBlock), + make_merge_transform(make_tuple(kKPerBlock / 8, 8))), + make_tuple(sequence<1>{}, sequence<0, 2>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + + return b_lds_block_desc; + } +#elif 1 + // fake XOR + template + CK_TILE_HOST_DEVICE static constexpr auto MakeBSmemBlockDescriptor() + { + using BDataType = remove_cvref_t; + + constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN; + constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK; + + constexpr auto b_lds_block_desc_d1_d2_d3 = make_naive_tensor_descriptor_packed( + make_tuple(number{}, number<2>{}, number{}), + number{}); + + constexpr index_t kK1 = 16 / sizeof(BDataType); + + constexpr auto b_lds_block_desc_d4_d5_d6 = transform_tensor_descriptor( + b_lds_block_desc_d1_d2_d3, + make_tuple( + make_xor_transform(make_tuple(number{}, number{}), kK1), + make_pass_through_transform(2)), + make_tuple(sequence<0, 2>{}, sequence<1>{}), + make_tuple(sequence<0, 2>{}, sequence<1>{})); + + constexpr auto b_lds_block_desc_n_k = transform_tensor_descriptor( + b_lds_block_desc_d4_d5_d6, + make_tuple(make_merge_transform(make_tuple(number{}, number<2>{})), + make_pass_through_transform(kKPerBlock)), + make_tuple(sequence<0, 1>{}, sequence<2>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + + return b_lds_block_desc_n_k; + } +#endif +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_problem.hpp b/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_problem.hpp new file mode 100644 index 0000000000..7a0390a8a2 --- /dev/null +++ b/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_problem.hpp @@ -0,0 +1,26 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" + +namespace ck_tile { + +// Problem Description for BlockGemmARegBSmemCReg +template +struct BlockGemmARegBSmemCRegProblem +{ + using ADataType = remove_cvref_t; + using BDataType = remove_cvref_t; + using CDataType = remove_cvref_t; + using BlockGemmShape = remove_cvref_t; + + static constexpr index_t kBlockSize = kBlockSize_; +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1.hpp b/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1.hpp new file mode 100644 index 0000000000..9e1f529426 --- /dev/null +++ b/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1.hpp @@ -0,0 +1,337 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" + +namespace ck_tile { + +// A is block distributed tensor +// B is block window on shared memory +// C is block distributed tensor +template +struct BlockGemmARegBSmemCRegV1 +{ + using Problem = remove_cvref_t; + using Policy = remove_cvref_t; + using ADataType = remove_cvref_t; + using BDataType = remove_cvref_t; + using CDataType = remove_cvref_t; + using BlockGemmShape = remove_cvref_t; + + static constexpr index_t kBlockSize = Problem::kBlockSize; + + // C += A * B + template + CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor, + const ABlockTensorTmp& a_block_tensor_tmp, + const BBlockWindowTmp& b_block_window_tmp) const + { + static_assert(is_same_v> && + is_same_v> && + is_same_v>, + "wrong!"); + + constexpr index_t MPerBlock = ABlockTensorTmp{}.get_lengths()[number<0>{}]; + constexpr index_t NPerBlock = BBlockWindowTmp{}.get_window_lengths()[number<0>{}]; + constexpr index_t KPerBlock = ABlockTensorTmp{}.get_lengths()[number<1>{}]; + + static_assert(MPerBlock == BlockGemmShape::kM && NPerBlock == BlockGemmShape::kN && + KPerBlock == BlockGemmShape::kK, + "wrong!"); + + constexpr auto config = Policy::template GetWarpGemmMWarpNWarp(); + + using WG = remove_cvref_t())>; + + constexpr index_t MWarp = config.template at<1>(); + constexpr index_t NWarp = config.template at<2>(); + + constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM); + constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN); + constexpr index_t KIterPerWarp = KPerBlock / WG::kK; + + constexpr index_t NPerBlockPerIter = NPerBlock / NIterPerWarp; + constexpr index_t KPerBlockPerIter = KPerBlock / KIterPerWarp; + + const index_t iNWarp = get_warp_id() % NWarp; + + constexpr auto a_block_outer_dstr_encoding = + tile_distribution_encoding, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + + constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding< + sequence<>, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + + constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + a_block_outer_dstr_encoding, typename WG::AWarpDstrEncoding{}); + + constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{}); + + constexpr auto a_block_dstr = make_static_tile_distribution(a_block_dstr_encode); + + // constrcut from A-block-tensor from A-Block-tensor-tmp + // FIXME: need method to check a_block_tensor and a_block_tensor_tmp have equivalent + // distribution + auto a_block_tensor = + make_static_distributed_tensor(a_block_dstr); + + a_block_tensor.get_thread_buffer() = a_block_tensor_tmp.get_thread_buffer(); + + // construct B-warp-window + auto b_warp_window_tmp = make_tile_window( + b_block_window_tmp.get_bottom_tensor_view(), + make_tuple(number{}, number{}), + b_block_window_tmp.get_window_origin() + multi_index<2>{iNWarp * WG::kN, 0}, + make_static_tile_distribution(typename WG::BWarpDstrEncoding{})); + +#if 0 // FIXME: using array will cause register spill + array, NIterPerWarp> b_warp_windows{ + {b_warp_window_tmp}}; + + for(index_t nIter = 0; nIter < NIterPerWarp; nIter++) + { + for(index_t kIter = 0; kIter < KIterPerWarp; kIter++) + { + move_tile_window(b_warp_windows(nIter)(kIter), + {nIter * NPerBlockPerIter, kIter * KPerBlockPerIter}); + } + } +#else + statically_indexed_array< + statically_indexed_array, + NIterPerWarp> + b_warp_windows; + + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + b_warp_windows(nIter)(kIter) = b_warp_window_tmp; + + move_tile_window(b_warp_windows(nIter)(kIter), + {nIter * NPerBlockPerIter, kIter * KPerBlockPerIter}); + }); + }); +#endif + + // check C-block-distribution + static_assert( + is_same_v, + remove_cvref_t>, + "wrong!"); + + using AWarpDstr = typename WG::AWarpDstr; + using CWarpDstr = typename WG::CWarpDstr; + + using AWarpTensor = typename WG::AWarpTensor; + using CWarpTensor = typename WG::CWarpTensor; + + constexpr auto a_warp_y_lengths = + to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); + constexpr auto c_warp_y_lengths = + to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); + + constexpr auto a_warp_y_index_zeros = uniform_sequence_gen_t{}; + constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t{}; + + // hot loop: + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { + // read A warp tensor from A block tensor + AWarpTensor a_warp_tensor; + + a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data( + merge_sequences(sequence{}, a_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, a_warp_y_lengths)); + + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + // read B warp tensor from B Block window + const auto b_warp_tensor = load_tile(b_warp_windows(nIter)(kIter)); + + // read C warp tensor from C block tensor + CWarpTensor c_warp_tensor; + + c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); + + // warp GEMM + WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor); + + // write C warp tensor into C block tensor + c_block_tensor.set_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), + c_warp_tensor.get_thread_buffer()); + }); + }); + }); + } + + // C = A * B + template + CK_TILE_DEVICE auto operator()(const ABlockTensorTmp& a_block_tensor_tmp, + const BBlockWindowTmp& b_block_window_tmp) const + { + static_assert(is_same_v> && + is_same_v>, + "wrong!"); + + constexpr index_t MPerBlock = ABlockTensorTmp{}.get_lengths()[number<0>{}]; + constexpr index_t NPerBlock = BBlockWindowTmp{}.get_window_lengths()[number<0>{}]; + constexpr index_t KPerBlock = ABlockTensorTmp{}.get_lengths()[number<1>{}]; + + static_assert(MPerBlock == BlockGemmShape::kM && NPerBlock == BlockGemmShape::kN && + KPerBlock == BlockGemmShape::kK, + "wrong!"); + + constexpr auto config = Policy::template GetWarpGemmMWarpNWarp(); + + using WG = remove_cvref_t())>; + + constexpr index_t MWarp = config.template at<1>(); + constexpr index_t NWarp = config.template at<2>(); + + constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM); + constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN); + constexpr index_t KIterPerWarp = KPerBlock / WG::kK; + + constexpr index_t NPerBlockPerIter = NPerBlock / NIterPerWarp; + constexpr index_t KPerBlockPerIter = KPerBlock / KIterPerWarp; + + const index_t iNWarp = get_warp_id() % NWarp; + + constexpr auto a_block_outer_dstr_encoding = + tile_distribution_encoding, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + + constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding< + sequence<>, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + + constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + a_block_outer_dstr_encoding, typename WG::AWarpDstrEncoding{}); + + constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{}); + + constexpr auto a_block_dstr = make_static_tile_distribution(a_block_dstr_encode); + constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode); + + // constrcut from A-block-tensor from A-Block-tensor-tmp + // FIXME: need method to check a_block_tensor and a_block_tensor_tmp have equivalent + // distribution + auto a_block_tensor = + make_static_distributed_tensor(a_block_dstr); + + a_block_tensor.get_thread_buffer() = a_block_tensor_tmp.get_thread_buffer(); + + // construct B-warp-window + auto b_warp_window_tmp = make_tile_window( + b_block_window_tmp.get_bottom_tensor_view(), + make_tuple(number{}, number{}), + b_block_window_tmp.get_window_origin() + multi_index<2>{iNWarp * WG::kN, 0}, + make_static_tile_distribution(typename WG::BWarpDstrEncoding{})); + +#if 0 // FIXME: using array will cause register spill + array, NIterPerWarp> b_warp_windows{ + {b_warp_window_tmp}}; + + for(index_t nIter = 0; nIter < NIterPerWarp; nIter++) + { + for(index_t kIter = 0; kIter < KIterPerWarp; kIter++) + { + move_tile_window(b_warp_windows(nIter)(kIter), + {nIter * NPerBlockPerIter, kIter * KPerBlockPerIter}); + } + } +#else + statically_indexed_array< + statically_indexed_array, + NIterPerWarp> + b_warp_windows; + + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + b_warp_windows(nIter)(kIter) = b_warp_window_tmp; + + move_tile_window(b_warp_windows(nIter)(kIter), + {nIter * NPerBlockPerIter, kIter * KPerBlockPerIter}); + }); + }); +#endif + + // Construct C-Block-HostTensor + auto c_block_tensor = make_static_distributed_tensor(c_block_dstr); + + using AWarpDstr = typename WG::AWarpDstr; + using CWarpDstr = typename WG::CWarpDstr; + + using AWarpTensor = typename WG::AWarpTensor; + using CWarpTensor = typename WG::CWarpTensor; + + constexpr auto a_warp_y_lengths = + to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); + constexpr auto c_warp_y_lengths = + to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); + + constexpr auto a_warp_y_index_zeros = uniform_sequence_gen_t{}; + constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t{}; + + // hot loop: + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { + // read A warp tensor from A block tensor + AWarpTensor a_warp_tensor; + + a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data( + merge_sequences(sequence{}, a_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, a_warp_y_lengths)); + + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + // read B warp tensor from B Block window + const auto b_warp_tensor = load_tile(b_warp_windows(nIter)(kIter)); + + // read C warp tensor from C block tensor + CWarpTensor c_warp_tensor; + + c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); + + // warp GEMM + WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor); + + // write C warp tensor into C block tensor + c_block_tensor.set_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), + c_warp_tensor.get_thread_buffer()); + }); + }); + }); + + return c_block_tensor; + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_custom_policy.hpp b/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_custom_policy.hpp new file mode 100644 index 0000000000..779113d96a --- /dev/null +++ b/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_custom_policy.hpp @@ -0,0 +1,36 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" + +namespace ck_tile { + +template +struct BlockGemmARegBSmemCRegV1CustomPolicy +{ + using AType = remove_cvref_t; + using BType = remove_cvref_t; + using CType = remove_cvref_t; + + using BlockWarps = remove_cvref_t; + + static constexpr index_t kMWarps = BlockWarps::at(number<0>{}); + static constexpr index_t kNWarps = BlockWarps::at(number<1>{}); + static constexpr index_t kKWarps = BlockWarps::at(number<2>{}); + + using WarpGemm = remove_cvref_t; + + template + CK_TILE_HOST_DEVICE static constexpr auto GetWarpGemmMWarpNWarp() + { + return make_tuple(WarpGemm{}, kMWarps, kNWarps); + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_default_policy.hpp b/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_default_policy.hpp new file mode 100644 index 0000000000..d19d167cfc --- /dev/null +++ b/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_default_policy.hpp @@ -0,0 +1,55 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" + +namespace ck_tile { + +// Default policy for BlockGemmARegBSmemCRegV1 +// Default policy class should not be templated, put template on member functions instead +struct BlockGemmARegBSmemCRegV1DefaultPolicy +{ + template + CK_TILE_HOST_DEVICE static constexpr auto GetWarpGemmMWarpNWarp() + { + if constexpr(is_same_v && + is_same_v && + is_same_v) + { +#if 0 + constexpr index_t kBlockSize = Problem::kBlockSize; + + constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM; + constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN; + constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK; + + static_assert(kBlockSize % get_warp_size() == 0, "wrong!"); + + constexpr index_t NumWarp = kBlockSize / get_warp_size(); + + // FIXME + if constexpr(NumWarp == 4 && kMPerBlock % 128 == 0 && + kNPerBlock % 128 == 0 % kKPerBlock % 16 == 0) + { + return make_tuple(WarpGemmMfmaF16F16F32M32N32K8{}, 4, 1); + } + else + { + return make_tuple(WarpGemmMfmaF16F16F32M32N32K8{}, 4, 1); + } +#else + return make_tuple(WarpGemmMfmaF16F16F32M32N32K8TransposedCDistribution{}, 4, 1); +#endif + } + else if constexpr(is_same_v && + is_same_v && + is_same_v) + { + return make_tuple(WarpGemmMfmaBf16Bf16F32M32N32K8TransposedCDistribution{}, 4, 1); + } + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2.hpp b/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2.hpp new file mode 100644 index 0000000000..c03379cc83 --- /dev/null +++ b/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2.hpp @@ -0,0 +1,225 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" + +namespace ck_tile { + +// A is block distributed tensor +// B is block window on shared memory +// C is block distributed tensor +template +struct BlockGemmARegBSmemCRegV2 +{ + using Problem = remove_cvref_t; + using Policy = remove_cvref_t; + using ADataType = remove_cvref_t; + using BDataType = remove_cvref_t; + using CDataType = remove_cvref_t; + using BlockGemmShape = remove_cvref_t; + + static constexpr index_t kBlockSize = Problem::kBlockSize; + + // C += A * B + template + CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor, + const ABlockTensorTmp& a_block_tensor_tmp, + const BBlockWindowTmp& b_block_window_tmp) const + { + static_assert(is_same_v> && + is_same_v> && + is_same_v>, + "wrong!"); + + constexpr index_t MPerBlock = ABlockTensorTmp{}.get_lengths()[number<0>{}]; + constexpr index_t NPerBlock = BBlockWindowTmp{}.get_window_lengths()[number<0>{}]; + constexpr index_t KPerBlock = ABlockTensorTmp{}.get_lengths()[number<1>{}]; + + static_assert(MPerBlock == BlockGemmShape::kM && NPerBlock == BlockGemmShape::kN && + KPerBlock == BlockGemmShape::kK, + "wrong!"); + + constexpr auto config = Policy::template GetWarpGemmMWarpNWarp(); + + using WG = remove_cvref_t())>; + + constexpr index_t MWarp = config.template at<1>(); + constexpr index_t NWarp = config.template at<2>(); + + constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM); + constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN); + constexpr index_t KIterPerWarp = KPerBlock / WG::kK; + + constexpr index_t NPerBlockPerIter = NPerBlock / NIterPerWarp; + constexpr index_t KPerBlockPerIter = KPerBlock / KIterPerWarp; + + const index_t iNWarp = get_warp_id() % NWarp; + + constexpr auto a_block_outer_dstr_encoding = + tile_distribution_encoding, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + + constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding< + sequence<>, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + + constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + a_block_outer_dstr_encoding, typename WG::AWarpDstrEncoding{}); + + constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{}); + + constexpr auto a_block_dstr = make_static_tile_distribution(a_block_dstr_encode); + + // constrcut from A-block-tensor from A-Block-tensor-tmp + // FIXME: need method to check a_block_tensor and a_block_tensor_tmp have equivalent + // distribution + auto a_block_tensor = + make_static_distributed_tensor(a_block_dstr); + + a_block_tensor.get_thread_buffer() = a_block_tensor_tmp.get_thread_buffer(); + + // construct B-warp-window + auto b_warp_window_tmp = make_tile_window( + b_block_window_tmp.get_bottom_tensor_view(), + make_tuple(number{}, number{}), + b_block_window_tmp.get_window_origin() + multi_index<2>{iNWarp * WG::kN, 0}, + make_static_tile_distribution(typename WG::BWarpDstrEncoding{})); + +#if 0 // FIXME: using array will cause register spill + array, NIterPerWarp> b_warp_windows{ + {b_warp_window_tmp}}; + + for(index_t nIter = 0; nIter < NIterPerWarp; nIter++) + { + for(index_t kIter = 0; kIter < KIterPerWarp; kIter++) + { + move_tile_window(b_warp_windows(nIter)(kIter), + {nIter * NPerBlockPerIter, kIter * KPerBlockPerIter}); + } + } +#else + statically_indexed_array< + statically_indexed_array, + NIterPerWarp> + b_warp_windows; + + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + b_warp_windows(nIter)(kIter) = b_warp_window_tmp; + + move_tile_window(b_warp_windows(nIter)(kIter), + {nIter * NPerBlockPerIter, kIter * KPerBlockPerIter}); + }); + }); +#endif + + // check C-block-distribution + static_assert( + is_same_v, + remove_cvref_t>, + "wrong!"); + + using AWarpDstr = typename WG::AWarpDstr; + using CWarpDstr = typename WG::CWarpDstr; + + using AWarpTensor = typename WG::AWarpTensor; + using CWarpTensor = typename WG::CWarpTensor; + + constexpr auto a_warp_y_lengths = + to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); + constexpr auto c_warp_y_lengths = + to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); + + constexpr auto a_warp_y_index_zeros = uniform_sequence_gen_t{}; + constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t{}; + + // hot loop: + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + // read B warp tensor from B Block window + const auto b_warp_tensor = load_tile(b_warp_windows(nIter)(kIter)); + + static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { + // read A warp tensor from A block tensor + AWarpTensor a_warp_tensor; + + a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data( + merge_sequences(sequence{}, a_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, a_warp_y_lengths)); + + // read C warp tensor from C block tensor + CWarpTensor c_warp_tensor; + + c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); + + // warp GEMM + WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor); + // WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor_array[nIter]); + + // write C warp tensor into C block tensor + c_block_tensor.set_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), + c_warp_tensor.get_thread_buffer()); + }); + }); + }); + } + + CK_TILE_DEVICE constexpr auto MakeCBlockTile() const + { + constexpr index_t MPerBlock = BlockGemmShape::kM; + constexpr index_t NPerBlock = BlockGemmShape::kN; + + constexpr auto config = Policy::template GetWarpGemmMWarpNWarp(); + + using WG = remove_cvref_t())>; + + constexpr index_t MWarp = config.template at<1>(); + constexpr index_t NWarp = config.template at<2>(); + + constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM); + constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN); + // constexpr index_t KIterPerWarp = KPerBlock / WG::kK; + + constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding< + sequence<>, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + + constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{}); + constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode); + auto c_block_tensor = make_static_distributed_tensor(c_block_dstr); + return c_block_tensor; + } + + // C = A * B + template + CK_TILE_DEVICE auto operator()(const ABlockTensorTmp& a_block_tensor_tmp, + const BBlockWindowTmp& b_block_window_tmp) const + { + auto c_block_tensor = MakeCBlockTile(); + operator()(c_block_tensor, a_block_tensor_tmp, b_block_window_tmp); + return c_block_tensor; + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2_custom_policy.hpp b/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2_custom_policy.hpp new file mode 100644 index 0000000000..8bcd04b7b0 --- /dev/null +++ b/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2_custom_policy.hpp @@ -0,0 +1,36 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" + +namespace ck_tile { + +template +struct BlockGemmARegBSmemCRegV2CustomPolicy +{ + using AType = remove_cvref_t; + using BType = remove_cvref_t; + using CType = remove_cvref_t; + + using BlockWarps = remove_cvref_t; + + static constexpr index_t kMWarps = BlockWarps::at(number<0>{}); + static constexpr index_t kNWarps = BlockWarps::at(number<1>{}); + static constexpr index_t kKWarps = BlockWarps::at(number<2>{}); + + using WarpGemm = remove_cvref_t; + + template + CK_TILE_HOST_DEVICE static constexpr auto GetWarpGemmMWarpNWarp() + { + return make_tuple(WarpGemm{}, kMWarps, kNWarps); + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2_default_policy.hpp b/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2_default_policy.hpp new file mode 100644 index 0000000000..3c091c0b73 --- /dev/null +++ b/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2_default_policy.hpp @@ -0,0 +1,45 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" + +namespace ck_tile { + +// Default policy for BlockGemmARegBSmemCRegV2 +// Default policy class should not be templated, put template on member functions instead +struct BlockGemmARegBSmemCRegV2DefaultPolicy +{ + template + CK_TILE_HOST_DEVICE static constexpr auto GetWarpGemmMWarpNWarp() + { + +#if 0 + constexpr index_t kBlockSize = Problem::kBlockSize; + + constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM; + constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN; + constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK; + + static_assert(kBlockSize % get_warp_size() == 0, "wrong!"); + + constexpr index_t NumWarp = kBlockSize / get_warp_size(); + + // FIXME + if constexpr(NumWarp == 4 && kMPerBlock % 128 == 0 && + kNPerBlock % 128 == 0 % kKPerBlock % 16 == 0) + { + return make_tuple(WarpGemmMfmaF16F16F32M32N32K8{}, 4, 1); + } + else + { + return make_tuple(WarpGemmMfmaF16F16F32M32N32K8{}, 4, 1); + } +#else + return make_tuple(WarpGemmMfmaF16F16F32M32N32K8TransposedCDistribution{}, 4, 1); +#endif + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_problem.hpp b/include/ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_problem.hpp new file mode 100644 index 0000000000..ed772891a4 --- /dev/null +++ b/include/ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_problem.hpp @@ -0,0 +1,26 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" + +namespace ck_tile { + +// Problem Description for BlockGemmASmemBSmemCRegV1 +template +struct BlockGemmASmemBSmemCRegProblem +{ + using ADataType = remove_cvref_t; + using BDataType = remove_cvref_t; + using CDataType = remove_cvref_t; + using BlockGemmShape = remove_cvref_t; + + static constexpr index_t kBlockSize = kBlockSize_; +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1.hpp b/include/ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1.hpp new file mode 100644 index 0000000000..de93c7dad4 --- /dev/null +++ b/include/ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1.hpp @@ -0,0 +1,212 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" + +namespace ck_tile { + +// A is block window on shared memory +// B is block window on shared memory +// C is block distributed tensor +template +struct BlockGemmASmemBSmemCRegV1 +{ + using Problem = remove_cvref_t; + using Policy = remove_cvref_t; + using ADataType = remove_cvref_t; + using BDataType = remove_cvref_t; + using CDataType = remove_cvref_t; + using BlockGemmShape = remove_cvref_t; + + static constexpr index_t kBlockSize = Problem::kBlockSize; + + // C += A * B + template + CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor, + const ABlockWindowTmp& a_block_window_tmp, + const BBlockWindowTmp& b_block_window_tmp) const + { + static_assert(is_same_v && + is_same_v && + is_same_v, + "wrong!"); + + constexpr index_t MPerBlock = ABlockWindowTmp{}.get_window_lengths()[number<0>{}]; + constexpr index_t NPerBlock = BBlockWindowTmp{}.get_window_lengths()[number<0>{}]; + constexpr index_t KPerBlock = ABlockWindowTmp{}.get_window_lengths()[number<1>{}]; + + static_assert(MPerBlock == BlockGemmShape::kM && NPerBlock == BlockGemmShape::kN && + KPerBlock == BlockGemmShape::kK, + "wrong!"); + + constexpr auto config = Policy::template GetWarpGemmMWarpNWarp(); + + using WG = remove_cvref_t())>; + + constexpr index_t MWarp = config.template at<1>(); + constexpr index_t NWarp = config.template at<2>(); + + constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM); + constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN); + constexpr index_t KIterPerWarp = KPerBlock / WG::kK; + + constexpr index_t MPerBlockPerIter = MPerBlock / MIterPerWarp; + constexpr index_t NPerBlockPerIter = NPerBlock / NIterPerWarp; + constexpr index_t KPerBlockPerIter = KPerBlock / KIterPerWarp; + + const index_t iMWarp = get_warp_id() / NWarp; + const index_t iNWarp = get_warp_id() % NWarp; + + // construct A-warp-window + auto a_warp_window_tmp = make_tile_window( + a_block_window_tmp.get_bottom_tensor_view(), + make_tuple(number{}, number{}), + a_block_window_tmp.get_window_origin() + multi_index<2>{iMWarp * WG::kM, 0}, + make_static_tile_distribution(typename WG::AWarpDstrEncoding{})); + +#if 0 // FIXME: using array will cause register spill + array, MIterPerWarp> a_warp_windows{ + {a_warp_window_tmp}}; + + for(index_t mIter = 0; mIter < MIterPerWarp; mIter++) + { + for(index_t kIter = 0; kIter < KIterPerWarp; kIter++) + { + move_tile_window(a_warp_windows(mIter)(kIter), + {mIter * MPerBlockPerIter, kIter * KPerBlockPerIter}); + } + } +#else + statically_indexed_array< + statically_indexed_array, + MIterPerWarp> + a_warp_windows; + + static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + a_warp_windows(mIter)(kIter) = a_warp_window_tmp; + + move_tile_window(a_warp_windows(mIter)(kIter), + {mIter * MPerBlockPerIter, kIter * KPerBlockPerIter}); + }); + }); +#endif + + // construct B-warp-window + auto b_warp_window_tmp = make_tile_window( + b_block_window_tmp.get_bottom_tensor_view(), + make_tuple(number{}, number{}), + b_block_window_tmp.get_window_origin() + multi_index<2>{iNWarp * WG::kN, 0}, + make_static_tile_distribution(typename WG::BWarpDstrEncoding{})); + +#if 0 // FIXME: using array will cause register spill + array, NIterPerWarp> b_warp_windows{ + {b_warp_window_tmp}}; + + for(index_t nIter = 0; nIter < NIterPerWarp; nIter++) + { + for(index_t kIter = 0; kIter < KIterPerWarp; kIter++) + { + move_tile_window(b_warp_windows(nIter)(kIter), + {nIter * NPerBlockPerIter, kIter * KPerBlockPerIter}); + } + } +#else + statically_indexed_array< + statically_indexed_array, + NIterPerWarp> + b_warp_windows; + + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + b_warp_windows(nIter)(kIter) = b_warp_window_tmp; + + move_tile_window(b_warp_windows(nIter)(kIter), + {nIter * NPerBlockPerIter, kIter * KPerBlockPerIter}); + }); + }); +#endif + + using CWarpDstr = typename WG::CWarpDstr; + using CWarpTensor = typename WG::CWarpTensor; + + constexpr auto c_warp_y_lengths = + to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); + constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t{}; + + // hot loop: + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { + // read A warp tensor from A block window + const auto a_warp_tensor = load_tile(a_warp_windows(mIter)(kIter)); + + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + // read B warp tensor from B Block window + const auto b_warp_tensor = load_tile(b_warp_windows(nIter)(kIter)); + + // read C warp tensor from C block tensor + CWarpTensor c_warp_tensor; + + c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); + + // warp GEMM + WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor); + + // write C warp tensor into C block tensor + c_block_tensor.set_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), + c_warp_tensor.get_thread_buffer()); + }); + }); + }); + } + + CK_TILE_DEVICE constexpr auto MakeCBlockTile() const + { + constexpr index_t MPerBlock = BlockGemmShape::kM; + constexpr index_t NPerBlock = BlockGemmShape::kN; + + constexpr auto config = Policy::template GetWarpGemmMWarpNWarp(); + + using WG = remove_cvref_t())>; + + constexpr index_t MWarp = config.template at<1>(); + constexpr index_t NWarp = config.template at<2>(); + + constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM); + constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN); + + constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding< + sequence<>, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + + constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{}); + + constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode); + + auto c_block_tensor = make_static_distributed_tensor(c_block_dstr); + return c_block_tensor; + } + + // C = A * B + template + CK_TILE_DEVICE auto operator()(const ABlockTensorTmp& a_block_tensor_tmp, + const BBlockWindowTmp& b_block_window_tmp) const + { + auto c_block_tensor = MakeCBlockTile(); + operator()(c_block_tensor, a_block_tensor_tmp, b_block_window_tmp); + return c_block_tensor; + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_custom_policy.hpp b/include/ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_custom_policy.hpp new file mode 100644 index 0000000000..319711088f --- /dev/null +++ b/include/ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_custom_policy.hpp @@ -0,0 +1,38 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" + +namespace ck_tile { + +// Default policy for BlockGemmASmemBSmemCRegV1 +// Default policy class should not be templated, put template on member functions instead +template +struct BlockGemmASmemBSmemCRegV1CustomPolicy +{ + using AType = remove_cvref_t; + using BType = remove_cvref_t; + using CType = remove_cvref_t; + + using BlockWarps = remove_cvref_t; + + static constexpr index_t kMWarps = BlockWarps::at(number<0>{}); + static constexpr index_t kNWarps = BlockWarps::at(number<1>{}); + static constexpr index_t kKWarps = BlockWarps::at(number<2>{}); + + using WarpGemm = remove_cvref_t; + + template + CK_TILE_HOST_DEVICE static constexpr auto GetWarpGemmMWarpNWarp() + { + return make_tuple(WarpGemm{}, kMWarps, kNWarps); + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_default_policy.hpp b/include/ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_default_policy.hpp new file mode 100644 index 0000000000..60a25549ed --- /dev/null +++ b/include/ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_default_policy.hpp @@ -0,0 +1,54 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" + +namespace ck_tile { + +// Default policy for BlockGemmASmemBSmemCRegV1 +// Default policy class should not be templated, put template on member functions instead +struct BlockGemmASmemBSmemCRegV1DefaultPolicy +{ + template + CK_TILE_HOST_DEVICE static constexpr auto GetWarpGemmMWarpNWarp() + { + if constexpr(is_same_v && + is_same_v && + is_same_v) + { +#if 0 + constexpr index_t kBlockSize = Problem::kBlockSize; + + constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM; + constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN; + constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK; + + static_assert(kBlockSize % get_warp_size() == 0, "wrong!"); + + constexpr index_t NumWarp = kBlockSize / get_warp_size(); + + if constexpr(NumWarp == 4 && kMPerBlock % 128 == 0 && + kNPerBlock % 128 == 0 % kKPerBlock % 16 == 0) + { + return make_tuple(WarpGemmMfmaF16F16F32M32N32K16{}, 2, 2); + } + else + { + return make_tuple(WarpGemmMfmaF16F16F32M32N32K16{}, 2, 2); + } +#else + return make_tuple(WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution{}, 4, 1); +#endif + } + else if constexpr(is_same_v && + is_same_v && + is_same_v) + { + return make_tuple(WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution{}, 4, 1); + } + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v1.hpp b/include/ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v1.hpp new file mode 100644 index 0000000000..0e3ef2d794 --- /dev/null +++ b/include/ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v1.hpp @@ -0,0 +1,201 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" + +namespace ck_tile { + +// A Tile Window: global memory +// B Tile Window: global memory +// C Distributed tensor: register +template +struct BlockGemmPipelineAGmemBGmemCRegV1 +{ + using ADataType = remove_cvref_t; + using BDataType = remove_cvref_t; + using CDataType = remove_cvref_t; + using BlockGemmShape = remove_cvref_t; + + static constexpr index_t kBlockSize = Problem::kBlockSize; + + static constexpr index_t kMPerBlock = BlockGemmShape::kM; + static constexpr index_t kNPerBlock = BlockGemmShape::kN; + static constexpr index_t kKPerBlock = BlockGemmShape::kK; + + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetStaticLdsSize() + { + return ck_tile::integer_divide_ceil( + sizeof(ADataType) * + Policy::template MakeALdsBlockDescriptor().get_element_space_size(), + 16) * + 16 + + sizeof(BDataType) * + Policy::template MakeBLdsBlockDescriptor().get_element_space_size(); + } + + template + CK_TILE_HOST_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, + const AElementFunction& a_element_func, + const BDramBlockWindowTmp& b_dram_block_window_tmp, + const BElementFunction& b_element_func, + index_t num_loop, + void* p_smem) const + { + static_assert( + is_same_v> && + is_same_v>, + "wrong!"); + + static_assert(kMPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kNPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kKPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<1>{}], + "wrong!"); + + // A tile in LDS + ADataType* p_a_lds = static_cast(p_smem); + + constexpr auto a_lds_block_desc = Policy::template MakeALdsBlockDescriptor(); + + auto a_lds_block = make_tensor_view(p_a_lds, a_lds_block_desc); + + constexpr index_t a_lds_block_space_size_aligned = + math::integer_divide_ceil(sizeof(ADataType) * a_lds_block_desc.get_element_space_size(), + 16) * + 16; + + // B tile in LDS + BDataType* p_b_lds = static_cast( + static_cast(static_cast(p_smem) + a_lds_block_space_size_aligned)); + + constexpr auto b_lds_block_desc = Policy::template MakeBLdsBlockDescriptor(); + + auto b_lds_block = make_tensor_view(p_b_lds, b_lds_block_desc); + + // A DRAM tile window for load + auto a_copy_dram_window = + make_tile_window(a_dram_block_window_tmp.get_bottom_tensor_view(), + make_tuple(number{}, number{}), + a_dram_block_window_tmp.get_window_origin(), + Policy::template MakeADramTileDistribution()); + + // A LDS tile window for store + auto a_copy_lds_window = + make_tile_window(a_lds_block, + make_tuple(number{}, number{}), + {0, 0}, + a_copy_dram_window.get_tile_distribution()); + + // B DRAM tile window for load + auto b_copy_dram_window = + make_tile_window(b_dram_block_window_tmp.get_bottom_tensor_view(), + make_tuple(number{}, number{}), + b_dram_block_window_tmp.get_window_origin(), + Policy::template MakeBDramTileDistribution()); + + // B LDS tile window for store + auto b_copy_lds_window = + make_tile_window(b_lds_block, + make_tuple(number{}, number{}), + {0, 0}, + b_copy_dram_window.get_tile_distribution()); + + // A LDS tile for block GEMM + auto a_lds_gemm_window = make_tile_window( + a_lds_block, make_tuple(number{}, number{}), {0, 0}); + + // B LDS tile for block GEMM + auto b_lds_gemm_window = make_tile_window( + b_lds_block, make_tuple(number{}, number{}), {0, 0}); + + // Block GEMM + constexpr auto block_gemm = Policy::template GetBlockGemm(); + + // Acc register tile + auto c_block_tile = decltype(block_gemm(a_lds_gemm_window, b_lds_gemm_window)){}; + + // prefetch + // global read 0 + auto a_block_tile = load_tile(a_copy_dram_window); + auto b_block_tile = load_tile(b_copy_dram_window); + + { + // move to 1 + move_tile_window(a_copy_dram_window, {0, kKPerBlock}); + move_tile_window(b_copy_dram_window, {0, kKPerBlock}); + + // initialize C + tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile); + + // LDS write 0 + const auto a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile); + store_tile(a_copy_lds_window, a_block_tile_tmp); + + // LDS write 0 + const auto b_block_tile_tmp = tile_elementwise_in(b_element_func, b_block_tile); + store_tile(b_copy_lds_window, b_block_tile_tmp); + } + + index_t iCounter = num_loop - 1; + + do + { + // global read i + 1 + a_block_tile = load_tile(a_copy_dram_window); + b_block_tile = load_tile(b_copy_dram_window); + + block_sync_lds(); + + // GEMM i + block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window); + + block_sync_lds(); + + // move to i + 2 + move_tile_window(a_copy_dram_window, {0, kKPerBlock}); + move_tile_window(b_copy_dram_window, {0, kKPerBlock}); + + // LDS write i + 1 + const auto a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile); + store_tile(a_copy_lds_window, a_block_tile_tmp); + + // LDS write i + 1 + const auto b_block_tile_tmp = tile_elementwise_in(b_element_func, b_block_tile); + store_tile(b_copy_lds_window, b_block_tile_tmp); + + iCounter--; + + } while(iCounter > 0); + + // tail + { + block_sync_lds(); + + // GEMM num_loop - 1 + block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window); + } + + return c_block_tile; + } + + template + CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, + const BDramBlockWindowTmp& b_dram_block_window_tmp, + index_t num_loop, + void* p_smem) const + { + return operator()( + a_dram_block_window_tmp, + [](const ADataType& a) { return a; }, + b_dram_block_window_tmp, + [](const BDataType& b) { return b; }, + num_loop, + p_smem); + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp b/include/ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp new file mode 100644 index 0000000000..f706900013 --- /dev/null +++ b/include/ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp @@ -0,0 +1,251 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" + +namespace ck_tile { + +// Default policy for BlockGemmPipelineAGmemBGmemCRegV1 +// Default policy class should not be templated, put template on member functions instead +struct BlockGemmPipelineAGmemBGmemCRegV1DefaultPolicy +{ +#if 0 + // 2d + template + CK_TILE_HOST_DEVICE static constexpr auto MakeALdsBlockDescriptor() + { + using namespace ck_tile; + + constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM; + constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK; + + constexpr auto a_lds_block_desc = + make_naive_tensor_descriptor_packed(make_tuple(kMPerBlock, kKPerBlock), number<32>{}); + + return a_lds_block_desc; + } + + // 2d + template + CK_TILE_HOST_DEVICE static constexpr auto MakeBLdsBlockDescriptor() + { + using namespace ck_tile; + + constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN; + constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK; + + constexpr auto b_lds_block_desc = + make_naive_tensor_descriptor_packed(make_tuple(kNPerBlock, kKPerBlock), number<32>{}); + + return b_lds_block_desc; + } +#elif 1 + // 3d + padding + template + CK_TILE_HOST_DEVICE static constexpr auto MakeALdsBlockDescriptor() + { + using namespace ck_tile; + + constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM; + constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK; + + constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor( + make_tuple(number{}, number{}, number<8>{}), + make_tuple(number<(kMPerBlock + 1) * 8>{}, number<8>{}, number<1>{}), + number<8>{}, + number<1>{}); + + constexpr auto a_lds_block_desc = transform_tensor_descriptor( + a_lds_block_desc_0, + make_tuple(make_pass_through_transform(kMPerBlock), + make_merge_transform(make_tuple(kKPerBlock / 8, 8))), + make_tuple(sequence<1>{}, sequence<0, 2>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + + return a_lds_block_desc; + } + + // 3d + padding + template + CK_TILE_HOST_DEVICE static constexpr auto MakeBLdsBlockDescriptor() + { + using namespace ck_tile; + + constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN; + constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK; + + constexpr auto b_lds_block_desc_0 = make_naive_tensor_descriptor( + make_tuple(number{}, number{}, number<8>{}), + make_tuple(number<(kNPerBlock + 1) * 8>{}, number<8>{}, number<1>{}), + number<8>{}, + number<1>{}); + + constexpr auto b_lds_block_desc = transform_tensor_descriptor( + b_lds_block_desc_0, + make_tuple(make_pass_through_transform(kNPerBlock), + make_merge_transform(make_tuple(kKPerBlock / 8, 8))), + make_tuple(sequence<1>{}, sequence<0, 2>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + + return b_lds_block_desc; + } +#elif 1 + // fake XOR + template + CK_TILE_HOST_DEVICE static constexpr auto MakeALdsBlockDescriptor() + { + using namespace ck_tile; + + using ADataType = remove_cvref_t; + + constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM; + constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK; + + constexpr auto a_lds_block_desc_d1_d2_d3 = make_naive_tensor_descriptor_packed( + make_tuple(number{}, number<2>{}, number{}), + number{}); + + constexpr index_t kK1 = 16 / sizeof(ADataType); + + constexpr auto a_lds_block_desc_d4_d5_d6 = transform_tensor_descriptor( + a_lds_block_desc_d1_d2_d3, + make_tuple( + make_xor_transform(make_tuple(number{}, number{}), kK1), + make_pass_through_transform(2)), + make_tuple(sequence<0, 2>{}, sequence<1>{}), + make_tuple(sequence<0, 2>{}, sequence<1>{})); + + constexpr auto a_lds_block_desc_m_k = transform_tensor_descriptor( + a_lds_block_desc_d4_d5_d6, + make_tuple(make_merge_transform(make_tuple(number{}, number<2>{})), + make_pass_through_transform(kKPerBlock)), + make_tuple(sequence<0, 1>{}, sequence<2>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + + return a_lds_block_desc_m_k; + } + + // fake XOR + template + CK_TILE_HOST_DEVICE static constexpr auto MakeBLdsBlockDescriptor() + { + using namespace ck_tile; + + using BDataType = remove_cvref_t; + + constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN; + constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK; + + constexpr auto b_lds_block_desc_d1_d2_d3 = make_naive_tensor_descriptor_packed( + make_tuple(number{}, number<2>{}, number{}), + number{}); + + constexpr index_t kK1 = 16 / sizeof(BDataType); + + constexpr auto b_lds_block_desc_d4_d5_d6 = transform_tensor_descriptor( + b_lds_block_desc_d1_d2_d3, + make_tuple( + make_xor_transform(make_tuple(number{}, number{}), kK1), + make_pass_through_transform(2)), + make_tuple(sequence<0, 2>{}, sequence<1>{}), + make_tuple(sequence<0, 2>{}, sequence<1>{})); + + constexpr auto b_lds_block_desc_n_k = transform_tensor_descriptor( + b_lds_block_desc_d4_d5_d6, + make_tuple(make_merge_transform(make_tuple(number{}, number<2>{})), + make_pass_through_transform(kKPerBlock)), + make_tuple(sequence<0, 1>{}, sequence<2>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + + return b_lds_block_desc_n_k; + } +#endif + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeADramTileDistribution() + { + using ADataType = remove_cvref_t; + + constexpr index_t kBlockSize = Problem::kBlockSize; + + constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM; + constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK; + + constexpr index_t K1 = 16 / sizeof(ADataType); + constexpr index_t K0 = kKPerBlock / K1; + constexpr index_t M2 = get_warp_size() / K0; +#if 1 // coalesce reading for each blocks + constexpr index_t M1 = kBlockSize / get_warp_size(); + constexpr index_t M0 = kMPerBlock / (M2 * M1); + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<2, 0>>, + sequence<1, 2>, + sequence<0, 1>>{}); +#else // coalesce reading for each warps + constexpr index_t M0 = kBlockSize / get_warp_size(); + constexpr index_t M1 = kMPerBlock / (M2 * M0); + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<2, 0>>, + sequence<1, 2>, + sequence<1, 1>>{}); +#endif + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeBDramTileDistribution() + { + using BDataType = remove_cvref_t; + + constexpr index_t kBlockSize = Problem::kBlockSize; + + constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN; + constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK; + + constexpr index_t K1 = 16 / sizeof(BDataType); + constexpr index_t K0 = kKPerBlock / K1; + constexpr index_t N2 = get_warp_size() / K0; +#if 1 // coalesce reading for each blocks + constexpr index_t N1 = kBlockSize / get_warp_size(); + constexpr index_t N0 = kNPerBlock / (N2 * N1); + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<2, 0>>, + sequence<1, 2>, + sequence<0, 1>>{}); +#else // coalesce reading for each warps + constexpr index_t N0 = kBlockSize / get_warp_size(); + constexpr index_t N1 = kNPerBlock / (N2 * N0); + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<2, 0>>, + sequence<1, 2>, + sequence<1, 1>>{}); +#endif + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm() + { + using BlockGemmPolicy = BlockGemmASmemBSmemCRegV1DefaultPolicy; + + return BlockGemmASmemBSmemCRegV1{}; + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v2.hpp b/include/ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v2.hpp new file mode 100644 index 0000000000..c9f775ceac --- /dev/null +++ b/include/ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v2.hpp @@ -0,0 +1,219 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" + +namespace ck_tile { + +// A Tile Window: global memory +// B Tile Window: global memory +// C Distributed tensor: register +template +struct BlockGemmPipelineAGmemBGmemCRegV2 +{ + using ADataType = remove_cvref_t; + using BDataType = remove_cvref_t; + using CDataType = remove_cvref_t; + using BlockGemmShape = remove_cvref_t; + + static constexpr index_t kBlockSize = Problem::kBlockSize; + + static constexpr index_t kMPerBlock = BlockGemmShape::kM; + static constexpr index_t kNPerBlock = BlockGemmShape::kN; + static constexpr index_t kKPerBlock = BlockGemmShape::kK; + + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetStaticLdsSize() + { + return ck_tile::integer_divide_ceil( + sizeof(ADataType) * + Policy::template MakeALdsBlockDescriptor().get_element_space_size(), + 16) * + 16 + + sizeof(BDataType) * + Policy::template MakeBLdsBlockDescriptor().get_element_space_size(); + } + + template + CK_TILE_HOST_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, + const AElementFunction& a_element_func, + const BDramBlockWindowTmp& b_dram_block_window_tmp, + const BElementFunction& b_element_func, + index_t num_loop, + void* p_smem) const + { + static_assert( + is_same_v> && + is_same_v>, + "wrong!"); + + static_assert(kMPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kNPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kKPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<1>{}], + "wrong!"); + + // A tile in LDS + ADataType* p_a_lds = static_cast(p_smem); + + constexpr auto a_lds_block_desc = Policy::template MakeALdsBlockDescriptor(); + + auto a_lds_block = make_tensor_view(p_a_lds, a_lds_block_desc); + + constexpr index_t a_lds_block_space_size_aligned = + math::integer_divide_ceil(sizeof(ADataType) * a_lds_block_desc.get_element_space_size(), + 16) * + 16; + + // B tile in LDS + BDataType* p_b_lds = static_cast( + static_cast(static_cast(p_smem) + a_lds_block_space_size_aligned)); + + constexpr auto b_lds_block_desc = Policy::template MakeBLdsBlockDescriptor(); + + auto b_lds_block = make_tensor_view(p_b_lds, b_lds_block_desc); + + // A DRAM tile window for load + auto a_copy_dram_window = + make_tile_window(a_dram_block_window_tmp.get_bottom_tensor_view(), + make_tuple(number{}, number{}), + a_dram_block_window_tmp.get_window_origin(), + Policy::template MakeADramTileDistribution()); + + // A LDS tile window for store + auto a_copy_lds_window = + make_tile_window(a_lds_block, + make_tuple(number{}, number{}), + {0, 0}, + a_copy_dram_window.get_tile_distribution()); + + // B DRAM tile window for load + auto b_copy_dram_window = + make_tile_window(b_dram_block_window_tmp.get_bottom_tensor_view(), + make_tuple(number{}, number{}), + b_dram_block_window_tmp.get_window_origin(), + Policy::template MakeBDramTileDistribution()); + + // B LDS tile window for store + auto b_copy_lds_window = + make_tile_window(b_lds_block, + make_tuple(number{}, number{}), + {0, 0}, + b_copy_dram_window.get_tile_distribution()); + + // A LDS tile for block GEMM + auto a_lds_gemm_window = make_tile_window( + a_lds_block, make_tuple(number{}, number{}), {0, 0}); + + // B LDS tile for block GEMM + auto b_lds_gemm_window = make_tile_window( + b_lds_block, make_tuple(number{}, number{}), {0, 0}); + + // Block GEMM + constexpr auto block_gemm = Policy::template GetBlockGemm(); + + // Acc register tile + auto c_block_tile = decltype(block_gemm(a_lds_gemm_window, b_lds_gemm_window)){}; + + // prefetch + // global read 0 + auto a_block_tile = load_tile(a_copy_dram_window); + auto b_block_tile = load_tile(b_copy_dram_window); + + { + // move to 1 + move_tile_window(a_copy_dram_window, {0, kKPerBlock}); + move_tile_window(b_copy_dram_window, {0, kKPerBlock}); + + // initialize C + tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile); + + // LDS write 0 + const auto a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile); + store_tile(a_copy_lds_window, a_block_tile_tmp); + // global read 1 + a_block_tile = load_tile(a_copy_dram_window); + + // LDS write 0 + const auto b_block_tile_tmp = tile_elementwise_in(b_element_func, b_block_tile); + store_tile(b_copy_lds_window, b_block_tile_tmp); + // global read 1 + b_block_tile = load_tile(b_copy_dram_window); + } + + index_t iCounter = num_loop - 2; + + do + { + block_sync_lds(); + + // GEMM i + block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window); + + block_sync_lds(); + + // move to i + 2 + move_tile_window(a_copy_dram_window, {0, kKPerBlock}); + move_tile_window(b_copy_dram_window, {0, kKPerBlock}); + + // LDS write i + 1 + const auto a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile); + store_tile(a_copy_lds_window, a_block_tile_tmp); + // global read i + 2 + a_block_tile = load_tile(a_copy_dram_window); + + // LDS write i + 1 + const auto b_block_tile_tmp = tile_elementwise_in(b_element_func, b_block_tile); + store_tile(b_copy_lds_window, b_block_tile_tmp); + // global read i + 2 + b_block_tile = load_tile(b_copy_dram_window); + + iCounter--; + + } while(iCounter > 0); + + // tail + { + block_sync_lds(); + + // GEMM num_loop - 2 + block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window); + + block_sync_lds(); + + // LDS write num_loop - 1 + const auto a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile); + store_tile(a_copy_lds_window, a_block_tile_tmp); + + const auto b_block_tile_tmp = tile_elementwise_in(b_element_func, b_block_tile); + store_tile(b_copy_lds_window, b_block_tile_tmp); + + block_sync_lds(); + + // GEMM num_loop - 1 + block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window); + } + + return c_block_tile; + } + + template + CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, + const BDramBlockWindowTmp& b_dram_block_window_tmp, + index_t num_loop, + void* p_smem) const + { + return operator()( + a_dram_block_window_tmp, + [](const ADataType& a) { return a; }, + b_dram_block_window_tmp, + [](const BDataType& b) { return b; }, + num_loop, + p_smem); + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v2_default_policy.hpp b/include/ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v2_default_policy.hpp new file mode 100644 index 0000000000..0596408501 --- /dev/null +++ b/include/ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v2_default_policy.hpp @@ -0,0 +1,18 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" + +namespace ck_tile { + +// Default policy for BlockGemmPipelineAGmemBGmemCRegV2 +// Default policy class should not be templated, put template on member functions instead +// NOTE: policy should be binded to its corresponding operation. It's just a coincidence that +// BlockGemmPipelineAGmemBGmemCRegV2DefaultPolicy is the same as +// BlockGemmPipelineAGmemBGmemCRegV1DefaultPolicy +using BlockGemmPipelineAGmemBGmemCRegV2DefaultPolicy = + BlockGemmPipelineAGmemBGmemCRegV1DefaultPolicy; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/gemm/pipeline/block_gemm_pipeline_problem.hpp b/include/ck_tile/ops/gemm/pipeline/block_gemm_pipeline_problem.hpp new file mode 100644 index 0000000000..62165ebce2 --- /dev/null +++ b/include/ck_tile/ops/gemm/pipeline/block_gemm_pipeline_problem.hpp @@ -0,0 +1,25 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" + +namespace ck_tile { + +template +struct BlockGemmPipelineProblem +{ + using ADataType = remove_cvref_t; + using BDataType = remove_cvref_t; + using CDataType = remove_cvref_t; + using BlockGemmShape = remove_cvref_t; + + static constexpr index_t kBlockSize = kBlockSize_; +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp b/include/ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp new file mode 100644 index 0000000000..f3c4d8bf67 --- /dev/null +++ b/include/ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp @@ -0,0 +1,18 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" + +namespace ck_tile { + +template +struct TileGemmShape +{ + static constexpr index_t kM = kMPerTile; + static constexpr index_t kN = kNPerTile; + static constexpr index_t kK = kKPerTile; +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm.hpp new file mode 100644 index 0000000000..705af423c4 --- /dev/null +++ b/include/ck_tile/ops/gemm/warp/warp_gemm.hpp @@ -0,0 +1,103 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" + +namespace ck_tile { + +// fp16 +using WarpGemmMfmaF16F16F32M32N32K8 = + WarpGemmImpl>; + +using WarpGemmMfmaF16F16F32M16N16K16 = + WarpGemmImpl>; + +using WarpGemmMfmaF16F16F32M32N32K16 = + WarpGemmImpl>; + +using WarpGemmMfmaF16F16F32M16N16K32 = + WarpGemmImpl>; + +using WarpGemmMfmaF16F16F32M32N32K8TransposedCDistribution = WarpGemmImpl< + WarpGemmAtrributeMfmaTransposedCDistribution>; + +using WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution = WarpGemmImpl< + WarpGemmAtrributeMfmaTransposedCDistribution>; + +using WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution = + WarpGemmImpl>; + +using WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution = + WarpGemmImpl>; + +using WarpGemmMfmaF16F16F32M16N16K32SwizzleBTransposedCDistribution = + WarpGemmImpl>; + +// bf16 +using WarpGemmMfmaBf16Bf16F32M32N32K8 = + WarpGemmImpl>; + +using WarpGemmMfmaBf16Bf16F32M16N16K16 = + WarpGemmImpl>; + +using WarpGemmMfmaBf16Bf16F32M32N32K16 = + WarpGemmImpl>; + +using WarpGemmMfmaBf16Bf16F32M16N16K32 = + WarpGemmImpl>; + +using WarpGemmMfmaBf16Bf16F32M32N32K8TransposedCDistribution = WarpGemmImpl< + WarpGemmAtrributeMfmaTransposedCDistribution>; + +using WarpGemmMfmaBf16Bf16F32M16N16K16TransposedCDistribution = WarpGemmImpl< + WarpGemmAtrributeMfmaTransposedCDistribution>; + +using WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution = + WarpGemmImpl>; + +using WarpGemmMfmaBf16Bf16F32M16N16K32TransposedCDistribution = + WarpGemmImpl>; + +using WarpGemmMfmaBf16Bf16F32M16N16K32SwizzleBTransposedCDistribution = + WarpGemmImpl>; + +// fp8 +using WarpGemmMfma_f32_32x32x16_fp8_fp8 = + WarpGemmImpl>; + +using WarpGemmMfma_f32_32x32x16_fp8_bf8 = + WarpGemmImpl>; + +using WarpGemmMfma_f32_32x32x16_bf8_fp8 = + WarpGemmImpl>; + +using WarpGemmMfma_f32_32x32x16_bf8_bf8 = + WarpGemmImpl>; + +using WarpGemmMfma_f32_32x32x16_fp8_fp8_CTransposed = WarpGemmImpl< + WarpGemmAtrributeMfmaTransposedCDistribution>; + +using WarpGemmMfma_f32_32x32x16_fp8_bf8_CTransposed = WarpGemmImpl< + WarpGemmAtrributeMfmaTransposedCDistribution>; + +using WarpGemmMfma_f32_32x32x16_bf8_fp8_CTransposed = WarpGemmImpl< + WarpGemmAtrributeMfmaTransposedCDistribution>; + +using WarpGemmMfma_f32_32x32x16_bf8_bf8_CTransposed = WarpGemmImpl< + WarpGemmAtrributeMfmaTransposedCDistribution>; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp new file mode 100644 index 0000000000..f7d048a015 --- /dev/null +++ b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp @@ -0,0 +1,455 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" + +namespace ck_tile { + +template +struct WarpGemmAtrributeMfma +{ + using Impl = remove_cvref_t; + + using ADataType = typename Impl::ADataType; + using BDataType = typename Impl::BDataType; + using CDataType = typename Impl::CDataType; + + using AVecType = typename Impl::AVecType; + using BVecType = typename Impl::BVecType; + using CVecType = typename Impl::CVecType; + + static constexpr index_t kM = Impl::kM; + static constexpr index_t kN = Impl::kN; + static constexpr index_t kK = Impl::kK; + + using AWarpDstrEncoding = tile_distribution_encoding< + sequence<>, + tuple, sequence>, + tuple>, + tuple>, + sequence<2>, + sequence<1>>; + + using BWarpDstrEncoding = tile_distribution_encoding< + sequence<>, + tuple, sequence>, + tuple>, + tuple>, + sequence<2>, + sequence<1>>; + + using CWarpDstrEncoding = tile_distribution_encoding< + sequence<>, + tuple, + sequence>, + tuple>, + tuple>, + sequence<1, 1>, + sequence<0, 2>>; + + // c_vec += a_vec * b_vec + CK_TILE_DEVICE void + operator()(CVecType& c_vec, const AVecType& a_vec, const BVecType& b_vec) const + { + Impl{}(c_vec, a_vec, b_vec); + } + + // c_vec = a_vec * b_vec + CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const + { + return Impl{}(a_vec, b_vec); + } +}; + +template +struct WarpGemmAtrributeMfmaIterateK +{ + static_assert(kKIter > 0, "wrong!"); + + using Impl = remove_cvref_t; + + using ADataType = typename Impl::ADataType; + using BDataType = typename Impl::BDataType; + using CDataType = typename Impl::CDataType; + + using AVecType = typename vector_type_maker::type::type; + using BVecType = typename vector_type_maker::type::type; + using CVecType = typename Impl::CVecType; + + static constexpr index_t kM = Impl::kM; + static constexpr index_t kN = Impl::kN; + static constexpr index_t kK = Impl::kK * kKIter; + + using AWarpDstrEncoding = tile_distribution_encoding< + sequence<>, + tuple, sequence>, + tuple>, + tuple>, + sequence<2>, + sequence<1>>; + + using BWarpDstrEncoding = tile_distribution_encoding< + sequence<>, + tuple, sequence>, + tuple>, + tuple>, + sequence<2>, + sequence<1>>; + + using CWarpDstrEncoding = tile_distribution_encoding< + sequence<>, + tuple, + sequence>, + tuple>, + tuple>, + sequence<1, 1>, + sequence<0, 2>>; + + // c_vec += a_vec * b_vec + CK_TILE_DEVICE void + operator()(CVecType& c_vec, const AVecType& a_vec, const BVecType& b_vec) const + { + const auto a_vector = typename vector_type_maker::type{a_vec}; + const auto b_vector = typename vector_type_maker::type{b_vec}; + + static_for<0, kKIter, 1>{}([&](auto iKIter) { + Impl{}(c_vec, + a_vector.template AsType()[iKIter], + b_vector.template AsType()[iKIter]); + }); + } + + // c_vec = a_vec * b_vec + CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const + { + const auto a_vector = typename vector_type_maker::type{a_vec}; + const auto b_vector = typename vector_type_maker::type{b_vec}; + + constexpr auto I0 = number<0>{}; + + // c = a * b + auto c_vec = Impl{}(a_vector.template AsType()[I0], + b_vector.template AsType()[I0]); + + // c += a * b + static_for<1, kKIter, 1>{}([&](auto iKIter) { + Impl{}(c_vec, + a_vector.template AsType()[iKIter], + b_vector.template AsType()[iKIter]); + }); + + return c_vec; + } +}; + +template +struct WarpGemmAtrributeMfmaTransposedCDistribution +{ + using Impl = remove_cvref_t; + + using ADataType = typename Impl::BDataType; + using BDataType = typename Impl::ADataType; + using CDataType = typename Impl::CDataType; + + using AVecType = typename Impl::BVecType; + using BVecType = typename Impl::AVecType; + using CVecType = typename Impl::CVecType; + + static constexpr index_t kM = Impl::kN; + static constexpr index_t kN = Impl::kM; + static constexpr index_t kK = Impl::kK; + + using AWarpDstrEncoding = tile_distribution_encoding< + sequence<>, + tuple, sequence>, + tuple>, + tuple>, + sequence<2>, + sequence<1>>; + + using BWarpDstrEncoding = tile_distribution_encoding< + sequence<>, + tuple, sequence>, + tuple>, + tuple>, + sequence<2>, + sequence<1>>; + + using CWarpDstrEncoding = tile_distribution_encoding< + sequence<>, + tuple, + sequence>, + tuple>, + tuple>, + sequence<2, 2>, + sequence<0, 2>>; + + // c_vec += a_vec * b_vec + CK_TILE_DEVICE void + operator()(CVecType& c_vec, const AVecType& a_vec, const BVecType& b_vec) const + { + // swap A and B + Impl{}(c_vec, b_vec, a_vec); + } + + // c_vec = a_vec * b_vec + CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const + { + // swap A and B + return Impl{}(b_vec, a_vec); + } +}; + +template +struct WarpGemmAtrributeMfmaTransposedCDistribution_SwizzleB +{ + using Impl = remove_cvref_t; + + using ADataType = typename Impl::BDataType; + using BDataType = typename Impl::ADataType; + using CDataType = typename Impl::CDataType; + + using AVecType = typename Impl::BVecType; + using BVecType = typename Impl::AVecType; + using CVecType = typename Impl::CVecType; + + static constexpr index_t kM = Impl::kN; + static constexpr index_t kN = Impl::kM; + static constexpr index_t kK = Impl::kK; + + using AWarpDstrEncoding = tile_distribution_encoding< + sequence<>, + tuple, sequence>, + tuple>, + tuple>, + sequence<2>, + sequence<1>>; + + using BWarpDstrEncoding = tile_distribution_encoding< + sequence<>, + tuple, + sequence>, + tuple>, + tuple>, + sequence<2>, + sequence<1>>; + + using CWarpDstrEncoding = tile_distribution_encoding< + sequence<>, + tuple, + sequence>, + tuple>, + tuple>, + sequence<2, 2>, + sequence<0, 2>>; + + // c_vec += a_vec * b_vec + CK_TILE_DEVICE void + operator()(CVecType& c_vec, const AVecType& a_vec, const BVecType& b_vec) const + { + // swap A and B + Impl{}(c_vec, b_vec, a_vec); + } + + // c_vec = a_vec * b_vec + CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const + { + // swap A and B + return Impl{}(b_vec, a_vec); + } +}; + +template +struct WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution +{ + using Impl = remove_cvref_t; + + // swap A and B + using ADataType = typename Impl::BDataType; + using BDataType = typename Impl::ADataType; + using CDataType = typename Impl::CDataType; + + using AVecType = typename vector_type_maker::type::type; + using BVecType = typename vector_type_maker::type::type; + using CVecType = typename Impl::CVecType; + + static constexpr index_t kM = Impl::kN; + static constexpr index_t kN = Impl::kM; + static constexpr index_t kK = Impl::kK * kKIter; + + using AWarpDstrEncoding = tile_distribution_encoding< + sequence<>, + tuple, sequence>, + tuple>, + tuple>, + sequence<2>, + sequence<1>>; + + using BWarpDstrEncoding = tile_distribution_encoding< + sequence<>, + tuple, sequence>, + tuple>, + tuple>, + sequence<2>, + sequence<1>>; + + using CWarpDstrEncoding = tile_distribution_encoding< + sequence<>, + tuple, + sequence>, + tuple>, + tuple>, + sequence<2, 2>, + sequence<0, 2>>; + + // c_vec += a_vec * b_vec + CK_TILE_DEVICE void + operator()(CVecType& c_vec, const AVecType& a_vec, const BVecType& b_vec) const + { + const auto a_vector = typename vector_type_maker::type{a_vec}; + + const auto b_vector = typename vector_type_maker::type{b_vec}; + + // swap A and B, value and type + static_for<0, kKIter, 1>{}([&](auto iKIter) { + Impl{}(c_vec, + b_vector.template AsType()[iKIter], + a_vector.template AsType()[iKIter]); + }); + } + + // c_vec = a_vec * b_vec + CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const + { + const auto a_vector = typename vector_type_maker::type{a_vec}; + const auto b_vector = typename vector_type_maker::type{b_vec}; + + constexpr auto I0 = number<0>{}; + + // swap A and B, value and type + auto c_vec = Impl{}(b_vector.template AsType()[I0], + a_vector.template AsType()[I0]); + + static_for<1, kKIter, 1>{}([&](auto iKIter) { + Impl{}(c_vec, + b_vector.template AsType()[iKIter], + a_vector.template AsType()[iKIter]); + }); + + return c_vec; + } +}; + +template +struct WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB +{ + using Impl = remove_cvref_t; + + // swap A and B + using ADataType = typename Impl::BDataType; + using BDataType = typename Impl::ADataType; + using CDataType = typename Impl::CDataType; + + using AVecType = typename vector_type_maker::type::type; + using BVecType = typename vector_type_maker::type::type; + using CVecType = typename Impl::CVecType; + + static constexpr index_t kM = Impl::kN; + static constexpr index_t kN = Impl::kM; + static constexpr index_t kK = Impl::kK * kKIter; + static constexpr index_t SFactor = SFactor_; // group how many CM1 together + + using AWarpDstrEncoding = tile_distribution_encoding< + sequence<>, + tuple, sequence>, + tuple>, + tuple>, + sequence<2>, + sequence<1>>; +#if 0 + using BWarpDstrEncoding = tile_distribution_encoding< + sequence<>, + tuple, + sequence>, + tuple>, + tuple>, + sequence<2>, + sequence<1>>; + + using CWarpDstrEncoding = tile_distribution_encoding< + sequence<>, + tuple, + sequence>, + tuple>, + tuple>, + sequence<2, 2>, + sequence<0, 2>>; +#else + using BWarpDstrEncoding = tile_distribution_encoding< + sequence<>, + tuple, + sequence>, + tuple>, + tuple>, + sequence<2>, + sequence<1>>; + + using CWarpDstrEncoding = tile_distribution_encoding< + sequence<>, + tuple, + sequence>, + tuple>, + tuple>, + sequence<2, 2>, + sequence<0, 2>>; +#endif + // c_vec += a_vec * b_vec + CK_TILE_DEVICE void + operator()(CVecType& c_vec, const AVecType& a_vec, const BVecType& b_vec) const + { + const auto a_vector = typename vector_type_maker::type{a_vec}; + + const auto b_vector = typename vector_type_maker::type{b_vec}; + + // swap A and B, value and type + static_for<0, kKIter, 1>{}([&](auto iKIter) { + Impl{}(c_vec, + b_vector.template AsType()[iKIter], + a_vector.template AsType()[iKIter]); + }); + } + + // c_vec = a_vec * b_vec + CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const + { + const auto a_vector = typename vector_type_maker::type{a_vec}; + const auto b_vector = typename vector_type_maker::type{b_vec}; + + constexpr auto I0 = number<0>{}; + + // swap A and B, value and type + auto c_vec = Impl{}(b_vector.template AsType()[I0], + a_vector.template AsType()[I0]); + + static_for<1, kKIter, 1>{}([&](auto iKIter) { + Impl{}(c_vec, + b_vector.template AsType()[iKIter], + a_vector.template AsType()[iKIter]); + }); + + return c_vec; + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp new file mode 100644 index 0000000000..d67c396651 --- /dev/null +++ b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp @@ -0,0 +1,247 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" + +namespace ck_tile { + +// FP16 +struct WarpGemmAttributeMfmaImplF16F16F32M32N32K8 +{ + using ADataType = half_t; + using BDataType = half_t; + using CDataType = float; + + using AVecType = typename vector_type::type; + using BVecType = typename vector_type::type; + using CVecType = typename vector_type::type; + + static constexpr index_t kM = 32; + static constexpr index_t kN = 32; + static constexpr index_t kK = 8; + + static constexpr index_t kAMLane = 32; + static constexpr index_t kBNLane = 32; + static constexpr index_t kABKLane = 2; + static constexpr index_t kABKPerLane = 4; + + static constexpr index_t kCMLane = 2; + static constexpr index_t kCNLane = 32; + static constexpr index_t kCM0PerLane = 4; + static constexpr index_t kCM1PerLane = 4; + + // c_vec += a_vec * b_vec + CK_TILE_DEVICE void + operator()(CVecType& c_vec, const AVecType& a_vec, const BVecType& b_vec) const + { + c_vec = __builtin_amdgcn_mfma_f32_32x32x8f16(a_vec, b_vec, c_vec, 0, 0, 0); + } + + // c_vec = a_vec * b_vec + CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const + { + return __builtin_amdgcn_mfma_f32_32x32x8f16(a_vec, b_vec, CVecType{0.f}, 0, 0, 0); + } +}; + +struct WarpGemmAttributeMfmaImplF16F16F32M16N16K16 +{ + using ADataType = half_t; + using BDataType = half_t; + using CDataType = float; + + using AVecType = typename vector_type::type; + using BVecType = typename vector_type::type; + using CVecType = typename vector_type::type; + + static constexpr index_t kM = 16; + static constexpr index_t kN = 16; + static constexpr index_t kK = 16; + + static constexpr index_t kAMLane = 16; + static constexpr index_t kBNLane = 16; + static constexpr index_t kABKLane = 4; + static constexpr index_t kABKPerLane = 4; + + static constexpr index_t kCMLane = 4; + static constexpr index_t kCNLane = 16; + static constexpr index_t kCM0PerLane = 1; + static constexpr index_t kCM1PerLane = 4; + + // c_vec += a_vec * b_vec + CK_TILE_DEVICE void + operator()(CVecType& c_vec, const AVecType& a_vec, const BVecType& b_vec) const + { + c_vec = __builtin_amdgcn_mfma_f32_16x16x16f16(a_vec, b_vec, c_vec, 0, 0, 0); + } + + // c_vec = a_vec * b_vec + CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const + { + return __builtin_amdgcn_mfma_f32_16x16x16f16(a_vec, b_vec, CVecType{0.f}, 0, 0, 0); + } +}; + +// Bf16 +struct WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8 +{ + using ADataType = bhalf_t; + using BDataType = bhalf_t; + using CDataType = float; + + using AVecType = typename vector_type::type; + using BVecType = typename vector_type::type; + using CVecType = typename vector_type::type; + + static constexpr index_t kM = 32; + static constexpr index_t kN = 32; + static constexpr index_t kK = 8; + + static constexpr index_t kAMLane = 32; + static constexpr index_t kBNLane = 32; + static constexpr index_t kABKLane = 2; + static constexpr index_t kABKPerLane = 4; + + static constexpr index_t kCMLane = 2; + static constexpr index_t kCNLane = 32; + static constexpr index_t kCM0PerLane = 4; + static constexpr index_t kCM1PerLane = 4; + + // c_vec += a_vec * b_vec + CK_TILE_DEVICE void + operator()(CVecType& c_vec, const AVecType& a_vec, const BVecType& b_vec) const + { + c_vec = __builtin_amdgcn_mfma_f32_32x32x8bf16_1k(a_vec, b_vec, c_vec, 0, 0, 0); + } + + // c_vec = a_vec * b_vec + CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const + { + return __builtin_amdgcn_mfma_f32_32x32x8bf16_1k(a_vec, b_vec, CVecType{0.f}, 0, 0, 0); + } +}; + +struct WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16 +{ + using ADataType = bhalf_t; + using BDataType = bhalf_t; + using CDataType = float; + + using AVecType = typename vector_type::type; + using BVecType = typename vector_type::type; + using CVecType = typename vector_type::type; + + static constexpr index_t kM = 16; + static constexpr index_t kN = 16; + static constexpr index_t kK = 16; + + static constexpr index_t kAMLane = 16; + static constexpr index_t kBNLane = 16; + static constexpr index_t kABKLane = 4; + static constexpr index_t kABKPerLane = 4; + + static constexpr index_t kCMLane = 4; + static constexpr index_t kCNLane = 16; + static constexpr index_t kCM0PerLane = 1; + static constexpr index_t kCM1PerLane = 4; + + // c_vec += a_vec * b_vec + CK_TILE_DEVICE void + operator()(CVecType& c_vec, const AVecType& a_vec, const BVecType& b_vec) const + { + c_vec = __builtin_amdgcn_mfma_f32_16x16x16bf16_1k(a_vec, b_vec, c_vec, 0, 0, 0); + } + + // c_vec = a_vec * b_vec + CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const + { + return __builtin_amdgcn_mfma_f32_16x16x16bf16_1k(a_vec, b_vec, CVecType{0.f}, 0, 0, 0); + } +}; + +// FP8 +template +struct WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base +{ + using ADataType = AType_; + using BDataType = BType_; + using CDataType = float; + + using AVecType = typename vector_type::type; + using BVecType = typename vector_type::type; + using CVecType = typename vector_type::type; + + static constexpr index_t kM = 32; + static constexpr index_t kN = 32; + static constexpr index_t kK = 16; + + static constexpr index_t kAMLane = 32; + static constexpr index_t kBNLane = 32; + static constexpr index_t kABKLane = 2; + static constexpr index_t kABKPerLane = 8; + + static constexpr index_t kCMLane = 2; + static constexpr index_t kCNLane = 32; + static constexpr index_t kCM0PerLane = 4; + static constexpr index_t kCM1PerLane = 4; + + // c_vec += a_vec * b_vec + CK_TILE_DEVICE void + operator()(CVecType& c_vec, const AVecType& a_vec, const BVecType& b_vec) const + { +#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) + if constexpr(is_same_v && is_same_v) + c_vec = __builtin_amdgcn_mfma_f32_32x32x16_fp8_fp8( + bit_cast(a_vec), bit_cast(b_vec), c_vec, 0, 0, 0); + else if constexpr(is_same_v && is_same_v) + c_vec = __builtin_amdgcn_mfma_f32_32x32x16_fp8_bf8( + bit_cast(a_vec), bit_cast(b_vec), c_vec, 0, 0, 0); + else if constexpr(is_same_v && is_same_v) + c_vec = __builtin_amdgcn_mfma_f32_32x32x16_bf8_fp8( + bit_cast(a_vec), bit_cast(b_vec), c_vec, 0, 0, 0); + else if constexpr(is_same_v && is_same_v) + c_vec = __builtin_amdgcn_mfma_f32_32x32x16_bf8_bf8( + bit_cast(a_vec), bit_cast(b_vec), c_vec, 0, 0, 0); +#else + vector_type a_(a_vec); + vector_type b_(b_vec); + + static_for<0, 8, 1>{}([&](auto k) { + float a_f32 = type_convert(a_.template AsType()[number{}]); + float b_f32 = type_convert(b_.template AsType()[number{}]); + + c_vec = __builtin_amdgcn_mfma_f32_32x32x2f32(a_f32, b_f32, c_vec, 0, 0, 0); + }); +#endif + } + + // c_vec = a_vec * b_vec + CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const + { + if constexpr(is_same_v && is_same_v) + return __builtin_amdgcn_mfma_f32_32x32x16_fp8_fp8( + bit_cast(a_vec), bit_cast(b_vec), CVecType{0.f}, 0, 0, 0); + else if constexpr(is_same_v && is_same_v) + return __builtin_amdgcn_mfma_f32_32x32x16_fp8_bf8( + bit_cast(a_vec), bit_cast(b_vec), CVecType{0.f}, 0, 0, 0); + else if constexpr(is_same_v && is_same_v) + return __builtin_amdgcn_mfma_f32_32x32x16_bf8_fp8( + bit_cast(a_vec), bit_cast(b_vec), CVecType{0.f}, 0, 0, 0); + else if constexpr(is_same_v && is_same_v) + return __builtin_amdgcn_mfma_f32_32x32x16_bf8_bf8( + bit_cast(a_vec), bit_cast(b_vec), CVecType{0.f}, 0, 0, 0); + } +}; + +using WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_fp8 = + WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base; +using WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_bf8 = + WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base; +using WarpGemmAttributeMfmaImpl_f32_32x32x16_bf8_fp8 = + WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base; +using WarpGemmAttributeMfmaImpl_f32_32x32x16_bf8_bf8 = + WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp new file mode 100644 index 0000000000..5309794b23 --- /dev/null +++ b/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp @@ -0,0 +1,64 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" + +namespace ck_tile { + +namespace impl { +template +struct WarpGemmMfmaDispatcher; + +// clang-format off +// fp16 +template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaF16F16F32M32N32K8; }; +template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaF16F16F32M32N32K8TransposedCDistribution; }; +template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaF16F16F32M32N32K16; }; +template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution; }; +template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaF16F16F32M16N16K16; }; +template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution; }; +template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaF16F16F32M16N16K32; }; +template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution; }; + +// bf16 +template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaBf16Bf16F32M32N32K8; }; +template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaBf16Bf16F32M32N32K8TransposedCDistribution; }; +template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaBf16Bf16F32M32N32K16; }; +template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution; }; +template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaBf16Bf16F32M16N16K16; }; +template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaBf16Bf16F32M16N16K16TransposedCDistribution; }; +template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaBf16Bf16F32M16N16K32; }; +template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaBf16Bf16F32M16N16K32TransposedCDistribution; }; + +// fp8 +template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfma_f32_32x32x16_fp8_fp8; }; +template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfma_f32_32x32x16_fp8_fp8_CTransposed; }; +template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfma_f32_32x32x16_fp8_bf8; }; +template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfma_f32_32x32x16_fp8_bf8_CTransposed; }; +template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfma_f32_32x32x16_bf8_fp8; }; +template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfma_f32_32x32x16_bf8_fp8_CTransposed; }; +template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfma_f32_32x32x16_bf8_bf8; }; +template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfma_f32_32x32x16_bf8_bf8_CTransposed; }; + +// clang-format on +} // namespace impl + +template +using WarpGemmMfmaDispatcher = typename impl:: + WarpGemmMfmaDispatcher::Type; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_impl.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_impl.hpp new file mode 100644 index 0000000000..94d0e02931 --- /dev/null +++ b/include/ck_tile/ops/gemm/warp/warp_gemm_impl.hpp @@ -0,0 +1,74 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +namespace ck_tile { + +template +struct WarpGemmImpl +{ + using WarpGemmAttribute = remove_cvref_t; + + static constexpr index_t kM = WarpGemmAttribute::kM; + static constexpr index_t kN = WarpGemmAttribute::kN; + static constexpr index_t kK = WarpGemmAttribute::kK; + + using ADataType = typename WarpGemmAttribute::ADataType; + using BDataType = typename WarpGemmAttribute::BDataType; + using CDataType = typename WarpGemmAttribute::CDataType; + + using AWarpDstrEncoding = typename WarpGemmAttribute::AWarpDstrEncoding; + using BWarpDstrEncoding = typename WarpGemmAttribute::BWarpDstrEncoding; + using CWarpDstrEncoding = typename WarpGemmAttribute::CWarpDstrEncoding; + + using AWarpDstr = remove_cvref_t; + using BWarpDstr = remove_cvref_t; + using CWarpDstr = remove_cvref_t; + + using AWarpTensor = static_distributed_tensor; + using BWarpTensor = static_distributed_tensor; + using CWarpTensor = static_distributed_tensor; + + CK_TILE_DEVICE void operator()(CWarpTensor& c, const AWarpTensor& a, const BWarpTensor& b) const + { + using AVec = typename vector_type::type; + using BVec = typename vector_type::type; + using CVec = typename vector_type::type; + + constexpr auto I0 = number<0>{}; + + const auto a_vec = a.get_thread_buffer().template get_as(I0); + const auto b_vec = b.get_thread_buffer().template get_as(I0); + auto c_vec = c.get_thread_buffer().template get_as(I0); + + // c_vec += a_vec * b_vec + WarpGemmAttribute{}(c_vec, a_vec, b_vec); + + c.get_thread_buffer().template set_as(I0, c_vec); + } + + CK_TILE_DEVICE auto operator()(const AWarpTensor& a, const BWarpTensor& b) const + { + CWarpTensor c; + + using AVec = typename vector_type::type; + using BVec = typename vector_type::type; + using CVec = typename vector_type::type; + + constexpr auto I0 = number<0>{}; + + const auto a_vec = a.get_thread_buffer().template get_as(I0); + const auto b_vec = b.get_thread_buffer().template get_as(I0); + + // c_vec = a_vec * b_vec + auto c_vec = WarpGemmAttribute{}(a_vec, b_vec); + + c.get_thread_buffer().template set_as(I0, c_vec); + + return c; + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/reduce.hpp b/include/ck_tile/ops/reduce.hpp new file mode 100644 index 0000000000..59d761fb17 --- /dev/null +++ b/include/ck_tile/ops/reduce.hpp @@ -0,0 +1,6 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/ops/reduce/block/block_reduce.hpp" diff --git a/include/ck_tile/ops/reduce/block/block_reduce.hpp b/include/ck_tile/ops/reduce/block/block_reduce.hpp new file mode 100644 index 0000000000..7e0140b9d5 --- /dev/null +++ b/include/ck_tile/ops/reduce/block/block_reduce.hpp @@ -0,0 +1,211 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" + +namespace ck_tile { + +// synchronize reduce result (cross lane reduction and broadcast on replicated dimension) +template +CK_TILE_DEVICE void block_tile_reduce_sync(AccDistributedTensor_& acc_tensor, + const ReduceFunc& reduce_func, + bool_constant = {}) +{ + using Dstr = typename AccDistributedTensor_::StaticTileDistribution; + using DstrEncode = typename Dstr::DstrEncode; + using DstrEncodeDetail = typename DstrEncode::detail; + + constexpr index_t NDimP = Dstr::GetNumOfDimensionP(); + constexpr index_t NDimR = Dstr::GetNumOfDimensionR(); + + constexpr index_t idim_p_lane = NDimP - 1; + + const auto ps_idx = make_array(get_block_id(), get_lane_id()); + const auto rs_idx = acc_tensor.get_tile_distribution().calculate_rs_index_from_ps_index(ps_idx); + + constexpr index_t thread_buf_size = AccDistributedTensor_::get_thread_buffer_size(); + + // loop over thread data + static_for<0, thread_buf_size, 1>{}([&](auto i) { + auto v_local = acc_tensor.get_thread_buffer()[i]; + + // cross-lane reduce for replication + // only reduce on R dimension correspond to lane + // (lane id maps to this R dimension) + static_for<0, NDimR, 1>{}([&](auto idim_r) { + // FIXME: nasty to use does_p_own_r_ + if constexpr(DstrEncodeDetail::does_p_own_r_[idim_p_lane][idim_r]) + { + constexpr index_t r_length = DstrEncode::rs_lengths_[idim_r]; + + constexpr index_t lid_over_rid_derivative = + DstrEncodeDetail::ps_over_rs_derivative_[idim_p_lane][idim_r]; + + static_assert(math::is_power_of_two_integer(r_length), + "wrong! only support power of 2 reduction"); + + constexpr index_t nstage = math::integer_log2_floor(r_length); + + // reduction sweep forward + static_for<0, nstage, 1>{}([&](auto istage) { + constexpr index_t lid_delta = + lid_over_rid_derivative * (1 << (nstage - istage - 1)); + + // pull data from remote lane + const auto v_remote = warp_shuffle_down(v_local, lid_delta); + + // reduce + v_local = reduce_func(v_local, v_remote); + }); + } + }); + + if constexpr(WithBroadcast) + { + // cross-lane broadcast for replication + // only broadcast on R dimension correspond to lane + // (lane id maps to this R dimension) + static_for<0, NDimR, 1>{}([&](auto idim_r) { + // FIXME: nasty to use does_p_own_r_ + if constexpr(DstrEncodeDetail::does_p_own_r_[idim_p_lane][idim_r]) + { + const index_t r_id = rs_idx[idim_r]; + + constexpr index_t r_length = DstrEncode::rs_lengths_[idim_r]; + + constexpr index_t lid_over_rid_derivative = + DstrEncodeDetail::ps_over_rs_derivative_[NDimP - 1][idim_r]; + + static_assert(math::is_power_of_two_integer(r_length), + "wrong! only support power of 2 reduction"); + + constexpr index_t nstage = math::integer_log2_floor(r_length); + + // broadcast sweep backward + static_for<0, nstage, 1>{}([&](auto istage) { + // do I hold reduced data? + const bool do_i_hold_reduced_data = r_id < (1 << istage); + + constexpr index_t lid_delta = lid_over_rid_derivative * (1 << istage); + + // pull data from remote lane + const auto v_remote = warp_shuffle_up(v_local, lid_delta); + + // decide whether to update local data with remote data + v_local = do_i_hold_reduced_data ? v_local : v_remote; + }); + } + }); + } + + acc_tensor.get_thread_buffer()(i) = v_local; + }); +} + +// FIXME: this is for 2D to 1D reduce only, need to support n-D +template +CK_TILE_DEVICE void block_tile_reduce(AccDistributedTensor_& acc_tensor, + const InDistributedTensor_& in_tensor, + sequence, + const ReduceFunc& reduce_func) +{ + constexpr auto I0 = number<0>{}; + constexpr auto I1 = number<1>{}; + +#if 0 + constexpr auto in_reduce_dims = sequence{}; + + constexpr index_t ndim_in = InDistributedTensor_::get_num_of_dimension(); + constexpr index_t ndim_in_reduce = in_reduce_dims.size(); + constexpr index_t ndim_in_free = ndim_in - ndim_in_reduce; + + constexpr auto in_free_dims_arr = [&] { + array is_free_dims{true}; + + for(index_t i = 0; i < ndim_reduce; i++) + { + is_free_dims(in_reduce_dims[i]) = false; + } + + array in_free_dims{-1}; + + index_t cnt = 0; + + for(index_t i = 0; i < ndim_in; i++) + { + if(is_free_dims[i]) + { + in_free_dims(cnt) = i; + + cnt++ + } + } + + return is_free_dims; + }(); + + constexpr auto in_free_dims = TO_SEQUENCE(is_free_dims_arr, ndim_in_free); +#else + + constexpr auto spans = InDistributedTensor_::get_distributed_spans(); + + // in-thread reduction + // FIXME: hard coded to be 2D to 1D reduction + sweep_tile_span(spans[I0], [&](auto dstr_idx_i0) { + constexpr auto acc_dstr_idx = make_tuple(dstr_idx_i0); + + auto acc = acc_tensor[acc_dstr_idx]; + + // FIXME + sweep_tile_span(spans[I1], [&](auto dstr_idx_i1) { + constexpr auto in_dstr_idx = make_tuple(dstr_idx_i0, dstr_idx_i1); + + const auto in = in_tensor[in_dstr_idx]; + + acc = reduce_func(acc, in); + }); + + acc_tensor(acc_dstr_idx) = acc; + }); +#endif +} + +template +CK_TILE_DEVICE auto block_tile_reduce(const InDistributedTensor_& in_tensor, + sequence in_reduce_dims, + const ReduceFunc& reduce_func, + const InDataType_& reduce_init) +{ + using InDataType = typename InDistributedTensor_::DataType; + using AccDataType = remove_cvref_t; + + static_assert(is_same_v>, "wrong!"); + + // declare acc_tensor + constexpr auto acc_dstr = + make_static_tile_distribution(ck_tile::detail::make_reduce_tile_distribution_encoding( + InDistributedTensor_::get_tile_distribution().get_static_tile_distribution_encoding(), + sequence{})); + + auto acc_tensor = make_static_distributed_tensor(acc_dstr); + + // init acc_tensor + tile_elementwise_inout([&](auto& acc) { acc = type_convert(reduce_init); }, + acc_tensor); + + // warp reduce + block_tile_reduce(acc_tensor, in_tensor, in_reduce_dims, reduce_func); + + return acc_tensor; +} + +} // namespace ck_tile diff --git a/include/ck_tile/remod.py b/include/ck_tile/remod.py new file mode 100644 index 0000000000..0183a97e84 --- /dev/null +++ b/include/ck_tile/remod.py @@ -0,0 +1,76 @@ +import pathlib +from pathlib import Path +import subprocess +import os + +NS = 'ck_tile' +OPS = 'ops' + +HEADER_COMMON = """// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.\n +""" + +# aa/bb/cc/file.hpp -> (aa, bb, cc, file.hpp) +def get_module(f, level = 0): + all_parts = f.parts + return str(all_parts[level]) + +all_files = [] +for p in sorted(Path("./").rglob("*")): + if p.suffix == '.hpp': + all_files.append(pathlib.PurePath(p)) + +class submodule_t: + def __init__(self): + self.m = dict() + def push(self, f): + if len(f.parents) != 1: # ignore ./xxx.hpp + mod = get_module(f) + if mod == OPS: + if mod not in self.m.keys(): + self.m[mod] = dict() + mod2 = get_module(f, 1) + if Path(mod2).suffix != '.hpp': + # ignore ops/xxx.hpp + if mod2 not in self.m[mod].keys(): + self.m[mod][mod2] = list() + self.m[mod][mod2].append(f) + else: + if mod not in self.m.keys(): + self.m[mod] = list() + self.m[mod].append(f) + + def gen(self): + def gen_header(hpath, include_list): + # print(hpath) + if os.path.exists(str(hpath)): + os.remove(str(hpath)) + with hpath.open('w') as f: + f.write(HEADER_COMMON) + f.write('#pragma once\n') + f.write('\n') + for individual_header in include_list: + header_path = NS + '/' + str(individual_header) + f.write(f'#include \"{header_path}\"\n') + f.write('\n') + # print(self.m) + for k, v in self.m.items(): + if k == OPS: + for km, kv in v.items(): + gen_header(Path(k) / (f'{km}.hpp'), kv) + else: + gen_header(Path(f'{k}.hpp'), v) + + +submodule = submodule_t() +# formatting +for x in all_files: + cmd = f'clang-format-12 -style=file -i {str(x)}' + #for xp in x.parents: + #print(get_file_base(x)) + subprocess.Popen(cmd, shell=True) + submodule.push(x) + +submodule.gen() + +#print(all_files)