diff --git a/example/ck_tile/12_smoothquant/CMakeLists.txt b/example/ck_tile/12_smoothquant/CMakeLists.txt index 09a56c6dab..3849833aca 100644 --- a/example/ck_tile/12_smoothquant/CMakeLists.txt +++ b/example/ck_tile/12_smoothquant/CMakeLists.txt @@ -18,7 +18,7 @@ function (add_smoothquant_example TARGET_NAME MAIN_SRC) target_compile_options(${TARGET_NAME} PRIVATE ${COMPILE_OPTIONS}) endfunction(add_smoothquant_example TARGET_NAME MAIN_SRC) -file(GLOB INSTANCE_SRCS instances/*.cpp) -add_smoothquant_example(tile_smoothquant smoothquant.cpp ${INSTANCE_SRCS}) add_smoothquant_example(tile_example_smoothquant example_smoothquant.cpp) +file(GLOB INSTANCE_SRCS instances/*.cpp) +add_smoothquant_example(tile_smoothquant smoothquant.cpp ${INSTANCE_SRCS}) diff --git a/include/ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_one_pass.hpp b/include/ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_one_pass.hpp index 12a15938ae..24f35d3636 100644 --- a/include/ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_one_pass.hpp +++ b/include/ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_one_pass.hpp @@ -28,8 +28,9 @@ struct AddRmsnorm2dRdquantFwdPipelineOnePass static constexpr bool kSaveX = Problem::kSaveX; static constexpr bool kNeedCrossWarpSync = Problem::kNeedCrossWarpSync; - static constexpr bool kPadM = false; // TODO - BlockAddRmsnorm2dRdquantFwdProblem::kPadM - static constexpr bool kPadN = Problem::kPadN; + static constexpr bool kPadM = false; // TODO - BlockAddRmsnorm2dRdquantFwdProblem::kPadM + static constexpr bool kPadN = Problem::kPadN; + static constexpr bool UseMax3 = true; // TODO - Move to trait static constexpr const char* name = []() { if constexpr(kNeedCrossWarpSync) @@ -69,9 +70,16 @@ struct AddRmsnorm2dRdquantFwdPipelineOnePass auto reduce_square_sum_func = ReduceOp::SquareAdd{}; auto reduce_sum_func = ReduceOp::Add{}; auto reduce_absmax_func = ReduceOp::AbsMax{}; - auto reduce_max_func = ReduceOp::Max{}; - auto block_reduce2d = Policy::template GetBlockReduce2d(); - auto block_reduce2d_sync = Policy::template GetBlockReduce2dSync(); + auto reduce_absmax3_func = [](auto acc_, auto v_0_, auto v_1_) { + float rtn; + asm volatile("v_max3_f32 %0, %1, abs(%2), abs(%3)" + : "=v"(rtn) + : "v"(acc_), "v"(v_0_), "v"(v_1_)); + return rtn; + }; + auto reduce_max_func = ReduceOp::Max{}; + auto block_reduce2d = Policy::template GetBlockReduce2d(); + auto block_reduce2d_sync = Policy::template GetBlockReduce2dSync(); auto block_reduce2d_cross_warp_sync = Policy::template GetBlockReduce2dCrossWarpSync(); @@ -116,8 +124,23 @@ struct AddRmsnorm2dRdquantFwdPipelineOnePass }); // compute absmax, each-thread->cross-lane->cross-warp - auto absmax = block_reduce2d( - y, reduce_absmax_func.GetIdentityValue(), reduce_absmax_func); + auto absmax = [&]() { + constexpr auto x_size_per_row = + x.get_tile_distribution().get_ys_to_d_descriptor().get_lengths().at(number<1>{}); + if constexpr(UseMax3 && std::is_same_v && + x_size_per_row % 2 == 0) + { + return block_reduce2d(y, + reduce_absmax_func.GetIdentityValue(), + reduce_absmax3_func, + sequence<1, 2>{}); + } + else + { + return block_reduce2d( + y, reduce_absmax_func.GetIdentityValue(), reduce_absmax_func); + } + }(); block_reduce2d_sync(absmax, reduce_max_func); block_reduce2d_cross_warp_sync(absmax, smem, reduce_max_func); diff --git a/include/ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_three_pass.hpp b/include/ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_three_pass.hpp index 0dbb20645a..aec7368e27 100644 --- a/include/ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_three_pass.hpp +++ b/include/ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_three_pass.hpp @@ -28,8 +28,9 @@ struct AddRmsnorm2dRdquantFwdPipelineThreePass static constexpr bool kSaveX = Problem::kSaveX; static constexpr bool kNeedCrossWarpSync = Problem::kNeedCrossWarpSync; - static constexpr bool kPadM = false; // TODO - BlockAddRmsnorm2dRdquantFwdProblem::kPadM - static constexpr bool kPadN = Problem::kPadN; + static constexpr bool kPadM = false; // TODO - BlockAddRmsnorm2dRdquantFwdProblem::kPadM + static constexpr bool kPadN = Problem::kPadN; + static constexpr bool UseMax3 = true; // TODO - Move to trait static constexpr const char* name = []() { if constexpr(kNeedCrossWarpSync) @@ -76,9 +77,16 @@ struct AddRmsnorm2dRdquantFwdPipelineThreePass auto reduce_square_sum_func = ReduceOp::SquareAdd{}; auto reduce_sum_func = ReduceOp::Add{}; auto reduce_absmax_func = ReduceOp::AbsMax{}; - auto reduce_max_func = ReduceOp::Max{}; - auto block_reduce2d = Policy::template GetBlockReduce2d(); - auto block_reduce2d_sync = Policy::template GetBlockReduce2dSync(); + auto reduce_absmax3_func = [](auto acc_, auto v_0_, auto v_1_) { + float rtn; + asm volatile("v_max3_f32 %0, %1, abs(%2), abs(%3)" + : "=v"(rtn) + : "v"(acc_), "v"(v_0_), "v"(v_1_)); + return rtn; + }; + auto reduce_max_func = ReduceOp::Max{}; + auto block_reduce2d = Policy::template GetBlockReduce2d(); + auto block_reduce2d_sync = Policy::template GetBlockReduce2dSync(); auto block_reduce2d_cross_warp_sync = Policy::template GetBlockReduce2dCrossWarpSync(); @@ -177,7 +185,13 @@ struct AddRmsnorm2dRdquantFwdPipelineThreePass y(idx) = type_convert(y_); }); - block_reduce2d(y, absmax, reduce_absmax_func); + constexpr auto x_size_per_row = + x.get_tile_distribution().get_ys_to_d_descriptor().get_lengths().at(number<1>{}); + if constexpr(UseMax3 && std::is_same_v && + x_size_per_row % 2 == 0) + block_reduce2d(y, absmax, reduce_absmax3_func, sequence<1, 2>{}); + else + block_reduce2d(y, absmax, reduce_absmax_func); if constexpr(kSaveX) move_tile_window(x_window, {0, -Block_N}); diff --git a/include/ck_tile/ops/smoothquant/pipeline/smoothquant_pipeline_one_pass.hpp b/include/ck_tile/ops/smoothquant/pipeline/smoothquant_pipeline_one_pass.hpp index d5b3780dea..b2fc240c1d 100644 --- a/include/ck_tile/ops/smoothquant/pipeline/smoothquant_pipeline_one_pass.hpp +++ b/include/ck_tile/ops/smoothquant/pipeline/smoothquant_pipeline_one_pass.hpp @@ -25,6 +25,7 @@ struct SmoothquantPipelineOnePass static constexpr bool kNeedCrossWarpSync = Problem::kNeedCrossWarpSync; static constexpr bool kPadM = false; // TODO - BlockSmoothquantProblem::kPadM static constexpr bool kPadN = Problem::kPadN; + static constexpr bool UseMax3 = true; // TODO - Move to trait static constexpr const char* name = []() { if constexpr(kNeedCrossWarpSync) @@ -52,7 +53,15 @@ struct SmoothquantPipelineOnePass xscale_window_, Policy::template MakeXScaleBlockTileDistribution()); auto reduce_absmax_func = ReduceOp::AbsMax{}; - auto reduce_max_func = ReduceOp::Max{}; + auto reduce_absmax3_func = [](auto acc_, auto v_0_, auto v_1_) { + float rtn; + asm volatile("v_max3_f32 %0, %1, abs(%2), abs(%3)" + : "=v"(rtn) + : "v"(acc_), "v"(v_0_), "v"(v_1_)); + return rtn; + }; + auto reduce_max_func = ReduceOp::Max{}; + auto block_reduce2d = Policy::template GetBlockReduce2d(); auto block_reduce2d_sync = Policy::template GetBlockReduce2dSync(); auto block_reduce2d_cross_warp_sync = @@ -68,8 +77,23 @@ struct SmoothquantPipelineOnePass xscale); // compute absmax, cross-lane->cross-warp - auto absmax = block_reduce2d( - y, reduce_absmax_func.GetIdentityValue(), reduce_absmax_func); + auto absmax = [&]() { + constexpr auto x_size_per_row = + x.get_tile_distribution().get_ys_to_d_descriptor().get_lengths().at(number<1>{}); + if constexpr(UseMax3 && std::is_same_v && + x_size_per_row % 2 == 0) + { + return block_reduce2d(y, + reduce_absmax_func.GetIdentityValue(), + reduce_absmax3_func, + sequence<1, 2>{}); + } + else + { + return block_reduce2d( + y, reduce_absmax_func.GetIdentityValue(), reduce_absmax_func); + } + }(); block_reduce2d_sync(absmax, reduce_max_func); block_reduce2d_cross_warp_sync(absmax, smem, reduce_max_func); diff --git a/include/ck_tile/ops/smoothquant/pipeline/smoothquant_pipeline_two_pass.hpp b/include/ck_tile/ops/smoothquant/pipeline/smoothquant_pipeline_two_pass.hpp index 7878ef1d34..9e9df663b9 100644 --- a/include/ck_tile/ops/smoothquant/pipeline/smoothquant_pipeline_two_pass.hpp +++ b/include/ck_tile/ops/smoothquant/pipeline/smoothquant_pipeline_two_pass.hpp @@ -25,6 +25,7 @@ struct SmoothquantPipelineTwoPass static constexpr bool kNeedCrossWarpSync = Problem::kNeedCrossWarpSync; static constexpr bool kPadM = false; // TODO - BlockSmoothquantProblem::kPadM static constexpr bool kPadN = Problem::kPadN; + static constexpr bool UseMax3 = true; // TODO - Move to trait static constexpr const char* name = []() { if constexpr(kNeedCrossWarpSync) @@ -56,6 +57,13 @@ struct SmoothquantPipelineTwoPass __builtin_amdgcn_readfirstlane(integer_divide_ceil(row_size, Block_N)); auto reduce_absmax_func = ReduceOp::AbsMax{}; + auto reduce_absmax3_func = [](auto acc_, auto v_0_, auto v_1_) { + float rtn; + asm volatile("v_max3_f32 %0, %1, abs(%2), abs(%3)" + : "=v"(rtn) + : "v"(acc_), "v"(v_0_), "v"(v_1_)); + return rtn; + }; auto reduce_max_func = ReduceOp::Max{}; auto block_reduce2d = Policy::template GetBlockReduce2d(); auto block_reduce2d_sync = Policy::template GetBlockReduce2dSync(); @@ -77,7 +85,13 @@ struct SmoothquantPipelineTwoPass x, xscale); - block_reduce2d(y, absmax, reduce_absmax_func); + constexpr auto x_size_per_row = + x.get_tile_distribution().get_ys_to_d_descriptor().get_lengths().at(number<1>{}); + if constexpr(UseMax3 && std::is_same_v && + x_size_per_row % 2 == 0) + block_reduce2d(y, absmax, reduce_absmax3_func, sequence<1, 2>{}); + else + block_reduce2d(y, absmax, reduce_absmax_func); move_tile_window(x_window, {0, Block_N}); move_tile_window(xscale_window, {Block_N});