diff --git a/CMakeLists.txt b/CMakeLists.txt index 356491d9c1..610f9c9d2a 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -41,6 +41,7 @@ include(CTest) option(ENABLE_CLANG_CPP_CHECKS "Enables clang tidy, cppcheck" ON) option(MIOPEN_REQ_LIBS_ONLY "Build only the MIOpen required libraries" OFF) +option(HIPTENSOR_REQ_LIBS_ONLY "Build only the HipTensor required libraries" OFF) option(CK_EXPERIMENTAL_BUILDER "Enable experimental builder" OFF) option(BUILD_MHA_LIB "Build the static library for flash attention" OFF) option(FORCE_DISABLE_XDL "Skip compiling XDL specific instances (even if supported GPUs are included in GPU_TARGETS)" OFF) @@ -648,7 +649,7 @@ if("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU" AND CMAKE_CXX_COMPILER_VERSION VERS add_compile_options(-fdiagnostics-color=always) endif() -if(NOT MIOPEN_REQ_LIBS_ONLY) +if(NOT MIOPEN_REQ_LIBS_ONLY AND NOT HIPTENSOR_REQ_LIBS_ONLY) # make check runs the entire set of examples and tests add_custom_target(check COMMAND ${CMAKE_CTEST_COMMAND} --output-on-failure -C ${CMAKE_CFG_INTDIR} USES_TERMINAL) # make smoke runs the tests and examples that runs within 30 seconds on gfx90a @@ -706,6 +707,7 @@ ENDFOREACH() add_custom_target(instances DEPENDS utility;${CK_DEVICE_INSTANCES} SOURCES ${INSTANCE_FILES}) option(MIOPEN_REQ_LIBS_ONLY "Build only the MIOpen required libraries" OFF) +option(HIPTENSOR_REQ_LIBS_ONLY "Build only the HipTensor required libraries" OFF) option(DISABLE_OFFLOAD_COMPRESS "Disable offload compress compiler flag when building instances" OFF) option(BUILD_MHA_LIB "Build the static library for flash attention" OFF) @@ -716,7 +718,7 @@ if (CK_EXPERIMENTAL_BUILDER) add_subdirectory(experimental/grouped_convolution_tile_instances) endif() -if(NOT GPU_ARCHS AND USER_GPU_TARGETS AND NOT MIOPEN_REQ_LIBS_ONLY) +if(NOT GPU_ARCHS AND USER_GPU_TARGETS AND NOT MIOPEN_REQ_LIBS_ONLY AND NOT HIPTENSOR_REQ_LIBS_ONLY) rocm_package_setup_component(tests LIBRARY_NAME composablekernel PACKAGE_NAME tests # Prevent -static suffix on package name @@ -739,7 +741,7 @@ if(NOT GPU_ARCHS AND USER_GPU_TARGETS AND NOT MIOPEN_REQ_LIBS_ONLY) endif() endif() -if (NOT MIOPEN_REQ_LIBS_ONLY) +if (NOT MIOPEN_REQ_LIBS_ONLY AND NOT HIPTENSOR_REQ_LIBS_ONLY) rocm_package_setup_component(profiler LIBRARY_NAME composablekernel PACKAGE_NAME ckprofiler diff --git a/example/60_gemm_multi_ABD/gemm_multi_ABD_wmma_bias_fastgelu_bf16_i8.cpp b/example/60_gemm_multi_ABD/gemm_multi_ABD_wmma_bias_fastgelu_bf16_i8.cpp index cf8dd31c3f..78d98e92ce 100644 --- a/example/60_gemm_multi_ABD/gemm_multi_ABD_wmma_bias_fastgelu_bf16_i8.cpp +++ b/example/60_gemm_multi_ABD/gemm_multi_ABD_wmma_bias_fastgelu_bf16_i8.cpp @@ -96,11 +96,11 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleABD_Wmm 8, 8, 0, - S<8, 32, 1>, + S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, - 1, + 8, 8, 0, 1, @@ -108,7 +108,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleABD_Wmm S<1, 32, 1, 8>, S<8, 8, 8>, ck::BlockGemmPipelineScheduler::Intrawave, - ck::BlockGemmPipelineVersion::v3>; + ck::BlockGemmPipelineVersion::v1>; int main(int argc, char* argv[]) { @@ -174,6 +174,29 @@ int main(int argc, char* argv[]) } }; + auto f_get_default_stride = + [](std::size_t row, std::size_t col, ck::index_t stride, auto layout) { + if(stride == -1 || stride == 0) + { + // give a chance if stride is -1, return a default packed stride + if constexpr(std::is_same_v) + { + return static_cast(col); + } + else + { + return static_cast(row); + } + } + else + return static_cast(stride); + }; + + StrideA = f_get_default_stride(M, K, StrideA, A0Layout{}); + StrideB = f_get_default_stride(K, N, StrideB, B0Layout{}); + StrideD = f_get_default_stride(M, N, StrideD, D0Layout{}); + StrideE = f_get_default_stride(M, N, StrideE, ELayout{}); + Tensor a0_m_k(f_host_tensor_descriptor(M, K, StrideA, A0Layout{})); Tensor b0_k_n(f_host_tensor_descriptor(K, N, StrideB, B0Layout{})); Tensor b1_k_n(f_host_tensor_descriptor(K, N, StrideB, B1Layout{})); diff --git a/example/60_gemm_multi_ABD/gemm_multi_ABD_wmma_fastgelu_bf16_i8.cpp b/example/60_gemm_multi_ABD/gemm_multi_ABD_wmma_fastgelu_bf16_i8.cpp index e4033e5bac..089404757a 100644 --- a/example/60_gemm_multi_ABD/gemm_multi_ABD_wmma_fastgelu_bf16_i8.cpp +++ b/example/60_gemm_multi_ABD/gemm_multi_ABD_wmma_fastgelu_bf16_i8.cpp @@ -94,11 +94,11 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleABD_Wmm 8, 8, 0, - S<8, 32, 1>, + S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, - 1, + 8, 8, 0, 1, @@ -106,7 +106,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleABD_Wmm S<1, 32, 1, 8>, S<8, 8, 8>, ck::BlockGemmPipelineScheduler::Intrawave, - ck::BlockGemmPipelineVersion::v3>; + ck::BlockGemmPipelineVersion::v1>; int main(int argc, char* argv[]) { @@ -133,7 +133,7 @@ int main(int argc, char* argv[]) init_method = std::stoi(argv[2]); time_kernel = std::stoi(argv[3]); } - else if(argc == 11) + else if(argc == 10) { do_verification = std::stoi(argv[1]); init_method = std::stoi(argv[2]); @@ -170,6 +170,28 @@ int main(int argc, char* argv[]) } }; + auto f_get_default_stride = + [](std::size_t row, std::size_t col, ck::index_t stride, auto layout) { + if(stride == -1 || stride == 0) + { + // give a chance if stride is -1, return a default packed stride + if constexpr(std::is_same_v) + { + return static_cast(col); + } + else + { + return static_cast(row); + } + } + else + return static_cast(stride); + }; + + StrideA = f_get_default_stride(M, K, StrideA, A0Layout{}); + StrideB = f_get_default_stride(K, N, StrideB, B0Layout{}); + StrideE = f_get_default_stride(M, N, StrideE, ELayout{}); + Tensor a0_m_k(f_host_tensor_descriptor(M, K, StrideA, A0Layout{})); Tensor b0_k_n(f_host_tensor_descriptor(K, N, StrideB, B0Layout{})); Tensor b1_k_n(f_host_tensor_descriptor(K, N, StrideB, B1Layout{})); diff --git a/example/60_gemm_multi_ABD/gemm_multi_ABD_wmma_fp16.cpp b/example/60_gemm_multi_ABD/gemm_multi_ABD_wmma_fp16.cpp index 5817269fdf..d5ccf7eb59 100644 --- a/example/60_gemm_multi_ABD/gemm_multi_ABD_wmma_fp16.cpp +++ b/example/60_gemm_multi_ABD/gemm_multi_ABD_wmma_fp16.cpp @@ -141,11 +141,11 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleABD_Wmm 8, 8, 0, - S<4, 64, 1>, + S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, - 1, + 8, 8, 0, 1, @@ -233,6 +233,29 @@ int main(int argc, char* argv[]) } }; + auto f_get_default_stride = + [](std::size_t row, std::size_t col, ck::index_t stride, auto layout) { + if(stride == -1 || stride == 0) + { + // give a chance if stride is -1, return a default packed stride + if constexpr(std::is_same_v) + { + return static_cast(col); + } + else + { + return static_cast(row); + } + } + else + return static_cast(stride); + }; + + StrideA = f_get_default_stride(M, K, StrideA, ALayout{}); + StrideB = f_get_default_stride(K, N, StrideB, BLayout{}); + StrideD = f_get_default_stride(M, N, StrideD, DLayout{}); + StrideE = f_get_default_stride(M, N, StrideE, ELayout{}); + Tensor a0_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); Tensor a1_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); diff --git a/example/60_gemm_multi_ABD/gemm_multi_ABD_wmma_multiply_bias_fastgelu_bf16_i8.cpp b/example/60_gemm_multi_ABD/gemm_multi_ABD_wmma_multiply_bias_fastgelu_bf16_i8.cpp index 4fb1a5ab4e..2d07bc480d 100644 --- a/example/60_gemm_multi_ABD/gemm_multi_ABD_wmma_multiply_bias_fastgelu_bf16_i8.cpp +++ b/example/60_gemm_multi_ABD/gemm_multi_ABD_wmma_multiply_bias_fastgelu_bf16_i8.cpp @@ -95,11 +95,11 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleABD_Wmm 8, 8, 0, - S<8, 32, 1>, + S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, - 1, + 8, 8, 0, 1, @@ -107,7 +107,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleABD_Wmm S<1, 32, 1, 8>, S<8, 8, 8>, ck::BlockGemmPipelineScheduler::Intrawave, - ck::BlockGemmPipelineVersion::v3>; + ck::BlockGemmPipelineVersion::v1>; int main(int argc, char* argv[]) { @@ -173,6 +173,29 @@ int main(int argc, char* argv[]) } }; + auto f_get_default_stride = + [](std::size_t row, std::size_t col, ck::index_t stride, auto layout) { + if(stride == -1 || stride == 0) + { + // give a chance if stride is -1, return a default packed stride + if constexpr(std::is_same_v) + { + return static_cast(col); + } + else + { + return static_cast(row); + } + } + else + return static_cast(stride); + }; + + StrideA = f_get_default_stride(M, K, StrideA, A0Layout{}); + StrideB = f_get_default_stride(K, N, StrideB, B0Layout{}); + StrideD = f_get_default_stride(M, N, StrideD, D0Layout{}); + StrideE = f_get_default_stride(M, N, StrideE, ELayout{}); + Tensor a0_m_k(f_host_tensor_descriptor(M, K, StrideA, A0Layout{})); Tensor b0_k_n(f_host_tensor_descriptor(K, N, StrideB, B0Layout{})); Tensor d0_m_n(f_host_tensor_descriptor(M, N, StrideD, D0Layout{})); diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py b/example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py index 9a2d727253..42f686e0c0 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py @@ -630,6 +630,7 @@ class KernelComponentFactory: if dtype in ["fp16", "bf16"]: return { 128 : [FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], + 256 : [FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], } # fmt: skip elif dtype in ["fp8bf16"]: return { diff --git a/example/ck_tile/38_block_scale_gemm/gemm_abquant_quantgrouped.cpp b/example/ck_tile/38_block_scale_gemm/gemm_abquant_quantgrouped.cpp index a7cb88079b..e4e0503b5a 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_abquant_quantgrouped.cpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_abquant_quantgrouped.cpp @@ -164,5 +164,35 @@ static auto _ = []() { BQuantGroupSize, ck_tile::QuantType::ABQuantGrouped>(arg_parser); }; + lut[hash_multiple_strings( + {"fp4", "abquant", "non-preshuffleb", "non-preshufflequant", "1x128x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using AQuantGroupSize = ck_tile::QuantGroupShape>; + using BQuantGroupSize = ck_tile::QuantGroupShape>; + using TypeConfig = decltype(GemmQuantTypeConfig{}); + return run_gemm_example_prec_type, + TypeConfig, + AQuantGroupSize, + BQuantGroupSize, + ck_tile::QuantType::ABQuantGrouped>(arg_parser); + }; + lut[hash_multiple_strings( + {"fp4", "abquant", "preshuffleb", "non-preshufflequant", "1x128x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using AQuantGroupSize = ck_tile::QuantGroupShape>; + using BQuantGroupSize = ck_tile::QuantGroupShape>; + using TypeConfig = decltype(GemmQuantTypeConfig{}); + return run_gemm_example_prec_type, + TypeConfig, + AQuantGroupSize, + BQuantGroupSize, + ck_tile::QuantType::ABQuantGrouped>(arg_parser); + }; return 0; }(); diff --git a/example/ck_tile/38_block_scale_gemm/gemm_quant.cpp b/example/ck_tile/38_block_scale_gemm/gemm_quant.cpp index 1fbe4d7b47..cc4302a992 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_quant.cpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_quant.cpp @@ -32,7 +32,7 @@ auto create_args(int argc, char* argv[]) .insert("prec", "fp8", "Data type. For AQuant: fp8, bf8, i4fp8, or i4bf8; for Bquant: fp8, bf8, fp8i4, " - "or bf8i4; for ABQuant: fp8, bf8") + "or bf8i4; for ABQuant: fp8, bf8, fp4") .insert("warmup", "50", "Number of iterations before benchmarking the kernel") .insert("repeat", "1000", "Number of iterations to benchmark the kernel") .insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer") diff --git a/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc b/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc index 665c7828ad..540d5725dd 100644 --- a/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc +++ b/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc @@ -9,6 +9,7 @@ #include #include #include +#include #include "ck_tile/core/config.hpp" #include "ck_tile/ops/common/utils.hpp" @@ -35,10 +36,9 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str static_assert(std::is_same_v); constexpr bool transpose_c = GemmConfig::TransposeC; // QuantMode == ck_tile::QuantType::ABQuantGrouped; - using ComputeDataType = std::conditional_t; + + // Use automatically determined compute type from + using ComputeDataType = void; using GemmShape = ck_tile::TileGemmShape< ck_tile::sequence, @@ -80,7 +80,10 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str std::conditional_t< QuantMode == ck_tile::QuantType::AQuantGrouped, ck_tile::BaseGemmPipelineAgBgCrMem, - ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV2>>>; + std::conditional_t< + QuantMode == ck_tile::QuantType::ABQuantGrouped, + ck_tile::BaseGemmPipelineAgBgCrMem, + ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV2>>>>; const ck_tile::index_t K_split = ck_tile::integer_least_multiple(args.K, GemmConfig::K_Tile); const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split); @@ -182,30 +185,28 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str printf( "TiledPermuteN: %d (QuantGroupSize::kN=%d)\n", TiledPermuteN, BQuantGroupSize::kN); } - using GemmEpilogue = ck_tile::CShuffleEpilogue, - typename TypeConfig::ADataType, - typename TypeConfig::BDataType>, - ck_tile::tuple<>, - typename TypeConfig::AccDataType, - typename TypeConfig::CDataType, - ck_tile::tuple<>, - CLayout, - CDEElementWise, - TilePartitioner::MPerBlock, - TilePartitioner::NPerBlock, - GemmConfig::M_Warp, - GemmConfig::N_Warp, - GemmConfig::M_Warp_Tile, - GemmConfig::N_Warp_Tile, - GemmConfig::K_Warp_Tile, - transpose_c, - 1, - false, - 1, - TiledPermuteN>>; + + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem, + typename TypeConfig::AccDataType, + typename TypeConfig::CDataType, + ck_tile::tuple<>, + CLayout, + CDEElementWise, + TilePartitioner::MPerBlock, + TilePartitioner::NPerBlock, + GemmConfig::M_Warp, + GemmConfig::N_Warp, + GemmConfig::M_Warp_Tile, + GemmConfig::N_Warp_Tile, + GemmConfig::K_Warp_Tile, + transpose_c, + 1, + false, + 1, + TiledPermuteN>>; using Kernel = ck_tile::QuantGemmKernel; @@ -557,8 +558,7 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser, { if constexpr(std::is_same_v) { - ck_tile::FillUniformDistribution{-5.0f, 5.0f, fill_seed(gen)}( - b_k_n); + ck_tile::FillUniformDistribution{-5.0f, 5.0f, fill_seed(gen)}(b_k_n); ck_tile::FillUniformDistribution{-2.0f, 2.0f, fill_seed(gen)}( *bq_tensor_ptr); } @@ -594,18 +594,26 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser, } else if constexpr(QuantMode == ck_tile::QuantType::ABQuantGrouped) { - if constexpr(std::is_same_v) + if constexpr(std::is_same_v || + std::is_same_v) { - ck_tile::FillUniformDistribution{-5.0f, 5.0f, fill_seed(gen)}( - a_m_k); - ck_tile::FillUniformDistribution{-5.0f, 5.0f, fill_seed(gen)}( - b_k_n); + ck_tile::FillUniformDistribution{-5.0f, 5.0f, fill_seed(gen)}(a_m_k); } else { ck_tile::FillUniformDistribution{-2.0f, 3.0f, fill_seed(gen)}(a_m_k); + } + + if constexpr(std::is_same_v || + std::is_same_v) + { + ck_tile::FillUniformDistribution{-5.0f, 5.0f, fill_seed(gen)}(b_k_n); + } + else + { ck_tile::FillUniformDistribution{-2.0f, 3.0f, fill_seed(gen)}(b_k_n); } + ck_tile::FillUniformDistribution{-2.0f, 2.0f, fill_seed(gen)}( *aq_tensor_ptr); ck_tile::FillUniformDistribution{-2.0f, 2.0f, fill_seed(gen)}( @@ -723,12 +731,11 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser, } else if constexpr(QuantMode == ck_tile::QuantType::ABQuantGrouped) { - if constexpr(std::is_same_v) + if constexpr(std::is_same_v || + std::is_same_v) { - ck_tile::FillUniformDistribution{-5.0f, 5.0f, fill_seed(gen)}( - a_m_k); - ck_tile::FillUniformDistribution{-5.0f, 5.0f, fill_seed(gen)}( - b_k_n); + ck_tile::FillUniformDistribution{-5.0f, 5.0f, fill_seed(gen)}(a_m_k); + ck_tile::FillUniformDistribution{-5.0f, 5.0f, fill_seed(gen)}(b_k_n); } else { @@ -804,12 +811,11 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser, } else if constexpr(QuantMode == ck_tile::QuantType::ABQuantGrouped) { - if constexpr(std::is_same_v) + if constexpr(std::is_same_v || + std::is_same_v) { - ck_tile::FillUniformDistribution{-5.0f, 5.0f, fill_seed(gen)}( - a_m_k); - ck_tile::FillUniformDistribution{-5.0f, 5.0f, fill_seed(gen)}( - b_k_n); + ck_tile::FillUniformDistribution{-5.0f, 5.0f, fill_seed(gen)}(a_m_k); + ck_tile::FillUniformDistribution{-5.0f, 5.0f, fill_seed(gen)}(b_k_n); } else { @@ -984,10 +990,14 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser, if(arg_parser.get_int("v") == 1) { + std::cout << "Performing CPU verification..." << std::endl; + ck_tile::HostTensor c_m_n_host_ref( ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{}))); c_m_n_host_ref.SetZero(); + // Track start time for reference operation + auto start_reference_tick = std::chrono::high_resolution_clock::now(); if constexpr(QuantMode == ck_tile::QuantType::AQuantGrouped) { ck_tile::reference_gemm_quant( @@ -1061,6 +1074,9 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser, rtol_atol.at(ck_tile::number<0>{}), rtol_atol.at(ck_tile::number<1>{})); + // "Stop" our timer + auto verification_finished_tick = std::chrono::high_resolution_clock::now(); + if(!pass) { std::cout << "Relative error threshold: " << rtol_atol.at(ck_tile::number<0>{}) @@ -1068,6 +1084,21 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser, << std::endl; } std::cout << "The CPU verification result is:" << (pass ? "correct" : "fail") << std::endl; + + // Calculate and display reference timing + using DurationType = std::chrono::duration; + double reference_sec = std::chrono::duration_cast(verification_finished_tick - + start_reference_tick) + .count(); + double verification_sec = std::chrono::duration_cast( + verification_finished_tick - start_verification_tick) + .count(); + float reference_msec = static_cast(reference_sec * 1e3); + float verification_msec = static_cast(verification_sec * 1e3); + + std::cout << std::fixed << std::setprecision(1) << "CPU reference GEMM took " + << reference_msec << "ms, verification took " << verification_msec << "ms." + << std::endl; } else if(arg_parser.get_int("v") == 2) { @@ -1098,6 +1129,7 @@ int run_gemm_example_prec_type(const ck_tile::ArgParser& arg_parser) } if constexpr(std::is_same_v || + std::is_same_v || std::is_same_v || std::is_same_v || std::is_same_v) diff --git a/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_global.hpp b/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_global.hpp index 1c322fe4a7..d1c6f30a14 100644 --- a/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_global.hpp +++ b/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_global.hpp @@ -12,16 +12,17 @@ namespace ck { -template + bool DoTranspose, + index_t NumThreadScratch = 1> struct ThreadGroupTransferGlobal { static constexpr auto I0 = Number<0>{}; @@ -32,24 +33,57 @@ struct ThreadGroupTransferGlobal static constexpr auto I5 = Number<5>{}; static constexpr auto I6 = Number<6>{}; - static constexpr index_t nDim = remove_reference_t::GetNumOfDimension(); - using Index = MultiIndex; - using SrcCoord = decltype(make_tensor_coordinate(SrcDesc{}, Index{})); - using DstCoord = decltype(make_tensor_coordinate(DstDesc{}, Index{})); + // return a tuple of coordiantes for a tuple of tensor + template = false> + static constexpr auto MakeCoordinates(const Descs& descs, const Indices& indices) + { + return generate_tuple([&](auto i) { return make_tensor_coordinate(descs[i], indices[i]); }, + Number{}); + } - __device__ ThreadGroupTransferGlobal(const SrcDesc& src_desc, - const DstDesc& dst_desc, - const Index& src_block_slice_origin, - const Index& dst_block_slice_origin, - const ElementwiseOperation& element_op) - : src_coord_(make_tensor_coordinate(src_desc, src_block_slice_origin)), + static constexpr index_t nDim = + remove_cvref_t>::GetNumOfDimension(); + static constexpr index_t nSrc = SrcDescs::Size(); + using Index = MultiIndex; + using SrcCoords = decltype(MakeCoordinates(SrcDescs{}, StaticallyIndexedArray{})); + using DstCoord = decltype(make_tensor_coordinate(DstDesc{}, Index{})); + + __device__ + ThreadGroupTransferGlobal(const SrcDescs& src_descs, + const DstDesc& dst_desc, + const StaticallyIndexedArray& src_block_slice_origins, + const Index& dst_block_slice_origin, + const ElementwiseOperation& element_op) + : src_coords_(MakeCoordinates(src_descs, src_block_slice_origins)), dst_coord_(make_tensor_coordinate(dst_desc, dst_block_slice_origin)), element_op_(element_op) { } - template - __device__ void RunRead(const SrcDesc& src_desc, const GridBufferType& grid_buf) + template + __device__ static auto generate_vectors() + { + auto data_types = DataTypes_{}; + + constexpr index_t num = data_types.Size(); + + return generate_tuple( + [&](auto i) { + using DataType = remove_cvref_t; + + return vector_type_maker_t{}; + }, + Number{}); + } + + template = false> + __device__ void RunRead(SrcDescs& src_descs, + const GridBufferTypes& grid_bufs, + Number thread_scratch_id = Number{}) { constexpr auto src_access_lengths = NumberOfIterations{}; constexpr auto src_dim_access_order = IterationOrder{}; @@ -57,36 +91,6 @@ struct ThreadGroupTransferGlobal container_reorder_given_new2old(src_access_lengths, src_dim_access_order); constexpr auto ordered_fwd_step = StepsPerIteration{}; - // make forward steps - // forward step for each iteration just add 1 - const auto src_forward_steps = generate_tuple( - [&](auto i) { - Index forward_step_idx; - - static_for<0, nDim, 1>{}([&](auto j) { - forward_step_idx(j) = (i.value == j.value) ? ordered_fwd_step[i] : 0; - }); - - return make_tensor_coordinate_step(src_desc, forward_step_idx); - }, - Number{}); - - // make backward steps - // backward step at the end of the dimension iteration subtract IterationLength - 1 - const auto src_backward_steps = generate_tuple( - [&](auto i) { - Index backward_step_idx; - - static_for<0, nDim, 1>{}([&](auto j) { - backward_step_idx(j) = (i.value == j.value) - ? (-src_access_lengths[i] + 1) * ordered_fwd_step[i] - : 0; - }); - - return make_tensor_coordinate_step(src_desc, backward_step_idx); - }, - Number{}); - static_ford{}([&](auto ordered_src_access_idx) { // judge move forward or move backward constexpr auto forward_sweep = [&]() { @@ -157,10 +161,26 @@ struct ThreadGroupTransferGlobal }, Number{}); - // check if src element is valid - const bool is_src_valid = - coordinate_has_valid_offset_assuming_visible_index_is_valid(src_desc, src_coord_); - oob_thread_scratch_.template SetAsType(vgpr_data_idx_seq, is_src_valid); + auto src_vectors = generate_vectors(); + bool oob_val = true; + + static_for<0, nSrc, 1>{}([&](auto i) { + using src_vector_t = typename remove_cvref_t::type; + // check if src element is valid + const bool is_src_valid = + coordinate_has_valid_offset_assuming_visible_index_is_valid(src_descs[i], + src_coords_[i]); + + oob_val = oob_val & is_src_valid; + + // Load data from memory in src_vector first + auto index = is_src_valid || !DoTranspose ? src_coords_[i].GetOffset() : 0; + src_vectors(i).template AsType()(I0) = + grid_bufs[i].template Get(index, true); + }); + + oob_thread_scratch_(thread_scratch_id) + .template SetAsType(vgpr_data_idx_seq, oob_val); // Vector length of elementwise operation constexpr auto get_elem_op_vec_len = []() { @@ -185,57 +205,105 @@ struct ThreadGroupTransferGlobal } }; - // This is 1 for pass through because internally it's doing type conversion constexpr index_t elem_op_vec_len = get_elem_op_vec_len(); - using src_vector_container = vector_type_maker_t; - using src_vector_container_t = typename src_vector_container::type; - - using elem_op_vec_t = typename vector_type::type; - using dst_vector_type = vector_type_maker_t; using dst_vector_t = typename dst_vector_type::type; - dst_vector_type op_r_v; - // Load data from memory in src_vector first - auto index = is_src_valid || !DoTranspose ? src_coord_.GetOffset() : 0; - src_vector_container src_vector = src_vector_container{ - grid_buf.template Get(index, true)}; - // apply the src elementwise op and convert to DstData under the hood if needed static_for<0, VectorSize / elem_op_vec_len, 1>{}([&](auto idx) { - element_op_(op_r_v.template AsType()(idx), - src_vector.template AsType()[idx]); + // get reference to src data + const auto src_data_refs = generate_tie( + // return type should be lvalue + [&](auto iSrc) -> const auto& { + using SrcData = remove_cvref_t>; + + using elem_op_vec_t = typename vector_type::type; + + return src_vectors[iSrc].template AsType()[idx]; + }, + Number{}); + + // get reference to dst data + auto dst_data_refs = generate_tie( + // return type should be lvalue + [&](auto) -> auto& { + using elem_op_vec_t = typename vector_type::type; + + return op_r_v.template AsType()(idx); + }, + Number<1>{}); + + // apply pointwise function + unpack2(element_op_, dst_data_refs, src_data_refs); }); // store result in dvgpr_ (static array holding loaded data). // At this point data is already converted to DstData type and // the elementwise operation has been applied - src_dvgpr_.template SetAsType(vgpr_data_idx_seq, - op_r_v.template AsType()[I0]); + src_dvgpr_(thread_scratch_id) + .template SetAsType(vgpr_data_idx_seq, + op_r_v.template AsType()[I0]); - // For each dimension move fwd, bwd or don't move - static_for<0, nDim, 1>{}([&](auto i) { - if constexpr(move_on_dim[i]) - { - if constexpr(forward_sweep[i]) + // Move each src coordinate + static_for<0, nSrc, 1>{}([&](auto iSrc) { + // make forward steps + // forward step for each iteration just add 1 + const auto src_forward_steps = generate_tuple( + [&](auto iDim) { + Index forward_step_idx; + + static_for<0, nDim, 1>{}([&](auto j) { + forward_step_idx(j) = + (iDim.value == j.value) ? ordered_fwd_step[iDim] : 0; + }); + return make_tensor_coordinate_step(src_descs[iSrc], forward_step_idx); + }, + Number{}); + + // make backward steps + // backward step at the end of the dimension iteration subtract IterationLength - 1 + const auto src_backward_steps = generate_tuple( + [&](auto iDim) { + Index backward_step_idx; + + static_for<0, nDim, 1>{}([&](auto j) { + backward_step_idx(j) = + (iDim.value == j.value) + ? (-src_access_lengths[iDim] + 1) * ordered_fwd_step[iDim] + : 0; + }); + return make_tensor_coordinate_step(src_descs[iSrc], backward_step_idx); + }, + Number{}); + + // For each dimension move fwd, bwd or don't move + static_for<0, nDim, 1>{}([&](auto i) { + if constexpr(move_on_dim[i]) { - move_tensor_coordinate( - src_desc, src_coord_, src_forward_steps[src_dim_access_order[i]]); + if constexpr(forward_sweep[i]) + { + move_tensor_coordinate(src_descs[iSrc], + src_coords_(iSrc), + src_forward_steps[src_dim_access_order[i]]); + } + else + { + move_tensor_coordinate(src_descs[iSrc], + src_coords_(iSrc), + src_backward_steps[src_dim_access_order[i]]); + } } - else - { - move_tensor_coordinate( - src_desc, src_coord_, src_backward_steps[src_dim_access_order[i]]); - } - } + }); }); }); } - template - __device__ void RunWrite(const DstDesc& dst_desc, BlockBufferType& dst_buf) + template + __device__ void RunWrite(const DstDesc& dst_desc, + BlockBufferType& dst_buf, + Number thread_scratch_id = Number{}) { using dst_vector_type = vector_type_maker_t; using dst_vector_t = typename dst_vector_type::type; @@ -272,9 +340,10 @@ struct ThreadGroupTransferGlobal }, Number{}); - auto op_r = src_dvgpr_.template GetAsType(vgpr_data_idx_seq); + auto op_r = + src_dvgpr_(thread_scratch_id).template GetAsType(vgpr_data_idx_seq); const bool is_src_valid = - oob_thread_scratch_.template GetAsType(vgpr_data_idx_seq); + oob_thread_scratch_(thread_scratch_id).template GetAsType(vgpr_data_idx_seq); auto op_r_v = is_src_valid ? op_r : dst_vector_t(0); dst_dvgpr_.template SetAsType(vgpr_data_idx_seq, op_r_v); }); @@ -404,10 +473,12 @@ struct ThreadGroupTransferGlobal }); } - __device__ void MoveSrcSliceWindow(const SrcDesc& src_desc, const Index& step) + __device__ void MoveSrcSliceWindow(const SrcDescs& src_descs, const Index& step) { - const auto adjusted_step = make_tensor_coordinate_step(src_desc, step); - move_tensor_coordinate(src_desc, src_coord_, adjusted_step); + static_for<0, nSrc, 1>{}([&](auto iSrc) { + const auto adjusted_step = make_tensor_coordinate_step(src_descs[iSrc], step); + move_tensor_coordinate(src_descs[iSrc], src_coords_(iSrc), adjusted_step); + }); } private: @@ -443,10 +514,10 @@ struct ThreadGroupTransferGlobal decltype(src_oob_thread_scratch_desc_), true>; - ThreadScratchData src_dvgpr_; + StaticallyIndexedArray src_dvgpr_; ThreadScratchData dst_dvgpr_; - OOBThreadScratch oob_thread_scratch_; - SrcCoord src_coord_; + StaticallyIndexedArray oob_thread_scratch_; + SrcCoords src_coords_; DstCoord dst_coord_; const ElementwiseOperation element_op_; }; diff --git a/include/ck/tensor_operation/gpu/device/device_grouped_conv_bwd_weight.hpp b/include/ck/tensor_operation/gpu/device/device_grouped_conv_bwd_weight.hpp index 58da96e2f0..eadfa29c9f 100644 --- a/include/ck/tensor_operation/gpu/device/device_grouped_conv_bwd_weight.hpp +++ b/include/ck/tensor_operation/gpu/device/device_grouped_conv_bwd_weight.hpp @@ -11,8 +11,6 @@ namespace ck { namespace tensor_operation { namespace device { -#define DISABLE_SPLIT_K_AUTODEDUCE_FOR_ONE_STAGE_KERNELS 1 - template ()) diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_wmma_cshuffle_v3.hpp index bc072a7019..f662ff834f 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_wmma_cshuffle_v3.hpp @@ -22,6 +22,7 @@ #include #include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp" #include "ck/tensor_operation/gpu/device/impl/split_k_arg.hpp" +#include "ck/tensor_operation/gpu/device/impl/split_k_utils.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/host_utility/device_prop.hpp" @@ -524,6 +525,44 @@ struct DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3 decltype(GridwiseGemm::MakeDEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( CGridDesc_M_N{}, 1, 1)); + struct ActiveWorkgroupsPerCU + { + ActiveWorkgroupsPerCU() + { + if(!ck::is_gfx11_supported() && !ck::is_gfx12_supported()) + { + return; + } + constexpr int dynamic_smem_size = 0; + constexpr index_t minimum_occupancy = + BlkGemmPipeSched == BlockGemmPipelineScheduler::Intrawave ? 1 : 2; + int max_occupancy = 0; + + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v4) + { + // TODO: implement + } + else + { + hip_check_error(hipOccupancyMaxActiveBlocksPerMultiprocessor( + &max_occupancy, + kernel_grouped_conv_bwd_weight_wmma_cshuffle_v3_multiple_d< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t, + ComputePtrOffsetOfStridedBatch, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy>, + BlockSize, + dynamic_smem_size)); + } + max_occupancy_ = std::max(1, max_occupancy); + } + int max_occupancy_; + }; + struct Argument : public BaseArgument, public ArgumentSplitK { Argument( @@ -574,6 +613,8 @@ struct DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3 input_left_pads_{input_left_pads}, input_right_pads_{input_right_pads} { + static ActiveWorkgroupsPerCU active_workgroups_per_cu; + constexpr index_t spatial_offset = 3; std::copy(begin(b_g_n_c_wis_lengths) + spatial_offset, end(b_g_n_c_wis_lengths), @@ -585,7 +626,6 @@ struct DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3 end(a_g_n_k_wos_lengths), begin(output_spatial_lengths_)); -#if !DISABLE_SPLIT_K_AUTODEDUCE_FOR_ONE_STAGE_KERNELS if(split_k < 0) { ck::index_t gemmM, gemmN, gemmK; @@ -602,6 +642,9 @@ struct DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3 const auto k_batch_max = math::integer_divide_ceil((gemmK - 1), KPerBlock); k_batch_ = std::min(k_batch_, k_batch_max); + // Cap k_batch_ to 128 to avoid accuracy issues + k_batch_ = std::min(k_batch_, 128); + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) { std::cout << "[SPLIT-K AUTODEDUCE] k_batch max value: " << k_batch_max @@ -611,7 +654,6 @@ struct DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3 } } else -#endif { k_batch_ = split_k; } @@ -988,13 +1030,6 @@ struct DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3 static bool IsSupportedArgument(const Argument& arg) { -#if DISABLE_SPLIT_K_AUTODEDUCE_FOR_ONE_STAGE_KERNELS - if(arg.k_batch_ < 0) - { - return false; - } -#endif - const index_t GemmM = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I1); const index_t GemmN = arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I1); const index_t GemmK = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I0) * diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp index 51dc56e306..1e23fef191 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp @@ -677,7 +677,6 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle end(a_g_n_k_wos_lengths), begin(output_spatial_lengths_)); -#if !DISABLE_SPLIT_K_AUTODEDUCE_FOR_ONE_STAGE_KERNELS if(split_k < 0) { ck::index_t gemmM, gemmN; @@ -688,9 +687,11 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle calculate_mn_grid_size(gemmM, gemmN) * Conv_G_; k_batch_ = get_best_occupancy_k_batch_value(active_workgroups_per_cu.max_occupancy_, grid_size); + + // Cap k_batch_ to 128 to avoid accuracy issues + k_batch_ = std::min(k_batch_, 128); } else -#endif { k_batch_ = split_k; } @@ -947,12 +948,6 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle static bool IsSupportedArgument(const Argument& arg) { -#if DISABLE_SPLIT_K_AUTODEDUCE_FOR_ONE_STAGE_KERNELS - if(arg.k_batch_ < 0) - { - return false; - } -#endif if(!ck::is_xdl_wmma_supported()) { return false; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle_v3.hpp index 3f8093afe1..b2ae092c27 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle_v3.hpp @@ -511,7 +511,7 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffleV3 std::copy(begin(a_g_n_k_wos_lengths) + spatial_offset, end(a_g_n_k_wos_lengths), begin(output_spatial_lengths_)); -#if !DISABLE_SPLIT_K_AUTODEDUCE_FOR_ONE_STAGE_KERNELS + if(split_k < 0) { ck::index_t gemmM, gemmN, gemmK; @@ -528,6 +528,9 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffleV3 const auto k_batch_max = math::integer_divide_ceil((gemmK - 1), KPerBlock); k_batch_ = std::min(k_batch_, k_batch_max); + // Cap k_batch_ to 128 to avoid accuracy issues + k_batch_ = std::min(k_batch_, 128); + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) { std::cout << "[SPLIT-K AUTODEDUCE] k_batch max value: " << k_batch_max @@ -537,7 +540,6 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffleV3 } } else -#endif { k_batch_ = split_k; } @@ -1040,12 +1042,6 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffleV3 static bool IsSupportedArgument(const Argument& arg) { -#if DISABLE_SPLIT_K_AUTODEDUCE_FOR_ONE_STAGE_KERNELS - if(arg.k_batch_ < 0) - { - return false; - } -#endif const index_t GemmM = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I1); const index_t GemmN = arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I1); const index_t GemmK = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I0) * diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp index 0ea94806d0..1f6f2fb789 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp @@ -651,7 +651,6 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle conv_ngchw_to_nhwgc_transformer.TransposeWeiStrides(e_g_k_c_xs_lengths, e_g_k_c_xs_strides); -#if !DISABLE_SPLIT_K_AUTODEDUCE_FOR_ONE_STAGE_KERNELS if(split_k < 0) { ck::index_t gemmM, gemmN; @@ -662,9 +661,11 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle calculate_mn_grid_size(gemmM, gemmN) * Conv_G_; k_batch_ = get_best_occupancy_k_batch_value(active_workgroups_per_cu.max_occupancy_, grid_size); + + // Cap k_batch_ to 128 to avoid accuracy issues + k_batch_ = std::min(k_batch_, 128); } else -#endif { k_batch_ = split_k; } @@ -1083,12 +1084,6 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle static bool IsSupportedArgument(const Argument& arg) { -#if DISABLE_SPLIT_K_AUTODEDUCE_FOR_ONE_STAGE_KERNELS - if(arg.k_batch_ < 0) - { - return false; - } -#endif if(!ck::is_xdl_wmma_supported()) { return false; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp index 26cf586017..ac83cee251 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp @@ -594,7 +594,6 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 end(a_g_n_k_wos_lengths), begin(output_spatial_lengths_)); -#if !DISABLE_SPLIT_K_AUTODEDUCE_FOR_ONE_STAGE_KERNELS if(split_k < 0) { ck::index_t gemmM, gemmN, gemmK; @@ -611,6 +610,9 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 const auto k_batch_max = static_cast((gemmK - 1) / K0PerBlock); k_batch_ = std::max(std::min(k_batch_, k_batch_max), 1); + // Cap k_batch_ to 128 to avoid accuracy issues + k_batch_ = std::min(k_batch_, 128); + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) { std::cout << "[SPLIT-K AUTODEDUCE] k_batch max value: " << k_batch_max @@ -620,7 +622,6 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 } } else -#endif { k_batch_ = split_k; } @@ -1399,13 +1400,6 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 static bool IsSupportedArgument(const Argument& arg) { -#if DISABLE_SPLIT_K_AUTODEDUCE_FOR_ONE_STAGE_KERNELS - if(arg.k_batch_ < 0) - { - return false; - } -#endif - // check device if constexpr(DirectLoad) { diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_ab_transfer_thread_tiles.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_ab_transfer_thread_tiles.hpp index 96387c6f64..4d5c052e02 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_ab_transfer_thread_tiles.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_ab_transfer_thread_tiles.hpp @@ -488,6 +488,19 @@ struct ABTransferThreadTiles { return make_dynamic_buffer(p_shared_AB, size); } + + template + __device__ __forceinline__ static auto get_first_element_workaround(Type& array) + { + if constexpr(numElements > 1) + { + return array; + } + else + { + return array[I0]; + } + } }; } // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_ab_transfer_thread_tiles_preshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_ab_transfer_thread_tiles_preshuffle.hpp index ad9af92ae5..fb6d1451d3 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_ab_transfer_thread_tiles_preshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_ab_transfer_thread_tiles_preshuffle.hpp @@ -133,6 +133,19 @@ struct ABTransferThreadTilesPreShuffle { return make_static_buffer(size); } + + template + __device__ __forceinline__ static auto get_first_element_workaround(Type& array) + { + if constexpr(numElements > 1) + { + return array; + } + else + { + return array[I0]; + } + } }; } // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_ab_transfer_wave_tiles.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_ab_transfer_wave_tiles.hpp index caf468d6cb..63c0299750 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_ab_transfer_wave_tiles.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_ab_transfer_wave_tiles.hpp @@ -318,43 +318,43 @@ struct ABTransferWaveTiles const index_t block_mn_id, const index_t) { - // Note: GlobalBufferNum is currently not used but it will be needed - // once we add other pipelines. It is currently needed only for - // consistency with the thread tiles approach - static_assert(GlobalBufferNum == 1, "single global buffer is only supported"); constexpr index_t NumABTensor = ABsDataType::Size(); - static_assert(NumABTensor == 1, "multiAB currently not supported"); - - using ABDataType = remove_cvref_t>; const auto wave_idx = GetWaveIdx(); index_t wave_idK = wave_idx[I1]; index_t wave_idMN = wave_idx[I0]; - const auto grid_lane_id = GetGridLaneIdx(); - index_t lane_group_grid = grid_lane_id[I0]; - index_t lane_local_id_grid = grid_lane_id[I1]; - const auto block_lane_id = GetBlockLaneIdx(); index_t lane_group_block = block_lane_id[I0]; index_t lane_local_id_block = block_lane_id[I1]; - return ThreadGroupTransferGlobal>; + const auto grid_lane_id = GetGridLaneIdx(); + index_t lane_group_grid = grid_lane_id[I0]; + index_t lane_local_id_grid = grid_lane_id[I1]; + return make_multi_index(block_mn_id * (MNRepeat_ * MNWaves_) + wave_idMN, + wave_idK, + lane_group_grid, + lane_local_id_grid); + }, + Number{}); + + return ThreadGroupTransferGlobal, Sequence, Sequence, ABK1Value, - ABDoTranspose>( - grid_descriptor[I0], + ABDoTranspose, + GlobalBufferNum>( + grid_descriptor, block_descriptor, - make_multi_index(block_mn_id * (MNRepeat_ * MNWaves_) + wave_idMN, - wave_idK, - lane_group_grid, - lane_local_id_grid), + idx_as_block_begin, make_multi_index(wave_idMN, wave_idK, lane_group_block, lane_local_id_block), ab_element_op); } @@ -398,6 +398,12 @@ struct ABTransferWaveTiles { return make_dynamic_buffer(p_shared_AB, size); } + + template + __device__ __forceinline__ static auto get_first_element_workaround(Type& array) + { + return array; + } }; } // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_ab_transfer_wave_tiles_interleave.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_ab_transfer_wave_tiles_interleave.hpp index bfe5b7bd08..e1ee47770b 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_ab_transfer_wave_tiles_interleave.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_ab_transfer_wave_tiles_interleave.hpp @@ -218,45 +218,46 @@ struct ABTransferWaveTilesInterleave : ABTransferWaveTiles>; const auto wave_idx = GetWaveIdx(); index_t wave_idK = wave_idx[I1]; index_t wave_idMN = wave_idx[I0]; - const auto grid_lane_id = Base::template GetGridLaneIdx(); - index_t lane_group_grid = grid_lane_id[I0]; - index_t lane_local_id_grid = grid_lane_id[I1]; - const auto block_lane_id = GetBlockLaneIdx(); index_t lane_group_block = block_lane_id[I0]; index_t lane_local_id_block = block_lane_id[I1]; constexpr index_t MNRepeatRatio = MNRepeat_Grid / MNRepeat_; - return ThreadGroupTransferGlobal>; + const auto grid_lane_id = Base::template GetGridLaneIdx(); + index_t lane_group_grid = grid_lane_id[I0]; + index_t lane_local_id_grid = grid_lane_id[I1]; + return make_multi_index(block_mn_id * MNWaves_Grid + wave_idMN / MNRepeatRatio, + wave_idK * KRepeat_Grid, + (wave_idMN % MNRepeatRatio) * MNRepeat_, + lane_group_grid, + lane_local_id_grid); + }, + Number{}); + + return ThreadGroupTransferGlobal, Sequence, Sequence, ABK1Value, - ABDoTranspose>( - grid_descriptor[I0], + ABDoTranspose, + GlobalBufferNum>( + grid_descriptor, block_descriptor, - make_multi_index(block_mn_id * MNWaves_Grid + wave_idMN / MNRepeatRatio, - wave_idK * KRepeat_Grid, - (wave_idMN % MNRepeatRatio) * MNRepeat_, - lane_group_grid, - lane_local_id_grid), + idx_as_block_begin, make_multi_index(wave_idMN / MNRepeatRatio, wave_idK * KRepeat_, (wave_idMN % MNRepeatRatio) * MNRepeat_, diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp index bcf131003c..03735bbc6a 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp @@ -364,7 +364,7 @@ struct GridwiseGemm_wmma_cshuffle_v3_base __host__ __device__ static constexpr bool AWaveTransferApplicable() { - return !ForceThreadTileTransfer && NumATensor == 1 && APackedSize == 1 && + return !ForceThreadTileTransfer && APackedSize == 1 && ABlockTransferSrcScalarPerVector == 8 && ABlockTransferDstScalarPerVector_AK1 == 8 && BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 && AK1Value == 8 && !IsBPreShuffled; @@ -372,13 +372,11 @@ struct GridwiseGemm_wmma_cshuffle_v3_base __host__ __device__ static constexpr bool BWaveTransferApplicable() { - return !ForceThreadTileTransfer && NumBTensor == 1 && BPackedSize == 1 && + return !ForceThreadTileTransfer && BPackedSize == 1 && BBlockTransferSrcScalarPerVector == 8 && BBlockTransferDstScalarPerVector_BK1 == 8 && BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 && BK1Value == 8; } - // Limitations of the current implementation: - // - no multiAB #ifdef __gfx12__ static constexpr bool IsAWaveTransferApplicable = AWaveTransferApplicable(); @@ -1319,19 +1317,6 @@ struct GridwiseGemm_wmma_cshuffle_v3_base } } - template - __device__ __forceinline__ static auto get_first_element_workaround(Type& array) - { - if constexpr(numElements > 1) - { - return array; - } - else - { - return array[I0]; - } - } - // Note: arguments k_batch and k_id should be set if splitk is used // with implicit gemm (no pointer shift but shift using tensor descriptors) template ( - get_first_element_workaround(as_grid_desc_ak0_m_ak1), + ATransfer::template get_first_element_workaround(as_grid_desc_ak0_m_ak1), a_block_desc_ak0_m_ak1, a_blockwise_copy, - get_first_element_workaround(as_grid_buf), + ATransfer::template get_first_element_workaround(as_grid_buf), a_block_buf, a_block_slice_copy_step, - get_first_element_workaround(bs_grid_desc_bk0_n_bk1), + BTransfer::template get_first_element_workaround(bs_grid_desc_bk0_n_bk1), b_block_desc_bk0_n_bk1, b_blockwise_copy, - get_first_element_workaround(bs_grid_buf), + BTransfer::template get_first_element_workaround(bs_grid_buf), b_block_buf, b_block_slice_copy_step, c_thread_buf, diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_softmax.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_softmax.hpp index 96e13ac55c..a6fa04a824 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_softmax.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_softmax.hpp @@ -26,7 +26,7 @@ __global__ void kernel_softmax(const GridDesc_M_K in_grid_desc_m_k, AccDataType alpha, const InDataType* const __restrict__ p_in_value_global, AccDataType beta, - OutDataType* const __restrict__ p_out_value_global) + OutDataType* p_out_value_global) { GridwiseReduction::Run(in_grid_desc_m_k, out_grid_desc_m_k, @@ -91,7 +91,7 @@ struct GridwiseSoftmax_mk_to_mk AccDataType alpha, const InDataType* const __restrict__ p_in_value_global, AccDataType beta, - OutDataType* const __restrict__ p_out_value_global) + OutDataType* p_out_value_global) { if constexpr(SweepOnce) { diff --git a/include/ck_tile/core.hpp b/include/ck_tile/core.hpp index f3596df9bd..438e44f5f1 100644 --- a/include/ck_tile/core.hpp +++ b/include/ck_tile/core.hpp @@ -91,6 +91,7 @@ #include "ck_tile/core/utility/ignore.hpp" #include "ck_tile/core/utility/literals.hpp" #include "ck_tile/core/utility/magic_div.hpp" +#include "ck_tile/core/utility/mixed_prec_compute_type.hpp" #include "ck_tile/core/utility/persistent_async_input_scheduler.hpp" #include "ck_tile/core/utility/philox_rand.hpp" #include "ck_tile/core/utility/print.hpp" diff --git a/include/ck_tile/core/arch/amd_buffer_addressing.hpp b/include/ck_tile/core/arch/amd_buffer_addressing.hpp index 7af2f558ad..8f9dd30bda 100644 --- a/include/ck_tile/core/arch/amd_buffer_addressing.hpp +++ b/include/ck_tile/core/arch/amd_buffer_addressing.hpp @@ -1544,7 +1544,8 @@ CK_TILE_DEVICE thread_buffer amd_buffer_load_impl(int32x4_t src_wave_buffe (N == 1 || N == 2 || N == 4 || N == 8 || N == 16 || N == 32)) || (std::is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || - (std::is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)), + (std::is_same::value && + (N == 1 || N == 2 || N == 4 || N == 8 || N == 16 || N == 32)), "wrong! not implemented"); using rtn_type = thread_buffer; diff --git a/include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp b/include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp index 9f9770df1b..42886b8ced 100644 --- a/include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp +++ b/include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp @@ -1414,7 +1414,7 @@ CK_TILE_DEVICE thread_buffer amd_buffer_load_impl(int32x4_t src_wave_buffe (std::is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16 || N == 32) || (std::is_same::value && - (N == 1 || N == 2 || N == 4 || N == 8 || N == 16))), + (N == 1 || N == 2 || N == 4 || N == 8 || N == 16 || N == 32))), "wrong! not implemented"); using rtn_type = thread_buffer; diff --git a/include/ck_tile/core/numeric/pk_fp4.hpp b/include/ck_tile/core/numeric/pk_fp4.hpp index cc23ce71a8..d74db6b336 100644 --- a/include/ck_tile/core/numeric/pk_fp4.hpp +++ b/include/ck_tile/core/numeric/pk_fp4.hpp @@ -6,6 +6,7 @@ #include #include "ck_tile/core/config.hpp" #include "ck_tile/core/numeric/half.hpp" +#include "ck_tile/core/numeric/float8.hpp" #include "ck_tile/core/numeric/mxfp_convert.hpp" #if defined(__gfx950__) @@ -23,6 +24,12 @@ using fp32x2_t = float __attribute__((ext_vector_type(2))); using fp16x2_t = _Float16 __attribute__((ext_vector_type(2))); using bf16x2_t = bfloat16_t __attribute__((ext_vector_type(2))); +#if CK_TILE_USE_CUSTOM_DATA_TYPE +using fp8x2_t = fp8_raw_t __attribute__((ext_vector_type(2))); +#else +using fp8x2_t = fp8_t __attribute__((ext_vector_type(2))); +#endif + // Helpers: constexpr-safe access to elements of ext_vector_type(2) // Some compilers don't allow operator[] in constant expressions for vector types. // We use bit_cast to a trivially copyable representation to extract lanes. @@ -98,6 +105,8 @@ struct pk_float4_e2m1_t CK_TILE_HOST_DEVICE constexpr fp16x2_t to_fp16x2(float scale = 1.f) const; CK_TILE_HOST_DEVICE constexpr bf16_t to_bf16(float scale = 1.f) const; CK_TILE_HOST_DEVICE constexpr bf16x2_t to_bf16x2(float scale = 1.f) const; + CK_TILE_HOST_DEVICE constexpr fp8_t to_fp8(float scale = 1.f) const; + CK_TILE_HOST_DEVICE constexpr fp8x2_t to_fp8x2(float scale = 1.f) const; CK_TILE_HOST_DEVICE constexpr operator float() const { return to_float(); } CK_TILE_HOST_DEVICE constexpr operator fp32x2_t() const { return to_fp32x2(); } @@ -105,6 +114,8 @@ struct pk_float4_e2m1_t CK_TILE_HOST_DEVICE constexpr operator fp16x2_t() const { return to_fp16x2(); } CK_TILE_HOST_DEVICE constexpr operator bf16_t() const { return to_bf16(); } CK_TILE_HOST_DEVICE constexpr operator bf16x2_t() const { return to_bf16x2(); } + CK_TILE_HOST_DEVICE constexpr operator fp8_t() const { return to_fp8(); } + CK_TILE_HOST_DEVICE constexpr operator fp8x2_t() const { return to_fp8x2(); } template CK_TILE_HOST_DEVICE constexpr pk_float4_e2m1_t unpack(number) const @@ -145,6 +156,49 @@ struct pk_float4_e2m1_t bit_cast(static_cast(0xC400)), // -4 bit_cast(static_cast(0xC600)) // -6 }; + +#if CK_TILE_USE_OCP_FP8 + // FP8 EM4E3 (OCP) representation + static constexpr fp8_t e2m1_to_fp8_table[16] = { + fp8_t(static_cast(0x00)), // 0 + fp8_t(static_cast(0x30)), // 0.5 + fp8_t(static_cast(0x38)), // 1 + fp8_t(static_cast(0x3C)), // 1.5 + fp8_t(static_cast(0x40)), // 2 + fp8_t(static_cast(0x44)), // 3 + fp8_t(static_cast(0x48)), // 4 + fp8_t(static_cast(0x4C)), // 6 + fp8_t(static_cast(0x00)), // -0 + fp8_t(static_cast(0xB0)), // -0.5 + fp8_t(static_cast(0xB8)), // -1 + fp8_t(static_cast(0xBC)), // -1.5 + fp8_t(static_cast(0xC0)), // -2 + fp8_t(static_cast(0xC4)), // -3 + fp8_t(static_cast(0xC8)), // -4 + fp8_t(static_cast(0xCC)) // -6 + }; +#else // CK_TILE_USE_FNUZ_FP8 + // FP8 E4M3 FNUZ + static constexpr fp8_t e2m1_to_fp8_table[16] = { + fp8_t(static_cast(0x00)), // 0 + fp8_t(static_cast(0x38)), // 0.5 + fp8_t(static_cast(0x40)), // 1 + fp8_t(static_cast(0x44)), // 1.5 + fp8_t(static_cast(0x48)), // 2 + fp8_t(static_cast(0x4C)), // 3 + fp8_t(static_cast(0x50)), // 4 + fp8_t(static_cast(0x54)), // 6 + fp8_t(static_cast(0x00)), // -0 + fp8_t(static_cast(0xB8)), // -0.5 + fp8_t(static_cast(0xC0)), // -1 + fp8_t(static_cast(0xC4)), // -1.5 + fp8_t(static_cast(0xC4)), // -2 + fp8_t(static_cast(0xCC)), // -3 + fp8_t(static_cast(0xD0)), // -4 + fp8_t(static_cast(0xD4)) // -6 + }; +#endif + #endif }; @@ -408,6 +462,27 @@ CK_TILE_HOST_DEVICE constexpr fp16x2_t pk_fp4_t::to_fp16x2(float scale) const type_convert(convert_to_float(_unpack(number<1>{}), scale))}; #endif } +CK_TILE_HOST_DEVICE constexpr fp8_t pk_fp4_t::to_fp8(float scale) const +{ + // NOTE: No specialized fp4 to fp8 instructions are available. Unsure whether fp4 to fp16 to fp8 + // would be better than the naive implementation below + // #if CK_TILE_FP4_CVT_DEVICE + // return impl::_from_f4(data, scale); + // #else + return fp8_t{type_convert(convert_to_float(_unpack(number<0>{}), scale))}; + // #endif +} +CK_TILE_HOST_DEVICE constexpr fp8x2_t pk_fp4_t::to_fp8x2(float scale) const +{ + // NOTE: No specialized fp4 to fp8 instructions are available. Unsure whether fp4 to fp16 to fp8 + // would be better than the naive implementation below + // #if CK_TILE_FP4_CVT_DEVICE + // return impl::_from_f4(data, scale); + // #else + return fp8x2_t{type_convert(convert_to_float(_unpack(number<0>{}), scale)), + type_convert(convert_to_float(_unpack(number<1>{}), scale))}; + // #endif +} #else CK_TILE_HOST_DEVICE constexpr float pk_fp4_t::to_float(float scale) const { @@ -415,7 +490,8 @@ CK_TILE_HOST_DEVICE constexpr float pk_fp4_t::to_float(float scale) const } CK_TILE_HOST_DEVICE constexpr fp32x2_t pk_fp4_t::to_fp32x2(float scale) const { - return fp32x2_t{e2m1_to_fp32_table[_unpack(number<0>{})] * scale, e2m1_to_fp32_table[_unpack(number<1>{}] * scale}; + return fp32x2_t{e2m1_to_fp32_table[_unpack(number<0>{})] * scale, + e2m1_to_fp32_table[_unpack(number<1>{})] * scale}; } CK_TILE_HOST_DEVICE constexpr fp16_t pk_fp4_t::to_fp16(float scale) const { @@ -428,6 +504,16 @@ CK_TILE_HOST_DEVICE constexpr fp16x2_t pk_fp4_t::to_fp16x2(float scale) const type_convert(type_convert(e2m1_to_fp16_table[_unpack(number<1>{})]) * scale)}; } +CK_TILE_HOST_DEVICE constexpr fp8_t pk_fp4_t::to_fp8(float scale) const +{ + return type_convert(e2m1_to_fp8_table[_unpack(number<0>{})]) * scale; +} +CK_TILE_HOST_DEVICE constexpr fp8x2_t pk_fp4_t::to_fp8x2(float scale) const +{ + return fp8x2_t{ + type_convert(type_convert(e2m1_to_fp8_table[_unpack(number<0>{})]) * scale), + type_convert(type_convert(e2m1_to_fp8_table[_unpack(number<1>{})]) * scale)}; +} #endif } // namespace ck_tile diff --git a/include/ck_tile/core/numeric/pk_int4.hpp b/include/ck_tile/core/numeric/pk_int4.hpp index d5df4d1917..9eb62a6ec4 100644 --- a/include/ck_tile/core/numeric/pk_int4.hpp +++ b/include/ck_tile/core/numeric/pk_int4.hpp @@ -6,6 +6,7 @@ #include "ck_tile/core/numeric/integral_constant.hpp" #include "ck_tile/core/numeric/math.hpp" #include "ck_tile/core/numeric/numeric.hpp" +#include "ck_tile/core/numeric/pk_fp4.hpp" #include "ck_tile/core/utility/bit_cast.hpp" #include "ck_tile/core/utility/random.hpp" #include @@ -23,6 +24,11 @@ struct pk_int4_t type data; CK_TILE_HOST_DEVICE constexpr pk_int4_t() : data{type{}} {} CK_TILE_HOST_DEVICE constexpr pk_int4_t(type init) : data{init} {} + + // NOTE: added for interface compatibility with pk_fp4_t + // Other data types could be added for greater similarity + CK_TILE_HOST_DEVICE constexpr fp32x2_t to_fp32x2() const; + CK_TILE_HOST_DEVICE constexpr operator fp32x2_t() const { return to_fp32x2(); } }; // limits @@ -186,4 +192,9 @@ CK_TILE_HOST_DEVICE int8x2_t pk_int4_t_to_int8x2_t(const pk_int4_t& x) return res; } +CK_TILE_HOST_DEVICE constexpr fp32x2_t pk_int4_t::to_fp32x2() const +{ + return pk_int4_t_to_fp32x2_t(*this); +} + } // namespace ck_tile diff --git a/include/ck_tile/core/numeric/vector_type.hpp b/include/ck_tile/core/numeric/vector_type.hpp index 90ddc2a56e..def054f415 100644 --- a/include/ck_tile/core/numeric/vector_type.hpp +++ b/include/ck_tile/core/numeric/vector_type.hpp @@ -11,6 +11,7 @@ #include "ck_tile/core/numeric/half.hpp" #include "ck_tile/core/numeric/bfloat16.hpp" #include "ck_tile/core/numeric/pk_int4.hpp" +#include "ck_tile/core/numeric/pk_fp4.hpp" #include "ck_tile/core/numeric/e8m0.hpp" #include "ck_tile/core/utility/type_traits.hpp" diff --git a/include/ck_tile/core/utility/mixed_prec_compute_type.hpp b/include/ck_tile/core/utility/mixed_prec_compute_type.hpp new file mode 100644 index 0000000000..021763c108 --- /dev/null +++ b/include/ck_tile/core/utility/mixed_prec_compute_type.hpp @@ -0,0 +1,54 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck_tile/core/config.hpp" +#include "ck_tile/core/numeric/numeric.hpp" +#include "ck_tile/core/utility/type_traits.hpp" + +#include + +namespace ck_tile { + +namespace detail { + +// Helper method to automatically determine compute type +// Selects the largest type of the two. If both of them are packed data types, defaults to fp8. +template +struct auto_compute_type +{ + using LargestInputType = largest_type_t; + + // Sanity check: there are no packed types larger than 1 byte yet, but if we add them + // this logic should change + static_assert(!is_packed_type_v || sizeof(LargestInputType) == sizeof(fp8_t)); + + using type = std::conditional_t, fp8_t, LargestInputType>; +}; + +// Helper method to determine compute type, defaulting an explicitly passed-in compute type +template +struct mixed_prec_compute_type +{ + using type = std::conditional_t, + typename auto_compute_type::type, + ComputeDataType>; +}; + +} // namespace detail + +template +using mixed_prec_compute_type_t = + typename detail::mixed_prec_compute_type::type; + +// Helper method to determine compute type, defaulting to input data type +// If "ThisDataType" is packed (4-bit), will default to "OtherDataType". If both are packed, +// ComputeDataType is used. +template +using mixed_prec_compute_type_from_input_t = std::conditional_t< + is_packed_type_v, + std::conditional_t, ComputeDataType, OtherDataType>, + ThisDataType>; + +} // namespace ck_tile diff --git a/include/ck_tile/core/utility/type_traits.hpp b/include/ck_tile/core/utility/type_traits.hpp index f07e25e19c..c11d180839 100644 --- a/include/ck_tile/core/utility/type_traits.hpp +++ b/include/ck_tile/core/utility/type_traits.hpp @@ -4,6 +4,8 @@ #pragma once #include "ck_tile/core/config.hpp" +#include "ck_tile/core/numeric/numeric.hpp" + #include #include #include @@ -187,4 +189,19 @@ template using tuple_element_or_default_t = typename tuple_element_or_default::type; +// Helper struct to determine if a type is packed (more than 1 element per byte) +template +struct is_packed_type +{ + static constexpr bool value = numeric_traits::PackedSize > 1; +}; + +template +static constexpr bool is_packed_type_v = is_packed_type::value; + +// Helper definition to take the largest sizes type +template +using largest_type_t = + std::conditional_t= sizeof(BDataType), ADataType, BDataType>; + } // namespace ck_tile diff --git a/include/ck_tile/host/reference/reference_gemm.hpp b/include/ck_tile/host/reference/reference_gemm.hpp index 9ad5af8264..7830150b63 100644 --- a/include/ck_tile/host/reference/reference_gemm.hpp +++ b/include/ck_tile/host/reference/reference_gemm.hpp @@ -137,47 +137,55 @@ CK_TILE_HOST void reference_gemm_abquant(const HostTensor& a_m_k, const BElementOp& b_element_op = {}, const ACCElementOp& acc_element_op = {}) { - const std::size_t M = a_m_k.get_length(0); - const std::size_t N = b_k_n.get_length(1); - const std::size_t K = a_m_k.get_length(1); + constexpr auto A_TENSOR_M_DIM = 0; + constexpr auto A_TENSOR_K_DIM = 1; + constexpr auto B_TENSOR_K_DIM = 0; + constexpr auto B_TENSOR_N_DIM = 1; + + const std::size_t M = a_m_k.get_length(A_TENSOR_M_DIM); + const std::size_t N = b_k_n.get_length(B_TENSOR_N_DIM); + const std::size_t K = a_m_k.get_length(A_TENSOR_K_DIM); + + // Pre-convert A/B tensors to AccData type + // This prevents doing slow reconversions for each row/column + HostTensor a_acc(a_m_k.mDesc); + HostTensor b_acc(b_k_n.mDesc); + + a_acc.ForEach([&](auto& self, auto index) { + if constexpr(std::is_same_v || std::is_same_v) + { + const ADataType pk_val = a_element_op(a_m_k(index)); + const fp32x2_t fp32_val = pk_val.to_fp32x2(); + self(index) = (index[A_TENSOR_K_DIM] & 1) ? fp32_val.hi : fp32_val.lo; + } + else + { + self(index) = ck_tile::type_convert(a_element_op(a_m_k(index))); + } + }); + + b_acc.ForEach([&](auto& self, auto index) { + if constexpr(std::is_same_v || std::is_same_v) + { + const BDataType pk_val = b_element_op(b_k_n(index)); + const fp32x2_t fp32_val = pk_val.to_fp32x2(); + self(index) = (index[B_TENSOR_K_DIM] & 1) ? fp32_val.hi : fp32_val.lo; + } + else if constexpr(std::is_same_v) + { + self(index) = fp8_to_float_raw(b_element_op(b_k_n(index))); + } + else + { + self(index) = ck_tile::type_convert(b_element_op(b_k_n(index))); + } + }); auto f_mn = [&](auto m, auto n) { AccDataType v_acc = 0; constexpr std::size_t kGroupK = BQuantGroupSize::kK; - // ---- A loader: dequant A(m,k) into AccDataType ---- - auto load_a = [&](std::size_t k) -> AccDataType { - if constexpr(std::is_same_v) - { - const pk_int4_t pk_val = a_element_op(a_m_k(m, k)); - const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t(pk_val); - return (k & 1) ? fp32_val.hi : fp32_val.lo; - } - else - { - return ck_tile::type_convert(a_element_op(a_m_k(m, k))); - } - }; - - // ---- B loader: dequant B(k,n) into AccDataType ---- - auto load_b = [&](std::size_t k) -> AccDataType { - if constexpr(std::is_same_v) - { - const pk_int4_t pk_val = b_element_op(b_k_n(k, n)); - const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t(pk_val); - return (k & 1) ? fp32_val.hi : fp32_val.lo; - } - else if constexpr(std::is_same_v) - { - return fp8_to_float_raw(b_element_op(b_k_n(k, n))); - } - else - { - return ck_tile::type_convert(b_element_op(b_k_n(k, n))); - } - }; - // ---- a scale loader for a given K-group index ---- auto load_scale_a = [&](ck_tile::index_t k_group) -> float { const ck_tile::index_t outer_dim = m / AQuantGroupSize::kM; @@ -224,8 +232,8 @@ CK_TILE_HOST void reference_gemm_abquant(const HostTensor& a_m_k, // unscaled accumulation within this K-group for(std::size_t k = k_begin; k < k_end; ++k) { - const AccDataType v_a = load_a(k); - const AccDataType v_b = load_b(k); + const AccDataType v_a = a_acc(m, k); + const AccDataType v_b = b_acc(k, n); v_block_acc += v_a * v_b; } diff --git a/include/ck_tile/ops/common/load_interleaved_pk_type.hpp b/include/ck_tile/ops/common/load_interleaved_pk_type.hpp index 10c2a1e4df..3f1a3b8f1c 100644 --- a/include/ck_tile/ops/common/load_interleaved_pk_type.hpp +++ b/include/ck_tile/ops/common/load_interleaved_pk_type.hpp @@ -4,11 +4,12 @@ #pragma once #include "ck_tile/core/config.hpp" +#include "ck_tile/core/utility/type_traits.hpp" #include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp" namespace ck_tile { -template +template struct InterleavedPKTypeLoader { template @@ -21,10 +22,15 @@ struct InterleavedPKTypeLoader constexpr index_t thread_buffer_size = WarpTile::get_thread_buffer_size() / UnaryOpSize; const auto in_dstr_tensors = load_tile(warp_window); - using DstVectorType = DstDataType __attribute__((ext_vector_type(UnaryOpSize))); + // NOTE: we rely on types packing neatly here + using RawSrcType = typename SrcDataType::type; + constexpr auto PackedSize = numeric_traits::PackedSize; + + using SrcVectorType = ext_vector_t; + using DstVectorType = ext_vector_t; static_for<0, thread_buffer_size, 1>{}([&](auto i) { elementwise_op(warp_tile.get_thread_buffer().template get_as()(i), - in_dstr_tensors.get_thread_buffer().template get_as()[i]); + in_dstr_tensors.get_thread_buffer().template get_as()[i]); }); } }; @@ -37,10 +43,11 @@ template CK_TILE_DEVICE void load_int4_tile(WarpTile& dst, const WarpWindow& src) { - if constexpr(std::is_same_v) + if constexpr(is_packed_type_v) { - static_assert(!LoadTranspose, "LoadTranspose not supported with pk_int4_t"); - InterleavedPKTypeLoader::load_interleaved_pk_type(dst, src); + static_assert(!LoadTranspose, "LoadTranspose not supported with pk_int4_t or pk_fp4_t"); + InterleavedPKTypeLoader::load_interleaved_pk_type( + dst, src); } else if constexpr(LoadTranspose) { diff --git a/include/ck_tile/ops/elementwise/unary_element_wise_operation.hpp b/include/ck_tile/ops/elementwise/unary_element_wise_operation.hpp index ca9af0a7a8..3f58eceb33 100644 --- a/include/ck_tile/ops/elementwise/unary_element_wise_operation.hpp +++ b/include/ck_tile/ops/elementwise/unary_element_wise_operation.hpp @@ -397,6 +397,29 @@ struct PassThroughPack8 y.hi = i4_to_bf8x4(bit_cast(x) >> 8); #endif } + + CK_TILE_HOST_DEVICE constexpr void operator()(fp8x8_t& y, const pk_fp4x4_t& x) const + { + pk_fp4_t f0 = pk_fp4_t{x[0]}; + pk_fp4_t f1 = pk_fp4_t{x[1]}; + pk_fp4_t f2 = pk_fp4_t{x[2]}; + pk_fp4_t f3 = pk_fp4_t{x[3]}; + + fp8x2_t x0 = f0.to_fp8x2(); + fp8x2_t x1 = f1.to_fp8x2(); + fp8x2_t x2 = f2.to_fp8x2(); + fp8x2_t x3 = f3.to_fp8x2(); + + y[0] = x0[0]; + y[1] = x0[1]; + y[2] = x1[0]; + y[3] = x1[1]; + y[4] = x2[0]; + y[5] = x2[1]; + y[6] = x3[0]; + y[7] = x3[1]; + } + constexpr const static bool is_pack8_invocable = true; }; diff --git a/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_base_policy.hpp b/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_base_policy.hpp index 1784436f87..0044b412ec 100644 --- a/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_base_policy.hpp +++ b/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_base_policy.hpp @@ -4,6 +4,7 @@ #pragma once #include "ck_tile/core.hpp" +#include "ck_tile/core/numeric/numeric.hpp" #include "ck_tile/ops/gemm/block/block_wp_asmem_breg_creg.hpp" #include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp" @@ -255,17 +256,26 @@ struct UniversalWeightPreshufflePipelineAgBgCrPolicy { using BlockWarps = typename Problem::BlockGemmShape::BlockWarps; using WarpTile = typename Problem::BlockGemmShape::WarpTile; - using BTypeToUse = - std::conditional_t, - typename Problem::ADataType, - typename Problem::BDataType>; + + // Determine compute types to use + // This logic defaults to A/B DataType, but if one of them is packed falls back to the other + // If both are packed, it falls back to the explicitly defined ComputeDataType in the + // problem It might be a good idea to use ComputeDataType anyway, but that would break how + // this behaviour used to work + using ATypeToUse = mixed_prec_compute_type_from_input_t; + using BTypeToUse = mixed_prec_compute_type_from_input_t; + constexpr index_t WaveSize = get_warp_size(); constexpr index_t KLane = WarpTile::at(I2) * WarpTile::at(I0) / WaveSize; using BDataType = typename Problem::BDataType; constexpr index_t KLaneBytes = KLane / numeric_traits::PackedSize * sizeof(BDataType); constexpr auto NumAccess = static_cast(max(1, KLaneBytes / 16)); - using WarpGemm = WarpGemmDispatcher f32 static_assert( (std::is_same_v || std::is_same_v || - std::is_same_v) && + std::is_same_v || + std::is_same_v) && (std::is_same_v || std::is_same_v || - std::is_same_v) && + std::is_same_v || + std::is_same_v) && (std::is_same_v || std::is_same_v || std::is_same_v) && (std::is_same_v || std::is_same_v || @@ -189,7 +191,8 @@ struct BlockGemmWeightPreshuffleABQuantARegBRegCReg typename BFlatBlockTensor, typename AQBlockTensor, typename BQBlockTensor, - typename ABlockWindow> + typename ABlockWindow, + index_t UnaryOpSize = 8> CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor, ABlockTensor& a_warp_tensor, BFlatBlockTensor& b_warp_tensor, @@ -249,8 +252,10 @@ struct BlockGemmWeightPreshuffleABQuantARegBRegCReg { constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp; constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp); - a_warp_tensor(number{}) = - load_tile(a_warp_windows(number{})(number{})); + + load_int4_tile( + a_warp_tensor(number{}), + a_warp_windows(number{})(number{})); } // barrier // Could be deleted diff --git a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_bquant_cr.hpp b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_bquant_cr.hpp index 2d28b813bf..d79bd31489 100644 --- a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_bquant_cr.hpp +++ b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_bquant_cr.hpp @@ -108,9 +108,11 @@ struct ABQuantBlockUniversalGemmAsBsCr : public BlockGemmQuantBase // 4. i4, bf8, (fp8/fp32) -> f32 static_assert( (std::is_same_v || std::is_same_v || - std::is_same_v) && + std::is_same_v || + std::is_same_v) && (std::is_same_v || std::is_same_v || - std::is_same_v) && + std::is_same_v || + std::is_same_v) && (std::is_same_v || std::is_same_v || std::is_same_v) && (std::is_same_v || std::is_same_v || @@ -135,12 +137,9 @@ struct ABQuantBlockUniversalGemmAsBsCr : public BlockGemmQuantBase using ComputeDataType = remove_cvref_t; using CDataType = remove_cvref_t; - // BDataType gets converted from PkInt4 during loading - using OverrideBDataType = std::conditional_t< - std::is_same_v && - std::is_same_v, - ADataType, - BDataType>; + // A/B DataType get converted from PkInt4/PkFp4 during loading + using OverrideADataType = ComputeDataType; + using OverrideBDataType = ComputeDataType; using Base = BlockGemmQuantBase; using WarpGemm = remove_cvref_t; @@ -268,9 +267,9 @@ struct ABQuantBlockUniversalGemmAsBsCr : public BlockGemmQuantBase bool_constant = {}, bool_constant = {}) { - load_int4_tile( + // If A/B datatype were pkint4/pkfp4 it would be converted prior to storing in LDS + load_int4_tile( a_warp_tile_, a_block_window); - // If B datatype were pkint4 it would be converted prior to storing in LDS load_int4_tile( b_warp_tile_, b_block_window); } diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_abquant_pipeline_ag_bg_cr_policy.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_abquant_pipeline_ag_bg_cr_policy.hpp index 095275e60b..b636bfa4b7 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_abquant_pipeline_ag_bg_cr_policy.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_abquant_pipeline_ag_bg_cr_policy.hpp @@ -10,9 +10,10 @@ namespace ck_tile { -struct GemmABQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgCrPolicy +struct GemmABQuantPipelineAgBgCrDefaultPolicy + : public UniversalGemmBasePolicy { - using Base = UniversalGemmPipelineAgBgCrPolicy; + using Base = UniversalGemmBasePolicy; using Base::I0; using Base::I1; using Base::I2; diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_abquant_pipeline_ag_bg_cr_v3.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_abquant_pipeline_ag_bg_cr_v3.hpp index 5902dd0c4f..cfd12313e8 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_abquant_pipeline_ag_bg_cr_v3.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_abquant_pipeline_ag_bg_cr_v3.hpp @@ -34,9 +34,6 @@ struct ABQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3; using AQuantGroupSize = remove_cvref_t; using BQuantGroupSize = remove_cvref_t; - // BDataType gets converted from PkInt4 during loading - using OverrideBDataType = - std::conditional_t, ADataType, BDataType>; static_assert(BQuantGroupSize::kM == 1, "only N/K blocks for BQuant kernel!"); static_assert(AQuantGroupSize::kN == 1, "only M/K blocks for AQuant kernel!"); @@ -67,6 +64,10 @@ struct ABQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3())>; + // A/B DataType gets converted from PkInt4/PkFp4 during loading + using OverrideADataType = BlockGemm::OverrideADataType; + using OverrideBDataType = BlockGemm::OverrideBDataType; + static constexpr index_t BlockSize = Problem::kBlockSize; static constexpr index_t MPerBlock = BlockGemmShape::kM; static constexpr index_t NPerBlock = BlockGemmShape::kN; @@ -281,9 +282,9 @@ struct ABQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3(p_smem); + Base::template GetABLdsTensorViews(p_smem); constexpr auto a_lds_load_tile_distr = make_static_tile_distribution(BlockGemm::MakeABlockDistributionEncode()); @@ -303,9 +304,9 @@ struct ABQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3(ABlockTileDistr{})); + decltype(make_static_distributed_tensor(ABlockTileDistr{})); using BBlockTile = - decltype(make_static_distributed_tensor(BBlockTileDistr{})); + decltype(make_static_distributed_tensor(BBlockTileDistr{})); using AQBlockTile = decltype(make_static_distributed_tensor(AQBlockTileDistr{})); using BQBlockTile = @@ -361,7 +362,7 @@ struct ABQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3( + auto a_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledARegTileDistribution()); transpose_tile2d(a_shuffle_tmp, a_block_tile); Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func); @@ -373,7 +374,7 @@ struct ABQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3( + auto b_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledBRegTileDistribution()); transpose_tile2d(b_shuffle_tmp, b_block_tile); Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func); @@ -409,7 +410,8 @@ struct ABQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3( + // Note: ABDataType PkInt4/PkFp4 gets converted during loading earlier + auto a_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledARegTileDistribution()); transpose_tile2d(a_shuffle_tmp, a_block_tile); Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func); @@ -420,7 +422,7 @@ struct ABQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3( Policy::template MakeShuffledBRegTileDistribution()); transpose_tile2d(b_shuffle_tmp, b_block_tile); @@ -493,7 +495,8 @@ struct ABQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3( + // Note: ADataType gets converted during loading from PkInt4/PkFp4 + auto a_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledARegTileDistribution()); transpose_tile2d(a_shuffle_tmp, a_block_tile); Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func); @@ -543,9 +546,9 @@ struct ABQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3{}.template operator()( a_dram_block_window_tmp, - [](const ADataType& a) { return a; }, + [](const OverrideADataType& a) { return a; }, b_dram_block_window_tmp, - [](const BDataType& b) { return b; }, + [](const OverrideBDataType& b) { return b; }, aq_dram_block_window_tmp, bq_dram_block_window_tmp, m, @@ -593,9 +596,10 @@ struct ABQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3{}.template operator()( a_dram_block_window_tmp, - [](const ADataType& a) { return a; }, + // Note: ADataType PkInt4/PkFp4 gets converted during loading + [](const OverrideADataType& a) { return a; }, b_dram_block_window_tmp, - // Note: BDataType PkInt4 gets converted during loading + // Note: BDataType PkInt4/PkFp4 gets converted during loading [](const OverrideBDataType& b) { return b; }, aq_dram_block_window_tmp, bq_dram_block_window_tmp, diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_quant_pipeline_problem.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_quant_pipeline_problem.hpp index 1edbe9ac16..9b02585e69 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_quant_pipeline_problem.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_quant_pipeline_problem.hpp @@ -21,23 +21,27 @@ template -struct GemmQuantPipelineProblemBase : public GemmPipelineProblemBase +struct GemmQuantPipelineProblemBase + : public GemmPipelineProblemBase< + ADataType_, + BDataType_, + CDataType_, + BlockGemmShape_, + Traits_, + mixed_prec_compute_type_t> { - using Base = GemmPipelineProblemBase; + + using Base = GemmPipelineProblemBase< + ADataType_, + BDataType_, + CDataType_, + BlockGemmShape_, + Traits_, + mixed_prec_compute_type_t>; using Traits = typename Base::Traits; diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_abquant_pipeline_ag_bg_cr_base_policy.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_abquant_pipeline_ag_bg_cr_base_policy.hpp index ae2a601f8a..f136b86314 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_abquant_pipeline_ag_bg_cr_base_policy.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_abquant_pipeline_ag_bg_cr_base_policy.hpp @@ -95,11 +95,6 @@ struct GemmWPABQuantPipelineAgBgCrPolicy : public UniversalWeightPreshufflePipel using BlockWarps = typename Problem::BlockGemmShape::BlockWarps; using WarpTile = typename Problem::BlockGemmShape::WarpTile; - using BTypeToUse = - std::conditional_t, - typename Problem::ADataType, - typename Problem::BDataType>; - constexpr index_t WaveSize = get_warp_size(); constexpr index_t KLane = WarpTile::at(I2) * WarpTile::at(I0) / WaveSize; using BDataType = typename Problem::BDataType; @@ -107,8 +102,8 @@ struct GemmWPABQuantPipelineAgBgCrPolicy : public UniversalWeightPreshufflePipel KLane / numeric_traits::PackedSize * sizeof(BDataType); constexpr auto NumAccess = static_cast(max(1, KLaneBytes / 16)); - using WarpGemm = WarpGemmDispatcher #include "ck_tile/core.hpp" +#include "ck_tile/ops/common/load_interleaved_pk_type.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp" #include "ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_v2.hpp" @@ -239,36 +240,42 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe make_tensor_view(p_a_lds_pong, a_lds_block_desc); // A DRAM tile window for load + auto a_dram_tile_distribution = + PipelinePolicy::template MakeADramTileDistribution(); + auto a_copy_dram_window = make_tile_window(a_dram_block_window_tmp.get_bottom_tensor_view(), make_tuple(number{}, number{}), a_dram_block_window_tmp.get_window_origin(), - PipelinePolicy::template MakeADramTileDistribution()); + a_dram_tile_distribution); auto a_copy_lds_window_ping = make_tile_window(a_lds_block_ping, make_tuple(number{}, number{}), {0, 0}, - PipelinePolicy::template MakeADramTileDistribution()); + a_dram_tile_distribution); auto a_copy_lds_window_pong = make_tile_window(a_lds_block_pong, make_tuple(number{}, number{}), {0, 0}, - PipelinePolicy::template MakeADramTileDistribution()); + a_dram_tile_distribution); // ping-pong window for A LDS + auto a_warp_tile_distribution = + make_static_tile_distribution(typename WG::AWarpDstrEncoding{}); + auto a_warp_window_ping_tmp = make_tile_window(a_lds_block_ping, make_tuple(number{}, number{}), {iMWarp * WG::kM, 0}, - make_static_tile_distribution(typename WG::AWarpDstrEncoding{})); + a_warp_tile_distribution); auto a_warp_window_pong_tmp = make_tile_window(a_lds_block_pong, make_tuple(number{}, number{}), {iMWarp * WG::kM, 0}, - make_static_tile_distribution(typename WG::AWarpDstrEncoding{})); + a_warp_tile_distribution); statically_indexed_array< statically_indexed_array, @@ -314,7 +321,7 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe b_flat_distribution); using BTypeToUse = - std::conditional_t, ADataType, BDataType>; + mixed_prec_compute_type_from_input_t; using BTileType = decltype(make_static_distributed_tensor(b_flat_distribution)); // pingpong buffer for B @@ -354,7 +361,7 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe move_tile_window(b_flat_dram_windows(nIter)(kIter), {nIter * flatNPerWarp, kIter * flatKPerWarp}); - load_int4_tile( + load_int4_tile( b_warp_tensor_ping(nIter)(kIter), b_flat_dram_windows(nIter)(kIter)); }); }); @@ -393,15 +400,17 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe block_sync_lds(); // preload A00,A10 from lds - statically_indexed_array{})(number<0>{}))), - m_preload> - a_warp_tensor; + using ATypeToUse = + mixed_prec_compute_type_from_input_t; + using ATileType = + decltype(make_static_distributed_tensor(a_warp_tile_distribution)); + statically_indexed_array a_warp_tensor; static_for<0, m_preload, 1>{}([&](auto loadIter) { constexpr auto mIter = loadIter % MIterPerWarp; constexpr auto kIter = loadIter / MIterPerWarp; - a_warp_tensor(loadIter) = - load_tile(a_warp_windows_ping(number{})(number{})); + load_int4_tile( + a_warp_tensor(loadIter), a_warp_windows_ping(number{})(number{})); }); __builtin_amdgcn_sched_barrier(0); @@ -434,7 +443,7 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe move_tile_window(b_flat_dram_windows(nIter)(kIter), {nIter * flatNPerWarp, kIter * flatKPerWarp}); - load_int4_tile( + load_int4_tile( b_warp_tensor_pong(nIter)(kIter), b_flat_dram_windows(nIter)(kIter)); }); }); @@ -450,8 +459,8 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe static_for<0, m_preload, 1>{}([&](auto loadIter) { constexpr auto mIter = loadIter % MIterPerWarp; constexpr auto kIter = loadIter / MIterPerWarp; - a_warp_tensor(loadIter) = - load_tile(a_warp_windows_pong(number{})(number{})); + load_int4_tile( + a_warp_tensor(loadIter), a_warp_windows_pong(number{})(number{})); }); // Next K @@ -463,7 +472,7 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe move_tile_window(b_flat_dram_windows(nIter)(kIter), {nIter * flatNPerWarp, kIter * flatKPerWarp}); - load_int4_tile( + load_int4_tile( b_warp_tensor_ping(nIter)(kIter), b_flat_dram_windows(nIter)(kIter)); }); }); @@ -495,8 +504,8 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe static_for<0, m_preload, 1>{}([&](auto loadIter) { constexpr auto mIter = loadIter % MIterPerWarp; constexpr auto kIter = loadIter / MIterPerWarp; - a_warp_tensor(loadIter) = - load_tile(a_warp_windows_ping(number{})(number{})); + load_int4_tile( + a_warp_tensor(loadIter), a_warp_windows_ping(number{})(number{})); }); iCounter--; HotLoopScheduler(); @@ -513,7 +522,7 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe move_tile_window(b_flat_dram_windows(nIter)(kIter), {nIter * flatNPerWarp, kIter * flatKPerWarp}); - load_int4_tile( + load_int4_tile( b_warp_tensor_pong(nIter)(kIter), b_flat_dram_windows(nIter)(kIter)); }); }); @@ -535,8 +544,8 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe static_for<0, m_preload, 1>{}([&](auto loadIter) { constexpr auto mIter = loadIter % MIterPerWarp; constexpr auto kIter = loadIter / MIterPerWarp; - a_warp_tensor(loadIter) = - load_tile(a_warp_windows_pong(number{})(number{})); + load_int4_tile( + a_warp_tensor(loadIter), a_warp_windows_pong(number{})(number{})); }); // GEMM loopK diff --git a/include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_model_sensitive_pass.hpp b/include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_model_sensitive_pass.hpp index de27b15952..f94d220b94 100644 --- a/include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_model_sensitive_pass.hpp +++ b/include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_model_sensitive_pass.hpp @@ -181,12 +181,10 @@ struct Rmsnorm2dFwdPipelineModelSensitiveT5Pass if constexpr(std::is_same_v) { - const auto tmp0 = - float_to_bf16(acc[idx] * inv_rms_[i_idx]); - const auto tmp1 = float_to_bf16( - type_convert(tmp0) * gamma_); - const auto rmsn_ = type_convert(tmp1); - rmsn(idx) = rmsn_; + const auto tmp = acc[idx] * inv_rms_[i_idx]; + const auto tmp_bf16 = float_to_bf16(tmp); + const auto rmsn_ = type_convert(tmp_bf16) * gamma_; + rmsn(idx) = rmsn_; } else { diff --git a/library/src/tensor_operation_instance/gpu/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/CMakeLists.txt index 41fc8b740e..d5989e7a39 100644 --- a/library/src/tensor_operation_instance/gpu/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/CMakeLists.txt @@ -335,11 +335,23 @@ FOREACH(subdir_path ${dir_list}) endif() endif() + # Build the required pattern based on library settings + set(required_pattern "") + set(pattern_parts "") if(MIOPEN_REQ_LIBS_ONLY) message(STATUS "Removing all sources that are not required for MIOpen") - if(NOT "${cmake_instance}" MATCHES "conv") - set(add_inst 0) - endif() + list(APPEND pattern_parts "conv") + endif() + if(HIPTENSOR_REQ_LIBS_ONLY) + message(STATUS "Removing all sources that are not required for HipTensor") + list(APPEND pattern_parts "contract" "reduce" "element") + endif() + if(pattern_parts) + string(JOIN "|" required_pattern ${pattern_parts}) + endif() + # Apply the pattern if one was set + if(required_pattern AND NOT "${cmake_instance}" MATCHES "${required_pattern}") + set(add_inst 0) endif() if((add_inst EQUAL 1)) @@ -405,7 +417,7 @@ if(CK_DEVICE_OTHER_INSTANCES AND NOT MIOPEN_REQ_LIBS_ONLY) DESTINATION ${CMAKE_INSTALL_LIBDIR}/cmake/composable_kernel ) endif() -if(CK_DEVICE_GEMM_INSTANCES AND NOT MIOPEN_REQ_LIBS_ONLY) +if(CK_DEVICE_GEMM_INSTANCES AND NOT MIOPEN_REQ_LIBS_ONLY AND NOT HIPTENSOR_REQ_LIBS_ONLY) add_library(device_gemm_operations ${CK_DEVICE_GEMM_INSTANCES}) add_library(composablekernels::device_gemm_operations ALIAS device_gemm_operations) target_compile_features(device_gemm_operations PUBLIC) @@ -426,7 +438,7 @@ if(CK_DEVICE_GEMM_INSTANCES AND NOT MIOPEN_REQ_LIBS_ONLY) DESTINATION ${CMAKE_INSTALL_LIBDIR}/cmake/composable_kernel ) endif() -if(CK_DEVICE_CONV_INSTANCES) +if(CK_DEVICE_CONV_INSTANCES AND (NOT HIPTENSOR_REQ_LIBS_ONLY OR MIOPEN_REQ_LIBS_ONLY)) add_library(device_conv_operations ${CK_DEVICE_CONV_INSTANCES}) add_library(composablekernels::device_conv_operations ALIAS device_conv_operations) target_compile_features(device_conv_operations PUBLIC) @@ -451,7 +463,7 @@ if(CK_DEVICE_CONV_INSTANCES) DESTINATION ${CMAKE_INSTALL_LIBDIR}/cmake/composable_kernel ) endif() -if(CK_DEVICE_MHA_INSTANCES AND NOT MIOPEN_REQ_LIBS_ONLY AND BUILD_MHA_LIB) +if(CK_DEVICE_MHA_INSTANCES AND NOT MIOPEN_REQ_LIBS_ONLY AND NOT HIPTENSOR_REQ_LIBS_ONLY AND BUILD_MHA_LIB) set(gpu_list ${INST_TARGETS}) if(gpu_list MATCHES "gfx94" OR gpu_list MATCHES "gfx90a" OR gpu_list MATCHES "gfx95") add_library(device_mha_operations ${CK_DEVICE_MHA_INSTANCES}) @@ -517,7 +529,7 @@ if(CK_DEVICE_REDUCTION_INSTANCES AND NOT MIOPEN_REQ_LIBS_ONLY) ) endif() -if(NOT MIOPEN_REQ_LIBS_ONLY) +if(NOT MIOPEN_REQ_LIBS_ONLY AND NOT HIPTENSOR_REQ_LIBS_ONLY) add_library(device_operations INTERFACE) target_link_libraries(device_operations INTERFACE device_contraction_operations diff --git a/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_kn_mn_common.hpp b/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_kn_mn_common.hpp index 4cd4403436..0dd666b3d9 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_kn_mn_common.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_kn_mn_common.hpp @@ -73,14 +73,17 @@ template using device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_kn_mn_comp_instances = std::tuple< // clang-format off - //###################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| BlkGemmPipeSched| BlkGemmPipelineVer| - //###################################| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MPerBlock_NBlock_NPerBlock| ScalarPerVector| | | - //###################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| | | | | - //###################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGemmMultipleABD_Wmma_CShuffleV3< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 256, 256, 32, 8, 8, 16, 16, 8, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, - DeviceGemmMultipleABD_Wmma_CShuffleV3< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, - DeviceGemmMultipleABD_Wmma_CShuffleV3< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 128, 256, 32, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, - DeviceGemmMultipleABD_Wmma_CShuffleV3< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1> + //###################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| BlkGemmPipeSched| BlkGemmPipelineVer| + //###################################| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MPerBlock_NBlock_NPerBlock| ScalarPerVector| | | + //###################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| | | | | + //###################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmMultipleABD_Wmma_CShuffleV3< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 256, 256, 32, 8, 8, 16, 16, 8, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemmMultipleABD_Wmma_CShuffleV3< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemmMultipleABD_Wmma_CShuffleV3< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 128, 256, 32, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceGemmMultipleABD_Wmma_CShuffleV3< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceGemmMultipleABD_Wmma_CShuffleV3< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 8, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, S<8, 8, 8>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceGemmMultipleABD_Wmma_CShuffleV3< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 128, 256, 32, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceGemmMultipleABD_Wmma_CShuffleV3< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1> // clang-format on >; diff --git a/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp index 76a92a1971..3587c6700c 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp @@ -39,7 +39,7 @@ void add_device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_kn_mn_v1_instances( Multiply, PassThrough, GemmMNKPadding, - Interwave>{}); + Intrawave>{}); add_device_operation_instances(instances, device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_kn_mn_mem_instances< ck::Tuple, diff --git a/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_nk_mn_common.hpp b/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_nk_mn_common.hpp index 1607b240f6..7cb50cd954 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_nk_mn_common.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_nk_mn_common.hpp @@ -71,12 +71,15 @@ template using device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_nk_mn_comp_instances = std::tuple< // clang-format off - //###################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| BlkGemmPipeSched| BlkGemmPipelineVer| - //###################################| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MPerBlock_NBlock_NPerBlock| ScalarPerVector| | | - //###################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| | | | | - //###################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGemmMultipleABD_Wmma_CShuffleV3< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, - DeviceGemmMultipleABD_Wmma_CShuffleV3< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1> + //###################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| BlkGemmPipeSched| BlkGemmPipelineVer| + //###################################| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MPerBlock_NBlock_NPerBlock| ScalarPerVector| | | + //###################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| | | | | + //###################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmMultipleABD_Wmma_CShuffleV3< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemmMultipleABD_Wmma_CShuffleV3< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceGemmMultipleABD_Wmma_CShuffleV3< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 128, 256, 32, 8, 8, 16, 16, 2, 8, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 64, 1, 4>, S<8, 8, 8>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceGemmMultipleABD_Wmma_CShuffleV3< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceGemmMultipleABD_Wmma_CShuffleV3< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 8, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, S<8, 8, 8>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1> // clang-format on >; } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_wmma_multi_abd_bias_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_wmma_multi_abd_bias_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp index 2a4aae98a5..731518257b 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_wmma_multi_abd_bias_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_wmma_multi_abd_bias_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp @@ -39,7 +39,7 @@ void add_device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_kn_mn_bias_v1_instances( Multiply, Add, GemmMNKPadding, - Interwave>{}); + Intrawave>{}); add_device_operation_instances(instances, device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_kn_mn_mem_instances< ck::Tuple, diff --git a/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_wmma_multi_abd_bias_gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_wmma_multi_abd_bias_gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp index 477d6811d2..0a67f2357e 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_wmma_multi_abd_bias_gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_wmma_multi_abd_bias_gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp @@ -39,7 +39,7 @@ void add_device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_kn_mn_bias_gelu_v1_instances Multiply, AddFastGelu, GemmMNKPadding, - Interwave>{}); + Intrawave>{}); add_device_operation_instances(instances, device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_kn_mn_mem_instances< ck::Tuple, diff --git a/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_wmma_multi_abd_bias_gelu_bf16_i8_bf16_mk_nk_mn_v1_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_wmma_multi_abd_bias_gelu_bf16_i8_bf16_mk_nk_mn_v1_instance.cpp index 71c04b3485..c0b4cf7b9a 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_wmma_multi_abd_bias_gelu_bf16_i8_bf16_mk_nk_mn_v1_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_wmma_multi_abd_bias_gelu_bf16_i8_bf16_mk_nk_mn_v1_instance.cpp @@ -36,7 +36,7 @@ void add_device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_nk_mn_bias_gelu_v1_instances ck::Tuple, AddFastGelu, GemmMNKPadding, - Interwave>{}); + Intrawave>{}); } void add_device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_nk_mn_bias_v1_instances( @@ -58,7 +58,7 @@ void add_device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_nk_mn_bias_v1_instances( ck::Tuple, Add, GemmMNKPadding, - Interwave>{}); + Intrawave>{}); } void add_device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_nk_mn_v1_instances( @@ -80,7 +80,7 @@ void add_device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_nk_mn_v1_instances( ck::Tuple<>, PassThrough, GemmMNKPadding, - Interwave>{}); + Intrawave>{}); } void add_device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_nk_mn_gelu_v1_instances( @@ -102,7 +102,7 @@ void add_device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_nk_mn_gelu_v1_instances( ck::Tuple<>, FastGelu, GemmMNKPadding, - Interwave>{}); + Intrawave>{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_wmma_multi_abd_gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_wmma_multi_abd_gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp index 33422fc6db..9176910cea 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_wmma_multi_abd_gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_wmma_multi_abd_gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp @@ -39,7 +39,7 @@ void add_device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_kn_mn_gelu_v1_instances( Multiply, FastGelu, GemmMNKPadding, - Interwave>{}); + Intrawave>{}); add_device_operation_instances(instances, device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_kn_mn_mem_instances< diff --git a/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_wmma_multi_abd_multiply_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_wmma_multi_abd_multiply_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp index 639bda6017..669eb4144a 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_wmma_multi_abd_multiply_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_wmma_multi_abd_multiply_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp @@ -39,7 +39,7 @@ void add_device_gemm_wmma_multi_abd_multiply_bf16_i8_bf16_mk_kn_mn_v1_instances( PassThrough, Multiply, GemmMNKPadding, - Interwave>{}); + Intrawave>{}); add_device_operation_instances( instances, device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_kn_mn_mem_instances, diff --git a/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_wmma_multi_abd_multiply_bias_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_wmma_multi_abd_multiply_bias_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp index 7f8fea44c5..c6a812645b 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_wmma_multi_abd_multiply_bias_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_wmma_multi_abd_multiply_bias_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp @@ -39,7 +39,7 @@ void add_device_gemm_wmma_multi_abd_multiply_bf16_i8_bf16_mk_kn_mn_bias_v1_insta PassThrough, MultiplyAdd, GemmMNKPadding, - Interwave>{}); + Intrawave>{}); add_device_operation_instances(instances, device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_kn_mn_mem_instances< ck::Tuple, diff --git a/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_wmma_multi_abd_multiply_bias_gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_wmma_multi_abd_multiply_bias_gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp index b2bf995507..2d7ffd120d 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_wmma_multi_abd_multiply_bias_gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_wmma_multi_abd_multiply_bias_gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp @@ -39,7 +39,7 @@ void add_device_gemm_wmma_multi_abd_multiply_bf16_i8_bf16_mk_kn_mn_bias_gelu_v1_ PassThrough, MultiplyAddFastGelu, GemmMNKPadding, - Interwave>{}); + Intrawave>{}); add_device_operation_instances(instances, device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_kn_mn_mem_instances< ck::Tuple, diff --git a/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_wmma_multi_abd_multiply_gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_wmma_multi_abd_multiply_gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp index d2adc36dc3..ab49d2f1c9 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_wmma_multi_abd_multiply_gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_wmma_multi_abd_multiply_gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp @@ -39,7 +39,7 @@ void add_device_gemm_wmma_multi_abd_multiply_bf16_i8_bf16_mk_kn_mn_gelu_v1_insta PassThrough, MultiplyFastGelu, GemmMNKPadding, - Interwave>{}); + Intrawave>{}); add_device_operation_instances( instances, device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_kn_mn_mem_instances, diff --git a/profiler/include/profiler/profile_grouped_conv_bwd_weight_impl.hpp b/profiler/include/profiler/profile_grouped_conv_bwd_weight_impl.hpp index 3a9f14e595..afc88150ed 100644 --- a/profiler/include/profiler/profile_grouped_conv_bwd_weight_impl.hpp +++ b/profiler/include/profiler/profile_grouped_conv_bwd_weight_impl.hpp @@ -364,26 +364,39 @@ bool profile_grouped_conv_bwd_weight_impl(int do_verification, using AccDataType = std::conditional_t, int32_t, float>; - // Calculate number of accumulations accounting for split_k - const int num_accums = - static_cast(output.GetElementSize() / conv_param.K_ / split_k_value); - - // Additional tolerance for split_k accumulation if needed - int total_accums = num_accums; - if(split_k_value > 1) - { - total_accums = std::max(num_accums, static_cast(split_k_value)); - } - - // Perform GPU verification (max value computed internally on GPU) + const index_t num_accums = output.GetElementSize() / conv_param.K_; + const index_t num_accums_split_k = split_k_value; + // Get maximum accumulated value from reference const std::size_t tensor_size = weight_device_result.mDesc.GetElementSpaceSize(); + max_accumulated_value = + gpu_reduce_max(gpu_ref_wei_buf.GetDeviceBuffer(), tensor_size); + // Calculate thresholds + auto rtol = + ck::utils::get_relative_threshold( + num_accums / num_accums_split_k); + auto atol = + ck::utils::get_absolute_threshold( + max_accumulated_value / num_accums_split_k, + num_accums / num_accums_split_k); + // Calculate error due to split_k accumulation + auto rtol_split_k = + ck::utils::get_relative_threshold( + num_accums_split_k); + auto atol_split_k = + ck::utils::get_absolute_threshold( + max_accumulated_value, num_accums_split_k); + // Use higher threshold + rtol = std::max(rtol, rtol_split_k); + atol = std::max(atol, atol_split_k); + + // Perform GPU verification auto gpu_result = - ck::profiler::gpu_verify( - wei_device_buf.GetDeviceBuffer(), - gpu_ref_wei_buf.GetDeviceBuffer(), - total_accums, - tensor_size); + ck::profiler::gpu_verify(wei_device_buf.GetDeviceBuffer(), + gpu_ref_wei_buf.GetDeviceBuffer(), + rtol, + atol, + tensor_size); if(!gpu_result) { diff --git a/script/tools/ck-build b/script/tools/ck-build index 2c0bb24eda..a2a02387eb 100755 --- a/script/tools/ck-build +++ b/script/tools/ck-build @@ -2,7 +2,8 @@ # Copyright (c) Advanced Micro Devices, Inc., or its affiliates. # SPDX-License-Identifier: MIT -# CK Build - Build Composable Kernel targets in Docker +# CK Build - Build Composable Kernel targets +# Environment-agnostic: works natively on ROCm hosts or inside containers set -e set -o pipefail @@ -12,46 +13,51 @@ SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" source "${SCRIPT_DIR}/common.sh" # Initialize configuration -PROJECT_ROOT=$(get_project_root "${SCRIPT_DIR}") -CONTAINER_NAME=$(get_container_name "${PROJECT_ROOT}") +PROJECT_ROOT=$(find_project_root "${SCRIPT_DIR}" || get_project_root "${SCRIPT_DIR}") +BUILD_DIR=$(get_build_dir "${PROJECT_ROOT}") # Help message show_help() { cat << EOF -CK Build - Build Composable Kernel targets in Docker +CK Build - Build Composable Kernel targets Usage: ck-build [options] [target...] Options: -h, --help Show this help message - --name Specify container name - --reconfigure Reconfigure CMake before building -j Parallel jobs (passed to ninja) + -v, --verbose Verbose output + --build-dir Build directory (default: ./build) --clean Clean before building + --configure Auto-configure if build.ninja missing + --list List available targets Arguments: target Target(s) to build (default: all) Environment: - CK_CONTAINER_NAME - Override default container name - GPU_TARGET - Override GPU target detection (e.g., gfx950, gfx942) + CK_BUILD_DIR - Override build directory + CK_GPU_TARGET - Override GPU target for auto-configure Examples: ck-build # Build all targets ck-build test_amdgcn_mma # Build specific target ck-build test_amdgcn_mma test_gemm # Build multiple targets - ck-build --reconfigure # Reconfigure CMake and build all + ck-build --configure # Auto-configure and build all ck-build --clean test_amdgcn_mma # Clean and build target ck-build -j 8 test_amdgcn_mma # Build with 8 parallel jobs + ck-build --list # List available targets EOF } # Parse arguments targets=() -reconfigure=false -clean=false parallel_jobs="" +verbose=false +clean=false +auto_configure=false +list_targets=false while [[ $# -gt 0 ]]; do case $1 in @@ -59,21 +65,35 @@ while [[ $# -gt 0 ]]; do show_help exit 0 ;; - --name) - CONTAINER_NAME="$2" + -j) + require_arg "$1" "${2:-}" + parallel_jobs="$2" shift 2 ;; - --reconfigure) - reconfigure=true + -j*) + parallel_jobs="${1#-j}" shift ;; + -v|--verbose) + verbose=true + shift + ;; + --build-dir) + require_arg "$1" "${2:-}" + BUILD_DIR="$2" + shift 2 + ;; --clean) clean=true shift ;; - -j) - parallel_jobs="-j $2" - shift 2 + --configure) + auto_configure=true + shift + ;; + --list) + list_targets=true + shift ;; *) targets+=("$1") @@ -82,62 +102,62 @@ while [[ $# -gt 0 ]]; do esac done -# Ensure container is running -if ! container_is_running "${CONTAINER_NAME}"; then - echo "Container '${CONTAINER_NAME}' not running. Starting..." - "${SCRIPT_DIR}/ck-start" "${CONTAINER_NAME}" +# Handle --list +if [ "$list_targets" = true ]; then + if ! is_build_configured "${BUILD_DIR}"; then + error "Build not configured. Run 'ck-configure' first or use --configure" + exit 1 + fi + info "Available targets:" + cd "${BUILD_DIR}" + ninja -t targets 2>/dev/null | grep -E '^[a-zA-Z_][a-zA-Z0-9_-]*:' | cut -d: -f1 | sort | head -100 echo "" + echo "(Showing first 100 targets. Use 'ninja -t targets' for full list)" + exit 0 fi -# Configure CMake if needed or requested -if [ "$reconfigure" = true ] || ! docker exec "${CONTAINER_NAME}" test -f /workspace/build/build.ninja 2>/dev/null; then - echo "Detecting GPU target..." - GPU_TARGET_DETECTED=$(detect_gpu_target "${CONTAINER_NAME}") - - if [ "$reconfigure" = true ]; then - echo "Reconfiguring CMake from scratch for GPU target: ${GPU_TARGET_DETECTED}" +# Auto-configure if needed +if ! is_build_configured "${BUILD_DIR}"; then + if [ "$auto_configure" = true ]; then + info "Build not configured. Running ck-configure..." + "${SCRIPT_DIR}/ck-configure" --build-dir "${BUILD_DIR}" + echo "" else - echo "Configuring build with CMake for GPU target: ${GPU_TARGET_DETECTED}" + error "Build not configured. Run 'ck-configure' first or use --configure" + exit 1 fi - - docker exec "${CONTAINER_NAME}" bash -c " - cd /workspace || exit 1 - rm -rf /workspace/build - mkdir /workspace/build - cd /workspace/build || exit 1 - cmake .. -GNinja \ - -DGPU_TARGETS=${GPU_TARGET_DETECTED} \ - -DCMAKE_BUILD_TYPE=Release \ - -DCMAKE_CXX_COMPILER=/opt/rocm/llvm/bin/clang++ \ - -DBUILD_TESTING=ON 2>&1 | tail -30 - " - echo "" fi # Clean if requested if [ "$clean" = true ]; then - echo "Cleaning build directory..." - docker exec "${CONTAINER_NAME}" bash -c " - cd /workspace/build || exit 1 - ninja clean - " + info "Cleaning build directory..." + cd "${BUILD_DIR}" + ninja clean echo "" fi -# Build targets -if [ ${#targets[@]} -eq 0 ]; then - echo "Building all configured targets..." - docker exec "${CONTAINER_NAME}" bash -c " - cd /workspace/build || exit 1 - ninja ${parallel_jobs} 2>&1 - " -else - echo "Building targets: ${targets[*]}" - docker exec "${CONTAINER_NAME}" bash -c " - cd /workspace/build || exit 1 - ninja ${parallel_jobs} ${targets[*]} 2>&1 - " +# Build ninja command +ninja_cmd=(ninja -C "${BUILD_DIR}") + +if [ -n "$parallel_jobs" ]; then + ninja_cmd+=("-j" "$parallel_jobs") fi +if [ "$verbose" = true ]; then + ninja_cmd+=(-v) +fi + +# Add targets +ninja_cmd+=("${targets[@]}") + +# Build targets +if [ ${#targets[@]} -eq 0 ]; then + info "Building all configured targets..." +else + info "Building targets: ${targets[*]}" +fi + +"${ninja_cmd[@]}" + echo "" -echo "Build complete ✓" +info "Build complete" diff --git a/script/tools/ck-configure b/script/tools/ck-configure new file mode 100755 index 0000000000..ffe5a4daca --- /dev/null +++ b/script/tools/ck-configure @@ -0,0 +1,187 @@ +#!/bin/bash +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +# CK Configure - Configure CMake build for Composable Kernel +# Environment-agnostic: works natively on ROCm hosts or inside containers + +set -e +set -o pipefail + +# Find script directory and load common utilities +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +source "${SCRIPT_DIR}/common.sh" + +# Initialize configuration +PROJECT_ROOT=$(find_project_root "${SCRIPT_DIR}" || get_project_root "${SCRIPT_DIR}") +BUILD_DIR=$(get_build_dir "${PROJECT_ROOT}") + +# Help message +show_help() { + cat << EOF +CK Configure - Configure CMake build for Composable Kernel + +Usage: ck-configure [options] + +Options: + -h, --help Show this help message + --preset Use CMake preset (dev, dev-gfx908, dev-gfx90a, dev-gfx942, dev-gfx950) + --gpu Override GPU_TARGETS (auto-detected if not specified) + --dtypes Set DTYPES (e.g., fp16,fp32,bf16) + --build-type CMAKE_BUILD_TYPE (default: Release) + --build-dir Build directory (default: ./build) + --clean Remove existing build directory before configuring + --list-presets List available CMake presets + -D = Pass additional CMake variable + +Environment: + CK_GPU_TARGET - Override GPU target detection (e.g., gfx950, gfx942) + CK_BUILD_DIR - Override build directory + +Examples: + ck-configure # Auto-detect GPU and configure + ck-configure --preset dev-gfx950 # Use CMake preset + ck-configure --gpu gfx942 # Configure for specific GPU + ck-configure --clean --preset dev # Clean and reconfigure + ck-configure -D BUILD_DEV=ON # Pass CMake variable + +EOF +} + +# Parse arguments +preset="" +gpu_target="" +dtypes="" +build_type="Release" +clean=false +list_presets=false +cmake_vars=() + +while [[ $# -gt 0 ]]; do + case $1 in + -h|--help) + show_help + exit 0 + ;; + --preset) + require_arg "$1" "${2:-}" + preset="$2" + shift 2 + ;; + --gpu) + require_arg "$1" "${2:-}" + gpu_target="$2" + shift 2 + ;; + --dtypes) + require_arg "$1" "${2:-}" + dtypes="$2" + shift 2 + ;; + --build-type) + require_arg "$1" "${2:-}" + build_type="$2" + shift 2 + ;; + --build-dir) + require_arg "$1" "${2:-}" + BUILD_DIR="$2" + shift 2 + ;; + --clean) + clean=true + shift + ;; + --list-presets) + list_presets=true + shift + ;; + -D) + require_arg "$1" "${2:-}" + cmake_vars+=("-D$2") + shift 2 + ;; + -D*) + cmake_vars+=("$1") + shift + ;; + *) + error "Unknown option: $1" + echo "" + show_help + exit 1 + ;; + esac +done + +# Handle --list-presets +if [ "$list_presets" = true ]; then + echo "Available CMake presets:" + presets=$(list_cmake_presets "${PROJECT_ROOT}" 2>/dev/null) + if [ -n "$presets" ]; then + echo "$presets" | sed 's/^/ /' + else + echo " (No CMakePresets.json found or jq not available)" + fi + exit 0 +fi + +# Clean build directory if requested +if [ "$clean" = true ]; then + if [ -d "${BUILD_DIR}" ]; then + info "Removing existing build directory: ${BUILD_DIR}" + rm -rf "${BUILD_DIR}" + fi +fi + +# Create build directory +mkdir -p "${BUILD_DIR}" + +# Change to project root for CMake +cd "${PROJECT_ROOT}" + +# Build CMake command +cmake_cmd=(cmake -S . -B "${BUILD_DIR}" -GNinja) + +# Use preset if specified +if [ -n "$preset" ]; then + cmake_cmd+=(--preset "${preset}") + info "Using CMake preset: ${preset}" +else + # Manual configuration + + # Detect GPU target if not specified + if [ -z "$gpu_target" ]; then + gpu_target=$(detect_gpu_native) + info "Auto-detected GPU target: ${gpu_target}" + else + info "Using specified GPU target: ${gpu_target}" + fi + + cmake_cmd+=(-DGPU_TARGETS="${gpu_target}") + cmake_cmd+=(-DCMAKE_BUILD_TYPE="${build_type}") + cmake_cmd+=(-DCMAKE_CXX_COMPILER=/opt/rocm/llvm/bin/clang++) + cmake_cmd+=(-DBUILD_TESTING=ON) + + # Add DTYPES if specified + if [ -n "$dtypes" ]; then + cmake_cmd+=(-DDTYPES="${dtypes}") + info "Using DTYPES: ${dtypes}" + fi +fi + +# Add any additional CMake variables +for var in "${cmake_vars[@]}"; do + cmake_cmd+=("$var") +done + +# Run CMake +info "Configuring build in: ${BUILD_DIR}" +echo "Running: ${cmake_cmd[*]}" +echo "" + +"${cmake_cmd[@]}" + +echo "" +info "Configuration complete. Build directory: ${BUILD_DIR}" +info "Next: run 'ck-build' to build targets" diff --git a/script/tools/ck-docker b/script/tools/ck-docker index 82bf770011..6c118561b7 100755 --- a/script/tools/ck-docker +++ b/script/tools/ck-docker @@ -22,25 +22,29 @@ CK Docker Tool - Build and test composable_kernel in Docker Usage: ck-docker [options] -Commands: - start [name] Start Docker container - build [target] [--reconfigure] Build target (optionally reconfigure CMake) - test [options] Run test - shell [name] Open shell in container - status [name] Check container status - stop [name] Stop and remove container +Container Management: + start [name] Start Docker container + stop [name] Stop and remove container + status [name] Check container status + shell [name] Open shell in container + +Build/Test (delegates to core tools inside container): + configure [opts] Run ck-configure in container + build [opts] Run ck-build in container + test [opts] Run ck-test in container + exec Run arbitrary command in container Examples: ck-docker start + ck-docker configure --preset dev-gfx950 ck-docker build test_amdgcn_mma - ck-docker build --reconfigure test_amdgcn_mma - ck-docker test test_amdgcn_mma --gtest_filter=*Fp16* + ck-docker test test_amdgcn_mma --filter '*Fp16*' ck-docker shell + ck-docker exec rocminfo Environment: CK_CONTAINER_NAME - Override default container name (default: ck__) CK_DOCKER_IMAGE - Override Docker image (default: rocm/composable_kernel:ck_ub24.04_rocm7.0.1) - GPU_TARGET - Override GPU target detection (e.g., gfx950, gfx942) EOF } @@ -77,126 +81,38 @@ cmd_start() { docker exec "${name}" bash -c "echo 'Working directory:' && pwd" } -# Build target -cmd_build() { - local target="" - local name="${CONTAINER_NAME}" - local reconfigure=false - - while [[ $# -gt 0 ]]; do - case $1 in - --name) - name="$2" - shift 2 - ;; - --reconfigure) - reconfigure=true - shift - ;; - *) - target="$1" - shift - ;; - esac - done - - # Check if container is running - if ! container_is_running "${name}"; then - echo "Container '${name}' not running. Starting..." - cmd_start "${name}" - fi - - # Reconfigure CMake if requested or if build.ninja doesn't exist - if [ "$reconfigure" = true ] || ! docker exec "${name}" test -f /workspace/build/build.ninja 2>/dev/null; then - echo "Detecting GPU target..." - local gpu_target=$(detect_gpu_target "${name}") - - if [ "$reconfigure" = true ]; then - echo "Reconfiguring CMake from scratch for GPU target: ${gpu_target}" - else - echo "Configuring build with CMake for GPU target: ${gpu_target}" - fi - - docker exec "${name}" bash -c " - cd /workspace || exit 1 - rm -rf /workspace/build - mkdir /workspace/build - cd /workspace/build || exit 1 - cmake .. -GNinja \ - -DGPU_TARGETS=${gpu_target} \ - -DCMAKE_BUILD_TYPE=Release \ - -DCMAKE_CXX_COMPILER=/opt/rocm/llvm/bin/clang++ \ - -DBUILD_TESTING=ON 2>&1 | tail -30 - " - fi - - if [ -z "$target" ]; then - echo "Building all configured targets..." - else - echo "Building target: ${target}" - fi - - docker exec "${name}" bash -c " - cd /workspace/build || exit 1 - ninja ${target} 2>&1 - " - - echo "Build complete" +# Configure (delegate to ck-configure in container) +cmd_configure() { + ensure_container_running "${CONTAINER_NAME}" "${SCRIPT_DIR}" + docker exec "${CONTAINER_NAME}" /workspace/script/tools/ck-configure "$@" } -# Run test +# Build (delegate to ck-build in container) +cmd_build() { + ensure_container_running "${CONTAINER_NAME}" "${SCRIPT_DIR}" + docker exec "${CONTAINER_NAME}" /workspace/script/tools/ck-build "$@" +} + +# Test (delegate to ck-test in container) cmd_test() { - local test_name="" - local name="${CONTAINER_NAME}" - local -a test_options=() + ensure_container_running "${CONTAINER_NAME}" "${SCRIPT_DIR}" + docker exec "${CONTAINER_NAME}" /workspace/script/tools/ck-test "$@" +} - while [[ $# -gt 0 ]]; do - case $1 in - --name) - name="$2" - shift 2 - ;; - --gtest_*|--help) - test_options+=("$1") - shift - ;; - *) - if [ -z "$test_name" ]; then - test_name="$1" - else - test_options+=("$1") - fi - shift - ;; - esac - done - - if [ -z "$test_name" ]; then - echo "Error: test_name required" - echo "Usage: ck-docker test [--name container_name] [gtest_options]" +# Execute arbitrary command in container +cmd_exec() { + if [ $# -eq 0 ]; then + error "command required" + echo "Usage: ck-docker exec " return 1 fi - # Check if container is running - if ! container_is_running "${name}"; then - echo "Error: Container '${name}' not running" - echo "Start it with: ck-docker start --name ${name}" - return 1 - fi + ensure_container_running "${CONTAINER_NAME}" "${SCRIPT_DIR}" - if ! docker exec "${name}" test -f "/workspace/build/bin/${test_name}" 2>/dev/null; then - echo "Test executable not found. Building ${test_name}..." - cmd_build "${test_name}" --name "${name}" - fi + local docker_flags=() + [ -t 0 ] && [ -t 1 ] && docker_flags+=("-it") - echo "Running: ${test_name} ${test_options[*]}" - echo "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━" - # Build the command with proper quoting - local cmd="cd /workspace/build && ./bin/${test_name}" - for opt in "${test_options[@]}"; do - cmd="${cmd} $(printf '%q' "$opt")" - done - docker exec "${name}" bash -c "${cmd}" + docker exec "${docker_flags[@]}" "${CONTAINER_NAME}" "$@" } # Shell @@ -220,7 +136,7 @@ cmd_status() { if [ -z "$name" ]; then echo "Composable Kernel Docker Containers:" - echo "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━" + echo "---" docker ps -a --filter "ancestor=${docker_image}" \ --format "table {{.Names}}\t{{.Status}}\t{{.CreatedAt}}" || echo "No containers found" else @@ -262,6 +178,10 @@ case "${1:-}" in shift cmd_start "$@" ;; + configure) + shift + cmd_configure "$@" + ;; build) shift cmd_build "$@" @@ -270,6 +190,10 @@ case "${1:-}" in shift cmd_test "$@" ;; + exec) + shift + cmd_exec "$@" + ;; shell) shift cmd_shell "$@" diff --git a/script/tools/ck-rocprof b/script/tools/ck-rocprof new file mode 100755 index 0000000000..2b41a7403c --- /dev/null +++ b/script/tools/ck-rocprof @@ -0,0 +1,806 @@ +#!/bin/bash +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +# CK ROCProf Tool - Profile CK applications with rocprof-compute +# Native-only tool. For Docker usage, run via: ck-docker exec ck-rocprof ... + +set -e +set -o pipefail + +# Find script directory and load common utilities +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +source "${SCRIPT_DIR}/common.sh" + +# Initialize configuration +PROJECT_ROOT=$(find_project_root "${SCRIPT_DIR}" || get_project_root "${SCRIPT_DIR}") + +# ============================================================================ +# rocprof-compute detection +# ============================================================================ + +# Common rocprof-compute binary locations +# Order: user installs first, then system ROCm versions (newest first) +ROCPROF_CANDIDATES=( + "${HOME}/.local/rocprofiler-compute/3.4.0/bin/rocprof-compute" + "/opt/rocm/bin/rocprof-compute" + "/opt/rocm-7.2.0/bin/rocprof-compute" + "/opt/rocm-7.0.1/bin/rocprof-compute" + "/opt/rocm-6.2.0/bin/rocprof-compute" + "/opt/rocm-6.1.0/bin/rocprof-compute" +) + +# Find rocprof-compute binary +find_rocprof_bin() { + # Check CK_ROCPROF_BIN first + if [ -n "${CK_ROCPROF_BIN:-}" ] && [ -f "${CK_ROCPROF_BIN}" ]; then + echo "${CK_ROCPROF_BIN}" + return 0 + fi + + # Check PATH + if command -v rocprof-compute &>/dev/null; then + command -v rocprof-compute + return 0 + fi + + # Check common ROCm locations and user installations + for bin in "${ROCPROF_CANDIDATES[@]}"; do + if [ -f "$bin" ]; then + echo "$bin" + return 0 + fi + done + + return 1 +} + +# Find ROCm requirements file +find_rocm_requirements() { + local rocprof_bin="${1:-$(find_rocprof_bin)}" + if [ -z "$rocprof_bin" ]; then + return 1 + fi + + # Requirements file is typically at ../libexec/rocprofiler-compute/requirements.txt + local rocm_dir + rocm_dir=$(dirname "$(dirname "$rocprof_bin")") + local req_file="${rocm_dir}/libexec/rocprofiler-compute/requirements.txt" + + if [ -f "$req_file" ]; then + echo "$req_file" + return 0 + fi + + return 1 +} + +# ============================================================================ +# Configuration +# ============================================================================ + +ROCPROF_BIN="${CK_ROCPROF_BIN:-$(find_rocprof_bin || echo "")}" +VENV_PATH="${CK_PROFILE_VENV:-${PROJECT_ROOT}/.ck-rocprof-venv}" +WORKLOAD_DIR="${CK_WORKLOAD_DIR:-$(get_build_dir "${PROJECT_ROOT}")/workloads}" +ROCM_REQUIREMENTS="${CK_ROCM_REQUIREMENTS:-$(find_rocm_requirements "${ROCPROF_BIN}" || echo "")}" + +# ============================================================================ +# Helper functions +# ============================================================================ + +# Get file/directory size +get_size() { + local path="$1" + du -sh "$path" 2>/dev/null | cut -f1 +} + +# Get file modification date (cross-platform: Linux and macOS) +get_date() { + local path="$1" + # Try GNU stat first (Linux), fall back to BSD stat (macOS) + if stat --version &>/dev/null 2>&1; then + stat -c %y "$path" 2>/dev/null | cut -d' ' -f1 + else + stat -f %Sm -t %Y-%m-%d "$path" 2>/dev/null + fi +} + +# Help message +show_help() { + cat << EOF +CK ROCProf Tool - Profile CK applications with rocprof-compute + +Usage: ck-rocprof [options] + +Commands: + setup One-time setup: create Python venv and install dependencies + run [args] Profile executable and save results as + analyze [block] Analyze profiling results (default: block 12 - LDS metrics) + compare Compare two profiling runs + list List available profiling runs + clean Remove a profiling run (use --all for all runs) + status Show current configuration and status + help Show this help message + +Examples: + ck-rocprof setup + ck-rocprof run baseline ./bin/tile_example_gemm_universal + ck-rocprof analyze baseline + ck-rocprof analyze baseline 12 + ck-rocprof compare baseline optimized + ck-rocprof list + ck-rocprof clean baseline + ck-rocprof status + +Environment Variables: + CK_GPU_TARGET - Override GPU detection (e.g., gfx950, MI300X) + CK_PROFILE_VENV - Python venv path (default: \$PROJECT/.ck-rocprof-venv) + CK_ROCPROF_BIN - rocprof-compute binary path + CK_ROCM_REQUIREMENTS - Path to rocprofiler-compute requirements.txt + CK_WORKLOAD_DIR - Workload storage directory + +Profiling Blocks (use with 'analyze '): + Block 2: System Speed-of-Light (SOL) + Block 6: Shader Engine (SE) utilization + Block 7: L2 Cache metrics + Block 11: Vector L1D Cache metrics + Block 12: LDS (Local Data Share) - DEFAULT + Block 16: Instruction mix statistics + Block 17: Compute Unit (CU) metrics + +LDS Metrics (Block 12): + - 12.1.3: Bank Conflict Rate (% of peak) + - 12.2.9: Bank Conflicts/Access (conflicts/access) + - 12.2.12: Bank Conflict (cycles per kernel) + - 12.2.17: LDS Data FIFO Full Rate (cycles) + +Notes: + - Workload names must be alphanumeric with hyphens/underscores only + - Profiling skips roofline analysis (--no-roof) for faster execution + - Results stored in workloads// + - For Docker usage, run via: ck-docker exec ck-rocprof ... +EOF +} + +# Get rocprof-compute wrapper path +get_rocprof_wrapper() { + echo "${VENV_PATH}/bin/rocprof-compute" +} + +# Validate workload name to prevent path traversal and shell injection +# Allowed: alphanumeric, hyphens, underscores +validate_workload_name() { + local name="$1" + if [[ ! "$name" =~ ^[a-zA-Z0-9_-]+$ ]]; then + error "Invalid workload name: '$name'" + echo "Names must contain only letters, numbers, hyphens, and underscores" + return 1 + fi + # Prevent reserved names + if [[ "$name" == "." || "$name" == ".." ]]; then + error "Invalid workload name: '$name'" + return 1 + fi + return 0 +} + +# Check if setup is complete +is_setup_complete() { + local wrapper + wrapper=$(get_rocprof_wrapper) + [ -d "${VENV_PATH}" ] && [ -f "${wrapper}" ] +} + +# ============================================================================ +# Source installation +# ============================================================================ + +# rocprofiler-compute source installation location +ROCPROF_SOURCE_VERSION="3.4.0" +ROCPROF_SOURCE_DIR="${HOME}/.local/rocprofiler-compute/${ROCPROF_SOURCE_VERSION}" +ROCPROF_SOURCE_BIN="${ROCPROF_SOURCE_DIR}/bin/rocprof-compute" +ROCPROF_REPO_URL="https://github.com/ROCm/rocprofiler-compute.git" +ROCPROF_REPO_BRANCH="release/rocprofiler-compute-v${ROCPROF_SOURCE_VERSION}" + +# Install rocprofiler-compute from source +install_from_source() { + local install_dir="${ROCPROF_SOURCE_DIR}" + local src_dir="${install_dir}/src" + + info "Installing rocprofiler-compute ${ROCPROF_SOURCE_VERSION} from source..." + echo "Install location: ${install_dir}" + echo "" + + # Ensure uv is available + if ! command -v uv &>/dev/null; then + info "Installing uv package manager via pip..." + if ! python3 -m pip install --user uv; then + error "Failed to install uv package manager" + return 1 + fi + export PATH="${HOME}/.local/bin:${PATH}" + if ! command -v uv &>/dev/null; then + error "uv installed but not found in PATH" + return 1 + fi + fi + + # Create installation directory + mkdir -p "${install_dir}" + + # Clone repository + if [ -d "${src_dir}" ]; then + info "Source already exists, updating..." + git -C "${src_dir}" fetch --quiet + git -C "${src_dir}" checkout --quiet "${ROCPROF_REPO_BRANCH}" 2>/dev/null || \ + git -C "${src_dir}" checkout --quiet "amd-mainline" + else + info "Cloning rocprofiler-compute repository..." + if ! git clone --quiet --branch "${ROCPROF_REPO_BRANCH}" --depth 1 "${ROCPROF_REPO_URL}" "${src_dir}" 2>/dev/null; then + # Fall back to amd-mainline if release branch doesn't exist + info "Release branch not found, using amd-mainline..." + git clone --quiet --branch "amd-mainline" --depth 1 "${ROCPROF_REPO_URL}" "${src_dir}" + fi + fi + + # Create venv for source installation + local venv_dir="${install_dir}/venv" + if [ ! -d "${venv_dir}" ]; then + info "Creating Python virtual environment..." + uv venv "${venv_dir}" + fi + + # Install dependencies from requirements.txt + info "Installing dependencies (this may take a minute)..." + uv pip install --python "${venv_dir}/bin/python" -r "${src_dir}/requirements.txt" --quiet + # Pin pandas to avoid CSV conversion bug + uv pip install --python "${venv_dir}/bin/python" 'pandas<3.0' --quiet + + # Create bin directory and wrapper script + mkdir -p "${install_dir}/bin" + cat > "${ROCPROF_SOURCE_BIN}" << 'WRAPPER_EOF' +#!/bin/bash +# rocprof-compute wrapper for source installation +INSTALL_DIR="$(cd "$(dirname "$0")/.." && pwd)" +SRC_DIR="${INSTALL_DIR}/src/src" +VENV_DIR="${INSTALL_DIR}/venv" + +# Set PYTHONPATH to source directory for module imports +export PYTHONPATH="${SRC_DIR}:${PYTHONPATH}" + +# Execute rocprof-compute script with venv Python +exec "${VENV_DIR}/bin/python3" "${SRC_DIR}/rocprof-compute" "$@" +WRAPPER_EOF + chmod +x "${ROCPROF_SOURCE_BIN}" + + info "rocprofiler-compute installed successfully!" + echo " Binary: ${ROCPROF_SOURCE_BIN}" + echo "" +} + +# ============================================================================ +# Commands +# ============================================================================ + +# Setup: Create Python venv and install rocprof-compute dependencies +cmd_setup() { + echo "Setting up rocprof-compute profiling environment..." + echo "===========================================" + + # Check if rocprof-compute exists, install from source if not + if [ -z "${ROCPROF_BIN}" ] || [ ! -f "${ROCPROF_BIN}" ]; then + warn "rocprof-compute not found in standard locations" + echo "" + echo "Searched locations:" + for bin in "${ROCPROF_CANDIDATES[@]}"; do + echo " - $bin" + done + echo "" + + # Check if we can install from source + if ! command -v git &>/dev/null; then + error "git is required to install from source" + return 1 + fi + if ! command -v python3 &>/dev/null; then + error "python3 is required to install from source" + return 1 + fi + + echo "Installing rocprofiler-compute from source..." + echo "" + if ! install_from_source; then + error "Failed to install rocprofiler-compute from source" + return 1 + fi + + # Update configuration with source installation + ROCPROF_BIN="${ROCPROF_SOURCE_BIN}" + ROCM_REQUIREMENTS="${ROCPROF_SOURCE_DIR}/libexec/rocprofiler-compute/requirements.txt" + fi + info "Using rocprof-compute: ${ROCPROF_BIN}" + + # Check requirements file (only needed for non-source installs that use separate venv) + if [ -z "${ROCM_REQUIREMENTS}" ] || [ ! -f "${ROCM_REQUIREMENTS}" ]; then + # For source installs, requirements are bundled + if [[ "${ROCPROF_BIN}" == "${ROCPROF_SOURCE_BIN}" ]]; then + ROCM_REQUIREMENTS="${ROCPROF_SOURCE_DIR}/libexec/rocprofiler-compute/requirements.txt" + else + error "ROCm requirements file not found" + local expected_path + expected_path="$(dirname "$(dirname "${ROCPROF_BIN}")")/libexec/rocprofiler-compute/requirements.txt" + echo "Expected at: ${expected_path}" + echo "Set CK_ROCM_REQUIREMENTS to override" + return 1 + fi + fi + + # Check GPU access + if [ ! -r /dev/kfd ]; then + warn "No read access to /dev/kfd - GPU profiling may fail" + warn "Add user to video/render group: sudo usermod -a -G video,render \$USER" + fi + + # For source installations, the venv is already set up - just create wrapper + if [[ "${ROCPROF_BIN}" == "${ROCPROF_SOURCE_BIN}" ]]; then + # Source install already has everything set up + local wrapper + wrapper=$(get_rocprof_wrapper) + mkdir -p "$(dirname "${wrapper}")" + + # For source install, wrapper just calls the source binary + cat > "${wrapper}" << WRAPPER_EOF +#!/bin/bash +# rocprof-compute wrapper (using source installation) +exec "${ROCPROF_BIN}" "\$@" +WRAPPER_EOF + chmod +x "${wrapper}" + info "Wrapper created at ${wrapper}" + + # Create marker file for venv directory + mkdir -p "${VENV_PATH}/bin" + touch "${VENV_PATH}/.source-install" + else + # System install - need to set up venv with dependencies + # Install uv if needed + if ! command -v uv &>/dev/null; then + info "Installing uv package manager via pip..." + if ! python3 -m pip install --user uv; then + error "Failed to install uv package manager" + return 1 + fi + export PATH="${HOME}/.local/bin:${PATH}" + if ! command -v uv &>/dev/null; then + error "uv installed but not found in PATH" + echo "Try adding ~/.local/bin to your PATH" + return 1 + fi + fi + + # Create venv + if [ -d "${VENV_PATH}" ]; then + info "Python venv already exists at ${VENV_PATH}" + else + info "Creating Python venv at ${VENV_PATH}..." + uv venv "${VENV_PATH}" + fi + + # Install dependencies + info "Installing dependencies..." + uv pip install --python "${VENV_PATH}/bin/python" -r "${ROCM_REQUIREMENTS}" + uv pip install --python "${VENV_PATH}/bin/python" 'pandas<3.0' + + # Create wrapper script + local wrapper + wrapper=$(get_rocprof_wrapper) + mkdir -p "$(dirname "${wrapper}")" + cat > "${wrapper}" << WRAPPER_EOF +#!/bin/bash +# rocprof-compute wrapper using venv Python +VENV_DIR="\$(cd "\$(dirname "\$0")/.." && pwd)" +exec "\${VENV_DIR}/bin/python" "${ROCPROF_BIN}" "\$@" +WRAPPER_EOF + chmod +x "${wrapper}" + info "Wrapper created at ${wrapper}" + fi + + # Create workload directory + mkdir -p "${WORKLOAD_DIR}" + info "Workload directory: ${WORKLOAD_DIR}" + + echo "" + info "Setup complete! You can now use:" + echo " ck-rocprof run " +} + +# Detect GPU architecture +detect_gpu_arch() { + # Allow override via environment variable + if [ -n "${CK_GPU_TARGET:-}" ]; then + echo "${CK_GPU_TARGET}" + return 0 + fi + + if command -v rocminfo &>/dev/null; then + # Try marketing name first (MI350, MI300X) + local marketing_name + marketing_name=$(rocminfo 2>/dev/null | grep 'Marketing Name:' | grep -oE 'MI[0-9]+[A-Z]*' | head -1) + if [ -n "$marketing_name" ]; then + echo "$marketing_name" + return 0 + fi + + # Fallback to gfx name + local gfx_name + gfx_name=$(rocminfo 2>/dev/null | grep -oE 'gfx[0-9a-z]+' | head -1) + if [ -n "$gfx_name" ]; then + echo "$gfx_name" + return 0 + fi + fi + + # Try existing workload directories + if [ -d "${WORKLOAD_DIR}" ]; then + local first_dir + first_dir=$(find "${WORKLOAD_DIR}" -maxdepth 2 -type d \( -name 'gfx*' -o -name 'MI*' \) 2>/dev/null | head -1) + if [ -n "$first_dir" ]; then + basename "$first_dir" + return 0 + fi + fi + + # Final fallback - use gfx950 consistent with common.sh + echo "gfx950" +} + +# Run profiling +cmd_run() { + # Validate argument count before shifting + if [ $# -lt 2 ]; then + error "name and executable required" + echo "Usage: ck-rocprof run [args]" + return 1 + fi + + local name="$1" + local executable="$2" + shift 2 + local -a exe_args=("$@") + + # Validate workload name (prevents path traversal) + if ! validate_workload_name "$name"; then + return 1 + fi + + # Check setup + if ! is_setup_complete; then + error "Profiling environment not set up" + echo "Run: ck-rocprof setup" + return 1 + fi + + # Check if executable exists + if [ ! -f "$executable" ]; then + error "Executable not found: $executable" + return 1 + fi + + local wrapper + wrapper=$(get_rocprof_wrapper) + local gpu_arch + gpu_arch=$(detect_gpu_arch) + + echo "Profiling: $executable ${exe_args[*]}" + echo "Run name: $name" + echo "GPU arch: $gpu_arch" + echo "===========================================" + + # Build command with proper escaping to prevent shell injection + # --no-roof skips roofline analysis to speed up profiling + local escaped_executable + escaped_executable=$(printf '%q' "$executable") + local escaped_workload_dir + escaped_workload_dir=$(printf '%q' "${WORKLOAD_DIR}/${name}") + + local cmd="${wrapper} profile --no-roof --path ${escaped_workload_dir} --name ${name} -- ${escaped_executable}" + for arg in "${exe_args[@]}"; do + cmd="${cmd} $(printf '%q' "$arg")" + done + + # Run profiling + bash -c "${cmd}" + + echo "" + info "Profiling complete" + echo "Results saved to: ${WORKLOAD_DIR}/${name}/" + echo "" + echo "Analyze with: ck-rocprof analyze ${name}" +} + +# Find workload path for a given run name +find_workload_path() { + local name="$1" + local run_dir="${WORKLOAD_DIR}/${name}" + + if [ ! -d "$run_dir" ]; then + return 1 + fi + + # Check if profiling data exists + if [ -f "${run_dir}/pmc_perf.csv" ]; then + echo "$run_dir" + return 0 + fi + + return 1 +} + +# Analyze profiling results +cmd_analyze() { + local name="$1" + local block="${2:-12}" # Default to block 12 (LDS metrics) + + if [ -z "$name" ]; then + error "name required" + echo "Usage: ck-rocprof analyze [block]" + return 1 + fi + + # Validate workload name (prevents path traversal) + if ! validate_workload_name "$name"; then + return 1 + fi + + # Check setup + if ! is_setup_complete; then + error "Profiling environment not set up" + echo "Run: ck-rocprof setup" + return 1 + fi + + local wrapper + wrapper=$(get_rocprof_wrapper) + local workload_path + workload_path=$(find_workload_path "${name}") + + if [ -z "$workload_path" ]; then + error "Profiling results not found for '${name}'" + echo "" + echo "Available runs:" + cmd_list + return 1 + fi + + echo "Analyzing: ${name} (Block ${block})" + echo "===========================================" + echo "" + + "${wrapper}" analyze --path "${workload_path}" --block "${block}" +} + +# Compare two profiling runs +cmd_compare() { + local name1="$1" + local name2="$2" + + if [ -z "$name1" ] || [ -z "$name2" ]; then + error "two run names required" + echo "Usage: ck-rocprof compare " + return 1 + fi + + # Validate workload names (prevents path traversal) + if ! validate_workload_name "$name1"; then + return 1 + fi + if ! validate_workload_name "$name2"; then + return 1 + fi + + # Check setup + if ! is_setup_complete; then + error "Profiling environment not set up" + echo "Run: ck-rocprof setup" + return 1 + fi + + # Verify both runs exist + local path1 + path1=$(find_workload_path "${name1}") + local path2 + path2=$(find_workload_path "${name2}") + + if [ -z "$path1" ]; then + error "Profiling results not found for '${name1}'" + return 1 + fi + + if [ -z "$path2" ]; then + error "Profiling results not found for '${name2}'" + return 1 + fi + + echo "Comparing profiling runs:" + echo " Baseline: ${name1}" + echo " Optimized: ${name2}" + echo "===========================================" + echo "" + + echo "=== ${name1} - Block 12 (LDS) ===" + cmd_analyze "${name1}" 12 2>/dev/null | head -40 + + echo "" + echo "=== ${name2} - Block 12 (LDS) ===" + cmd_analyze "${name2}" 12 2>/dev/null | head -40 + + echo "" + echo "===========================================" + echo "For detailed analysis, run:" + echo " ck-rocprof analyze ${name1} 12" + echo " ck-rocprof analyze ${name2} 12" +} + +# List available profiling runs +cmd_list() { + if [ ! -d "${WORKLOAD_DIR}" ]; then + echo "No profiling runs found (workload directory doesn't exist)" + return 0 + fi + + local runs + runs=$(find "${WORKLOAD_DIR}" -maxdepth 1 -mindepth 1 -type d -exec basename {} \; 2>/dev/null | sort) + + if [ -z "$runs" ]; then + echo "No profiling runs found in ${WORKLOAD_DIR}" + return 0 + fi + + echo "Available profiling runs:" + echo "===========================================" + + while IFS= read -r run; do + local path + path=$(find_workload_path "$run") + + if [ -n "$path" ]; then + local size + size=$(get_size "$path") + local date + date=$(get_date "$path") + printf " %-25s [%s, %s]\n" "$run" "$size" "$date" + else + printf " %-25s [no data]\n" "$run" + fi + done <<< "$runs" + + echo "" + echo "Analyze with: ck-rocprof analyze " +} + +# Clean (remove) profiling runs +cmd_clean() { + local name="${1:-}" + + if [ -z "$name" ]; then + error "name required (or use --all to remove all runs)" + echo "Usage: ck-rocprof clean " + echo " ck-rocprof clean --all" + return 1 + fi + + if [ "$name" = "--all" ]; then + # Remove all profiling runs + if [ ! -d "${WORKLOAD_DIR}" ]; then + echo "No profiling runs to clean" + return 0 + fi + + echo "This will remove ALL profiling runs in ${WORKLOAD_DIR}" + read -r -p "Are you sure? [y/N] " confirm + if [[ ! "$confirm" =~ ^[Yy]$ ]]; then + echo "Cancelled" + return 0 + fi + + rm -rf "${WORKLOAD_DIR:?}"/* + info "All profiling runs removed" + else + # Validate name + if ! validate_workload_name "$name"; then + return 1 + fi + + local run_dir="${WORKLOAD_DIR}/${name}" + if [ ! -d "$run_dir" ]; then + error "Profiling run not found: ${name}" + return 1 + fi + + rm -rf "${run_dir}" + info "Removed profiling run: ${name}" + fi +} + +# Show status information +cmd_status() { + echo "CK ROCProf Status" + echo "===========================================" + echo "" + + # rocprof-compute binary + if [ -n "${ROCPROF_BIN}" ] && [ -f "${ROCPROF_BIN}" ]; then + echo "rocprof-compute: ${ROCPROF_BIN}" + else + echo "rocprof-compute: not found" + fi + echo "" + + # Paths + echo "Paths:" + echo " Venv: ${VENV_PATH}" + echo " Workloads: ${WORKLOAD_DIR}" + echo "" + + # Setup status + echo "Setup status:" + if is_setup_complete; then + echo " Profiling environment: ready" + else + echo " Profiling environment: not configured (run 'ck-rocprof setup')" + fi + echo "" + + # Workload count + if [ -d "${WORKLOAD_DIR}" ]; then + local count + count=$(find "${WORKLOAD_DIR}" -maxdepth 1 -mindepth 1 -type d 2>/dev/null | wc -l) + echo "Profiling runs: ${count}" + else + echo "Profiling runs: 0" + fi +} + +# ============================================================================ +# Main command dispatcher +# ============================================================================ + +case "${1:-}" in + setup) + cmd_setup + ;; + run) + shift + cmd_run "$@" + ;; + analyze) + shift + cmd_analyze "$@" + ;; + compare) + shift + cmd_compare "$@" + ;; + list) + cmd_list + ;; + clean) + shift + cmd_clean "$@" + ;; + status) + cmd_status + ;; + help|--help|-h) + show_help + ;; + *) + if [ -z "${1:-}" ]; then + show_help + else + echo "Unknown command: ${1}" + echo "" + show_help + exit 1 + fi + ;; +esac diff --git a/script/tools/ck-rocprof.md b/script/tools/ck-rocprof.md new file mode 100644 index 0000000000..0588846097 --- /dev/null +++ b/script/tools/ck-rocprof.md @@ -0,0 +1,167 @@ +# CK ROCProf Tool + +GPU performance profiling for Composable Kernel applications using AMD rocprof-compute. + +**Note:** This is a native-only tool. For Docker usage, run via `ck-docker exec ck-rocprof ...` + +## Quick Start + +```bash +# One-time setup (requires rocprofiler-compute installed) +./script/tools/ck-rocprof setup + +# Profile executable +cd build +../script/tools/ck-rocprof run baseline ./bin/tile_example_gemm_universal + +# Analyze LDS metrics +../script/tools/ck-rocprof analyze baseline + +# Compare optimizations +../script/tools/ck-rocprof run optimized ./bin/tile_example_gemm_universal +../script/tools/ck-rocprof compare baseline optimized +``` + +## Commands + +### `setup` +One-time setup: creates Python venv, installs dependencies, configures rocprof-compute. + +### `run [args]` +Profile executable and save results. + +```bash +# Basic profiling +ck-rocprof run baseline ./bin/gemm_example + +# With arguments +ck-rocprof run large_matrix ./bin/gemm_example -m 8192 -n 8192 -k 4096 + +# Test filtering +ck-rocprof run unit_test ./bin/test_gemm --gtest_filter="*Fp16*" +``` + +### `analyze [block]` +Display profiling metrics (default: Block 12 - LDS). + +```bash +ck-rocprof analyze baseline # LDS metrics +ck-rocprof analyze baseline 2 # L2 Cache +ck-rocprof analyze baseline 7 # Instruction Mix +``` + +### `compare ` +Side-by-side comparison of two runs. + +### `list` +List all profiling runs with size and date. + +### `clean ` / `clean --all` +Remove profiling runs. Use `--all` to remove all runs. + +### `status` +Show current configuration: mode (native/Docker), paths, setup status. + +## Key LDS Metrics (Block 12) + +**Target Values:** +- Bank Conflicts/Access: <0.01 (1% conflict rate) +- Bank Conflict Rate: >90% of peak bandwidth + +**Critical Metrics:** +- **12.2.9 Bank Conflicts/Access**: Direct conflict measure + - Baseline (naive): ~0.04 (4% conflicts) + - Optimized: <0.005 (<0.5% conflicts) +- **12.2.12 Bank Conflict Cycles**: Wasted cycles per kernel +- **12.2.17 LDS Data FIFO Full**: Memory system pressure + +## Optimization Workflow + +```bash +# 1. Baseline +ck-rocprof run baseline ./bin/my_kernel + +# 2. Check conflicts +ck-rocprof analyze baseline +# Look for Bank Conflicts/Access > 0.02 + +# 3. Optimize code (XOR transforms, padding, etc.) +# ... edit source ... + +# 4. Test optimization +ninja my_kernel +ck-rocprof run optimized ./bin/my_kernel + +# 5. Verify improvement +ck-rocprof compare baseline optimized +# Target: 8-10x reduction in conflicts +``` + +## Environment Variables + +- `CK_PROFILE_VENV`: Python venv path (default: `$PROJECT/.ck-rocprof-venv`) +- `CK_ROCPROF_BIN`: rocprof-compute binary path (auto-detected from PATH or /opt/rocm) +- `CK_ROCM_REQUIREMENTS`: Path to rocprofiler-compute requirements.txt (auto-detected) +- `CK_WORKLOAD_DIR`: Results directory (default: `$PROJECT/build/workloads`) +- `CK_GPU_TARGET`: Override GPU detection (e.g., `gfx950`, `MI300X`) + +## Interpreting Results + +**Good Performance:** +``` +Bank Conflicts/Access: <0.01 +Bank Conflict Rate: >90% of peak +LDS Data FIFO Full: Minimal cycles +``` + +**Needs Optimization:** +``` +Bank Conflicts/Access: >0.02 +Bank Conflict Cycles: High MAX values +LDS Data FIFO Full: High memory pressure +``` + +## Troubleshooting + +**"Profiling environment not set up"** +```bash +ck-rocprof setup +``` + +**"rocprof-compute not found"** +```bash +export CK_ROCPROF_BIN=/custom/path/rocprof-compute +ck-rocprof setup +``` + +**"Profiling results not found"** +```bash +ck-rocprof list # Check available runs +rocminfo | grep gfx # Verify GPU arch +export CK_GPU_TARGET=gfx950 # Override if needed +``` + +## Storage Layout + +Results stored in `workloads//`: +- `pmc_perf.csv`: Performance counters (primary data file) +- `perfmon/`: Input metric files +- `out/`: Raw output data from profiler runs +- `log.txt`: Profiling log + +## Technical Details + +- **Setup**: Creates isolated Python venv, installs dependencies +- **Profiling**: Runs `rocprof-compute profile --name -- ` +- **Analysis**: Runs `rocprof-compute analyze --path --block ` +- **GPU Support**: MI300/MI350 series, auto-detects architecture + +## Related Tools + +- `ck-docker`: Container management +- `rocprof-compute`: AMD GPU profiler v2 +- `rocm-smi`: System monitoring + +## License + +Copyright (c) Advanced Micro Devices, Inc. SPDX-License-Identifier: MIT diff --git a/script/tools/ck-test b/script/tools/ck-test index 712f904596..1ee8d0defd 100755 --- a/script/tools/ck-test +++ b/script/tools/ck-test @@ -2,7 +2,8 @@ # Copyright (c) Advanced Micro Devices, Inc., or its affiliates. # SPDX-License-Identifier: MIT -# CK Test - Build and test Composable Kernel in Docker +# CK Test - Run Composable Kernel tests +# Environment-agnostic: works natively on ROCm hosts or inside containers set -e set -o pipefail @@ -12,155 +13,219 @@ SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" source "${SCRIPT_DIR}/common.sh" # Initialize configuration -PROJECT_ROOT=$(get_project_root "${SCRIPT_DIR}") -CONTAINER_NAME=$(get_container_name "${PROJECT_ROOT}") +PROJECT_ROOT=$(find_project_root "${SCRIPT_DIR}" || get_project_root "${SCRIPT_DIR}") +BUILD_DIR=$(get_build_dir "${PROJECT_ROOT}") # Help message show_help() { cat << EOF -CK Test - Build and test Composable Kernel in Docker +CK Test - Run Composable Kernel tests -Usage: ck-test [options] [test_options] +Usage: ck-test [options] [test_name] [-- gtest_options] Options: -h, --help Show this help message - --name Specify container name - --reconfigure Reconfigure CMake before building + --build-dir Build directory (default: ./build) --no-build Skip building, run test directly + --list List available tests + --smoke Run all smoke tests (via CTest -L SMOKE_TEST) + --regression Run all regression tests (via CTest -L REGRESSION_TEST) + --all Run all tests (via CTest) + --filter Shorthand for --gtest_filter= Arguments: - test_name Name of test executable (required) - test_options Additional options passed to test (e.g., --gtest_filter=*) + test_name Name of test executable (optional for --smoke/--regression/--all) + gtest_options Additional options passed to test (after --) Environment: - CK_CONTAINER_NAME - Override default container name - GPU_TARGET - Override GPU target detection (e.g., gfx950, gfx942) + CK_BUILD_DIR - Override build directory Examples: - ck-test test_amdgcn_mma - ck-test test_amdgcn_mma --gtest_filter=*Fp16* - ck-test --name my_container test_amdgcn_mma - ck-test --reconfigure test_amdgcn_mma + ck-test test_amdgcn_mma # Build and run specific test + ck-test test_amdgcn_mma --filter '*Fp16*' # Run with gtest filter + ck-test test_amdgcn_mma -- --gtest_filter=*Fp16* # Explicit gtest options + ck-test --no-build test_amdgcn_mma # Run without rebuilding + ck-test --list # List available tests + ck-test --smoke # Run all smoke tests + ck-test --regression # Run all regression tests + ck-test --all # Run all tests EOF } # Parse arguments test_name="" -reconfigure=false no_build=false -test_options=() +list_tests=false +run_smoke=false +run_regression=false +run_all=false +gtest_filter="" +gtest_options=() +parsing_gtest=false while [[ $# -gt 0 ]]; do + if [ "$parsing_gtest" = true ]; then + gtest_options+=("$1") + shift + continue + fi + case $1 in -h|--help) show_help exit 0 ;; - --name) - CONTAINER_NAME="$2" + --build-dir) + require_arg "$1" "${2:-}" + BUILD_DIR="$2" shift 2 ;; - --reconfigure) - reconfigure=true - shift - ;; --no-build) no_build=true shift ;; - --gtest_*|--help) - test_options+=("$1") + --list) + list_tests=true + shift + ;; + --smoke) + run_smoke=true + shift + ;; + --regression) + run_regression=true + shift + ;; + --all) + run_all=true + shift + ;; + --filter) + require_arg "$1" "${2:-}" + gtest_filter="$2" + shift 2 + ;; + --) + parsing_gtest=true + shift + ;; + --gtest_*) + gtest_options+=("$1") shift ;; *) if [ -z "$test_name" ]; then test_name="$1" else - test_options+=("$1") + gtest_options+=("$1") fi shift ;; esac done -# Validate test name +# Add filter to gtest options if specified +if [ -n "$gtest_filter" ]; then + gtest_options+=("--gtest_filter=${gtest_filter}") +fi + +# Validate mutual exclusivity of test suite options +suite_count=0 +[ "$run_smoke" = true ] && suite_count=$((suite_count + 1)) +[ "$run_regression" = true ] && suite_count=$((suite_count + 1)) +[ "$run_all" = true ] && suite_count=$((suite_count + 1)) + +if [ "$suite_count" -gt 1 ]; then + error "Options --smoke, --regression, and --all are mutually exclusive" + exit 1 +fi + +# Check build is configured +if ! is_build_configured "${BUILD_DIR}"; then + error "Build not configured. Run 'ck-configure' first" + exit 1 +fi + +# Handle --list +if [ "$list_tests" = true ]; then + info "Available tests:" + if [ -d "${BUILD_DIR}/bin" ]; then + ls -1 "${BUILD_DIR}/bin/" 2>/dev/null | grep -E '^test_' | sort || echo " (No test binaries found)" + else + echo " (No bin directory found)" + fi + echo "" + echo "CTest labels:" + cd "${BUILD_DIR}" + ctest -N 2>/dev/null | head -20 || echo " (Run 'ctest -N' for full list)" + exit 0 +fi + +# Handle CTest-based test suites +if [ "$run_smoke" = true ] || [ "$run_regression" = true ] || [ "$run_all" = true ]; then + cd "${BUILD_DIR}" + + ctest_cmd=(ctest --output-on-failure) + + if [ "$run_smoke" = true ]; then + ctest_cmd+=(-L SMOKE_TEST) + info "Running smoke tests..." + elif [ "$run_regression" = true ]; then + ctest_cmd+=(-L REGRESSION_TEST) + info "Running regression tests..." + else + info "Running all tests..." + fi + + "${ctest_cmd[@]}" + exit_code=$? + + echo "" + if [ $exit_code -eq 0 ]; then + info "Tests completed successfully" + else + error "Tests failed with exit code: ${exit_code}" + fi + exit $exit_code +fi + +# Validate test name for individual test runs if [ -z "$test_name" ]; then - echo "Error: test_name required" + error "test_name required (or use --smoke/--regression/--all for test suites)" echo "" show_help exit 1 fi -# Ensure container is running -if ! container_is_running "${CONTAINER_NAME}"; then - echo "Container '${CONTAINER_NAME}' not running. Starting..." - "${SCRIPT_DIR}/ck-start" "${CONTAINER_NAME}" - echo "" -fi - -# Configure CMake if needed or requested -if [ "$reconfigure" = true ] || ! docker exec "${CONTAINER_NAME}" test -f /workspace/build/build.ninja 2>/dev/null; then - echo "Detecting GPU target..." - GPU_TARGET_DETECTED=$(detect_gpu_target "${CONTAINER_NAME}") - - if [ "$reconfigure" = true ]; then - echo "Reconfiguring CMake from scratch for GPU target: ${GPU_TARGET_DETECTED}" - else - echo "Configuring build with CMake for GPU target: ${GPU_TARGET_DETECTED}" - fi - - docker exec "${CONTAINER_NAME}" bash -c " - cd /workspace || exit 1 - rm -rf /workspace/build - mkdir /workspace/build - cd /workspace/build || exit 1 - cmake .. -GNinja \ - -DGPU_TARGETS=${GPU_TARGET_DETECTED} \ - -DCMAKE_BUILD_TYPE=Release \ - -DCMAKE_CXX_COMPILER=/opt/rocm/llvm/bin/clang++ \ - -DBUILD_TESTING=ON 2>&1 | tail -30 - " - echo "" -fi - # Build test if needed (unless --no-build is specified) if [ "$no_build" = false ]; then - if ! docker exec "${CONTAINER_NAME}" test -f "/workspace/build/bin/${test_name}" 2>/dev/null; then - echo "Building ${test_name}..." - docker exec "${CONTAINER_NAME}" bash -c " - cd /workspace/build || exit 1 - ninja ${test_name} 2>&1 - " - echo "" - else - echo "Test executable found, rebuilding to ensure latest version..." - docker exec "${CONTAINER_NAME}" bash -c " - cd /workspace/build || exit 1 - ninja ${test_name} 2>&1 - " - echo "" - fi + info "Building ${test_name}..." + "${SCRIPT_DIR}/ck-build" --build-dir "${BUILD_DIR}" "${test_name}" + echo "" +fi + +# Verify test executable exists +test_binary="${BUILD_DIR}/bin/${test_name}" +if [ ! -f "$test_binary" ]; then + error "Test executable not found: ${test_binary}" + echo "Run 'ck-build ${test_name}' first" + exit 1 fi # Run test -echo "Running: ${test_name} ${test_options[*]}" -echo "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━" +echo "Running: ${test_name} ${gtest_options[*]}" +echo "---" -# Build the command with proper quoting -cmd="cd /workspace/build && ./bin/${test_name}" -for opt in "${test_options[@]}"; do - cmd="${cmd} $(printf '%q' "$opt")" -done - -docker exec "${CONTAINER_NAME}" bash -c "${cmd}" +cd "${BUILD_DIR}" +"./bin/${test_name}" "${gtest_options[@]}" exit_code=$? -echo "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━" +echo "---" if [ $exit_code -eq 0 ]; then - echo "Test completed successfully" + info "Test completed successfully" else - echo "Test failed with exit code: ${exit_code}" + error "Test failed with exit code: ${exit_code}" fi exit $exit_code diff --git a/script/tools/common.sh b/script/tools/common.sh index 6683572c0f..e5a39cea67 100644 --- a/script/tools/common.sh +++ b/script/tools/common.sh @@ -74,14 +74,14 @@ container_is_running() { detect_gpu_target() { local container="$1" - # Allow override via GPU_TARGET environment variable - if [ -n "${GPU_TARGET:-}" ]; then - echo "${GPU_TARGET}" + # Allow override via CK_GPU_TARGET environment variable + if [ -n "${CK_GPU_TARGET:-}" ]; then + echo "${CK_GPU_TARGET}" return 0 fi docker exec "${container}" bash -c " - rocminfo 2>/dev/null | grep -oP 'gfx[0-9a-z]+' | head -1 || echo 'gfx950' + rocminfo 2>/dev/null | grep -oE 'gfx[0-9a-z]+' | head -1 || echo 'gfx950' " | tr -d '\r\n' } @@ -95,3 +95,87 @@ ensure_container_running() { "${script_dir}/ck-docker" start "${container}" fi } + +# ============================================================================ +# Native (non-Docker) utilities +# ============================================================================ + +# Output utilities +info() { echo "[info] $*"; } +warn() { echo "[warn] $*" >&2; } +error() { echo "[error] $*" >&2; } + +# Require argument for option (validates $2 exists and is not another flag) +require_arg() { + local option="$1" + local value="$2" + if [ -z "$value" ] || [[ "$value" == -* ]]; then + error "Option $option requires an argument" + exit 1 + fi +} + +# Native GPU detection (no Docker required) +detect_gpu_native() { + # Allow override via CK_GPU_TARGET environment variable + if [ -n "${CK_GPU_TARGET:-}" ]; then + echo "${CK_GPU_TARGET}" + return 0 + fi + + # Try rocminfo if available + if command -v rocminfo &>/dev/null; then + local gpu + gpu=$(rocminfo 2>/dev/null | grep -oE 'gfx[0-9a-z]+' | head -1) + if [ -n "$gpu" ]; then + echo "$gpu" + return 0 + fi + fi + + # Fallback + echo "gfx950" +} + +# Get build directory (respects CK_BUILD_DIR env var) +get_build_dir() { + local project_root="${1:-$(get_project_root "$(dirname "${BASH_SOURCE[0]}")")}" + echo "${CK_BUILD_DIR:-${project_root}/build}" +} + +# Check if build is configured (build.ninja exists) +is_build_configured() { + local build_dir="${1:-$(get_build_dir)}" + [ -f "${build_dir}/build.ninja" ] +} + +# Find project root from any subdirectory (walks up to find .git) +find_project_root() { + local dir="${1:-$(pwd)}" + while [ "$dir" != "/" ]; do + if [ -d "$dir/.git" ]; then + echo "$dir" + return 0 + fi + dir=$(dirname "$dir") + done + return 1 +} + +# List available CMake presets +list_cmake_presets() { + local project_root="${1:-$(find_project_root)}" + local presets_file="${project_root}/CMakePresets.json" + + if [ ! -f "$presets_file" ]; then + return 1 + fi + + # Extract non-hidden preset names + if command -v jq &>/dev/null; then + jq -r '.configurePresets[] | select(.hidden != true) | .name' "$presets_file" 2>/dev/null + else + # Fallback: sed-based extraction (more portable than grep -P) + sed -n 's/.*"name"[[:space:]]*:[[:space:]]*"\([^"]*\)".*/\1/p' "$presets_file" | grep -v '^use-' + fi +} diff --git a/test/ck_tile/gemm_block_scale/CMakeLists.txt b/test/ck_tile/gemm_block_scale/CMakeLists.txt index e1c246a8b0..2080fc185f 100644 --- a/test/ck_tile/gemm_block_scale/CMakeLists.txt +++ b/test/ck_tile/gemm_block_scale/CMakeLists.txt @@ -76,6 +76,22 @@ if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx12") ) target_compile_options(test_tile_gemm_quant_abquant_preshuffle PRIVATE ${TEST_GEMM_COMPILE_OPTIONS}) + + add_gtest_executable(test_tile_gemm_quant_abquant_a4w4_base + test_gemm_quant_abquant_a4w4_base.cpp + ) + target_compile_options(test_tile_gemm_quant_abquant_a4w4_base PRIVATE ${TEST_GEMM_COMPILE_OPTIONS}) + + add_gtest_executable(test_tile_gemm_quant_abquant_a4w4_padding + test_gemm_quant_abquant_a4w4_padding.cpp + ) + target_compile_options(test_tile_gemm_quant_abquant_a4w4_padding PRIVATE ${TEST_GEMM_COMPILE_OPTIONS}) + + add_gtest_executable(test_tile_gemm_quant_abquant_a4w4_preshuffle + test_gemm_quant_abquant_a4w4_preshuffle.cpp + ) + target_compile_options(test_tile_gemm_quant_abquant_a4w4_preshuffle PRIVATE ${TEST_GEMM_COMPILE_OPTIONS}) + add_gtest_executable(test_tile_gemm_quant_abquant_preshuffleQuant test_gemm_quant_abquant_preshuffleQuant.cpp ) diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_abquant_a4w4_base.cpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_abquant_a4w4_base.cpp new file mode 100644 index 0000000000..5e2403f7d1 --- /dev/null +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_abquant_a4w4_base.cpp @@ -0,0 +1,44 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck_tile/host.hpp" +#include "ck_tile/ops/gemm.hpp" + +#include +#include + +#include "test_gemm_quant_fixtures.hpp" + +// Type aliases for readability +using RowMajor = ck_tile::tensor_layout::gemm::RowMajor; +using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor; +using Half = ck_tile::half_t; +using PkFP4 = ck_tile::pk_fp4_t; +using ABQuantGrouped = + std::integral_constant; + +// 1d block sizes for AQuant +using GroupSize1D = ck_tile::QuantGroupShape>; + +// 2d block sizes for BQuant +using GroupSize2D = ck_tile::QuantGroupShape>; + +// Type combinations for ABQuant tests +// Tuple format: +// clang-format off +using ABQuantTypes = ::testing::Types< + // PreshuffleQuant = false && TransposeC = false + // RCR layout with RowMajor AQ, ColumnMajor BQ + std::tuple +>; +// clang-format on + +// Test suite for ABQuant +TYPED_TEST_SUITE(TestCkTileGemmABQuant, ABQuantTypes); + +// AQuant tests +TYPED_TEST(TestCkTileGemmABQuant, ABQuantGroupedTest) +{ + this->run_test_with_validation(1024, 1024, 1024); +} diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_abquant_a4w4_padding.cpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_abquant_a4w4_padding.cpp new file mode 100644 index 0000000000..1e496d5b64 --- /dev/null +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_abquant_a4w4_padding.cpp @@ -0,0 +1,65 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck_tile/host.hpp" +#include "ck_tile/ops/gemm.hpp" + +#include +#include + +#include "test_gemm_quant_fixtures.hpp" + +// Type aliases for readability +using RowMajor = ck_tile::tensor_layout::gemm::RowMajor; +using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor; +using Half = ck_tile::half_t; +using PkFP4 = ck_tile::pk_fp4_t; +using ABQuantGrouped = + std::integral_constant; + +// 1d block sizes for AQuant +using GroupSize1D = ck_tile::QuantGroupShape>; + +// 2d block sizes for BQuant +using GroupSize2D = ck_tile::QuantGroupShape>; + +// Type combinations for ABQuant tests +// Tuple format: +// clang-format off +using ABQuantTypes = ::testing::Types< + // PreshuffleQuant = false && TransposeC = false + // RCR layout with RowMajor AQ, ColumnMajor BQ + std::tuple +>; +// clang-format on + +// Test suite for ABQuant +TYPED_TEST_SUITE(TestCkTileGemmABQuant, ABQuantTypes); + +// AQuant tests + +TYPED_TEST(TestCkTileGemmABQuant, ABQuantGroupedTest_PadK) +{ + this->run_test_with_validation(1024, 1024, 832); +} + +TYPED_TEST(TestCkTileGemmABQuant, ABQuantGroupedTest_PadN) +{ + this->run_test_with_validation(1024, 832, 1024); +} + +TYPED_TEST(TestCkTileGemmABQuant, ABQuantGroupedTest_PadM) +{ + this->run_test_with_validation(832, 1024, 1024); +} + +TYPED_TEST(TestCkTileGemmABQuant, ABQuantGroupedTest_PadMNK) +{ + this->run_test_with_validation(832, 832, 832); +} + +TYPED_TEST(TestCkTileGemmABQuant, ABQuantGroupedTest_PadNK) +{ + this->run_test_with_validation(1024, 832, 832); +} diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_abquant_a4w4_preshuffle.cpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_abquant_a4w4_preshuffle.cpp new file mode 100644 index 0000000000..43051c8d08 --- /dev/null +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_abquant_a4w4_preshuffle.cpp @@ -0,0 +1,44 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck_tile/host.hpp" +#include "ck_tile/ops/gemm.hpp" + +#include +#include + +#include "test_gemm_quant_fixtures.hpp" + +// Type aliases for readability +using RowMajor = ck_tile::tensor_layout::gemm::RowMajor; +using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor; +using Half = ck_tile::half_t; +using PkFP4 = ck_tile::pk_fp4_t; +using ABQuantGrouped = + std::integral_constant; + +// 1d block sizes for AQuant +using GroupSize1D = ck_tile::QuantGroupShape>; + +// 2d block sizes for BQuant +using GroupSize2D = ck_tile::QuantGroupShape>; + +// Type combinations for ABQuant tests +// Tuple format: +// clang-format off +using ABQuantTypes = ::testing::Types< + // RCR layout with RowMajor AQ, ColumnMajor BQ + // PreshuffleB = true && TransposeC = false + std::tuple +>; +// clang-format on + +// Test suite for ABQuant +TYPED_TEST_SUITE(TestCkTileGemmABQuant, ABQuantTypes); + +// AQuant tests +TYPED_TEST(TestCkTileGemmABQuant, ABQuantGroupedTest) +{ + this->run_test_with_validation(1024, 1024, 1024); +} diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_base.hpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_base.hpp index 7be4131db4..5937b44229 100644 --- a/test/ck_tile/gemm_block_scale/test_gemm_quant_base.hpp +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_base.hpp @@ -209,7 +209,7 @@ template <> struct QuantTypeTraits { template - using ComputeDataType = BDataType; // For AQuant, compute type is BDataType + using ComputeDataType = void; // Use automatically determined compute type static constexpr const char* name = "abquant"; }; diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp index 9683fa98aa..0033bb42a8 100644 --- a/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp @@ -1174,8 +1174,8 @@ class TestCkTileGemmABQuant : public TestCkTileGemmQuantBase>; using GemmEpilogue = ck_tile::CShuffleEpilogue< - ck_tile::CShuffleEpilogueProblem, AccDataType, CDataType, diff --git a/test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight_interface_xdl.cpp b/test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight_interface_xdl.cpp index bce6da4b68..5aa0b13c07 100644 --- a/test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight_interface_xdl.cpp +++ b/test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight_interface_xdl.cpp @@ -184,5 +184,5 @@ TYPED_TEST(TestGroupedConvndBwdWeightDefault, SingleStageAutoDeduce) this->conv_param = {2, 2, 128, 128, 256, {1, 1}, {3, 3}, {1, 1}, {1, 1}, {0, 0}, {0, 0}}; this->split_k_ = -1; bool is_supported = this->template Run<2>(); - EXPECT_FALSE(is_supported); + EXPECT_TRUE(is_supported); }