From 294cb823142a815170cf1faa63e01a431a557a04 Mon Sep 17 00:00:00 2001 From: BrianHarrisonAMD <169072757+BrianHarrisonAMD@users.noreply.github.com> Date: Wed, 2 Oct 2024 10:26:11 -0600 Subject: [PATCH 01/24] Add generating mha static library for gfx90a (#1540) * Add generating mha static library for gfx90a * Update comment to reflect changes --- library/src/tensor_operation_instance/gpu/CMakeLists.txt | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/library/src/tensor_operation_instance/gpu/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/CMakeLists.txt index df3283b543..bc66fe0bed 100644 --- a/library/src/tensor_operation_instance/gpu/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/CMakeLists.txt @@ -64,9 +64,9 @@ function(add_instance_library INSTANCE_NAME) list(REMOVE_ITEM ARGN "${source}") endif() endforeach() - # Do not build mha instances if gfx94 targets are not on the target list + # Do not build mha instances if gfx94 or gfx90a targets are not on the target list foreach(source IN LISTS ARGN) - if(NOT INST_TARGETS MATCHES "gfx94" AND source MATCHES "mha") + if(NOT INST_TARGETS MATCHES "gfx94" AND NOT INST_TARGETS MATCHES "gfx90a" AND source MATCHES "mha") message("removing mha instance ${source} ") list(REMOVE_ITEM ARGN "${source}") endif() @@ -85,7 +85,7 @@ function(add_instance_library INSTANCE_NAME) elseif(ARGN MATCHES "_wmma") list(REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030) elseif(ARGN MATCHES "mha") - list(REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx908 gfx90a gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201) + list(REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx908 gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201) endif() set(offload_targets) foreach(target IN LISTS INST_TARGETS) @@ -320,8 +320,7 @@ if(CK_DEVICE_CONV_INSTANCES) endif() if(CK_DEVICE_MHA_INSTANCES) set(gpu_list ${INST_TARGETS}) - list(FILTER gpu_list INCLUDE REGEX "^gfx94") - if(gpu_list) + if(gpu_list MATCHES "gfx94" OR gpu_list MATCHES "gfx90a") add_library(device_mha_operations STATIC ${CK_DEVICE_MHA_INSTANCES}) add_library(composablekernels::device_mha_operations ALIAS device_mha_operations) target_compile_features(device_mha_operations PUBLIC) From aeb7c91f48a0e8fa1e288d91f719415282c03f03 Mon Sep 17 00:00:00 2001 From: macurtis-amd Date: Wed, 2 Oct 2024 15:56:22 -0500 Subject: [PATCH 02/24] Fix compilation errors generated by forthcoming Clang changes (#1544) Without this change, the following diagnostic is generated: a template argument list is expected after a name prefixed by the template keyword [-Wmissing-template-arg-list-after-template-kw] See C++17 spec [temp.names] p5. --- ...ckwise_gemm_pipeline_xdlops_v1_ab_scale.hpp | 9 +++++---- ...ckwise_gemm_pipeline_xdlops_v2_ab_scale.hpp | 18 ++++++++++-------- ...ckwise_gemm_pipeline_xdlops_v3_ab_scale.hpp | 9 +++++---- 3 files changed, 20 insertions(+), 16 deletions(-) diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1_ab_scale.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1_ab_scale.hpp index 121593d3cc..821bbb0051 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1_ab_scale.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1_ab_scale.hpp @@ -308,7 +308,7 @@ struct BlockwiseGemmXdlops_pipeline_v1_ab_scale::type; - xdlops_gemm.template Run( + xdlops_gemm.template Run<>( a_thread_vec.template AsType(), b_thread_vec.template AsType(), c_thread_buf_per_scale.GetVectorTypeReference(I0)); @@ -390,9 +390,10 @@ struct BlockwiseGemmXdlops_pipeline_v1_ab_scale::type; - xdlops_gemm.template Run(a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf_per_scale.GetVectorTypeReference(I0)); + xdlops_gemm.template Run<>( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf_per_scale.GetVectorTypeReference(I0)); }); static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { constexpr index_t c_offset = diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2_ab_scale.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2_ab_scale.hpp index cb7cf605be..40fa776484 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2_ab_scale.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2_ab_scale.hpp @@ -350,7 +350,7 @@ struct BlockwiseGemmXdlops_pipeline_v2_ab_scale::type; - xdlops_gemm.template Run( + xdlops_gemm.template Run<>( a_thread_vec.template AsType(), b_thread_vec.template AsType(), c_thread_buf_per_scale.GetVectorTypeReference(I0)); @@ -443,7 +443,7 @@ struct BlockwiseGemmXdlops_pipeline_v2_ab_scale::type; - xdlops_gemm.template Run( + xdlops_gemm.template Run<>( a_thread_vec.template AsType(), b_thread_vec.template AsType(), c_thread_buf_per_scale.GetVectorTypeReference(I0)); @@ -518,9 +518,10 @@ struct BlockwiseGemmXdlops_pipeline_v2_ab_scale::type; - xdlops_gemm.template Run(a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf_per_scale.GetVectorTypeReference(I0)); + xdlops_gemm.template Run<>( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf_per_scale.GetVectorTypeReference(I0)); }); static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { constexpr index_t c_offset = @@ -575,9 +576,10 @@ struct BlockwiseGemmXdlops_pipeline_v2_ab_scale::type; - xdlops_gemm.template Run(a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf_per_scale.GetVectorTypeReference(I0)); + xdlops_gemm.template Run<>( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf_per_scale.GetVectorTypeReference(I0)); }); static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { constexpr index_t c_offset = diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3_ab_scale.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3_ab_scale.hpp index 66c9a5c339..de542866a6 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3_ab_scale.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3_ab_scale.hpp @@ -427,7 +427,7 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale::type; - xdlops_gemm.template Run( + xdlops_gemm.template Run<>( a_thread_vec.template AsType(), b_thread_vec.template AsType(), c_thread_buf_per_scale.GetVectorTypeReference(I0)); @@ -504,9 +504,10 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale::type; - xdlops_gemm.template Run(a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf_per_scale.GetVectorTypeReference(I0)); + xdlops_gemm.template Run<>( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf_per_scale.GetVectorTypeReference(I0)); }); static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { constexpr index_t c_offset = From 6b54d2faf8d0b106fb31719654b6b4d5f18552f2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bart=C5=82omiej=20Kocot?= Date: Fri, 4 Oct 2024 17:32:43 +0200 Subject: [PATCH 03/24] Fix grouped gemm check to avoid overflow (#1545) --- .../device/impl/device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp | 2 +- ...ce_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp | 2 +- .../device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp | 2 +- .../gpu/device/impl/device_grouped_gemm_xdl_fixed_nk.hpp | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp index 1f60818e39..21afc06040 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp @@ -64,7 +64,7 @@ __global__ void const index_t N = gemm_desc_ptr[group_id].N; const index_t K = gemm_desc_ptr[group_id].K; - if(M * N * K == 0) + if(M == 0 || N == 0 || K == 0) return; const auto StrideAs = gemm_desc_ptr[group_id].StrideAs; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp index 8354335577..68c6dcc0f5 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp @@ -345,7 +345,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage const index_t N = gemm_descs[i].N_; const index_t K = gemm_descs[i].K_; - if(M * N * K == 0) + if(M == 0 || N == 0 || K == 0) { skipped_group_count_++; continue; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp index 70011124fc..2884e558cd 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp @@ -109,7 +109,7 @@ __global__ void N = gemm_desc_ptr[group_id].N; K = gemm_desc_ptr[group_id].K; - if(M * N * K == 0) + if(M == 0 || N == 0 || K == 0) { grid_size_grp = 0; continue; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_fixed_nk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_fixed_nk.hpp index c98ec6e2aa..ac05a0703f 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_fixed_nk.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_fixed_nk.hpp @@ -68,7 +68,7 @@ __global__ void const index_t N = gemm_desc_ptr[group_id].N; const index_t K = gemm_desc_ptr[group_id].K; - if(M * N * K == 0) + if(M == 0 || N == 0 || K == 0) return; const auto StrideA = gemm_desc_ptr[group_id].StrideA; From b545de175a7a1410baaea70f1f32ecb819e1d056 Mon Sep 17 00:00:00 2001 From: arai713 <67439843+arai713@users.noreply.github.com> Date: Fri, 4 Oct 2024 10:51:50 -0700 Subject: [PATCH 04/24] Codegen build (#1526) * updating codegen build for MIOpen access: adding .cmake for codegen component (cherry picked from commit 652a7c046381526947f507a89299aa92d89dbd02) * updating CMake (cherry picked from commit a685822e361045f3ef02a2f60c1c0eadd9cc4c85) --- CMakeLists.txt | 4 +++- cmake/Embed.cmake | 4 +++- codegen/CMakeLists.txt | 9 ++++++++- 3 files changed, 14 insertions(+), 3 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index fd321f7722..dc73b5f4d4 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -569,7 +569,9 @@ if(NOT DEFINED INSTANCES_ONLY) PACKAGE_NAME examples ) add_subdirectory(example) - add_subdirectory(test) + if(BUILD_TESTING) + add_subdirectory(test) + endif() rocm_package_setup_component(profiler LIBRARY_NAME composablekernel diff --git a/cmake/Embed.cmake b/cmake/Embed.cmake index 4bc638b446..3946cf4e8d 100644 --- a/cmake/Embed.cmake +++ b/cmake/Embed.cmake @@ -233,6 +233,8 @@ function(add_embed_library EMBED_NAME) else() target_sources(${EMBED_NAME} INTERFACE $) endif() - target_include_directories(${EMBED_NAME} INTERFACE "${EMBED_DIR}/include") + target_include_directories(${EMBED_NAME} INTERFACE + $ + $) endfunction() diff --git a/codegen/CMakeLists.txt b/codegen/CMakeLists.txt index 3b3e9f06ee..2492804f28 100644 --- a/codegen/CMakeLists.txt +++ b/codegen/CMakeLists.txt @@ -39,6 +39,7 @@ set_target_properties(ck_host PROPERTIES target_include_directories(ck_host PUBLIC $ + $ ) add_executable(ck-template-driver driver/main.cpp) @@ -48,6 +49,12 @@ rocm_install( TARGETS ck_host ck_headers EXPORT ck_hostTargets ) +rocm_install(EXPORT ck_hostTargets + FILE composable_kernelck_hostTargets.cmake + NAMESPACE composable_kernel:: + DESTINATION ${CMAKE_INSTALL_LIBDIR}/cmake/composable_kernel) rocm_install(DIRECTORY include/ck DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}) -add_subdirectory(test) +if(BUILD_TESTING) + add_subdirectory(test) +endif() From c24fae234600aa2863e945d072e6f5b3aec2a6b2 Mon Sep 17 00:00:00 2001 From: kylasa Date: Fri, 4 Oct 2024 11:48:47 -0700 Subject: [PATCH 05/24] Adding seed and offset pointer support to the philox random number generator. (#1523) * Adding seed and offset pointer support to the philox random number generator. * Separating seed and offset pointer checks with different condition statements. * Changes include, adding support for device seed and offset pointers, union is used to store seed/offset values and device pointers to minimize device SGPRs. * Correcting a typo in the readme file * Re-format files using remod.py * Use STL type for API parameters * Use simpler struct design for drop_seed & drop_offset * Undo unnecessary changes * Sync kargs style for fmha_fwd.hpp/.cpp * Use templated union to reduce code * Use structured binding to make code more readable --------- Co-authored-by: Sudhir Kylasa Co-authored-by: Po Yen Chen --- example/ck_tile/01_fmha/README.md | 7 +- example/ck_tile/01_fmha/fmha_bwd.cpp | 23 ++++- example/ck_tile/01_fmha/fmha_bwd.hpp | 6 +- example/ck_tile/01_fmha/fmha_fwd.cpp | 23 ++++- example/ck_tile/01_fmha/fmha_fwd.hpp | 6 +- .../ops/fmha/kernel/fmha_bwd_kernel.hpp | 91 ++++++++++++++++--- .../ops/fmha/kernel/fmha_fwd_kernel.hpp | 83 ++++++++++++++--- 7 files changed, 205 insertions(+), 34 deletions(-) diff --git a/example/ck_tile/01_fmha/README.md b/example/ck_tile/01_fmha/README.md index 0bb5408772..0803d54d66 100644 --- a/example/ck_tile/01_fmha/README.md +++ b/example/ck_tile/01_fmha/README.md @@ -70,8 +70,13 @@ args: -seed random seed used for initializing input tensors. 0 for non-deterministic seed (default:11939) -warmup number of iterations before benchmark the kernel (default:5) -repeat number of iterations to benchmark the kernel (default:20) + -drop_seed seed for the random number generator for the dropout layer, default is 1 +-drop_offset offset for the dropout layer which is used during random number generation, default is 0 + -drop_prefs flag to indicate `drop_seed` and `drop_offset` values if present on the GPU, default is 0, 0 - host, 1 - GPU ``` -Example: `./bin/tile_example_fmha_fwd -b=1 -h=16 -s=16384 -d=128` will run a fmha case with batch=1, nhead=16, sequence length=16384, hdim=128, fp16 case. +Example 1: `./bin/tile_example_fmha_fwd -b=1 -h=16 -s=16384 -d=128` will run a fmha case with batch=1, nhead=16, sequence length=16384, hdim=128, fp16 case. +Example 2: `./bin/tile_example_fmha_fwd -b=1 -h=8 -s=16384 -d=64 -drop_prefs=1 -drop_seed=10 -drop_offset=1234` will run a fmha case with + batch=1, nhead=8, sequence length=16384, hdim=64, drop_seed=0 (in GPU memory), drop_offset=1234 (in GPU memory) fp16 case ## support features Currently we are still in rapid development stage, so more features/optimizations will be coming soon. diff --git a/example/ck_tile/01_fmha/fmha_bwd.cpp b/example/ck_tile/01_fmha/fmha_bwd.cpp index c2f554f6cc..2d76627a72 100644 --- a/example/ck_tile/01_fmha/fmha_bwd.cpp +++ b/example/ck_tile/01_fmha/fmha_bwd.cpp @@ -85,6 +85,9 @@ auto create_args(int argc, char* argv[]) .insert("p_drop", "0", "0~1 probability of dropout") .insert("drop_seed", "1", "seed for random number generator") .insert("drop_offset", "0", "offset for random number generator") + .insert("drop_prefs", + "0", + "seed and offset values are present on GPU; 0 - host, 1 - device/GPU") .insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer") .insert("warmup", "5", "number of iterations before benchmark the kernel") .insert("repeat", "20", "number of iterations to benchmark the kernel") @@ -158,6 +161,8 @@ bool run(const ck_tile::ArgParser& arg_parser) float p_drop = arg_parser.get_float("p_drop"); uint64_t drop_seed = arg_parser.get_uint64("drop_seed"); uint64_t drop_offset = arg_parser.get_uint64("drop_offset"); + bool drop_prefs = arg_parser.get_bool("drop_prefs"); + if(use_dbias && bias.type != bias_enum::elementwise_bias) { std::cerr << "dbias only exists when bias type is elementwise" << std::endl; @@ -381,6 +386,8 @@ bool run(const ck_tile::ArgParser& arg_parser) ck_tile::DeviceMem dbias_buf(dbias_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem seqstart_q(seqstart_q_host.size() * sizeof(int32_t)); ck_tile::DeviceMem seqstart_k(seqstart_k_host.size() * sizeof(int32_t)); + ck_tile::DeviceMem drop_seed_buf(drop_prefs ? sizeof(uint64_t) : 0); + ck_tile::DeviceMem drop_offset_buf(drop_prefs ? sizeof(uint64_t) : 0); ck_tile::DeviceMem alibi_slope_buf(alibi_slope_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem dq_acc_buf(dq_acc_host.get_element_space_size_in_bytes()); @@ -391,6 +398,8 @@ bool run(const ck_tile::ArgParser& arg_parser) do_buf.ToDevice(do_host.data()); seqstart_q.ToDevice(seqstart_q_host.data()); seqstart_k.ToDevice(seqstart_k_host.data()); + drop_seed_buf.ToDevice(drop_prefs ? &drop_seed : nullptr); + drop_offset_buf.ToDevice(drop_prefs ? &drop_offset : nullptr); alibi_slope_buf.ToDevice(alibi_slope_host.data()); // clang-format off @@ -472,6 +481,18 @@ bool run(const ck_tile::ArgParser& arg_parser) const ck_tile::index_t split_stride_dq_acc = (shape_batch * nhead * shape_seqlen_q * hdim_q); + const auto drop_seed_offset = [&]() -> decltype(fmha_bwd_args::drop_seed_offset) { + if(drop_prefs) + { + return std::make_pair(drop_seed_buf.GetDeviceBuffer(), + drop_offset_buf.GetDeviceBuffer()); + } + else + { + return std::make_pair(drop_seed, drop_offset); + } + }(); + return fmha_bwd_args{q_buf.GetDeviceBuffer(), k_buf.GetDeviceBuffer(), v_buf.GetDeviceBuffer(), @@ -545,7 +566,7 @@ bool run(const ck_tile::ArgParser& arg_parser) static_cast(mask.type), p_drop, p_undrop, - {drop_seed, drop_offset}}; + drop_seed_offset}; }(); float ave_time = fmha_bwd(fmha_traits, fmha_args, stream_config); diff --git a/example/ck_tile/01_fmha/fmha_bwd.hpp b/example/ck_tile/01_fmha/fmha_bwd.hpp index aea42515dc..3b21a3257f 100644 --- a/example/ck_tile/01_fmha/fmha_bwd.hpp +++ b/example/ck_tile/01_fmha/fmha_bwd.hpp @@ -9,7 +9,10 @@ #include "ck_tile/ops/epilogue.hpp" #include "mask.hpp" #include "bias.hpp" + #include +#include +#include template struct FmhaBwdTypeConfig; @@ -135,7 +138,8 @@ struct fmha_bwd_args ck_tile::index_t mask_type; float p_drop; float p_undrop; - std::tuple drop_seed_offset; + std::variant, std::pair> + drop_seed_offset; }; template diff --git a/example/ck_tile/01_fmha/fmha_fwd.cpp b/example/ck_tile/01_fmha/fmha_fwd.cpp index b9cb9a1ec2..6d519a7ea8 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.cpp +++ b/example/ck_tile/01_fmha/fmha_fwd.cpp @@ -122,6 +122,9 @@ auto create_args(int argc, char* argv[]) .insert("p_drop", "0", "0~1 probability of dropout") .insert("drop_seed", "1", "seed for random number generator") .insert("drop_offset", "0", "offset for random number generator") + .insert("drop_prefs", + "0", + "seed and offset values are present on GPU; 0 - host, 1 - device/GPU") .insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer") .insert( "rotary_dim", "0", "RoPE rotary dimension. rotary_dim <= 0 means not apply RoPE at all") @@ -442,6 +445,8 @@ bool run(const ck_tile::ArgParser& arg_parser) float p_drop = arg_parser.get_float("p_drop"); uint64_t drop_seed = arg_parser.get_uint64("drop_seed"); uint64_t drop_offset = arg_parser.get_uint64("drop_offset"); + bool drop_prefs = arg_parser.get_bool("drop_prefs"); + if(p_drop < 0.0f || p_drop > 1.0f) { std::cerr << "The value of p_drop should be 0~1" << std::endl; @@ -756,6 +761,8 @@ bool run(const ck_tile::ArgParser& arg_parser) need_append_kvcache ? cache_seqlen_ks.size() * sizeof(int32_t) : 0); ck_tile::DeviceMem rotary_cos_buf(rotary_cos_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem rotary_sin_buf(rotary_sin_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem drop_seed_buf(drop_prefs ? sizeof(uint64_t) : 0); + ck_tile::DeviceMem drop_offset_buf(drop_prefs ? sizeof(uint64_t) : 0); ck_tile::DeviceMem randval_buf(randval_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem alibi_slope_buf(alibi_slope_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem block_table_buf(block_table_host.get_element_space_size_in_bytes()); @@ -774,6 +781,8 @@ bool run(const ck_tile::ArgParser& arg_parser) cache_seqlen_k_buf.ToDevice(need_append_kvcache ? cache_seqlen_ks.data() : nullptr); rotary_cos_buf.ToDevice(rotary_cos_host.data()); rotary_sin_buf.ToDevice(rotary_sin_host.data()); + drop_seed_buf.ToDevice(drop_prefs ? &drop_seed : nullptr); + drop_offset_buf.ToDevice(drop_prefs ? &drop_offset : nullptr); alibi_slope_buf.ToDevice(alibi_slope_host.data()); block_table_buf.ToDevice(block_table_host.data()); cache_batch_idx_buf.ToDevice(cache_batch_idx_host.data()); @@ -1013,9 +1022,17 @@ bool run(const ck_tile::ArgParser& arg_parser) args.nhead_stride_randval = nhead_stride_randval; args.batch_stride_randval = batch_stride_randval; - args.p_drop = p_drop; - args.s_randval = s_randval; - args.drop_seed_offset = std::tie(drop_seed, drop_offset); + args.p_drop = p_drop; + args.s_randval = s_randval; + if(drop_prefs) + { + args.drop_seed_offset = std::make_pair(drop_seed_buf.GetDeviceBuffer(), + drop_offset_buf.GetDeviceBuffer()); + } + else + { + args.drop_seed_offset = std::make_pair(drop_seed, drop_offset); + } } else if constexpr(std::is_same_v>) { diff --git a/example/ck_tile/01_fmha/fmha_fwd.hpp b/example/ck_tile/01_fmha/fmha_fwd.hpp index 5dcad7907f..251e61bc76 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd.hpp @@ -13,6 +13,8 @@ #include "rotary.hpp" #include +#include +#include template struct FmhaFwdTypeConfig; @@ -144,7 +146,9 @@ struct fmha_fwd_args float p_drop; bool s_randval; - std::tuple drop_seed_offset; + + std::variant, std::pair> + drop_seed_offset; }; struct fmha_fwd_splitkv_args diff --git a/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp index 167494b193..c5858a20f7 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp @@ -6,8 +6,11 @@ #include "ck_tile/core.hpp" #include "ck_tile/ops/common.hpp" #include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp" + #include #include +#include +#include // S[seqlen_q, seqlen_k] = Q[seqlen_q, hdim_q] @ K[seqlen_k, hdim_q] // S'[seqlen_q, seqlen_k] = S[seqlen_q, seqlen_k] * Scale[1] @@ -194,11 +197,23 @@ struct FmhaBwdDQDKDVKernel ck_tile::GenericAttentionMaskEnum mask_type; }; - struct FmhaBwdCommonDropoutKargs + struct FmhaBwdDropoutSeedOffset { - void init_dropout(const float p_drop, - const std::tuple& drop_seed_offset, - const float raw_scale) + template + union ValueOrPointer + { + T val; + const T* ptr; + }; + + ValueOrPointer drop_seed; + ValueOrPointer drop_offset; + bool is_drop_seed_offset_from_host; + }; + + struct FmhaBwdCommonDropoutKargs : FmhaBwdDropoutSeedOffset + { + void init_dropout(float p_drop, uint64_t seed, uint64_t offset, float raw_scale) { float p_undrop = 1.0 - p_drop; p_undrop_in_uint8_t = @@ -206,23 +221,41 @@ struct FmhaBwdDQDKDVKernel rp_undrop = 1.0 / p_undrop; scale_rp_undrop = rp_undrop * raw_scale; - drop_seed = std::get<0>(drop_seed_offset); - drop_offset = std::get<1>(drop_seed_offset); + this->drop_seed.val = seed; + this->drop_offset.val = offset; + this->is_drop_seed_offset_from_host = true; } + + void init_dropout(float p_drop, + const uint64_t* seed_ptr, + const uint64_t* offset_ptr, + float raw_scale) + { + float p_undrop = 1.0 - p_drop; + p_undrop_in_uint8_t = + uint8_t(std::floor(p_undrop * std::numeric_limits::max())); + rp_undrop = 1.0 / p_undrop; + scale_rp_undrop = rp_undrop * raw_scale; + + this->drop_seed.ptr = seed_ptr; + this->drop_offset.ptr = offset_ptr; + this->is_drop_seed_offset_from_host = false; + } + float rp_undrop = 1; float scale_rp_undrop = 1; uint8_t p_undrop_in_uint8_t = std::numeric_limits::max(); - uint64_t drop_seed = 1; - uint64_t drop_offset = 0; void* rand_val_ptr = nullptr; ck_tile::index_t stride_randval = 0; ck_tile::index_t nhead_stride_randval = 0; }; + struct FmhaBwdBatchModeDropoutKargs : FmhaBwdCommonDropoutKargs { ck_tile::index_t batch_stride_randval = 0; }; + struct FmhaBwdDeterministicKargs { ck_tile::index_t split_stride_dq_acc = 0; @@ -327,7 +360,8 @@ struct FmhaBwdDQDKDVKernel ck_tile::index_t window_size_right, ck_tile::index_t mask_type, float p_drop, - const std::tuple& drop_seed_offset) + std::variant, std::pair> + drop_seed_offset) { Kargs kargs{{q_ptr, k_ptr, @@ -405,7 +439,20 @@ struct FmhaBwdDQDKDVKernel if constexpr(kHasDropout) { - kargs.init_dropout(p_drop, drop_seed_offset, scale); + if(drop_seed_offset.index() == 0) // seed & offset come from host + { + const auto& [seed, offset] = std::get<0>(drop_seed_offset); + kargs.init_dropout(p_drop, seed, offset, scale); + } + else // seed & offset come from device + { + const auto& [seed_ptr, offset_ptr] = std::get<1>(drop_seed_offset); + kargs.init_dropout(p_drop, + reinterpret_cast(seed_ptr), + reinterpret_cast(offset_ptr), + scale); + } + if constexpr(kIsStoreRandval) { kargs.rand_val_ptr = rand_val_ptr; @@ -471,7 +518,8 @@ struct FmhaBwdDQDKDVKernel ck_tile::index_t window_size_right, ck_tile::index_t mask_type, float p_drop, - const std::tuple& drop_seed_offset) + std::variant, std::pair> + drop_seed_offset) { Kargs kargs{{q_ptr, k_ptr, @@ -539,7 +587,20 @@ struct FmhaBwdDQDKDVKernel } if constexpr(kHasDropout) { - kargs.init_dropout(p_drop, drop_seed_offset, scale); + if(drop_seed_offset.index() == 0) // seed & offset come from host + { + const auto& [seed, offset] = std::get<0>(drop_seed_offset); + kargs.init_dropout(p_drop, seed, offset, scale); + } + else // seed & offset come from device + { + const auto& [seed_ptr, offset_ptr] = std::get<1>(drop_seed_offset); + kargs.init_dropout(p_drop, + reinterpret_cast(seed_ptr), + reinterpret_cast(offset_ptr), + scale); + } + if constexpr(kIsStoreRandval) { kargs.rand_val_ptr = rand_val_ptr; @@ -958,8 +1019,10 @@ struct FmhaBwdDQDKDVKernel return FmhaDropout{i_batch_, i_nhead_, kargs.num_head_q, - kargs.drop_seed, - kargs.drop_offset, + kargs.is_drop_seed_offset_from_host ? kargs.drop_seed.val + : *kargs.drop_seed.ptr, + kargs.is_drop_seed_offset_from_host ? kargs.drop_offset.val + : *kargs.drop_offset.ptr, kargs.rp_undrop, kargs.p_undrop_in_uint8_t}; } diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp index 49ef7bf6d9..adabda165c 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp @@ -6,8 +6,11 @@ #include "ck_tile/core.hpp" #include "ck_tile/ops/common.hpp" #include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp" + #include #include +#include +#include // S[seqlen_q, seqlen_k] = Q[seqlen_q, hdim_q] @ K[seqlen_k, hdim_q] // S'[seqlen_q, seqlen_k] = S[seqlen_q, seqlen_k] * Scale[1] @@ -170,29 +173,55 @@ struct FmhaFwdKernel ck_tile::index_t batch_stride_lse = 0; }; - struct FmhaFwdCommonDropoutKargs + struct FmhaFwdDropoutSeedOffset { - void init_dropout(const float p_drop, - const std::tuple& drop_seed_offset) + template + union ValueOrPointer + { + T val; + const T* ptr; + }; + + ValueOrPointer drop_seed; + ValueOrPointer drop_offset; + bool is_drop_seed_offset_from_host; + }; + + struct FmhaFwdCommonDropoutKargs : FmhaFwdDropoutSeedOffset + { + void init_dropout(float p_drop, uint64_t seed, uint64_t offset) { float p_undrop = 1.0 - p_drop; p_undrop_in_uint8_t = uint8_t(std::floor(p_undrop * std::numeric_limits::max())); rp_undrop = 1.0 / p_undrop; - drop_seed = std::get<0>(drop_seed_offset); - drop_offset = std::get<1>(drop_seed_offset); + this->drop_seed.val = seed; + this->drop_offset.val = offset; + this->is_drop_seed_offset_from_host = true; } + + void init_dropout(float p_drop, const uint64_t* seed_ptr, const uint64_t* offset_ptr) + { + float p_undrop = 1.0 - p_drop; + p_undrop_in_uint8_t = + uint8_t(std::floor(p_undrop * std::numeric_limits::max())); + rp_undrop = 1.0 / p_undrop; + + this->drop_seed.ptr = seed_ptr; + this->drop_offset.ptr = offset_ptr; + this->is_drop_seed_offset_from_host = false; + } + float rp_undrop = 1; uint8_t p_undrop_in_uint8_t = std::numeric_limits::max(); bool is_store_randval = false; - uint64_t drop_seed = 1; - uint64_t drop_offset = 0; void* rand_val_ptr = nullptr; ck_tile::index_t stride_randval = 0; ck_tile::index_t nhead_stride_randval = 0; }; + struct FmhaFwdBatchModeDropoutKargs : FmhaFwdCommonDropoutKargs { ck_tile::index_t batch_stride_randval = 0; @@ -278,7 +307,8 @@ struct FmhaFwdKernel ck_tile::index_t mask_type, float p_drop, bool s_randval, - const std::tuple& drop_seed_offset) + std::variant, std::pair> + drop_seed_offset) { Kargs kargs{{q_ptr, k_ptr, @@ -344,7 +374,19 @@ struct FmhaFwdKernel } if constexpr(kHasDropout) { - kargs.init_dropout(p_drop, drop_seed_offset); + if(drop_seed_offset.index() == 0) // seed & offset come from host + { + const auto& [seed, offset] = std::get<0>(drop_seed_offset); + kargs.init_dropout(p_drop, seed, offset); + } + else // seed & offset come from device + { + const auto& [seed_ptr, offset_ptr] = std::get<1>(drop_seed_offset); + kargs.init_dropout(p_drop, + reinterpret_cast(seed_ptr), + reinterpret_cast(offset_ptr)); + } + kargs.rand_val_ptr = rand_val_ptr; kargs.stride_randval = stride_randval; kargs.nhead_stride_randval = nhead_stride_randval; @@ -392,7 +434,8 @@ struct FmhaFwdKernel ck_tile::index_t mask_type, float p_drop, bool s_randval, - const std::tuple& drop_seed_offset) + std::variant, std::pair> + drop_seed_offset) { Kargs kargs{{q_ptr, k_ptr, @@ -455,7 +498,19 @@ struct FmhaFwdKernel } if constexpr(kHasDropout) { - kargs.init_dropout(p_drop, drop_seed_offset); + if(drop_seed_offset.index() == 0) // seed & offset come from host + { + const auto& [seed, offset] = std::get<0>(drop_seed_offset); + kargs.init_dropout(p_drop, seed, offset); + } + else // seed & offset come from device + { + const auto& [seed_ptr, offset_ptr] = std::get<1>(drop_seed_offset); + kargs.init_dropout(p_drop, + reinterpret_cast(seed_ptr), + reinterpret_cast(offset_ptr)); + } + kargs.rand_val_ptr = rand_val_ptr; kargs.stride_randval = stride_randval; kargs.nhead_stride_randval = nhead_stride_randval; @@ -748,8 +803,10 @@ struct FmhaFwdKernel return BlockDropout{i_batch_, i_nhead_, kargs.num_head_q, - kargs.drop_seed, - kargs.drop_offset, + kargs.is_drop_seed_offset_from_host ? kargs.drop_seed.val + : *kargs.drop_seed.ptr, + kargs.is_drop_seed_offset_from_host ? kargs.drop_offset.val + : *kargs.drop_offset.ptr, kargs.rp_undrop, kargs.p_undrop_in_uint8_t, kargs.is_store_randval}; From 0023f01ab02b9cc05a98ae1a7753df1481252e4d Mon Sep 17 00:00:00 2001 From: rocking Date: Mon, 7 Oct 2024 14:25:53 +0800 Subject: [PATCH 06/24] [Ck tile] Support layernorm one pass (#1512) * Fix compile error * Add one pass pipeline * Extract creating tile_window to operator() * clang format * reduce duplicated code * do not hardcode * Support padding in layernorm --------- Co-authored-by: Po Yen Chen --- .../02_layernorm2d/layernorm2d_fwd.cpp | 4 +- .../kernel/layernorm2d_fwd_kernel.hpp | 333 +++++++++++++----- .../block_layernorm2d_fwd_problem.hpp | 22 +- 3 files changed, 263 insertions(+), 96 deletions(-) diff --git a/example/ck_tile/02_layernorm2d/layernorm2d_fwd.cpp b/example/ck_tile/02_layernorm2d/layernorm2d_fwd.cpp index 9cbd286104..35f291e060 100644 --- a/example/ck_tile/02_layernorm2d/layernorm2d_fwd.cpp +++ b/example/ck_tile/02_layernorm2d/layernorm2d_fwd.cpp @@ -35,7 +35,9 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t, YDataType, MeanDataType, InvStdDataType, - Shape>; + Shape, + true, + true>; using Kernel = ck_tile::Layernorm2dFwd; diff --git a/include/ck_tile/ops/layernorm2d/kernel/layernorm2d_fwd_kernel.hpp b/include/ck_tile/ops/layernorm2d/kernel/layernorm2d_fwd_kernel.hpp index 4be3e56874..468df793da 100644 --- a/include/ck_tile/ops/layernorm2d/kernel/layernorm2d_fwd_kernel.hpp +++ b/include/ck_tile/ops/layernorm2d/kernel/layernorm2d_fwd_kernel.hpp @@ -31,8 +31,14 @@ struct Layernorm2dFwd static constexpr ck_tile::index_t kMPerBlock = Problem::BlockShape::kMPerBlock; static constexpr ck_tile::index_t kNPerBlock = Problem::BlockShape::kNPerBlock; + static constexpr bool kPadM = Problem::kPadM; + static constexpr bool kPadN = Problem::kPadN; static constexpr ck_tile::index_t kNThreadPerWarp = Problem::BlockShape::kNThreadPerWarp; + static constexpr ck_tile::index_t kNPerThread = Problem::BlockShape::kNPerThread; + + static constexpr auto I0 = number<0>{}; + static constexpr auto I1 = number<1>{}; struct Kargs { @@ -96,19 +102,25 @@ struct Layernorm2dFwd sequence<2>>{}); } - template - CK_TILE_DEVICE static constexpr auto GetNPerThread(Dstr) + CK_TILE_DEVICE static int GetWelfordMaxCount(int N) { - constexpr auto nDstrSpan = Dstr::get_distributed_spans().template at<1>(); + constexpr ck_tile::index_t kNThreadPerBlock = kNPerBlock / kNPerThread; - using Lengths = decltype(nDstrSpan.impl_); + int thread_id_n = get_thread_id() % kNThreadPerBlock; + int max_count = + __builtin_amdgcn_readfirstlane(N < kNPerBlock ? 0 : kNPerThread * (N / kNPerBlock)); + int n_per_block_tail_loop = + __builtin_amdgcn_readfirstlane(N - max_count * kNThreadPerBlock); - ck_tile::index_t ret = 1; + if(n_per_block_tail_loop > 0) + { + int thread_max_n = (thread_id_n + 1) * kNPerThread; + int delta = thread_max_n - n_per_block_tail_loop; + delta = clamp(thread_max_n - n_per_block_tail_loop, 0, kNPerThread); + max_count += kNPerThread - delta; + } - ck_tile::static_for<0, Lengths::size(), 1>{}( - [&](auto idx) { ret *= Lengths::template at(idx); }); - - return ret; + return max_count; } template @@ -129,42 +141,29 @@ struct Layernorm2dFwd return out_dstr_tensor; } - template - CK_TILE_DEVICE std::enable_if_t TwoPassLayernorm2dFwd(const XDataType* p_x, - const GammaDataType* p_gamma, - const BetaDataType* p_beta, - YDataType* p_y, - MeanDataType* p_mean, - InvStdDataType* p_invStd, - const ComputeDataType epsilon, - ck_tile::index_t M, - ck_tile::index_t N) const + template + CK_TILE_DEVICE std::enable_if_t + TwoPassLayernorm2dFwd(XBlockWindow& x_block_window, + GammaBlockWindow& gamma_block_window, + BetaBlockWindow& beta_block_window, + YBlockWindow& y_block_window, + MeanBlockWindow& mean_block_window, + InvStdBlockWindow& inv_std_block_window, + ComputeDataType epsilon, + ck_tile::index_t N) const { - constexpr auto I0 = number<0>{}; - constexpr auto I1 = number<1>{}; + // TODO - Optimize tail loop to reduce move_tile_window() + index_t num_n_tile_iteration = + __builtin_amdgcn_readfirstlane(integer_divide_ceil(N, kNPerBlock)); - const auto x_m_n = make_naive_tensor_view( - p_x, make_tuple(M, N), make_tuple(N, 1), number<32>{}, number<1>{}); - - const auto gamma_n = make_naive_tensor_view( - p_gamma, make_tuple(N), make_tuple(1), number<32>{}, number<1>{}); - - const auto beta_n = make_naive_tensor_view( - p_beta, make_tuple(N), make_tuple(1), number<32>{}, number<1>{}); - - const auto iM = get_block_id() * kMPerBlock; - - constexpr auto xDstr = MakeXBlockTileDistribution(); - - auto x_block_window = make_tile_window( - x_m_n, make_tuple(number{}, number{}), {iM, 0}, xDstr); - - index_t num_n_tile_iteration = __builtin_amdgcn_readfirstlane(N / kNPerBlock); - - // TODO: padding - handle max_count if N % kNPerBlock != 0 - constexpr auto NPerThread = GetNPerThread(xDstr); - ThreadWelford thread_welford{ - type_convert(NPerThread * N / kNPerBlock)}; + int welford_max_count = GetWelfordMaxCount(N); + ThreadWelford thread_welford{welford_max_count}; using XTensorType = decltype(load_tile(x_block_window)); auto mean_compute_block_tensor = @@ -190,44 +189,14 @@ struct Layernorm2dFwd auto inv_std_compute_block_tensor = InvSqrt(var_compute_block_tensor, epsilon); if constexpr(kSaveMean) - { - const auto mean_m = make_naive_tensor_view_packed( - p_mean, make_tuple(M), number<32>{}); - - auto mean_block_window = - make_tile_window(mean_m, make_tuple(number{}), {iM}); - store_tile(mean_block_window, cast_tile(mean_compute_block_tensor)); - } if constexpr(kSaveInvStd) - { - const auto inv_std_m = make_naive_tensor_view_packed( - p_invStd, make_tuple(M), number<32>{}); - - auto inv_std_block_window = - make_tile_window(inv_std_m, make_tuple(number{}), {iM}); - - store_tile(inv_std_block_window, cast_tile(inv_std_compute_block_tensor)); - } - - // TODO: Extract normalize pipeline - const auto y_m_n = make_naive_tensor_view( - p_y, make_tuple(M, N), make_tuple(N, 1), number<32>{}, number<1>{}); - - auto y_block_window = make_tile_window( - y_m_n, make_tuple(number{}, number{}), {iM, 0}); - - constexpr auto gammaDstr = MakeGammaBetaBlockTileDistribution(); - constexpr auto betaDstr = gammaDstr; - - auto gamma_block_window = - make_tile_window(gamma_n, make_tuple(number{}), {0}, gammaDstr); - - auto beta_block_window = make_tile_window( - beta_n, make_tuple(number{}, number{}), {0}, betaDstr); + store_tile(inv_std_block_window, + cast_tile(inv_std_compute_block_tensor)); // reverse read x to reuse cache - ck_tile::index_t stride_to_right_most_window = N - kNPerBlock; + ck_tile::index_t stride_to_right_most_window = + N % kNPerBlock == 0 ? N - kNPerBlock : N - N % kNPerBlock; move_tile_window(x_block_window, {0, -kNPerBlock}); move_tile_window(gamma_block_window, {stride_to_right_most_window}); @@ -274,17 +243,209 @@ struct Layernorm2dFwd } } + template + CK_TILE_DEVICE std::enable_if_t + OnePassLayernorm2dFwd(XBlockWindow& x_block_window, + GammaBlockWindow& gamma_block_window, + BetaBlockWindow& beta_block_window, + YBlockWindow& y_block_window, + MeanBlockWindow& mean_block_window, + InvStdBlockWindow& inv_std_block_window, + ComputeDataType epsilon, + ck_tile::index_t N) const + { + int welford_max_count = GetWelfordMaxCount(N); + ThreadWelford thread_welford{welford_max_count}; + + using XTensorType = decltype(load_tile(x_block_window)); + auto mean_compute_block_tensor = + thread_welford.template MakeInitialMeanVarDistributedTensor(); + auto var_compute_block_tensor = + thread_welford.template MakeInitialMeanVarDistributedTensor(); + + clear_tile(mean_compute_block_tensor); + clear_tile(var_compute_block_tensor); + + const auto x_block_tensor = load_tile(x_block_window); + thread_welford(x_block_tensor, mean_compute_block_tensor, var_compute_block_tensor); + // TODO: support cross warp Welford + WarpMergeWelford{}( + mean_compute_block_tensor, var_compute_block_tensor, thread_welford.cur_count_); + + auto inv_std_compute_block_tensor = InvSqrt(var_compute_block_tensor, epsilon); + + if constexpr(kSaveMean) + store_tile(mean_block_window, cast_tile(mean_compute_block_tensor)); + if constexpr(kSaveInvStd) + store_tile(inv_std_block_window, + cast_tile(inv_std_compute_block_tensor)); + + // normalize + const auto gamma_block_tensor = load_tile(gamma_block_window); + const auto beta_block_tensor = load_tile(beta_block_window); + + constexpr auto x_spans = decltype(x_block_tensor)::get_distributed_spans(); + + auto y_block_tensor = + make_static_distributed_tensor(x_block_tensor.get_tile_distribution()); + + sweep_tile_span(x_spans[I1], [&](auto idx1) { + constexpr auto j_idx = make_tuple(idx1); + const auto gamma = type_convert(gamma_block_tensor[j_idx]); + const auto beta = type_convert(beta_block_tensor[j_idx]); + + sweep_tile_span(x_spans[I0], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); + constexpr auto i_j_idx = make_tuple(idx0, idx1); + + const auto mean = mean_compute_block_tensor[i_idx]; + const auto inv_std = inv_std_compute_block_tensor[i_idx]; + + const auto x = type_convert(x_block_tensor[i_j_idx]); + auto y = (x - mean) * inv_std * gamma + beta; + + y_block_tensor(i_j_idx) = type_convert(y); + }); + }); + + store_tile(y_block_window, y_block_tensor); + } + CK_TILE_DEVICE void operator()(Kargs kargs) const { - TwoPassLayernorm2dFwd(static_cast(kargs.p_x), - static_cast(kargs.p_gamma), - static_cast(kargs.p_beta), - static_cast(kargs.p_y), - static_cast(kargs.p_mean), - static_cast(kargs.p_invStd), - static_cast(kargs.epsilon), - kargs.M, - kargs.N); + const auto x_m_n = [&]() { + const auto x_dram_naive = make_naive_tensor_view( + static_cast(kargs.p_x), + make_tuple(kargs.M, kargs.N), + make_tuple(kargs.N, 1), + number{}, + number<1>{}); + + return pad_tensor_view(x_dram_naive, + make_tuple(number{}, number{}), + sequence{}); + }(); + + const auto gamma_n = [&]() { + const auto gamma_dram_naive = make_naive_tensor_view( + static_cast(kargs.p_gamma), + make_tuple(kargs.N), + make_tuple(1), + number{}, + number<1>{}); + + return pad_tensor_view( + gamma_dram_naive, make_tuple(number{}), sequence{}); + }(); + + const auto beta_n = [&]() { + const auto gamma_dram_naive = make_naive_tensor_view( + static_cast(kargs.p_beta), + make_tuple(kargs.N), + make_tuple(1), + number{}, + number<1>{}); + + return pad_tensor_view( + gamma_dram_naive, make_tuple(number{}), sequence{}); + }(); + + const auto iM = get_block_id() * kMPerBlock; + + constexpr auto xDstr = MakeXBlockTileDistribution(); + + auto x_block_window = make_tile_window( + x_m_n, make_tuple(number{}, number{}), {iM, 0}, xDstr); + + const auto y_m_n = [&]() { + const auto y_dram_naive = make_naive_tensor_view( + static_cast(kargs.p_y), + make_tuple(kargs.M, kargs.N), + make_tuple(kargs.N, 1), + number{}, + number<1>{}); + + return pad_tensor_view(y_dram_naive, + make_tuple(number{}, number{}), + sequence{}); + }(); + + auto y_block_window = make_tile_window( + y_m_n, make_tuple(number{}, number{}), {iM, 0}); + + constexpr auto gammaDstr = MakeGammaBetaBlockTileDistribution(); + constexpr auto betaDstr = gammaDstr; + + auto gamma_block_window = + make_tile_window(gamma_n, make_tuple(number{}), {0}, gammaDstr); + + auto beta_block_window = make_tile_window( + beta_n, make_tuple(number{}, number{}), {0}, betaDstr); + + auto mean_block_window = [&]() { + if constexpr(kSaveMean) + { + const auto mean_m = [&]() { + const auto mean_dram_naive = + make_naive_tensor_view_packed( + static_cast(kargs.p_mean), + make_tuple(kargs.M), + number<1>{}); + + return pad_tensor_view( + mean_dram_naive, make_tuple(number{}), sequence{}); + }(); + + return make_tile_window(mean_m, make_tuple(number{}), {iM}); + } + else + return make_null_tile_window(make_tuple(number{})); + }(); + + auto inv_std_block_window = [&]() { + if constexpr(kSaveInvStd) + { + const auto inv_std_m = [&]() { + const auto inv_std_dram_naive = + make_naive_tensor_view_packed( + static_cast(kargs.p_invStd), + make_tuple(kargs.M), + number<1>{}); + + return pad_tensor_view( + inv_std_dram_naive, make_tuple(number{}), sequence{}); + }(); + + return make_tile_window(inv_std_m, make_tuple(number{}), {iM}); + } + else + return make_null_tile_window(make_tuple(number{})); + }(); + + if(kargs.N <= kNPerBlock) + OnePassLayernorm2dFwd(x_block_window, + gamma_block_window, + beta_block_window, + y_block_window, + mean_block_window, + inv_std_block_window, + static_cast(kargs.epsilon), + kargs.N); + else + TwoPassLayernorm2dFwd(x_block_window, + gamma_block_window, + beta_block_window, + y_block_window, + mean_block_window, + inv_std_block_window, + static_cast(kargs.epsilon), + kargs.N); } }; diff --git a/include/ck_tile/ops/layernorm2d/pipeline/block_layernorm2d_fwd_problem.hpp b/include/ck_tile/ops/layernorm2d/pipeline/block_layernorm2d_fwd_problem.hpp index 5206d36d7d..707a38f621 100644 --- a/include/ck_tile/ops/layernorm2d/pipeline/block_layernorm2d_fwd_problem.hpp +++ b/include/ck_tile/ops/layernorm2d/pipeline/block_layernorm2d_fwd_problem.hpp @@ -14,17 +14,21 @@ template + typename BlockShape_, + bool kPadM_, + bool kPadN_> struct BlockLayernorm2dFwdProblem { - using XDataType = remove_cvref_t; - using GammaDataType = remove_cvref_t; - using BetaDataType = remove_cvref_t; - using ComputeDataType = remove_cvref_t; - using YDataType = remove_cvref_t; - using MeanDataType = remove_cvref_t; - using InvStdDataType = remove_cvref_t; - using BlockShape = remove_cvref_t; + using XDataType = remove_cvref_t; + using GammaDataType = remove_cvref_t; + using BetaDataType = remove_cvref_t; + using ComputeDataType = remove_cvref_t; + using YDataType = remove_cvref_t; + using MeanDataType = remove_cvref_t; + using InvStdDataType = remove_cvref_t; + using BlockShape = remove_cvref_t; + static constexpr bool kPadM = kPadM_; + static constexpr bool kPadN = kPadN_; }; } // namespace ck_tile From cc8f466a7ecbdca058aa9b8aeb2c75c57864d2ac Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bart=C5=82omiej=20Kocot?= Date: Mon, 7 Oct 2024 15:21:21 +0200 Subject: [PATCH 07/24] [CK_TILE] Fix conv param multiple definition (#1550) Co-authored-by: Po Yen Chen --- include/ck_tile/host/convolution_parameter.hpp | 6 ------ 1 file changed, 6 deletions(-) diff --git a/include/ck_tile/host/convolution_parameter.hpp b/include/ck_tile/host/convolution_parameter.hpp index 741a25ad73..81ea51a94f 100644 --- a/include/ck_tile/host/convolution_parameter.hpp +++ b/include/ck_tile/host/convolution_parameter.hpp @@ -13,7 +13,6 @@ namespace conv { struct ConvParam { - ConvParam(); ConvParam(ck_tile::index_t n_dim, ck_tile::index_t group_count, ck_tile::index_t n_batch, @@ -199,11 +198,6 @@ struct ConvParam } }; -ConvParam::ConvParam() - : ConvParam::ConvParam(2, 1, 128, 256, 192, {3, 3}, {71, 71}, {2, 2}, {1, 1}, {1, 1}, {1, 1}) -{ -} - CK_TILE_HOST std::string get_conv_param_parser_helper_msg() { std::string msg; From 7d8ea5f08bfea303b978c3fcb4f5b7069985b0ff Mon Sep 17 00:00:00 2001 From: Illia Silin <98187287+illsilin@users.noreply.github.com> Date: Mon, 7 Oct 2024 08:18:23 -0700 Subject: [PATCH 08/24] Fix build logic using GRU_ARCHS. (#1536) * update build logic with GPU_ARCHS * fix the GPU_ARCHS build for codegen * unset GPU_TARGETS when GPU_ARCHS are set --- CMakeLists.txt | 101 +++++++----------- Jenkinsfile | 4 +- README.md | 11 +- codegen/test/CMakeLists.txt | 3 +- example/CMakeLists.txt | 13 +-- include/ck/config.h.in | 7 -- .../gpu/CMakeLists.txt | 19 +--- profiler/src/CMakeLists.txt | 12 +-- test/CMakeLists.txt | 16 +-- 9 files changed, 64 insertions(+), 122 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index dc73b5f4d4..989995d0f5 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -98,11 +98,6 @@ if(DL_KERNELS) set(CK_ENABLE_DL_KERNELS "ON") endif() -if(INSTANCES_ONLY) - add_definitions(-DINSTANCES_ONLY) - set(CK_ENABLE_INSTANCES_ONLY "ON") -endif() - include(getopt) # CK version file to record release version as well as git commit hash @@ -127,6 +122,12 @@ rocm_setup_version(VERSION ${version}) list(APPEND CMAKE_PREFIX_PATH ${CMAKE_INSTALL_PREFIX} ${CMAKE_INSTALL_PREFIX}/llvm ${CMAKE_INSTALL_PREFIX}/hip /opt/rocm /opt/rocm/llvm /opt/rocm/hip "$ENV{ROCM_PATH}" "$ENV{HIP_PATH}") message("GPU_TARGETS= ${GPU_TARGETS}") +message("GPU_ARCHS= ${GPU_ARCHS}") +if(GPU_ARCHS) + #disable GPU_TARGETS to avoid conflicts, this needs to happen before we call hip package + unset(GPU_TARGETS CACHE) + unset(AMDGPU_TARGETS CACHE) +endif() find_package(hip) # No assumption that HIP kernels are launched with uniform block size for backward compatibility @@ -135,55 +136,38 @@ math(EXPR hip_VERSION_FLAT "(${hip_VERSION_MAJOR} * 1000 + ${hip_VERSION_MINOR}) message("hip_version_flat=${hip_VERSION_FLAT}") message("checking which targets are supported") -#This is the list of targets to be used in case GPU_TARGETS is not set on command line -#These targets will be filtered and only supported ones will be used -#Setting GPU_TARGETS on command line will override this list -if(NOT PROFILER_ONLY) - if(NOT ENABLE_ASAN_PACKAGING) - #build CK for all supported targets - if(NOT WIN32 AND ${hip_VERSION_FLAT} LESS 600300000) - # WORKAROUND: compiler does not yet fully support gfx12 targets, need to fix version above - rocm_check_target_ids(DEFAULT_GPU_TARGETS - TARGETS "gfx908;gfx90a;gfx940;gfx941;gfx942;gfx1030;gfx1100;gfx1101;gfx1102") - else() - rocm_check_target_ids(DEFAULT_GPU_TARGETS - TARGETS "gfx908;gfx90a;gfx940;gfx941;gfx942;gfx1030;gfx1100;gfx1101;gfx1102;gfx1200;gfx1201") - endif() +#In order to build just the CK library (without tests and examples) for all supported GPU targets +#use -D GPU_ARCHS="gfx908;gfx90a;gfx940;gfx941;gfx942;gfx1030;gfx1100;gfx1101;gfx1102;gfx1200;gfx1201" +#the GPU_TARGETS flag will be reset in this case in order to avoid conflicts. +# +#In order to build CK along with all tests and examples it should be OK to set GPU_TARGETS to just 1 or 2 similar architectures. +if(NOT ENABLE_ASAN_PACKAGING) + if(NOT WIN32 AND ${hip_VERSION_FLAT} LESS 600300000) + # WORKAROUND: compiler does not yet fully support gfx12 targets, need to fix version above + set(CK_GPU_TARGETS "gfx908;gfx90a;gfx940;gfx941;gfx942;gfx1030;gfx1100;gfx1101;gfx1102") else() - #build CK only for xnack-supported targets - rocm_check_target_ids(DEFAULT_GPU_TARGETS - TARGETS "gfx908:xnack+;gfx90a:xnack+;gfx940:xnack+;gfx941:xnack+;gfx942:xnack+") - set(GPU_TARGETS "${DEFAULT_GPU_TARGETS}" CACHE STRING " " FORCE) + set(CK_GPU_TARGETS "gfx908;gfx90a;gfx940;gfx941;gfx942;gfx1030;gfx1100;gfx1101;gfx1102;gfx1200;gfx1201") endif() else() - add_definitions(-DPROFILER_ONLY) - set(GPU_TARGETS "" CACHE STRING "" FORCE) + #build CK only for xnack-supported targets when using ASAN + set(CK_GPU_TARGETS "gfx908:xnack+;gfx90a:xnack+;gfx940:xnack+;gfx941:xnack+;gfx942:xnack+") +endif() + +#if user set GPU_ARCHS on the cmake command line, overwrite default target list with user's list +#otherwise, if user set GPU_TARGETS, use that set of targets +if(GPU_ARCHS) + set(CK_GPU_TARGETS ${GPU_ARCHS}) +else() if(GPU_TARGETS) - message(FATAL_ERROR "For PROFILE_ONLY build, please do not set GPU_TARGETS, use GPU_ARCH = gfx90, gfx94, gfx10, gfx11 or gfx12") + set(CK_GPU_TARGETS ${GPU_TARGETS}) endif() - if(GPU_ARCH MATCHES "gfx90") - rocm_check_target_ids(DEFAULT_GPU_TARGETS TARGETS "gfx908;gfx90a") - elseif(GPU_ARCH MATCHES "gfx94") - rocm_check_target_ids(DEFAULT_GPU_TARGETS TARGETS "gfx940;gfx941;gfx942") - elseif(GPU_ARCH MATCHES "gfx10") - rocm_check_target_ids(DEFAULT_GPU_TARGETS TARGETS "gfx1030") - elseif(GPU_ARCH MATCHES "gfx11") - rocm_check_target_ids(DEFAULT_GPU_TARGETS TARGETS "gfx1100;gfx1101;gfx1102") - elseif(GPU_ARCH MATCHES "gfx12") - rocm_check_target_ids(DEFAULT_GPU_TARGETS TARGETS "gfx1200;gfx1201") - else() - message(FATAL_ERROR "For PROFILE_ONLY build, please specify GPU_ARCH as gfx90, gfx94, gfx10, gfx11 or gfx12") - endif() - set(GPU_TARGETS "${DEFAULT_GPU_TARGETS}" CACHE STRING " " FORCE) endif() -message("Supported GPU_TARGETS= ${DEFAULT_GPU_TARGETS}") +#make sure all the targets on the list are actually supported by the current compiler +rocm_check_target_ids(SUPPORTED_GPU_TARGETS + TARGETS ${CK_GPU_TARGETS}) -if(GPU_TARGETS) - message("Building CK for the following targets: ${GPU_TARGETS}") -else() - message("Building CK for the default targets: ${DEFAULT_GPU_TARGETS}") -endif() +message("Building CK for the following targets: ${SUPPORTED_GPU_TARGETS}") if (GPU_TARGETS) if (GPU_TARGETS MATCHES "gfx9") @@ -557,8 +541,7 @@ ENDFOREACH() add_custom_target(instances DEPENDS utility;${CK_DEVICE_INSTANCES} SOURCES ${INSTANCE_FILES}) add_subdirectory(library) -if(NOT DEFINED INSTANCES_ONLY) - if(NOT DEFINED PROFILER_ONLY) +if(NOT GPU_ARCHS) rocm_package_setup_component(tests LIBRARY_NAME composablekernel PACKAGE_NAME tests # Prevent -static suffix on package name @@ -572,23 +555,15 @@ if(NOT DEFINED INSTANCES_ONLY) if(BUILD_TESTING) add_subdirectory(test) endif() - - rocm_package_setup_component(profiler - LIBRARY_NAME composablekernel - PACKAGE_NAME ckprofiler - ) - add_subdirectory(profiler) - else() - #When building PROFILER_ONLY, label the package with GPU_ARCH - rocm_package_setup_component(profiler - LIBRARY_NAME composablekernel - PACKAGE_NAME ckprofiler_${GPU_ARCH} - ) - add_subdirectory(profiler) - endif() endif() -if(NOT DEFINED PROFILER_ONLY AND (GPU_TARGETS MATCHES "gfx9" OR DEFINED INSTANCES_ONLY)) +rocm_package_setup_component(profiler + LIBRARY_NAME composablekernel + PACKAGE_NAME ckprofiler +) +add_subdirectory(profiler) + +if(GPU_TARGETS MATCHES "gfx9" OR GPU_ARCHS) add_subdirectory(codegen) endif() diff --git a/Jenkinsfile b/Jenkinsfile index 22468401dc..e61fb71e8e 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -1138,8 +1138,8 @@ pipeline { execute_args = """ cmake -D CMAKE_PREFIX_PATH=/opt/rocm \ -D CMAKE_CXX_COMPILER="${build_compiler()}" \ -D CMAKE_BUILD_TYPE=Release \ - -D INSTANCES_ONLY=ON \ - -DCMAKE_CXX_FLAGS=" -O3 " .. && make -j64 """ + -D GPU_ARCHS="gfx908;gfx90a;gfx940;gfx941;gfx942;gfx1030;gfx1100;gfx1101;gfx1102;gfx1200;gfx1201" \ + -D CMAKE_CXX_FLAGS=" -O3 " .. && make -j64 """ } steps{ buildHipClangJobAndReboot(setup_cmd: "", build_cmd: "", no_reboot:true, build_type: 'Release', execute_cmd: execute_args) diff --git a/README.md b/README.md index 4889914691..34ac0919ae 100644 --- a/README.md +++ b/README.md @@ -90,7 +90,12 @@ Docker images are available on [DockerHub](https://hub.docker.com/r/rocm/composa ``` If you don't set `GPU_TARGETS` on the cmake command line, CK is built for all GPU targets - supported by the current compiler (this may take a long time). + supported by the current compiler (this may take a long time). + + NOTE: If you try setting `GPU_TARGETS` to a list of architectures, the build will only work if the + architectures are similar, e.g., `gfx908;gfx90a`, or `gfx1100;gfx1101;gfx11012`. Otherwise, if you + want to build the library for a list of different architectures, + you should use the `GPU_ARCHS` build argument, for example `GPU_ARCHS=gfx908;gfx1030;gfx1100;gfx942`. 4. Build the entire CK library: @@ -137,10 +142,6 @@ crash. In such cases, you can reduce the number of threads to 32 by using `-j32` Additional cmake flags can be used to significantly speed-up the build: -* `INSTANCES_ONLY` (default is OFF) must be set to ON in order to build only the instances and library - while skipping all tests, examples, and profiler. This is useful in cases when you plan to use CK as a - dependency and don't plan to run any examples or tests. - * `DTYPES` (default is not set) can be set to any subset of "fp64;fp32;fp16;fp8;bf16;int8" to build instances of select data types only. The main default data types are fp32 and fp16; you can safely skip other data types. diff --git a/codegen/test/CMakeLists.txt b/codegen/test/CMakeLists.txt index 6dd130bc3f..1de612e49a 100644 --- a/codegen/test/CMakeLists.txt +++ b/codegen/test/CMakeLists.txt @@ -1,7 +1,8 @@ list(APPEND CMAKE_PREFIX_PATH /opt/rocm) add_subdirectory(rtc) file(GLOB TEST_SRCS CONFIGURE_DEPENDS *.cpp) -if(NOT INSTANCES_ONLY) +# do not build the tests when we build the library for various targets +if(NOT GPU_ARCHS) foreach(TEST_SRC ${TEST_SRCS}) set_source_files_properties(${TEST_SRC} PROPERTIES LANGUAGE HIP) get_filename_component(BASE_NAME ${TEST_SRC} NAME_WE) diff --git a/example/CMakeLists.txt b/example/CMakeLists.txt index f9e62a2356..ad3f7c787f 100644 --- a/example/CMakeLists.txt +++ b/example/CMakeLists.txt @@ -45,11 +45,7 @@ function(add_example_executable EXAMPLE_NAME FILE_NAME) endforeach() endif() - if(INSTANCES_ONLY) - set(EX_TARGETS ${DEFAULT_GPU_TARGETS}) - else() - set(EX_TARGETS ${GPU_TARGETS}) - endif() + set(EX_TARGETS ${SUPPORTED_GPU_TARGETS}) #Do not build any DL examples if DL_KERNELS not set foreach(source IN LISTS FILE_NAME) @@ -147,11 +143,8 @@ function(add_example_executable_no_testing EXAMPLE_NAME FILE_NAME) endforeach() endif() - if(INSTANCES_ONLY) - set(EX_TARGETS ${DEFAULT_GPU_TARGETS}) - else() - set(EX_TARGETS ${GPU_TARGETS}) - endif() + set(EX_TARGETS ${SUPPORTED_GPU_TARGETS}) + #Do not build any DL examples if DL_KERNELS not set foreach(source IN LISTS FILE_NAME) if(NOT DEFINED DL_KERNELS AND source MATCHES "_dl") diff --git a/include/ck/config.h.in b/include/ck/config.h.in index eb9049b599..0f0b7bd607 100644 --- a/include/ck/config.h.in +++ b/include/ck/config.h.in @@ -97,13 +97,6 @@ #cmakedefine CK_ENABLE_DL_KERNELS @CK_ENABLE_DL_KERNELS@ #endif -// -// Instances supports in the current CK build -// -#ifndef CK_ENABLE_INSTANCES_ONLY -#cmakedefine CK_ENABLE_INSTANCES_ONLY @CK_ENABLE_INSTANCES_ONLY@ -#endif - // // CK kernels which support XDL (MI series) // diff --git a/library/src/tensor_operation_instance/gpu/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/CMakeLists.txt index bc66fe0bed..f82176ffc6 100644 --- a/library/src/tensor_operation_instance/gpu/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/CMakeLists.txt @@ -37,11 +37,7 @@ function(add_instance_library INSTANCE_NAME) endforeach() endif() - if(INSTANCES_ONLY) - set(INST_TARGETS ${DEFAULT_GPU_TARGETS}) - else() - set(INST_TARGETS ${GPU_TARGETS}) - endif() + set(INST_TARGETS ${SUPPORTED_GPU_TARGETS}) # Do not build DL instances if DL_KERNELS macro is not set foreach(source IN LISTS ARGN) @@ -75,11 +71,7 @@ function(add_instance_library INSTANCE_NAME) if(ARGN) set(INST_OBJ) foreach(source IN LISTS ARGN) - if(INSTANCES_ONLY) - set(INST_TARGETS ${DEFAULT_GPU_TARGETS}) - else() - set(INST_TARGETS ${GPU_TARGETS}) - endif() + set(INST_TARGETS ${SUPPORTED_GPU_TARGETS}) if(source MATCHES "_xdl") list(REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201) elseif(ARGN MATCHES "_wmma") @@ -191,12 +183,7 @@ FOREACH(subdir_path ${dir_list}) set(add_inst 1) endif() - if(INSTANCES_ONLY) - set(INST_TARGETS ${DEFAULT_GPU_TARGETS}) - else() - set(INST_TARGETS ${GPU_TARGETS}) - endif() - + set(INST_TARGETS ${SUPPORTED_GPU_TARGETS}) if(("${cmake_instance}" MATCHES "quantization") AND (DEFINED DTYPES) AND (NOT DTYPES MATCHES "int8")) message("quantization instances will not be built!") diff --git a/profiler/src/CMakeLists.txt b/profiler/src/CMakeLists.txt index e9528baeb6..7d4df3cf9b 100644 --- a/profiler/src/CMakeLists.txt +++ b/profiler/src/CMakeLists.txt @@ -24,7 +24,7 @@ set(PROFILER_SOURCES profile_permute_scale.cpp ) -if(GPU_TARGETS MATCHES "gfx9") +if(SUPPORTED_GPU_TARGETS MATCHES "gfx9") if(DTYPES MATCHES "fp32" OR DTYPES MATCHES "fp64" OR NOT DEFINED DTYPES) list(APPEND PROFILER_SOURCES profile_contraction_bilinear.cpp) list(APPEND PROFILER_SOURCES profile_contraction_scale.cpp) @@ -49,7 +49,7 @@ if(GPU_TARGETS MATCHES "gfx9") list(APPEND PROFILER_SOURCES profile_grouped_gemm_multiply_tile_loop.cpp) endif() list(APPEND PROFILER_SOURCES profile_gemm_multiply_add.cpp) - if(GPU_TARGETS MATCHES "gfx94") + if(SUPPORTED_GPU_TARGETS MATCHES "gfx94") list(APPEND PROFILER_SOURCES profile_gemm_multiply_multiply.cpp) list(APPEND PROFILER_SOURCES profile_gemm_ab_scale.cpp) endif() @@ -69,7 +69,7 @@ if(GPU_TARGETS MATCHES "gfx9") endif() -if(GPU_TARGETS MATCHES "gfx11" OR GPU_TARGETS MATCHES "gfx12" OR GPU_TARGETS MATCHES "gfx9") +if(SUPPORTED_GPU_TARGETS MATCHES "gfx11" OR SUPPORTED_GPU_TARGETS MATCHES "gfx12" OR SUPPORTED_GPU_TARGETS MATCHES "gfx9") if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) list(APPEND PROFILER_SOURCES profile_gemm_bilinear.cpp) endif() @@ -111,7 +111,7 @@ target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_column_to_image_inst target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_transpose_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_permute_scale_instance) -if(GPU_TARGETS MATCHES "gfx9") +if(SUPPORTED_GPU_TARGETS MATCHES "gfx9") if(DTYPES MATCHES "fp32" OR DTYPES MATCHES "fp64" OR NOT DEFINED DTYPES) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_contraction_bilinear_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_contraction_scale_instance) @@ -135,7 +135,7 @@ if(GPU_TARGETS MATCHES "gfx9") target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_reduce_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_multiply_add_instance) - if(GPU_TARGETS MATCHES "gfx94") + if(SUPPORTED_GPU_TARGETS MATCHES "gfx94") target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_multiply_multiply_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_ab_scale_instance) endif() @@ -159,7 +159,7 @@ if(GPU_TARGETS MATCHES "gfx9") target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv3d_fwd_convinvscale_instance) endif() -if(GPU_TARGETS MATCHES "gfx9" OR GPU_TARGETS MATCHES "gfx11" OR GPU_TARGETS MATCHES "gfx12") +if(SUPPORTED_GPU_TARGETS MATCHES "gfx9" OR SUPPORTED_GPU_TARGETS MATCHES "gfx11" OR SUPPORTED_GPU_TARGETS MATCHES "gfx12") if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_bilinear_instance) endif() diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index e61d937f08..b836dd687e 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -41,11 +41,7 @@ function(add_test_executable TEST_NAME) endforeach() endif() - if(INSTANCES_ONLY) - set(TEST_TARGETS ${DEFAULT_GPU_TARGETS}) - else() - set(TEST_TARGETS ${GPU_TARGETS}) - endif() + set(TEST_TARGETS ${SUPPORTED_GPU_TARGETS}) foreach(source IN LISTS ARGN) if(NOT DEFINED DL_KERNELS AND source MATCHES "_dl") @@ -122,11 +118,7 @@ function(add_gtest_executable TEST_NAME) endforeach() endif() - if(INSTANCES_ONLY) - set(TEST_TARGETS ${DEFAULT_GPU_TARGETS}) - else() - set(TEST_TARGETS ${GPU_TARGETS}) - endif() + set(TEST_TARGETS ${SUPPORTED_GPU_TARGETS}) foreach(source IN LISTS ARGN) if(NOT DEFINED DL_KERNELS AND source MATCHES "_dl") @@ -211,10 +203,10 @@ add_subdirectory(conv_tensor_rearrange) add_subdirectory(transpose) add_subdirectory(permute_scale) add_subdirectory(wrapper) -if(GPU_TARGETS MATCHES "gfx11") +if(SUPPORTED_GPU_TARGETS MATCHES "gfx11") add_subdirectory(wmma_op) endif() -if(GPU_TARGETS MATCHES "gfx942" AND CK_HIP_VERSION_MAJOR GREATER_EQUAL 6 AND CK_HIP_VERSION_MINOR GREATER_EQUAL 2) # smfmac needs ROCm6.2 +if(SUPPORTED_GPU_TARGETS MATCHES "gfx942" AND CK_HIP_VERSION_MAJOR GREATER_EQUAL 6 AND CK_HIP_VERSION_MINOR GREATER_EQUAL 2) # smfmac needs ROCm6.2 add_subdirectory(smfmac_op) endif() add_subdirectory(position_embedding) From 7733ae167bf7b683a65da66148bd98a0bccd35eb Mon Sep 17 00:00:00 2001 From: Illia Silin <98187287+illsilin@users.noreply.github.com> Date: Mon, 7 Oct 2024 15:45:19 -0700 Subject: [PATCH 09/24] add a CK_USE_CODEGEN build argument to enable codegen (#1552) * add a CK_USE_CODEGEN build argument to enable codegen * fix cmake codegen logic --- CMakeLists.txt | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 989995d0f5..6ad6307cb3 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -97,6 +97,10 @@ if(DL_KERNELS) add_definitions(-DDL_KERNELS) set(CK_ENABLE_DL_KERNELS "ON") endif() +option(CK_USE_CODEGEN "Enable codegen library" OFF) +if(CK_USE_CODEGEN) + add_definitions(-DCK_USE_CODEGEN) +endif() include(getopt) @@ -563,7 +567,7 @@ rocm_package_setup_component(profiler ) add_subdirectory(profiler) -if(GPU_TARGETS MATCHES "gfx9" OR GPU_ARCHS) +if(CK_USE_CODEGEN AND (GPU_TARGETS MATCHES "gfx9" OR GPU_ARCHS)) add_subdirectory(codegen) endif() From 74d68e3b991dbfff7f14881a572bc77f4954c4fc Mon Sep 17 00:00:00 2001 From: Qianfeng Date: Tue, 8 Oct 2024 10:44:34 +0800 Subject: [PATCH 10/24] [CK_TILE] Simplify the codes in splitkv_combine pipeline (#1549) * Simplify the codes in splitkv_combine pipeline * Always set kPadSeqLenK=true for fmha splitkv kernels * Change in Oacc Alignment and TileDistribution to be more adaptable to tile sizes --------- Co-authored-by: Po Yen Chen --- .../01_fmha/codegen/ops/fmha_fwd_splitkv.py | 4 +- ...lock_fmha_fwd_splitkv_combine_pipeline.hpp | 90 ++++++++++--------- ...plitkv_combine_pipeline_default_policy.hpp | 23 +++-- 3 files changed, 67 insertions(+), 50 deletions(-) diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py index ba826c8fb3..82cf3a5ab2 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py @@ -600,8 +600,8 @@ def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> # TODO: use async pipeline when compiler is more stable if hdim == 256 or hdim in [32, 64, 128]: # if True: - pipelines.append(Pipeline('qr', 'row', 'f', 'f', 'f', 'f', bias, lse, squant, pagedkv, mask)) - pipelines.append(Pipeline('qr', 'col', 'f', 'f', 'f', 'f', bias, lse, squant, pagedkv, mask)) + pipelines.append(Pipeline('qr', 'row', 'f', 't', 'f', 'f', bias, lse, squant, pagedkv, mask)) + pipelines.append(Pipeline('qr', 'col', 'f', 't', 'f', 'f', bias, lse, squant, pagedkv, mask)) pipelines.append(Pipeline('qr', 'row', 't', 't', 't', 't', bias, lse, squant, pagedkv, mask)) pipelines.append(Pipeline('qr', 'col', 't', 't', 't', 't', bias, lse, squant, pagedkv, mask)) diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline.hpp index 842090afbe..1afe0feab3 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline.hpp @@ -172,22 +172,27 @@ struct BlockFmhaFwdSplitKVCombinePipeline lse_accum, sequence<1>{}, f_max, -numeric::infinity()); block_tile_reduce_sync(lse_max, f_max, bool_constant{}); - static const auto get_validated_m = [](LSEDataType raw_m) { - return raw_m == -numeric::infinity() ? type_convert(0.f) - : raw_m; - }; - decltype(lse_accum) lse_exp; { constexpr auto spans = decltype(lse_exp)::get_distributed_spans(); sweep_tile_span(spans[number<0>{}], [&](auto idx0) { constexpr auto i_idx = make_tuple(idx0); - sweep_tile_span(spans[number<1>{}], [&](auto idx1) { - constexpr auto i_j_idx = make_tuple(idx0, idx1); + if(lse_max[i_idx] == -numeric::infinity()) + { + sweep_tile_span(spans[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); - lse_exp(i_j_idx) = - ck_tile::exp(lse_accum(i_j_idx) - get_validated_m(lse_max(i_idx))); - }); + lse_exp(i_j_idx) = ck_tile::type_convert(0.0f); + }); + } + else + { + sweep_tile_span(spans[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); + + lse_exp(i_j_idx) = ck_tile::exp(lse_accum(i_j_idx) - lse_max(i_idx)); + }); + } }); } @@ -201,15 +206,10 @@ struct BlockFmhaFwdSplitKVCombinePipeline sweep_tile_span(spans[number<0>{}], [&](auto idx0) { constexpr auto i_idx = make_tuple(idx0); - if(lse_sum(i_idx) == 0.f || lse_sum(i_idx) != lse_sum(i_idx)) - { - lse_logsum(i_idx) = numeric::infinity(); - } + if(lse_sum[i_idx] == ck_tile::type_convert(0.0f)) + lse_logsum(i_idx) = -numeric::infinity(); else - { - lse_logsum(i_idx) = - ck_tile::log(lse_sum(i_idx)) + get_validated_m(lse_max(i_idx)); - } + lse_logsum(i_idx) = ck_tile::log(lse_sum(i_idx)) + lse_max(i_idx); }); } @@ -218,37 +218,47 @@ struct BlockFmhaFwdSplitKVCombinePipeline constexpr auto spans = decltype(lse_accum)::get_distributed_spans(); sweep_tile_span(spans[number<0>{}], [&](auto idx0) { constexpr auto i_idx = make_tuple(idx0); - sweep_tile_span(spans[number<1>{}], [&](auto idx1) { - constexpr auto i_j_idx = make_tuple(idx0, idx1); + if(lse_logsum(i_idx) == -numeric::infinity()) + { + sweep_tile_span(spans[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); - const auto x_indices = get_x_indices_from_distributed_indices( - lse_accum.get_tile_distribution(), i_j_idx); + const auto x_indices = get_x_indices_from_distributed_indices( + lse_accum.get_tile_distribution(), i_j_idx); - const auto col = x_indices.at(number<1>{}); - if(col < num_splits) - { - const auto row = x_indices.at(number<0>{}); + const auto col = x_indices.at(number<1>{}); + if(col < num_splits) + { + const auto row = x_indices.at(number<0>{}); - lse_acc_lds(row, col) = - ck_tile::exp(lse_accum(i_j_idx) - lse_logsum(i_idx)); - } - }); + lse_acc_lds(row, col) = ck_tile::type_convert(0.0f); + } + }); + } + else + { + sweep_tile_span(spans[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); + + const auto x_indices = get_x_indices_from_distributed_indices( + lse_accum.get_tile_distribution(), i_j_idx); + + const auto col = x_indices.at(number<1>{}); + if(col < num_splits) + { + const auto row = x_indices.at(number<0>{}); + + lse_acc_lds(row, col) = + ck_tile::exp(lse_accum(i_j_idx) - lse_logsum(i_idx)); + } + }); + } }); } block_sync_lds(); if constexpr(kStoreLSE) { - constexpr auto spans = decltype(lse_logsum)::get_distributed_spans(); - sweep_tile_span(spans[number<0>{}], [&](auto idx0) { - constexpr auto i_idx = make_tuple(idx0); - - if(lse_logsum(i_idx) == numeric::infinity()) - { - lse_logsum(i_idx) = -numeric::infinity(); - } - }); - store_tile(lse_dram_window_tmp, tile_elementwise_in(lse_element_func, lse_logsum)); } diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline_default_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline_default_policy.hpp index 2eb092f055..3327d4af87 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline_default_policy.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline_default_policy.hpp @@ -21,14 +21,23 @@ struct BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentOacc() { using OaccDataType = remove_cvref_t; - return 16 / sizeof(OaccDataType); + + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t kMPerBlock = Problem::kM0; + constexpr index_t kNPerBlock = Problem::kN1; + + constexpr index_t M1 = kBlockSize / get_warp_size(); + constexpr index_t M2 = min(kMPerBlock / M1, get_warp_size()); + constexpr index_t N0 = get_warp_size() / M2; + constexpr index_t N1 = kNPerBlock / N0; + + return min(N1, static_cast(16 / sizeof(OaccDataType))); } template CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentO() { - using ODataType = remove_cvref_t; - return 16 / sizeof(ODataType); + return GetAlignmentOacc(); } template @@ -150,16 +159,14 @@ struct BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy template CK_TILE_HOST_DEVICE static constexpr auto MakeOaccDramTileDistribution() { - using OaccDataType = remove_cvref_t; - constexpr index_t kBlockSize = Problem::kBlockSize; constexpr index_t kMPerBlock = Problem::kM0; constexpr index_t kNPerBlock = Problem::kN1; - constexpr index_t N1 = 16 / sizeof(OaccDataType); - constexpr index_t N0 = kNPerBlock / N1; - constexpr index_t M2 = get_warp_size() / N0; constexpr index_t M1 = kBlockSize / get_warp_size(); + constexpr index_t M2 = min(kMPerBlock / M1, get_warp_size()); + constexpr index_t N0 = get_warp_size() / M2; + constexpr index_t N1 = kNPerBlock / N0; constexpr index_t M0 = kMPerBlock / (M2 * M1); return make_static_tile_distribution( From 0c094daa7e3fcc3c4b4a6d75c85c31f2925f02a8 Mon Sep 17 00:00:00 2001 From: Po Yen Chen Date: Tue, 8 Oct 2024 10:45:12 +0800 Subject: [PATCH 11/24] [CK_TILE] Update example README files & fix script compatibility issue (#1548) * Fix text alignment of ArgParser::print() * Update example README files * Clarify make-ck-dev.sh usage * Only keep some of the argument from '-?' output * Undo command line output changes in README * Only keep existing argument on doc and update description * Fix text alignment * Make cmake-ck-*.sh compatible with 'sh' command --- example/ck_tile/01_fmha/README.md | 45 ++++++++++++------------ example/ck_tile/02_layernorm2d/README.md | 3 +- example/ck_tile/03_gemm/README.md | 20 +++++++---- example/ck_tile/04_img2col/README.md | 3 +- include/ck_tile/host/arg_parser.hpp | 20 ++++++++--- script/cmake-ck-dev.sh | 3 +- script/cmake-ck-release.sh | 3 +- 7 files changed, 60 insertions(+), 37 deletions(-) diff --git a/example/ck_tile/01_fmha/README.md b/example/ck_tile/01_fmha/README.md index 0803d54d66..c7ab296c3b 100644 --- a/example/ck_tile/01_fmha/README.md +++ b/example/ck_tile/01_fmha/README.md @@ -6,7 +6,8 @@ This folder contains example for fmha(fused multi-head attention) using ck_tile ``` # in the root of ck_tile mkdir build && cd build -sh ../script/cmake-ck-dev.sh ../ # you can replace this to gfx90a, gfx942... +# you can replace with the appropriate architecture (for example gfx90a or gfx942) or leave it blank +sh ../script/cmake-ck-dev.sh ../ make tile_example_fmha_fwd -j ``` This will result in an executable `build/bin/tile_example_fmha_fwd` @@ -23,7 +24,7 @@ There are 3 template parameters for this kernel template. To speed up compile time, we instantiate the kernels into separate file. In this way we can benefit from parallel building from CMake/Make system. This is achieved by `generate.py` script. Besides, you can look into this script to learn how to instantiate a kernel instance step by step, which is described in `FMHA_FWD_KERNEL_BODY` variable. ## executable -`tile_example_fmha_fwd` is the example executable, implemented in `fmha_fwd.cpp`. You can type `./bin/tile_example_fmha_fwd -?` to list all supported args. Below is an example of the output (may subject to change) +`tile_example_fmha_fwd` is the example executable, implemented in `fmha_fwd.cpp`. You can type `./bin/tile_example_fmha_fwd -?` to list all the arguments. Below is an example of the output (may subject to change) ``` args: -v weather do CPU validation or not (default:1) @@ -31,48 +32,48 @@ args: -b batch size (default:2) -h num of head, for q (default:8) -h_k num of head, for k/v, -1 means equal to h (default:-1) - if not equal to h, then this is GQA/MQA case + if not equal to h, then this is GQA/MQA case -s seqlen_q. if group-mode, means the average value of seqlen_q (default:3328) - total_seqlen_q = seqlen_q * batch, and seqlen_q per batch may vary - also with "-s=s0,s1,s2..." comma seperated int to set per batch seqlen(group-mode) - -s_k seqlen_k, -1 means equal to s (default:-1) + total_seqlen_q = seqlen_q * batch, and seqlen_q per batch may vary + also with "-s=s0,s1,s2..." comma seperated int to set per batch seqlen(group-mode) + -s_k seqlen_k (including new key/value), -1 means equal to s (default:-1) -d head dim for q, k (default:128) -d_v head dim for v, -1 means equal to d (default:-1) -scale_s scale factor of S. 0 means equal to 1/sqrt(hdim). (default:0) - note when squant=1, this value will be modified by range_q/k + note when squant=1, this value will be modified by range_q/k -range_q per-tensor quantization range of q. used if squant=1. (default:16) -range_k per-tensor quantization range of k. used if squant=1. (default:16) -range_v per-tensor quantization range of v. used if squant=1. (default:16) -range_p per-tensor quantization range of p [e^(s-m)]. used if squant=1. (default:1) -range_o per-tensor quantization range of o (p*v). used if squant=1. (default:16) -squant if using static quantization fusion or not. auto: fp8 will default use squant, other will not (default:auto) - 0: no static quant(not implemented) 1: apply scale_p and scale_o with respect to P and O. - calculate scale_s, scale_p, scale_o according to range_q, range_k, range_v, range_p, range_o + 0: no static quant(not implemented) 1: apply scale_p and scale_o with respect to P and O. + calculate scale_s, scale_p, scale_o according to range_q, range_k, range_v, range_p, range_o -iperm permute input (default:1) - if true, will be b*h*s*d, else b*s*h*d + if true, will be b*h*s*d, else b*s*h*d -operm permute output (default:1) -bias n or 0, no bias (default:n) - e(lementwise) or 1, elementwise bias with 1*1*s*s. e:1, 1*h*s*s. e:2, b*h*s*s - a(libi) or 2, alibi with 1*h. a:1, b*h + e(lementwise) or 1, elementwise bias with 1*1*s*s. e:1, 1*h*s*s. e:2, b*h*s*s + a(libi) or 2, alibi with 1*h. a:1, b*h -prec data type. fp16/bf16/fp8/bf8 (default:fp16) -mask 0: no mask, 1: top-left(same as 't'), 2:bottom-right(same as 'b') (default:0) - 't', top-left causal mask, 'b', bottom-r causal mask - 't:l,r', top-left sliding window attn(swa) with FA style left right size - 'b:l,r', bottom-r sliding window attn(swa) with FA style left right size - 'xt:window_size', xformer style masking from top-left, window_size negative is causal, positive is swa - 'xb:window_size', xformer style masking from bottom-r, window_size negative is causal, positive is swa - 'g:y,x', generic attention mask coordinate with y/x size (only debug purpose for now) + 't', top-left causal mask, 'b', bottom-r causal mask + 't:l,r', top-left sliding window attn(swa) with FA style left right size + 'b:l,r', bottom-r sliding window attn(swa) with FA style left right size + 'xt:window_size', xformer style masking from top-left, window_size negative is causal, positive is swa + 'xb:window_size', xformer style masking from bottom-r, window_size negative is causal, positive is swa + 'g:y,x', generic attention mask coordinate with y/x size (only debug purpose for now) -vlayout r for row-major(seqlen*hdim), c for col-major(hdim*seqlen) (default:r) -lse 0 not store lse, 1 store lse (default:0) -kname if set to 1 will print kernel name (default:0) -init init method. ui, uniform random int, ni, normalized random int (default:uf) - uf, uniform random float, nf, normalized random float, tf, trig float, uf:q, quantization + uf, uniform random float, nf, normalized random float, tf, trig float, uf:q, quantization -seed random seed used for initializing input tensors. 0 for non-deterministic seed (default:11939) + -drop_seed seed for random number generator (default:1) +-drop_offset offset for random number generator (default:0) + -drop_prefs seed and offset values are present on GPU; 0 - host, 1 - device/GPU (default:0) -warmup number of iterations before benchmark the kernel (default:5) -repeat number of iterations to benchmark the kernel (default:20) - -drop_seed seed for the random number generator for the dropout layer, default is 1 --drop_offset offset for the dropout layer which is used during random number generation, default is 0 - -drop_prefs flag to indicate `drop_seed` and `drop_offset` values if present on the GPU, default is 0, 0 - host, 1 - GPU ``` Example 1: `./bin/tile_example_fmha_fwd -b=1 -h=16 -s=16384 -d=128` will run a fmha case with batch=1, nhead=16, sequence length=16384, hdim=128, fp16 case. Example 2: `./bin/tile_example_fmha_fwd -b=1 -h=8 -s=16384 -d=64 -drop_prefs=1 -drop_seed=10 -drop_offset=1234` will run a fmha case with diff --git a/example/ck_tile/02_layernorm2d/README.md b/example/ck_tile/02_layernorm2d/README.md index 433dad04e6..66b16c1b7f 100644 --- a/example/ck_tile/02_layernorm2d/README.md +++ b/example/ck_tile/02_layernorm2d/README.md @@ -6,7 +6,8 @@ This folder contains example for Layernorm2D forward using ck_tile tile-programm ``` # in the root of ck_tile mkdir build && cd build -sh ../script/cmake-ck-dev.sh ../ # you can replace this to gfx90a, gfx942... +# you can replace with the appropriate architecture (for example gfx90a or gfx942) or leave it blank +sh ../script/cmake-ck-dev.sh ../ make tile_example_layernorm2d_fwd -j ``` This will result in an executable `build/bin/tile_example_layernorm2d_fwd` diff --git a/example/ck_tile/03_gemm/README.md b/example/ck_tile/03_gemm/README.md index 00303bf62c..aacbdf6863 100644 --- a/example/ck_tile/03_gemm/README.md +++ b/example/ck_tile/03_gemm/README.md @@ -6,7 +6,8 @@ This folder contains example for GEMM using ck_tile tile-programming implementat ``` # in the root of ck_tile mkdir build && cd build -sh ../script/cmake-ck-dev.sh ../ # you can replace this to gfx90a, gfx942... +# you can replace with the appropriate architecture (for example gfx90a or gfx942) or leave it blank +sh ../script/cmake-ck-dev.sh ../ make tile_example_gemm_basic -j ``` This will result in an executable `build/bin/tile_example_gemm_basic` @@ -14,10 +15,17 @@ This will result in an executable `build/bin/tile_example_gemm_basic` ## example ``` args: - -m m dimension (default:3328) - -n m dimension (default:4096) + -b batch size (default:1) + -m m dimension (default:1024) + -n n dimension (default:2048) -k k dimension (default:64) - -e epsilon (default:1e-5) - -v cpu validation or not (default:1) - -prec precision (default:fp16) + -stride_a Tensor A stride (default:0) + -stride_b Tensor B stride (default:0) + -stride_c Tensor C stride (default:0) + -v 0. No validation, 1. Validation on CPU, 2. Validation on GPU (default:2) + -e Absolute error tolerance (default:1e-5) + -prec data type. fp16/bf16/fp8/bf8 (default:fp16) + -warmup number of iterations before benchmark the kernel (default:10) + -repeat number of iterations to benchmark the kernel (default:100) + -timer gpu:gpu timer, cpu:cpu timer (default:gpu) ``` diff --git a/example/ck_tile/04_img2col/README.md b/example/ck_tile/04_img2col/README.md index 6ae2cea5e5..df5c51a9c0 100644 --- a/example/ck_tile/04_img2col/README.md +++ b/example/ck_tile/04_img2col/README.md @@ -6,7 +6,8 @@ This folder contains example for Image to Column using ck_tile tile-programming ``` # in the root of ck_tile mkdir build && cd build -sh ../script/cmake-ck-dev.sh ../ # you can replace this to gfx90a, gfx942... +# you can replace with the appropriate architecture (for example gfx90a or gfx942) or leave it blank +sh ../script/cmake-ck-dev.sh ../ make tile_example_img2col -j ``` This will result in an executable `build/bin/tile_example_img2col` diff --git a/include/ck_tile/host/arg_parser.hpp b/include/ck_tile/host/arg_parser.hpp index 5f8a78b4c9..3765156df0 100644 --- a/include/ck_tile/host/arg_parser.hpp +++ b/include/ck_tile/host/arg_parser.hpp @@ -50,12 +50,22 @@ class ArgParser } return *this; } - void print() + void print() const { + // find max key length + std::string::size_type max_key_length = 11; + for(auto& key : keys) + { + if(max_key_length < key.length()) + { + max_key_length = key.length(); + } + } + printf("args:\n"); for(auto& key : keys) { - auto value = input_map[key]; + auto value = input_map.at(key); std::vector help_text_lines; size_t pos = 0; for(size_t next_pos = value.help_text.find('\n', pos); next_pos != std::string::npos;) @@ -69,8 +79,7 @@ class ArgParser std::string(value.help_text.begin() + pos, value.help_text.end())); std::string default_value = std::string("(default:") + value.value + std::string(")"); - - std::cout << std::setw(2) << std::setw(12 - value.name.length()) << "-" << key + std::cout << std::setw(1 + max_key_length - value.name.length()) << "-" << key << std::setw(4) << " " << help_text_lines[0] << " " << default_value << std::endl; @@ -78,7 +87,8 @@ class ArgParser help_next_line != help_text_lines.end(); ++help_next_line) { - std::cout << std::setw(17) << " " << *help_next_line << std::endl; + std::cout << std::setw(1 + max_key_length + 4) << " " << *help_next_line + << std::endl; } } } diff --git a/script/cmake-ck-dev.sh b/script/cmake-ck-dev.sh index 5dae86089a..4097ca98f6 100755 --- a/script/cmake-ck-dev.sh +++ b/script/cmake-ck-dev.sh @@ -7,7 +7,8 @@ MY_PROJECT_SOURCE=$1 if [ $# -ge 2 ] ; then GPU_TARGETS=$2 - REST_ARGS=${@:3} + shift 2 + REST_ARGS=$@ else GPU_TARGETS="gfx908;gfx90a;gfx940" REST_ARGS= diff --git a/script/cmake-ck-release.sh b/script/cmake-ck-release.sh index f65ec610dd..5e3f7faac2 100755 --- a/script/cmake-ck-release.sh +++ b/script/cmake-ck-release.sh @@ -7,7 +7,8 @@ MY_PROJECT_SOURCE=$1 if [ $# -ge 2 ] ; then GPU_TARGETS=$2 - REST_ARGS=${@:3} + shift 2 + REST_ARGS=$@ else GPU_TARGETS="gfx908;gfx90a;gfx940" REST_ARGS= From aa932445eae1d2d8a6abb6c8a78c3fc41489ecf9 Mon Sep 17 00:00:00 2001 From: Rostyslav Geyyer <46627076+geyyer@users.noreply.github.com> Date: Tue, 8 Oct 2024 11:05:28 -0500 Subject: [PATCH 12/24] Add a gpu gemm reference kernel (#1528) * Add a gpu gemm reference kernel * Switch to gpu reference in gemm examples * Remove redundant arguments * Update all related examples * Update more examples * Try less threads per block * Try even less threads per block * Add support for all matrix layouts * Increase block size * Clean up * Remove hardcoded strides * Clean up * Try a column-major case * Revert back to row-major * Run both CPU and GPU veriffication --------- Co-authored-by: Po Yen Chen --- example/01_gemm/common.hpp | 33 +-- example/01_gemm/gemm_dl_fp16.cpp | 13 +- example/01_gemm/gemm_dl_fp32.cpp | 13 +- example/01_gemm/gemm_dl_int8.cpp | 13 +- example/01_gemm/gemm_dpp_fp16.cpp | 5 +- example/01_gemm/gemm_wmma_fp16.cpp | 13 +- example/01_gemm/gemm_xdl_bf16.cpp | 16 +- example/01_gemm/gemm_xdl_bf16_rtn.cpp | 16 +- example/01_gemm/gemm_xdl_fp16.cpp | 13 +- example/01_gemm/gemm_xdl_fp16_fp8.cpp | 13 +- example/01_gemm/gemm_xdl_fp16_v2.cpp | 13 +- example/01_gemm/gemm_xdl_fp64.cpp | 13 +- example/01_gemm/gemm_xdl_fp8.cpp | 14 + example/01_gemm/gemm_xdl_fp8_bf8.cpp | 13 +- example/01_gemm/gemm_xdl_int8.cpp | 13 +- .../01_gemm/gemm_xdl_lds_direct_load_fp16.cpp | 13 +- .../01_gemm/gemm_xdl_lds_direct_load_fp32.cpp | 13 +- example/01_gemm/gemm_xdl_streamk.cpp | 13 +- example/01_gemm/gemm_xdl_wavelet_fp16.cpp | 13 +- example/01_gemm/run_gemm_example.inc | 46 +++- .../gpu/reference_gemm.hpp | 245 ++++++++++++++++++ 21 files changed, 518 insertions(+), 39 deletions(-) create mode 100644 library/include/ck/library/reference_tensor_operation/gpu/reference_gemm.hpp diff --git a/example/01_gemm/common.hpp b/example/01_gemm/common.hpp index 3d8f4565cb..eb1738e760 100644 --- a/example/01_gemm/common.hpp +++ b/example/01_gemm/common.hpp @@ -21,6 +21,7 @@ #include "ck/library/utility/host_tensor_generator.hpp" #include "ck/library/utility/literals.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" +#include "ck/library/reference_tensor_operation/gpu/reference_gemm.hpp" struct ProblemSize final { @@ -28,9 +29,9 @@ struct ProblemSize final ck::index_t N = 4096; ck::index_t K = 4096; - ck::index_t StrideA = 4096; - ck::index_t StrideB = 4096; - ck::index_t StrideC = 4096; + ck::index_t StrideA = 0; + ck::index_t StrideB = 0; + ck::index_t StrideC = 0; }; struct ProblemSizeStreamK final @@ -39,9 +40,9 @@ struct ProblemSizeStreamK final ck::index_t N = 4096; ck::index_t K = 4096; - ck::index_t StrideA = 4096; - ck::index_t StrideB = 4096; - ck::index_t StrideC = 4096; + ck::index_t StrideA = 0; + ck::index_t StrideB = 0; + ck::index_t StrideC = 0; ck::index_t NumSKBlocks = -1; }; @@ -51,9 +52,9 @@ struct ProblemSizeStreamK_universal final ck::index_t N = 4096; ck::index_t K = 4096; - ck::index_t StrideA = 4096; - ck::index_t StrideB = 4096; - ck::index_t StrideC = 4096; + ck::index_t StrideA = 0; + ck::index_t StrideB = 0; + ck::index_t StrideC = 0; ck::index_t Grid_size = -1; // defaults to max occupancy ck::index_t Streamk_sel = 1; // defaults to 1-tile SK @@ -65,9 +66,9 @@ struct ProblemSizeSplitK final ck::index_t N = 4096; ck::index_t K = 4096; - ck::index_t StrideA = 4096; - ck::index_t StrideB = 4096; - ck::index_t StrideC = 4096; + ck::index_t StrideA = 0; + ck::index_t StrideB = 0; + ck::index_t StrideC = 0; ck::index_t KBatch = 1; }; @@ -125,7 +126,7 @@ bool parse_cmd_args(int argc, } else { - std::cerr << "arg1: verification (0=no, 1=yes)" << std::endl + std::cerr << "arg1: verification (0=no, 1=CPU and GPU)" << std::endl << "arg2: initialization (0=no init, 1=integer value, 2=decimal value)" << std::endl << "arg3: time kernel (0=no, 1=yes)" << std::endl @@ -175,7 +176,7 @@ bool parse_cmd_args(int argc, else { std::cerr - << "arg1: verification (0=no, 1=yes)" << std::endl + << "arg1: verification (0=no, 1=CPU and GPU)" << std::endl << "arg2: initialization (0=no init, 1=integer value, 2=decimal value)" << std::endl << "arg3: time kernel (0=no, 1=yes)" << std::endl << "arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC" << std::endl @@ -224,7 +225,7 @@ bool parse_cmd_args(int argc, } else { - std::cerr << "arg1: verification (0=no, 1=yes)" << std::endl + std::cerr << "arg1: verification (0=no, 1=CPU and GPU)" << std::endl << "arg2: initialization (0=no init, 1=integer value, 2=decimal value)" << std::endl << "arg3: time kernel (0=no, 1=yes)" << std::endl @@ -274,7 +275,7 @@ bool parse_cmd_args(int argc, } else { - std::cerr << "arg1: verification (0=no, 1=yes)" << std::endl + std::cerr << "arg1: verification (0=no, 1=CPU and GPU)" << std::endl << "arg2: initialization (0=no init, 1=integer value, 2=decimal value)" << std::endl << "arg3: time kernel (0=no, 1=yes)" << std::endl diff --git a/example/01_gemm/gemm_dl_fp16.cpp b/example/01_gemm/gemm_dl_fp16.cpp index b5fecb9752..b9284b2783 100644 --- a/example/01_gemm/gemm_dl_fp16.cpp +++ b/example/01_gemm/gemm_dl_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" @@ -32,6 +32,17 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmDl using ReferenceGemmInstance = ck::tensor_operation::host:: ReferenceGemm; +using ReferenceGemmInstanceGPU = ck::tensor_operation::device::ReferenceGemm; + #include "run_gemm_example.inc" int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); } diff --git a/example/01_gemm/gemm_dl_fp32.cpp b/example/01_gemm/gemm_dl_fp32.cpp index 212b72f2a6..1684213641 100644 --- a/example/01_gemm/gemm_dl_fp32.cpp +++ b/example/01_gemm/gemm_dl_fp32.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" @@ -32,6 +32,17 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmDl using ReferenceGemmInstance = ck::tensor_operation::host:: ReferenceGemm; +using ReferenceGemmInstanceGPU = ck::tensor_operation::device::ReferenceGemm; + #include "run_gemm_example.inc" int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); } diff --git a/example/01_gemm/gemm_dl_int8.cpp b/example/01_gemm/gemm_dl_int8.cpp index 1840390aa9..1e64e9a0a3 100644 --- a/example/01_gemm/gemm_dl_int8.cpp +++ b/example/01_gemm/gemm_dl_int8.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" @@ -32,6 +32,17 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmDl using ReferenceGemmInstance = ck::tensor_operation::host:: ReferenceGemm; +using ReferenceGemmInstanceGPU = ck::tensor_operation::device::ReferenceGemm; + #include "run_gemm_example.inc" int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); } diff --git a/example/01_gemm/gemm_dpp_fp16.cpp b/example/01_gemm/gemm_dpp_fp16.cpp index 7a9e3f6186..30faf542dd 100644 --- a/example/01_gemm/gemm_dpp_fp16.cpp +++ b/example/01_gemm/gemm_dpp_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" @@ -34,6 +34,9 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmDpp using ReferenceGemmInstance = ck::tensor_operation::host:: ReferenceGemm; +using ReferenceGemmInstanceGPU = ck::tensor_operation::device:: + ReferenceGemm; + #include "run_gemm_example.inc" int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); } diff --git a/example/01_gemm/gemm_wmma_fp16.cpp b/example/01_gemm/gemm_wmma_fp16.cpp index f8afe8d6db..28ab878ac3 100644 --- a/example/01_gemm/gemm_wmma_fp16.cpp +++ b/example/01_gemm/gemm_wmma_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" @@ -68,6 +68,17 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmWmma_CShuffle using ReferenceGemmInstance = ck::tensor_operation::host:: ReferenceGemm; +using ReferenceGemmInstanceGPU = ck::tensor_operation::device::ReferenceGemm; + #include "run_gemm_example.inc" int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); } diff --git a/example/01_gemm/gemm_xdl_bf16.cpp b/example/01_gemm/gemm_xdl_bf16.cpp index 3cac55ef47..6cfff30dbd 100644 --- a/example/01_gemm/gemm_xdl_bf16.cpp +++ b/example/01_gemm/gemm_xdl_bf16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" @@ -33,6 +33,20 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle using ReferenceGemmInstance = ck::tensor_operation::host:: ReferenceGemm; +using ReferenceComputeType = float; +using ReferenceGemmInstanceGPU = ck::tensor_operation::device::ReferenceGemm; + #include "run_gemm_example.inc" int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); } diff --git a/example/01_gemm/gemm_xdl_bf16_rtn.cpp b/example/01_gemm/gemm_xdl_bf16_rtn.cpp index cc14dcb8eb..108c100cbd 100644 --- a/example/01_gemm/gemm_xdl_bf16_rtn.cpp +++ b/example/01_gemm/gemm_xdl_bf16_rtn.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" @@ -34,6 +34,20 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle using ReferenceGemmInstance = ck::tensor_operation::host:: ReferenceGemm; +using ReferenceComputeType = float; +using ReferenceGemmInstanceGPU = ck::tensor_operation::device::ReferenceGemm; + #include "run_gemm_example.inc" int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); } diff --git a/example/01_gemm/gemm_xdl_fp16.cpp b/example/01_gemm/gemm_xdl_fp16.cpp index 2338cdc9c1..07d51855d6 100644 --- a/example/01_gemm/gemm_xdl_fp16.cpp +++ b/example/01_gemm/gemm_xdl_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" @@ -47,6 +47,17 @@ using DeviceGemmInstance = DeviceGemmInstance1; using ReferenceGemmInstance = ck::tensor_operation::host:: ReferenceGemm; +using ReferenceGemmInstanceGPU = ck::tensor_operation::device::ReferenceGemm; + #include "run_gemm_example.inc" int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); } diff --git a/example/01_gemm/gemm_xdl_fp16_fp8.cpp b/example/01_gemm/gemm_xdl_fp16_fp8.cpp index 979a200791..a996d034e6 100644 --- a/example/01_gemm/gemm_xdl_fp16_fp8.cpp +++ b/example/01_gemm/gemm_xdl_fp16_fp8.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" @@ -42,6 +42,17 @@ using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; +using ReferenceGemmInstanceGPU = ck::tensor_operation::device::ReferenceGemm; + #include "run_gemm_example.inc" int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); } diff --git a/example/01_gemm/gemm_xdl_fp16_v2.cpp b/example/01_gemm/gemm_xdl_fp16_v2.cpp index eba0ea9d11..ecd3b7be5d 100644 --- a/example/01_gemm/gemm_xdl_fp16_v2.cpp +++ b/example/01_gemm/gemm_xdl_fp16_v2.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" @@ -46,6 +46,17 @@ using DeviceGemmInstance = using ReferenceGemmInstance = ck::tensor_operation::host:: ReferenceGemm; +using ReferenceGemmInstanceGPU = ck::tensor_operation::device::ReferenceGemm; + #include "run_gemm_example.inc" int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); } diff --git a/example/01_gemm/gemm_xdl_fp64.cpp b/example/01_gemm/gemm_xdl_fp64.cpp index 8361576299..5afb3d1554 100644 --- a/example/01_gemm/gemm_xdl_fp64.cpp +++ b/example/01_gemm/gemm_xdl_fp64.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" @@ -41,6 +41,17 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdl BElementOp, CElementOp>; +using ReferenceGemmInstanceGPU = ck::tensor_operation::device::ReferenceGemm; + #include "run_gemm_example.inc" int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); } diff --git a/example/01_gemm/gemm_xdl_fp8.cpp b/example/01_gemm/gemm_xdl_fp8.cpp index fe41602301..3c75a44d21 100644 --- a/example/01_gemm/gemm_xdl_fp8.cpp +++ b/example/01_gemm/gemm_xdl_fp8.cpp @@ -37,6 +37,20 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle using ReferenceGemmInstance = ck::tensor_operation::host:: ReferenceGemm; +using ReferenceComputeType = float; +using ReferenceGemmInstanceGPU = ck::tensor_operation::device::ReferenceGemm; + #include "run_gemm_example.inc" int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); } diff --git a/example/01_gemm/gemm_xdl_fp8_bf8.cpp b/example/01_gemm/gemm_xdl_fp8_bf8.cpp index acc5fbc515..1dec165abd 100644 --- a/example/01_gemm/gemm_xdl_fp8_bf8.cpp +++ b/example/01_gemm/gemm_xdl_fp8_bf8.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" @@ -44,6 +44,17 @@ using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; +using ReferenceGemmInstanceGPU = ck::tensor_operation::device::ReferenceGemm; + #include "run_gemm_example.inc" int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); } diff --git a/example/01_gemm/gemm_xdl_int8.cpp b/example/01_gemm/gemm_xdl_int8.cpp index cc03200b9d..3237f1a61c 100644 --- a/example/01_gemm/gemm_xdl_int8.cpp +++ b/example/01_gemm/gemm_xdl_int8.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" @@ -33,6 +33,17 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle using ReferenceGemmInstance = ck::tensor_operation::host:: ReferenceGemm; +using ReferenceGemmInstanceGPU = ck::tensor_operation::device::ReferenceGemm; + #include "run_gemm_example.inc" int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); } diff --git a/example/01_gemm/gemm_xdl_lds_direct_load_fp16.cpp b/example/01_gemm/gemm_xdl_lds_direct_load_fp16.cpp index d29cb74cd6..62037f7740 100644 --- a/example/01_gemm/gemm_xdl_lds_direct_load_fp16.cpp +++ b/example/01_gemm/gemm_xdl_lds_direct_load_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. #include @@ -53,6 +53,17 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle using ReferenceGemmInstance = ck::tensor_operation::host:: ReferenceGemm; +using ReferenceGemmInstanceGPU = ck::tensor_operation::device::ReferenceGemm; + #include "run_gemm_example.inc" int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); } diff --git a/example/01_gemm/gemm_xdl_lds_direct_load_fp32.cpp b/example/01_gemm/gemm_xdl_lds_direct_load_fp32.cpp index e99249389e..75971bdecf 100644 --- a/example/01_gemm/gemm_xdl_lds_direct_load_fp32.cpp +++ b/example/01_gemm/gemm_xdl_lds_direct_load_fp32.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. #include @@ -52,6 +52,17 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle using ReferenceGemmInstance = ck::tensor_operation::host:: ReferenceGemm; +using ReferenceGemmInstanceGPU = ck::tensor_operation::device::ReferenceGemm; + #include "run_gemm_example.inc" int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); } diff --git a/example/01_gemm/gemm_xdl_streamk.cpp b/example/01_gemm/gemm_xdl_streamk.cpp index 7d433b6145..5a02457daf 100644 --- a/example/01_gemm/gemm_xdl_streamk.cpp +++ b/example/01_gemm/gemm_xdl_streamk.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" @@ -44,6 +44,17 @@ using DeviceGemmInstance = DeviceGemmStreamK; using ReferenceGemmInstance = ck::tensor_operation::host:: ReferenceGemm; +using ReferenceGemmInstanceGPU = ck::tensor_operation::device::ReferenceGemm; + #include "run_gemm_example.inc" int main(int argc, char* argv[]) { return !run_gemm_streamk_example(argc, argv); } diff --git a/example/01_gemm/gemm_xdl_wavelet_fp16.cpp b/example/01_gemm/gemm_xdl_wavelet_fp16.cpp index b0f963fee5..d8672f6a0c 100644 --- a/example/01_gemm/gemm_xdl_wavelet_fp16.cpp +++ b/example/01_gemm/gemm_xdl_wavelet_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" @@ -37,6 +37,17 @@ using DeviceGemmInstance = DeviceGemmInstance; using ReferenceGemmInstance = ck::tensor_operation::host:: ReferenceGemm; +using ReferenceGemmInstanceGPU = ck::tensor_operation::device::ReferenceGemm; + #include "run_gemm_example.inc" int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); } diff --git a/example/01_gemm/run_gemm_example.inc b/example/01_gemm/run_gemm_example.inc index a6f0d0bcfe..f66d2adc11 100644 --- a/example/01_gemm/run_gemm_example.inc +++ b/example/01_gemm/run_gemm_example.inc @@ -173,6 +173,7 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) Tensor c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); Tensor c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + Tensor c_m_n_device_ref_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; @@ -193,6 +194,8 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize()); DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize()); DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpaceSize()); + DeviceMem c_m_n_device_ref_buf(sizeof(CDataType) * + c_m_n_device_ref_result.mDesc.GetElementSpaceSize()); a_m_k_device_buf.ToDevice(a_m_k.mData.data()); b_k_n_device_buf.ToDevice(b_k_n.mData.data()); @@ -325,14 +328,18 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " << gemm.GetTypeString() << std::endl; + bool pass = true; + if(config.do_verification) { + // CPU verification auto ref_gemm = ReferenceGemmInstance{}; auto ref_invoker = ref_gemm.MakeInvoker(); auto ref_argument = ref_gemm.MakeArgument( a_m_k, b_k_n, c_m_n_host_result, a_element_op, b_element_op, c_element_op); + std::cout << "Running verification on CPU." << std::endl; ref_invoker.Run(ref_argument); #ifdef BUILD_INT4_EXAMPLE @@ -346,15 +353,42 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) #else c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data()); - return ck::utils::check_err(c_m_n_device_result, - c_m_n_host_result, - "Error: Incorrect results!", - get_rtol(), - get_atol()); + pass &= !ck::utils::check_err(c_m_n_device_result, + c_m_n_host_result, + "Error: Incorrect results!", + get_rtol(), + get_atol()); #endif + + // GPU verification + auto ref_gemm_gpu = ReferenceGemmInstanceGPU{}; + auto ref_invoker_gpu = ref_gemm_gpu.MakeInvoker(); + + auto ref_argument_gpu = ref_gemm_gpu.MakeArgument( + static_cast(a_m_k_device_buf.GetDeviceBuffer()), + static_cast(b_k_n_device_buf.GetDeviceBuffer()), + static_cast(c_m_n_device_ref_buf.GetDeviceBuffer()), + M, + N, + K, + a_element_op, + b_element_op, + c_element_op); + + std::cout << "Running verification on GPU." << std::endl; + ref_invoker_gpu.Run(ref_argument_gpu, StreamConfig{}); + + c_m_n_device_ref_buf.FromDevice(c_m_n_device_ref_result.mData.data()); + c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data()); + + pass &= !ck::utils::check_err(c_m_n_device_result, + c_m_n_device_ref_result, + "Error: Incorrect results!", + get_rtol(), + get_atol()); } - return true; + return !pass; } bool run_gemm_example(int argc, char* argv[]) diff --git a/library/include/ck/library/reference_tensor_operation/gpu/reference_gemm.hpp b/library/include/ck/library/reference_tensor_operation/gpu/reference_gemm.hpp new file mode 100644 index 0000000000..639b5fe80f --- /dev/null +++ b/library/include/ck/library/reference_tensor_operation/gpu/reference_gemm.hpp @@ -0,0 +1,245 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/device/device_base.hpp" + +namespace ck { + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + naive_gemm_kernel(const ADataType* __restrict__ p_a_grid, + const BDataType* __restrict__ p_b_grid, + CDataType* __restrict__ p_c_grid, + index_t m, + index_t n, + index_t k, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CDEElementwiseOperation c_element_op) +{ + using RowMajor = ck::tensor_layout::gemm::RowMajor; + + const int row_idx = blockIdx.x * blockDim.x + threadIdx.x; + const int col_idx = blockIdx.y * blockDim.y + threadIdx.y; + + if(row_idx < m && col_idx < n) + { + + AccDataType v_acc = static_cast(0.0); + ComputeTypeA v_a = static_cast(0.0); + ComputeTypeB v_b = static_cast(0.0); + CDataType v_c = static_cast(0.0); + + for(int k_idx = 0; k_idx < k; ++k_idx) + { + // check input matrices layout + int element_idx_a = 0; + int element_idx_b = 0; + if constexpr(std::is_same_v) + { + element_idx_a = row_idx * k + k_idx; + } + else + { + element_idx_a = row_idx + m * k_idx; + } + if constexpr(std::is_same_v) + { + element_idx_b = k_idx * n + col_idx; + } + else + { + element_idx_b = k_idx + k * col_idx; + } + // apply a_element_op + a_element_op(v_a, p_a_grid[element_idx_a]); + // apply b_element_op + b_element_op(v_b, p_b_grid[element_idx_b]); + // multiply and accumulate + v_acc += static_cast(v_a) * static_cast(v_b); + } + // apply c_element_op + c_element_op(v_c, v_acc); + // check output matrix layout + int element_idx_c = 0; + if constexpr(std::is_same_v) + { + element_idx_c = row_idx * n + col_idx; + } + else + { + element_idx_c = row_idx + m * col_idx; + } + // prepare output + p_c_grid[element_idx_c] = v_c; + } +} + +} // namespace ck + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct ReferenceGemm : public device::BaseOperator +{ + // Argument + struct Argument : public device::BaseArgument + { + Argument(const void* p_a_grid, + const void* p_b_grid, + void* p_c_grid, + index_t m, + index_t n, + index_t k, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + : p_a_grid_{static_cast(p_a_grid)}, + p_b_grid_{static_cast(p_b_grid)}, + p_c_grid_{static_cast(p_c_grid)}, + m_{m}, + n_{n}, + k_{k}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + c_element_op_{c_element_op} + { + } + + const ADataType* p_a_grid_; + const BDataType* p_b_grid_; + CDataType* p_c_grid_; + + index_t m_; + index_t n_; + index_t k_; + + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CElementwiseOperation c_element_op_; + }; + + // Invoker + struct Invoker : public device::BaseInvoker + { + using Argument = ReferenceGemm::Argument; + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + int block_size = 16; + dim3 block_dim(block_size, block_size, 1); + dim3 grid_dim( + (arg.m_ + block_size - 1) / block_size, (arg.n_ + block_size - 1) / block_size, 1); + + auto launch_kernel = [&]() { + const auto kernel = naive_gemm_kernel; + + return launch_and_time_kernel(stream_config, + kernel, + grid_dim, + block_dim, + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.m_, + arg.n_, + arg.k_, + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_); + }; + + return launch_kernel(); + } + + float Run(const device::BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + bool IsSupportedArgument(const device::BaseArgument*) override { return true; } + + static auto MakeArgument(const void* p_a_grid, + const void* p_b_grid, + void* p_c_grid, + index_t m, + index_t n, + index_t k, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + { + return Argument{ + p_a_grid, p_b_grid, p_c_grid, m, n, k, a_element_op, b_element_op, c_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + virtual std::unique_ptr MakeInvokerPointer() + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "Device Reference Gemm" + << std::endl; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck From ceaed8e097cbba105e23f465c6226ab48e37a3a8 Mon Sep 17 00:00:00 2001 From: Christopher Millette <63608002+cgmillette@users.noreply.github.com> Date: Wed, 9 Oct 2024 01:41:35 -0600 Subject: [PATCH 13/24] Fixes small memory leak from missing hipEventDestroy (#1554) --- include/ck/host_utility/kernel_launch.hpp | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/include/ck/host_utility/kernel_launch.hpp b/include/ck/host_utility/kernel_launch.hpp index a616433ac9..962f89e479 100644 --- a/include/ck/host_utility/kernel_launch.hpp +++ b/include/ck/host_utility/kernel_launch.hpp @@ -66,6 +66,9 @@ float launch_and_time_kernel(const StreamConfig& stream_config, hip_check_error(hipEventElapsedTime(&total_time, start, stop)); + hip_check_error(hipEventDestroy(start)); + hip_check_error(hipEventDestroy(stop)); + return total_time / nrepeat; } else @@ -143,6 +146,9 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config, hip_check_error(hipEventElapsedTime(&total_time, start, stop)); + hip_check_error(hipEventDestroy(start)); + hip_check_error(hipEventDestroy(stop)); + return total_time / nrepeat; } else From cfac9497e28a7489d5cde5bf2b4f40691dd5659c Mon Sep 17 00:00:00 2001 From: Illia Silin <98187287+illsilin@users.noreply.github.com> Date: Wed, 9 Oct 2024 10:18:05 -0700 Subject: [PATCH 14/24] remove gfx12 targets from daily builds with rocm6.2 (#1560) --- Jenkinsfile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Jenkinsfile b/Jenkinsfile index e61fb71e8e..a79ed859f2 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -1138,7 +1138,7 @@ pipeline { execute_args = """ cmake -D CMAKE_PREFIX_PATH=/opt/rocm \ -D CMAKE_CXX_COMPILER="${build_compiler()}" \ -D CMAKE_BUILD_TYPE=Release \ - -D GPU_ARCHS="gfx908;gfx90a;gfx940;gfx941;gfx942;gfx1030;gfx1100;gfx1101;gfx1102;gfx1200;gfx1201" \ + -D GPU_ARCHS="gfx908;gfx90a;gfx940;gfx941;gfx942;gfx1030;gfx1100;gfx1101;gfx1102" \ -D CMAKE_CXX_FLAGS=" -O3 " .. && make -j64 """ } steps{ From 2e1165c1a73552dbacf08ccd351314ae95de14f7 Mon Sep 17 00:00:00 2001 From: Illia Silin <98187287+illsilin@users.noreply.github.com> Date: Wed, 9 Oct 2024 15:21:57 -0700 Subject: [PATCH 15/24] fix the target selection logic (#1561) --- CMakeLists.txt | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 6ad6307cb3..3f22bb4b61 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -132,7 +132,11 @@ if(GPU_ARCHS) unset(GPU_TARGETS CACHE) unset(AMDGPU_TARGETS CACHE) endif() - +if(GPU_TARGETS) + set(USER_GPU_TARGETS 1) +else() + set(USER_GPU_TARGETS 0) +endif() find_package(hip) # No assumption that HIP kernels are launched with uniform block size for backward compatibility # SWDEV-413293 and https://reviews.llvm.org/D155213 @@ -162,7 +166,7 @@ endif() if(GPU_ARCHS) set(CK_GPU_TARGETS ${GPU_ARCHS}) else() - if(GPU_TARGETS) + if(USER_GPU_TARGETS) set(CK_GPU_TARGETS ${GPU_TARGETS}) endif() endif() From 6f27bc987248633255cc400437bd017dca70cf1e Mon Sep 17 00:00:00 2001 From: Thomas Ning Date: Thu, 10 Oct 2024 03:02:22 -0700 Subject: [PATCH 16/24] Ck tile gemm cshuffle & CK Tile GEMM restructure (#1535) * ake the cshuffle compilable * modify Mhe reference on gpu and cpu. Correaccess of cshuffle * fix the cpu reference code * Complete the in tile shuffle logic * restructure the kernel template input * change the naming pattern of ck_tile gemm pipeline * Re-format files using remod.py * Solve the fmha conflict with gemm * Comment Addressed from Carlus --------- Co-authored-by: Po Yen, Chen --- example/ck_tile/03_gemm/gemm_basic.cpp | 55 ++++-- .../ck_tile/core/container/thread_buffer.hpp | 2 +- .../ck_tile/host/reference/reference_gemm.hpp | 47 ++++- include/ck_tile/ops/epilogue.hpp | 1 + .../ops/epilogue/cshuffle_epilogue.hpp | 171 ++++++++++++++++++ ...block_fmha_bwd_pipeline_default_policy.hpp | 133 ++++++++------ ...k_fmha_pipeline_qx_ks_vs_custom_policy.hpp | 81 +++++---- include/ck_tile/ops/gemm.hpp | 11 +- .../ck_tile/ops/gemm/kernel/gemm_kernel.hpp | 15 +- ... => gemm_pipeline_agmem_bgmem_creg_v1.hpp} | 10 +- ...ne_agmem_bgmem_creg_v1_default_policy.hpp} | 4 +- ... => gemm_pipeline_agmem_bgmem_creg_v2.hpp} | 6 +- ...ne_agmem_bgmem_creg_v2_default_policy.hpp} | 9 +- ..._problem.hpp => gemm_pipeline_problem.hpp} | 17 +- .../ops/gemm/pipeline/tile_gemm_traits.hpp | 27 +++ 15 files changed, 447 insertions(+), 142 deletions(-) create mode 100644 include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp rename include/ck_tile/ops/gemm/pipeline/{block_gemm_pipeline_agmem_bgmem_creg_v1.hpp => gemm_pipeline_agmem_bgmem_creg_v1.hpp} (95%) rename include/ck_tile/ops/gemm/pipeline/{block_gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp => gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp} (99%) rename include/ck_tile/ops/gemm/pipeline/{block_gemm_pipeline_agmem_bgmem_creg_v2.hpp => gemm_pipeline_agmem_bgmem_creg_v2.hpp} (97%) rename include/ck_tile/ops/gemm/pipeline/{block_gemm_pipeline_agmem_bgmem_creg_v2_default_policy.hpp => gemm_pipeline_agmem_bgmem_creg_v2_default_policy.hpp} (57%) rename include/ck_tile/ops/gemm/pipeline/{block_gemm_pipeline_problem.hpp => gemm_pipeline_problem.hpp} (65%) create mode 100644 include/ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp diff --git a/example/ck_tile/03_gemm/gemm_basic.cpp b/example/ck_tile/03_gemm/gemm_basic.cpp index 9f790f6acb..e3c8d72590 100644 --- a/example/ck_tile/03_gemm/gemm_basic.cpp +++ b/example/ck_tile/03_gemm/gemm_basic.cpp @@ -41,18 +41,39 @@ template ; - using GemmEpilogue = ck_tile::Default2DEpilogue< - ck_tile::Default2DEpilogueProblem>; + + // The rank and permutation will also be generate out by the CodeGen part. + constexpr ck_tile::index_t kOutputRank = 2; + + // Whether doing the CShuffle (transpose before the global memory), depending on the output + // layout. + constexpr bool CShuffleEpilogue = + std::is_same_v; + + using GemmEpilogue = std::conditional_t< + CShuffleEpilogue, + ck_tile::CShuffleEpilogue>, + ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem>>; // ToDo: Will add the codegen part to test different pipeline policies in GEMM. // Now we only use the BlockGemmASmemBSmemCRegV1DefaultPolicy. - using Kernel = - ck_tile::GemmKernel; + using Kernel = ck_tile::GemmKernel; auto kargs = Kernel::MakeKargs(args.p_a, args.p_b, @@ -255,15 +276,13 @@ int main(int argc, char* argv[]) ck_tile::sequence, ck_tile::sequence>; - using CodegenPipelineProblem = ck_tile::BlockGemmPipelineProblem; + using CodegenGemmTraits = ck_tile:: + TileGemmTraits; - using CodegenGemmPipeline = ck_tile::BlockGemmPipelineAGmemBGmemCRegV1; + using CodegenPipelineProblem = ck_tile:: + GemmPipelineProblem; + + using CodegenGemmPipeline = ck_tile::GemmPipelineAGmemBGmemCRegV1; invoke_gemm c_host_gpu_ref(c_dimensions); ck_tile::DeviceMem c_gpu_buf(c_host_gpu_ref.get_element_space_size_in_bytes()); - ck_tile::reference_gemm_gpu( + ck_tile::reference_gemm_gpu( a_buf, b_buf, c_gpu_buf, M, N, K, stride_a, stride_b, stride_c); c_buf.FromDevice(c_host_gpu_ref.data()); diff --git a/include/ck_tile/core/container/thread_buffer.hpp b/include/ck_tile/core/container/thread_buffer.hpp index a7dad5233b..279a48acb3 100644 --- a/include/ck_tile/core/container/thread_buffer.hpp +++ b/include/ck_tile/core/container/thread_buffer.hpp @@ -58,7 +58,7 @@ struct thread_buffer { template CK_TILE_HOST_DEVICE constexpr const auto& at() const { return get(I); } template CK_TILE_HOST_DEVICE constexpr auto& at(number) { return get(I); } template CK_TILE_HOST_DEVICE constexpr const auto& at(number) const { return get(I); } - + template ::value, bool>::type = false> CK_TILE_HOST_DEVICE constexpr auto _get_as() const diff --git a/include/ck_tile/host/reference/reference_gemm.hpp b/include/ck_tile/host/reference/reference_gemm.hpp index a0ddd02d9e..a496c91e00 100644 --- a/include/ck_tile/host/reference/reference_gemm.hpp +++ b/include/ck_tile/host/reference/reference_gemm.hpp @@ -27,7 +27,9 @@ CK_TILE_HOST void reference_gemm(const HostTensor& a_m_k, const BElementOp& b_element_op = {}, const ACCElementOp& acc_element_op = {}) { - const int N = b_n_k.mDesc.get_lengths()[0]; + const int N = (std::is_same_v) + ? b_n_k.mDesc.get_lengths()[0] + : b_n_k.mDesc.get_lengths()[1]; const int K = (std::is_same_v) ? a_m_k.mDesc.get_lengths()[1] : a_m_k.mDesc.get_lengths()[0]; @@ -45,20 +47,31 @@ CK_TILE_HOST void reference_gemm(const HostTensor& a_m_k, ADataType v_a = (std::is_same_v) ? a_element_op(a_m_k(m, k)) : a_element_op(a_m_k(k, m)); - BDataType v_b = b_element_op(b_n_k(n, k)); + BDataType v_b = (std::is_same_v) + ? b_element_op(b_n_k(n, k)) + : b_element_op(b_n_k(k, n)); v_acc += ck_tile::type_convert(v_a) * ck_tile::type_convert(v_b); } - c_m_n(m, n) = ck_tile::type_convert(acc_element_op(v_acc)); + CDataType& c_ref = (std::is_same_v) + ? c_m_n(m, n) + : c_m_n(n, m); + c_ref = ck_tile::type_convert(acc_element_op(v_acc)); } }; make_ParallelTensorFunctor(f, M)(std::thread::hardware_concurrency()); } -template +template __global__ void naive_gemm_kernel(ADataType* A, BDataType* B, CDataType* C, @@ -76,18 +89,32 @@ __global__ void naive_gemm_kernel(ADataType* A, if(row < M && col < N) { AccDataType acc = 0.0; - for(int k = 0; k < K; ++k) { - acc += static_cast(A[row * strideA + k]) * - static_cast(B[col * strideB + k]); + // Adjust indexing based on matrix layout + int a_index = (std::is_same_v) + ? row * strideA + k + : k * strideA + row; + int b_index = (std::is_same_v) + ? col * strideB + k + : k * strideB + col; + acc += static_cast(A[a_index]) * static_cast(B[b_index]); } - C[row * strideC + col] = acc; // Store as AccDataType + int c_index = (std::is_same_v) + ? row * strideC + col + : col * strideC + row; + C[c_index] = acc; } } -template +template void reference_gemm_gpu(DeviceMem& a_device, DeviceMem& b_device, DeviceMem& c_device, @@ -145,7 +172,7 @@ void reference_gemm_gpu(DeviceMem& a_device, int numThreadsPerBlock = 256; // Common choice for threads per block int numBlocks = (totalElements + numThreadsPerBlock - 1) / numThreadsPerBlock; - naive_gemm_kernel + naive_gemm_kernel <<>>(d_A, d_B, d_C, M, N, K, stride_a, stride_b, stride_c); errC = hipMemcpy( c_device.GetDeviceBuffer(), d_C, M * N * sizeof(CDataType), hipMemcpyDeviceToHost); diff --git a/include/ck_tile/ops/epilogue.hpp b/include/ck_tile/ops/epilogue.hpp index 388f52c898..a98f60b364 100644 --- a/include/ck_tile/ops/epilogue.hpp +++ b/include/ck_tile/ops/epilogue.hpp @@ -3,5 +3,6 @@ #pragma once +#include "ck_tile/ops/epilogue/cshuffle_epilogue.hpp" #include "ck_tile/ops/epilogue/default_2d_epilogue.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" diff --git a/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp b/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp new file mode 100644 index 0000000000..9625b137bd --- /dev/null +++ b/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp @@ -0,0 +1,171 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" + +#define CK_TILE_MAX_RANK 5 + +namespace ck_tile { + +// this epilogue aiming to store a matrix with different layout from the shared memory to the global +// memory. +template +struct CShuffleEpilogueProblem +{ + using AccDataType = remove_cvref_t; + using ODataType = remove_cvref_t; + static constexpr bool kPadM = kPadM_; + static constexpr bool kPadN = kPadN_; + static constexpr bool kTilePermute = kTilePermute_; + static constexpr index_t kRank = kRank_; + static constexpr index_t kPerm[CK_TILE_MAX_RANK] = {kPerm0, kPerm1, kPerm2, kPerm3, kPerm4}; + static constexpr index_t tile_sizes[CK_TILE_MAX_RANK] = { + TileSize0, TileSize1, TileSize2, TileSize3, TileSize4}; +}; + +template +struct CShuffleEpilogue +{ + using Problem = remove_cvref_t; + using AccDataType = remove_cvref_t; + using ODataType = remove_cvref_t; + static constexpr bool kPadM = Problem::kPadM; + static constexpr bool kPadN = Problem::kPadN; + const index_t* kPerm = Problem::kPerm; + static constexpr bool kTilePermute = Problem::kTilePermute; + static constexpr index_t kRank = Problem::kRank; + const index_t* tile_sizes = Problem::tile_sizes; + + // No additional shared memory needed + CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { return 0; } + + template + CK_TILE_DEVICE void permute_tile_data(OAccTile& o_acc_tile) + { + using DataType = typename OAccTile::DataType; + + // Get thread buffer + auto& thread_buf = o_acc_tile.get_thread_buffer(); + + // Create a temporary buffer to hold the permuted data + thread_buffer permuted_thread_buf; + + // Get the lengths of each dimension + auto thread_tensor_lengths = o_acc_tile.get_lengths(); + + // Total number of elements + index_t total_elements = OAccTile::kThreadElementSpaceSize; + + // Iterate over all elements + for(index_t linear_idx = 0; linear_idx < total_elements; ++linear_idx) + { + // Convert linear index to multi-dimensional indices + array indices; + index_t remaining = linear_idx; + static_for<0, kRank, 1>{}([&](auto i) { + constexpr auto rev_i = kRank - 1 - i; + indices(rev_i) = remaining % thread_tensor_lengths.get(number{}); + remaining /= thread_tensor_lengths.get(number{}); + }); + + // Apply the permutation + array permuted_indices; + static_for<0, kRank, 1>{}( + [&](auto i) { permuted_indices(i) = indices.get(number{}); }); + + // Compute offsets + index_t dst_offset = 0; + index_t stride = 1; + + static_for<0, kRank, 1>{}([&](auto i) { + constexpr auto rev_i = kRank - 1 - i; + dst_offset += permuted_indices[rev_i] * stride; + stride *= thread_tensor_lengths.get(number{}); + }); + + // Move the data + permuted_thread_buf(dst_offset) = thread_buf[linear_idx]; + } + + // Copy the permuted data back to the original thread buffer + for(index_t i = 0; i < total_elements; ++i) + { + thread_buf.set_as(i, permuted_thread_buf.get(i)); + } + } + + template + CK_TILE_DEVICE auto operator()(ODramWindowTmp& o_dram_window_tmp, OAccTile& o_acc_tile) + { + const auto& current_window_origin = o_dram_window_tmp.get_window_origin(); + + // Compute the tile coordinates by dividing the window origin by the tile sizes + index_t tile_coords[CK_TILE_MAX_RANK] = {0}; + for(index_t i = 0; i < kRank; ++i) + { + tile_coords[i] = current_window_origin[i] / tile_sizes[i]; + // printf("The tile_coord is: %d", tile_coords[i]); + } + + // Apply the permutation to the tile coordinates + index_t permuted_tile_coords[CK_TILE_MAX_RANK]; + for(index_t i = 0; i < kRank; ++i) + { + permuted_tile_coords[i] = tile_coords[kPerm[i]]; + // printf("The new permuted_tile_coords is: %d", permuted_tile_coords[i]); + } + + // Compute the permuted window origin + index_t permuted_window_origin[CK_TILE_MAX_RANK] = {0}; + for(index_t i = 0; i < kRank; ++i) + { + permuted_window_origin[i] = permuted_tile_coords[i] * tile_sizes[i]; + // printf("The new permuted_window_origin is: %d", permuted_window_origin[i]); + } + + typename ODramWindowTmp::BottomTensorIndex step = {}; + for(index_t i = 0; i < kRank; ++i) + { + step[i] = permuted_window_origin[i] - current_window_origin[i]; + } + + // Move the window + move_tile_window(o_dram_window_tmp, step); + + // Permute the data within the tile if necessary + if constexpr(kTilePermute) + { + permute_tile_data(o_acc_tile); + } + + // Store the tile data to the permuted location + if constexpr(kPadM || kPadN) + { + store_tile_raw(o_dram_window_tmp, cast_tile(o_acc_tile)); + buffer_store_fence(); + } + else + { + store_tile(o_dram_window_tmp, cast_tile(o_acc_tile)); + } + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp index 8647a7d25a..e1f05d39db 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp @@ -5,8 +5,9 @@ #include "ck_tile/core.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" -#include "ck_tile/ops/gemm/pipeline/block_gemm_pipeline_problem.hpp" +#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp" #include "ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp" +#include "ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp" #include "ck_tile/ops/gemm/warp/warp_gemm.hpp" #include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp" #include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_custom_policy.hpp" @@ -25,15 +26,21 @@ struct BlockFmhaBwdPipelineDefaultPolicy template CK_TILE_HOST_DEVICE static constexpr auto GetQKBlockGemm() { - using BlockGemmProblem = BlockGemmPipelineProblem< - typename Problem::QDataType, - typename Problem::KDataType, - typename Problem::AccDataType, - TileGemmShape, - typename Problem::BlockFmhaShape::Gemm0BlockWarps, - typename Problem::BlockFmhaShape::Gemm0WarpTile>>; + using GemmProblem = + GemmPipelineProblem, + typename Problem::BlockFmhaShape::Gemm0BlockWarps, + typename Problem::BlockFmhaShape::Gemm0WarpTile>, + TileGemmTraits>; using WarpGemm = WarpGemmMfmaDispatcher< typename Problem::QDataType, @@ -52,21 +59,27 @@ struct BlockFmhaBwdPipelineDefaultPolicy typename Problem::BlockFmhaShape::Gemm0BlockWarps, WarpGemm>; - return BlockGemmARegBRegCRegV1{}; + return BlockGemmARegBRegCRegV1{}; } template CK_TILE_HOST_DEVICE static constexpr auto GetPTOGradTBlockGemm() { - using BlockGemmProblem = BlockGemmPipelineProblem< - typename Problem::GemmDataType, - typename Problem::OGradDataType, - typename Problem::AccDataType, - TileGemmShape, - typename Problem::BlockFmhaShape::Gemm1BlockWarps, - typename Problem::BlockFmhaShape::Gemm1WarpTile>>; + using GemmProblem = + GemmPipelineProblem, + typename Problem::BlockFmhaShape::Gemm1BlockWarps, + typename Problem::BlockFmhaShape::Gemm1WarpTile>, + TileGemmTraits>; using WarpGemm = WarpGemmMfmaDispatcher; - return BlockGemmARegBRegCRegV1{}; + return BlockGemmARegBRegCRegV1{}; } template CK_TILE_HOST_DEVICE static constexpr auto GetOGradVBlockGemm() { - using BlockGemmProblem = BlockGemmPipelineProblem< - typename Problem::OGradDataType, - typename Problem::VDataType, - typename Problem::AccDataType, - TileGemmShape, - typename Problem::BlockFmhaShape::Gemm2BlockWarps, - typename Problem::BlockFmhaShape::Gemm2WarpTile>>; + using GemmProblem = + GemmPipelineProblem, + typename Problem::BlockFmhaShape::Gemm2BlockWarps, + typename Problem::BlockFmhaShape::Gemm2WarpTile>, + TileGemmTraits>; using WarpGemm = WarpGemmMfmaDispatcher< typename Problem::OGradDataType, @@ -117,21 +136,27 @@ struct BlockFmhaBwdPipelineDefaultPolicy typename Problem::BlockFmhaShape::Gemm2BlockWarps, WarpGemm>; - return BlockGemmARegBRegCRegV1{}; + return BlockGemmARegBRegCRegV1{}; } template CK_TILE_HOST_DEVICE static constexpr auto GetSGradTQTBlockGemm() { - using BlockGemmProblem = BlockGemmPipelineProblem< - typename Problem::GemmDataType, - typename Problem::QDataType, - typename Problem::AccDataType, - TileGemmShape, - typename Problem::BlockFmhaShape::Gemm3BlockWarps, - typename Problem::BlockFmhaShape::Gemm3WarpTile>>; + using GemmProblem = + GemmPipelineProblem, + typename Problem::BlockFmhaShape::Gemm3BlockWarps, + typename Problem::BlockFmhaShape::Gemm3WarpTile>, + TileGemmTraits>; using WarpGemm = WarpGemmMfmaDispatcher; - return BlockGemmARegBRegCRegV1{}; + return BlockGemmARegBRegCRegV1{}; } template CK_TILE_HOST_DEVICE static constexpr auto GetSGradKTBlockGemm() { - using BlockGemmProblem = BlockGemmPipelineProblem< - typename Problem::GemmDataType, - typename Problem::KDataType, - typename Problem::AccDataType, - TileGemmShape, - typename Problem::BlockFmhaShape::Gemm4BlockWarps, - typename Problem::BlockFmhaShape::Gemm4WarpTile>>; + using GemmProblem = + GemmPipelineProblem, + typename Problem::BlockFmhaShape::Gemm4BlockWarps, + typename Problem::BlockFmhaShape::Gemm4WarpTile>, + TileGemmTraits>; using WarpGemm = WarpGemmMfmaDispatcher; - return BlockGemmARegBRegCRegV1{}; + return BlockGemmARegBRegCRegV1{}; } // these are for global load diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp index ae9e320f67..4ea0c4c9f2 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp @@ -5,8 +5,9 @@ #include "ck_tile/core.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" -#include "ck_tile/ops/gemm/pipeline/block_gemm_pipeline_problem.hpp" +#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp" #include "ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp" +#include "ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp" #include "ck_tile/ops/gemm/warp/warp_gemm.hpp" #include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp" #include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_custom_policy.hpp" @@ -75,15 +76,21 @@ struct BlockFmhaPipelineQXCustomPolicy template CK_TILE_HOST_DEVICE static constexpr auto GetQKBlockGemm() { - using BlockGemmProblem = BlockGemmPipelineProblem< - typename Problem::QDataType, - typename Problem::KDataType, - typename Problem::SaccDataType, - TileGemmShape, - typename Problem::BlockFmhaShape::Gemm0BlockWarps, - typename Problem::BlockFmhaShape::Gemm0WarpTile>>; + using GemmProblem = + GemmPipelineProblem, + typename Problem::BlockFmhaShape::Gemm0BlockWarps, + typename Problem::BlockFmhaShape::Gemm0WarpTile>, + TileGemmTraits>; constexpr auto warp_gemm = []() { if constexpr(std::is_same_v && @@ -116,7 +123,7 @@ struct BlockFmhaPipelineQXCustomPolicy typename Problem::BlockFmhaShape::Gemm0BlockWarps, decltype(warp_gemm)>; - return BlockGemmARegBSmemCRegV2{}; + return BlockGemmARegBSmemCRegV2{}; } }; @@ -199,15 +206,21 @@ struct BlockFmhaPipelineQXCustomPolicy template CK_TILE_HOST_DEVICE static constexpr auto GetQKBlockGemm() { - using BlockGemmProblem = BlockGemmPipelineProblem< - typename Problem::QDataType, - typename Problem::KDataType, - typename Problem::SaccDataType, - TileGemmShape, - typename Problem::BlockFmhaShape::Gemm0BlockWarps, - typename Problem::BlockFmhaShape::Gemm0WarpTile>>; + using GemmProblem = + GemmPipelineProblem, + typename Problem::BlockFmhaShape::Gemm0BlockWarps, + typename Problem::BlockFmhaShape::Gemm0WarpTile>, + TileGemmTraits>; constexpr auto warp_gemm = []() { if constexpr(std::is_same_v && @@ -240,7 +253,7 @@ struct BlockFmhaPipelineQXCustomPolicy typename Problem::BlockFmhaShape::Gemm0BlockWarps, decltype(warp_gemm)>; - return BlockGemmASmemBSmemCRegV1{}; + return BlockGemmASmemBSmemCRegV1{}; } }; @@ -954,15 +967,21 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy CK_TILE_HOST_DEVICE static constexpr auto GetKVBlockGemm() { - using BlockGemmProblem = BlockGemmPipelineProblem< - typename Problem::PDataType, - typename Problem::VDataType, - typename Problem::OaccDataType, - TileGemmShape, - typename Problem::BlockFmhaShape::Gemm1BlockWarps, - typename Problem::BlockFmhaShape::Gemm1WarpTile>>; + using GemmProblem = + GemmPipelineProblem, + typename Problem::BlockFmhaShape::Gemm1BlockWarps, + typename Problem::BlockFmhaShape::Gemm1WarpTile>, + TileGemmTraits>; auto warp_gemm = [&]() { if constexpr(std::is_same_v && @@ -996,7 +1015,7 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy; - return BlockGemmARegBSmemCRegV2{}; + return BlockGemmARegBSmemCRegV2{}; } }; diff --git a/include/ck_tile/ops/gemm.hpp b/include/ck_tile/ops/gemm.hpp index e9005462b0..dc5983e4d1 100644 --- a/include/ck_tile/ops/gemm.hpp +++ b/include/ck_tile/ops/gemm.hpp @@ -23,12 +23,13 @@ #include "ck_tile/ops/gemm/block/block_gemm_problem.hpp" #include "ck_tile/ops/gemm/kernel/gemm_kernel.hpp" #include "ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp" -#include "ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v1.hpp" -#include "ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp" -#include "ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v2.hpp" -#include "ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v2_default_policy.hpp" -#include "ck_tile/ops/gemm/pipeline/block_gemm_pipeline_problem.hpp" +#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp" +#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp" +#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2.hpp" +#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2_default_policy.hpp" +#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp" #include "ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp" +#include "ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp" #include "ck_tile/ops/gemm/warp/warp_gemm.hpp" #include "ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp" #include "ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp" diff --git a/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp b/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp index e24d7f9ea0..48329c8ba5 100644 --- a/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp +++ b/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp @@ -11,20 +11,12 @@ namespace ck_tile { -template +template struct GemmKernel { using TilePartitioner = remove_cvref_t; using GemmPipeline = remove_cvref_t; using EpiloguePipeline = remove_cvref_t; - using LayoutA = remove_cvref_t; - using LayoutB = remove_cvref_t; - using LayoutC = remove_cvref_t; static constexpr index_t KernelBlockSize = GemmPipeline::kBlockSize; using ADataType = remove_cvref_t; @@ -32,6 +24,10 @@ struct GemmKernel using CAccDataType = remove_cvref_t; using CODataType = remove_cvref_t; + using LayoutA = remove_cvref_t; + using LayoutB = remove_cvref_t; + using LayoutC = remove_cvref_t; + __host__ static constexpr auto GridSize(index_t M_size, index_t N_size, index_t Batch_size) { return TilePartitioner::GridSize(M_size, N_size, Batch_size); @@ -184,6 +180,7 @@ struct GemmKernel c_pad_view, make_tuple(number{}, number{}), {i_m, i_n}); + EpiloguePipeline{}(CBlockWindow_pad, acc); } }; diff --git a/include/ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v1.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp similarity index 95% rename from include/ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v1.hpp rename to include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp index bec8a204cc..5ed7d036ea 100644 --- a/include/ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v1.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp @@ -4,15 +4,15 @@ #pragma once #include "ck_tile/core.hpp" -#include "ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp" +#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp" namespace ck_tile { // A Tile Window: global memory // B Tile Window: global memory // C Distributed tensor: register -template -struct BlockGemmPipelineAGmemBGmemCRegV1 +template +struct GemmPipelineAGmemBGmemCRegV1 { using ADataType = remove_cvref_t; using BDataType = remove_cvref_t; @@ -33,6 +33,10 @@ struct BlockGemmPipelineAGmemBGmemCRegV1 static constexpr bool kPadB = Problem::kPadB; static constexpr bool kPadC = Problem::kPadC; + using LayoutA = remove_cvref_t; + using LayoutB = remove_cvref_t; + using LayoutC = remove_cvref_t; + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetStaticLdsSize() { return ck_tile::integer_divide_ceil( diff --git a/include/ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp similarity index 99% rename from include/ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp rename to include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp index 3048adad67..8639f00fbb 100644 --- a/include/ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp @@ -7,9 +7,9 @@ namespace ck_tile { -// Default policy for BlockGemmPipelineAGmemBGmemCRegV1 +// Default policy for GemmPipelineAGmemBGmemCRegV1 // Default policy class should not be templated, put template on member functions instead -struct BlockGemmPipelineAGmemBGmemCRegV1DefaultPolicy +struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy { #if 0 // 2d diff --git a/include/ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v2.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2.hpp similarity index 97% rename from include/ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v2.hpp rename to include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2.hpp index ab5fe79114..bff7fc0a0e 100644 --- a/include/ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v2.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2.hpp @@ -4,15 +4,15 @@ #pragma once #include "ck_tile/core.hpp" -#include "ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v2_default_policy.hpp" +#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2_default_policy.hpp" namespace ck_tile { // A Tile Window: global memory // B Tile Window: global memory // C Distributed tensor: register -template -struct BlockGemmPipelineAGmemBGmemCRegV2 +template +struct GemmPipelineAGmemBGmemCRegV2 { using ADataType = remove_cvref_t; using BDataType = remove_cvref_t; diff --git a/include/ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v2_default_policy.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2_default_policy.hpp similarity index 57% rename from include/ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v2_default_policy.hpp rename to include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2_default_policy.hpp index 0596408501..7dad55d6b9 100644 --- a/include/ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v2_default_policy.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2_default_policy.hpp @@ -7,12 +7,11 @@ namespace ck_tile { -// Default policy for BlockGemmPipelineAGmemBGmemCRegV2 +// Default policy for GemmPipelineAGmemBGmemCRegV2 // Default policy class should not be templated, put template on member functions instead // NOTE: policy should be binded to its corresponding operation. It's just a coincidence that -// BlockGemmPipelineAGmemBGmemCRegV2DefaultPolicy is the same as -// BlockGemmPipelineAGmemBGmemCRegV1DefaultPolicy -using BlockGemmPipelineAGmemBGmemCRegV2DefaultPolicy = - BlockGemmPipelineAGmemBGmemCRegV1DefaultPolicy; +// GemmPipelineAGmemBGmemCRegV2DefaultPolicy is the same as +// GemmPipelineAGmemBGmemCRegV1DefaultPolicy +using GemmPipelineAGmemBGmemCRegV2DefaultPolicy = GemmPipelineAGmemBGmemCRegV1DefaultPolicy; } // namespace ck_tile diff --git a/include/ck_tile/ops/gemm/pipeline/block_gemm_pipeline_problem.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp similarity index 65% rename from include/ck_tile/ops/gemm/pipeline/block_gemm_pipeline_problem.hpp rename to include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp index 8dfba08ad7..d7b3b24a4a 100644 --- a/include/ck_tile/ops/gemm/pipeline/block_gemm_pipeline_problem.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp @@ -13,20 +13,23 @@ template -struct BlockGemmPipelineProblem + typename TileGemmTraits_> +struct GemmPipelineProblem { using ADataType = remove_cvref_t; using BDataType = remove_cvref_t; using CDataType = remove_cvref_t; using BlockGemmShape = remove_cvref_t; + using GemmTraits = remove_cvref_t; static constexpr index_t kBlockSize = BlockGemmShape::NumWarps * get_warp_size(); - static constexpr bool kPadA = kPadA_; - static constexpr bool kPadB = kPadB_; - static constexpr bool kPadC = kPadC_; + static constexpr bool kPadA = GemmTraits::kPadA; + static constexpr bool kPadB = GemmTraits::kPadB; + static constexpr bool kPadC = GemmTraits::kPadC; + + using LayoutA = remove_cvref_t; + using LayoutB = remove_cvref_t; + using LayoutC = remove_cvref_t; static constexpr index_t AlignmentA = kPadA ? 1 : VectorLoadSize / sizeof(ADataType); static constexpr index_t AlignmentB = kPadB ? 1 : VectorLoadSize / sizeof(BDataType); diff --git a/include/ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp b/include/ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp new file mode 100644 index 0000000000..98da1510c7 --- /dev/null +++ b/include/ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp @@ -0,0 +1,27 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" + +namespace ck_tile { + +template +struct TileGemmTraits +{ + static constexpr bool kPadA = kPadA_; + static constexpr bool kPadB = kPadB_; + static constexpr bool kPadC = kPadC_; + + using LayoutA = LayoutA_; + using LayoutB = LayoutB_; + using LayoutC = LayoutC_; +}; + +} // namespace ck_tile From d18fc0797ff483dee4446e643798be699713d22c Mon Sep 17 00:00:00 2001 From: Rostyslav Geyyer <46627076+geyyer@users.noreply.github.com> Date: Thu, 10 Oct 2024 09:37:09 -0500 Subject: [PATCH 17/24] Fix default stride value (#1559) --- example/01_gemm/run_gemm_example_streamk_v2.inc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/example/01_gemm/run_gemm_example_streamk_v2.inc b/example/01_gemm/run_gemm_example_streamk_v2.inc index 6679f95157..32bd3a19a6 100644 --- a/example/01_gemm/run_gemm_example_streamk_v2.inc +++ b/example/01_gemm/run_gemm_example_streamk_v2.inc @@ -117,9 +117,9 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) auto f_get_default_stride = [](std::size_t row, std::size_t col, ck::index_t stride, auto layout) { - if(stride == -1) + if(stride == 0) { - // give a chance if stride is -1, return a default packed stride + // give a chance if stride is 0, return a default packed stride if constexpr(std::is_same_v) { return static_cast(col); From 14c52befdaadc392e93450df6b5501f70c43f34d Mon Sep 17 00:00:00 2001 From: spolifroni-amd Date: Thu, 10 Oct 2024 16:57:23 -0400 Subject: [PATCH 18/24] removed API usage header (#1566) --- docs/reference/API_Reference_Guide.rst | 6 ------ 1 file changed, 6 deletions(-) diff --git a/docs/reference/API_Reference_Guide.rst b/docs/reference/API_Reference_Guide.rst index 22222b0cf0..0d2d41c1eb 100644 --- a/docs/reference/API_Reference_Guide.rst +++ b/docs/reference/API_Reference_Guide.rst @@ -12,12 +12,6 @@ API reference guide This document contains details of the APIs for the Composable Kernel (CK) library and introduces some of the key design principles that are used to write new classes that extend CK functionality. -================= -Using CK API -================= - -This section describes how to use the CK library API. - ================= CK Datatypes ================= From f46a9eee9dbcf44697b3dad27f0675ca6d877d99 Mon Sep 17 00:00:00 2001 From: Illia Silin <98187287+illsilin@users.noreply.github.com> Date: Thu, 10 Oct 2024 15:31:56 -0700 Subject: [PATCH 19/24] only build tests and examples if user sets GPU_TARGETS (#1565) --- CMakeLists.txt | 2 +- README.md | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 3f22bb4b61..cfcfa24b37 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -549,7 +549,7 @@ ENDFOREACH() add_custom_target(instances DEPENDS utility;${CK_DEVICE_INSTANCES} SOURCES ${INSTANCE_FILES}) add_subdirectory(library) -if(NOT GPU_ARCHS) +if(NOT GPU_ARCHS AND USER_GPU_TARGETS) rocm_package_setup_component(tests LIBRARY_NAME composablekernel PACKAGE_NAME tests # Prevent -static suffix on package name diff --git a/README.md b/README.md index 34ac0919ae..4366ec0329 100644 --- a/README.md +++ b/README.md @@ -91,6 +91,7 @@ Docker images are available on [DockerHub](https://hub.docker.com/r/rocm/composa If you don't set `GPU_TARGETS` on the cmake command line, CK is built for all GPU targets supported by the current compiler (this may take a long time). + Tests and examples will only get built if the GPU_TARGETS is set by the user on the cmake command line. NOTE: If you try setting `GPU_TARGETS` to a list of architectures, the build will only work if the architectures are similar, e.g., `gfx908;gfx90a`, or `gfx1100;gfx1101;gfx11012`. Otherwise, if you From 11444e4cf2d158500a10dbd2ace3bbd27cc65776 Mon Sep 17 00:00:00 2001 From: Illia Silin <98187287+illsilin@users.noreply.github.com> Date: Fri, 11 Oct 2024 14:29:46 -0700 Subject: [PATCH 20/24] [CI] remove the --rm docker container flags (#1568) --- Jenkinsfile | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/Jenkinsfile b/Jenkinsfile index a79ed859f2..132257ad80 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -353,7 +353,7 @@ def buildHipClangJob(Map conf=[:]){ def prefixpath = conf.get("prefixpath", "/opt/rocm") // Jenkins is complaining about the render group - def dockerOpts="--rm --device=/dev/kfd --device=/dev/dri --group-add video --group-add render --cap-add=SYS_PTRACE --security-opt seccomp=unconfined" + def dockerOpts="--device=/dev/kfd --device=/dev/dri --group-add video --group-add render --cap-add=SYS_PTRACE --security-opt seccomp=unconfined" if (conf.get("enforce_xnack_on", false)) { dockerOpts = dockerOpts + " --env HSA_XNACK=1 " } @@ -412,7 +412,7 @@ def runCKProfiler(Map conf=[:]){ def prefixpath = conf.get("prefixpath", "/opt/rocm") // Jenkins is complaining about the render group - def dockerOpts="--rm --device=/dev/kfd --device=/dev/dri --group-add video --group-add render --cap-add=SYS_PTRACE --security-opt seccomp=unconfined" + def dockerOpts="--device=/dev/kfd --device=/dev/dri --group-add video --group-add render --cap-add=SYS_PTRACE --security-opt seccomp=unconfined" if (conf.get("enforce_xnack_on", false)) { dockerOpts = dockerOpts + " --env HSA_XNACK=1 " } @@ -544,7 +544,7 @@ def Build_CK(Map conf=[:]){ def prefixpath = conf.get("prefixpath", "/opt/rocm") // Jenkins is complaining about the render group - def dockerOpts="--rm --device=/dev/kfd --device=/dev/dri --group-add video --group-add render --cap-add=SYS_PTRACE --security-opt seccomp=unconfined" + def dockerOpts="--device=/dev/kfd --device=/dev/dri --group-add video --group-add render --cap-add=SYS_PTRACE --security-opt seccomp=unconfined" if (conf.get("enforce_xnack_on", false)) { dockerOpts = dockerOpts + " --env HSA_XNACK=1 " } @@ -660,7 +660,7 @@ def process_results(Map conf=[:]){ def prefixpath = "/opt/rocm" // Jenkins is complaining about the render group - def dockerOpts="--rm --cap-add=SYS_PTRACE --security-opt seccomp=unconfined" + def dockerOpts="--cap-add=SYS_PTRACE --security-opt seccomp=unconfined" if (conf.get("enforce_xnack_on", false)) { dockerOpts = dockerOpts + " --env HSA_XNACK=1 " } From 29d384d0b2f266ba8fbf3f7728d2bba4f5a7b852 Mon Sep 17 00:00:00 2001 From: Adam Osewski <19374865+aosewski@users.noreply.github.com> Date: Sat, 12 Oct 2024 08:05:11 +0200 Subject: [PATCH 21/24] Implement GetWorkSpaceSize from BaseOperator. (#1564) --- .../gpu/device/device_cgemm.hpp | 6 +++--- .../impl/device_cgemm_4gemm_xdl_cshuffle.hpp | 18 +++++++++++++++++- 2 files changed, 20 insertions(+), 4 deletions(-) diff --git a/include/ck/tensor_operation/gpu/device/device_cgemm.hpp b/include/ck/tensor_operation/gpu/device/device_cgemm.hpp index 8484212118..44dedeeef9 100644 --- a/include/ck/tensor_operation/gpu/device/device_cgemm.hpp +++ b/include/ck/tensor_operation/gpu/device/device_cgemm.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once #include "device_base.hpp" @@ -31,13 +31,13 @@ struct DeviceCGemm : public BaseOperator CElementwiseOperation c_element_op, ck::index_t KBatch = 1) = 0; - virtual std::unique_ptr MakeInvokerPointer() = 0; + virtual std::unique_ptr MakeInvokerPointer() = 0; virtual std::size_t GetWorkspaceSize(index_t MRaw, index_t NRaw, index_t KRaw, index_t StrideA, index_t StrideB, - index_t StrideC) = 0; + index_t StrideC) const = 0; }; template (base_arg); + + if(!parg) + { + std::ostringstream err; + err << "Provided argument pointer is not of an Argument class!" + << " In " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__; + throw std::runtime_error(err.str()); + } + + return GetWorkspaceSize( + parg->M, parg->N, parg->K, parg->StrideA, parg->StrideB, parg->StrideC); + } }; } // namespace device From 35c1777d59d89ccab1b25391daf3836af5a75522 Mon Sep 17 00:00:00 2001 From: Thomas Ning Date: Mon, 14 Oct 2024 13:59:26 +0800 Subject: [PATCH 22/24] decouple the calling from gemm_pipeline (#1571) * decouple the calling from gemm_pipeline * clang format --- ...block_fmha_bwd_pipeline_default_policy.hpp | 118 +++++++----------- ...k_fmha_pipeline_qx_ks_vs_custom_policy.hpp | 71 +++++------ 2 files changed, 74 insertions(+), 115 deletions(-) diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp index e1f05d39db..0afad0446c 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp @@ -5,9 +5,8 @@ #include "ck_tile/core.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" -#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp" +#include "ck_tile/ops/gemm/block/block_gemm_problem.hpp" #include "ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp" -#include "ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp" #include "ck_tile/ops/gemm/warp/warp_gemm.hpp" #include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp" #include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_custom_policy.hpp" @@ -27,20 +26,15 @@ struct BlockFmhaBwdPipelineDefaultPolicy CK_TILE_HOST_DEVICE static constexpr auto GetQKBlockGemm() { using GemmProblem = - GemmPipelineProblem, - typename Problem::BlockFmhaShape::Gemm0BlockWarps, - typename Problem::BlockFmhaShape::Gemm0WarpTile>, - TileGemmTraits>; + BlockGemmProblem, + typename Problem::BlockFmhaShape::Gemm0BlockWarps, + typename Problem::BlockFmhaShape::Gemm0WarpTile>>; using WarpGemm = WarpGemmMfmaDispatcher< typename Problem::QDataType, @@ -66,20 +60,15 @@ struct BlockFmhaBwdPipelineDefaultPolicy CK_TILE_HOST_DEVICE static constexpr auto GetPTOGradTBlockGemm() { using GemmProblem = - GemmPipelineProblem, - typename Problem::BlockFmhaShape::Gemm1BlockWarps, - typename Problem::BlockFmhaShape::Gemm1WarpTile>, - TileGemmTraits>; + BlockGemmProblem, + typename Problem::BlockFmhaShape::Gemm1BlockWarps, + typename Problem::BlockFmhaShape::Gemm1WarpTile>>; using WarpGemm = WarpGemmMfmaDispatcher, - typename Problem::BlockFmhaShape::Gemm2BlockWarps, - typename Problem::BlockFmhaShape::Gemm2WarpTile>, - TileGemmTraits>; + BlockGemmProblem, + typename Problem::BlockFmhaShape::Gemm2BlockWarps, + typename Problem::BlockFmhaShape::Gemm2WarpTile>>; using WarpGemm = WarpGemmMfmaDispatcher< typename Problem::OGradDataType, @@ -143,20 +127,15 @@ struct BlockFmhaBwdPipelineDefaultPolicy CK_TILE_HOST_DEVICE static constexpr auto GetSGradTQTBlockGemm() { using GemmProblem = - GemmPipelineProblem, - typename Problem::BlockFmhaShape::Gemm3BlockWarps, - typename Problem::BlockFmhaShape::Gemm3WarpTile>, - TileGemmTraits>; + BlockGemmProblem, + typename Problem::BlockFmhaShape::Gemm3BlockWarps, + typename Problem::BlockFmhaShape::Gemm3WarpTile>>; using WarpGemm = WarpGemmMfmaDispatcher, - typename Problem::BlockFmhaShape::Gemm4BlockWarps, - typename Problem::BlockFmhaShape::Gemm4WarpTile>, - TileGemmTraits>; + BlockGemmProblem, + typename Problem::BlockFmhaShape::Gemm4BlockWarps, + typename Problem::BlockFmhaShape::Gemm4WarpTile>>; using WarpGemm = WarpGemmMfmaDispatcher CK_TILE_HOST_DEVICE static constexpr auto GetQKBlockGemm() { using GemmProblem = - GemmPipelineProblem, - typename Problem::BlockFmhaShape::Gemm0BlockWarps, - typename Problem::BlockFmhaShape::Gemm0WarpTile>, - TileGemmTraits>; + BlockGemmProblem, + typename Problem::BlockFmhaShape::Gemm0BlockWarps, + typename Problem::BlockFmhaShape::Gemm0WarpTile>>; constexpr auto warp_gemm = []() { if constexpr(std::is_same_v && @@ -207,20 +202,15 @@ struct BlockFmhaPipelineQXCustomPolicy CK_TILE_HOST_DEVICE static constexpr auto GetQKBlockGemm() { using GemmProblem = - GemmPipelineProblem, - typename Problem::BlockFmhaShape::Gemm0BlockWarps, - typename Problem::BlockFmhaShape::Gemm0WarpTile>, - TileGemmTraits>; + BlockGemmProblem, + typename Problem::BlockFmhaShape::Gemm0BlockWarps, + typename Problem::BlockFmhaShape::Gemm0WarpTile>>; constexpr auto warp_gemm = []() { if constexpr(std::is_same_v && @@ -968,20 +958,15 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy, - typename Problem::BlockFmhaShape::Gemm1BlockWarps, - typename Problem::BlockFmhaShape::Gemm1WarpTile>, - TileGemmTraits>; + BlockGemmProblem, + typename Problem::BlockFmhaShape::Gemm1BlockWarps, + typename Problem::BlockFmhaShape::Gemm1WarpTile>>; auto warp_gemm = [&]() { if constexpr(std::is_same_v && From f21cda25366311091c4d7e97ac0f3d739f102c10 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bart=C5=82omiej=20Kocot?= Date: Mon, 14 Oct 2024 17:39:38 +0200 Subject: [PATCH 23/24] Add transpose scale amax example (#1547) * Add transpose scale amax example * fixes * Tune reduce instance --- example/44_elementwise_permute/CMakeLists.txt | 1 + ...entwise_scale_permute_amax_2D_fp16_fp8.cpp | 247 ++++++++++++++++++ .../element/unary_element_wise_operation.hpp | 6 + include/ck/utility/math_v2.hpp | 4 + 4 files changed, 258 insertions(+) create mode 100644 example/44_elementwise_permute/elementwise_scale_permute_amax_2D_fp16_fp8.cpp diff --git a/example/44_elementwise_permute/CMakeLists.txt b/example/44_elementwise_permute/CMakeLists.txt index afbf948683..867493465d 100644 --- a/example/44_elementwise_permute/CMakeLists.txt +++ b/example/44_elementwise_permute/CMakeLists.txt @@ -5,3 +5,4 @@ add_example_executable(example_elementwise_permute_4D_fp32_col elementwise_permu add_example_executable(example_elementwise_permute_4D_fp16_col elementwise_permute_4D_fp16_col.cpp) add_example_executable(example_elementwise_binary_4D_fp16 elementwise_binary_4D_fp16.cpp) add_example_executable(example_elementwise_trinary_4D_fp16 elementwise_trinary_4D_fp16.cpp) +add_example_executable(elementwise_scale_permute_amax_2D_fp16_fp8 elementwise_scale_permute_amax_2D_fp16_fp8.cpp) diff --git a/example/44_elementwise_permute/elementwise_scale_permute_amax_2D_fp16_fp8.cpp b/example/44_elementwise_permute/elementwise_scale_permute_amax_2D_fp16_fp8.cpp new file mode 100644 index 0000000000..7ac3c4e239 --- /dev/null +++ b/example/44_elementwise_permute/elementwise_scale_permute_amax_2D_fp16_fp8.cpp @@ -0,0 +1,247 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_elementwise_dynamic_vector_dims_impl.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_reduce_multiblock.hpp" +#include "ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp" + +#include "ck/library/reference_tensor_operation/cpu/reference_elementwise.hpp" + +#include "ck/library/utility/algorithm.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/utility/reduction_enums.hpp" + +using F16 = ck::half_t; +using F32 = float; +using F8 = ck::f8_t; + +using InputDataType = F16; +using ScaleDataType = F32; +using OutputDataType = F8; + +static constexpr ck::index_t NumDim = 2; + +constexpr ck::ReduceTensorOp ReduceOpId = ck::ReduceTensorOp::MAX; +constexpr bool PropagateNan = true; +constexpr bool OutputIndex = false; + +using ReduceOperation = typename ck::reduce_binary_operator::opType; + +struct ScalePassThrough +{ + ScalePassThrough(const float alpha = 1.f) : alpha_(alpha) {} + + __host__ __device__ constexpr void + operator()(OutputDataType& y0, OutputDataType& y1, const InputDataType& x0) const + { + y0 = ck::type_convert(ck::type_convert(x0) * alpha_); + y1 = y0; + } + + const ScaleDataType alpha_; +}; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using UnaryAbs = ck::tensor_operation::element_wise::UnaryAbs; + +using DeviceElementwisePermuteInstance = ck::tensor_operation::device::DeviceElementwiseImpl< + ck::Tuple, // InDataTypeTuple + ck::Tuple, // OutDataTypeTuple + ScalePassThrough, // Elementwise + NumDim, // NumDim + 256, // BlockSize + 128, // M0PerBlock + 128, // M1PerBlock + 8, // M0PerThread + 8, // M1PerThread + ck::Sequence<1, 0>, // ThreadClusterArrangeOrder + ck::Sequence<8>, // InScalarPerVectorSeq + ck::Sequence<8, 1>>; // OutScalarPerVectorSeq + +using DeviceReduceInstance = + ck::tensor_operation::device::DeviceReduceMultiBlock; // OutDstVectorSize + +void reference_scale_permute_amax(Tensor& input, + Tensor& host_output_scaled_casted_transposed, + Tensor& host_output_scaled_casted, + Tensor& host_output_amax, + const float scale) +{ + ScalePassThrough out_element_op(scale); + const ck::index_t M = input.GetLengths()[0]; + const ck::index_t K = input.GetLengths()[1]; + for(ck::index_t m = 0; m < M; m++) + { + for(ck::index_t k = 0; k < K; k++) + { + OutputDataType y0, y1; + out_element_op(y0, y1, input(m, k)); + + host_output_scaled_casted(m, k) = y0; + host_output_scaled_casted_transposed(m, k) = y1; + const OutputDataType y_fabs = + ck::type_convert(ck::math::abs(ck::type_convert(y0))); + host_output_amax(0) = ck::math::max(y_fabs, host_output_amax(0)); + } + } +} + +int main(int argc, char* argv[]) +{ + bool do_verification = true; + bool time_kernel = true; + + const float scale = 2.f; + + ck::index_t M = 1024; + ck::index_t K = 1024; + + if(argc == 3) + { + M = std::stoi(argv[1]); + K = std::stoi(argv[2]); + } + + std::array dims = {M, K}; + std::array in_strides = {K, 1}; + std::array out_strides = {1, M}; + + Tensor input(dims, in_strides); + Tensor output_scaled_casted_transposed(dims, out_strides); + Tensor output_scaled_casted(dims, in_strides); + Tensor output_amax({1}); + + input.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + + DeviceMem input_dev_buf(sizeof(InputDataType) * input.mDesc.GetElementSpaceSize()); + DeviceMem output_scaled_casted_transposed_dev_buf( + sizeof(OutputDataType) * output_scaled_casted_transposed.mDesc.GetElementSpaceSize()); + DeviceMem output_scaled_casted_dev_buf(sizeof(OutputDataType) * + output_scaled_casted.mDesc.GetElementSpaceSize()); + DeviceMem output_amax_dev_buf(sizeof(OutputDataType) * output_amax.mDesc.GetElementSpaceSize()); + + input_dev_buf.ToDevice(input.mData.data()); + + std::array inputs = {input_dev_buf.GetDeviceBuffer()}; + std::array outputs = {output_scaled_casted_transposed_dev_buf.GetDeviceBuffer(), + output_scaled_casted_dev_buf.GetDeviceBuffer()}; + + std::cout << "Input: " << input.mDesc << std::endl; + std::cout << "Scale: " << scale << std::endl; + std::cout << "Output scaled casted transposed: " << output_scaled_casted_transposed.mDesc + << std::endl; + std::cout << "Output scaled casted: " << output_scaled_casted.mDesc << std::endl; + std::cout << "Output amax: " << output_amax.mDesc << std::endl; + + auto launch_transpose_scale = [&]() { + auto transposeScale = DeviceElementwisePermuteInstance{}; + auto argument = transposeScale.MakeArgumentPointer(dims, + {in_strides}, + {out_strides, in_strides}, + inputs, + outputs, + ScalePassThrough{scale}); + + if(!transposeScale.IsSupportedArgument(argument.get())) + { + throw std::runtime_error( + "The runtime parameters seems not supported by the device instance, exiting!"); + }; + + auto transposeScale_invoker_ptr = transposeScale.MakeInvokerPointer(); + return transposeScale_invoker_ptr->Run(argument.get(), StreamConfig{nullptr, time_kernel}); + }; + + auto launch_reduce = [&]() { + auto reduce = DeviceReduceInstance{}; + auto reduce_argument_ptr = + reduce.MakeArgumentPointer(dims, + in_strides, + {1}, // Output Lengths + {1}, // Output Strides + {0, 1}, // Reduce Dims + static_cast(1.f), + static_cast(0.f), + output_scaled_casted_dev_buf.GetDeviceBuffer(), + nullptr, + output_amax_dev_buf.GetDeviceBuffer(), + nullptr, + UnaryAbs{}, + PassThrough{}); + + if(!reduce.IsSupportedArgument(reduce_argument_ptr.get())) + { + throw std::runtime_error( + "The runtime parameters seems not supported by the device instance, exiting!"); + }; + + auto invoker_ptr = reduce.MakeInvokerPointer(); + + return invoker_ptr->Run(reduce_argument_ptr.get(), StreamConfig{nullptr, time_kernel}); + }; + + float ave_time = launch_transpose_scale(); + ave_time += launch_reduce(); + std::cout << "Perf: " << ave_time << " ms" << std::endl; + bool pass = true; + + if(do_verification) + { + Tensor host_output_scaled_casted_transposed(dims, out_strides); + Tensor host_output_scaled_casted(dims, in_strides); + Tensor host_output_amax({1}); + + reference_scale_permute_amax(input, + host_output_scaled_casted_transposed, + host_output_scaled_casted, + host_output_amax, + scale); + + output_scaled_casted_transposed_dev_buf.FromDevice( + output_scaled_casted_transposed.mData.data()); + output_scaled_casted_dev_buf.FromDevice(output_scaled_casted.mData.data()); + output_amax_dev_buf.FromDevice(output_amax.mData.data()); + + pass &= ck::utils::check_err(output_scaled_casted_transposed.mData, + host_output_scaled_casted_transposed.mData, + "Error: Incorrect results scaled transposed", + 1e-3, + 1e-3); + pass &= ck::utils::check_err(output_scaled_casted.mData, + host_output_scaled_casted.mData, + "Error: Incorrect results scaled", + 1e-3, + 1e-3); + pass &= ck::utils::check_err( + output_amax.mData, host_output_amax.mData, "Error: Incorrect results amax", 1e-3, 1e-3); + } + + return pass ? 0 : 1; +} diff --git a/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp b/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp index 8079b04b84..ab6b1691af 100644 --- a/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp +++ b/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp @@ -419,6 +419,12 @@ struct UnaryAbs y = ck::math::abs(x); }; + + template <> + __host__ __device__ void operator()(f8_t& y, const f8_t& x) const + { + y = ck::type_convert(ck::math::abs(ck::type_convert(x))); + }; }; struct UnarySqrt diff --git a/include/ck/utility/math_v2.hpp b/include/ck/utility/math_v2.hpp index d961cdb198..cbbe155859 100644 --- a/include/ck/utility/math_v2.hpp +++ b/include/ck/utility/math_v2.hpp @@ -80,6 +80,8 @@ static inline __host__ bool isnan(half_t x) return (xx & 0x7FFF) > 0x7C00; }; +static inline __host__ bool isnan(f8_t x) { return (x & 0x80); }; + #ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 static inline __host__ bool isnan(int4_t x) { @@ -529,6 +531,8 @@ static inline __device__ bool isnan(half_t x) return (xx & 0x7FFF) > 0x7C00; }; +static inline __device__ bool isnan(f8_t x) { return (x & 0x80); }; + static inline __device__ half_t sqrt(half_t x) { return static_cast(__builtin_amdgcn_sqrtf(static_cast(x))); From 4cf70b36c1330b3ee25e00473b219857575d3df2 Mon Sep 17 00:00:00 2001 From: Rostyslav Geyyer <46627076+geyyer@users.noreply.github.com> Date: Mon, 14 Oct 2024 11:56:45 -0500 Subject: [PATCH 24/24] Add custom type vector support (#1333) * Add non_native_vector_type * Add a test * Add non-native vector type * Fix CTOR * Fix non-native vector type of 1 * Fix CTORs * Use vector_type to cover non-native implementation as well * Update the test * Format * Format * Fix copyright years * Remove BoolVecT so far * Add AsType test cases * Update assert error message * Remove redundant type * Update naming * Add complex half type with tests * Add tests for vector reshaping * Add missing alignas * Update test/data_type/test_custom_type.cpp Co-authored-by: Adam Osewski <19374865+aosewski@users.noreply.github.com> * Compare custom types to built-in types * Add default constructor test * Add an alignment test --------- Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com> Co-authored-by: Adam Osewski <19374865+aosewski@users.noreply.github.com> Co-authored-by: Po Yen Chen --- include/ck/utility/data_type.hpp | 655 ++++++++++++++++++++- test/data_type/CMakeLists.txt | 5 + test/data_type/test_custom_type.cpp | 874 ++++++++++++++++++++++++++++ 3 files changed, 1504 insertions(+), 30 deletions(-) create mode 100644 test/data_type/test_custom_type.cpp diff --git a/include/ck/utility/data_type.hpp b/include/ck/utility/data_type.hpp index 4df14c6211..debeb472ad 100644 --- a/include/ck/utility/data_type.hpp +++ b/include/ck/utility/data_type.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -13,8 +13,24 @@ using int4_t = _BitInt(4); using f8_t = _BitInt(8); using bf8_t = unsigned _BitInt(8); +inline constexpr auto next_pow2(uint32_t x) +{ + // Precondition: x > 1. + return x > 1u ? (1u << (32u - __builtin_clz(x - 1u))) : x; +} + +// native types: double, float, _Float16, ushort, int32_t, int8_t, uint8_t, f8_t, bf8_t, bool +template +inline constexpr bool is_native_type() +{ + return is_same::value || is_same::value || is_same::value || + is_same::value || is_same::value || is_same::value || + is_same::value || is_same::value || is_same::value || + is_same::value; +} + // vector_type -template +template struct vector_type; // Caution: DO NOT REMOVE @@ -171,7 +187,7 @@ struct scalar_type }; template -struct vector_type +struct vector_type()>> { using d1_t = T; using type = d1_t; @@ -189,7 +205,8 @@ struct vector_type template __host__ __device__ constexpr const auto& AsType() const { - static_assert(is_same::value, "wrong!"); + static_assert(is_same::value, + "Something went wrong, please check src and dst types."); return data_.d1x1_; } @@ -197,7 +214,8 @@ struct vector_type template __host__ __device__ constexpr auto& AsType() { - static_assert(is_same::value, "wrong!"); + static_assert(is_same::value, + "Something went wrong, please check src and dst types."); return data_.d1x1_; } @@ -205,7 +223,7 @@ struct vector_type __device__ int static err = 0; template -struct vector_type +struct vector_type()>> { using d1_t = T; typedef T d2_t __attribute__((ext_vector_type(2))); @@ -226,7 +244,8 @@ struct vector_type template __host__ __device__ constexpr const auto& AsType() const { - static_assert(is_same::value || is_same::value, "wrong!"); + static_assert(is_same::value || is_same::value, + "Something went wrong, please check src and dst types."); if constexpr(is_same::value) { @@ -245,7 +264,8 @@ struct vector_type template __host__ __device__ constexpr auto& AsType() { - static_assert(is_same::value || is_same::value, "wrong!"); + static_assert(is_same::value || is_same::value, + "Something went wrong, please check src and dst types."); if constexpr(is_same::value) { @@ -263,7 +283,7 @@ struct vector_type }; template -struct vector_type +struct vector_type()>> { using d1_t = T; typedef T d2_t __attribute__((ext_vector_type(2))); @@ -287,7 +307,7 @@ struct vector_type __host__ __device__ constexpr const auto& AsType() const { static_assert(is_same::value || is_same::value || is_same::value, - "wrong!"); + "Something went wrong, please check src and dst types."); if constexpr(is_same::value) { @@ -311,7 +331,7 @@ struct vector_type __host__ __device__ constexpr auto& AsType() { static_assert(is_same::value || is_same::value || is_same::value, - "wrong!"); + "Something went wrong, please check src and dst types."); if constexpr(is_same::value) { @@ -333,7 +353,7 @@ struct vector_type }; template -struct vector_type +struct vector_type()>> { using d1_t = T; typedef T d2_t __attribute__((ext_vector_type(2))); @@ -360,7 +380,7 @@ struct vector_type { static_assert(is_same::value || is_same::value || is_same::value || is_same::value, - "wrong!"); + "Something went wrong, please check src and dst types."); if constexpr(is_same::value) { @@ -389,7 +409,7 @@ struct vector_type { static_assert(is_same::value || is_same::value || is_same::value || is_same::value, - "wrong!"); + "Something went wrong, please check src and dst types."); if constexpr(is_same::value) { @@ -415,7 +435,7 @@ struct vector_type }; template -struct vector_type +struct vector_type()>> { using d1_t = T; typedef T d2_t __attribute__((ext_vector_type(2))); @@ -445,7 +465,7 @@ struct vector_type static_assert(is_same::value || is_same::value || is_same::value || is_same::value || is_same::value, - "wrong!"); + "Something went wrong, please check src and dst types."); if constexpr(is_same::value) { @@ -479,7 +499,7 @@ struct vector_type static_assert(is_same::value || is_same::value || is_same::value || is_same::value || is_same::value, - "wrong!"); + "Something went wrong, please check src and dst types."); if constexpr(is_same::value) { @@ -509,7 +529,7 @@ struct vector_type }; template -struct vector_type +struct vector_type()>> { using d1_t = T; typedef T d2_t __attribute__((ext_vector_type(2))); @@ -541,7 +561,7 @@ struct vector_type static_assert(is_same::value || is_same::value || is_same::value || is_same::value || is_same::value || is_same::value, - "wrong!"); + "Something went wrong, please check src and dst types."); if constexpr(is_same::value) { @@ -579,7 +599,7 @@ struct vector_type static_assert(is_same::value || is_same::value || is_same::value || is_same::value || is_same::value || is_same::value, - "wrong!"); + "Something went wrong, please check src and dst types."); if constexpr(is_same::value) { @@ -613,7 +633,7 @@ struct vector_type }; template -struct vector_type +struct vector_type()>> { using d1_t = T; typedef T d2_t __attribute__((ext_vector_type(2))); @@ -648,7 +668,7 @@ struct vector_type is_same::value || is_same::value || is_same::value || is_same::value || is_same::value, - "wrong!"); + "Something went wrong, please check src and dst types."); if constexpr(is_same::value) { @@ -691,7 +711,7 @@ struct vector_type is_same::value || is_same::value || is_same::value || is_same::value || is_same::value, - "wrong!"); + "Something went wrong, please check src and dst types."); if constexpr(is_same::value) { @@ -729,7 +749,7 @@ struct vector_type }; template -struct vector_type +struct vector_type()>> { using d1_t = T; typedef T d2_t __attribute__((ext_vector_type(2))); @@ -766,7 +786,7 @@ struct vector_type is_same::value || is_same::value || is_same::value || is_same::value || is_same::value || is_same::value, - "wrong!"); + "Something went wrong, please check src and dst types."); if constexpr(is_same::value) { @@ -813,7 +833,7 @@ struct vector_type is_same::value || is_same::value || is_same::value || is_same::value || is_same::value || is_same::value, - "wrong!"); + "Something went wrong, please check src and dst types."); if constexpr(is_same::value) { @@ -855,7 +875,7 @@ struct vector_type }; template -struct vector_type +struct vector_type()>> { using d1_t = T; typedef T d2_t __attribute__((ext_vector_type(2))); @@ -894,7 +914,7 @@ struct vector_type is_same::value || is_same::value || is_same::value || is_same::value || is_same::value || is_same::value || is_same::value || is_same::value || is_same::value, - "wrong!"); + "Something went wrong, please check src and dst types."); if constexpr(is_same::value) { @@ -945,7 +965,7 @@ struct vector_type is_same::value || is_same::value || is_same::value || is_same::value || is_same::value || is_same::value || is_same::value || is_same::value || is_same::value, - "wrong!"); + "Something went wrong, please check src and dst types."); if constexpr(is_same::value) { @@ -990,6 +1010,581 @@ struct vector_type } }; +template +struct non_native_vector_base +{ + using type = non_native_vector_base; + + __host__ __device__ non_native_vector_base() = default; + __host__ __device__ non_native_vector_base(const type&) = default; + __host__ __device__ non_native_vector_base(type&&) = default; + __host__ __device__ ~non_native_vector_base() = default; + + T d[N]; +}; + +// non-native vector_type implementation +template +struct vector_type()>> +{ + using d1_t = T; + using type = d1_t; + + union alignas(next_pow2(1 * sizeof(T))) + { + d1_t d1_; + StaticallyIndexedArray d1x1_; + } data_; + + __host__ __device__ constexpr vector_type() : data_{type{}} {} + + __host__ __device__ constexpr vector_type(type v) : data_{v} {} + + template + __host__ __device__ constexpr const auto& AsType() const + { + static_assert(is_same::value, + "Something went wrong, please check src and dst types."); + + return data_.d1x1_; + } + + template + __host__ __device__ constexpr auto& AsType() + { + static_assert(is_same::value, + "Something went wrong, please check src and dst types."); + + return data_.d1x1_; + } +}; + +template +struct vector_type()>> +{ + using d1_t = T; + using d2_t = non_native_vector_base; + + using type = d2_t; + + union alignas(next_pow2(2 * sizeof(T))) + { + d2_t d2_; + StaticallyIndexedArray d1x2_; + StaticallyIndexedArray d2x1_; + } data_; + + __host__ __device__ constexpr vector_type() : data_{type{}} {} + + __host__ __device__ constexpr vector_type(type v) : data_{v} {} + + template + __host__ __device__ constexpr const auto& AsType() const + { + static_assert(is_same::value || is_same::value, + "Something went wrong, please check src and dst types."); + + if constexpr(is_same::value) + { + return data_.d1x2_; + } + else if constexpr(is_same::value) + { + return data_.d2x1_; + } + else + { + return err; + } + } + + template + __host__ __device__ constexpr auto& AsType() + { + static_assert(is_same::value || is_same::value, + "Something went wrong, please check src and dst types."); + + if constexpr(is_same::value) + { + return data_.d1x2_; + } + else if constexpr(is_same::value) + { + return data_.d2x1_; + } + else + { + return err; + } + } +}; + +template +struct vector_type()>> +{ + using d1_t = T; + using d2_t = non_native_vector_base; + using d4_t = non_native_vector_base; + + using type = d4_t; + + union alignas(next_pow2(4 * sizeof(T))) + { + d4_t d4_; + StaticallyIndexedArray d1x4_; + StaticallyIndexedArray d2x2_; + StaticallyIndexedArray d4x1_; + } data_; + + __host__ __device__ constexpr vector_type() : data_{type{}} {} + + __host__ __device__ constexpr vector_type(type v) : data_{v} {} + + template + __host__ __device__ constexpr const auto& AsType() const + { + static_assert(is_same::value || is_same::value || is_same::value, + "Something went wrong, please check src and dst types."); + + if constexpr(is_same::value) + { + return data_.d1x4_; + } + else if constexpr(is_same::value) + { + return data_.d2x2_; + } + else if constexpr(is_same::value) + { + return data_.d4x1_; + } + else + { + return err; + } + } + + template + __host__ __device__ constexpr auto& AsType() + { + static_assert(is_same::value || is_same::value || is_same::value, + "Something went wrong, please check src and dst types."); + + if constexpr(is_same::value) + { + return data_.d1x4_; + } + else if constexpr(is_same::value) + { + return data_.d2x2_; + } + else if constexpr(is_same::value) + { + return data_.d4x1_; + } + else + { + return err; + } + } +}; + +template +struct vector_type()>> +{ + using d1_t = T; + using d2_t = non_native_vector_base; + using d4_t = non_native_vector_base; + using d8_t = non_native_vector_base; + + using type = d8_t; + + union alignas(next_pow2(8 * sizeof(T))) + { + d8_t d8_; + StaticallyIndexedArray d1x8_; + StaticallyIndexedArray d2x4_; + StaticallyIndexedArray d4x2_; + StaticallyIndexedArray d8x1_; + } data_; + + __host__ __device__ constexpr vector_type() : data_{type{}} {} + + __host__ __device__ constexpr vector_type(type v) : data_{v} {} + + template + __host__ __device__ constexpr const auto& AsType() const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value, + "Something went wrong, please check src and dst types."); + + if constexpr(is_same::value) + { + return data_.d1x8_; + } + else if constexpr(is_same::value) + { + return data_.d2x4_; + } + else if constexpr(is_same::value) + { + return data_.d4x2_; + } + else if constexpr(is_same::value) + { + return data_.d8x1_; + } + else + { + return err; + } + } + + template + __host__ __device__ constexpr auto& AsType() + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value, + "Something went wrong, please check src and dst types."); + + if constexpr(is_same::value) + { + return data_.d1x8_; + } + else if constexpr(is_same::value) + { + return data_.d2x4_; + } + else if constexpr(is_same::value) + { + return data_.d4x2_; + } + else if constexpr(is_same::value) + { + return data_.d8x1_; + } + else + { + return err; + } + } +}; + +template +struct vector_type()>> +{ + using d1_t = T; + using d2_t = non_native_vector_base; + using d4_t = non_native_vector_base; + using d8_t = non_native_vector_base; + using d16_t = non_native_vector_base; + + using type = d16_t; + + union alignas(next_pow2(16 * sizeof(T))) + { + d16_t d16_; + StaticallyIndexedArray d1x16_; + StaticallyIndexedArray d2x8_; + StaticallyIndexedArray d4x4_; + StaticallyIndexedArray d8x2_; + StaticallyIndexedArray d16x1_; + } data_; + + __host__ __device__ constexpr vector_type() : data_{type{}} {} + + __host__ __device__ constexpr vector_type(type v) : data_{v} {} + + template + __host__ __device__ constexpr const auto& AsType() const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "Something went wrong, please check src and dst types."); + + if constexpr(is_same::value) + { + return data_.d1x16_; + } + else if constexpr(is_same::value) + { + return data_.d2x8_; + } + else if constexpr(is_same::value) + { + return data_.d4x4_; + } + else if constexpr(is_same::value) + { + return data_.d8x2_; + } + else if constexpr(is_same::value) + { + return data_.d16x1_; + } + else + { + return err; + } + } + + template + __host__ __device__ constexpr auto& AsType() + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "Something went wrong, please check src and dst types."); + + if constexpr(is_same::value) + { + return data_.d1x16_; + } + else if constexpr(is_same::value) + { + return data_.d2x8_; + } + else if constexpr(is_same::value) + { + return data_.d4x4_; + } + else if constexpr(is_same::value) + { + return data_.d8x2_; + } + else if constexpr(is_same::value) + { + return data_.d16x1_; + } + else + { + return err; + } + } +}; + +template +struct vector_type()>> +{ + using d1_t = T; + using d2_t = non_native_vector_base; + using d4_t = non_native_vector_base; + using d8_t = non_native_vector_base; + using d16_t = non_native_vector_base; + using d32_t = non_native_vector_base; + + using type = d32_t; + + union alignas(next_pow2(32 * sizeof(T))) + { + d32_t d32_; + StaticallyIndexedArray d1x32_; + StaticallyIndexedArray d2x16_; + StaticallyIndexedArray d4x8_; + StaticallyIndexedArray d8x4_; + StaticallyIndexedArray d16x2_; + StaticallyIndexedArray d32x1_; + } data_; + + __host__ __device__ constexpr vector_type() : data_{type{}} {} + + __host__ __device__ constexpr vector_type(type v) : data_{v} {} + + template + __host__ __device__ constexpr const auto& AsType() const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value || is_same::value, + "Something went wrong, please check src and dst types."); + + if constexpr(is_same::value) + { + return data_.d1x32_; + } + else if constexpr(is_same::value) + { + return data_.d2x16_; + } + else if constexpr(is_same::value) + { + return data_.d4x8_; + } + else if constexpr(is_same::value) + { + return data_.d8x4_; + } + else if constexpr(is_same::value) + { + return data_.d16x2_; + } + else if constexpr(is_same::value) + { + return data_.d32x1_; + } + else + { + return err; + } + } + + template + __host__ __device__ constexpr auto& AsType() + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value || is_same::value, + "Something went wrong, please check src and dst types."); + + if constexpr(is_same::value) + { + return data_.d1x32_; + } + else if constexpr(is_same::value) + { + return data_.d2x16_; + } + else if constexpr(is_same::value) + { + return data_.d4x8_; + } + else if constexpr(is_same::value) + { + return data_.d8x4_; + } + else if constexpr(is_same::value) + { + return data_.d16x2_; + } + else if constexpr(is_same::value) + { + return data_.d32x1_; + } + else + { + return err; + } + } +}; + +template +struct vector_type()>> +{ + using d1_t = T; + using d2_t = non_native_vector_base; + using d4_t = non_native_vector_base; + using d8_t = non_native_vector_base; + using d16_t = non_native_vector_base; + using d32_t = non_native_vector_base; + using d64_t = non_native_vector_base; + + using type = d64_t; + + union alignas(next_pow2(64 * sizeof(T))) + { + d64_t d64_; + StaticallyIndexedArray d1x64_; + StaticallyIndexedArray d2x32_; + StaticallyIndexedArray d4x16_; + StaticallyIndexedArray d8x8_; + StaticallyIndexedArray d16x4_; + StaticallyIndexedArray d32x2_; + StaticallyIndexedArray d64x1_; + } data_; + + __host__ __device__ constexpr vector_type() : data_{type{}} {} + + __host__ __device__ constexpr vector_type(type v) : data_{v} {} + + template + __host__ __device__ constexpr const auto& AsType() const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "Something went wrong, please check src and dst types."); + + if constexpr(is_same::value) + { + return data_.d1x64_; + } + else if constexpr(is_same::value) + { + return data_.d2x32_; + } + else if constexpr(is_same::value) + { + return data_.d4x16_; + } + else if constexpr(is_same::value) + { + return data_.d8x8_; + } + else if constexpr(is_same::value) + { + return data_.d16x4_; + } + else if constexpr(is_same::value) + { + return data_.d32x2_; + } + else if constexpr(is_same::value) + { + return data_.d64x1_; + } + else + { + return err; + } + } + + template + __host__ __device__ constexpr auto& AsType() + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "Something went wrong, please check src and dst types."); + + if constexpr(is_same::value) + { + return data_.d1x64_; + } + else if constexpr(is_same::value) + { + return data_.d2x32_; + } + else if constexpr(is_same::value) + { + return data_.d4x16_; + } + else if constexpr(is_same::value) + { + return data_.d8x8_; + } + else if constexpr(is_same::value) + { + return data_.d16x4_; + } + else if constexpr(is_same::value) + { + return data_.d32x2_; + } + else if constexpr(is_same::value) + { + return data_.d64x1_; + } + else + { + return err; + } + } +}; + using int64_t = long; // fp64 @@ -1051,8 +1646,8 @@ using bf8x8_t = typename vector_type::type; using bf8x16_t = typename vector_type::type; using bf8x32_t = typename vector_type::type; using bf8x64_t = typename vector_type::type; + // u8 -// i8 using uint8x2_t = typename vector_type::type; using uint8x4_t = typename vector_type::type; using uint8x8_t = typename vector_type::type; diff --git a/test/data_type/CMakeLists.txt b/test/data_type/CMakeLists.txt index 95f1367fbf..a783be7bb0 100644 --- a/test/data_type/CMakeLists.txt +++ b/test/data_type/CMakeLists.txt @@ -18,4 +18,9 @@ if(result EQUAL 0) target_link_libraries(test_bf8 PRIVATE utility) endif() +add_gtest_executable(test_custom_type test_custom_type.cpp) +if(result EQUAL 0) + target_link_libraries(test_custom_type PRIVATE utility) +endif() + add_gtest_executable(test_type_convert_const type_convert_const.cpp) diff --git a/test/data_type/test_custom_type.cpp b/test/data_type/test_custom_type.cpp new file mode 100644 index 0000000000..1016812544 --- /dev/null +++ b/test/data_type/test_custom_type.cpp @@ -0,0 +1,874 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "gtest/gtest.h" +#include "ck/utility/data_type.hpp" +#include "ck/utility/type_convert.hpp" + +using ck::bf8_t; +using ck::bhalf_t; +using ck::f8_t; +using ck::half_t; +using ck::Number; +using ck::type_convert; +using ck::vector_type; + +TEST(Custom_bool, TestSize) +{ + struct custom_bool_t + { + bool data; + }; + ASSERT_EQ(sizeof(custom_bool_t), sizeof(bool)); + ASSERT_EQ(sizeof(vector_type), sizeof(vector_type)); + ASSERT_EQ(sizeof(vector_type), sizeof(vector_type)); + ASSERT_EQ(sizeof(vector_type), sizeof(vector_type)); + ASSERT_EQ(sizeof(vector_type), sizeof(vector_type)); + ASSERT_EQ(sizeof(vector_type), sizeof(vector_type)); + ASSERT_EQ(sizeof(vector_type), sizeof(vector_type)); +} + +TEST(Custom_bool, TestAsType) +{ + struct custom_bool_t + { + using type = bool; + type data; + custom_bool_t() : data{type{}} {} + custom_bool_t(type init) : data{init} {} + }; + + // test size + const int size = 4; + std::vector test_vec = {false, true, false, true}; + // reference vector + vector_type right_vec; + // check default CTOR + ck::static_for<0, size, 1>{}([&](auto i) { + ASSERT_EQ(right_vec.template AsType()(Number{}).data, false); + }); + // assign test values to the vector + ck::static_for<0, size, 1>{}([&](auto i) { + right_vec.template AsType()(Number{}) = custom_bool_t{test_vec.at(i)}; + }); + // copy the vector + vector_type left_vec{right_vec}; + // check if values were copied correctly + ck::static_for<0, size, 1>{}([&](auto i) { + ASSERT_EQ(left_vec.template AsType()(Number{}).data, test_vec.at(i)); + }); +} + +TEST(Custom_bool, TestAsTypeReshape) +{ + struct custom_bool_t + { + using type = bool; + type data; + custom_bool_t() : data{type{}} {} + custom_bool_t(type init) : data{init} {} + }; + + // test size + const int size = 4; + std::vector test_vec = {false, true, false, true}; + // reference vector + vector_type right_vec; + // check default CTOR + ck::static_for<0, size, 1>{}([&](auto i) { + ASSERT_EQ(right_vec.template AsType()(Number{}).data, false); + }); + // assign test values to the vector + ck::static_for<0, size, 1>{}([&](auto i) { + right_vec.template AsType()(Number{}) = custom_bool_t{test_vec.at(i)}; + }); + // copy the first half of a vector + vector_type left_vec{ + right_vec.template AsType::type>()(Number<0>{})}; + // check if values were copied correctly + ck::static_for<0, size / 2, 1>{}([&](auto i) { + ASSERT_EQ(left_vec.template AsType()(Number{}).data, test_vec.at(i)); + }); +} + +TEST(Custom_int8, TestSize) +{ + struct custom_int8_t + { + int8_t data; + }; + ASSERT_EQ(sizeof(custom_int8_t), sizeof(int8_t)); + ASSERT_EQ(sizeof(vector_type), sizeof(vector_type)); + ASSERT_EQ(sizeof(vector_type), sizeof(vector_type)); + ASSERT_EQ(sizeof(vector_type), sizeof(vector_type)); + ASSERT_EQ(sizeof(vector_type), sizeof(vector_type)); + ASSERT_EQ(sizeof(vector_type), sizeof(vector_type)); + ASSERT_EQ(sizeof(vector_type), sizeof(vector_type)); +} + +TEST(Custom_int8, TestAsType) +{ + struct custom_int8_t + { + using type = int8_t; + type data; + custom_int8_t() : data{type{}} {} + custom_int8_t(type init) : data{init} {} + }; + + // test size + const int size = 4; + std::vector test_vec = {3, -6, 8, -2}; + // reference vector + vector_type right_vec; + // check default CTOR + ck::static_for<0, size, 1>{}([&](auto i) { + ASSERT_EQ(right_vec.template AsType()(Number{}).data, 0); + }); + // assign test values to the vector + ck::static_for<0, size, 1>{}([&](auto i) { + right_vec.template AsType()(Number{}) = custom_int8_t{test_vec.at(i)}; + }); + // copy the vector + vector_type left_vec{right_vec}; + // check if values were copied correctly + ck::static_for<0, size, 1>{}([&](auto i) { + ASSERT_EQ(left_vec.template AsType()(Number{}).data, test_vec.at(i)); + }); +} + +TEST(Custom_int8, TestAsTypeReshape) +{ + struct custom_int8_t + { + using type = int8_t; + type data; + custom_int8_t() : data{type{}} {} + custom_int8_t(type init) : data{init} {} + }; + + // test size + const int size = 4; + std::vector test_vec = {3, -6, 8, -2}; + // reference vector + vector_type right_vec; + // check default CTOR + ck::static_for<0, size, 1>{}([&](auto i) { + ASSERT_EQ(right_vec.template AsType()(Number{}).data, 0); + }); + // assign test values to the vector + ck::static_for<0, size, 1>{}([&](auto i) { + right_vec.template AsType()(Number{}) = custom_int8_t{test_vec.at(i)}; + }); + // copy the first half of a vector + vector_type left_vec{ + right_vec.template AsType::type>()(Number<0>{})}; + // check if values were copied correctly + ck::static_for<0, size / 2, 1>{}([&](auto i) { + ASSERT_EQ(left_vec.template AsType()(Number{}).data, test_vec.at(i)); + }); +} + +TEST(Custom_uint8, TestSize) +{ + struct custom_uint8_t + { + uint8_t data; + }; + ASSERT_EQ(sizeof(custom_uint8_t), sizeof(uint8_t)); + ASSERT_EQ(sizeof(vector_type), sizeof(vector_type)); + ASSERT_EQ(sizeof(vector_type), sizeof(vector_type)); + ASSERT_EQ(sizeof(vector_type), sizeof(vector_type)); + ASSERT_EQ(sizeof(vector_type), sizeof(vector_type)); + ASSERT_EQ(sizeof(vector_type), sizeof(vector_type)); + ASSERT_EQ(sizeof(vector_type), sizeof(vector_type)); +} + +TEST(Custom_uint8, TestAsType) +{ + struct custom_uint8_t + { + using type = uint8_t; + type data; + custom_uint8_t() : data{type{}} {} + custom_uint8_t(type init) : data{init} {} + }; + + // test size + const int size = 4; + std::vector test_vec = {3, 6, 8, 2}; + // reference vector + vector_type right_vec; + // check default CTOR + ck::static_for<0, size, 1>{}([&](auto i) { + ASSERT_EQ(right_vec.template AsType()(Number{}).data, 0); + }); + // assign test values to the vector + ck::static_for<0, size, 1>{}([&](auto i) { + right_vec.template AsType()(Number{}) = custom_uint8_t{test_vec.at(i)}; + }); + // copy the vector + vector_type left_vec{right_vec}; + // check if values were copied correctly + ck::static_for<0, size, 1>{}([&](auto i) { + ASSERT_EQ(left_vec.template AsType()(Number{}).data, test_vec.at(i)); + }); +} + +TEST(Custom_uint8, TestAsTypeReshape) +{ + struct custom_uint8_t + { + using type = uint8_t; + type data; + custom_uint8_t() : data{type{}} {} + custom_uint8_t(type init) : data{init} {} + }; + + // test size + const int size = 4; + std::vector test_vec = {3, 6, 8, 2}; + // reference vector + vector_type right_vec; + // check default CTOR + ck::static_for<0, size, 1>{}([&](auto i) { + ASSERT_EQ(right_vec.template AsType()(Number{}).data, 0); + }); + // assign test values to the vector + ck::static_for<0, size, 1>{}([&](auto i) { + right_vec.template AsType()(Number{}) = custom_uint8_t{test_vec.at(i)}; + }); + // copy the first half of a vector + vector_type left_vec{ + right_vec.template AsType::type>()(Number<0>{})}; + // check if values were copied correctly + ck::static_for<0, size / 2, 1>{}([&](auto i) { + ASSERT_EQ(left_vec.template AsType()(Number{}).data, test_vec.at(i)); + }); +} + +TEST(Custom_f8, TestSize) +{ + struct custom_f8_t + { + _BitInt(8) data; + }; + ASSERT_EQ(sizeof(custom_f8_t), sizeof(_BitInt(8))); + ASSERT_EQ(sizeof(vector_type), sizeof(vector_type<_BitInt(8), 2>)); + ASSERT_EQ(sizeof(vector_type), sizeof(vector_type<_BitInt(8), 4>)); + ASSERT_EQ(sizeof(vector_type), sizeof(vector_type<_BitInt(8), 8>)); + ASSERT_EQ(sizeof(vector_type), sizeof(vector_type<_BitInt(8), 16>)); + ASSERT_EQ(sizeof(vector_type), sizeof(vector_type<_BitInt(8), 32>)); + ASSERT_EQ(sizeof(vector_type), sizeof(vector_type<_BitInt(8), 64>)); +} + +TEST(Custom_f8, TestAsType) +{ + struct custom_f8_t + { + using type = _BitInt(8); + type data; + custom_f8_t() : data{type{}} {} + custom_f8_t(type init) : data{init} {} + }; + + // test size + const int size = 4; + std::vector<_BitInt(8)> test_vec = {type_convert<_BitInt(8)>(0.3f), + type_convert<_BitInt(8)>(-0.6f), + type_convert<_BitInt(8)>(0.8f), + type_convert<_BitInt(8)>(-0.2f)}; + // reference vector + vector_type right_vec; + // check default CTOR + ck::static_for<0, size, 1>{}( + [&](auto i) { ASSERT_EQ(right_vec.template AsType()(Number{}).data, 0); }); + // assign test values to the vector + ck::static_for<0, size, 1>{}([&](auto i) { + right_vec.template AsType()(Number{}) = custom_f8_t{test_vec.at(i)}; + }); + // copy the vector + vector_type left_vec{right_vec}; + // check if values were copied correctly + ck::static_for<0, size, 1>{}([&](auto i) { + ASSERT_EQ(left_vec.template AsType()(Number{}).data, test_vec.at(i)); + }); +} + +TEST(Custom_f8, TestAsTypeReshape) +{ + struct custom_f8_t + { + using type = _BitInt(8); + type data; + custom_f8_t() : data{type{}} {} + custom_f8_t(type init) : data{init} {} + }; + + // test size + const int size = 4; + std::vector<_BitInt(8)> test_vec = {type_convert<_BitInt(8)>(0.3f), + type_convert<_BitInt(8)>(-0.6f), + type_convert<_BitInt(8)>(0.8f), + type_convert<_BitInt(8)>(-0.2f)}; + // reference vector + vector_type right_vec; + // check default CTOR + ck::static_for<0, size, 1>{}( + [&](auto i) { ASSERT_EQ(right_vec.template AsType()(Number{}).data, 0); }); + // assign test values to the vector + ck::static_for<0, size, 1>{}([&](auto i) { + right_vec.template AsType()(Number{}) = custom_f8_t{test_vec.at(i)}; + }); + // copy the first half of a vector + vector_type left_vec{ + right_vec.template AsType::type>()(Number<0>{})}; + // check if values were copied correctly + ck::static_for<0, size / 2, 1>{}([&](auto i) { + ASSERT_EQ(left_vec.template AsType()(Number{}).data, test_vec.at(i)); + }); +} + +TEST(Custom_bf8, TestSize) +{ + struct custom_bf8_t + { + unsigned _BitInt(8) data; + }; + ASSERT_EQ(sizeof(custom_bf8_t), sizeof(unsigned _BitInt(8))); + ASSERT_EQ(sizeof(vector_type), sizeof(vector_type)); + ASSERT_EQ(sizeof(vector_type), sizeof(vector_type)); + ASSERT_EQ(sizeof(vector_type), sizeof(vector_type)); + ASSERT_EQ(sizeof(vector_type), sizeof(vector_type)); + ASSERT_EQ(sizeof(vector_type), sizeof(vector_type)); + ASSERT_EQ(sizeof(vector_type), sizeof(vector_type)); +} + +TEST(Custom_bf8, TestAsType) +{ + struct custom_bf8_t + { + using type = unsigned _BitInt(8); + type data; + custom_bf8_t() : data{type{}} {} + custom_bf8_t(type init) : data{init} {} + }; + + // test size + const int size = 4; + std::vector test_vec = {type_convert(0.3f), + type_convert(-0.6f), + type_convert(0.8f), + type_convert(-0.2f)}; + // reference vector + vector_type right_vec; + // check default CTOR + ck::static_for<0, size, 1>{}( + [&](auto i) { ASSERT_EQ(right_vec.template AsType()(Number{}).data, 0); }); + // assign test values to the vector + ck::static_for<0, size, 1>{}([&](auto i) { + right_vec.template AsType()(Number{}) = custom_bf8_t{test_vec.at(i)}; + }); + // copy the vector + vector_type left_vec{right_vec}; + // check if values were copied correctly + ck::static_for<0, size, 1>{}([&](auto i) { + ASSERT_EQ(left_vec.template AsType()(Number{}).data, test_vec.at(i)); + }); +} + +TEST(Custom_bf8, TestAsTypeReshape) +{ + struct custom_bf8_t + { + using type = unsigned _BitInt(8); + type data; + custom_bf8_t() : data{type{}} {} + custom_bf8_t(type init) : data{init} {} + }; + + // test size + const int size = 4; + std::vector test_vec = {type_convert(0.3f), + type_convert(-0.6f), + type_convert(0.8f), + type_convert(-0.2f)}; + // reference vector + vector_type right_vec; + // check default CTOR + ck::static_for<0, size, 1>{}( + [&](auto i) { ASSERT_EQ(right_vec.template AsType()(Number{}).data, 0); }); + // assign test values to the vector + ck::static_for<0, size, 1>{}([&](auto i) { + right_vec.template AsType()(Number{}) = custom_bf8_t{test_vec.at(i)}; + }); + // copy the first half of a vector + vector_type left_vec{ + right_vec.template AsType::type>()(Number<0>{})}; + // check if values were copied correctly + ck::static_for<0, size / 2, 1>{}([&](auto i) { + ASSERT_EQ(left_vec.template AsType()(Number{}).data, test_vec.at(i)); + }); +} + +TEST(Custom_half, TestSize) +{ + struct custom_half_t + { + half_t data; + }; + ASSERT_EQ(sizeof(custom_half_t), sizeof(half_t)); + ASSERT_EQ(sizeof(vector_type), sizeof(vector_type)); + ASSERT_EQ(sizeof(vector_type), sizeof(vector_type)); + ASSERT_EQ(sizeof(vector_type), sizeof(vector_type)); + ASSERT_EQ(sizeof(vector_type), sizeof(vector_type)); + ASSERT_EQ(sizeof(vector_type), sizeof(vector_type)); + ASSERT_EQ(sizeof(vector_type), sizeof(vector_type)); +} + +TEST(Custom_half, TestAsType) +{ + struct custom_half_t + { + using type = half_t; + type data; + custom_half_t() : data{type{}} {} + custom_half_t(type init) : data{init} {} + }; + + // test size + const int size = 4; + std::vector test_vec = {half_t{0.3f}, half_t{-0.6f}, half_t{0.8f}, half_t{-0.2f}}; + // reference vector + vector_type right_vec; + // check default CTOR + ck::static_for<0, size, 1>{}([&](auto i) { + ASSERT_EQ(right_vec.template AsType()(Number{}).data, + type_convert(0.0f)); + }); + // assign test values to the vector + ck::static_for<0, size, 1>{}([&](auto i) { + right_vec.template AsType()(Number{}) = custom_half_t{test_vec.at(i)}; + }); + // copy the vector + vector_type left_vec{right_vec}; + // check if values were copied correctly + ck::static_for<0, size, 1>{}([&](auto i) { + ASSERT_EQ(left_vec.template AsType()(Number{}).data, test_vec.at(i)); + }); +} + +TEST(Custom_half, TestAsTypeReshape) +{ + struct custom_half_t + { + using type = half_t; + type data; + custom_half_t() : data{type{}} {} + custom_half_t(type init) : data{init} {} + }; + + // test size + const int size = 4; + std::vector test_vec = {half_t{0.3f}, half_t{-0.6f}, half_t{0.8f}, half_t{-0.2f}}; + // reference vector + vector_type right_vec; + // check default CTOR + ck::static_for<0, size, 1>{}([&](auto i) { + ASSERT_EQ(right_vec.template AsType()(Number{}).data, + type_convert(0.0f)); + }); + // assign test values to the vector + ck::static_for<0, size, 1>{}([&](auto i) { + right_vec.template AsType()(Number{}) = custom_half_t{test_vec.at(i)}; + }); + // copy the first half of a vector + vector_type left_vec{ + right_vec.template AsType::type>()(Number<0>{})}; + // check if values were copied correctly + ck::static_for<0, size / 2, 1>{}([&](auto i) { + ASSERT_EQ(left_vec.template AsType()(Number{}).data, test_vec.at(i)); + }); +} + +TEST(Custom_bhalf, TestSize) +{ + struct custom_bhalf_t + { + bhalf_t data; + }; + ASSERT_EQ(sizeof(custom_bhalf_t), sizeof(bhalf_t)); + ASSERT_EQ(sizeof(vector_type), sizeof(vector_type)); + ASSERT_EQ(sizeof(vector_type), sizeof(vector_type)); + ASSERT_EQ(sizeof(vector_type), sizeof(vector_type)); + ASSERT_EQ(sizeof(vector_type), sizeof(vector_type)); + ASSERT_EQ(sizeof(vector_type), sizeof(vector_type)); + ASSERT_EQ(sizeof(vector_type), sizeof(vector_type)); +} + +TEST(Custom_bhalf, TestAsType) +{ + struct custom_bhalf_t + { + using type = bhalf_t; + type data; + custom_bhalf_t() : data{type{}} {} + custom_bhalf_t(type init) : data{init} {} + }; + + // test size + const int size = 4; + std::vector test_vec = {type_convert(0.3f), + type_convert(-0.6f), + type_convert(0.8f), + type_convert(-0.2f)}; + // reference vector + vector_type right_vec; + // check default CTOR + ck::static_for<0, size, 1>{}([&](auto i) { + ASSERT_EQ(right_vec.template AsType()(Number{}).data, + type_convert(0.0f)); + }); + // assign test values to the vector + ck::static_for<0, size, 1>{}([&](auto i) { + right_vec.template AsType()(Number{}) = custom_bhalf_t{test_vec.at(i)}; + }); + // copy the vector + vector_type left_vec{right_vec}; + // check if values were copied correctly + ck::static_for<0, size, 1>{}([&](auto i) { + ASSERT_EQ(left_vec.template AsType()(Number{}).data, test_vec.at(i)); + }); +} + +TEST(Custom_bhalf, TestAsTypeReshape) +{ + struct custom_bhalf_t + { + using type = bhalf_t; + type data; + custom_bhalf_t() : data{type{}} {} + custom_bhalf_t(type init) : data{init} {} + }; + + // test size + const int size = 4; + std::vector test_vec = {type_convert(0.3f), + type_convert(-0.6f), + type_convert(0.8f), + type_convert(-0.2f)}; + // reference vector + vector_type right_vec; + // check default CTOR + ck::static_for<0, size, 1>{}([&](auto i) { + ASSERT_EQ(right_vec.template AsType()(Number{}).data, + type_convert(0.0f)); + }); + // assign test values to the vector + ck::static_for<0, size, 1>{}([&](auto i) { + right_vec.template AsType()(Number{}) = custom_bhalf_t{test_vec.at(i)}; + }); + // copy the first half of a vector + vector_type left_vec{ + right_vec.template AsType::type>()(Number<0>{})}; + // check if values were copied correctly + ck::static_for<0, size / 2, 1>{}([&](auto i) { + ASSERT_EQ(left_vec.template AsType()(Number{}).data, test_vec.at(i)); + }); +} + +TEST(Custom_float, TestSize) +{ + struct custom_float_t + { + float data; + }; + ASSERT_EQ(sizeof(custom_float_t), sizeof(float)); + ASSERT_EQ(sizeof(vector_type), sizeof(vector_type)); + ASSERT_EQ(sizeof(vector_type), sizeof(vector_type)); + ASSERT_EQ(sizeof(vector_type), sizeof(vector_type)); + ASSERT_EQ(sizeof(vector_type), sizeof(vector_type)); + ASSERT_EQ(sizeof(vector_type), sizeof(vector_type)); + ASSERT_EQ(sizeof(vector_type), sizeof(vector_type)); +} + +TEST(Custom_float, TestAsType) +{ + struct custom_float_t + { + using type = float; + type data; + custom_float_t() : data{type{}} {} + custom_float_t(type init) : data{init} {} + }; + + // test size + const int size = 4; + std::vector test_vec = {0.3f, -0.6f, 0.8f, -0.2f}; + // reference vector + vector_type right_vec; + // check default CTOR + ck::static_for<0, size, 1>{}([&](auto i) { + ASSERT_EQ(right_vec.template AsType()(Number{}).data, 0.0f); + }); + // assign test values to the vector + ck::static_for<0, size, 1>{}([&](auto i) { + right_vec.template AsType()(Number{}) = custom_float_t{test_vec.at(i)}; + }); + // copy the vector + vector_type left_vec{right_vec}; + // check if values were copied correctly + ck::static_for<0, size, 1>{}([&](auto i) { + ASSERT_EQ(left_vec.template AsType()(Number{}).data, test_vec.at(i)); + }); +} + +TEST(Custom_float, TestAsTypeReshape) +{ + struct custom_float_t + { + using type = float; + type data; + custom_float_t() : data{type{}} {} + custom_float_t(type init) : data{init} {} + }; + + // test size + const int size = 4; + std::vector test_vec = {0.3f, -0.6f, 0.8f, -0.2f}; + // reference vector + vector_type right_vec; + // check default CTOR + ck::static_for<0, size, 1>{}([&](auto i) { + ASSERT_EQ(right_vec.template AsType()(Number{}).data, 0.0f); + }); + // assign test values to the vector + ck::static_for<0, size, 1>{}([&](auto i) { + right_vec.template AsType()(Number{}) = custom_float_t{test_vec.at(i)}; + }); + // copy the first half of a vector + vector_type left_vec{ + right_vec.template AsType::type>()(Number<0>{})}; + // check if values were copied correctly + ck::static_for<0, size / 2, 1>{}([&](auto i) { + ASSERT_EQ(left_vec.template AsType()(Number{}).data, test_vec.at(i)); + }); +} + +TEST(Custom_double, TestSize) +{ + struct custom_double_t + { + double data; + }; + ASSERT_EQ(sizeof(custom_double_t), sizeof(double)); + ASSERT_EQ(sizeof(vector_type), sizeof(vector_type)); + ASSERT_EQ(sizeof(vector_type), sizeof(vector_type)); + ASSERT_EQ(sizeof(vector_type), sizeof(vector_type)); + ASSERT_EQ(sizeof(vector_type), sizeof(vector_type)); + ASSERT_EQ(sizeof(vector_type), sizeof(vector_type)); + ASSERT_EQ(sizeof(vector_type), sizeof(vector_type)); +} + +TEST(Custom_double, TestAsType) +{ + struct custom_double_t + { + using type = double; + type data; + custom_double_t() : data{type{}} {} + custom_double_t(type init) : data{init} {} + }; + + // test size + const int size = 4; + std::vector test_vec = {0.3, 0.6, 0.8, 0.2}; + // reference vector + vector_type right_vec; + // check default CTOR + ck::static_for<0, size, 1>{}([&](auto i) { + ASSERT_EQ(right_vec.template AsType()(Number{}).data, 0.0); + }); + // assign test values to the vector + ck::static_for<0, size, 1>{}([&](auto i) { + right_vec.template AsType()(Number{}) = custom_double_t{test_vec.at(i)}; + }); + // copy the vector + vector_type left_vec{right_vec}; + // check if values were copied correctly + ck::static_for<0, size, 1>{}([&](auto i) { + ASSERT_EQ(left_vec.template AsType()(Number{}).data, test_vec.at(i)); + }); +} + +TEST(Custom_double, TestAsTypeReshape) +{ + struct custom_double_t + { + using type = double; + type data; + custom_double_t() : data{type{}} {} + custom_double_t(type init) : data{init} {} + }; + + // test size + const int size = 4; + std::vector test_vec = {0.3, 0.6, 0.8, 0.2}; + // reference vector + vector_type right_vec; + // check default CTOR + ck::static_for<0, size, 1>{}([&](auto i) { + ASSERT_EQ(right_vec.template AsType()(Number{}).data, 0.0); + }); + // assign test values to the vector + ck::static_for<0, size, 1>{}([&](auto i) { + right_vec.template AsType()(Number{}) = custom_double_t{test_vec.at(i)}; + }); + // copy the first half of a vector + vector_type left_vec{ + right_vec.template AsType::type>()(Number<0>{})}; + // check if values were copied correctly + ck::static_for<0, size / 2, 1>{}([&](auto i) { + ASSERT_EQ(left_vec.template AsType()(Number{}).data, test_vec.at(i)); + }); +} + +TEST(Complex_half, TestSize) +{ + struct complex_half_t + { + half_t real; + half_t img; + }; + ASSERT_EQ(sizeof(complex_half_t), sizeof(half_t) + sizeof(half_t)); + ASSERT_EQ(sizeof(vector_type), + sizeof(vector_type) + sizeof(vector_type)); + ASSERT_EQ(sizeof(vector_type), + sizeof(vector_type) + sizeof(vector_type)); + ASSERT_EQ(sizeof(vector_type), + sizeof(vector_type) + sizeof(vector_type)); + ASSERT_EQ(sizeof(vector_type), + sizeof(vector_type) + sizeof(vector_type)); + ASSERT_EQ(sizeof(vector_type), + sizeof(vector_type) + sizeof(vector_type)); + ASSERT_EQ(sizeof(vector_type), + sizeof(vector_type) + sizeof(vector_type)); +} + +TEST(Complex_half, TestAlignment) +{ + struct complex_half_t + { + half_t real; + half_t img; + }; + ASSERT_EQ(alignof(vector_type), + alignof(vector_type) + alignof(vector_type)); + ASSERT_EQ(alignof(vector_type), + alignof(vector_type) + alignof(vector_type)); + ASSERT_EQ(alignof(vector_type), + alignof(vector_type) + alignof(vector_type)); + ASSERT_EQ(alignof(vector_type), + alignof(vector_type) + alignof(vector_type)); + ASSERT_EQ(alignof(vector_type), + alignof(vector_type) + alignof(vector_type)); + ASSERT_EQ(alignof(vector_type), + alignof(vector_type) + alignof(vector_type)); +} + +TEST(Complex_half, TestAsType) +{ + struct complex_half_t + { + using type = half_t; + type real; + type img; + complex_half_t() : real{type{}}, img{type{}} {} + complex_half_t(type real_init, type img_init) : real{real_init}, img{img_init} {} + }; + + // test size + const int size = 4; + // custom type number of elements + const int num_elem = sizeof(complex_half_t) / sizeof(complex_half_t::type); + std::vector test_vec = {half_t{0.3f}, + half_t{-0.6f}, + half_t{0.8f}, + half_t{-0.2f}, + half_t{0.5f}, + half_t{-0.7f}, + half_t{0.9f}, + half_t{-0.3f}}; + // reference vector + vector_type right_vec; + // check default CTOR + ck::static_for<0, size, 1>{}([&](auto i) { + ASSERT_EQ(right_vec.template AsType()(Number{}).real, + type_convert(0.0f)); + ASSERT_EQ(right_vec.template AsType()(Number{}).img, + type_convert(0.0f)); + }); + // assign test values to the vector + ck::static_for<0, size, 1>{}([&](auto i) { + right_vec.template AsType()(Number{}) = + complex_half_t{test_vec.at(num_elem * i), test_vec.at(num_elem * i + 1)}; + }); + // copy the vector + vector_type left_vec{right_vec}; + // check if values were copied correctly + ck::static_for<0, size, 1>{}([&](auto i) { + ASSERT_EQ(left_vec.template AsType()(Number{}).real, + test_vec.at(num_elem * i)); + ASSERT_EQ(left_vec.template AsType()(Number{}).img, + test_vec.at(num_elem * i + 1)); + }); +} + +TEST(Complex_half, TestAsTypeReshape) +{ + struct complex_half_t + { + using type = half_t; + type real; + type img; + complex_half_t() : real{type{}}, img{type{}} {} + complex_half_t(type real_init, type img_init) : real{real_init}, img{img_init} {} + }; + + // test size + const int size = 4; + // custom type number of elements + const int num_elem = sizeof(complex_half_t) / sizeof(complex_half_t::type); + std::vector test_vec = {half_t{0.3f}, + half_t{-0.6f}, + half_t{0.8f}, + half_t{-0.2f}, + half_t{0.5f}, + half_t{-0.7f}, + half_t{0.9f}, + half_t{-0.3f}}; + // reference vector + vector_type right_vec; + // check default CTOR + ck::static_for<0, size, 1>{}([&](auto i) { + ASSERT_EQ(right_vec.template AsType()(Number{}).real, + type_convert(0.0f)); + ASSERT_EQ(right_vec.template AsType()(Number{}).img, + type_convert(0.0f)); + }); + // assign test values to the vector + ck::static_for<0, size, 1>{}([&](auto i) { + right_vec.template AsType()(Number{}) = + complex_half_t{test_vec.at(num_elem * i), test_vec.at(num_elem * i + 1)}; + }); + // copy the first half of a vector + vector_type left_vec{ + right_vec.template AsType::type>()(Number<0>{})}; + // check if values were copied correctly + ck::static_for<0, size / 2, 1>{}([&](auto i) { + ASSERT_EQ(left_vec.template AsType()(Number{}).real, + test_vec.at(num_elem * i)); + ASSERT_EQ(left_vec.template AsType()(Number{}).img, + test_vec.at(num_elem * i + 1)); + }); +}