diff --git a/example/12_reduce/reduce_blockwise_two_call.cpp b/example/12_reduce/reduce_blockwise_two_call.cpp index 0fdc0378b7..88556acaf7 100644 --- a/example/12_reduce/reduce_blockwise_two_call.cpp +++ b/example/12_reduce/reduce_blockwise_two_call.cpp @@ -88,7 +88,7 @@ using DeviceReduceInstance_2 = DeviceReduceMultiBlock in_1(inLengths_1); @@ -174,22 +174,22 @@ int main(int argc, char* argv[]) case 0: break; case 1: in_1.GenerateTensorValue(GeneratorTensor_1{1}, num_thread); - if(beta != 0.0f) + if(beta_ != 0.0f) out_ref.GenerateTensorValue(GeneratorTensor_1{1}, num_thread); break; case 2: in_1.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); - if(beta != 0.0f) + if(beta_ != 0.0f) out_ref.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); break; default: in_1.GenerateTensorValue(GeneratorTensor_3{-5.0, 5.0}, num_thread); - if(beta != 0.0f) + if(beta_ != 0.0f) out_ref.GenerateTensorValue(GeneratorTensor_3{-5.0, 5.0}, num_thread); } - if(beta != 0.0f) + if(beta_ != 0.0f) for(size_t i = 0; i < out_ref.mDesc.GetElementSpaceSize(); i++) out.mData[i] = out_ref.mData[i]; }; @@ -200,7 +200,7 @@ int main(int argc, char* argv[]) in_1_dev.ToDevice(in_1.mData.data()); - if(beta != 0.0f) + if(beta_ != 0.0f) out_dev.ToDevice(out.mData.data()); InElementwiseOperation in_elementwise_op; @@ -246,7 +246,7 @@ int main(int argc, char* argv[]) arrOutStrides, reduceDims, static_cast(alpha), - static_cast(beta), + static_cast(beta_), in_1.mData.data(), nullptr, out_ref.mData.data(), @@ -298,7 +298,7 @@ int main(int argc, char* argv[]) arrOutStrides, reduceDims_2, static_cast(alpha), - static_cast(beta), + static_cast(beta_), in_2_dev.GetDeviceBuffer(), nullptr, out_dev.GetDeviceBuffer(), diff --git a/example/ck_tile/18_flatmm/flatmm_basic.cpp b/example/ck_tile/18_flatmm/flatmm_basic.cpp index af46884a90..19593a0f04 100644 --- a/example/ck_tile/18_flatmm/flatmm_basic.cpp +++ b/example/ck_tile/18_flatmm/flatmm_basic.cpp @@ -8,6 +8,7 @@ #include #include #include +#include #include "ck_tile/host.hpp" #include "flatmm_basic.hpp" diff --git a/example/ck_tile/18_flatmm/grouped_flatmm.cpp b/example/ck_tile/18_flatmm/grouped_flatmm.cpp index 780a21ba14..6b0135d370 100644 --- a/example/ck_tile/18_flatmm/grouped_flatmm.cpp +++ b/example/ck_tile/18_flatmm/grouped_flatmm.cpp @@ -8,6 +8,7 @@ #include #include #include +#include #include "flatmm_basic.hpp" diff --git a/example/ck_tile/18_flatmm/run_grouped_flatmm_example.inc b/example/ck_tile/18_flatmm/run_grouped_flatmm_example.inc index 2027544709..f1889168ae 100644 --- a/example/ck_tile/18_flatmm/run_grouped_flatmm_example.inc +++ b/example/ck_tile/18_flatmm/run_grouped_flatmm_example.inc @@ -166,7 +166,7 @@ int run_contiguous_grouped_flatmm_example_with_layouts( } ck_tile::index_t M = - std::reduce(Ms.begin(), Ms.begin() + group_count, 0, [](auto acc, auto group_m) { + std::accumulate(Ms.begin(), Ms.begin() + group_count, 0, [](auto acc, auto group_m) { // round up to the multiple of BlockM return acc + (group_m + BlockM - 1) / BlockM * BlockM; }); diff --git a/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc b/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc index dd7e1abb02..c2954f3bf5 100644 --- a/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc +++ b/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc @@ -35,16 +35,19 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str { static_assert(std::is_same_v); constexpr bool IS_FP8BLOCKSCALE = - QuantMode == ck_tile::QuantType::ABQuantGrouped && BQuantGroupSize::kN == 128 && + QuantMode == ck_tile::QuantType::ABQuantGrouped && (std::is_same_v || std::is_same_v) && (std::is_same_v || std::is_same_v); constexpr bool transpose_c = GemmConfig::TransposeC; constexpr bool eight_warps = - IS_FP8BLOCKSCALE && BQuantGroupSize::kN == 128 && - (GemmConfig::M_Warp * GemmConfig::N_Warp * GemmConfig::K_Warp == 8) && +#ifdef CK_GFX950_SUPPORT + IS_FP8BLOCKSCALE && (GemmConfig::M_Warp * GemmConfig::N_Warp * GemmConfig::K_Warp == 8) && GemmConfig::K_Warp_Tile == 128; +#else + false; +#endif using ComputeDataType = std::conditional_t; diff --git a/include/ck_tile/core/config.hpp b/include/ck_tile/core/config.hpp index fed9209bad..ed102c86a8 100644 --- a/include/ck_tile/core/config.hpp +++ b/include/ck_tile/core/config.hpp @@ -73,7 +73,7 @@ #define CK_TILE_FLOAT_TO_BFLOAT16_RTA_ASM 4 #ifndef CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT -#define CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT CK_TILE_FLOAT_TO_BFLOAT16_TRUNCATE +#define CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT CK_TILE_FLOAT_TO_BFLOAT16_STANDARD #endif #define CK_TILE_FLOAT_TO_FP8_STANDARD 0 diff --git a/include/ck_tile/host/check_err.hpp b/include/ck_tile/host/check_err.hpp index a2f6728316..96ec7bec4a 100644 --- a/include/ck_tile/host/check_err.hpp +++ b/include/ck_tile/host/check_err.hpp @@ -137,7 +137,7 @@ CK_TILE_HOST double get_absolute_threshold(const double max_possible_num, int>::value, "Warning: Unhandled ComputeDataType for setting up the absolute threshold!"); - auto expo = std::log2(std::abs(max_possible_num)); + auto expo = std::floor(std::log2(std::abs(max_possible_num))); double compute_error = 0; if constexpr(is_any_of::value) { @@ -158,7 +158,7 @@ CK_TILE_HOST double get_absolute_threshold(const double max_possible_num, } else { - output_error = std::pow(2, expo - numeric_traits::mant) * 0.5; + output_error = std::pow(2, expo - numeric_traits::mant) * 1.0; } double midway_error = std::max(compute_error, output_error); diff --git a/include/ck_tile/host/device_prop.hpp b/include/ck_tile/host/device_prop.hpp index f28d7df00d..5f021d7bc5 100644 --- a/include/ck_tile/host/device_prop.hpp +++ b/include/ck_tile/host/device_prop.hpp @@ -65,11 +65,7 @@ inline bool is_gfx12_supported() return get_device_name() == "gfx1200" || get_device_name() == "gfx1201"; } -inline bool is_gfx95_supported() -{ - // Check if load transpose is supported. - return get_device_name() == "gfx950"; -} +inline bool is_gfx95_supported() { return get_device_name() == "gfx950"; } inline size_t get_num_cus() { diff --git a/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp b/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp index e9a11909c7..b31f8ba02a 100644 --- a/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp +++ b/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp @@ -116,13 +116,13 @@ struct CShuffleEpilogue static constexpr index_t isCTransposed = Problem::isCTransposed; static constexpr bool FixedVectorSize = Problem::FixedVectorSize; static constexpr bool TiledMMAPermuteN = Problem::TiledMMAPermuteN; -#ifdef __gfx9__ - static constexpr bool AsyncPipeline = (MWave * NWave == 8); +#ifdef __gfx95__ + static constexpr bool EightWave = (MWave * NWave == 8); #else - static constexpr bool AsyncPipeline = false; + static constexpr bool EightWave = false; #endif static constexpr index_t BlockedXDLN_PerWarp = - AsyncPipeline ? kNPerBlock / NWave / NPerXdl : Problem::BlockedXDLN_PerWarp; + EightWave ? kNPerBlock / NWave / NPerXdl : Problem::BlockedXDLN_PerWarp; static constexpr bool DoubleSmemBuffer = Problem::DoubleSmemBuffer; static constexpr index_t VectorSizeC = Problem::VectorSizeC; static constexpr index_t MPerIteration = MPerXdl * MWave; @@ -447,7 +447,7 @@ struct CShuffleEpilogue if constexpr(is_950 || is_any_of::value || is_any_of::value) { - if constexpr(AsyncPipeline) + if constexpr(EightWave) { return tile_distribution_encoding< sequence<>, diff --git a/include/ck_tile/ops/flatmm/kernel/flatmm_kernel.hpp b/include/ck_tile/ops/flatmm/kernel/flatmm_kernel.hpp index 42dab68e91..b2b36adb1e 100644 --- a/include/ck_tile/ops/flatmm/kernel/flatmm_kernel.hpp +++ b/include/ck_tile/ops/flatmm/kernel/flatmm_kernel.hpp @@ -780,29 +780,31 @@ struct FlatmmKernel const SplitKBatchOffset& splitk_batch_offset, const index_t block_idx_m) { - constexpr int ScaleGranularityM = decltype(kargs.scale_m_ptr)::GranularityMN; - constexpr int ScaleGranularityKA = decltype(kargs.scale_m_ptr)::GranularityK; + constexpr int GM = decltype(kargs.scale_m_ptr)::GranularityMN; + constexpr int GK = decltype(kargs.scale_m_ptr)::GranularityK; - auto scale_stride_m = ScaleGranularityM == 0 ? 0 // per-tensor scale - : 1; // per-token scale + static_assert(GM != -1, + "MakeScaleMWindow should only be instantiated when scale is enabled"); + + // per-tensor (GM==0) -> Mdim = 1, stride 0 + const index_t m_dim = (GM == 0) ? 1 : (kargs.M / GM); + const index_t m_stride = (GM == 0) ? 0 : 1; + + const index_t k_dim = (GK == 0) ? 1 : (splitk_batch_offset.splitted_k / GK); + const index_t k_stride = 0; // your original code keeps K stride 0 - // Step 1: Create tensor view const auto scale_m_view = make_naive_tensor_view( kargs.scale_m_ptr.ptr, - make_tuple(kargs.M / ScaleGranularityM, - ScaleGranularityKA == 0 - ? 1 - : (splitk_batch_offset.splitted_k / ScaleGranularityKA)), - make_tuple(scale_stride_m, 0), - number < ScaleGranularityM == 1 ? FlatmmPipeline::GetVectorSizeA() : 1 > {}, + make_tuple(m_dim, k_dim), + make_tuple(m_stride, k_stride), + number < (GM == 1) ? FlatmmPipeline::GetVectorSizeA() : 1 > {}, number<1>{}); - // Step 2: Create tile window + // Window extents: if GM==0, we still just broadcast from [0,*] return make_tile_window(scale_m_view, make_tuple(number{}, - number < ScaleGranularityKA == 0 - ? TilePartitioner::NPerBlock - : TilePartitioner::KPerBlock > {}), + number < (GK == 0) ? TilePartitioner::NPerBlock + : TilePartitioner::KPerBlock > {}), {block_idx_m, 0}); } @@ -811,27 +813,29 @@ struct FlatmmKernel const SplitKBatchOffset& splitk_batch_offset, const index_t block_idx_n) { - constexpr int ScaleGranularityN = decltype(kargs.scale_n_ptr)::GranularityMN; - constexpr int ScaleGranularityKB = decltype(kargs.scale_n_ptr)::GranularityK; + constexpr int GN = decltype(kargs.scale_n_ptr)::GranularityMN; + constexpr int GK = decltype(kargs.scale_n_ptr)::GranularityK; - auto scale_stride_n = ScaleGranularityN == 0 ? 0 // per-tensor scale - : 1; // per-channel scale + static_assert(GN != -1, + "MakeScaleNWindow should only be instantiated when scale is enabled"); + + // per-tensor (GN==0) -> Ndim = 1, stride 0 + const index_t n_dim = (GN == 0) ? 1 : (kargs.N / GN); + const index_t n_stride = (GN == 0) ? 0 : 1; + + const index_t k_dim = (GK == 0) ? 1 : (splitk_batch_offset.splitted_k / GK); + const index_t k_stride = 0; - // Step 1: Create tensor view const auto scale_n_view = make_naive_tensor_view( kargs.scale_n_ptr.ptr, - make_tuple( - ScaleGranularityKB == 0 ? 1 : (splitk_batch_offset.splitted_k / ScaleGranularityKB), - kargs.N / ScaleGranularityN), - make_tuple(0, scale_stride_n), - number < ScaleGranularityN == 1 ? FlatmmPipeline::GetVectorSizeB() : 1 > {}, + make_tuple(k_dim, n_dim), + make_tuple(k_stride, n_stride), + number < (GN == 1) ? FlatmmPipeline::GetVectorSizeB() : 1 > {}, number<1>{}); - // Step 2: Create tile window return make_tile_window(scale_n_view, - make_tuple(number < ScaleGranularityKB == 0 - ? TilePartitioner::MPerBlock - : TilePartitioner::KPerBlock > {}, + make_tuple(number < (GK == 0) ? TilePartitioner::MPerBlock + : TilePartitioner::KPerBlock > {}, number{}), {0, block_idx_n}); } @@ -854,8 +858,6 @@ struct FlatmmKernel MakeABlockWindow(a_ptr, kargs, splitk_batch_offset.splitted_k, block_idx_m); const auto& b_flat_block_window = MakeBFlatBlockWindow(b_flat_ptr, kargs, block_idx_n); const auto& ds_block_window = MakeDBlockWindows(ds_ptr, kargs, block_idx_m, block_idx_n); - const auto& scale_m_window = MakeScaleMWindow(kargs, splitk_batch_offset, block_idx_m); - const auto& scale_n_window = MakeScaleNWindow(kargs, splitk_batch_offset, block_idx_n); const index_t num_loop = TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k); @@ -866,6 +868,8 @@ struct FlatmmKernel // Run Epilogue Pipeline with k_batch dispatching if constexpr(ScaleM::GranularityMN != -1 || ScaleN::GranularityMN != -1) { + const auto& scale_m_window = MakeScaleMWindow(kargs, splitk_batch_offset, block_idx_m); + const auto& scale_n_window = MakeScaleNWindow(kargs, splitk_batch_offset, block_idx_n); if(kargs.k_batch == 1) { auto e_block_window = MakeEBlockWindow( diff --git a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_bquant_cr_async.hpp b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_bquant_cr_async.hpp index 04d177d4d6..36c911f060 100755 --- a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_bquant_cr_async.hpp +++ b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_bquant_cr_async.hpp @@ -76,7 +76,8 @@ struct ABQuantBlockUniversalGemmAsBsCrAsync : public BlockGemmQuantBase static constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WarpGemm::kN); static constexpr index_t KIterPerWarp = KPerBlock / (KWarp * WarpGemm::kK); - static constexpr bool PreshuffleQuant = Problem::Traits::PreshuffleQuant; + static constexpr bool APreshuffleQuant = Problem::Traits::APreshuffleQuant; + static constexpr bool BPreshuffleQuant = Problem::Traits::BPreshuffleQuant; static constexpr index_t QScalesPerBlockRow = integer_divide_ceil(KPerBlock / KWarp, BQuantGroupSize::kK); @@ -158,7 +159,8 @@ struct ABQuantBlockUniversalGemmAsBsCrAsync : public BlockGemmQuantBase using BWarpTensor = typename WarpGemm::BWarpTensor; using CWarpTensor = typename WarpGemm::CWarpTensor; - static constexpr bool PreshuffleQuant = Traits::PreshuffleQuant; + static constexpr bool APreshuffleQuant = Traits::APreshuffleQuant; + static constexpr bool BPreshuffleQuant = Traits::BPreshuffleQuant; static_assert(std::is_same_v); @@ -364,7 +366,7 @@ struct ABQuantBlockUniversalGemmAsBsCrAsync : public BlockGemmQuantBase AQPickerCommon aq_picker( aq_block_tensor); - if constexpr(PreshuffleQuant) + if constexpr(BPreshuffleQuant) { constexpr index_t reg_offset = nIter; auto pull_from_lane = diff --git a/test/ck_tile/grouped_gemm/test_grouped_gemm.cpp b/test/ck_tile/grouped_gemm/test_grouped_gemm.cpp index 2f37ef1bc0..0f8e25e03b 100644 --- a/test/ck_tile/grouped_gemm/test_grouped_gemm.cpp +++ b/test/ck_tile/grouped_gemm/test_grouped_gemm.cpp @@ -31,11 +31,6 @@ using KernelTypes = ::testing::Types< std::tuple< Col, Row, Row, F16, F16, F32, F16, True>, std::tuple< Col, Row, Row, F16, F16, F32, F16, False>, - std::tuple< Row, Col, Row, BF16, BF16, F32, BF16, True>, - std::tuple< Row, Col, Row, BF16, BF16, F32, BF16, False>, - std::tuple< Row, Col, Row, BF16, BF16, F32, BF16, True>, - std::tuple< Row, Col, Row, BF16, BF16, F32, BF16, False>, - std::tuple< Col, Col, Row, BF16, BF16, F32, BF16, True>, std::tuple< Col, Col, Row, BF16, BF16, F32, BF16, False>, std::tuple< Row, Row, Row, BF16, BF16, F32, BF16, True>,