diff --git a/.gitignore b/.gitignore index bcc5888b7f..6641e5bc58 100644 --- a/.gitignore +++ b/.gitignore @@ -36,6 +36,9 @@ tags # Editors .vscode +# Cline +.cline* + # build-in-source directory (see exceptions below) build* diff --git a/Dockerfile.aiter b/Dockerfile.aiter index b61c1e41a5..dab3f9588d 100644 --- a/Dockerfile.aiter +++ b/Dockerfile.aiter @@ -1,9 +1,8 @@ -ARG BASE_DOCKER="rocm/composable_kernel-private:ck_aiter_base" +ARG BASE_DOCKER="rocm/pytorch:latest" FROM $BASE_DOCKER ARG AITER_BRANCH="main" ARG CK_AITER_BRANCH="develop" -RUN groupadd irc && \ - pip install pandas zmq einops && \ +RUN pip install pandas zmq einops ninja && \ pip install numpy==1.26.2 && \ sudo mkdir /home/jenkins && \ sudo mkdir /home/jenkins/workspace && \ @@ -14,6 +13,8 @@ RUN groupadd irc && \ rm -rf 3rdparty/composable_kernel/ && \ git clone -b "$CK_AITER_BRANCH" https://github.com/ROCm/composable_kernel.git 3rdparty/composable_kernel/ && \ python3 setup.py develop && \ - chown -R jenkins:jenkins /home/jenkins/workspace && \ - chmod -R a+rwx /home/jenkins/workspace && \ + groupadd -g 1001 jenkins && \ + useradd -u 1001 -g 1001 -m -s /bin/bash jenkins && \ + chown -R jenkins:jenkins /home/jenkins && \ + chmod -R a+rwx /home/jenkins && \ sudo usermod -aG irc jenkins diff --git a/Jenkinsfile b/Jenkinsfile index 9acbbeeca2..625b6d0d09 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -194,6 +194,33 @@ def check_arch(){ return arch_type } +def check_arch_name(){ + def arch_name = "" + sh 'rocminfo | tee rocminfo.log' + if ( runShell('grep -n "gfx90a" rocminfo.log') ){ + arch_name = "gfx90a" + } + else if ( runShell('grep -n "gfx942" rocminfo.log') ) { + arch_name = "gfx942" + } + else if ( runShell('grep -n "gfx10" rocminfo.log') ) { + arch_name = "gfx10" + } + else if ( runShell('grep -n "gfx11" rocminfo.log') ) { + arch_name = "gfx11" + } + else if ( runShell('grep -n "gfx12" rocminfo.log') ) { + arch_name = "gfx12" + } + else if ( runShell('grep -n "gfx908" rocminfo.log') ) { + arch_name = "gfx908" + } + else if ( runShell('grep -n "gfx950" rocminfo.log') ) { + arch_name = "gfx950" + } + return arch_name +} + def getDockerImage(Map conf=[:]){ env.DOCKER_BUILDKIT=1 def prefixpath = conf.get("prefixpath", "/opt/rocm") @@ -302,12 +329,6 @@ def cmake_build(Map conf=[:]){ //cmake_env can overwrite default CXX variables. def cmake_envs = "CXX=${compiler} CXXFLAGS='-Werror' " + conf.get("cmake_ex_env","") - def package_build = (conf.get("package_build","") == "true") - - if (package_build == true) { - config_targets = "package" - } - if(conf.get("build_install","") == "true") { config_targets = 'install ' + config_targets @@ -455,15 +476,20 @@ def cmake_build(Map conf=[:]){ else{ sh "ninja check" } + if(params.BUILD_PACKAGES){ + echo "Build ckProfiler packages" + sh 'ninja -j64 package' + def arch_name = check_arch_name() + sh "mv composablekernel-ckprofiler_*.deb composablekernel-ckprofiler_1.2.0_amd64_${arch_name}.deb" + stash includes: "composablekernel-ckprofiler**.deb", name: "profiler_package_${arch_name}" + } } if(params.BUILD_INSTANCES_ONLY){ // build deb packages - echo "Build packages" + echo "Build library package" sh 'ninja -j64 package' - archiveArtifacts artifacts: 'composablekernel-dev*.deb' sh 'mv composablekernel-dev_*.deb composablekernel-dev_all_targets_1.2.0_amd64.deb' - sh 'mv composablekernel-ckprofiler_*.deb composablekernel-ckprofiler_1.2.0_amd64.deb' - stash includes: "composablekernel-**.deb", name: "packages" + stash includes: "composablekernel-dev**.deb", name: "lib_package" } } else{ @@ -475,15 +501,18 @@ def cmake_build(Map conf=[:]){ else{ sh "ninja check" } + if(params.BUILD_PACKAGES){ + echo "Build ckProfiler packages" + sh 'ninja -j64 package' + def arch_name = check_arch_name() + sh "mv composablekernel-ckprofiler_*.deb composablekernel-ckprofiler_1.2.0_amd64_${arch_name}.deb" + stash includes: "composablekernel-ckprofiler**.deb", name: "profiler_package_${arch_name}" + } } } } } - // Only archive from develop - if (package_build == true && env.BRANCH_NAME == "develop") { - archiveArtifacts artifacts: "build/*.deb", allowEmptyArchive: true, fingerprint: true - } //check the node gpu architecture def arch = check_arch() if (params.RUN_CK_TILE_FMHA_TESTS){ @@ -823,9 +852,42 @@ def process_results(Map conf=[:]){ } if (params.BUILD_INSTANCES_ONLY){ // unstash deb packages - unstash "packages" + try{ + unstash "lib_package" + } + catch(Exception err){ + echo "could not locate lib_package." + } sh "sshpass -p ${env.ck_deb_pw} scp -o StrictHostKeyChecking=no composablekernel-*.deb ${env.ck_deb_user}@${env.ck_deb_ip}:/var/www/html/composable_kernel/" } + if (params.BUILD_PACKAGES){ + // unstash deb packages + try{ + unstash "profiler_package_gfx90a" + } + catch(Exception err){ + echo "could not locate profiler_package_gfx90a." + } + try{ + unstash "profiler_package_gfx942" + } + catch(Exception err){ + echo "could not locate profiler_package_gfx942." + } + try{ + unstash "profiler_package_gfx950" + } + catch(Exception err){ + echo "could not locate profiler_package_gfx950." + } + try{ + unstash "profiler_package_gfx12" + } + catch(Exception err){ + echo "could not locate profiler_package_gfx12." + } + sh "sshpass -p ${env.ck_deb_pw} scp -o StrictHostKeyChecking=no composablekernel-ckprofiler*.deb ${env.ck_deb_user}@${env.ck_deb_ip}:/var/www/html/composable_kernel/" + } else{ // unstash perf files to master try{ @@ -993,7 +1055,7 @@ def run_pytorch_tests(Map conf=[:]){ //launch develop branch daily jobs CRON_SETTINGS = BRANCH_NAME == "develop" ? '''0 23 * * * % RUN_FULL_QA=true;RUN_CK_TILE_FMHA_TESTS=true;RUN_PERFORMANCE_TESTS=true;FORCE_CI=true 0 22 * * * % RUN_FULL_QA=true;DISABLE_DL_KERNELS=true;RUN_TILE_ENGINE_GEMM_TESTS=true;RUN_PERFORMANCE_TESTS=true;RUN_ALL_UNIT_TESTS=true;FORCE_CI=true - 0 21 * * * % RUN_GROUPED_CONV_LARGE_CASES_TESTS=true;hipTensor_test=true;BUILD_GFX908=true;BUILD_GFX942=true;BUILD_GFX950=true;RUN_PERFORMANCE_TESTS=true;RUN_ALL_UNIT_TESTS=true;FORCE_CI=true + 0 21 * * * % RUN_GROUPED_CONV_LARGE_CASES_TESTS=true;hipTensor_test=true;BUILD_GFX908=true;BUILD_GFX942=true;BUILD_GFX950=true;RUN_PERFORMANCE_TESTS=true;RUN_ALL_UNIT_TESTS=true;FORCE_CI=true;BUILD_PACKAGES=true 0 19 * * * % BUILD_DOCKER=true;COMPILER_VERSION=amd-staging;BUILD_COMPILER=/llvm-project/build/bin/clang++;USE_SCCACHE=false;NINJA_BUILD_TRACE=true;RUN_ALL_UNIT_TESTS=true;FORCE_CI=true 0 17 * * * % BUILD_DOCKER=true;COMPILER_VERSION=amd-mainline;BUILD_COMPILER=/llvm-project/build/bin/clang++;USE_SCCACHE=false;NINJA_BUILD_TRACE=true;RUN_ALL_UNIT_TESTS=true;FORCE_CI=true 0 15 * * * % BUILD_INSTANCES_ONLY=true;USE_SCCACHE=false;NINJA_BUILD_TRACE=true;FORCE_CI=true @@ -1085,6 +1147,10 @@ pipeline { name: "BUILD_INSTANCES_ONLY", defaultValue: false, description: "Test building instances for various architectures simultaneously (default: OFF)") + booleanParam( + name: "BUILD_PACKAGES", + defaultValue: false, + description: "Build packages for the libraries and/or ckProfiler (default: OFF)") booleanParam( name: "BUILD_GFX908", defaultValue: false, @@ -1574,7 +1640,6 @@ pipeline { -D GPU_TARGETS="gfx1201" \ -D GEMM_DATATYPE="fp16" \ -D GEMM_LAYOUT="rcr;rrr;crr;ccr" \ - -DGEMM_CONFIG_FILE=gfx120x_config.json \ -DCMAKE_CXX_FLAGS=" -O3 " .. && \ ninja -j64 benchmark_gemm_all && \ python3 ../tile_engine/ops/gemm/gemm_benchmark.py . --problem-sizes "1024,1024,1024" \ @@ -1830,7 +1895,7 @@ pipeline { stage("Process results"){ when { beforeAgent true - expression { (params.RUN_PERFORMANCE_TESTS.toBoolean() || params.BUILD_INSTANCES_ONLY.toBoolean() || params.RUN_CK_TILE_FMHA_TESTS.toBoolean()) && !params.BUILD_LEGACY_OS.toBoolean() } + expression { (params.RUN_PERFORMANCE_TESTS.toBoolean() || params.BUILD_INSTANCES_ONLY.toBoolean() || params.RUN_CK_TILE_FMHA_TESTS.toBoolean()|| params.BUILD_PACKAGES.toBoolean()) && !params.BUILD_LEGACY_OS.toBoolean() } } agent { label 'mici' } steps{ diff --git a/codegen/test/rtc/include/rtc/kernel.hpp b/codegen/test/rtc/include/rtc/kernel.hpp index b1ee729f77..96337fe2c1 100644 --- a/codegen/test/rtc/include/rtc/kernel.hpp +++ b/codegen/test/rtc/include/rtc/kernel.hpp @@ -52,7 +52,7 @@ struct kernel template auto launch(hipStream_t stream, std::size_t global, std::size_t local, Ts... zs) const { - return [=](auto&&... xs) { + return [=, this](auto&&... xs) { launch(stream, global, local, std::vector{xs...}, zs...); }; } diff --git a/example/15_grouped_gemm/grouped_gemm_xdl_fp32.cpp b/example/15_grouped_gemm/grouped_gemm_xdl_fp32.cpp index fb047ae364..236e5e4fa2 100644 --- a/example/15_grouped_gemm/grouped_gemm_xdl_fp32.cpp +++ b/example/15_grouped_gemm/grouped_gemm_xdl_fp32.cpp @@ -59,4 +59,11 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedGemm_Xdl #include "run_grouped_gemm_example.inc" -int main(int argc, char* argv[]) { return !run_grouped_gemm_example(argc, argv); } +int main(int argc, char* argv[]) +{ + if(ck::is_gfx11_supported() || ck::is_gfx12_supported()) + { + return 0; + } + return !run_grouped_gemm_example(argc, argv); +} diff --git a/example/15_grouped_gemm/run_grouped_gemm_example.inc b/example/15_grouped_gemm/run_grouped_gemm_example.inc index 4ef6074f4a..87ccebc3c4 100644 --- a/example/15_grouped_gemm/run_grouped_gemm_example.inc +++ b/example/15_grouped_gemm/run_grouped_gemm_example.inc @@ -278,19 +278,20 @@ bool run_grouped_gemm_example(int argc, char* argv[]) problem_size.group_count = 16; - if(argc == 4) + if(argc == 1) + { + // use default cases + } + else if(argc == 4 || argc == 6) { config.do_verification = std::stoi(argv[1]); config.init_method = std::stoi(argv[2]); config.time_kernel = std::stoi(argv[3]); - } - else if(argc == 6) - { - config.do_verification = std::stoi(argv[1]); - config.init_method = std::stoi(argv[2]); - config.time_kernel = std::stoi(argv[3]); - config.async_hargs = std::stoi(argv[4]); - problem_size.group_count = std::stoi(argv[5]); + if(argc == 6) + { + config.async_hargs = std::stoi(argv[4]); + problem_size.group_count = std::stoi(argv[5]); + } } else { @@ -299,18 +300,33 @@ bool run_grouped_gemm_example(int argc, char* argv[]) printf("arg3: time kernel (0=n0, 1=yes)\n"); printf("arg4: async hargs (0=n0, 1=yes)\n"); printf("arg5: group count (default=16)"); - exit(0); + exit(1); } + // Lambda to get stride based on layout + auto get_stride = [](auto layout, auto row_dim, auto col_dim) { + if constexpr(std::is_same_v) + { + return col_dim; + } + else + { + return row_dim; + } + }; + for(int i = 0; i < problem_size.group_count; i++) { problem_size.Ms.push_back(256 + 256 * i); problem_size.Ns.push_back(128 + 128 * i); problem_size.Ks.push_back(128 + 64 * i); - problem_size.stride_As.push_back(problem_size.Ks[i]); - problem_size.stride_Bs.push_back(problem_size.Ks[i]); - problem_size.stride_Cs.push_back(problem_size.Ns[i]); + problem_size.stride_As.push_back( + get_stride(ALayout{}, problem_size.Ms[i], problem_size.Ks[i])); + problem_size.stride_Bs.push_back( + get_stride(BLayout{}, problem_size.Ks[i], problem_size.Ns[i])); + problem_size.stride_Cs.push_back( + get_stride(ELayout{}, problem_size.Ms[i], problem_size.Ns[i])); } return run_grouped_gemm(problem_size, config); diff --git a/example/45_elementwise_normalization/elementwise_layernorm_blockwise.cpp b/example/45_elementwise_normalization/elementwise_layernorm_blockwise.cpp index 8064809123..21c5ff8d5a 100644 --- a/example/45_elementwise_normalization/elementwise_layernorm_blockwise.cpp +++ b/example/45_elementwise_normalization/elementwise_layernorm_blockwise.cpp @@ -82,37 +82,29 @@ int main(int argc, char* argv[]) bool do_verification = true; bool time_kernel = true; + ck::index_t M = 48 * 256; + ck::index_t N = 1024; + if(argc == 1) { // use default } - else if(argc == 3) + else if(argc == 3 || argc == 5) { do_verification = std::stoi(argv[1]); time_kernel = std::stoi(argv[2]); + if(argc == 5) + { + M = std::stoi(argv[3]); + N = std::stoi(argv[4]); + } } else { printf("arg1: verification (0=no, 1=yes)\n"); printf("arg2: time kernel (0=no, 1=yes)\n"); - exit(0); - } - - ck::index_t M = 48 * 256; - ck::index_t N = 1024; - if(argc == 1) - { - // use default case - } - else if(argc == 3) - { - M = std::stoi(argv[1]); - N = std::stoi(argv[2]); - } - else - { - std::cerr << "arg1 to 2: M, N" << std::endl; - return 1; + printf("arg3-4: M, N\n"); + exit(1); } ck::index_t Stride = N; diff --git a/example/ck_tile/17_grouped_gemm/grouped_gemm.hpp b/example/ck_tile/17_grouped_gemm/grouped_gemm.hpp index 10d7befc06..57d3f224d8 100644 --- a/example/ck_tile/17_grouped_gemm/grouped_gemm.hpp +++ b/example/ck_tile/17_grouped_gemm/grouped_gemm.hpp @@ -182,6 +182,7 @@ struct GemmConfigPreshuffleDecode : public GemmConfigBase static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default; static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_PRESHUFFLE_V2; static constexpr bool Preshuffle = true; + static constexpr bool Persistent = true; static constexpr bool DoubleSmemBuffer = true; }; diff --git a/example/ck_tile/17_grouped_gemm/grouped_gemm_preshuffle.cpp b/example/ck_tile/17_grouped_gemm/grouped_gemm_preshuffle.cpp index b9d6a4a1bc..52b84737cc 100644 --- a/example/ck_tile/17_grouped_gemm/grouped_gemm_preshuffle.cpp +++ b/example/ck_tile/17_grouped_gemm/grouped_gemm_preshuffle.cpp @@ -167,6 +167,113 @@ float grouped_gemm(const std::vector& gemm_descs, return ave_time = BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num); } +template +float grouped_gemm_tileloop(const ck_tile::stream_config& s, + const ck_tile::index_t num_groups, + void* kargs_ptr, + bool splitk) +{ + using GemmShape = ck_tile::TileGemmShape< + ck_tile::sequence, + ck_tile::sequence, + ck_tile:: + sequence>; + using TilePartitioner = + ck_tile::GemmSpatiallyLocalTilePartitioner; + + using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits; + + float ave_time{0}; + + const auto Run = [&](const auto memory_operation_) { + constexpr auto scheduler = GemmConfig::Scheduler; + constexpr auto memory_operation = memory_operation_.value; + + using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; + + using GemmPipeline = typename PipelineTypeTraits< + GemmConfig::Pipeline>::template GemmPipeline; + using GemmEpilogue = ck_tile::CShuffleEpilogue, // DsDataType (empty for no D tensors) + AccDataType, + CDataType, + ck_tile::tuple<>, // DsLayout (empty for no D tensors) + CLayout, + ck_tile::element_wise::PassThrough, + TilePartitioner::MPerBlock, + TilePartitioner::NPerBlock, + GemmConfig::M_Warp, + GemmConfig::N_Warp, + GemmConfig::M_Warp_Tile, + GemmConfig::N_Warp_Tile, + GemmConfig::K_Warp_Tile, + UniversalGemmProblem::TransposeC, + memory_operation>>; + using Kernel = ck_tile::GroupedGemmKernel; + const dim3 blocks = Kernel::BlockSize(); + const dim3 grids = Kernel::MaxOccupancyGridSize(s); + + if(s.log_level_ > 0) + { + std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" << " grid: {" + << grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {" + << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl; + } + + ave_time = + ck_tile::launch_kernel(s, + ck_tile::make_kernel( + Kernel{}, + grids, + blocks, + 0, + ck_tile::cast_pointer_to_constant_address_space(kargs_ptr), + num_groups)); + + return ave_time; + }; + + if(splitk) + { + Run(ck_tile::integral_constant{}); + } + else + { + Run(ck_tile::integral_constant{}); + } + + return ave_time; +} + #include "run_grouped_gemm_example.inc" template diff --git a/example/ck_tile/17_grouped_gemm/quant_grouped_gemm.cpp b/example/ck_tile/17_grouped_gemm/quant_grouped_gemm.cpp index 3b4258d8b1..1a913fcfc1 100644 --- a/example/ck_tile/17_grouped_gemm/quant_grouped_gemm.cpp +++ b/example/ck_tile/17_grouped_gemm/quant_grouped_gemm.cpp @@ -29,7 +29,7 @@ template + ck_tile::QuantType QuantMode = ck_tile::QuantType::BQuantGrouped> float grouped_gemm_tileloop(const ck_tile::stream_config& s, const ck_tile::index_t num_groups, void* kargs_ptr) @@ -48,8 +48,8 @@ float grouped_gemm_tileloop(const ck_tile::stream_config& s, using GemmUniversalTraits = ck_tile::TileGemmQuantTraits; + using QuantGemmProblem = typename std::conditional< + QuantMode == ck_tile::QuantType::BQuantGrouped, + ck_tile::GemmBQuantPipelineProblem, // QuantGroupSize + ck_tile::GemmRowColTensorQuantPipelineProblem>::type; - using GemmPipeline = typename PipelineTypeTraits< - GemmConfig::Pipeline>::template GemmPipeline; + using GemmPipeline = + typename std::conditional, + ck_tile::GemmPipelineAgBgCrCompV3>::type; using GemmEpilogue = ck_tile::CShuffleEpilogue< ck_tile::CShuffleEpilogueProblem constexpr ck_tile::index_t get_k_warp_tile() @@ -41,6 +42,14 @@ struct GemmTypeConfig using AccDataType = float; using CDataType = ck_tile::half_t; }; +template <> +struct GemmTypeConfig +{ + using ADataType = ck_tile::bf8_t; + using BDataType = ck_tile::bf8_t; + using AccDataType = float; + using CDataType = ck_tile::half_t; +}; struct GemmConfigBase { @@ -77,24 +86,11 @@ struct GemmConfigComputeV3_2 : public GemmConfigBase static constexpr ck_tile::index_t N_Warp_Tile = 32; static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); - static constexpr bool DoubleSmemBuffer = false; - static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3; + static constexpr bool DoubleSmemBuffer = false; static constexpr int kBlockPerCu = 1; }; -template -struct PipelineTypeTraits; - -template <> -struct PipelineTypeTraits -{ - template - using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3; - template - using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV3; -}; - using grouped_gemm_kargs = ck_tile::QuantGroupedGemmHostArgs; auto create_args(int argc, char* argv[]) @@ -122,8 +118,7 @@ auto create_args(int argc, char* argv[]) .insert("repeat", "100", "number of iterations to benchmark the kernel.") .insert("group_count", "8", "group count.") .insert("kbatch", "1", "kbatch for SplitK") - .insert("quant_mode", "tensor", "Choose tensor (default), or rowcol"); - ; + .insert("quant_mode", "bquant", "Choose bquant (default), tensor, or rowcol"); bool result = arg_parser.parse(argc, argv); return std::make_tuple(result, arg_parser); diff --git a/example/ck_tile/17_grouped_gemm/quant_run_grouped_gemm_example.inc b/example/ck_tile/17_grouped_gemm/quant_run_grouped_gemm_example.inc index 19211ed494..152df38bff 100644 --- a/example/ck_tile/17_grouped_gemm/quant_run_grouped_gemm_example.inc +++ b/example/ck_tile/17_grouped_gemm/quant_run_grouped_gemm_example.inc @@ -43,8 +43,8 @@ template + ck_tile::QuantType QuantMode = ck_tile::QuantType::BQuantGrouped, + typename CDEElementWise = ck_tile::element_wise::PassThrough> float invoke_gemm(int n_warmup, int n_repeat, int group_count, @@ -159,11 +159,12 @@ int run_grouped_gemm_example_with_layouts(int argc, return group_count != 0 && ((args.size() == static_cast(group_count)) && ...); }; - const int group_count = arg_parser.get_int("group_count"); - const int repeat = arg_parser.get_int("repeat"); - const int warmup = arg_parser.get_int("warmup"); - const int kbatch = arg_parser.get_int("kbatch"); - bool validate = arg_parser.get_bool("validate"); + const int group_count = arg_parser.get_int("group_count"); + const int repeat = arg_parser.get_int("repeat"); + const int warmup = arg_parser.get_int("warmup"); + const int kbatch = arg_parser.get_int("kbatch"); + bool validate = arg_parser.get_bool("validate"); + const ck_tile::index_t QuantGroupSize = 128; if(kbatch > 1 && validate && warmup + repeat > 1) { @@ -172,9 +173,11 @@ int run_grouped_gemm_example_with_layouts(int argc, validate = false; } - std::vector Ms = arg_parser.get_int_vec("Ms"); - std::vector Ns = arg_parser.get_int_vec("Ns"); - std::vector Ks = arg_parser.get_int_vec("Ks"); + std::vector Ms = arg_parser.get_int_vec("Ms"); + std::vector Ns = arg_parser.get_int_vec("Ns"); + std::vector Ks = arg_parser.get_int_vec("Ks"); + std::vector AQs; // dimension of AQ tensor is calculated from A tensor + std::vector BQs; // dimension of BQ tensor is calculated from B tensor std::vector stride_As = arg_parser.get_int_vec("stride_As"); std::vector stride_Bs = arg_parser.get_int_vec("stride_Bs"); std::vector stride_Cs = arg_parser.get_int_vec("stride_Cs"); @@ -252,6 +255,15 @@ int run_grouped_gemm_example_with_layouts(int argc, AQK = 1; // Row quantization: tensor shape [M, 1] or [1] BQK = 1; // Column quantization: tensor shape [1, N] or [1] } + else if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped) + { + AQK = 0; // No A quantization + BQK = K / QuantGroupSize; // Group quantization: BQK = K / GroupSize + if(K % QuantGroupSize != 0) + { + throw std::runtime_error("K must be divisible by 128 for BQuantGrouped mode"); + } + } stride_As[i] = ck_tile::get_default_stride(M, K, stride_As[i], is_row_major(a_layout)); stride_Bs[i] = ck_tile::get_default_stride(K, N, stride_Bs[i], is_row_major(b_layout)); @@ -289,6 +301,13 @@ int run_grouped_gemm_example_with_layouts(int argc, bq_tensors.push_back(ck_tile::HostTensor( ck_tile::host_tensor_descriptor(1, 1, stride_BQs[i], is_row_major(bq_layout)))); } + else if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped) + { + aq_tensors.push_back(ck_tile::HostTensor( + ck_tile::host_tensor_descriptor(0, AQK, stride_AQs[i], is_row_major(aq_layout)))); + bq_tensors.push_back(ck_tile::HostTensor( + ck_tile::host_tensor_descriptor(BQK, N, stride_BQs[i], is_row_major(bq_layout)))); + } std::cout << "gemm[" << i << "]" << " a_m_k: " << a_m_k_tensors[i].mDesc << " b_k_n: " << b_k_n_tensors[i].mDesc << " c_m_n: " << c_m_n_tensors[i].mDesc @@ -394,6 +413,17 @@ int run_grouped_gemm_example_with_layouts(int argc, bq_tensors[i], c_m_n_host_ref); } + else if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped) + { + ck_tile::reference_gemm_quant( + a_m_k_tensors[i], bq_tensors[i], b_k_n_tensors[i], c_m_n_host_ref); + } const float max_accumulated_value = *std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end()); @@ -441,42 +471,6 @@ int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int a QuantMode>( argc, argv, Row{}, Row{}, Col{}, Col{}, Row{}); } - else if(a_layout == "R" && b_layout == "R") - { - return run_grouped_gemm_example_with_layouts( - argc, argv, Row{}, Row{}, Row{}, Col{}, Row{}); - } - else if(a_layout == "C" && b_layout == "R") - { - return run_grouped_gemm_example_with_layouts( - argc, argv, Row{}, Row{}, Col{}, Col{}, Row{}); - } - else if(a_layout == "C" && b_layout == "C") - { - return run_grouped_gemm_example_with_layouts( - argc, argv, Col{}, Col{}, Col{}, Col{}, Row{}); - } else { throw std::runtime_error("Unsupported data layout configuration for A,B and C tensors!"); @@ -513,6 +507,41 @@ int run_grouped_gemm_example(int argc, char* argv[]) ck_tile::QuantType::RowColQuant>( a_layout, b_layout, argc, argv); } + else if(quant_mode == "bquant") + { + return run_gemm_example_prec_type, + ck_tile::fp8_t, + ck_tile::QuantType::BQuantGrouped>( + a_layout, b_layout, argc, argv); + } + else + { + throw std::runtime_error("Unsupported quantization mode!"); + } + } + if(data_type == "bf8") + { + if(quant_mode == "tensor") + { + return run_gemm_example_prec_type, + ck_tile::bf8_t, + ck_tile::QuantType::TensorQuant>( + a_layout, b_layout, argc, argv); + } + else if(quant_mode == "rowcol") + { + return run_gemm_example_prec_type, + ck_tile::bf8_t, + ck_tile::QuantType::RowColQuant>( + a_layout, b_layout, argc, argv); + } + else if(quant_mode == "bquant") + { + return run_gemm_example_prec_type, + ck_tile::bf8_t, + ck_tile::QuantType::BQuantGrouped>( + a_layout, b_layout, argc, argv); + } else { throw std::runtime_error("Unsupported quantization mode!"); diff --git a/example/ck_tile/17_grouped_gemm/run_grouped_gemm_example.inc b/example/ck_tile/17_grouped_gemm/run_grouped_gemm_example.inc index dbdbe80c5d..4eee165d66 100644 --- a/example/ck_tile/17_grouped_gemm/run_grouped_gemm_example.inc +++ b/example/ck_tile/17_grouped_gemm/run_grouped_gemm_example.inc @@ -70,23 +70,13 @@ float invoke_gemm(int n_warmup, } else { - if(GemmConfig::Preshuffle) - { - // not supported yet - throw std::runtime_error( - "Persistent grouped gemm with preshuffle is not supported yet"); - } - - // NOTE: With the persistent TileLoop kernel, we do not necessarily need to haveCollapse - // commentComment on line L74tenpercent commented on Sep 5, 2025 tenpercenton Sep 5, - // 2025ContributorMore actionsdid you intend to remove the comment?Write a replyResolve - // commentCode has comments. Press enter to view. the gemm problems known on the host. - // Instead, we can just pass the pointer to the kernel and let the workgroups figure out - // which tiles to work on. This is useful when the gemm problems are generated dynamically. - // In this example however, we generate the `kargs` using the known gemm_descs, - // and copy the gemm descriptions to the device memory. - // The contents of the memory pointed to by `kargs_ptr` pointer could be - // written by e.g. another kernel from earlier stage. + // NOTE: With the persistent TileLoop kernel, we do not necessarily need to have the gemm + // problems known on the host. Instead, we can just pass the pointer to the kernel and let + // the workgroups figure out which tiles to work on. This is useful when the gemm problems + // are generated dynamically. In this example however, we generate the `kargs` using the + // known gemm_descs, and copy the gemm descriptions to the device memory. The contents of + // the memory pointed to by `kargs_ptr` pointer could be written by e.g. another kernel from + // earlier stage. std::vector> kargs; void* kargs_ptr = gemm_workspace.GetDeviceBuffer(); diff --git a/example/ck_tile/20_grouped_convolution/CMakeLists.txt b/example/ck_tile/20_grouped_convolution/CMakeLists.txt index 10332137e2..e9614061e1 100644 --- a/example/ck_tile/20_grouped_convolution/CMakeLists.txt +++ b/example/ck_tile/20_grouped_convolution/CMakeLists.txt @@ -4,6 +4,9 @@ list(APPEND EXAMPLE_CONV_COMPILE_OPTIONS -mllvm -enable-noalias-to-md-conversion add_executable(tile_example_grouped_conv_fwd EXCLUDE_FROM_ALL grouped_convolution_forward.cpp) target_compile_options(tile_example_grouped_conv_fwd PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) +add_executable(tile_example_grouped_conv_fwd_bias_clamp EXCLUDE_FROM_ALL grouped_convolution_forward_bias_clamp.cpp) +target_compile_options(tile_example_grouped_conv_fwd_bias_clamp PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) + add_executable(tile_example_grouped_conv_bwd_weight EXCLUDE_FROM_ALL grouped_convolution_backward_weight.cpp) target_compile_options(tile_example_grouped_conv_bwd_weight PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) diff --git a/example/ck_tile/20_grouped_convolution/grouped_convolution_forward_bias_clamp.cpp b/example/ck_tile/20_grouped_convolution/grouped_convolution_forward_bias_clamp.cpp new file mode 100644 index 0000000000..ed215cb178 --- /dev/null +++ b/example/ck_tile/20_grouped_convolution/grouped_convolution_forward_bias_clamp.cpp @@ -0,0 +1,58 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include +#include +#include +#include +#include + +#include "ck_tile/host.hpp" +#include "grouped_convolution_utils.hpp" +#include "grouped_convolution_forward_invoker.hpp" +#include "run_grouped_convolution_fwd_bias_clamp_example.inc" + +template