From fabac7e2c38f134e70c4caab718579d4d44c2870 Mon Sep 17 00:00:00 2001 From: Johannes Graner Date: Thu, 29 Jan 2026 18:40:28 +0100 Subject: [PATCH 1/8] [Conv] Enable bwd weight splitk autodeduction with cap (#3656) * Enable bwd weight splitk autodeduction with cap * Fix error threshold calculations * Add missing logic to wmma multiple d kernel * Fix threshold calculation * Update test with new applicability --- .../device/device_grouped_conv_bwd_weight.hpp | 2 - ...ice_grouped_conv_bwd_weight_multiple_d.hpp | 2 - ...evice_grouped_conv_bwd_weight_explicit.hpp | 15 ++---- ...bwd_weight_multiple_d_wmma_cshuffle_v3.hpp | 53 +++++++++++++++---- ...onv_bwd_weight_multiple_d_xdl_cshuffle.hpp | 11 ++-- ...ouped_conv_bwd_weight_wmma_cshuffle_v3.hpp | 12 ++--- ...e_grouped_conv_bwd_weight_xdl_cshuffle.hpp | 11 ++-- ...rouped_conv_bwd_weight_xdl_cshuffle_v3.hpp | 12 ++--- .../profile_grouped_conv_bwd_weight_impl.hpp | 47 ++++++++++------ ...rouped_convnd_bwd_weight_interface_xdl.cpp | 2 +- 10 files changed, 91 insertions(+), 76 deletions(-) diff --git a/include/ck/tensor_operation/gpu/device/device_grouped_conv_bwd_weight.hpp b/include/ck/tensor_operation/gpu/device/device_grouped_conv_bwd_weight.hpp index 58da96e2f0..eadfa29c9f 100644 --- a/include/ck/tensor_operation/gpu/device/device_grouped_conv_bwd_weight.hpp +++ b/include/ck/tensor_operation/gpu/device/device_grouped_conv_bwd_weight.hpp @@ -11,8 +11,6 @@ namespace ck { namespace tensor_operation { namespace device { -#define DISABLE_SPLIT_K_AUTODEDUCE_FOR_ONE_STAGE_KERNELS 1 - template ()) diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_wmma_cshuffle_v3.hpp index bc072a7019..f662ff834f 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_wmma_cshuffle_v3.hpp @@ -22,6 +22,7 @@ #include #include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp" #include "ck/tensor_operation/gpu/device/impl/split_k_arg.hpp" +#include "ck/tensor_operation/gpu/device/impl/split_k_utils.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/host_utility/device_prop.hpp" @@ -524,6 +525,44 @@ struct DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3 decltype(GridwiseGemm::MakeDEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( CGridDesc_M_N{}, 1, 1)); + struct ActiveWorkgroupsPerCU + { + ActiveWorkgroupsPerCU() + { + if(!ck::is_gfx11_supported() && !ck::is_gfx12_supported()) + { + return; + } + constexpr int dynamic_smem_size = 0; + constexpr index_t minimum_occupancy = + BlkGemmPipeSched == BlockGemmPipelineScheduler::Intrawave ? 1 : 2; + int max_occupancy = 0; + + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v4) + { + // TODO: implement + } + else + { + hip_check_error(hipOccupancyMaxActiveBlocksPerMultiprocessor( + &max_occupancy, + kernel_grouped_conv_bwd_weight_wmma_cshuffle_v3_multiple_d< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t, + ComputePtrOffsetOfStridedBatch, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy>, + BlockSize, + dynamic_smem_size)); + } + max_occupancy_ = std::max(1, max_occupancy); + } + int max_occupancy_; + }; + struct Argument : public BaseArgument, public ArgumentSplitK { Argument( @@ -574,6 +613,8 @@ struct DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3 input_left_pads_{input_left_pads}, input_right_pads_{input_right_pads} { + static ActiveWorkgroupsPerCU active_workgroups_per_cu; + constexpr index_t spatial_offset = 3; std::copy(begin(b_g_n_c_wis_lengths) + spatial_offset, end(b_g_n_c_wis_lengths), @@ -585,7 +626,6 @@ struct DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3 end(a_g_n_k_wos_lengths), begin(output_spatial_lengths_)); -#if !DISABLE_SPLIT_K_AUTODEDUCE_FOR_ONE_STAGE_KERNELS if(split_k < 0) { ck::index_t gemmM, gemmN, gemmK; @@ -602,6 +642,9 @@ struct DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3 const auto k_batch_max = math::integer_divide_ceil((gemmK - 1), KPerBlock); k_batch_ = std::min(k_batch_, k_batch_max); + // Cap k_batch_ to 128 to avoid accuracy issues + k_batch_ = std::min(k_batch_, 128); + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) { std::cout << "[SPLIT-K AUTODEDUCE] k_batch max value: " << k_batch_max @@ -611,7 +654,6 @@ struct DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3 } } else -#endif { k_batch_ = split_k; } @@ -988,13 +1030,6 @@ struct DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3 static bool IsSupportedArgument(const Argument& arg) { -#if DISABLE_SPLIT_K_AUTODEDUCE_FOR_ONE_STAGE_KERNELS - if(arg.k_batch_ < 0) - { - return false; - } -#endif - const index_t GemmM = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I1); const index_t GemmN = arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I1); const index_t GemmK = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I0) * diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp index 51dc56e306..1e23fef191 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp @@ -677,7 +677,6 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle end(a_g_n_k_wos_lengths), begin(output_spatial_lengths_)); -#if !DISABLE_SPLIT_K_AUTODEDUCE_FOR_ONE_STAGE_KERNELS if(split_k < 0) { ck::index_t gemmM, gemmN; @@ -688,9 +687,11 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle calculate_mn_grid_size(gemmM, gemmN) * Conv_G_; k_batch_ = get_best_occupancy_k_batch_value(active_workgroups_per_cu.max_occupancy_, grid_size); + + // Cap k_batch_ to 128 to avoid accuracy issues + k_batch_ = std::min(k_batch_, 128); } else -#endif { k_batch_ = split_k; } @@ -947,12 +948,6 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle static bool IsSupportedArgument(const Argument& arg) { -#if DISABLE_SPLIT_K_AUTODEDUCE_FOR_ONE_STAGE_KERNELS - if(arg.k_batch_ < 0) - { - return false; - } -#endif if(!ck::is_xdl_wmma_supported()) { return false; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle_v3.hpp index 3f8093afe1..b2ae092c27 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle_v3.hpp @@ -511,7 +511,7 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffleV3 std::copy(begin(a_g_n_k_wos_lengths) + spatial_offset, end(a_g_n_k_wos_lengths), begin(output_spatial_lengths_)); -#if !DISABLE_SPLIT_K_AUTODEDUCE_FOR_ONE_STAGE_KERNELS + if(split_k < 0) { ck::index_t gemmM, gemmN, gemmK; @@ -528,6 +528,9 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffleV3 const auto k_batch_max = math::integer_divide_ceil((gemmK - 1), KPerBlock); k_batch_ = std::min(k_batch_, k_batch_max); + // Cap k_batch_ to 128 to avoid accuracy issues + k_batch_ = std::min(k_batch_, 128); + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) { std::cout << "[SPLIT-K AUTODEDUCE] k_batch max value: " << k_batch_max @@ -537,7 +540,6 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffleV3 } } else -#endif { k_batch_ = split_k; } @@ -1040,12 +1042,6 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffleV3 static bool IsSupportedArgument(const Argument& arg) { -#if DISABLE_SPLIT_K_AUTODEDUCE_FOR_ONE_STAGE_KERNELS - if(arg.k_batch_ < 0) - { - return false; - } -#endif const index_t GemmM = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I1); const index_t GemmN = arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I1); const index_t GemmK = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I0) * diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp index 0ea94806d0..1f6f2fb789 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp @@ -651,7 +651,6 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle conv_ngchw_to_nhwgc_transformer.TransposeWeiStrides(e_g_k_c_xs_lengths, e_g_k_c_xs_strides); -#if !DISABLE_SPLIT_K_AUTODEDUCE_FOR_ONE_STAGE_KERNELS if(split_k < 0) { ck::index_t gemmM, gemmN; @@ -662,9 +661,11 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle calculate_mn_grid_size(gemmM, gemmN) * Conv_G_; k_batch_ = get_best_occupancy_k_batch_value(active_workgroups_per_cu.max_occupancy_, grid_size); + + // Cap k_batch_ to 128 to avoid accuracy issues + k_batch_ = std::min(k_batch_, 128); } else -#endif { k_batch_ = split_k; } @@ -1083,12 +1084,6 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle static bool IsSupportedArgument(const Argument& arg) { -#if DISABLE_SPLIT_K_AUTODEDUCE_FOR_ONE_STAGE_KERNELS - if(arg.k_batch_ < 0) - { - return false; - } -#endif if(!ck::is_xdl_wmma_supported()) { return false; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp index 26cf586017..ac83cee251 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp @@ -594,7 +594,6 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 end(a_g_n_k_wos_lengths), begin(output_spatial_lengths_)); -#if !DISABLE_SPLIT_K_AUTODEDUCE_FOR_ONE_STAGE_KERNELS if(split_k < 0) { ck::index_t gemmM, gemmN, gemmK; @@ -611,6 +610,9 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 const auto k_batch_max = static_cast((gemmK - 1) / K0PerBlock); k_batch_ = std::max(std::min(k_batch_, k_batch_max), 1); + // Cap k_batch_ to 128 to avoid accuracy issues + k_batch_ = std::min(k_batch_, 128); + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) { std::cout << "[SPLIT-K AUTODEDUCE] k_batch max value: " << k_batch_max @@ -620,7 +622,6 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 } } else -#endif { k_batch_ = split_k; } @@ -1399,13 +1400,6 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 static bool IsSupportedArgument(const Argument& arg) { -#if DISABLE_SPLIT_K_AUTODEDUCE_FOR_ONE_STAGE_KERNELS - if(arg.k_batch_ < 0) - { - return false; - } -#endif - // check device if constexpr(DirectLoad) { diff --git a/profiler/include/profiler/profile_grouped_conv_bwd_weight_impl.hpp b/profiler/include/profiler/profile_grouped_conv_bwd_weight_impl.hpp index 3a9f14e595..afc88150ed 100644 --- a/profiler/include/profiler/profile_grouped_conv_bwd_weight_impl.hpp +++ b/profiler/include/profiler/profile_grouped_conv_bwd_weight_impl.hpp @@ -364,26 +364,39 @@ bool profile_grouped_conv_bwd_weight_impl(int do_verification, using AccDataType = std::conditional_t, int32_t, float>; - // Calculate number of accumulations accounting for split_k - const int num_accums = - static_cast(output.GetElementSize() / conv_param.K_ / split_k_value); - - // Additional tolerance for split_k accumulation if needed - int total_accums = num_accums; - if(split_k_value > 1) - { - total_accums = std::max(num_accums, static_cast(split_k_value)); - } - - // Perform GPU verification (max value computed internally on GPU) + const index_t num_accums = output.GetElementSize() / conv_param.K_; + const index_t num_accums_split_k = split_k_value; + // Get maximum accumulated value from reference const std::size_t tensor_size = weight_device_result.mDesc.GetElementSpaceSize(); + max_accumulated_value = + gpu_reduce_max(gpu_ref_wei_buf.GetDeviceBuffer(), tensor_size); + // Calculate thresholds + auto rtol = + ck::utils::get_relative_threshold( + num_accums / num_accums_split_k); + auto atol = + ck::utils::get_absolute_threshold( + max_accumulated_value / num_accums_split_k, + num_accums / num_accums_split_k); + // Calculate error due to split_k accumulation + auto rtol_split_k = + ck::utils::get_relative_threshold( + num_accums_split_k); + auto atol_split_k = + ck::utils::get_absolute_threshold( + max_accumulated_value, num_accums_split_k); + // Use higher threshold + rtol = std::max(rtol, rtol_split_k); + atol = std::max(atol, atol_split_k); + + // Perform GPU verification auto gpu_result = - ck::profiler::gpu_verify( - wei_device_buf.GetDeviceBuffer(), - gpu_ref_wei_buf.GetDeviceBuffer(), - total_accums, - tensor_size); + ck::profiler::gpu_verify(wei_device_buf.GetDeviceBuffer(), + gpu_ref_wei_buf.GetDeviceBuffer(), + rtol, + atol, + tensor_size); if(!gpu_result) { diff --git a/test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight_interface_xdl.cpp b/test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight_interface_xdl.cpp index bce6da4b68..5aa0b13c07 100644 --- a/test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight_interface_xdl.cpp +++ b/test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight_interface_xdl.cpp @@ -184,5 +184,5 @@ TYPED_TEST(TestGroupedConvndBwdWeightDefault, SingleStageAutoDeduce) this->conv_param = {2, 2, 128, 128, 256, {1, 1}, {3, 3}, {1, 1}, {1, 1}, {0, 0}, {0, 0}}; this->split_k_ = -1; bool is_supported = this->template Run<2>(); - EXPECT_FALSE(is_supported); + EXPECT_TRUE(is_supported); } From f16d9100e42a978261f76319c66a7995e5f6d555 Mon Sep 17 00:00:00 2001 From: Enrico Degregori <73224202+EnricoDeg@users.noreply.github.com> Date: Thu, 29 Jan 2026 19:29:40 +0100 Subject: [PATCH 2/8] Multi AB support for wave transfer (#3578) * Add multi AB support to wave transfer * Improviments to multi ABD examples * Add instances and use intrawave v1 instead of interwave * Apply changes to other transfers * Wave transfer: add support for multiple internal vgpr buffers * Fix compilation error gfx11 --- ...m_multi_ABD_wmma_bias_fastgelu_bf16_i8.cpp | 29 +- .../gemm_multi_ABD_wmma_fastgelu_bf16_i8.cpp | 30 ++- .../gemm_multi_ABD_wmma_fp16.cpp | 27 +- ...BD_wmma_multiply_bias_fastgelu_bf16_i8.cpp | 29 +- ...ead_group_tensor_slice_transfer_global.hpp | 247 +++++++++++------- .../gridwise_ab_transfer_thread_tiles.hpp | 13 + ...se_ab_transfer_thread_tiles_preshuffle.hpp | 13 + .../grid/gridwise_ab_transfer_wave_tiles.hpp | 46 ++-- ...wise_ab_transfer_wave_tiles_interleave.hpp | 43 +-- .../gridwise_gemm_wmma_cshuffle_v3_common.hpp | 27 +- ...multi_abd_bf16_i8_bf16_mk_kn_mn_common.hpp | 19 +- ..._abd_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp | 2 +- ...multi_abd_bf16_i8_bf16_mk_nk_mn_common.hpp | 15 +- ...bias_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp | 2 +- ...gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp | 2 +- ...gelu_bf16_i8_bf16_mk_nk_mn_v1_instance.cpp | 8 +- ...gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp | 2 +- ...iply_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp | 2 +- ...bias_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp | 2 +- ...gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp | 2 +- ...gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp | 2 +- 21 files changed, 374 insertions(+), 188 deletions(-) diff --git a/example/60_gemm_multi_ABD/gemm_multi_ABD_wmma_bias_fastgelu_bf16_i8.cpp b/example/60_gemm_multi_ABD/gemm_multi_ABD_wmma_bias_fastgelu_bf16_i8.cpp index cf8dd31c3f..78d98e92ce 100644 --- a/example/60_gemm_multi_ABD/gemm_multi_ABD_wmma_bias_fastgelu_bf16_i8.cpp +++ b/example/60_gemm_multi_ABD/gemm_multi_ABD_wmma_bias_fastgelu_bf16_i8.cpp @@ -96,11 +96,11 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleABD_Wmm 8, 8, 0, - S<8, 32, 1>, + S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, - 1, + 8, 8, 0, 1, @@ -108,7 +108,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleABD_Wmm S<1, 32, 1, 8>, S<8, 8, 8>, ck::BlockGemmPipelineScheduler::Intrawave, - ck::BlockGemmPipelineVersion::v3>; + ck::BlockGemmPipelineVersion::v1>; int main(int argc, char* argv[]) { @@ -174,6 +174,29 @@ int main(int argc, char* argv[]) } }; + auto f_get_default_stride = + [](std::size_t row, std::size_t col, ck::index_t stride, auto layout) { + if(stride == -1 || stride == 0) + { + // give a chance if stride is -1, return a default packed stride + if constexpr(std::is_same_v) + { + return static_cast(col); + } + else + { + return static_cast(row); + } + } + else + return static_cast(stride); + }; + + StrideA = f_get_default_stride(M, K, StrideA, A0Layout{}); + StrideB = f_get_default_stride(K, N, StrideB, B0Layout{}); + StrideD = f_get_default_stride(M, N, StrideD, D0Layout{}); + StrideE = f_get_default_stride(M, N, StrideE, ELayout{}); + Tensor a0_m_k(f_host_tensor_descriptor(M, K, StrideA, A0Layout{})); Tensor b0_k_n(f_host_tensor_descriptor(K, N, StrideB, B0Layout{})); Tensor b1_k_n(f_host_tensor_descriptor(K, N, StrideB, B1Layout{})); diff --git a/example/60_gemm_multi_ABD/gemm_multi_ABD_wmma_fastgelu_bf16_i8.cpp b/example/60_gemm_multi_ABD/gemm_multi_ABD_wmma_fastgelu_bf16_i8.cpp index e4033e5bac..089404757a 100644 --- a/example/60_gemm_multi_ABD/gemm_multi_ABD_wmma_fastgelu_bf16_i8.cpp +++ b/example/60_gemm_multi_ABD/gemm_multi_ABD_wmma_fastgelu_bf16_i8.cpp @@ -94,11 +94,11 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleABD_Wmm 8, 8, 0, - S<8, 32, 1>, + S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, - 1, + 8, 8, 0, 1, @@ -106,7 +106,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleABD_Wmm S<1, 32, 1, 8>, S<8, 8, 8>, ck::BlockGemmPipelineScheduler::Intrawave, - ck::BlockGemmPipelineVersion::v3>; + ck::BlockGemmPipelineVersion::v1>; int main(int argc, char* argv[]) { @@ -133,7 +133,7 @@ int main(int argc, char* argv[]) init_method = std::stoi(argv[2]); time_kernel = std::stoi(argv[3]); } - else if(argc == 11) + else if(argc == 10) { do_verification = std::stoi(argv[1]); init_method = std::stoi(argv[2]); @@ -170,6 +170,28 @@ int main(int argc, char* argv[]) } }; + auto f_get_default_stride = + [](std::size_t row, std::size_t col, ck::index_t stride, auto layout) { + if(stride == -1 || stride == 0) + { + // give a chance if stride is -1, return a default packed stride + if constexpr(std::is_same_v) + { + return static_cast(col); + } + else + { + return static_cast(row); + } + } + else + return static_cast(stride); + }; + + StrideA = f_get_default_stride(M, K, StrideA, A0Layout{}); + StrideB = f_get_default_stride(K, N, StrideB, B0Layout{}); + StrideE = f_get_default_stride(M, N, StrideE, ELayout{}); + Tensor a0_m_k(f_host_tensor_descriptor(M, K, StrideA, A0Layout{})); Tensor b0_k_n(f_host_tensor_descriptor(K, N, StrideB, B0Layout{})); Tensor b1_k_n(f_host_tensor_descriptor(K, N, StrideB, B1Layout{})); diff --git a/example/60_gemm_multi_ABD/gemm_multi_ABD_wmma_fp16.cpp b/example/60_gemm_multi_ABD/gemm_multi_ABD_wmma_fp16.cpp index 5817269fdf..d5ccf7eb59 100644 --- a/example/60_gemm_multi_ABD/gemm_multi_ABD_wmma_fp16.cpp +++ b/example/60_gemm_multi_ABD/gemm_multi_ABD_wmma_fp16.cpp @@ -141,11 +141,11 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleABD_Wmm 8, 8, 0, - S<4, 64, 1>, + S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, - 1, + 8, 8, 0, 1, @@ -233,6 +233,29 @@ int main(int argc, char* argv[]) } }; + auto f_get_default_stride = + [](std::size_t row, std::size_t col, ck::index_t stride, auto layout) { + if(stride == -1 || stride == 0) + { + // give a chance if stride is -1, return a default packed stride + if constexpr(std::is_same_v) + { + return static_cast(col); + } + else + { + return static_cast(row); + } + } + else + return static_cast(stride); + }; + + StrideA = f_get_default_stride(M, K, StrideA, ALayout{}); + StrideB = f_get_default_stride(K, N, StrideB, BLayout{}); + StrideD = f_get_default_stride(M, N, StrideD, DLayout{}); + StrideE = f_get_default_stride(M, N, StrideE, ELayout{}); + Tensor a0_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); Tensor a1_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); diff --git a/example/60_gemm_multi_ABD/gemm_multi_ABD_wmma_multiply_bias_fastgelu_bf16_i8.cpp b/example/60_gemm_multi_ABD/gemm_multi_ABD_wmma_multiply_bias_fastgelu_bf16_i8.cpp index 4fb1a5ab4e..2d07bc480d 100644 --- a/example/60_gemm_multi_ABD/gemm_multi_ABD_wmma_multiply_bias_fastgelu_bf16_i8.cpp +++ b/example/60_gemm_multi_ABD/gemm_multi_ABD_wmma_multiply_bias_fastgelu_bf16_i8.cpp @@ -95,11 +95,11 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleABD_Wmm 8, 8, 0, - S<8, 32, 1>, + S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, - 1, + 8, 8, 0, 1, @@ -107,7 +107,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleABD_Wmm S<1, 32, 1, 8>, S<8, 8, 8>, ck::BlockGemmPipelineScheduler::Intrawave, - ck::BlockGemmPipelineVersion::v3>; + ck::BlockGemmPipelineVersion::v1>; int main(int argc, char* argv[]) { @@ -173,6 +173,29 @@ int main(int argc, char* argv[]) } }; + auto f_get_default_stride = + [](std::size_t row, std::size_t col, ck::index_t stride, auto layout) { + if(stride == -1 || stride == 0) + { + // give a chance if stride is -1, return a default packed stride + if constexpr(std::is_same_v) + { + return static_cast(col); + } + else + { + return static_cast(row); + } + } + else + return static_cast(stride); + }; + + StrideA = f_get_default_stride(M, K, StrideA, A0Layout{}); + StrideB = f_get_default_stride(K, N, StrideB, B0Layout{}); + StrideD = f_get_default_stride(M, N, StrideD, D0Layout{}); + StrideE = f_get_default_stride(M, N, StrideE, ELayout{}); + Tensor a0_m_k(f_host_tensor_descriptor(M, K, StrideA, A0Layout{})); Tensor b0_k_n(f_host_tensor_descriptor(K, N, StrideB, B0Layout{})); Tensor d0_m_n(f_host_tensor_descriptor(M, N, StrideD, D0Layout{})); diff --git a/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_global.hpp b/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_global.hpp index 1c322fe4a7..d1c6f30a14 100644 --- a/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_global.hpp +++ b/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_global.hpp @@ -12,16 +12,17 @@ namespace ck { -template + bool DoTranspose, + index_t NumThreadScratch = 1> struct ThreadGroupTransferGlobal { static constexpr auto I0 = Number<0>{}; @@ -32,24 +33,57 @@ struct ThreadGroupTransferGlobal static constexpr auto I5 = Number<5>{}; static constexpr auto I6 = Number<6>{}; - static constexpr index_t nDim = remove_reference_t::GetNumOfDimension(); - using Index = MultiIndex; - using SrcCoord = decltype(make_tensor_coordinate(SrcDesc{}, Index{})); - using DstCoord = decltype(make_tensor_coordinate(DstDesc{}, Index{})); + // return a tuple of coordiantes for a tuple of tensor + template = false> + static constexpr auto MakeCoordinates(const Descs& descs, const Indices& indices) + { + return generate_tuple([&](auto i) { return make_tensor_coordinate(descs[i], indices[i]); }, + Number{}); + } - __device__ ThreadGroupTransferGlobal(const SrcDesc& src_desc, - const DstDesc& dst_desc, - const Index& src_block_slice_origin, - const Index& dst_block_slice_origin, - const ElementwiseOperation& element_op) - : src_coord_(make_tensor_coordinate(src_desc, src_block_slice_origin)), + static constexpr index_t nDim = + remove_cvref_t>::GetNumOfDimension(); + static constexpr index_t nSrc = SrcDescs::Size(); + using Index = MultiIndex; + using SrcCoords = decltype(MakeCoordinates(SrcDescs{}, StaticallyIndexedArray{})); + using DstCoord = decltype(make_tensor_coordinate(DstDesc{}, Index{})); + + __device__ + ThreadGroupTransferGlobal(const SrcDescs& src_descs, + const DstDesc& dst_desc, + const StaticallyIndexedArray& src_block_slice_origins, + const Index& dst_block_slice_origin, + const ElementwiseOperation& element_op) + : src_coords_(MakeCoordinates(src_descs, src_block_slice_origins)), dst_coord_(make_tensor_coordinate(dst_desc, dst_block_slice_origin)), element_op_(element_op) { } - template - __device__ void RunRead(const SrcDesc& src_desc, const GridBufferType& grid_buf) + template + __device__ static auto generate_vectors() + { + auto data_types = DataTypes_{}; + + constexpr index_t num = data_types.Size(); + + return generate_tuple( + [&](auto i) { + using DataType = remove_cvref_t; + + return vector_type_maker_t{}; + }, + Number{}); + } + + template = false> + __device__ void RunRead(SrcDescs& src_descs, + const GridBufferTypes& grid_bufs, + Number thread_scratch_id = Number{}) { constexpr auto src_access_lengths = NumberOfIterations{}; constexpr auto src_dim_access_order = IterationOrder{}; @@ -57,36 +91,6 @@ struct ThreadGroupTransferGlobal container_reorder_given_new2old(src_access_lengths, src_dim_access_order); constexpr auto ordered_fwd_step = StepsPerIteration{}; - // make forward steps - // forward step for each iteration just add 1 - const auto src_forward_steps = generate_tuple( - [&](auto i) { - Index forward_step_idx; - - static_for<0, nDim, 1>{}([&](auto j) { - forward_step_idx(j) = (i.value == j.value) ? ordered_fwd_step[i] : 0; - }); - - return make_tensor_coordinate_step(src_desc, forward_step_idx); - }, - Number{}); - - // make backward steps - // backward step at the end of the dimension iteration subtract IterationLength - 1 - const auto src_backward_steps = generate_tuple( - [&](auto i) { - Index backward_step_idx; - - static_for<0, nDim, 1>{}([&](auto j) { - backward_step_idx(j) = (i.value == j.value) - ? (-src_access_lengths[i] + 1) * ordered_fwd_step[i] - : 0; - }); - - return make_tensor_coordinate_step(src_desc, backward_step_idx); - }, - Number{}); - static_ford{}([&](auto ordered_src_access_idx) { // judge move forward or move backward constexpr auto forward_sweep = [&]() { @@ -157,10 +161,26 @@ struct ThreadGroupTransferGlobal }, Number{}); - // check if src element is valid - const bool is_src_valid = - coordinate_has_valid_offset_assuming_visible_index_is_valid(src_desc, src_coord_); - oob_thread_scratch_.template SetAsType(vgpr_data_idx_seq, is_src_valid); + auto src_vectors = generate_vectors(); + bool oob_val = true; + + static_for<0, nSrc, 1>{}([&](auto i) { + using src_vector_t = typename remove_cvref_t::type; + // check if src element is valid + const bool is_src_valid = + coordinate_has_valid_offset_assuming_visible_index_is_valid(src_descs[i], + src_coords_[i]); + + oob_val = oob_val & is_src_valid; + + // Load data from memory in src_vector first + auto index = is_src_valid || !DoTranspose ? src_coords_[i].GetOffset() : 0; + src_vectors(i).template AsType()(I0) = + grid_bufs[i].template Get(index, true); + }); + + oob_thread_scratch_(thread_scratch_id) + .template SetAsType(vgpr_data_idx_seq, oob_val); // Vector length of elementwise operation constexpr auto get_elem_op_vec_len = []() { @@ -185,57 +205,105 @@ struct ThreadGroupTransferGlobal } }; - // This is 1 for pass through because internally it's doing type conversion constexpr index_t elem_op_vec_len = get_elem_op_vec_len(); - using src_vector_container = vector_type_maker_t; - using src_vector_container_t = typename src_vector_container::type; - - using elem_op_vec_t = typename vector_type::type; - using dst_vector_type = vector_type_maker_t; using dst_vector_t = typename dst_vector_type::type; - dst_vector_type op_r_v; - // Load data from memory in src_vector first - auto index = is_src_valid || !DoTranspose ? src_coord_.GetOffset() : 0; - src_vector_container src_vector = src_vector_container{ - grid_buf.template Get(index, true)}; - // apply the src elementwise op and convert to DstData under the hood if needed static_for<0, VectorSize / elem_op_vec_len, 1>{}([&](auto idx) { - element_op_(op_r_v.template AsType()(idx), - src_vector.template AsType()[idx]); + // get reference to src data + const auto src_data_refs = generate_tie( + // return type should be lvalue + [&](auto iSrc) -> const auto& { + using SrcData = remove_cvref_t>; + + using elem_op_vec_t = typename vector_type::type; + + return src_vectors[iSrc].template AsType()[idx]; + }, + Number{}); + + // get reference to dst data + auto dst_data_refs = generate_tie( + // return type should be lvalue + [&](auto) -> auto& { + using elem_op_vec_t = typename vector_type::type; + + return op_r_v.template AsType()(idx); + }, + Number<1>{}); + + // apply pointwise function + unpack2(element_op_, dst_data_refs, src_data_refs); }); // store result in dvgpr_ (static array holding loaded data). // At this point data is already converted to DstData type and // the elementwise operation has been applied - src_dvgpr_.template SetAsType(vgpr_data_idx_seq, - op_r_v.template AsType()[I0]); + src_dvgpr_(thread_scratch_id) + .template SetAsType(vgpr_data_idx_seq, + op_r_v.template AsType()[I0]); - // For each dimension move fwd, bwd or don't move - static_for<0, nDim, 1>{}([&](auto i) { - if constexpr(move_on_dim[i]) - { - if constexpr(forward_sweep[i]) + // Move each src coordinate + static_for<0, nSrc, 1>{}([&](auto iSrc) { + // make forward steps + // forward step for each iteration just add 1 + const auto src_forward_steps = generate_tuple( + [&](auto iDim) { + Index forward_step_idx; + + static_for<0, nDim, 1>{}([&](auto j) { + forward_step_idx(j) = + (iDim.value == j.value) ? ordered_fwd_step[iDim] : 0; + }); + return make_tensor_coordinate_step(src_descs[iSrc], forward_step_idx); + }, + Number{}); + + // make backward steps + // backward step at the end of the dimension iteration subtract IterationLength - 1 + const auto src_backward_steps = generate_tuple( + [&](auto iDim) { + Index backward_step_idx; + + static_for<0, nDim, 1>{}([&](auto j) { + backward_step_idx(j) = + (iDim.value == j.value) + ? (-src_access_lengths[iDim] + 1) * ordered_fwd_step[iDim] + : 0; + }); + return make_tensor_coordinate_step(src_descs[iSrc], backward_step_idx); + }, + Number{}); + + // For each dimension move fwd, bwd or don't move + static_for<0, nDim, 1>{}([&](auto i) { + if constexpr(move_on_dim[i]) { - move_tensor_coordinate( - src_desc, src_coord_, src_forward_steps[src_dim_access_order[i]]); + if constexpr(forward_sweep[i]) + { + move_tensor_coordinate(src_descs[iSrc], + src_coords_(iSrc), + src_forward_steps[src_dim_access_order[i]]); + } + else + { + move_tensor_coordinate(src_descs[iSrc], + src_coords_(iSrc), + src_backward_steps[src_dim_access_order[i]]); + } } - else - { - move_tensor_coordinate( - src_desc, src_coord_, src_backward_steps[src_dim_access_order[i]]); - } - } + }); }); }); } - template - __device__ void RunWrite(const DstDesc& dst_desc, BlockBufferType& dst_buf) + template + __device__ void RunWrite(const DstDesc& dst_desc, + BlockBufferType& dst_buf, + Number thread_scratch_id = Number{}) { using dst_vector_type = vector_type_maker_t; using dst_vector_t = typename dst_vector_type::type; @@ -272,9 +340,10 @@ struct ThreadGroupTransferGlobal }, Number{}); - auto op_r = src_dvgpr_.template GetAsType(vgpr_data_idx_seq); + auto op_r = + src_dvgpr_(thread_scratch_id).template GetAsType(vgpr_data_idx_seq); const bool is_src_valid = - oob_thread_scratch_.template GetAsType(vgpr_data_idx_seq); + oob_thread_scratch_(thread_scratch_id).template GetAsType(vgpr_data_idx_seq); auto op_r_v = is_src_valid ? op_r : dst_vector_t(0); dst_dvgpr_.template SetAsType(vgpr_data_idx_seq, op_r_v); }); @@ -404,10 +473,12 @@ struct ThreadGroupTransferGlobal }); } - __device__ void MoveSrcSliceWindow(const SrcDesc& src_desc, const Index& step) + __device__ void MoveSrcSliceWindow(const SrcDescs& src_descs, const Index& step) { - const auto adjusted_step = make_tensor_coordinate_step(src_desc, step); - move_tensor_coordinate(src_desc, src_coord_, adjusted_step); + static_for<0, nSrc, 1>{}([&](auto iSrc) { + const auto adjusted_step = make_tensor_coordinate_step(src_descs[iSrc], step); + move_tensor_coordinate(src_descs[iSrc], src_coords_(iSrc), adjusted_step); + }); } private: @@ -443,10 +514,10 @@ struct ThreadGroupTransferGlobal decltype(src_oob_thread_scratch_desc_), true>; - ThreadScratchData src_dvgpr_; + StaticallyIndexedArray src_dvgpr_; ThreadScratchData dst_dvgpr_; - OOBThreadScratch oob_thread_scratch_; - SrcCoord src_coord_; + StaticallyIndexedArray oob_thread_scratch_; + SrcCoords src_coords_; DstCoord dst_coord_; const ElementwiseOperation element_op_; }; diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_ab_transfer_thread_tiles.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_ab_transfer_thread_tiles.hpp index 96387c6f64..4d5c052e02 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_ab_transfer_thread_tiles.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_ab_transfer_thread_tiles.hpp @@ -488,6 +488,19 @@ struct ABTransferThreadTiles { return make_dynamic_buffer(p_shared_AB, size); } + + template + __device__ __forceinline__ static auto get_first_element_workaround(Type& array) + { + if constexpr(numElements > 1) + { + return array; + } + else + { + return array[I0]; + } + } }; } // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_ab_transfer_thread_tiles_preshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_ab_transfer_thread_tiles_preshuffle.hpp index ad9af92ae5..fb6d1451d3 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_ab_transfer_thread_tiles_preshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_ab_transfer_thread_tiles_preshuffle.hpp @@ -133,6 +133,19 @@ struct ABTransferThreadTilesPreShuffle { return make_static_buffer(size); } + + template + __device__ __forceinline__ static auto get_first_element_workaround(Type& array) + { + if constexpr(numElements > 1) + { + return array; + } + else + { + return array[I0]; + } + } }; } // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_ab_transfer_wave_tiles.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_ab_transfer_wave_tiles.hpp index caf468d6cb..63c0299750 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_ab_transfer_wave_tiles.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_ab_transfer_wave_tiles.hpp @@ -318,43 +318,43 @@ struct ABTransferWaveTiles const index_t block_mn_id, const index_t) { - // Note: GlobalBufferNum is currently not used but it will be needed - // once we add other pipelines. It is currently needed only for - // consistency with the thread tiles approach - static_assert(GlobalBufferNum == 1, "single global buffer is only supported"); constexpr index_t NumABTensor = ABsDataType::Size(); - static_assert(NumABTensor == 1, "multiAB currently not supported"); - - using ABDataType = remove_cvref_t>; const auto wave_idx = GetWaveIdx(); index_t wave_idK = wave_idx[I1]; index_t wave_idMN = wave_idx[I0]; - const auto grid_lane_id = GetGridLaneIdx(); - index_t lane_group_grid = grid_lane_id[I0]; - index_t lane_local_id_grid = grid_lane_id[I1]; - const auto block_lane_id = GetBlockLaneIdx(); index_t lane_group_block = block_lane_id[I0]; index_t lane_local_id_block = block_lane_id[I1]; - return ThreadGroupTransferGlobal>; + const auto grid_lane_id = GetGridLaneIdx(); + index_t lane_group_grid = grid_lane_id[I0]; + index_t lane_local_id_grid = grid_lane_id[I1]; + return make_multi_index(block_mn_id * (MNRepeat_ * MNWaves_) + wave_idMN, + wave_idK, + lane_group_grid, + lane_local_id_grid); + }, + Number{}); + + return ThreadGroupTransferGlobal, Sequence, Sequence, ABK1Value, - ABDoTranspose>( - grid_descriptor[I0], + ABDoTranspose, + GlobalBufferNum>( + grid_descriptor, block_descriptor, - make_multi_index(block_mn_id * (MNRepeat_ * MNWaves_) + wave_idMN, - wave_idK, - lane_group_grid, - lane_local_id_grid), + idx_as_block_begin, make_multi_index(wave_idMN, wave_idK, lane_group_block, lane_local_id_block), ab_element_op); } @@ -398,6 +398,12 @@ struct ABTransferWaveTiles { return make_dynamic_buffer(p_shared_AB, size); } + + template + __device__ __forceinline__ static auto get_first_element_workaround(Type& array) + { + return array; + } }; } // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_ab_transfer_wave_tiles_interleave.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_ab_transfer_wave_tiles_interleave.hpp index bfe5b7bd08..e1ee47770b 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_ab_transfer_wave_tiles_interleave.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_ab_transfer_wave_tiles_interleave.hpp @@ -218,45 +218,46 @@ struct ABTransferWaveTilesInterleave : ABTransferWaveTiles>; const auto wave_idx = GetWaveIdx(); index_t wave_idK = wave_idx[I1]; index_t wave_idMN = wave_idx[I0]; - const auto grid_lane_id = Base::template GetGridLaneIdx(); - index_t lane_group_grid = grid_lane_id[I0]; - index_t lane_local_id_grid = grid_lane_id[I1]; - const auto block_lane_id = GetBlockLaneIdx(); index_t lane_group_block = block_lane_id[I0]; index_t lane_local_id_block = block_lane_id[I1]; constexpr index_t MNRepeatRatio = MNRepeat_Grid / MNRepeat_; - return ThreadGroupTransferGlobal>; + const auto grid_lane_id = Base::template GetGridLaneIdx(); + index_t lane_group_grid = grid_lane_id[I0]; + index_t lane_local_id_grid = grid_lane_id[I1]; + return make_multi_index(block_mn_id * MNWaves_Grid + wave_idMN / MNRepeatRatio, + wave_idK * KRepeat_Grid, + (wave_idMN % MNRepeatRatio) * MNRepeat_, + lane_group_grid, + lane_local_id_grid); + }, + Number{}); + + return ThreadGroupTransferGlobal, Sequence, Sequence, ABK1Value, - ABDoTranspose>( - grid_descriptor[I0], + ABDoTranspose, + GlobalBufferNum>( + grid_descriptor, block_descriptor, - make_multi_index(block_mn_id * MNWaves_Grid + wave_idMN / MNRepeatRatio, - wave_idK * KRepeat_Grid, - (wave_idMN % MNRepeatRatio) * MNRepeat_, - lane_group_grid, - lane_local_id_grid), + idx_as_block_begin, make_multi_index(wave_idMN / MNRepeatRatio, wave_idK * KRepeat_, (wave_idMN % MNRepeatRatio) * MNRepeat_, diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp index bcf131003c..03735bbc6a 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp @@ -364,7 +364,7 @@ struct GridwiseGemm_wmma_cshuffle_v3_base __host__ __device__ static constexpr bool AWaveTransferApplicable() { - return !ForceThreadTileTransfer && NumATensor == 1 && APackedSize == 1 && + return !ForceThreadTileTransfer && APackedSize == 1 && ABlockTransferSrcScalarPerVector == 8 && ABlockTransferDstScalarPerVector_AK1 == 8 && BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 && AK1Value == 8 && !IsBPreShuffled; @@ -372,13 +372,11 @@ struct GridwiseGemm_wmma_cshuffle_v3_base __host__ __device__ static constexpr bool BWaveTransferApplicable() { - return !ForceThreadTileTransfer && NumBTensor == 1 && BPackedSize == 1 && + return !ForceThreadTileTransfer && BPackedSize == 1 && BBlockTransferSrcScalarPerVector == 8 && BBlockTransferDstScalarPerVector_BK1 == 8 && BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 && BK1Value == 8; } - // Limitations of the current implementation: - // - no multiAB #ifdef __gfx12__ static constexpr bool IsAWaveTransferApplicable = AWaveTransferApplicable(); @@ -1319,19 +1317,6 @@ struct GridwiseGemm_wmma_cshuffle_v3_base } } - template - __device__ __forceinline__ static auto get_first_element_workaround(Type& array) - { - if constexpr(numElements > 1) - { - return array; - } - else - { - return array[I0]; - } - } - // Note: arguments k_batch and k_id should be set if splitk is used // with implicit gemm (no pointer shift but shift using tensor descriptors) template ( - get_first_element_workaround(as_grid_desc_ak0_m_ak1), + ATransfer::template get_first_element_workaround(as_grid_desc_ak0_m_ak1), a_block_desc_ak0_m_ak1, a_blockwise_copy, - get_first_element_workaround(as_grid_buf), + ATransfer::template get_first_element_workaround(as_grid_buf), a_block_buf, a_block_slice_copy_step, - get_first_element_workaround(bs_grid_desc_bk0_n_bk1), + BTransfer::template get_first_element_workaround(bs_grid_desc_bk0_n_bk1), b_block_desc_bk0_n_bk1, b_blockwise_copy, - get_first_element_workaround(bs_grid_buf), + BTransfer::template get_first_element_workaround(bs_grid_buf), b_block_buf, b_block_slice_copy_step, c_thread_buf, diff --git a/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_kn_mn_common.hpp b/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_kn_mn_common.hpp index 4cd4403436..0dd666b3d9 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_kn_mn_common.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_kn_mn_common.hpp @@ -73,14 +73,17 @@ template using device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_kn_mn_comp_instances = std::tuple< // clang-format off - //###################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| BlkGemmPipeSched| BlkGemmPipelineVer| - //###################################| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MPerBlock_NBlock_NPerBlock| ScalarPerVector| | | - //###################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| | | | | - //###################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGemmMultipleABD_Wmma_CShuffleV3< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 256, 256, 32, 8, 8, 16, 16, 8, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, - DeviceGemmMultipleABD_Wmma_CShuffleV3< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, - DeviceGemmMultipleABD_Wmma_CShuffleV3< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 128, 256, 32, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, - DeviceGemmMultipleABD_Wmma_CShuffleV3< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1> + //###################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| BlkGemmPipeSched| BlkGemmPipelineVer| + //###################################| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MPerBlock_NBlock_NPerBlock| ScalarPerVector| | | + //###################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| | | | | + //###################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmMultipleABD_Wmma_CShuffleV3< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 256, 256, 32, 8, 8, 16, 16, 8, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemmMultipleABD_Wmma_CShuffleV3< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemmMultipleABD_Wmma_CShuffleV3< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 128, 256, 32, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceGemmMultipleABD_Wmma_CShuffleV3< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceGemmMultipleABD_Wmma_CShuffleV3< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 8, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, S<8, 8, 8>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceGemmMultipleABD_Wmma_CShuffleV3< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 128, 256, 32, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceGemmMultipleABD_Wmma_CShuffleV3< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1> // clang-format on >; diff --git a/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp index 76a92a1971..3587c6700c 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp @@ -39,7 +39,7 @@ void add_device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_kn_mn_v1_instances( Multiply, PassThrough, GemmMNKPadding, - Interwave>{}); + Intrawave>{}); add_device_operation_instances(instances, device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_kn_mn_mem_instances< ck::Tuple, diff --git a/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_nk_mn_common.hpp b/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_nk_mn_common.hpp index 1607b240f6..7cb50cd954 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_nk_mn_common.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_nk_mn_common.hpp @@ -71,12 +71,15 @@ template using device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_nk_mn_comp_instances = std::tuple< // clang-format off - //###################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| BlkGemmPipeSched| BlkGemmPipelineVer| - //###################################| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MPerBlock_NBlock_NPerBlock| ScalarPerVector| | | - //###################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| | | | | - //###################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGemmMultipleABD_Wmma_CShuffleV3< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, - DeviceGemmMultipleABD_Wmma_CShuffleV3< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1> + //###################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| BlkGemmPipeSched| BlkGemmPipelineVer| + //###################################| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MPerBlock_NBlock_NPerBlock| ScalarPerVector| | | + //###################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| | | | | + //###################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmMultipleABD_Wmma_CShuffleV3< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemmMultipleABD_Wmma_CShuffleV3< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceGemmMultipleABD_Wmma_CShuffleV3< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 128, 256, 32, 8, 8, 16, 16, 2, 8, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 64, 1, 4>, S<8, 8, 8>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceGemmMultipleABD_Wmma_CShuffleV3< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceGemmMultipleABD_Wmma_CShuffleV3< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 8, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, S<8, 8, 8>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1> // clang-format on >; } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_wmma_multi_abd_bias_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_wmma_multi_abd_bias_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp index 2a4aae98a5..731518257b 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_wmma_multi_abd_bias_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_wmma_multi_abd_bias_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp @@ -39,7 +39,7 @@ void add_device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_kn_mn_bias_v1_instances( Multiply, Add, GemmMNKPadding, - Interwave>{}); + Intrawave>{}); add_device_operation_instances(instances, device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_kn_mn_mem_instances< ck::Tuple, diff --git a/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_wmma_multi_abd_bias_gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_wmma_multi_abd_bias_gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp index 477d6811d2..0a67f2357e 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_wmma_multi_abd_bias_gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_wmma_multi_abd_bias_gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp @@ -39,7 +39,7 @@ void add_device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_kn_mn_bias_gelu_v1_instances Multiply, AddFastGelu, GemmMNKPadding, - Interwave>{}); + Intrawave>{}); add_device_operation_instances(instances, device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_kn_mn_mem_instances< ck::Tuple, diff --git a/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_wmma_multi_abd_bias_gelu_bf16_i8_bf16_mk_nk_mn_v1_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_wmma_multi_abd_bias_gelu_bf16_i8_bf16_mk_nk_mn_v1_instance.cpp index 71c04b3485..c0b4cf7b9a 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_wmma_multi_abd_bias_gelu_bf16_i8_bf16_mk_nk_mn_v1_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_wmma_multi_abd_bias_gelu_bf16_i8_bf16_mk_nk_mn_v1_instance.cpp @@ -36,7 +36,7 @@ void add_device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_nk_mn_bias_gelu_v1_instances ck::Tuple, AddFastGelu, GemmMNKPadding, - Interwave>{}); + Intrawave>{}); } void add_device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_nk_mn_bias_v1_instances( @@ -58,7 +58,7 @@ void add_device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_nk_mn_bias_v1_instances( ck::Tuple, Add, GemmMNKPadding, - Interwave>{}); + Intrawave>{}); } void add_device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_nk_mn_v1_instances( @@ -80,7 +80,7 @@ void add_device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_nk_mn_v1_instances( ck::Tuple<>, PassThrough, GemmMNKPadding, - Interwave>{}); + Intrawave>{}); } void add_device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_nk_mn_gelu_v1_instances( @@ -102,7 +102,7 @@ void add_device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_nk_mn_gelu_v1_instances( ck::Tuple<>, FastGelu, GemmMNKPadding, - Interwave>{}); + Intrawave>{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_wmma_multi_abd_gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_wmma_multi_abd_gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp index 33422fc6db..9176910cea 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_wmma_multi_abd_gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_wmma_multi_abd_gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp @@ -39,7 +39,7 @@ void add_device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_kn_mn_gelu_v1_instances( Multiply, FastGelu, GemmMNKPadding, - Interwave>{}); + Intrawave>{}); add_device_operation_instances(instances, device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_kn_mn_mem_instances< diff --git a/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_wmma_multi_abd_multiply_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_wmma_multi_abd_multiply_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp index 639bda6017..669eb4144a 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_wmma_multi_abd_multiply_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_wmma_multi_abd_multiply_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp @@ -39,7 +39,7 @@ void add_device_gemm_wmma_multi_abd_multiply_bf16_i8_bf16_mk_kn_mn_v1_instances( PassThrough, Multiply, GemmMNKPadding, - Interwave>{}); + Intrawave>{}); add_device_operation_instances( instances, device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_kn_mn_mem_instances, diff --git a/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_wmma_multi_abd_multiply_bias_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_wmma_multi_abd_multiply_bias_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp index 7f8fea44c5..c6a812645b 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_wmma_multi_abd_multiply_bias_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_wmma_multi_abd_multiply_bias_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp @@ -39,7 +39,7 @@ void add_device_gemm_wmma_multi_abd_multiply_bf16_i8_bf16_mk_kn_mn_bias_v1_insta PassThrough, MultiplyAdd, GemmMNKPadding, - Interwave>{}); + Intrawave>{}); add_device_operation_instances(instances, device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_kn_mn_mem_instances< ck::Tuple, diff --git a/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_wmma_multi_abd_multiply_bias_gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_wmma_multi_abd_multiply_bias_gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp index b2bf995507..2d7ffd120d 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_wmma_multi_abd_multiply_bias_gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_wmma_multi_abd_multiply_bias_gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp @@ -39,7 +39,7 @@ void add_device_gemm_wmma_multi_abd_multiply_bf16_i8_bf16_mk_kn_mn_bias_gelu_v1_ PassThrough, MultiplyAddFastGelu, GemmMNKPadding, - Interwave>{}); + Intrawave>{}); add_device_operation_instances(instances, device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_kn_mn_mem_instances< ck::Tuple, diff --git a/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_wmma_multi_abd_multiply_gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_wmma_multi_abd_multiply_gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp index d2adc36dc3..ab49d2f1c9 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_wmma_multi_abd_multiply_gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_wmma_multi_abd_multiply_gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp @@ -39,7 +39,7 @@ void add_device_gemm_wmma_multi_abd_multiply_bf16_i8_bf16_mk_kn_mn_gelu_v1_insta PassThrough, MultiplyFastGelu, GemmMNKPadding, - Interwave>{}); + Intrawave>{}); add_device_operation_instances( instances, device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_kn_mn_mem_instances, From 05ef93a69d8ccaf63f84b43b3dcb9b585f428051 Mon Sep 17 00:00:00 2001 From: Illia Silin <98187287+illsilin@users.noreply.github.com> Date: Thu, 29 Jan 2026 16:12:49 -0800 Subject: [PATCH 3/8] Add a flag to build CK libs required for HipTensor. (#3684) * create a filter to build only libs required by hiptensor * allow building libs for miopen and hiptensor at the same time * tweak the lib filtering logic one more time --- CMakeLists.txt | 8 +++--- .../gpu/CMakeLists.txt | 26 ++++++++++++++----- 2 files changed, 24 insertions(+), 10 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 356491d9c1..610f9c9d2a 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -41,6 +41,7 @@ include(CTest) option(ENABLE_CLANG_CPP_CHECKS "Enables clang tidy, cppcheck" ON) option(MIOPEN_REQ_LIBS_ONLY "Build only the MIOpen required libraries" OFF) +option(HIPTENSOR_REQ_LIBS_ONLY "Build only the HipTensor required libraries" OFF) option(CK_EXPERIMENTAL_BUILDER "Enable experimental builder" OFF) option(BUILD_MHA_LIB "Build the static library for flash attention" OFF) option(FORCE_DISABLE_XDL "Skip compiling XDL specific instances (even if supported GPUs are included in GPU_TARGETS)" OFF) @@ -648,7 +649,7 @@ if("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU" AND CMAKE_CXX_COMPILER_VERSION VERS add_compile_options(-fdiagnostics-color=always) endif() -if(NOT MIOPEN_REQ_LIBS_ONLY) +if(NOT MIOPEN_REQ_LIBS_ONLY AND NOT HIPTENSOR_REQ_LIBS_ONLY) # make check runs the entire set of examples and tests add_custom_target(check COMMAND ${CMAKE_CTEST_COMMAND} --output-on-failure -C ${CMAKE_CFG_INTDIR} USES_TERMINAL) # make smoke runs the tests and examples that runs within 30 seconds on gfx90a @@ -706,6 +707,7 @@ ENDFOREACH() add_custom_target(instances DEPENDS utility;${CK_DEVICE_INSTANCES} SOURCES ${INSTANCE_FILES}) option(MIOPEN_REQ_LIBS_ONLY "Build only the MIOpen required libraries" OFF) +option(HIPTENSOR_REQ_LIBS_ONLY "Build only the HipTensor required libraries" OFF) option(DISABLE_OFFLOAD_COMPRESS "Disable offload compress compiler flag when building instances" OFF) option(BUILD_MHA_LIB "Build the static library for flash attention" OFF) @@ -716,7 +718,7 @@ if (CK_EXPERIMENTAL_BUILDER) add_subdirectory(experimental/grouped_convolution_tile_instances) endif() -if(NOT GPU_ARCHS AND USER_GPU_TARGETS AND NOT MIOPEN_REQ_LIBS_ONLY) +if(NOT GPU_ARCHS AND USER_GPU_TARGETS AND NOT MIOPEN_REQ_LIBS_ONLY AND NOT HIPTENSOR_REQ_LIBS_ONLY) rocm_package_setup_component(tests LIBRARY_NAME composablekernel PACKAGE_NAME tests # Prevent -static suffix on package name @@ -739,7 +741,7 @@ if(NOT GPU_ARCHS AND USER_GPU_TARGETS AND NOT MIOPEN_REQ_LIBS_ONLY) endif() endif() -if (NOT MIOPEN_REQ_LIBS_ONLY) +if (NOT MIOPEN_REQ_LIBS_ONLY AND NOT HIPTENSOR_REQ_LIBS_ONLY) rocm_package_setup_component(profiler LIBRARY_NAME composablekernel PACKAGE_NAME ckprofiler diff --git a/library/src/tensor_operation_instance/gpu/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/CMakeLists.txt index 41fc8b740e..d5989e7a39 100644 --- a/library/src/tensor_operation_instance/gpu/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/CMakeLists.txt @@ -335,11 +335,23 @@ FOREACH(subdir_path ${dir_list}) endif() endif() + # Build the required pattern based on library settings + set(required_pattern "") + set(pattern_parts "") if(MIOPEN_REQ_LIBS_ONLY) message(STATUS "Removing all sources that are not required for MIOpen") - if(NOT "${cmake_instance}" MATCHES "conv") - set(add_inst 0) - endif() + list(APPEND pattern_parts "conv") + endif() + if(HIPTENSOR_REQ_LIBS_ONLY) + message(STATUS "Removing all sources that are not required for HipTensor") + list(APPEND pattern_parts "contract" "reduce" "element") + endif() + if(pattern_parts) + string(JOIN "|" required_pattern ${pattern_parts}) + endif() + # Apply the pattern if one was set + if(required_pattern AND NOT "${cmake_instance}" MATCHES "${required_pattern}") + set(add_inst 0) endif() if((add_inst EQUAL 1)) @@ -405,7 +417,7 @@ if(CK_DEVICE_OTHER_INSTANCES AND NOT MIOPEN_REQ_LIBS_ONLY) DESTINATION ${CMAKE_INSTALL_LIBDIR}/cmake/composable_kernel ) endif() -if(CK_DEVICE_GEMM_INSTANCES AND NOT MIOPEN_REQ_LIBS_ONLY) +if(CK_DEVICE_GEMM_INSTANCES AND NOT MIOPEN_REQ_LIBS_ONLY AND NOT HIPTENSOR_REQ_LIBS_ONLY) add_library(device_gemm_operations ${CK_DEVICE_GEMM_INSTANCES}) add_library(composablekernels::device_gemm_operations ALIAS device_gemm_operations) target_compile_features(device_gemm_operations PUBLIC) @@ -426,7 +438,7 @@ if(CK_DEVICE_GEMM_INSTANCES AND NOT MIOPEN_REQ_LIBS_ONLY) DESTINATION ${CMAKE_INSTALL_LIBDIR}/cmake/composable_kernel ) endif() -if(CK_DEVICE_CONV_INSTANCES) +if(CK_DEVICE_CONV_INSTANCES AND (NOT HIPTENSOR_REQ_LIBS_ONLY OR MIOPEN_REQ_LIBS_ONLY)) add_library(device_conv_operations ${CK_DEVICE_CONV_INSTANCES}) add_library(composablekernels::device_conv_operations ALIAS device_conv_operations) target_compile_features(device_conv_operations PUBLIC) @@ -451,7 +463,7 @@ if(CK_DEVICE_CONV_INSTANCES) DESTINATION ${CMAKE_INSTALL_LIBDIR}/cmake/composable_kernel ) endif() -if(CK_DEVICE_MHA_INSTANCES AND NOT MIOPEN_REQ_LIBS_ONLY AND BUILD_MHA_LIB) +if(CK_DEVICE_MHA_INSTANCES AND NOT MIOPEN_REQ_LIBS_ONLY AND NOT HIPTENSOR_REQ_LIBS_ONLY AND BUILD_MHA_LIB) set(gpu_list ${INST_TARGETS}) if(gpu_list MATCHES "gfx94" OR gpu_list MATCHES "gfx90a" OR gpu_list MATCHES "gfx95") add_library(device_mha_operations ${CK_DEVICE_MHA_INSTANCES}) @@ -517,7 +529,7 @@ if(CK_DEVICE_REDUCTION_INSTANCES AND NOT MIOPEN_REQ_LIBS_ONLY) ) endif() -if(NOT MIOPEN_REQ_LIBS_ONLY) +if(NOT MIOPEN_REQ_LIBS_ONLY AND NOT HIPTENSOR_REQ_LIBS_ONLY) add_library(device_operations INTERFACE) target_link_libraries(device_operations INTERFACE device_contraction_operations From 83b61553548019eb9aa77a5efc72258a48dee42a Mon Sep 17 00:00:00 2001 From: Max Podkorytov <4273004+tenpercent@users.noreply.github.com> Date: Thu, 29 Jan 2026 17:20:22 -0800 Subject: [PATCH 4/8] Add ck-rocprof: GPU profiling tool for rocprof-compute (#3627) * Decouple configure/build/test tools from Docker Create a two-layer tool architecture: - Core tools (ck-configure, ck-build, ck-test): Environment-agnostic, work on any system with ROCm - no Docker dependency - Container tools (ck-docker): Manage Docker containers and delegate to core tools via docker exec Changes: - Add ck-configure: New CMake configuration tool with preset support, native GPU detection, and flexible options - Refactor ck-build: Remove Docker dependency, add --configure and --list options, call ninja directly - Refactor ck-test: Remove Docker dependency, add CTest integration with --smoke/--regression/--all options - Enhance common.sh: Add native GPU detection, build directory utils, and output helpers - Update ck-docker: Add configure/build/test/exec commands that delegate to core tools inside container This enables: - Native development on ROCm hosts without Docker - Simpler CI/CD integration - Consistent behavior inside and outside containers Co-Authored-By: Claude * Add ck-rocprof: GPU profiling tool for rocprof-compute Adds a command-line profiling tool to simplify GPU performance analysis workflow using AMD rocprof-compute. Features: - Easy setup with automatic Python venv configuration - Simple CLI: setup, run, analyze, compare, list - Automatic GPU architecture detection - Focus on LDS metrics (Block 12) for bank conflict analysis - Comprehensive documentation with examples and troubleshooting Usage: ck-rocprof setup # One-time environment setup ck-rocprof run # Profile executable ck-rocprof analyze [block] # Analyze metrics ck-rocprof compare # Compare two runs ck-rocprof list # List available runs * Make ck-rocprof documentation concise and improve Docker integration - Streamlined documentation from 416 to 157 lines (62% reduction) - Focused on essential commands, metrics, and workflows - Enhanced script to run all operations inside Docker containers - Fixed workload directory path and improved container management - Added automatic rocprofiler-compute installation and dependency handling * Add --no-roof flag to ck-rocprof profile command Skip roofline analysis by default to speed up profiling. Roofline analysis can add significant time to profiling runs but is not needed for most LDS bank conflict analysis workflows. * Make ck-rocprof work independently of Docker Add native execution mode that runs rocprof-compute directly on the host system when available, falling back to Docker mode when not. Key changes: - Auto-detect native mode when rocprof-compute is in PATH or common locations - Add execution mode wrappers (exec_cmd, file_exists, dir_exists, etc.) - Native mode stores venv at .ck-rocprof-venv in project root - Native mode stores workloads at build/workloads/ - Support user-installed rocprofiler-compute (e.g., ~/.local/rocprofiler-compute) - Add CK_FORCE_DOCKER env var to force Docker mode - Update help message to show current execution mode - Maintain full backward compatibility with existing Docker workflow Tested successfully with rocprofiler-compute 3.4.0 installed from source on MI300X GPU in native mode. Co-Authored-By: Claude * Add clean/status commands and improve ck-rocprof robustness - Add 'clean' command to remove profiling runs (supports --all) - Add 'status' command to show configuration and environment info - Add workload name validation to prevent path traversal attacks - Fix uv installation to use pip instead of curl for reliability - Add cross-platform stat support for macOS compatibility - Consolidate ROCPROF_CANDIDATES to avoid code duplication - Expand help documentation with all profiling block descriptions - Fix Docker wrapper script escaping issues Co-Authored-By: Claude * Fix analyze command to use correct workload path rocprof-compute stores results directly in the workload directory (pmc_perf.csv) rather than in a GPU architecture subdirectory. Updated find_workload_path to detect this correctly. Co-Authored-By: Claude * Address PR review security and robustness issues Security fixes: - Escape executable path in cmd_run to prevent shell injection - Add workload name validation to cmd_analyze and cmd_compare Robustness improvements: - Add error checking for uv package manager installation - Use consistent project root detection (find_project_root || get_project_root) - Use /opt/rocm instead of hardcoded /opt/rocm-7.0.1 in Docker mode - Derive ROCM_REQUIREMENTS path from ROCPROF_BIN for flexibility - Use gfx950 as fallback GPU consistent with common.sh Documentation updates: - Fix env var name GPU_TARGET -> CK_GPU_TARGET - Update storage layout to reflect current structure (workloads//) - Document clean and status commands - Clarify native vs Docker default paths Co-Authored-By: Claude * Simplify ck-rocprof to native-only mode Remove Docker mode from ck-rocprof. Docker users should run the tool via `ck-docker exec ck-rocprof ...` instead. This simplification: - Removes ~210 lines of Docker-specific code - Eliminates mode detection complexity - Makes the script easier to maintain - Provides clearer error messages when rocprof-compute is not found The setup command now lists all searched locations when rocprof-compute is not found, helping users understand how to install it. Co-Authored-By: Claude * Add rocprofiler-compute source installation fallback When rocprof-compute is not found in system locations, automatically install rocprofiler-compute 3.4.0 from source as a fallback. This eliminates the hard dependency on system ROCm packages. Implementation details: - Clone rocprofiler-compute from GitHub to ~/.local/ - Install dependencies via requirements.txt (not editable install) - Create wrapper that sets PYTHONPATH to source directory - Execute source script directly rather than importing as module This approach matches the project's development workflow and works around the incomplete pyproject.toml that prevents editable installs. Co-Authored-By: Claude --------- Co-authored-by: Claude --- script/tools/ck-build | 144 ++++--- script/tools/ck-configure | 187 +++++++++ script/tools/ck-docker | 168 +++----- script/tools/ck-rocprof | 806 +++++++++++++++++++++++++++++++++++++ script/tools/ck-rocprof.md | 167 ++++++++ script/tools/ck-test | 239 +++++++---- script/tools/common.sh | 92 ++++- 7 files changed, 1528 insertions(+), 275 deletions(-) create mode 100755 script/tools/ck-configure create mode 100755 script/tools/ck-rocprof create mode 100644 script/tools/ck-rocprof.md diff --git a/script/tools/ck-build b/script/tools/ck-build index 2c0bb24eda..a2a02387eb 100755 --- a/script/tools/ck-build +++ b/script/tools/ck-build @@ -2,7 +2,8 @@ # Copyright (c) Advanced Micro Devices, Inc., or its affiliates. # SPDX-License-Identifier: MIT -# CK Build - Build Composable Kernel targets in Docker +# CK Build - Build Composable Kernel targets +# Environment-agnostic: works natively on ROCm hosts or inside containers set -e set -o pipefail @@ -12,46 +13,51 @@ SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" source "${SCRIPT_DIR}/common.sh" # Initialize configuration -PROJECT_ROOT=$(get_project_root "${SCRIPT_DIR}") -CONTAINER_NAME=$(get_container_name "${PROJECT_ROOT}") +PROJECT_ROOT=$(find_project_root "${SCRIPT_DIR}" || get_project_root "${SCRIPT_DIR}") +BUILD_DIR=$(get_build_dir "${PROJECT_ROOT}") # Help message show_help() { cat << EOF -CK Build - Build Composable Kernel targets in Docker +CK Build - Build Composable Kernel targets Usage: ck-build [options] [target...] Options: -h, --help Show this help message - --name Specify container name - --reconfigure Reconfigure CMake before building -j Parallel jobs (passed to ninja) + -v, --verbose Verbose output + --build-dir Build directory (default: ./build) --clean Clean before building + --configure Auto-configure if build.ninja missing + --list List available targets Arguments: target Target(s) to build (default: all) Environment: - CK_CONTAINER_NAME - Override default container name - GPU_TARGET - Override GPU target detection (e.g., gfx950, gfx942) + CK_BUILD_DIR - Override build directory + CK_GPU_TARGET - Override GPU target for auto-configure Examples: ck-build # Build all targets ck-build test_amdgcn_mma # Build specific target ck-build test_amdgcn_mma test_gemm # Build multiple targets - ck-build --reconfigure # Reconfigure CMake and build all + ck-build --configure # Auto-configure and build all ck-build --clean test_amdgcn_mma # Clean and build target ck-build -j 8 test_amdgcn_mma # Build with 8 parallel jobs + ck-build --list # List available targets EOF } # Parse arguments targets=() -reconfigure=false -clean=false parallel_jobs="" +verbose=false +clean=false +auto_configure=false +list_targets=false while [[ $# -gt 0 ]]; do case $1 in @@ -59,21 +65,35 @@ while [[ $# -gt 0 ]]; do show_help exit 0 ;; - --name) - CONTAINER_NAME="$2" + -j) + require_arg "$1" "${2:-}" + parallel_jobs="$2" shift 2 ;; - --reconfigure) - reconfigure=true + -j*) + parallel_jobs="${1#-j}" shift ;; + -v|--verbose) + verbose=true + shift + ;; + --build-dir) + require_arg "$1" "${2:-}" + BUILD_DIR="$2" + shift 2 + ;; --clean) clean=true shift ;; - -j) - parallel_jobs="-j $2" - shift 2 + --configure) + auto_configure=true + shift + ;; + --list) + list_targets=true + shift ;; *) targets+=("$1") @@ -82,62 +102,62 @@ while [[ $# -gt 0 ]]; do esac done -# Ensure container is running -if ! container_is_running "${CONTAINER_NAME}"; then - echo "Container '${CONTAINER_NAME}' not running. Starting..." - "${SCRIPT_DIR}/ck-start" "${CONTAINER_NAME}" +# Handle --list +if [ "$list_targets" = true ]; then + if ! is_build_configured "${BUILD_DIR}"; then + error "Build not configured. Run 'ck-configure' first or use --configure" + exit 1 + fi + info "Available targets:" + cd "${BUILD_DIR}" + ninja -t targets 2>/dev/null | grep -E '^[a-zA-Z_][a-zA-Z0-9_-]*:' | cut -d: -f1 | sort | head -100 echo "" + echo "(Showing first 100 targets. Use 'ninja -t targets' for full list)" + exit 0 fi -# Configure CMake if needed or requested -if [ "$reconfigure" = true ] || ! docker exec "${CONTAINER_NAME}" test -f /workspace/build/build.ninja 2>/dev/null; then - echo "Detecting GPU target..." - GPU_TARGET_DETECTED=$(detect_gpu_target "${CONTAINER_NAME}") - - if [ "$reconfigure" = true ]; then - echo "Reconfiguring CMake from scratch for GPU target: ${GPU_TARGET_DETECTED}" +# Auto-configure if needed +if ! is_build_configured "${BUILD_DIR}"; then + if [ "$auto_configure" = true ]; then + info "Build not configured. Running ck-configure..." + "${SCRIPT_DIR}/ck-configure" --build-dir "${BUILD_DIR}" + echo "" else - echo "Configuring build with CMake for GPU target: ${GPU_TARGET_DETECTED}" + error "Build not configured. Run 'ck-configure' first or use --configure" + exit 1 fi - - docker exec "${CONTAINER_NAME}" bash -c " - cd /workspace || exit 1 - rm -rf /workspace/build - mkdir /workspace/build - cd /workspace/build || exit 1 - cmake .. -GNinja \ - -DGPU_TARGETS=${GPU_TARGET_DETECTED} \ - -DCMAKE_BUILD_TYPE=Release \ - -DCMAKE_CXX_COMPILER=/opt/rocm/llvm/bin/clang++ \ - -DBUILD_TESTING=ON 2>&1 | tail -30 - " - echo "" fi # Clean if requested if [ "$clean" = true ]; then - echo "Cleaning build directory..." - docker exec "${CONTAINER_NAME}" bash -c " - cd /workspace/build || exit 1 - ninja clean - " + info "Cleaning build directory..." + cd "${BUILD_DIR}" + ninja clean echo "" fi -# Build targets -if [ ${#targets[@]} -eq 0 ]; then - echo "Building all configured targets..." - docker exec "${CONTAINER_NAME}" bash -c " - cd /workspace/build || exit 1 - ninja ${parallel_jobs} 2>&1 - " -else - echo "Building targets: ${targets[*]}" - docker exec "${CONTAINER_NAME}" bash -c " - cd /workspace/build || exit 1 - ninja ${parallel_jobs} ${targets[*]} 2>&1 - " +# Build ninja command +ninja_cmd=(ninja -C "${BUILD_DIR}") + +if [ -n "$parallel_jobs" ]; then + ninja_cmd+=("-j" "$parallel_jobs") fi +if [ "$verbose" = true ]; then + ninja_cmd+=(-v) +fi + +# Add targets +ninja_cmd+=("${targets[@]}") + +# Build targets +if [ ${#targets[@]} -eq 0 ]; then + info "Building all configured targets..." +else + info "Building targets: ${targets[*]}" +fi + +"${ninja_cmd[@]}" + echo "" -echo "Build complete ✓" +info "Build complete" diff --git a/script/tools/ck-configure b/script/tools/ck-configure new file mode 100755 index 0000000000..ffe5a4daca --- /dev/null +++ b/script/tools/ck-configure @@ -0,0 +1,187 @@ +#!/bin/bash +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +# CK Configure - Configure CMake build for Composable Kernel +# Environment-agnostic: works natively on ROCm hosts or inside containers + +set -e +set -o pipefail + +# Find script directory and load common utilities +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +source "${SCRIPT_DIR}/common.sh" + +# Initialize configuration +PROJECT_ROOT=$(find_project_root "${SCRIPT_DIR}" || get_project_root "${SCRIPT_DIR}") +BUILD_DIR=$(get_build_dir "${PROJECT_ROOT}") + +# Help message +show_help() { + cat << EOF +CK Configure - Configure CMake build for Composable Kernel + +Usage: ck-configure [options] + +Options: + -h, --help Show this help message + --preset Use CMake preset (dev, dev-gfx908, dev-gfx90a, dev-gfx942, dev-gfx950) + --gpu Override GPU_TARGETS (auto-detected if not specified) + --dtypes Set DTYPES (e.g., fp16,fp32,bf16) + --build-type CMAKE_BUILD_TYPE (default: Release) + --build-dir Build directory (default: ./build) + --clean Remove existing build directory before configuring + --list-presets List available CMake presets + -D = Pass additional CMake variable + +Environment: + CK_GPU_TARGET - Override GPU target detection (e.g., gfx950, gfx942) + CK_BUILD_DIR - Override build directory + +Examples: + ck-configure # Auto-detect GPU and configure + ck-configure --preset dev-gfx950 # Use CMake preset + ck-configure --gpu gfx942 # Configure for specific GPU + ck-configure --clean --preset dev # Clean and reconfigure + ck-configure -D BUILD_DEV=ON # Pass CMake variable + +EOF +} + +# Parse arguments +preset="" +gpu_target="" +dtypes="" +build_type="Release" +clean=false +list_presets=false +cmake_vars=() + +while [[ $# -gt 0 ]]; do + case $1 in + -h|--help) + show_help + exit 0 + ;; + --preset) + require_arg "$1" "${2:-}" + preset="$2" + shift 2 + ;; + --gpu) + require_arg "$1" "${2:-}" + gpu_target="$2" + shift 2 + ;; + --dtypes) + require_arg "$1" "${2:-}" + dtypes="$2" + shift 2 + ;; + --build-type) + require_arg "$1" "${2:-}" + build_type="$2" + shift 2 + ;; + --build-dir) + require_arg "$1" "${2:-}" + BUILD_DIR="$2" + shift 2 + ;; + --clean) + clean=true + shift + ;; + --list-presets) + list_presets=true + shift + ;; + -D) + require_arg "$1" "${2:-}" + cmake_vars+=("-D$2") + shift 2 + ;; + -D*) + cmake_vars+=("$1") + shift + ;; + *) + error "Unknown option: $1" + echo "" + show_help + exit 1 + ;; + esac +done + +# Handle --list-presets +if [ "$list_presets" = true ]; then + echo "Available CMake presets:" + presets=$(list_cmake_presets "${PROJECT_ROOT}" 2>/dev/null) + if [ -n "$presets" ]; then + echo "$presets" | sed 's/^/ /' + else + echo " (No CMakePresets.json found or jq not available)" + fi + exit 0 +fi + +# Clean build directory if requested +if [ "$clean" = true ]; then + if [ -d "${BUILD_DIR}" ]; then + info "Removing existing build directory: ${BUILD_DIR}" + rm -rf "${BUILD_DIR}" + fi +fi + +# Create build directory +mkdir -p "${BUILD_DIR}" + +# Change to project root for CMake +cd "${PROJECT_ROOT}" + +# Build CMake command +cmake_cmd=(cmake -S . -B "${BUILD_DIR}" -GNinja) + +# Use preset if specified +if [ -n "$preset" ]; then + cmake_cmd+=(--preset "${preset}") + info "Using CMake preset: ${preset}" +else + # Manual configuration + + # Detect GPU target if not specified + if [ -z "$gpu_target" ]; then + gpu_target=$(detect_gpu_native) + info "Auto-detected GPU target: ${gpu_target}" + else + info "Using specified GPU target: ${gpu_target}" + fi + + cmake_cmd+=(-DGPU_TARGETS="${gpu_target}") + cmake_cmd+=(-DCMAKE_BUILD_TYPE="${build_type}") + cmake_cmd+=(-DCMAKE_CXX_COMPILER=/opt/rocm/llvm/bin/clang++) + cmake_cmd+=(-DBUILD_TESTING=ON) + + # Add DTYPES if specified + if [ -n "$dtypes" ]; then + cmake_cmd+=(-DDTYPES="${dtypes}") + info "Using DTYPES: ${dtypes}" + fi +fi + +# Add any additional CMake variables +for var in "${cmake_vars[@]}"; do + cmake_cmd+=("$var") +done + +# Run CMake +info "Configuring build in: ${BUILD_DIR}" +echo "Running: ${cmake_cmd[*]}" +echo "" + +"${cmake_cmd[@]}" + +echo "" +info "Configuration complete. Build directory: ${BUILD_DIR}" +info "Next: run 'ck-build' to build targets" diff --git a/script/tools/ck-docker b/script/tools/ck-docker index 82bf770011..6c118561b7 100755 --- a/script/tools/ck-docker +++ b/script/tools/ck-docker @@ -22,25 +22,29 @@ CK Docker Tool - Build and test composable_kernel in Docker Usage: ck-docker [options] -Commands: - start [name] Start Docker container - build [target] [--reconfigure] Build target (optionally reconfigure CMake) - test [options] Run test - shell [name] Open shell in container - status [name] Check container status - stop [name] Stop and remove container +Container Management: + start [name] Start Docker container + stop [name] Stop and remove container + status [name] Check container status + shell [name] Open shell in container + +Build/Test (delegates to core tools inside container): + configure [opts] Run ck-configure in container + build [opts] Run ck-build in container + test [opts] Run ck-test in container + exec Run arbitrary command in container Examples: ck-docker start + ck-docker configure --preset dev-gfx950 ck-docker build test_amdgcn_mma - ck-docker build --reconfigure test_amdgcn_mma - ck-docker test test_amdgcn_mma --gtest_filter=*Fp16* + ck-docker test test_amdgcn_mma --filter '*Fp16*' ck-docker shell + ck-docker exec rocminfo Environment: CK_CONTAINER_NAME - Override default container name (default: ck__) CK_DOCKER_IMAGE - Override Docker image (default: rocm/composable_kernel:ck_ub24.04_rocm7.0.1) - GPU_TARGET - Override GPU target detection (e.g., gfx950, gfx942) EOF } @@ -77,126 +81,38 @@ cmd_start() { docker exec "${name}" bash -c "echo 'Working directory:' && pwd" } -# Build target -cmd_build() { - local target="" - local name="${CONTAINER_NAME}" - local reconfigure=false - - while [[ $# -gt 0 ]]; do - case $1 in - --name) - name="$2" - shift 2 - ;; - --reconfigure) - reconfigure=true - shift - ;; - *) - target="$1" - shift - ;; - esac - done - - # Check if container is running - if ! container_is_running "${name}"; then - echo "Container '${name}' not running. Starting..." - cmd_start "${name}" - fi - - # Reconfigure CMake if requested or if build.ninja doesn't exist - if [ "$reconfigure" = true ] || ! docker exec "${name}" test -f /workspace/build/build.ninja 2>/dev/null; then - echo "Detecting GPU target..." - local gpu_target=$(detect_gpu_target "${name}") - - if [ "$reconfigure" = true ]; then - echo "Reconfiguring CMake from scratch for GPU target: ${gpu_target}" - else - echo "Configuring build with CMake for GPU target: ${gpu_target}" - fi - - docker exec "${name}" bash -c " - cd /workspace || exit 1 - rm -rf /workspace/build - mkdir /workspace/build - cd /workspace/build || exit 1 - cmake .. -GNinja \ - -DGPU_TARGETS=${gpu_target} \ - -DCMAKE_BUILD_TYPE=Release \ - -DCMAKE_CXX_COMPILER=/opt/rocm/llvm/bin/clang++ \ - -DBUILD_TESTING=ON 2>&1 | tail -30 - " - fi - - if [ -z "$target" ]; then - echo "Building all configured targets..." - else - echo "Building target: ${target}" - fi - - docker exec "${name}" bash -c " - cd /workspace/build || exit 1 - ninja ${target} 2>&1 - " - - echo "Build complete" +# Configure (delegate to ck-configure in container) +cmd_configure() { + ensure_container_running "${CONTAINER_NAME}" "${SCRIPT_DIR}" + docker exec "${CONTAINER_NAME}" /workspace/script/tools/ck-configure "$@" } -# Run test +# Build (delegate to ck-build in container) +cmd_build() { + ensure_container_running "${CONTAINER_NAME}" "${SCRIPT_DIR}" + docker exec "${CONTAINER_NAME}" /workspace/script/tools/ck-build "$@" +} + +# Test (delegate to ck-test in container) cmd_test() { - local test_name="" - local name="${CONTAINER_NAME}" - local -a test_options=() + ensure_container_running "${CONTAINER_NAME}" "${SCRIPT_DIR}" + docker exec "${CONTAINER_NAME}" /workspace/script/tools/ck-test "$@" +} - while [[ $# -gt 0 ]]; do - case $1 in - --name) - name="$2" - shift 2 - ;; - --gtest_*|--help) - test_options+=("$1") - shift - ;; - *) - if [ -z "$test_name" ]; then - test_name="$1" - else - test_options+=("$1") - fi - shift - ;; - esac - done - - if [ -z "$test_name" ]; then - echo "Error: test_name required" - echo "Usage: ck-docker test [--name container_name] [gtest_options]" +# Execute arbitrary command in container +cmd_exec() { + if [ $# -eq 0 ]; then + error "command required" + echo "Usage: ck-docker exec " return 1 fi - # Check if container is running - if ! container_is_running "${name}"; then - echo "Error: Container '${name}' not running" - echo "Start it with: ck-docker start --name ${name}" - return 1 - fi + ensure_container_running "${CONTAINER_NAME}" "${SCRIPT_DIR}" - if ! docker exec "${name}" test -f "/workspace/build/bin/${test_name}" 2>/dev/null; then - echo "Test executable not found. Building ${test_name}..." - cmd_build "${test_name}" --name "${name}" - fi + local docker_flags=() + [ -t 0 ] && [ -t 1 ] && docker_flags+=("-it") - echo "Running: ${test_name} ${test_options[*]}" - echo "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━" - # Build the command with proper quoting - local cmd="cd /workspace/build && ./bin/${test_name}" - for opt in "${test_options[@]}"; do - cmd="${cmd} $(printf '%q' "$opt")" - done - docker exec "${name}" bash -c "${cmd}" + docker exec "${docker_flags[@]}" "${CONTAINER_NAME}" "$@" } # Shell @@ -220,7 +136,7 @@ cmd_status() { if [ -z "$name" ]; then echo "Composable Kernel Docker Containers:" - echo "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━" + echo "---" docker ps -a --filter "ancestor=${docker_image}" \ --format "table {{.Names}}\t{{.Status}}\t{{.CreatedAt}}" || echo "No containers found" else @@ -262,6 +178,10 @@ case "${1:-}" in shift cmd_start "$@" ;; + configure) + shift + cmd_configure "$@" + ;; build) shift cmd_build "$@" @@ -270,6 +190,10 @@ case "${1:-}" in shift cmd_test "$@" ;; + exec) + shift + cmd_exec "$@" + ;; shell) shift cmd_shell "$@" diff --git a/script/tools/ck-rocprof b/script/tools/ck-rocprof new file mode 100755 index 0000000000..2b41a7403c --- /dev/null +++ b/script/tools/ck-rocprof @@ -0,0 +1,806 @@ +#!/bin/bash +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +# CK ROCProf Tool - Profile CK applications with rocprof-compute +# Native-only tool. For Docker usage, run via: ck-docker exec ck-rocprof ... + +set -e +set -o pipefail + +# Find script directory and load common utilities +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +source "${SCRIPT_DIR}/common.sh" + +# Initialize configuration +PROJECT_ROOT=$(find_project_root "${SCRIPT_DIR}" || get_project_root "${SCRIPT_DIR}") + +# ============================================================================ +# rocprof-compute detection +# ============================================================================ + +# Common rocprof-compute binary locations +# Order: user installs first, then system ROCm versions (newest first) +ROCPROF_CANDIDATES=( + "${HOME}/.local/rocprofiler-compute/3.4.0/bin/rocprof-compute" + "/opt/rocm/bin/rocprof-compute" + "/opt/rocm-7.2.0/bin/rocprof-compute" + "/opt/rocm-7.0.1/bin/rocprof-compute" + "/opt/rocm-6.2.0/bin/rocprof-compute" + "/opt/rocm-6.1.0/bin/rocprof-compute" +) + +# Find rocprof-compute binary +find_rocprof_bin() { + # Check CK_ROCPROF_BIN first + if [ -n "${CK_ROCPROF_BIN:-}" ] && [ -f "${CK_ROCPROF_BIN}" ]; then + echo "${CK_ROCPROF_BIN}" + return 0 + fi + + # Check PATH + if command -v rocprof-compute &>/dev/null; then + command -v rocprof-compute + return 0 + fi + + # Check common ROCm locations and user installations + for bin in "${ROCPROF_CANDIDATES[@]}"; do + if [ -f "$bin" ]; then + echo "$bin" + return 0 + fi + done + + return 1 +} + +# Find ROCm requirements file +find_rocm_requirements() { + local rocprof_bin="${1:-$(find_rocprof_bin)}" + if [ -z "$rocprof_bin" ]; then + return 1 + fi + + # Requirements file is typically at ../libexec/rocprofiler-compute/requirements.txt + local rocm_dir + rocm_dir=$(dirname "$(dirname "$rocprof_bin")") + local req_file="${rocm_dir}/libexec/rocprofiler-compute/requirements.txt" + + if [ -f "$req_file" ]; then + echo "$req_file" + return 0 + fi + + return 1 +} + +# ============================================================================ +# Configuration +# ============================================================================ + +ROCPROF_BIN="${CK_ROCPROF_BIN:-$(find_rocprof_bin || echo "")}" +VENV_PATH="${CK_PROFILE_VENV:-${PROJECT_ROOT}/.ck-rocprof-venv}" +WORKLOAD_DIR="${CK_WORKLOAD_DIR:-$(get_build_dir "${PROJECT_ROOT}")/workloads}" +ROCM_REQUIREMENTS="${CK_ROCM_REQUIREMENTS:-$(find_rocm_requirements "${ROCPROF_BIN}" || echo "")}" + +# ============================================================================ +# Helper functions +# ============================================================================ + +# Get file/directory size +get_size() { + local path="$1" + du -sh "$path" 2>/dev/null | cut -f1 +} + +# Get file modification date (cross-platform: Linux and macOS) +get_date() { + local path="$1" + # Try GNU stat first (Linux), fall back to BSD stat (macOS) + if stat --version &>/dev/null 2>&1; then + stat -c %y "$path" 2>/dev/null | cut -d' ' -f1 + else + stat -f %Sm -t %Y-%m-%d "$path" 2>/dev/null + fi +} + +# Help message +show_help() { + cat << EOF +CK ROCProf Tool - Profile CK applications with rocprof-compute + +Usage: ck-rocprof [options] + +Commands: + setup One-time setup: create Python venv and install dependencies + run [args] Profile executable and save results as + analyze [block] Analyze profiling results (default: block 12 - LDS metrics) + compare Compare two profiling runs + list List available profiling runs + clean Remove a profiling run (use --all for all runs) + status Show current configuration and status + help Show this help message + +Examples: + ck-rocprof setup + ck-rocprof run baseline ./bin/tile_example_gemm_universal + ck-rocprof analyze baseline + ck-rocprof analyze baseline 12 + ck-rocprof compare baseline optimized + ck-rocprof list + ck-rocprof clean baseline + ck-rocprof status + +Environment Variables: + CK_GPU_TARGET - Override GPU detection (e.g., gfx950, MI300X) + CK_PROFILE_VENV - Python venv path (default: \$PROJECT/.ck-rocprof-venv) + CK_ROCPROF_BIN - rocprof-compute binary path + CK_ROCM_REQUIREMENTS - Path to rocprofiler-compute requirements.txt + CK_WORKLOAD_DIR - Workload storage directory + +Profiling Blocks (use with 'analyze '): + Block 2: System Speed-of-Light (SOL) + Block 6: Shader Engine (SE) utilization + Block 7: L2 Cache metrics + Block 11: Vector L1D Cache metrics + Block 12: LDS (Local Data Share) - DEFAULT + Block 16: Instruction mix statistics + Block 17: Compute Unit (CU) metrics + +LDS Metrics (Block 12): + - 12.1.3: Bank Conflict Rate (% of peak) + - 12.2.9: Bank Conflicts/Access (conflicts/access) + - 12.2.12: Bank Conflict (cycles per kernel) + - 12.2.17: LDS Data FIFO Full Rate (cycles) + +Notes: + - Workload names must be alphanumeric with hyphens/underscores only + - Profiling skips roofline analysis (--no-roof) for faster execution + - Results stored in workloads// + - For Docker usage, run via: ck-docker exec ck-rocprof ... +EOF +} + +# Get rocprof-compute wrapper path +get_rocprof_wrapper() { + echo "${VENV_PATH}/bin/rocprof-compute" +} + +# Validate workload name to prevent path traversal and shell injection +# Allowed: alphanumeric, hyphens, underscores +validate_workload_name() { + local name="$1" + if [[ ! "$name" =~ ^[a-zA-Z0-9_-]+$ ]]; then + error "Invalid workload name: '$name'" + echo "Names must contain only letters, numbers, hyphens, and underscores" + return 1 + fi + # Prevent reserved names + if [[ "$name" == "." || "$name" == ".." ]]; then + error "Invalid workload name: '$name'" + return 1 + fi + return 0 +} + +# Check if setup is complete +is_setup_complete() { + local wrapper + wrapper=$(get_rocprof_wrapper) + [ -d "${VENV_PATH}" ] && [ -f "${wrapper}" ] +} + +# ============================================================================ +# Source installation +# ============================================================================ + +# rocprofiler-compute source installation location +ROCPROF_SOURCE_VERSION="3.4.0" +ROCPROF_SOURCE_DIR="${HOME}/.local/rocprofiler-compute/${ROCPROF_SOURCE_VERSION}" +ROCPROF_SOURCE_BIN="${ROCPROF_SOURCE_DIR}/bin/rocprof-compute" +ROCPROF_REPO_URL="https://github.com/ROCm/rocprofiler-compute.git" +ROCPROF_REPO_BRANCH="release/rocprofiler-compute-v${ROCPROF_SOURCE_VERSION}" + +# Install rocprofiler-compute from source +install_from_source() { + local install_dir="${ROCPROF_SOURCE_DIR}" + local src_dir="${install_dir}/src" + + info "Installing rocprofiler-compute ${ROCPROF_SOURCE_VERSION} from source..." + echo "Install location: ${install_dir}" + echo "" + + # Ensure uv is available + if ! command -v uv &>/dev/null; then + info "Installing uv package manager via pip..." + if ! python3 -m pip install --user uv; then + error "Failed to install uv package manager" + return 1 + fi + export PATH="${HOME}/.local/bin:${PATH}" + if ! command -v uv &>/dev/null; then + error "uv installed but not found in PATH" + return 1 + fi + fi + + # Create installation directory + mkdir -p "${install_dir}" + + # Clone repository + if [ -d "${src_dir}" ]; then + info "Source already exists, updating..." + git -C "${src_dir}" fetch --quiet + git -C "${src_dir}" checkout --quiet "${ROCPROF_REPO_BRANCH}" 2>/dev/null || \ + git -C "${src_dir}" checkout --quiet "amd-mainline" + else + info "Cloning rocprofiler-compute repository..." + if ! git clone --quiet --branch "${ROCPROF_REPO_BRANCH}" --depth 1 "${ROCPROF_REPO_URL}" "${src_dir}" 2>/dev/null; then + # Fall back to amd-mainline if release branch doesn't exist + info "Release branch not found, using amd-mainline..." + git clone --quiet --branch "amd-mainline" --depth 1 "${ROCPROF_REPO_URL}" "${src_dir}" + fi + fi + + # Create venv for source installation + local venv_dir="${install_dir}/venv" + if [ ! -d "${venv_dir}" ]; then + info "Creating Python virtual environment..." + uv venv "${venv_dir}" + fi + + # Install dependencies from requirements.txt + info "Installing dependencies (this may take a minute)..." + uv pip install --python "${venv_dir}/bin/python" -r "${src_dir}/requirements.txt" --quiet + # Pin pandas to avoid CSV conversion bug + uv pip install --python "${venv_dir}/bin/python" 'pandas<3.0' --quiet + + # Create bin directory and wrapper script + mkdir -p "${install_dir}/bin" + cat > "${ROCPROF_SOURCE_BIN}" << 'WRAPPER_EOF' +#!/bin/bash +# rocprof-compute wrapper for source installation +INSTALL_DIR="$(cd "$(dirname "$0")/.." && pwd)" +SRC_DIR="${INSTALL_DIR}/src/src" +VENV_DIR="${INSTALL_DIR}/venv" + +# Set PYTHONPATH to source directory for module imports +export PYTHONPATH="${SRC_DIR}:${PYTHONPATH}" + +# Execute rocprof-compute script with venv Python +exec "${VENV_DIR}/bin/python3" "${SRC_DIR}/rocprof-compute" "$@" +WRAPPER_EOF + chmod +x "${ROCPROF_SOURCE_BIN}" + + info "rocprofiler-compute installed successfully!" + echo " Binary: ${ROCPROF_SOURCE_BIN}" + echo "" +} + +# ============================================================================ +# Commands +# ============================================================================ + +# Setup: Create Python venv and install rocprof-compute dependencies +cmd_setup() { + echo "Setting up rocprof-compute profiling environment..." + echo "===========================================" + + # Check if rocprof-compute exists, install from source if not + if [ -z "${ROCPROF_BIN}" ] || [ ! -f "${ROCPROF_BIN}" ]; then + warn "rocprof-compute not found in standard locations" + echo "" + echo "Searched locations:" + for bin in "${ROCPROF_CANDIDATES[@]}"; do + echo " - $bin" + done + echo "" + + # Check if we can install from source + if ! command -v git &>/dev/null; then + error "git is required to install from source" + return 1 + fi + if ! command -v python3 &>/dev/null; then + error "python3 is required to install from source" + return 1 + fi + + echo "Installing rocprofiler-compute from source..." + echo "" + if ! install_from_source; then + error "Failed to install rocprofiler-compute from source" + return 1 + fi + + # Update configuration with source installation + ROCPROF_BIN="${ROCPROF_SOURCE_BIN}" + ROCM_REQUIREMENTS="${ROCPROF_SOURCE_DIR}/libexec/rocprofiler-compute/requirements.txt" + fi + info "Using rocprof-compute: ${ROCPROF_BIN}" + + # Check requirements file (only needed for non-source installs that use separate venv) + if [ -z "${ROCM_REQUIREMENTS}" ] || [ ! -f "${ROCM_REQUIREMENTS}" ]; then + # For source installs, requirements are bundled + if [[ "${ROCPROF_BIN}" == "${ROCPROF_SOURCE_BIN}" ]]; then + ROCM_REQUIREMENTS="${ROCPROF_SOURCE_DIR}/libexec/rocprofiler-compute/requirements.txt" + else + error "ROCm requirements file not found" + local expected_path + expected_path="$(dirname "$(dirname "${ROCPROF_BIN}")")/libexec/rocprofiler-compute/requirements.txt" + echo "Expected at: ${expected_path}" + echo "Set CK_ROCM_REQUIREMENTS to override" + return 1 + fi + fi + + # Check GPU access + if [ ! -r /dev/kfd ]; then + warn "No read access to /dev/kfd - GPU profiling may fail" + warn "Add user to video/render group: sudo usermod -a -G video,render \$USER" + fi + + # For source installations, the venv is already set up - just create wrapper + if [[ "${ROCPROF_BIN}" == "${ROCPROF_SOURCE_BIN}" ]]; then + # Source install already has everything set up + local wrapper + wrapper=$(get_rocprof_wrapper) + mkdir -p "$(dirname "${wrapper}")" + + # For source install, wrapper just calls the source binary + cat > "${wrapper}" << WRAPPER_EOF +#!/bin/bash +# rocprof-compute wrapper (using source installation) +exec "${ROCPROF_BIN}" "\$@" +WRAPPER_EOF + chmod +x "${wrapper}" + info "Wrapper created at ${wrapper}" + + # Create marker file for venv directory + mkdir -p "${VENV_PATH}/bin" + touch "${VENV_PATH}/.source-install" + else + # System install - need to set up venv with dependencies + # Install uv if needed + if ! command -v uv &>/dev/null; then + info "Installing uv package manager via pip..." + if ! python3 -m pip install --user uv; then + error "Failed to install uv package manager" + return 1 + fi + export PATH="${HOME}/.local/bin:${PATH}" + if ! command -v uv &>/dev/null; then + error "uv installed but not found in PATH" + echo "Try adding ~/.local/bin to your PATH" + return 1 + fi + fi + + # Create venv + if [ -d "${VENV_PATH}" ]; then + info "Python venv already exists at ${VENV_PATH}" + else + info "Creating Python venv at ${VENV_PATH}..." + uv venv "${VENV_PATH}" + fi + + # Install dependencies + info "Installing dependencies..." + uv pip install --python "${VENV_PATH}/bin/python" -r "${ROCM_REQUIREMENTS}" + uv pip install --python "${VENV_PATH}/bin/python" 'pandas<3.0' + + # Create wrapper script + local wrapper + wrapper=$(get_rocprof_wrapper) + mkdir -p "$(dirname "${wrapper}")" + cat > "${wrapper}" << WRAPPER_EOF +#!/bin/bash +# rocprof-compute wrapper using venv Python +VENV_DIR="\$(cd "\$(dirname "\$0")/.." && pwd)" +exec "\${VENV_DIR}/bin/python" "${ROCPROF_BIN}" "\$@" +WRAPPER_EOF + chmod +x "${wrapper}" + info "Wrapper created at ${wrapper}" + fi + + # Create workload directory + mkdir -p "${WORKLOAD_DIR}" + info "Workload directory: ${WORKLOAD_DIR}" + + echo "" + info "Setup complete! You can now use:" + echo " ck-rocprof run " +} + +# Detect GPU architecture +detect_gpu_arch() { + # Allow override via environment variable + if [ -n "${CK_GPU_TARGET:-}" ]; then + echo "${CK_GPU_TARGET}" + return 0 + fi + + if command -v rocminfo &>/dev/null; then + # Try marketing name first (MI350, MI300X) + local marketing_name + marketing_name=$(rocminfo 2>/dev/null | grep 'Marketing Name:' | grep -oE 'MI[0-9]+[A-Z]*' | head -1) + if [ -n "$marketing_name" ]; then + echo "$marketing_name" + return 0 + fi + + # Fallback to gfx name + local gfx_name + gfx_name=$(rocminfo 2>/dev/null | grep -oE 'gfx[0-9a-z]+' | head -1) + if [ -n "$gfx_name" ]; then + echo "$gfx_name" + return 0 + fi + fi + + # Try existing workload directories + if [ -d "${WORKLOAD_DIR}" ]; then + local first_dir + first_dir=$(find "${WORKLOAD_DIR}" -maxdepth 2 -type d \( -name 'gfx*' -o -name 'MI*' \) 2>/dev/null | head -1) + if [ -n "$first_dir" ]; then + basename "$first_dir" + return 0 + fi + fi + + # Final fallback - use gfx950 consistent with common.sh + echo "gfx950" +} + +# Run profiling +cmd_run() { + # Validate argument count before shifting + if [ $# -lt 2 ]; then + error "name and executable required" + echo "Usage: ck-rocprof run [args]" + return 1 + fi + + local name="$1" + local executable="$2" + shift 2 + local -a exe_args=("$@") + + # Validate workload name (prevents path traversal) + if ! validate_workload_name "$name"; then + return 1 + fi + + # Check setup + if ! is_setup_complete; then + error "Profiling environment not set up" + echo "Run: ck-rocprof setup" + return 1 + fi + + # Check if executable exists + if [ ! -f "$executable" ]; then + error "Executable not found: $executable" + return 1 + fi + + local wrapper + wrapper=$(get_rocprof_wrapper) + local gpu_arch + gpu_arch=$(detect_gpu_arch) + + echo "Profiling: $executable ${exe_args[*]}" + echo "Run name: $name" + echo "GPU arch: $gpu_arch" + echo "===========================================" + + # Build command with proper escaping to prevent shell injection + # --no-roof skips roofline analysis to speed up profiling + local escaped_executable + escaped_executable=$(printf '%q' "$executable") + local escaped_workload_dir + escaped_workload_dir=$(printf '%q' "${WORKLOAD_DIR}/${name}") + + local cmd="${wrapper} profile --no-roof --path ${escaped_workload_dir} --name ${name} -- ${escaped_executable}" + for arg in "${exe_args[@]}"; do + cmd="${cmd} $(printf '%q' "$arg")" + done + + # Run profiling + bash -c "${cmd}" + + echo "" + info "Profiling complete" + echo "Results saved to: ${WORKLOAD_DIR}/${name}/" + echo "" + echo "Analyze with: ck-rocprof analyze ${name}" +} + +# Find workload path for a given run name +find_workload_path() { + local name="$1" + local run_dir="${WORKLOAD_DIR}/${name}" + + if [ ! -d "$run_dir" ]; then + return 1 + fi + + # Check if profiling data exists + if [ -f "${run_dir}/pmc_perf.csv" ]; then + echo "$run_dir" + return 0 + fi + + return 1 +} + +# Analyze profiling results +cmd_analyze() { + local name="$1" + local block="${2:-12}" # Default to block 12 (LDS metrics) + + if [ -z "$name" ]; then + error "name required" + echo "Usage: ck-rocprof analyze [block]" + return 1 + fi + + # Validate workload name (prevents path traversal) + if ! validate_workload_name "$name"; then + return 1 + fi + + # Check setup + if ! is_setup_complete; then + error "Profiling environment not set up" + echo "Run: ck-rocprof setup" + return 1 + fi + + local wrapper + wrapper=$(get_rocprof_wrapper) + local workload_path + workload_path=$(find_workload_path "${name}") + + if [ -z "$workload_path" ]; then + error "Profiling results not found for '${name}'" + echo "" + echo "Available runs:" + cmd_list + return 1 + fi + + echo "Analyzing: ${name} (Block ${block})" + echo "===========================================" + echo "" + + "${wrapper}" analyze --path "${workload_path}" --block "${block}" +} + +# Compare two profiling runs +cmd_compare() { + local name1="$1" + local name2="$2" + + if [ -z "$name1" ] || [ -z "$name2" ]; then + error "two run names required" + echo "Usage: ck-rocprof compare " + return 1 + fi + + # Validate workload names (prevents path traversal) + if ! validate_workload_name "$name1"; then + return 1 + fi + if ! validate_workload_name "$name2"; then + return 1 + fi + + # Check setup + if ! is_setup_complete; then + error "Profiling environment not set up" + echo "Run: ck-rocprof setup" + return 1 + fi + + # Verify both runs exist + local path1 + path1=$(find_workload_path "${name1}") + local path2 + path2=$(find_workload_path "${name2}") + + if [ -z "$path1" ]; then + error "Profiling results not found for '${name1}'" + return 1 + fi + + if [ -z "$path2" ]; then + error "Profiling results not found for '${name2}'" + return 1 + fi + + echo "Comparing profiling runs:" + echo " Baseline: ${name1}" + echo " Optimized: ${name2}" + echo "===========================================" + echo "" + + echo "=== ${name1} - Block 12 (LDS) ===" + cmd_analyze "${name1}" 12 2>/dev/null | head -40 + + echo "" + echo "=== ${name2} - Block 12 (LDS) ===" + cmd_analyze "${name2}" 12 2>/dev/null | head -40 + + echo "" + echo "===========================================" + echo "For detailed analysis, run:" + echo " ck-rocprof analyze ${name1} 12" + echo " ck-rocprof analyze ${name2} 12" +} + +# List available profiling runs +cmd_list() { + if [ ! -d "${WORKLOAD_DIR}" ]; then + echo "No profiling runs found (workload directory doesn't exist)" + return 0 + fi + + local runs + runs=$(find "${WORKLOAD_DIR}" -maxdepth 1 -mindepth 1 -type d -exec basename {} \; 2>/dev/null | sort) + + if [ -z "$runs" ]; then + echo "No profiling runs found in ${WORKLOAD_DIR}" + return 0 + fi + + echo "Available profiling runs:" + echo "===========================================" + + while IFS= read -r run; do + local path + path=$(find_workload_path "$run") + + if [ -n "$path" ]; then + local size + size=$(get_size "$path") + local date + date=$(get_date "$path") + printf " %-25s [%s, %s]\n" "$run" "$size" "$date" + else + printf " %-25s [no data]\n" "$run" + fi + done <<< "$runs" + + echo "" + echo "Analyze with: ck-rocprof analyze " +} + +# Clean (remove) profiling runs +cmd_clean() { + local name="${1:-}" + + if [ -z "$name" ]; then + error "name required (or use --all to remove all runs)" + echo "Usage: ck-rocprof clean " + echo " ck-rocprof clean --all" + return 1 + fi + + if [ "$name" = "--all" ]; then + # Remove all profiling runs + if [ ! -d "${WORKLOAD_DIR}" ]; then + echo "No profiling runs to clean" + return 0 + fi + + echo "This will remove ALL profiling runs in ${WORKLOAD_DIR}" + read -r -p "Are you sure? [y/N] " confirm + if [[ ! "$confirm" =~ ^[Yy]$ ]]; then + echo "Cancelled" + return 0 + fi + + rm -rf "${WORKLOAD_DIR:?}"/* + info "All profiling runs removed" + else + # Validate name + if ! validate_workload_name "$name"; then + return 1 + fi + + local run_dir="${WORKLOAD_DIR}/${name}" + if [ ! -d "$run_dir" ]; then + error "Profiling run not found: ${name}" + return 1 + fi + + rm -rf "${run_dir}" + info "Removed profiling run: ${name}" + fi +} + +# Show status information +cmd_status() { + echo "CK ROCProf Status" + echo "===========================================" + echo "" + + # rocprof-compute binary + if [ -n "${ROCPROF_BIN}" ] && [ -f "${ROCPROF_BIN}" ]; then + echo "rocprof-compute: ${ROCPROF_BIN}" + else + echo "rocprof-compute: not found" + fi + echo "" + + # Paths + echo "Paths:" + echo " Venv: ${VENV_PATH}" + echo " Workloads: ${WORKLOAD_DIR}" + echo "" + + # Setup status + echo "Setup status:" + if is_setup_complete; then + echo " Profiling environment: ready" + else + echo " Profiling environment: not configured (run 'ck-rocprof setup')" + fi + echo "" + + # Workload count + if [ -d "${WORKLOAD_DIR}" ]; then + local count + count=$(find "${WORKLOAD_DIR}" -maxdepth 1 -mindepth 1 -type d 2>/dev/null | wc -l) + echo "Profiling runs: ${count}" + else + echo "Profiling runs: 0" + fi +} + +# ============================================================================ +# Main command dispatcher +# ============================================================================ + +case "${1:-}" in + setup) + cmd_setup + ;; + run) + shift + cmd_run "$@" + ;; + analyze) + shift + cmd_analyze "$@" + ;; + compare) + shift + cmd_compare "$@" + ;; + list) + cmd_list + ;; + clean) + shift + cmd_clean "$@" + ;; + status) + cmd_status + ;; + help|--help|-h) + show_help + ;; + *) + if [ -z "${1:-}" ]; then + show_help + else + echo "Unknown command: ${1}" + echo "" + show_help + exit 1 + fi + ;; +esac diff --git a/script/tools/ck-rocprof.md b/script/tools/ck-rocprof.md new file mode 100644 index 0000000000..0588846097 --- /dev/null +++ b/script/tools/ck-rocprof.md @@ -0,0 +1,167 @@ +# CK ROCProf Tool + +GPU performance profiling for Composable Kernel applications using AMD rocprof-compute. + +**Note:** This is a native-only tool. For Docker usage, run via `ck-docker exec ck-rocprof ...` + +## Quick Start + +```bash +# One-time setup (requires rocprofiler-compute installed) +./script/tools/ck-rocprof setup + +# Profile executable +cd build +../script/tools/ck-rocprof run baseline ./bin/tile_example_gemm_universal + +# Analyze LDS metrics +../script/tools/ck-rocprof analyze baseline + +# Compare optimizations +../script/tools/ck-rocprof run optimized ./bin/tile_example_gemm_universal +../script/tools/ck-rocprof compare baseline optimized +``` + +## Commands + +### `setup` +One-time setup: creates Python venv, installs dependencies, configures rocprof-compute. + +### `run [args]` +Profile executable and save results. + +```bash +# Basic profiling +ck-rocprof run baseline ./bin/gemm_example + +# With arguments +ck-rocprof run large_matrix ./bin/gemm_example -m 8192 -n 8192 -k 4096 + +# Test filtering +ck-rocprof run unit_test ./bin/test_gemm --gtest_filter="*Fp16*" +``` + +### `analyze [block]` +Display profiling metrics (default: Block 12 - LDS). + +```bash +ck-rocprof analyze baseline # LDS metrics +ck-rocprof analyze baseline 2 # L2 Cache +ck-rocprof analyze baseline 7 # Instruction Mix +``` + +### `compare ` +Side-by-side comparison of two runs. + +### `list` +List all profiling runs with size and date. + +### `clean ` / `clean --all` +Remove profiling runs. Use `--all` to remove all runs. + +### `status` +Show current configuration: mode (native/Docker), paths, setup status. + +## Key LDS Metrics (Block 12) + +**Target Values:** +- Bank Conflicts/Access: <0.01 (1% conflict rate) +- Bank Conflict Rate: >90% of peak bandwidth + +**Critical Metrics:** +- **12.2.9 Bank Conflicts/Access**: Direct conflict measure + - Baseline (naive): ~0.04 (4% conflicts) + - Optimized: <0.005 (<0.5% conflicts) +- **12.2.12 Bank Conflict Cycles**: Wasted cycles per kernel +- **12.2.17 LDS Data FIFO Full**: Memory system pressure + +## Optimization Workflow + +```bash +# 1. Baseline +ck-rocprof run baseline ./bin/my_kernel + +# 2. Check conflicts +ck-rocprof analyze baseline +# Look for Bank Conflicts/Access > 0.02 + +# 3. Optimize code (XOR transforms, padding, etc.) +# ... edit source ... + +# 4. Test optimization +ninja my_kernel +ck-rocprof run optimized ./bin/my_kernel + +# 5. Verify improvement +ck-rocprof compare baseline optimized +# Target: 8-10x reduction in conflicts +``` + +## Environment Variables + +- `CK_PROFILE_VENV`: Python venv path (default: `$PROJECT/.ck-rocprof-venv`) +- `CK_ROCPROF_BIN`: rocprof-compute binary path (auto-detected from PATH or /opt/rocm) +- `CK_ROCM_REQUIREMENTS`: Path to rocprofiler-compute requirements.txt (auto-detected) +- `CK_WORKLOAD_DIR`: Results directory (default: `$PROJECT/build/workloads`) +- `CK_GPU_TARGET`: Override GPU detection (e.g., `gfx950`, `MI300X`) + +## Interpreting Results + +**Good Performance:** +``` +Bank Conflicts/Access: <0.01 +Bank Conflict Rate: >90% of peak +LDS Data FIFO Full: Minimal cycles +``` + +**Needs Optimization:** +``` +Bank Conflicts/Access: >0.02 +Bank Conflict Cycles: High MAX values +LDS Data FIFO Full: High memory pressure +``` + +## Troubleshooting + +**"Profiling environment not set up"** +```bash +ck-rocprof setup +``` + +**"rocprof-compute not found"** +```bash +export CK_ROCPROF_BIN=/custom/path/rocprof-compute +ck-rocprof setup +``` + +**"Profiling results not found"** +```bash +ck-rocprof list # Check available runs +rocminfo | grep gfx # Verify GPU arch +export CK_GPU_TARGET=gfx950 # Override if needed +``` + +## Storage Layout + +Results stored in `workloads//`: +- `pmc_perf.csv`: Performance counters (primary data file) +- `perfmon/`: Input metric files +- `out/`: Raw output data from profiler runs +- `log.txt`: Profiling log + +## Technical Details + +- **Setup**: Creates isolated Python venv, installs dependencies +- **Profiling**: Runs `rocprof-compute profile --name -- ` +- **Analysis**: Runs `rocprof-compute analyze --path --block ` +- **GPU Support**: MI300/MI350 series, auto-detects architecture + +## Related Tools + +- `ck-docker`: Container management +- `rocprof-compute`: AMD GPU profiler v2 +- `rocm-smi`: System monitoring + +## License + +Copyright (c) Advanced Micro Devices, Inc. SPDX-License-Identifier: MIT diff --git a/script/tools/ck-test b/script/tools/ck-test index 712f904596..1ee8d0defd 100755 --- a/script/tools/ck-test +++ b/script/tools/ck-test @@ -2,7 +2,8 @@ # Copyright (c) Advanced Micro Devices, Inc., or its affiliates. # SPDX-License-Identifier: MIT -# CK Test - Build and test Composable Kernel in Docker +# CK Test - Run Composable Kernel tests +# Environment-agnostic: works natively on ROCm hosts or inside containers set -e set -o pipefail @@ -12,155 +13,219 @@ SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" source "${SCRIPT_DIR}/common.sh" # Initialize configuration -PROJECT_ROOT=$(get_project_root "${SCRIPT_DIR}") -CONTAINER_NAME=$(get_container_name "${PROJECT_ROOT}") +PROJECT_ROOT=$(find_project_root "${SCRIPT_DIR}" || get_project_root "${SCRIPT_DIR}") +BUILD_DIR=$(get_build_dir "${PROJECT_ROOT}") # Help message show_help() { cat << EOF -CK Test - Build and test Composable Kernel in Docker +CK Test - Run Composable Kernel tests -Usage: ck-test [options] [test_options] +Usage: ck-test [options] [test_name] [-- gtest_options] Options: -h, --help Show this help message - --name Specify container name - --reconfigure Reconfigure CMake before building + --build-dir Build directory (default: ./build) --no-build Skip building, run test directly + --list List available tests + --smoke Run all smoke tests (via CTest -L SMOKE_TEST) + --regression Run all regression tests (via CTest -L REGRESSION_TEST) + --all Run all tests (via CTest) + --filter Shorthand for --gtest_filter= Arguments: - test_name Name of test executable (required) - test_options Additional options passed to test (e.g., --gtest_filter=*) + test_name Name of test executable (optional for --smoke/--regression/--all) + gtest_options Additional options passed to test (after --) Environment: - CK_CONTAINER_NAME - Override default container name - GPU_TARGET - Override GPU target detection (e.g., gfx950, gfx942) + CK_BUILD_DIR - Override build directory Examples: - ck-test test_amdgcn_mma - ck-test test_amdgcn_mma --gtest_filter=*Fp16* - ck-test --name my_container test_amdgcn_mma - ck-test --reconfigure test_amdgcn_mma + ck-test test_amdgcn_mma # Build and run specific test + ck-test test_amdgcn_mma --filter '*Fp16*' # Run with gtest filter + ck-test test_amdgcn_mma -- --gtest_filter=*Fp16* # Explicit gtest options + ck-test --no-build test_amdgcn_mma # Run without rebuilding + ck-test --list # List available tests + ck-test --smoke # Run all smoke tests + ck-test --regression # Run all regression tests + ck-test --all # Run all tests EOF } # Parse arguments test_name="" -reconfigure=false no_build=false -test_options=() +list_tests=false +run_smoke=false +run_regression=false +run_all=false +gtest_filter="" +gtest_options=() +parsing_gtest=false while [[ $# -gt 0 ]]; do + if [ "$parsing_gtest" = true ]; then + gtest_options+=("$1") + shift + continue + fi + case $1 in -h|--help) show_help exit 0 ;; - --name) - CONTAINER_NAME="$2" + --build-dir) + require_arg "$1" "${2:-}" + BUILD_DIR="$2" shift 2 ;; - --reconfigure) - reconfigure=true - shift - ;; --no-build) no_build=true shift ;; - --gtest_*|--help) - test_options+=("$1") + --list) + list_tests=true + shift + ;; + --smoke) + run_smoke=true + shift + ;; + --regression) + run_regression=true + shift + ;; + --all) + run_all=true + shift + ;; + --filter) + require_arg "$1" "${2:-}" + gtest_filter="$2" + shift 2 + ;; + --) + parsing_gtest=true + shift + ;; + --gtest_*) + gtest_options+=("$1") shift ;; *) if [ -z "$test_name" ]; then test_name="$1" else - test_options+=("$1") + gtest_options+=("$1") fi shift ;; esac done -# Validate test name +# Add filter to gtest options if specified +if [ -n "$gtest_filter" ]; then + gtest_options+=("--gtest_filter=${gtest_filter}") +fi + +# Validate mutual exclusivity of test suite options +suite_count=0 +[ "$run_smoke" = true ] && suite_count=$((suite_count + 1)) +[ "$run_regression" = true ] && suite_count=$((suite_count + 1)) +[ "$run_all" = true ] && suite_count=$((suite_count + 1)) + +if [ "$suite_count" -gt 1 ]; then + error "Options --smoke, --regression, and --all are mutually exclusive" + exit 1 +fi + +# Check build is configured +if ! is_build_configured "${BUILD_DIR}"; then + error "Build not configured. Run 'ck-configure' first" + exit 1 +fi + +# Handle --list +if [ "$list_tests" = true ]; then + info "Available tests:" + if [ -d "${BUILD_DIR}/bin" ]; then + ls -1 "${BUILD_DIR}/bin/" 2>/dev/null | grep -E '^test_' | sort || echo " (No test binaries found)" + else + echo " (No bin directory found)" + fi + echo "" + echo "CTest labels:" + cd "${BUILD_DIR}" + ctest -N 2>/dev/null | head -20 || echo " (Run 'ctest -N' for full list)" + exit 0 +fi + +# Handle CTest-based test suites +if [ "$run_smoke" = true ] || [ "$run_regression" = true ] || [ "$run_all" = true ]; then + cd "${BUILD_DIR}" + + ctest_cmd=(ctest --output-on-failure) + + if [ "$run_smoke" = true ]; then + ctest_cmd+=(-L SMOKE_TEST) + info "Running smoke tests..." + elif [ "$run_regression" = true ]; then + ctest_cmd+=(-L REGRESSION_TEST) + info "Running regression tests..." + else + info "Running all tests..." + fi + + "${ctest_cmd[@]}" + exit_code=$? + + echo "" + if [ $exit_code -eq 0 ]; then + info "Tests completed successfully" + else + error "Tests failed with exit code: ${exit_code}" + fi + exit $exit_code +fi + +# Validate test name for individual test runs if [ -z "$test_name" ]; then - echo "Error: test_name required" + error "test_name required (or use --smoke/--regression/--all for test suites)" echo "" show_help exit 1 fi -# Ensure container is running -if ! container_is_running "${CONTAINER_NAME}"; then - echo "Container '${CONTAINER_NAME}' not running. Starting..." - "${SCRIPT_DIR}/ck-start" "${CONTAINER_NAME}" - echo "" -fi - -# Configure CMake if needed or requested -if [ "$reconfigure" = true ] || ! docker exec "${CONTAINER_NAME}" test -f /workspace/build/build.ninja 2>/dev/null; then - echo "Detecting GPU target..." - GPU_TARGET_DETECTED=$(detect_gpu_target "${CONTAINER_NAME}") - - if [ "$reconfigure" = true ]; then - echo "Reconfiguring CMake from scratch for GPU target: ${GPU_TARGET_DETECTED}" - else - echo "Configuring build with CMake for GPU target: ${GPU_TARGET_DETECTED}" - fi - - docker exec "${CONTAINER_NAME}" bash -c " - cd /workspace || exit 1 - rm -rf /workspace/build - mkdir /workspace/build - cd /workspace/build || exit 1 - cmake .. -GNinja \ - -DGPU_TARGETS=${GPU_TARGET_DETECTED} \ - -DCMAKE_BUILD_TYPE=Release \ - -DCMAKE_CXX_COMPILER=/opt/rocm/llvm/bin/clang++ \ - -DBUILD_TESTING=ON 2>&1 | tail -30 - " - echo "" -fi - # Build test if needed (unless --no-build is specified) if [ "$no_build" = false ]; then - if ! docker exec "${CONTAINER_NAME}" test -f "/workspace/build/bin/${test_name}" 2>/dev/null; then - echo "Building ${test_name}..." - docker exec "${CONTAINER_NAME}" bash -c " - cd /workspace/build || exit 1 - ninja ${test_name} 2>&1 - " - echo "" - else - echo "Test executable found, rebuilding to ensure latest version..." - docker exec "${CONTAINER_NAME}" bash -c " - cd /workspace/build || exit 1 - ninja ${test_name} 2>&1 - " - echo "" - fi + info "Building ${test_name}..." + "${SCRIPT_DIR}/ck-build" --build-dir "${BUILD_DIR}" "${test_name}" + echo "" +fi + +# Verify test executable exists +test_binary="${BUILD_DIR}/bin/${test_name}" +if [ ! -f "$test_binary" ]; then + error "Test executable not found: ${test_binary}" + echo "Run 'ck-build ${test_name}' first" + exit 1 fi # Run test -echo "Running: ${test_name} ${test_options[*]}" -echo "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━" +echo "Running: ${test_name} ${gtest_options[*]}" +echo "---" -# Build the command with proper quoting -cmd="cd /workspace/build && ./bin/${test_name}" -for opt in "${test_options[@]}"; do - cmd="${cmd} $(printf '%q' "$opt")" -done - -docker exec "${CONTAINER_NAME}" bash -c "${cmd}" +cd "${BUILD_DIR}" +"./bin/${test_name}" "${gtest_options[@]}" exit_code=$? -echo "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━" +echo "---" if [ $exit_code -eq 0 ]; then - echo "Test completed successfully" + info "Test completed successfully" else - echo "Test failed with exit code: ${exit_code}" + error "Test failed with exit code: ${exit_code}" fi exit $exit_code diff --git a/script/tools/common.sh b/script/tools/common.sh index 6683572c0f..e5a39cea67 100644 --- a/script/tools/common.sh +++ b/script/tools/common.sh @@ -74,14 +74,14 @@ container_is_running() { detect_gpu_target() { local container="$1" - # Allow override via GPU_TARGET environment variable - if [ -n "${GPU_TARGET:-}" ]; then - echo "${GPU_TARGET}" + # Allow override via CK_GPU_TARGET environment variable + if [ -n "${CK_GPU_TARGET:-}" ]; then + echo "${CK_GPU_TARGET}" return 0 fi docker exec "${container}" bash -c " - rocminfo 2>/dev/null | grep -oP 'gfx[0-9a-z]+' | head -1 || echo 'gfx950' + rocminfo 2>/dev/null | grep -oE 'gfx[0-9a-z]+' | head -1 || echo 'gfx950' " | tr -d '\r\n' } @@ -95,3 +95,87 @@ ensure_container_running() { "${script_dir}/ck-docker" start "${container}" fi } + +# ============================================================================ +# Native (non-Docker) utilities +# ============================================================================ + +# Output utilities +info() { echo "[info] $*"; } +warn() { echo "[warn] $*" >&2; } +error() { echo "[error] $*" >&2; } + +# Require argument for option (validates $2 exists and is not another flag) +require_arg() { + local option="$1" + local value="$2" + if [ -z "$value" ] || [[ "$value" == -* ]]; then + error "Option $option requires an argument" + exit 1 + fi +} + +# Native GPU detection (no Docker required) +detect_gpu_native() { + # Allow override via CK_GPU_TARGET environment variable + if [ -n "${CK_GPU_TARGET:-}" ]; then + echo "${CK_GPU_TARGET}" + return 0 + fi + + # Try rocminfo if available + if command -v rocminfo &>/dev/null; then + local gpu + gpu=$(rocminfo 2>/dev/null | grep -oE 'gfx[0-9a-z]+' | head -1) + if [ -n "$gpu" ]; then + echo "$gpu" + return 0 + fi + fi + + # Fallback + echo "gfx950" +} + +# Get build directory (respects CK_BUILD_DIR env var) +get_build_dir() { + local project_root="${1:-$(get_project_root "$(dirname "${BASH_SOURCE[0]}")")}" + echo "${CK_BUILD_DIR:-${project_root}/build}" +} + +# Check if build is configured (build.ninja exists) +is_build_configured() { + local build_dir="${1:-$(get_build_dir)}" + [ -f "${build_dir}/build.ninja" ] +} + +# Find project root from any subdirectory (walks up to find .git) +find_project_root() { + local dir="${1:-$(pwd)}" + while [ "$dir" != "/" ]; do + if [ -d "$dir/.git" ]; then + echo "$dir" + return 0 + fi + dir=$(dirname "$dir") + done + return 1 +} + +# List available CMake presets +list_cmake_presets() { + local project_root="${1:-$(find_project_root)}" + local presets_file="${project_root}/CMakePresets.json" + + if [ ! -f "$presets_file" ]; then + return 1 + fi + + # Extract non-hidden preset names + if command -v jq &>/dev/null; then + jq -r '.configurePresets[] | select(.hidden != true) | .name' "$presets_file" 2>/dev/null + else + # Fallback: sed-based extraction (more portable than grep -P) + sed -n 's/.*"name"[[:space:]]*:[[:space:]]*"\([^"]*\)".*/\1/p' "$presets_file" | grep -v '^use-' + fi +} From 6ff073784321a55ee276f38af195532d8d812670 Mon Sep 17 00:00:00 2001 From: MHYangAMD Date: Fri, 30 Jan 2026 10:52:19 +0800 Subject: [PATCH 5/8] Fix redundant cast in model sensitive rmsnorm (#3681) * Fix redundant cast * Fix linting --- .../rmsnorm2d_fwd_pipeline_model_sensitive_pass.hpp | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_model_sensitive_pass.hpp b/include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_model_sensitive_pass.hpp index de27b15952..f94d220b94 100644 --- a/include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_model_sensitive_pass.hpp +++ b/include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_model_sensitive_pass.hpp @@ -181,12 +181,10 @@ struct Rmsnorm2dFwdPipelineModelSensitiveT5Pass if constexpr(std::is_same_v) { - const auto tmp0 = - float_to_bf16(acc[idx] * inv_rms_[i_idx]); - const auto tmp1 = float_to_bf16( - type_convert(tmp0) * gamma_); - const auto rmsn_ = type_convert(tmp1); - rmsn(idx) = rmsn_; + const auto tmp = acc[idx] * inv_rms_[i_idx]; + const auto tmp_bf16 = float_to_bf16(tmp); + const auto rmsn_ = type_convert(tmp_bf16) * gamma_; + rmsn(idx) = rmsn_; } else { From f3d8b7210fb99827bcb1d1bdaf9672b3ae8fb209 Mon Sep 17 00:00:00 2001 From: vivienfanghuagood <89012307+vivienfanghuagood@users.noreply.github.com> Date: Fri, 30 Jan 2026 11:18:20 +0800 Subject: [PATCH 6/8] Extend CK fmha_batch_prefill kernel coverage to head_dim=256 (#3328) Co-authored-by: Po Yen Chen Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com> --- example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py | 1 + 1 file changed, 1 insertion(+) diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py b/example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py index 9a2d727253..42f686e0c0 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py @@ -630,6 +630,7 @@ class KernelComponentFactory: if dtype in ["fp16", "bf16"]: return { 128 : [FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], + 256 : [FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], } # fmt: skip elif dtype in ["fp8bf16"]: return { From 565fea26455b8e4f78ac57ed64d6bd12e701a9c9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Zolt=C3=A1n=20Lakatos?= <153429852+zsotakal@users.noreply.github.com> Date: Fri, 30 Jan 2026 08:22:54 +0100 Subject: [PATCH 7/8] fix undefined behaviour in softmax kernel (#3683) Co-authored-by: root --- include/ck/tensor_operation/gpu/grid/gridwise_softmax.hpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_softmax.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_softmax.hpp index 96e13ac55c..a6fa04a824 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_softmax.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_softmax.hpp @@ -26,7 +26,7 @@ __global__ void kernel_softmax(const GridDesc_M_K in_grid_desc_m_k, AccDataType alpha, const InDataType* const __restrict__ p_in_value_global, AccDataType beta, - OutDataType* const __restrict__ p_out_value_global) + OutDataType* p_out_value_global) { GridwiseReduction::Run(in_grid_desc_m_k, out_grid_desc_m_k, @@ -91,7 +91,7 @@ struct GridwiseSoftmax_mk_to_mk AccDataType alpha, const InDataType* const __restrict__ p_in_value_global, AccDataType beta, - OutDataType* const __restrict__ p_out_value_global) + OutDataType* p_out_value_global) { if constexpr(SweepOnce) { From 6a6177a246d6c81932fbb1061ad6a62e90b788a1 Mon Sep 17 00:00:00 2001 From: Erwin Terpstra Date: Fri, 30 Jan 2026 12:40:50 +0100 Subject: [PATCH 8/8] [CK_Tile] Support for a4w4 (fp4) in block scale gemm AB quant (#3603) * chore: split block scale example instances in more separate files to speed up compile times * wip: fp4 scaffolding for abquant * feat: add fp4 decoding-while-loading to abquant pipeline * feat: add support for fp4 CPU verification in abquant * chore: add time tracking to reference calculation * feat: add a4w4 test for blockscale gemm * feat: optimize reference calculation by preconverting values to AccType * feat: add fp4 to fp8 look-up table * fix: reference to wrong ComputeDataType field in QuantProblem * feat: type utilities for determining MFMA compute types * feat: packed fp4 for abquant weight preshuffle * feat: add separate tests for a4w4 base case, padding and preshuffleB * fix: fp4 conversion on gfx950 attempting to use non-supported method * fix: test case was using quant group sizes which don't work on gfx950 due to larger mfma tile size * chore: add fp4 preshuffleb mode to block scale example * chore: sanity check for packed types being 1 byte * chore: clarify tensor dimension indices with constants * chore: replace traits check with specialized check for packed types * style: some minor refactoring and cleanup * fix: correct conversion table for FNUZ fp8 * chore: add fp4 instances to main abquant instances again * chore: use same initialization branch for int4 and fp4 * chore: add missing initialization for fp4 in block scale gemm example --------- Co-authored-by: Thomas Ning --- .../gemm_abquant_quantgrouped.cpp | 30 +++++ .../38_block_scale_gemm/gemm_quant.cpp | 2 +- .../run_gemm_quant_example.inc | 124 +++++++++++------- include/ck_tile/core.hpp | 1 + .../core/arch/amd_buffer_addressing.hpp | 3 +- .../arch/amd_buffer_addressing_builtins.hpp | 2 +- include/ck_tile/core/numeric/pk_fp4.hpp | 88 ++++++++++++- include/ck_tile/core/numeric/pk_int4.hpp | 11 ++ include/ck_tile/core/numeric/vector_type.hpp | 1 + .../core/utility/mixed_prec_compute_type.hpp | 54 ++++++++ include/ck_tile/core/utility/type_traits.hpp | 17 +++ .../ck_tile/host/reference/reference_gemm.hpp | 82 ++++++------ .../ops/common/load_interleaved_pk_type.hpp | 19 ++- .../unary_element_wise_operation.hpp | 23 ++++ ..._pipeline_agmem_bgmem_creg_base_policy.hpp | 20 ++- ...versal_gemm_ar_aquant_flatbr_bquant_cr.hpp | 15 ++- ..._universal_gemm_as_aquant_bs_bquant_cr.hpp | 19 ++- .../gemm_abquant_pipeline_ag_bg_cr_policy.hpp | 5 +- .../gemm_abquant_pipeline_ag_bg_cr_v3.hpp | 36 ++--- .../pipeline/gemm_quant_pipeline_problem.hpp | 30 +++-- ..._abquant_pipeline_ag_bg_cr_base_policy.hpp | 9 +- .../gemm_wp_abquant_pipeline_ag_bg_cr_v2.hpp | 51 ++++--- test/ck_tile/gemm_block_scale/CMakeLists.txt | 16 +++ .../test_gemm_quant_abquant_a4w4_base.cpp | 44 +++++++ .../test_gemm_quant_abquant_a4w4_padding.cpp | 65 +++++++++ ...est_gemm_quant_abquant_a4w4_preshuffle.cpp | 44 +++++++ .../gemm_block_scale/test_gemm_quant_base.hpp | 2 +- .../test_gemm_quant_fixtures.hpp | 4 +- 28 files changed, 642 insertions(+), 175 deletions(-) create mode 100644 include/ck_tile/core/utility/mixed_prec_compute_type.hpp create mode 100644 test/ck_tile/gemm_block_scale/test_gemm_quant_abquant_a4w4_base.cpp create mode 100644 test/ck_tile/gemm_block_scale/test_gemm_quant_abquant_a4w4_padding.cpp create mode 100644 test/ck_tile/gemm_block_scale/test_gemm_quant_abquant_a4w4_preshuffle.cpp diff --git a/example/ck_tile/38_block_scale_gemm/gemm_abquant_quantgrouped.cpp b/example/ck_tile/38_block_scale_gemm/gemm_abquant_quantgrouped.cpp index a7cb88079b..e4e0503b5a 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_abquant_quantgrouped.cpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_abquant_quantgrouped.cpp @@ -164,5 +164,35 @@ static auto _ = []() { BQuantGroupSize, ck_tile::QuantType::ABQuantGrouped>(arg_parser); }; + lut[hash_multiple_strings( + {"fp4", "abquant", "non-preshuffleb", "non-preshufflequant", "1x128x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using AQuantGroupSize = ck_tile::QuantGroupShape>; + using BQuantGroupSize = ck_tile::QuantGroupShape>; + using TypeConfig = decltype(GemmQuantTypeConfig{}); + return run_gemm_example_prec_type, + TypeConfig, + AQuantGroupSize, + BQuantGroupSize, + ck_tile::QuantType::ABQuantGrouped>(arg_parser); + }; + lut[hash_multiple_strings( + {"fp4", "abquant", "preshuffleb", "non-preshufflequant", "1x128x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using AQuantGroupSize = ck_tile::QuantGroupShape>; + using BQuantGroupSize = ck_tile::QuantGroupShape>; + using TypeConfig = decltype(GemmQuantTypeConfig{}); + return run_gemm_example_prec_type, + TypeConfig, + AQuantGroupSize, + BQuantGroupSize, + ck_tile::QuantType::ABQuantGrouped>(arg_parser); + }; return 0; }(); diff --git a/example/ck_tile/38_block_scale_gemm/gemm_quant.cpp b/example/ck_tile/38_block_scale_gemm/gemm_quant.cpp index 1fbe4d7b47..cc4302a992 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_quant.cpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_quant.cpp @@ -32,7 +32,7 @@ auto create_args(int argc, char* argv[]) .insert("prec", "fp8", "Data type. For AQuant: fp8, bf8, i4fp8, or i4bf8; for Bquant: fp8, bf8, fp8i4, " - "or bf8i4; for ABQuant: fp8, bf8") + "or bf8i4; for ABQuant: fp8, bf8, fp4") .insert("warmup", "50", "Number of iterations before benchmarking the kernel") .insert("repeat", "1000", "Number of iterations to benchmark the kernel") .insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer") diff --git a/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc b/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc index 665c7828ad..540d5725dd 100644 --- a/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc +++ b/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc @@ -9,6 +9,7 @@ #include #include #include +#include #include "ck_tile/core/config.hpp" #include "ck_tile/ops/common/utils.hpp" @@ -35,10 +36,9 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str static_assert(std::is_same_v); constexpr bool transpose_c = GemmConfig::TransposeC; // QuantMode == ck_tile::QuantType::ABQuantGrouped; - using ComputeDataType = std::conditional_t; + + // Use automatically determined compute type from + using ComputeDataType = void; using GemmShape = ck_tile::TileGemmShape< ck_tile::sequence, @@ -80,7 +80,10 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str std::conditional_t< QuantMode == ck_tile::QuantType::AQuantGrouped, ck_tile::BaseGemmPipelineAgBgCrMem, - ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV2>>>; + std::conditional_t< + QuantMode == ck_tile::QuantType::ABQuantGrouped, + ck_tile::BaseGemmPipelineAgBgCrMem, + ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV2>>>>; const ck_tile::index_t K_split = ck_tile::integer_least_multiple(args.K, GemmConfig::K_Tile); const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split); @@ -182,30 +185,28 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str printf( "TiledPermuteN: %d (QuantGroupSize::kN=%d)\n", TiledPermuteN, BQuantGroupSize::kN); } - using GemmEpilogue = ck_tile::CShuffleEpilogue, - typename TypeConfig::ADataType, - typename TypeConfig::BDataType>, - ck_tile::tuple<>, - typename TypeConfig::AccDataType, - typename TypeConfig::CDataType, - ck_tile::tuple<>, - CLayout, - CDEElementWise, - TilePartitioner::MPerBlock, - TilePartitioner::NPerBlock, - GemmConfig::M_Warp, - GemmConfig::N_Warp, - GemmConfig::M_Warp_Tile, - GemmConfig::N_Warp_Tile, - GemmConfig::K_Warp_Tile, - transpose_c, - 1, - false, - 1, - TiledPermuteN>>; + + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem, + typename TypeConfig::AccDataType, + typename TypeConfig::CDataType, + ck_tile::tuple<>, + CLayout, + CDEElementWise, + TilePartitioner::MPerBlock, + TilePartitioner::NPerBlock, + GemmConfig::M_Warp, + GemmConfig::N_Warp, + GemmConfig::M_Warp_Tile, + GemmConfig::N_Warp_Tile, + GemmConfig::K_Warp_Tile, + transpose_c, + 1, + false, + 1, + TiledPermuteN>>; using Kernel = ck_tile::QuantGemmKernel; @@ -557,8 +558,7 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser, { if constexpr(std::is_same_v) { - ck_tile::FillUniformDistribution{-5.0f, 5.0f, fill_seed(gen)}( - b_k_n); + ck_tile::FillUniformDistribution{-5.0f, 5.0f, fill_seed(gen)}(b_k_n); ck_tile::FillUniformDistribution{-2.0f, 2.0f, fill_seed(gen)}( *bq_tensor_ptr); } @@ -594,18 +594,26 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser, } else if constexpr(QuantMode == ck_tile::QuantType::ABQuantGrouped) { - if constexpr(std::is_same_v) + if constexpr(std::is_same_v || + std::is_same_v) { - ck_tile::FillUniformDistribution{-5.0f, 5.0f, fill_seed(gen)}( - a_m_k); - ck_tile::FillUniformDistribution{-5.0f, 5.0f, fill_seed(gen)}( - b_k_n); + ck_tile::FillUniformDistribution{-5.0f, 5.0f, fill_seed(gen)}(a_m_k); } else { ck_tile::FillUniformDistribution{-2.0f, 3.0f, fill_seed(gen)}(a_m_k); + } + + if constexpr(std::is_same_v || + std::is_same_v) + { + ck_tile::FillUniformDistribution{-5.0f, 5.0f, fill_seed(gen)}(b_k_n); + } + else + { ck_tile::FillUniformDistribution{-2.0f, 3.0f, fill_seed(gen)}(b_k_n); } + ck_tile::FillUniformDistribution{-2.0f, 2.0f, fill_seed(gen)}( *aq_tensor_ptr); ck_tile::FillUniformDistribution{-2.0f, 2.0f, fill_seed(gen)}( @@ -723,12 +731,11 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser, } else if constexpr(QuantMode == ck_tile::QuantType::ABQuantGrouped) { - if constexpr(std::is_same_v) + if constexpr(std::is_same_v || + std::is_same_v) { - ck_tile::FillUniformDistribution{-5.0f, 5.0f, fill_seed(gen)}( - a_m_k); - ck_tile::FillUniformDistribution{-5.0f, 5.0f, fill_seed(gen)}( - b_k_n); + ck_tile::FillUniformDistribution{-5.0f, 5.0f, fill_seed(gen)}(a_m_k); + ck_tile::FillUniformDistribution{-5.0f, 5.0f, fill_seed(gen)}(b_k_n); } else { @@ -804,12 +811,11 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser, } else if constexpr(QuantMode == ck_tile::QuantType::ABQuantGrouped) { - if constexpr(std::is_same_v) + if constexpr(std::is_same_v || + std::is_same_v) { - ck_tile::FillUniformDistribution{-5.0f, 5.0f, fill_seed(gen)}( - a_m_k); - ck_tile::FillUniformDistribution{-5.0f, 5.0f, fill_seed(gen)}( - b_k_n); + ck_tile::FillUniformDistribution{-5.0f, 5.0f, fill_seed(gen)}(a_m_k); + ck_tile::FillUniformDistribution{-5.0f, 5.0f, fill_seed(gen)}(b_k_n); } else { @@ -984,10 +990,14 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser, if(arg_parser.get_int("v") == 1) { + std::cout << "Performing CPU verification..." << std::endl; + ck_tile::HostTensor c_m_n_host_ref( ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{}))); c_m_n_host_ref.SetZero(); + // Track start time for reference operation + auto start_reference_tick = std::chrono::high_resolution_clock::now(); if constexpr(QuantMode == ck_tile::QuantType::AQuantGrouped) { ck_tile::reference_gemm_quant( @@ -1061,6 +1074,9 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser, rtol_atol.at(ck_tile::number<0>{}), rtol_atol.at(ck_tile::number<1>{})); + // "Stop" our timer + auto verification_finished_tick = std::chrono::high_resolution_clock::now(); + if(!pass) { std::cout << "Relative error threshold: " << rtol_atol.at(ck_tile::number<0>{}) @@ -1068,6 +1084,21 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser, << std::endl; } std::cout << "The CPU verification result is:" << (pass ? "correct" : "fail") << std::endl; + + // Calculate and display reference timing + using DurationType = std::chrono::duration; + double reference_sec = std::chrono::duration_cast(verification_finished_tick - + start_reference_tick) + .count(); + double verification_sec = std::chrono::duration_cast( + verification_finished_tick - start_verification_tick) + .count(); + float reference_msec = static_cast(reference_sec * 1e3); + float verification_msec = static_cast(verification_sec * 1e3); + + std::cout << std::fixed << std::setprecision(1) << "CPU reference GEMM took " + << reference_msec << "ms, verification took " << verification_msec << "ms." + << std::endl; } else if(arg_parser.get_int("v") == 2) { @@ -1098,6 +1129,7 @@ int run_gemm_example_prec_type(const ck_tile::ArgParser& arg_parser) } if constexpr(std::is_same_v || + std::is_same_v || std::is_same_v || std::is_same_v || std::is_same_v) diff --git a/include/ck_tile/core.hpp b/include/ck_tile/core.hpp index f3596df9bd..438e44f5f1 100644 --- a/include/ck_tile/core.hpp +++ b/include/ck_tile/core.hpp @@ -91,6 +91,7 @@ #include "ck_tile/core/utility/ignore.hpp" #include "ck_tile/core/utility/literals.hpp" #include "ck_tile/core/utility/magic_div.hpp" +#include "ck_tile/core/utility/mixed_prec_compute_type.hpp" #include "ck_tile/core/utility/persistent_async_input_scheduler.hpp" #include "ck_tile/core/utility/philox_rand.hpp" #include "ck_tile/core/utility/print.hpp" diff --git a/include/ck_tile/core/arch/amd_buffer_addressing.hpp b/include/ck_tile/core/arch/amd_buffer_addressing.hpp index 7af2f558ad..8f9dd30bda 100644 --- a/include/ck_tile/core/arch/amd_buffer_addressing.hpp +++ b/include/ck_tile/core/arch/amd_buffer_addressing.hpp @@ -1544,7 +1544,8 @@ CK_TILE_DEVICE thread_buffer amd_buffer_load_impl(int32x4_t src_wave_buffe (N == 1 || N == 2 || N == 4 || N == 8 || N == 16 || N == 32)) || (std::is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || - (std::is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)), + (std::is_same::value && + (N == 1 || N == 2 || N == 4 || N == 8 || N == 16 || N == 32)), "wrong! not implemented"); using rtn_type = thread_buffer; diff --git a/include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp b/include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp index 9f9770df1b..42886b8ced 100644 --- a/include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp +++ b/include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp @@ -1414,7 +1414,7 @@ CK_TILE_DEVICE thread_buffer amd_buffer_load_impl(int32x4_t src_wave_buffe (std::is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16 || N == 32) || (std::is_same::value && - (N == 1 || N == 2 || N == 4 || N == 8 || N == 16))), + (N == 1 || N == 2 || N == 4 || N == 8 || N == 16 || N == 32))), "wrong! not implemented"); using rtn_type = thread_buffer; diff --git a/include/ck_tile/core/numeric/pk_fp4.hpp b/include/ck_tile/core/numeric/pk_fp4.hpp index cc23ce71a8..d74db6b336 100644 --- a/include/ck_tile/core/numeric/pk_fp4.hpp +++ b/include/ck_tile/core/numeric/pk_fp4.hpp @@ -6,6 +6,7 @@ #include #include "ck_tile/core/config.hpp" #include "ck_tile/core/numeric/half.hpp" +#include "ck_tile/core/numeric/float8.hpp" #include "ck_tile/core/numeric/mxfp_convert.hpp" #if defined(__gfx950__) @@ -23,6 +24,12 @@ using fp32x2_t = float __attribute__((ext_vector_type(2))); using fp16x2_t = _Float16 __attribute__((ext_vector_type(2))); using bf16x2_t = bfloat16_t __attribute__((ext_vector_type(2))); +#if CK_TILE_USE_CUSTOM_DATA_TYPE +using fp8x2_t = fp8_raw_t __attribute__((ext_vector_type(2))); +#else +using fp8x2_t = fp8_t __attribute__((ext_vector_type(2))); +#endif + // Helpers: constexpr-safe access to elements of ext_vector_type(2) // Some compilers don't allow operator[] in constant expressions for vector types. // We use bit_cast to a trivially copyable representation to extract lanes. @@ -98,6 +105,8 @@ struct pk_float4_e2m1_t CK_TILE_HOST_DEVICE constexpr fp16x2_t to_fp16x2(float scale = 1.f) const; CK_TILE_HOST_DEVICE constexpr bf16_t to_bf16(float scale = 1.f) const; CK_TILE_HOST_DEVICE constexpr bf16x2_t to_bf16x2(float scale = 1.f) const; + CK_TILE_HOST_DEVICE constexpr fp8_t to_fp8(float scale = 1.f) const; + CK_TILE_HOST_DEVICE constexpr fp8x2_t to_fp8x2(float scale = 1.f) const; CK_TILE_HOST_DEVICE constexpr operator float() const { return to_float(); } CK_TILE_HOST_DEVICE constexpr operator fp32x2_t() const { return to_fp32x2(); } @@ -105,6 +114,8 @@ struct pk_float4_e2m1_t CK_TILE_HOST_DEVICE constexpr operator fp16x2_t() const { return to_fp16x2(); } CK_TILE_HOST_DEVICE constexpr operator bf16_t() const { return to_bf16(); } CK_TILE_HOST_DEVICE constexpr operator bf16x2_t() const { return to_bf16x2(); } + CK_TILE_HOST_DEVICE constexpr operator fp8_t() const { return to_fp8(); } + CK_TILE_HOST_DEVICE constexpr operator fp8x2_t() const { return to_fp8x2(); } template CK_TILE_HOST_DEVICE constexpr pk_float4_e2m1_t unpack(number) const @@ -145,6 +156,49 @@ struct pk_float4_e2m1_t bit_cast(static_cast(0xC400)), // -4 bit_cast(static_cast(0xC600)) // -6 }; + +#if CK_TILE_USE_OCP_FP8 + // FP8 EM4E3 (OCP) representation + static constexpr fp8_t e2m1_to_fp8_table[16] = { + fp8_t(static_cast(0x00)), // 0 + fp8_t(static_cast(0x30)), // 0.5 + fp8_t(static_cast(0x38)), // 1 + fp8_t(static_cast(0x3C)), // 1.5 + fp8_t(static_cast(0x40)), // 2 + fp8_t(static_cast(0x44)), // 3 + fp8_t(static_cast(0x48)), // 4 + fp8_t(static_cast(0x4C)), // 6 + fp8_t(static_cast(0x00)), // -0 + fp8_t(static_cast(0xB0)), // -0.5 + fp8_t(static_cast(0xB8)), // -1 + fp8_t(static_cast(0xBC)), // -1.5 + fp8_t(static_cast(0xC0)), // -2 + fp8_t(static_cast(0xC4)), // -3 + fp8_t(static_cast(0xC8)), // -4 + fp8_t(static_cast(0xCC)) // -6 + }; +#else // CK_TILE_USE_FNUZ_FP8 + // FP8 E4M3 FNUZ + static constexpr fp8_t e2m1_to_fp8_table[16] = { + fp8_t(static_cast(0x00)), // 0 + fp8_t(static_cast(0x38)), // 0.5 + fp8_t(static_cast(0x40)), // 1 + fp8_t(static_cast(0x44)), // 1.5 + fp8_t(static_cast(0x48)), // 2 + fp8_t(static_cast(0x4C)), // 3 + fp8_t(static_cast(0x50)), // 4 + fp8_t(static_cast(0x54)), // 6 + fp8_t(static_cast(0x00)), // -0 + fp8_t(static_cast(0xB8)), // -0.5 + fp8_t(static_cast(0xC0)), // -1 + fp8_t(static_cast(0xC4)), // -1.5 + fp8_t(static_cast(0xC4)), // -2 + fp8_t(static_cast(0xCC)), // -3 + fp8_t(static_cast(0xD0)), // -4 + fp8_t(static_cast(0xD4)) // -6 + }; +#endif + #endif }; @@ -408,6 +462,27 @@ CK_TILE_HOST_DEVICE constexpr fp16x2_t pk_fp4_t::to_fp16x2(float scale) const type_convert(convert_to_float(_unpack(number<1>{}), scale))}; #endif } +CK_TILE_HOST_DEVICE constexpr fp8_t pk_fp4_t::to_fp8(float scale) const +{ + // NOTE: No specialized fp4 to fp8 instructions are available. Unsure whether fp4 to fp16 to fp8 + // would be better than the naive implementation below + // #if CK_TILE_FP4_CVT_DEVICE + // return impl::_from_f4(data, scale); + // #else + return fp8_t{type_convert(convert_to_float(_unpack(number<0>{}), scale))}; + // #endif +} +CK_TILE_HOST_DEVICE constexpr fp8x2_t pk_fp4_t::to_fp8x2(float scale) const +{ + // NOTE: No specialized fp4 to fp8 instructions are available. Unsure whether fp4 to fp16 to fp8 + // would be better than the naive implementation below + // #if CK_TILE_FP4_CVT_DEVICE + // return impl::_from_f4(data, scale); + // #else + return fp8x2_t{type_convert(convert_to_float(_unpack(number<0>{}), scale)), + type_convert(convert_to_float(_unpack(number<1>{}), scale))}; + // #endif +} #else CK_TILE_HOST_DEVICE constexpr float pk_fp4_t::to_float(float scale) const { @@ -415,7 +490,8 @@ CK_TILE_HOST_DEVICE constexpr float pk_fp4_t::to_float(float scale) const } CK_TILE_HOST_DEVICE constexpr fp32x2_t pk_fp4_t::to_fp32x2(float scale) const { - return fp32x2_t{e2m1_to_fp32_table[_unpack(number<0>{})] * scale, e2m1_to_fp32_table[_unpack(number<1>{}] * scale}; + return fp32x2_t{e2m1_to_fp32_table[_unpack(number<0>{})] * scale, + e2m1_to_fp32_table[_unpack(number<1>{})] * scale}; } CK_TILE_HOST_DEVICE constexpr fp16_t pk_fp4_t::to_fp16(float scale) const { @@ -428,6 +504,16 @@ CK_TILE_HOST_DEVICE constexpr fp16x2_t pk_fp4_t::to_fp16x2(float scale) const type_convert(type_convert(e2m1_to_fp16_table[_unpack(number<1>{})]) * scale)}; } +CK_TILE_HOST_DEVICE constexpr fp8_t pk_fp4_t::to_fp8(float scale) const +{ + return type_convert(e2m1_to_fp8_table[_unpack(number<0>{})]) * scale; +} +CK_TILE_HOST_DEVICE constexpr fp8x2_t pk_fp4_t::to_fp8x2(float scale) const +{ + return fp8x2_t{ + type_convert(type_convert(e2m1_to_fp8_table[_unpack(number<0>{})]) * scale), + type_convert(type_convert(e2m1_to_fp8_table[_unpack(number<1>{})]) * scale)}; +} #endif } // namespace ck_tile diff --git a/include/ck_tile/core/numeric/pk_int4.hpp b/include/ck_tile/core/numeric/pk_int4.hpp index d5df4d1917..9eb62a6ec4 100644 --- a/include/ck_tile/core/numeric/pk_int4.hpp +++ b/include/ck_tile/core/numeric/pk_int4.hpp @@ -6,6 +6,7 @@ #include "ck_tile/core/numeric/integral_constant.hpp" #include "ck_tile/core/numeric/math.hpp" #include "ck_tile/core/numeric/numeric.hpp" +#include "ck_tile/core/numeric/pk_fp4.hpp" #include "ck_tile/core/utility/bit_cast.hpp" #include "ck_tile/core/utility/random.hpp" #include @@ -23,6 +24,11 @@ struct pk_int4_t type data; CK_TILE_HOST_DEVICE constexpr pk_int4_t() : data{type{}} {} CK_TILE_HOST_DEVICE constexpr pk_int4_t(type init) : data{init} {} + + // NOTE: added for interface compatibility with pk_fp4_t + // Other data types could be added for greater similarity + CK_TILE_HOST_DEVICE constexpr fp32x2_t to_fp32x2() const; + CK_TILE_HOST_DEVICE constexpr operator fp32x2_t() const { return to_fp32x2(); } }; // limits @@ -186,4 +192,9 @@ CK_TILE_HOST_DEVICE int8x2_t pk_int4_t_to_int8x2_t(const pk_int4_t& x) return res; } +CK_TILE_HOST_DEVICE constexpr fp32x2_t pk_int4_t::to_fp32x2() const +{ + return pk_int4_t_to_fp32x2_t(*this); +} + } // namespace ck_tile diff --git a/include/ck_tile/core/numeric/vector_type.hpp b/include/ck_tile/core/numeric/vector_type.hpp index 90ddc2a56e..def054f415 100644 --- a/include/ck_tile/core/numeric/vector_type.hpp +++ b/include/ck_tile/core/numeric/vector_type.hpp @@ -11,6 +11,7 @@ #include "ck_tile/core/numeric/half.hpp" #include "ck_tile/core/numeric/bfloat16.hpp" #include "ck_tile/core/numeric/pk_int4.hpp" +#include "ck_tile/core/numeric/pk_fp4.hpp" #include "ck_tile/core/numeric/e8m0.hpp" #include "ck_tile/core/utility/type_traits.hpp" diff --git a/include/ck_tile/core/utility/mixed_prec_compute_type.hpp b/include/ck_tile/core/utility/mixed_prec_compute_type.hpp new file mode 100644 index 0000000000..021763c108 --- /dev/null +++ b/include/ck_tile/core/utility/mixed_prec_compute_type.hpp @@ -0,0 +1,54 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck_tile/core/config.hpp" +#include "ck_tile/core/numeric/numeric.hpp" +#include "ck_tile/core/utility/type_traits.hpp" + +#include + +namespace ck_tile { + +namespace detail { + +// Helper method to automatically determine compute type +// Selects the largest type of the two. If both of them are packed data types, defaults to fp8. +template +struct auto_compute_type +{ + using LargestInputType = largest_type_t; + + // Sanity check: there are no packed types larger than 1 byte yet, but if we add them + // this logic should change + static_assert(!is_packed_type_v || sizeof(LargestInputType) == sizeof(fp8_t)); + + using type = std::conditional_t, fp8_t, LargestInputType>; +}; + +// Helper method to determine compute type, defaulting an explicitly passed-in compute type +template +struct mixed_prec_compute_type +{ + using type = std::conditional_t, + typename auto_compute_type::type, + ComputeDataType>; +}; + +} // namespace detail + +template +using mixed_prec_compute_type_t = + typename detail::mixed_prec_compute_type::type; + +// Helper method to determine compute type, defaulting to input data type +// If "ThisDataType" is packed (4-bit), will default to "OtherDataType". If both are packed, +// ComputeDataType is used. +template +using mixed_prec_compute_type_from_input_t = std::conditional_t< + is_packed_type_v, + std::conditional_t, ComputeDataType, OtherDataType>, + ThisDataType>; + +} // namespace ck_tile diff --git a/include/ck_tile/core/utility/type_traits.hpp b/include/ck_tile/core/utility/type_traits.hpp index f07e25e19c..c11d180839 100644 --- a/include/ck_tile/core/utility/type_traits.hpp +++ b/include/ck_tile/core/utility/type_traits.hpp @@ -4,6 +4,8 @@ #pragma once #include "ck_tile/core/config.hpp" +#include "ck_tile/core/numeric/numeric.hpp" + #include #include #include @@ -187,4 +189,19 @@ template using tuple_element_or_default_t = typename tuple_element_or_default::type; +// Helper struct to determine if a type is packed (more than 1 element per byte) +template +struct is_packed_type +{ + static constexpr bool value = numeric_traits::PackedSize > 1; +}; + +template +static constexpr bool is_packed_type_v = is_packed_type::value; + +// Helper definition to take the largest sizes type +template +using largest_type_t = + std::conditional_t= sizeof(BDataType), ADataType, BDataType>; + } // namespace ck_tile diff --git a/include/ck_tile/host/reference/reference_gemm.hpp b/include/ck_tile/host/reference/reference_gemm.hpp index 9ad5af8264..7830150b63 100644 --- a/include/ck_tile/host/reference/reference_gemm.hpp +++ b/include/ck_tile/host/reference/reference_gemm.hpp @@ -137,47 +137,55 @@ CK_TILE_HOST void reference_gemm_abquant(const HostTensor& a_m_k, const BElementOp& b_element_op = {}, const ACCElementOp& acc_element_op = {}) { - const std::size_t M = a_m_k.get_length(0); - const std::size_t N = b_k_n.get_length(1); - const std::size_t K = a_m_k.get_length(1); + constexpr auto A_TENSOR_M_DIM = 0; + constexpr auto A_TENSOR_K_DIM = 1; + constexpr auto B_TENSOR_K_DIM = 0; + constexpr auto B_TENSOR_N_DIM = 1; + + const std::size_t M = a_m_k.get_length(A_TENSOR_M_DIM); + const std::size_t N = b_k_n.get_length(B_TENSOR_N_DIM); + const std::size_t K = a_m_k.get_length(A_TENSOR_K_DIM); + + // Pre-convert A/B tensors to AccData type + // This prevents doing slow reconversions for each row/column + HostTensor a_acc(a_m_k.mDesc); + HostTensor b_acc(b_k_n.mDesc); + + a_acc.ForEach([&](auto& self, auto index) { + if constexpr(std::is_same_v || std::is_same_v) + { + const ADataType pk_val = a_element_op(a_m_k(index)); + const fp32x2_t fp32_val = pk_val.to_fp32x2(); + self(index) = (index[A_TENSOR_K_DIM] & 1) ? fp32_val.hi : fp32_val.lo; + } + else + { + self(index) = ck_tile::type_convert(a_element_op(a_m_k(index))); + } + }); + + b_acc.ForEach([&](auto& self, auto index) { + if constexpr(std::is_same_v || std::is_same_v) + { + const BDataType pk_val = b_element_op(b_k_n(index)); + const fp32x2_t fp32_val = pk_val.to_fp32x2(); + self(index) = (index[B_TENSOR_K_DIM] & 1) ? fp32_val.hi : fp32_val.lo; + } + else if constexpr(std::is_same_v) + { + self(index) = fp8_to_float_raw(b_element_op(b_k_n(index))); + } + else + { + self(index) = ck_tile::type_convert(b_element_op(b_k_n(index))); + } + }); auto f_mn = [&](auto m, auto n) { AccDataType v_acc = 0; constexpr std::size_t kGroupK = BQuantGroupSize::kK; - // ---- A loader: dequant A(m,k) into AccDataType ---- - auto load_a = [&](std::size_t k) -> AccDataType { - if constexpr(std::is_same_v) - { - const pk_int4_t pk_val = a_element_op(a_m_k(m, k)); - const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t(pk_val); - return (k & 1) ? fp32_val.hi : fp32_val.lo; - } - else - { - return ck_tile::type_convert(a_element_op(a_m_k(m, k))); - } - }; - - // ---- B loader: dequant B(k,n) into AccDataType ---- - auto load_b = [&](std::size_t k) -> AccDataType { - if constexpr(std::is_same_v) - { - const pk_int4_t pk_val = b_element_op(b_k_n(k, n)); - const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t(pk_val); - return (k & 1) ? fp32_val.hi : fp32_val.lo; - } - else if constexpr(std::is_same_v) - { - return fp8_to_float_raw(b_element_op(b_k_n(k, n))); - } - else - { - return ck_tile::type_convert(b_element_op(b_k_n(k, n))); - } - }; - // ---- a scale loader for a given K-group index ---- auto load_scale_a = [&](ck_tile::index_t k_group) -> float { const ck_tile::index_t outer_dim = m / AQuantGroupSize::kM; @@ -224,8 +232,8 @@ CK_TILE_HOST void reference_gemm_abquant(const HostTensor& a_m_k, // unscaled accumulation within this K-group for(std::size_t k = k_begin; k < k_end; ++k) { - const AccDataType v_a = load_a(k); - const AccDataType v_b = load_b(k); + const AccDataType v_a = a_acc(m, k); + const AccDataType v_b = b_acc(k, n); v_block_acc += v_a * v_b; } diff --git a/include/ck_tile/ops/common/load_interleaved_pk_type.hpp b/include/ck_tile/ops/common/load_interleaved_pk_type.hpp index 10c2a1e4df..3f1a3b8f1c 100644 --- a/include/ck_tile/ops/common/load_interleaved_pk_type.hpp +++ b/include/ck_tile/ops/common/load_interleaved_pk_type.hpp @@ -4,11 +4,12 @@ #pragma once #include "ck_tile/core/config.hpp" +#include "ck_tile/core/utility/type_traits.hpp" #include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp" namespace ck_tile { -template +template struct InterleavedPKTypeLoader { template @@ -21,10 +22,15 @@ struct InterleavedPKTypeLoader constexpr index_t thread_buffer_size = WarpTile::get_thread_buffer_size() / UnaryOpSize; const auto in_dstr_tensors = load_tile(warp_window); - using DstVectorType = DstDataType __attribute__((ext_vector_type(UnaryOpSize))); + // NOTE: we rely on types packing neatly here + using RawSrcType = typename SrcDataType::type; + constexpr auto PackedSize = numeric_traits::PackedSize; + + using SrcVectorType = ext_vector_t; + using DstVectorType = ext_vector_t; static_for<0, thread_buffer_size, 1>{}([&](auto i) { elementwise_op(warp_tile.get_thread_buffer().template get_as()(i), - in_dstr_tensors.get_thread_buffer().template get_as()[i]); + in_dstr_tensors.get_thread_buffer().template get_as()[i]); }); } }; @@ -37,10 +43,11 @@ template CK_TILE_DEVICE void load_int4_tile(WarpTile& dst, const WarpWindow& src) { - if constexpr(std::is_same_v) + if constexpr(is_packed_type_v) { - static_assert(!LoadTranspose, "LoadTranspose not supported with pk_int4_t"); - InterleavedPKTypeLoader::load_interleaved_pk_type(dst, src); + static_assert(!LoadTranspose, "LoadTranspose not supported with pk_int4_t or pk_fp4_t"); + InterleavedPKTypeLoader::load_interleaved_pk_type( + dst, src); } else if constexpr(LoadTranspose) { diff --git a/include/ck_tile/ops/elementwise/unary_element_wise_operation.hpp b/include/ck_tile/ops/elementwise/unary_element_wise_operation.hpp index ca9af0a7a8..3f58eceb33 100644 --- a/include/ck_tile/ops/elementwise/unary_element_wise_operation.hpp +++ b/include/ck_tile/ops/elementwise/unary_element_wise_operation.hpp @@ -397,6 +397,29 @@ struct PassThroughPack8 y.hi = i4_to_bf8x4(bit_cast(x) >> 8); #endif } + + CK_TILE_HOST_DEVICE constexpr void operator()(fp8x8_t& y, const pk_fp4x4_t& x) const + { + pk_fp4_t f0 = pk_fp4_t{x[0]}; + pk_fp4_t f1 = pk_fp4_t{x[1]}; + pk_fp4_t f2 = pk_fp4_t{x[2]}; + pk_fp4_t f3 = pk_fp4_t{x[3]}; + + fp8x2_t x0 = f0.to_fp8x2(); + fp8x2_t x1 = f1.to_fp8x2(); + fp8x2_t x2 = f2.to_fp8x2(); + fp8x2_t x3 = f3.to_fp8x2(); + + y[0] = x0[0]; + y[1] = x0[1]; + y[2] = x1[0]; + y[3] = x1[1]; + y[4] = x2[0]; + y[5] = x2[1]; + y[6] = x3[0]; + y[7] = x3[1]; + } + constexpr const static bool is_pack8_invocable = true; }; diff --git a/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_base_policy.hpp b/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_base_policy.hpp index 1784436f87..0044b412ec 100644 --- a/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_base_policy.hpp +++ b/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_base_policy.hpp @@ -4,6 +4,7 @@ #pragma once #include "ck_tile/core.hpp" +#include "ck_tile/core/numeric/numeric.hpp" #include "ck_tile/ops/gemm/block/block_wp_asmem_breg_creg.hpp" #include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp" @@ -255,17 +256,26 @@ struct UniversalWeightPreshufflePipelineAgBgCrPolicy { using BlockWarps = typename Problem::BlockGemmShape::BlockWarps; using WarpTile = typename Problem::BlockGemmShape::WarpTile; - using BTypeToUse = - std::conditional_t, - typename Problem::ADataType, - typename Problem::BDataType>; + + // Determine compute types to use + // This logic defaults to A/B DataType, but if one of them is packed falls back to the other + // If both are packed, it falls back to the explicitly defined ComputeDataType in the + // problem It might be a good idea to use ComputeDataType anyway, but that would break how + // this behaviour used to work + using ATypeToUse = mixed_prec_compute_type_from_input_t; + using BTypeToUse = mixed_prec_compute_type_from_input_t; + constexpr index_t WaveSize = get_warp_size(); constexpr index_t KLane = WarpTile::at(I2) * WarpTile::at(I0) / WaveSize; using BDataType = typename Problem::BDataType; constexpr index_t KLaneBytes = KLane / numeric_traits::PackedSize * sizeof(BDataType); constexpr auto NumAccess = static_cast(max(1, KLaneBytes / 16)); - using WarpGemm = WarpGemmDispatcher f32 static_assert( (std::is_same_v || std::is_same_v || - std::is_same_v) && + std::is_same_v || + std::is_same_v) && (std::is_same_v || std::is_same_v || - std::is_same_v) && + std::is_same_v || + std::is_same_v) && (std::is_same_v || std::is_same_v || std::is_same_v) && (std::is_same_v || std::is_same_v || @@ -189,7 +191,8 @@ struct BlockGemmWeightPreshuffleABQuantARegBRegCReg typename BFlatBlockTensor, typename AQBlockTensor, typename BQBlockTensor, - typename ABlockWindow> + typename ABlockWindow, + index_t UnaryOpSize = 8> CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor, ABlockTensor& a_warp_tensor, BFlatBlockTensor& b_warp_tensor, @@ -249,8 +252,10 @@ struct BlockGemmWeightPreshuffleABQuantARegBRegCReg { constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp; constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp); - a_warp_tensor(number{}) = - load_tile(a_warp_windows(number{})(number{})); + + load_int4_tile( + a_warp_tensor(number{}), + a_warp_windows(number{})(number{})); } // barrier // Could be deleted diff --git a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_bquant_cr.hpp b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_bquant_cr.hpp index 2d28b813bf..d79bd31489 100644 --- a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_bquant_cr.hpp +++ b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_bquant_cr.hpp @@ -108,9 +108,11 @@ struct ABQuantBlockUniversalGemmAsBsCr : public BlockGemmQuantBase // 4. i4, bf8, (fp8/fp32) -> f32 static_assert( (std::is_same_v || std::is_same_v || - std::is_same_v) && + std::is_same_v || + std::is_same_v) && (std::is_same_v || std::is_same_v || - std::is_same_v) && + std::is_same_v || + std::is_same_v) && (std::is_same_v || std::is_same_v || std::is_same_v) && (std::is_same_v || std::is_same_v || @@ -135,12 +137,9 @@ struct ABQuantBlockUniversalGemmAsBsCr : public BlockGemmQuantBase using ComputeDataType = remove_cvref_t; using CDataType = remove_cvref_t; - // BDataType gets converted from PkInt4 during loading - using OverrideBDataType = std::conditional_t< - std::is_same_v && - std::is_same_v, - ADataType, - BDataType>; + // A/B DataType get converted from PkInt4/PkFp4 during loading + using OverrideADataType = ComputeDataType; + using OverrideBDataType = ComputeDataType; using Base = BlockGemmQuantBase; using WarpGemm = remove_cvref_t; @@ -268,9 +267,9 @@ struct ABQuantBlockUniversalGemmAsBsCr : public BlockGemmQuantBase bool_constant = {}, bool_constant = {}) { - load_int4_tile( + // If A/B datatype were pkint4/pkfp4 it would be converted prior to storing in LDS + load_int4_tile( a_warp_tile_, a_block_window); - // If B datatype were pkint4 it would be converted prior to storing in LDS load_int4_tile( b_warp_tile_, b_block_window); } diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_abquant_pipeline_ag_bg_cr_policy.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_abquant_pipeline_ag_bg_cr_policy.hpp index 095275e60b..b636bfa4b7 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_abquant_pipeline_ag_bg_cr_policy.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_abquant_pipeline_ag_bg_cr_policy.hpp @@ -10,9 +10,10 @@ namespace ck_tile { -struct GemmABQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgCrPolicy +struct GemmABQuantPipelineAgBgCrDefaultPolicy + : public UniversalGemmBasePolicy { - using Base = UniversalGemmPipelineAgBgCrPolicy; + using Base = UniversalGemmBasePolicy; using Base::I0; using Base::I1; using Base::I2; diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_abquant_pipeline_ag_bg_cr_v3.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_abquant_pipeline_ag_bg_cr_v3.hpp index 5902dd0c4f..cfd12313e8 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_abquant_pipeline_ag_bg_cr_v3.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_abquant_pipeline_ag_bg_cr_v3.hpp @@ -34,9 +34,6 @@ struct ABQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3; using AQuantGroupSize = remove_cvref_t; using BQuantGroupSize = remove_cvref_t; - // BDataType gets converted from PkInt4 during loading - using OverrideBDataType = - std::conditional_t, ADataType, BDataType>; static_assert(BQuantGroupSize::kM == 1, "only N/K blocks for BQuant kernel!"); static_assert(AQuantGroupSize::kN == 1, "only M/K blocks for AQuant kernel!"); @@ -67,6 +64,10 @@ struct ABQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3())>; + // A/B DataType gets converted from PkInt4/PkFp4 during loading + using OverrideADataType = BlockGemm::OverrideADataType; + using OverrideBDataType = BlockGemm::OverrideBDataType; + static constexpr index_t BlockSize = Problem::kBlockSize; static constexpr index_t MPerBlock = BlockGemmShape::kM; static constexpr index_t NPerBlock = BlockGemmShape::kN; @@ -281,9 +282,9 @@ struct ABQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3(p_smem); + Base::template GetABLdsTensorViews(p_smem); constexpr auto a_lds_load_tile_distr = make_static_tile_distribution(BlockGemm::MakeABlockDistributionEncode()); @@ -303,9 +304,9 @@ struct ABQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3(ABlockTileDistr{})); + decltype(make_static_distributed_tensor(ABlockTileDistr{})); using BBlockTile = - decltype(make_static_distributed_tensor(BBlockTileDistr{})); + decltype(make_static_distributed_tensor(BBlockTileDistr{})); using AQBlockTile = decltype(make_static_distributed_tensor(AQBlockTileDistr{})); using BQBlockTile = @@ -361,7 +362,7 @@ struct ABQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3( + auto a_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledARegTileDistribution()); transpose_tile2d(a_shuffle_tmp, a_block_tile); Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func); @@ -373,7 +374,7 @@ struct ABQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3( + auto b_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledBRegTileDistribution()); transpose_tile2d(b_shuffle_tmp, b_block_tile); Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func); @@ -409,7 +410,8 @@ struct ABQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3( + // Note: ABDataType PkInt4/PkFp4 gets converted during loading earlier + auto a_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledARegTileDistribution()); transpose_tile2d(a_shuffle_tmp, a_block_tile); Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func); @@ -420,7 +422,7 @@ struct ABQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3( Policy::template MakeShuffledBRegTileDistribution()); transpose_tile2d(b_shuffle_tmp, b_block_tile); @@ -493,7 +495,8 @@ struct ABQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3( + // Note: ADataType gets converted during loading from PkInt4/PkFp4 + auto a_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledARegTileDistribution()); transpose_tile2d(a_shuffle_tmp, a_block_tile); Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func); @@ -543,9 +546,9 @@ struct ABQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3{}.template operator()( a_dram_block_window_tmp, - [](const ADataType& a) { return a; }, + [](const OverrideADataType& a) { return a; }, b_dram_block_window_tmp, - [](const BDataType& b) { return b; }, + [](const OverrideBDataType& b) { return b; }, aq_dram_block_window_tmp, bq_dram_block_window_tmp, m, @@ -593,9 +596,10 @@ struct ABQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3{}.template operator()( a_dram_block_window_tmp, - [](const ADataType& a) { return a; }, + // Note: ADataType PkInt4/PkFp4 gets converted during loading + [](const OverrideADataType& a) { return a; }, b_dram_block_window_tmp, - // Note: BDataType PkInt4 gets converted during loading + // Note: BDataType PkInt4/PkFp4 gets converted during loading [](const OverrideBDataType& b) { return b; }, aq_dram_block_window_tmp, bq_dram_block_window_tmp, diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_quant_pipeline_problem.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_quant_pipeline_problem.hpp index 1edbe9ac16..9b02585e69 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_quant_pipeline_problem.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_quant_pipeline_problem.hpp @@ -21,23 +21,27 @@ template -struct GemmQuantPipelineProblemBase : public GemmPipelineProblemBase +struct GemmQuantPipelineProblemBase + : public GemmPipelineProblemBase< + ADataType_, + BDataType_, + CDataType_, + BlockGemmShape_, + Traits_, + mixed_prec_compute_type_t> { - using Base = GemmPipelineProblemBase; + + using Base = GemmPipelineProblemBase< + ADataType_, + BDataType_, + CDataType_, + BlockGemmShape_, + Traits_, + mixed_prec_compute_type_t>; using Traits = typename Base::Traits; diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_abquant_pipeline_ag_bg_cr_base_policy.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_abquant_pipeline_ag_bg_cr_base_policy.hpp index ae2a601f8a..f136b86314 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_abquant_pipeline_ag_bg_cr_base_policy.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_abquant_pipeline_ag_bg_cr_base_policy.hpp @@ -95,11 +95,6 @@ struct GemmWPABQuantPipelineAgBgCrPolicy : public UniversalWeightPreshufflePipel using BlockWarps = typename Problem::BlockGemmShape::BlockWarps; using WarpTile = typename Problem::BlockGemmShape::WarpTile; - using BTypeToUse = - std::conditional_t, - typename Problem::ADataType, - typename Problem::BDataType>; - constexpr index_t WaveSize = get_warp_size(); constexpr index_t KLane = WarpTile::at(I2) * WarpTile::at(I0) / WaveSize; using BDataType = typename Problem::BDataType; @@ -107,8 +102,8 @@ struct GemmWPABQuantPipelineAgBgCrPolicy : public UniversalWeightPreshufflePipel KLane / numeric_traits::PackedSize * sizeof(BDataType); constexpr auto NumAccess = static_cast(max(1, KLaneBytes / 16)); - using WarpGemm = WarpGemmDispatcher #include "ck_tile/core.hpp" +#include "ck_tile/ops/common/load_interleaved_pk_type.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp" #include "ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_v2.hpp" @@ -239,36 +240,42 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe make_tensor_view(p_a_lds_pong, a_lds_block_desc); // A DRAM tile window for load + auto a_dram_tile_distribution = + PipelinePolicy::template MakeADramTileDistribution(); + auto a_copy_dram_window = make_tile_window(a_dram_block_window_tmp.get_bottom_tensor_view(), make_tuple(number{}, number{}), a_dram_block_window_tmp.get_window_origin(), - PipelinePolicy::template MakeADramTileDistribution()); + a_dram_tile_distribution); auto a_copy_lds_window_ping = make_tile_window(a_lds_block_ping, make_tuple(number{}, number{}), {0, 0}, - PipelinePolicy::template MakeADramTileDistribution()); + a_dram_tile_distribution); auto a_copy_lds_window_pong = make_tile_window(a_lds_block_pong, make_tuple(number{}, number{}), {0, 0}, - PipelinePolicy::template MakeADramTileDistribution()); + a_dram_tile_distribution); // ping-pong window for A LDS + auto a_warp_tile_distribution = + make_static_tile_distribution(typename WG::AWarpDstrEncoding{}); + auto a_warp_window_ping_tmp = make_tile_window(a_lds_block_ping, make_tuple(number{}, number{}), {iMWarp * WG::kM, 0}, - make_static_tile_distribution(typename WG::AWarpDstrEncoding{})); + a_warp_tile_distribution); auto a_warp_window_pong_tmp = make_tile_window(a_lds_block_pong, make_tuple(number{}, number{}), {iMWarp * WG::kM, 0}, - make_static_tile_distribution(typename WG::AWarpDstrEncoding{})); + a_warp_tile_distribution); statically_indexed_array< statically_indexed_array, @@ -314,7 +321,7 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe b_flat_distribution); using BTypeToUse = - std::conditional_t, ADataType, BDataType>; + mixed_prec_compute_type_from_input_t; using BTileType = decltype(make_static_distributed_tensor(b_flat_distribution)); // pingpong buffer for B @@ -354,7 +361,7 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe move_tile_window(b_flat_dram_windows(nIter)(kIter), {nIter * flatNPerWarp, kIter * flatKPerWarp}); - load_int4_tile( + load_int4_tile( b_warp_tensor_ping(nIter)(kIter), b_flat_dram_windows(nIter)(kIter)); }); }); @@ -393,15 +400,17 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe block_sync_lds(); // preload A00,A10 from lds - statically_indexed_array{})(number<0>{}))), - m_preload> - a_warp_tensor; + using ATypeToUse = + mixed_prec_compute_type_from_input_t; + using ATileType = + decltype(make_static_distributed_tensor(a_warp_tile_distribution)); + statically_indexed_array a_warp_tensor; static_for<0, m_preload, 1>{}([&](auto loadIter) { constexpr auto mIter = loadIter % MIterPerWarp; constexpr auto kIter = loadIter / MIterPerWarp; - a_warp_tensor(loadIter) = - load_tile(a_warp_windows_ping(number{})(number{})); + load_int4_tile( + a_warp_tensor(loadIter), a_warp_windows_ping(number{})(number{})); }); __builtin_amdgcn_sched_barrier(0); @@ -434,7 +443,7 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe move_tile_window(b_flat_dram_windows(nIter)(kIter), {nIter * flatNPerWarp, kIter * flatKPerWarp}); - load_int4_tile( + load_int4_tile( b_warp_tensor_pong(nIter)(kIter), b_flat_dram_windows(nIter)(kIter)); }); }); @@ -450,8 +459,8 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe static_for<0, m_preload, 1>{}([&](auto loadIter) { constexpr auto mIter = loadIter % MIterPerWarp; constexpr auto kIter = loadIter / MIterPerWarp; - a_warp_tensor(loadIter) = - load_tile(a_warp_windows_pong(number{})(number{})); + load_int4_tile( + a_warp_tensor(loadIter), a_warp_windows_pong(number{})(number{})); }); // Next K @@ -463,7 +472,7 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe move_tile_window(b_flat_dram_windows(nIter)(kIter), {nIter * flatNPerWarp, kIter * flatKPerWarp}); - load_int4_tile( + load_int4_tile( b_warp_tensor_ping(nIter)(kIter), b_flat_dram_windows(nIter)(kIter)); }); }); @@ -495,8 +504,8 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe static_for<0, m_preload, 1>{}([&](auto loadIter) { constexpr auto mIter = loadIter % MIterPerWarp; constexpr auto kIter = loadIter / MIterPerWarp; - a_warp_tensor(loadIter) = - load_tile(a_warp_windows_ping(number{})(number{})); + load_int4_tile( + a_warp_tensor(loadIter), a_warp_windows_ping(number{})(number{})); }); iCounter--; HotLoopScheduler(); @@ -513,7 +522,7 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe move_tile_window(b_flat_dram_windows(nIter)(kIter), {nIter * flatNPerWarp, kIter * flatKPerWarp}); - load_int4_tile( + load_int4_tile( b_warp_tensor_pong(nIter)(kIter), b_flat_dram_windows(nIter)(kIter)); }); }); @@ -535,8 +544,8 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe static_for<0, m_preload, 1>{}([&](auto loadIter) { constexpr auto mIter = loadIter % MIterPerWarp; constexpr auto kIter = loadIter / MIterPerWarp; - a_warp_tensor(loadIter) = - load_tile(a_warp_windows_pong(number{})(number{})); + load_int4_tile( + a_warp_tensor(loadIter), a_warp_windows_pong(number{})(number{})); }); // GEMM loopK diff --git a/test/ck_tile/gemm_block_scale/CMakeLists.txt b/test/ck_tile/gemm_block_scale/CMakeLists.txt index 9dd9670ff5..8e005d588e 100644 --- a/test/ck_tile/gemm_block_scale/CMakeLists.txt +++ b/test/ck_tile/gemm_block_scale/CMakeLists.txt @@ -76,6 +76,22 @@ if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx12") ) target_compile_options(test_tile_gemm_quant_abquant_preshuffle PRIVATE ${TEST_GEMM_COMPILE_OPTIONS}) + + add_gtest_executable(test_tile_gemm_quant_abquant_a4w4_base + test_gemm_quant_abquant_a4w4_base.cpp + ) + target_compile_options(test_tile_gemm_quant_abquant_a4w4_base PRIVATE ${TEST_GEMM_COMPILE_OPTIONS}) + + add_gtest_executable(test_tile_gemm_quant_abquant_a4w4_padding + test_gemm_quant_abquant_a4w4_padding.cpp + ) + target_compile_options(test_tile_gemm_quant_abquant_a4w4_padding PRIVATE ${TEST_GEMM_COMPILE_OPTIONS}) + + add_gtest_executable(test_tile_gemm_quant_abquant_a4w4_preshuffle + test_gemm_quant_abquant_a4w4_preshuffle.cpp + ) + target_compile_options(test_tile_gemm_quant_abquant_a4w4_preshuffle PRIVATE ${TEST_GEMM_COMPILE_OPTIONS}) + add_gtest_executable(test_tile_gemm_quant_abquant_preshuffleQuant test_gemm_quant_abquant_preshuffleQuant.cpp ) diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_abquant_a4w4_base.cpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_abquant_a4w4_base.cpp new file mode 100644 index 0000000000..5e2403f7d1 --- /dev/null +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_abquant_a4w4_base.cpp @@ -0,0 +1,44 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck_tile/host.hpp" +#include "ck_tile/ops/gemm.hpp" + +#include +#include + +#include "test_gemm_quant_fixtures.hpp" + +// Type aliases for readability +using RowMajor = ck_tile::tensor_layout::gemm::RowMajor; +using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor; +using Half = ck_tile::half_t; +using PkFP4 = ck_tile::pk_fp4_t; +using ABQuantGrouped = + std::integral_constant; + +// 1d block sizes for AQuant +using GroupSize1D = ck_tile::QuantGroupShape>; + +// 2d block sizes for BQuant +using GroupSize2D = ck_tile::QuantGroupShape>; + +// Type combinations for ABQuant tests +// Tuple format: +// clang-format off +using ABQuantTypes = ::testing::Types< + // PreshuffleQuant = false && TransposeC = false + // RCR layout with RowMajor AQ, ColumnMajor BQ + std::tuple +>; +// clang-format on + +// Test suite for ABQuant +TYPED_TEST_SUITE(TestCkTileGemmABQuant, ABQuantTypes); + +// AQuant tests +TYPED_TEST(TestCkTileGemmABQuant, ABQuantGroupedTest) +{ + this->run_test_with_validation(1024, 1024, 1024); +} diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_abquant_a4w4_padding.cpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_abquant_a4w4_padding.cpp new file mode 100644 index 0000000000..1e496d5b64 --- /dev/null +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_abquant_a4w4_padding.cpp @@ -0,0 +1,65 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck_tile/host.hpp" +#include "ck_tile/ops/gemm.hpp" + +#include +#include + +#include "test_gemm_quant_fixtures.hpp" + +// Type aliases for readability +using RowMajor = ck_tile::tensor_layout::gemm::RowMajor; +using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor; +using Half = ck_tile::half_t; +using PkFP4 = ck_tile::pk_fp4_t; +using ABQuantGrouped = + std::integral_constant; + +// 1d block sizes for AQuant +using GroupSize1D = ck_tile::QuantGroupShape>; + +// 2d block sizes for BQuant +using GroupSize2D = ck_tile::QuantGroupShape>; + +// Type combinations for ABQuant tests +// Tuple format: +// clang-format off +using ABQuantTypes = ::testing::Types< + // PreshuffleQuant = false && TransposeC = false + // RCR layout with RowMajor AQ, ColumnMajor BQ + std::tuple +>; +// clang-format on + +// Test suite for ABQuant +TYPED_TEST_SUITE(TestCkTileGemmABQuant, ABQuantTypes); + +// AQuant tests + +TYPED_TEST(TestCkTileGemmABQuant, ABQuantGroupedTest_PadK) +{ + this->run_test_with_validation(1024, 1024, 832); +} + +TYPED_TEST(TestCkTileGemmABQuant, ABQuantGroupedTest_PadN) +{ + this->run_test_with_validation(1024, 832, 1024); +} + +TYPED_TEST(TestCkTileGemmABQuant, ABQuantGroupedTest_PadM) +{ + this->run_test_with_validation(832, 1024, 1024); +} + +TYPED_TEST(TestCkTileGemmABQuant, ABQuantGroupedTest_PadMNK) +{ + this->run_test_with_validation(832, 832, 832); +} + +TYPED_TEST(TestCkTileGemmABQuant, ABQuantGroupedTest_PadNK) +{ + this->run_test_with_validation(1024, 832, 832); +} diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_abquant_a4w4_preshuffle.cpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_abquant_a4w4_preshuffle.cpp new file mode 100644 index 0000000000..43051c8d08 --- /dev/null +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_abquant_a4w4_preshuffle.cpp @@ -0,0 +1,44 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck_tile/host.hpp" +#include "ck_tile/ops/gemm.hpp" + +#include +#include + +#include "test_gemm_quant_fixtures.hpp" + +// Type aliases for readability +using RowMajor = ck_tile::tensor_layout::gemm::RowMajor; +using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor; +using Half = ck_tile::half_t; +using PkFP4 = ck_tile::pk_fp4_t; +using ABQuantGrouped = + std::integral_constant; + +// 1d block sizes for AQuant +using GroupSize1D = ck_tile::QuantGroupShape>; + +// 2d block sizes for BQuant +using GroupSize2D = ck_tile::QuantGroupShape>; + +// Type combinations for ABQuant tests +// Tuple format: +// clang-format off +using ABQuantTypes = ::testing::Types< + // RCR layout with RowMajor AQ, ColumnMajor BQ + // PreshuffleB = true && TransposeC = false + std::tuple +>; +// clang-format on + +// Test suite for ABQuant +TYPED_TEST_SUITE(TestCkTileGemmABQuant, ABQuantTypes); + +// AQuant tests +TYPED_TEST(TestCkTileGemmABQuant, ABQuantGroupedTest) +{ + this->run_test_with_validation(1024, 1024, 1024); +} diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_base.hpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_base.hpp index 7be4131db4..5937b44229 100644 --- a/test/ck_tile/gemm_block_scale/test_gemm_quant_base.hpp +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_base.hpp @@ -209,7 +209,7 @@ template <> struct QuantTypeTraits { template - using ComputeDataType = BDataType; // For AQuant, compute type is BDataType + using ComputeDataType = void; // Use automatically determined compute type static constexpr const char* name = "abquant"; }; diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp index 9683fa98aa..0033bb42a8 100644 --- a/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp @@ -1174,8 +1174,8 @@ class TestCkTileGemmABQuant : public TestCkTileGemmQuantBase>; using GemmEpilogue = ck_tile::CShuffleEpilogue< - ck_tile::CShuffleEpilogueProblem, AccDataType, CDataType,