diff --git a/example/ck_tile/38_block_scale_gemm/CMakeLists.txt b/example/ck_tile/38_block_scale_gemm/CMakeLists.txt index e7a218152d..97e719177f 100644 --- a/example/ck_tile/38_block_scale_gemm/CMakeLists.txt +++ b/example/ck_tile/38_block_scale_gemm/CMakeLists.txt @@ -20,7 +20,9 @@ if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx12") gemm_aquant_quantgrouped_preshufflequant.cpp gemm_bquant_quantgrouped_bf8i4.cpp gemm_bquant_quantgrouped_fp8i4.cpp - gemm_bquant_quantgrouped_bf16mxfp4.cpp + gemm_bquant_quantgrouped_mx_bf16fp4.cpp + gemm_bquant_quantgrouped_mx_bf16bf8.cpp + gemm_bquant_quantgrouped_mx_bf16bf16.cpp gemm_bquant_quantgrouped_bf8.cpp gemm_bquant_quantgrouped_fp8.cpp gemm_bquant_quantgrouped_preshuffleb_bf8i4.cpp diff --git a/example/ck_tile/38_block_scale_gemm/README.md b/example/ck_tile/38_block_scale_gemm/README.md index eb36ae5800..accac6f083 100644 --- a/example/ck_tile/38_block_scale_gemm/README.md +++ b/example/ck_tile/38_block_scale_gemm/README.md @@ -53,7 +53,7 @@ args: -stride_b Tensor B stride (default:0) -stride_c Tensor C stride (default:0) -v 0: No validation, 1: Validation on CPU, 2: Validation on GPU (default:1) - -prec Data type. For AQuant: fp8, bf8, i4fp8, or i4bf8; for Bquant: fp8, bf8, fp8i4, bf8i4, or bf16fp4 (default for both AQuant and Bquant: fp8) + -prec Data type. For AQuant: fp8, bf8, i4fp8, or i4bf8; for Bquant: fp8, bf8, fp8i4, bf8i4, mxbf16bf16, mxbf16bf8 or mxbf16fp4 (default for both AQuant and Bquant: fp8) -warmup Number of iterations before benchmarking the kernel (default:50) -repeat Number of iterations to benchmark the kernel (default:1000) -timer gpu:gpu timer, cpu:cpu timer (default:gpu) diff --git a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_mx_bf16bf16.cpp b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_mx_bf16bf16.cpp new file mode 100644 index 0000000000..e1a64c8656 --- /dev/null +++ b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_mx_bf16bf16.cpp @@ -0,0 +1,35 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "run_gemm_quant_example.inc" + +template +using GemmConfig = GemmConfigQuantPrefill; + +#define RUN_GEMM_EXAMPLE_PREC_TYPE \ + run_gemm_example_prec_type, \ + TypeConfig, \ + QuantGroupSize, \ + ck_tile::QuantType::BQuantGrouped>(arg_parser); + +static auto _ = []() { + auto& lut = get_kernel_lut(); + using TypeConfig = decltype(GemmQuantTypeConfig{}); + + lut[hash_multiple_strings( + {"mxbf16bf16", "bquant", "non-preshuffleb", "non-preshufflequant", "1x1x32"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; + lut[hash_multiple_strings( + {"mxbf16bf16", "bquant", "non-preshuffleb", "non-preshufflequant", "1x1x64"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; + return 0; +}(); diff --git a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_mx_bf16bf8.cpp b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_mx_bf16bf8.cpp new file mode 100644 index 0000000000..0eb2a0ce34 --- /dev/null +++ b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_mx_bf16bf8.cpp @@ -0,0 +1,34 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "run_gemm_quant_example.inc" + +using GemmConfig = GemmConfigMixedPrecision; + +#define RUN_GEMM_EXAMPLE_PREC_TYPE \ + run_gemm_example_prec_type(arg_parser); + +static auto _ = []() { + auto& lut = get_kernel_lut(); + using TypeConfig = decltype(GemmQuantTypeConfig{}); + + lut[hash_multiple_strings( + {"mxbf16bf8", "bquant", "non-preshuffleb", "non-preshufflequant", "1x1x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; + lut[hash_multiple_strings( + {"mxbf16bf8", "bquant", "non-preshuffleb", "non-preshufflequant", "1x1x64"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; + return 0; +}(); diff --git a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_bf16mxfp4.cpp b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_mx_bf16fp4.cpp similarity index 67% rename from example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_bf16mxfp4.cpp rename to example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_mx_bf16fp4.cpp index b8eb670135..1f48609a1f 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_bf16mxfp4.cpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_mx_bf16fp4.cpp @@ -6,33 +6,33 @@ template using GemmConfig = GemmConfigQuantPrefill; -#define RUN_GEMM_EXAMPLE_PREC_TYPE \ - run_gemm_example_prec_type, \ - TypeConfig, \ - QuantGroupSize, \ +#define RUN_GEMM_EXAMPLE_PREC_TYPE \ + run_gemm_example_prec_type, \ + TypeConfig, \ + QuantGroupSize, \ ck_tile::QuantType::BQuantGrouped>(arg_parser); static auto _ = []() { auto& lut = get_kernel_lut(); using TypeConfig = decltype(GemmQuantTypeConfig{}); + ck_tile::e8m0_t>{}); lut[hash_multiple_strings( - {"bf16fp4", "bquant", "non-preshuffleb", "non-preshufflequant", "1x1x32"})] = + {"mxbf16fp4", "bquant", "non-preshuffleb", "non-preshufflequant", "1x1x32"})] = [](const ck_tile::ArgParser& arg_parser) { using QuantGroupSize = ck_tile::QuantGroupShape>; return RUN_GEMM_EXAMPLE_PREC_TYPE; }; lut[hash_multiple_strings( - {"bf16fp4", "bquant", "non-preshuffleb", "non-preshufflequant", "1x1x64"})] = + {"mxbf16fp4", "bquant", "non-preshuffleb", "non-preshufflequant", "1x1x64"})] = [](const ck_tile::ArgParser& arg_parser) { using QuantGroupSize = ck_tile::QuantGroupShape>; return RUN_GEMM_EXAMPLE_PREC_TYPE; }; lut[hash_multiple_strings( - {"bf16fp4", "bquant", "non-preshuffleb", "non-preshufflequant", "1x1x128"})] = + {"mxbf16fp4", "bquant", "non-preshuffleb", "non-preshufflequant", "1x1x128"})] = [](const ck_tile::ArgParser& arg_parser) { using QuantGroupSize = ck_tile::QuantGroupShape>; return RUN_GEMM_EXAMPLE_PREC_TYPE; diff --git a/example/ck_tile/38_block_scale_gemm/gemm_quant.cpp b/example/ck_tile/38_block_scale_gemm/gemm_quant.cpp index cc4302a992..dc4d1ad814 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_quant.cpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_quant.cpp @@ -32,7 +32,7 @@ auto create_args(int argc, char* argv[]) .insert("prec", "fp8", "Data type. For AQuant: fp8, bf8, i4fp8, or i4bf8; for Bquant: fp8, bf8, fp8i4, " - "or bf8i4; for ABQuant: fp8, bf8, fp4") + " mxbf16bf16, mxbf16bf8, mxbf16fp4 or bf8i4; for ABQuant: fp8, bf8, fp4") .insert("warmup", "50", "Number of iterations before benchmarking the kernel") .insert("repeat", "1000", "Number of iterations to benchmark the kernel") .insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer") diff --git a/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp b/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp index 9a51c786b6..db3f4c6e17 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp @@ -45,7 +45,7 @@ auto calculate_rtol_atol(const ck_tile::index_t K, const float max_accumulated_value) { using ComputeType = std::conditional_t< - std::is_same_v, + std::is_same_v, ADataType, std::conditional_t>; // Calculate thresholds @@ -278,6 +278,24 @@ struct GemmConfigABQuantPrefill : public GemmConfigQuantPrefill static constexpr bool TransposeC = true; }; +// Used for A=16bit and B=8bit. The warp tile has KPack=16 +// Matrix A: Vectorsize = 8, KPack=16 -> LDS read/write vectorsize = 8 (128bit) +// Matrix B: Vectorsize = 16, KPack=16 -> LDS read/write vectorsize = 16 (128bit) +struct GemmConfigMixedPrecision : public GemmConfigBase +{ + static constexpr ck_tile::index_t M_Tile = 128; + static constexpr ck_tile::index_t N_Tile = 128; + static constexpr ck_tile::index_t K_Tile = 128; + + static constexpr ck_tile::index_t M_Warp = 1; + static constexpr ck_tile::index_t N_Warp = 4; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 16; + static constexpr ck_tile::index_t N_Warp_Tile = 16; + static constexpr ck_tile::index_t K_Warp_Tile = 64; +}; + template struct GemmConfigEightWarps : public GemmConfigABQuantPrefill { diff --git a/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc b/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc index c2954f3bf5..da14f85c2c 100644 --- a/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc +++ b/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc @@ -108,6 +108,10 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) { constexpr bool has_hot_loop_v = has_hot_loop_.value; constexpr auto tail_number_v = tail_number_.value; + constexpr auto b_cast_policy = + std::is_same_v + ? ck_tile::CastPolicy::BeforeLDSWrite + : ck_tile::CastPolicy::AfterLDSRead; // row-col and tensor quants use the regular pipeline, A/B/AB quants use their own using PipelineProblem = std::conditional_t< @@ -150,7 +154,8 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str ComputeDataType, GemmConfig::Scheduler, has_hot_loop_v, - tail_number_v>, + tail_number_v, + b_cast_policy>, ck_tile::GemmABQuantPipelineProblem, - std::conditional_t< - std::is_same_v, - ck_tile::MxFp4GemmPipelineAgBgCrCompV3, - ck_tile::BQuantGemmPipelineAgBgCrCompV3>>; + std::conditional_t, + ck_tile::MicroscaleGemmPipelineAgBgCrCompV3, + ck_tile::BQuantGemmPipelineAgBgCrCompV3>>; using ABQuantPipeline = std::conditional_t< eight_warps, @@ -257,11 +261,7 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str ck_tile::HostTensor a_m(ck_tile::host_tensor_descriptor( args.M, args.K, args.stride_A, is_row_major(ALayout{}))); ck_tile::HostTensor b_n(ck_tile::host_tensor_descriptor( - std::is_same_v ? args.K / 2 - : args.K, - args.N, - args.stride_B, - is_row_major(BLayout{}))); + args.K, args.N, args.stride_B, is_row_major(BLayout{}))); auto size_a_buffer = a_m.get_element_space_size_in_bytes(); auto size_b_buffer = b_n.get_element_space_size_in_bytes(); @@ -495,11 +495,7 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser, int rotating_count = arg_parser.get_int("rotating_count"); stride_A = ck_tile::get_default_stride(M, K, stride_A, is_row_major(a_layout)); - stride_B = ck_tile::get_default_stride( - (std::is_same_v) ? (K / 2) : K, - N, - stride_B, - is_row_major(b_layout)); + stride_B = ck_tile::get_default_stride(K, N, stride_B, is_row_major(b_layout)); stride_C = ck_tile::get_default_stride(M, N, stride_C, is_row_major(CLayout{})); // Conditional stride calculation based on QuantMode @@ -531,11 +527,8 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser, ck_tile::HostTensor a_m_k( ck_tile::host_tensor_descriptor(M, K, stride_A, is_row_major(a_layout))); - ck_tile::HostTensor b_k_n(ck_tile::host_tensor_descriptor( - (std::is_same_v) ? (K / 2) : K, - N, - stride_B, - is_row_major(b_layout))); + ck_tile::HostTensor b_k_n( + ck_tile::host_tensor_descriptor(K, N, stride_B, is_row_major(b_layout))); ck_tile::HostTensor c_m_n_dev_result( ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{}))); @@ -575,18 +568,31 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser, if constexpr(std::is_same_v) { ck_tile::FillUniformDistribution{-5.0f, 5.0f, fill_seed(gen)}(b_k_n); - ck_tile::FillUniformDistribution{-2.0f, 2.0f, fill_seed(gen)}( - *bq_tensor_ptr); } - else if constexpr(std::is_same_v) + else if constexpr(std::is_same_v) { ck_tile::FillUniformDistribution{-5.0f, 5.0f, fill_seed(gen)}(b_k_n); - ck_tile::FillUniformDistribution{125.f, 130.f, fill_seed(gen)}( - *bq_tensor_ptr); } else { ck_tile::FillUniformDistribution{-2.0f, 3.0f, fill_seed(gen)}(b_k_n); + } + + if constexpr(std::is_same_v) + { + auto gen_scales = [&](auto& scales, float range_min, float range_max) { + // e8m0_t is basically an exponent of float32 + ck_tile::HostTensor pow2(scales.get_lengths()); + ck_tile::FillUniformDistributionIntegerValue{ + range_min, range_max, fill_seed(gen)}(pow2); + scales.ForEach([&](auto& self, const auto& i) { + self(i) = static_cast(std::exp2(pow2(i))); + }); + }; + gen_scales(*bq_tensor_ptr, -2, 2); + } + else + { ck_tile::FillUniformDistribution{-2.0f, 2.0f, fill_seed(gen)}( *bq_tensor_ptr); } @@ -850,18 +856,19 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser, } else if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped) { - if constexpr(std::is_same_v) - ck_tile::reference_mxfp4gemm_quant( + if constexpr(std::is_same_v) + ck_tile::reference_mx_gemm_bquant( a_m_k, *bq_tensor_ptr, b_k_n, c_m_n_host_ref); else ck_tile::reference_gemm_quant) && + std::is_same_v) && GemmConfig::PreshuffleB) { throw std::runtime_error( diff --git a/include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp b/include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp index c165cacba2..c74068c03c 100644 --- a/include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp +++ b/include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp @@ -2865,6 +2865,12 @@ __device__ auto amd_transpose_load_to_vgpr(const T* __restrict__ in_ptr) auto lds_ptr = reinterpret_cast<__LDS_ADDR llvm_i32x2_t*>(in_ptr_); return bit_cast>(__builtin_amdgcn_ds_read_tr8_b64_v2i32(lds_ptr)); } + else if constexpr(std::is_same_v, ck_tile::pk_fp4_t>) + { + typedef __attribute__((__vector_size__(2 * sizeof(index_t)))) index_t llvm_i32x2_t; + auto lds_ptr = reinterpret_cast<__LDS_ADDR llvm_i32x2_t*>(in_ptr_); + return bit_cast>(__builtin_amdgcn_ds_read_tr4_b64_v2i32(lds_ptr)); + } else { static_assert(false, "not implemented"); diff --git a/include/ck_tile/core/tensor/load_tile_transpose.hpp b/include/ck_tile/core/tensor/load_tile_transpose.hpp index 2d71a9cfab..462a9cf4ab 100644 --- a/include/ck_tile/core/tensor/load_tile_transpose.hpp +++ b/include/ck_tile/core/tensor/load_tile_transpose.hpp @@ -50,60 +50,61 @@ constexpr bool is_sequence_suffix_v = is_sequence_suffix::valu template struct DefaultTranspose { - template - struct Quad16 + template + struct Quad { static_assert(LaneGroupSize == 64 || LaneGroupSize == 32 || LaneGroupSize == 16, "LaneGroupSize must be 64, 32, or 16"); - using InputEncoding = - tile_distribution_encoding, - tuple, sequence>, - tuple>, - tuple>, - sequence<2>, - sequence<2>>; - using OutputEncoding = - tile_distribution_encoding, - tuple, sequence<4>>, - tuple>, - tuple>, - sequence<2>, - sequence<0>>; + // The tile is defined by the LaneGroupSize, which defines the number of lanes in the M/N + // dimensions for the MMA instruction defined by warp gemm. + // The LaneGroupSize is subdivided into groups of 16 (finer granularity of MMA + // instructions), we define these as major subtiles. Each of these major subtile is divided + // into minor subtiles which group the lanes exchanging data during the transpose Example + // LaneGroupSize = 16, 16 bit type: + // - There is 1 group of 16 lanes (1 major subtile) + // - Each major subtile is divided into 4 minor subtiles of (4x4) -> 4 lanes transpose + // the minor subtile and each lane holds 4 elements + + // all load transpose instructions use 64 bit right now + static constexpr index_t InstructionBits = 64; + // Subtile major dimension is fixed + static constexpr index_t SubtileMajorDimension = 16; + // Number of subtile major + static constexpr index_t NumSubtilesMajor = LaneGroupSize / 16; + // number of elements loaded by each lane with single instruction, but also number + // of consecutive lanes in a subtile. Subtile is squared (NLanes x NElementsPerLane) + static constexpr index_t SubtileMinorDimension = InstructionBits / NumBitType; + // Number of subtiles minor inside each subtile major + static constexpr index_t NumSubtilesMinor = 16 / SubtileMinorDimension; + + using InputEncoding = tile_distribution_encoding< + sequence<>, + tuple, + sequence>, + tuple>, + tuple>, + sequence<2>, + sequence<2>>; + + using OutputEncoding = tile_distribution_encoding< + sequence<>, + tuple, sequence>, + tuple>, + tuple>, + sequence<2>, + sequence<0>>; }; - template - struct Quad8 - { - static_assert(LaneGroupSize == 64 || LaneGroupSize == 32 || LaneGroupSize == 16, - "LaneGroupSize must be 64, 32, or 16"); - using InputEncoding = - tile_distribution_encoding, - tuple, sequence>, - tuple>, - tuple>, - sequence<2>, - sequence<2>>; - - using OutputEncoding = - tile_distribution_encoding, - tuple, sequence<8>>, - tuple>, - tuple>, - sequence<2>, - sequence<0>>; - }; + static constexpr index_t PackedSize = numeric_traits>::PackedSize; + static constexpr index_t NumBitsDataType = (sizeof(DataType) * 8) / PackedSize; // Select based on data size template - using QuadInputEncoding = std::conditional_t::InputEncoding, - typename Quad8::InputEncoding>; + using QuadInputEncoding = typename Quad::InputEncoding; template - using QuadOutputEncoding = std::conditional_t::OutputEncoding, - typename Quad8::OutputEncoding>; + using QuadOutputEncoding = typename Quad::OutputEncoding; // Always swap last two dimensions static constexpr auto transpose_dims = sequence<1, 0>{}; diff --git a/include/ck_tile/core/tensor/static_distributed_tensor.hpp b/include/ck_tile/core/tensor/static_distributed_tensor.hpp index bdd81dae07..787c17e1be 100644 --- a/include/ck_tile/core/tensor/static_distributed_tensor.hpp +++ b/include/ck_tile/core/tensor/static_distributed_tensor.hpp @@ -78,7 +78,7 @@ struct static_distributed_tensor constexpr auto sliced_thread_tensor_desc = make_naive_tensor_descriptor_packed(make_tuple(YSliceLengths...)); - thread_buffer + thread_buffer sliced_thread_data; static_ford>{}([&](auto idx) { diff --git a/include/ck_tile/core/tensor/tensor_view.hpp b/include/ck_tile/core/tensor/tensor_view.hpp index 95d66b66ed..1d008b495b 100644 --- a/include/ck_tile/core/tensor/tensor_view.hpp +++ b/include/ck_tile/core/tensor/tensor_view.hpp @@ -287,8 +287,8 @@ struct tensor_view get_transpose_vectorized_elements(const TensorCoord& coord, index_t linear_offset) const { return buf_.template transpose_get( - coord.get_offset(), - linear_offset, + coord.get_offset() / PackedSize, + linear_offset / PackedSize, coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord)); } @@ -303,7 +303,8 @@ struct tensor_view bool is_valid_element // flag ) const { - return buf_.template transpose_get(coord.get_offset(), linear_offset, is_valid_element); + return buf_.template transpose_get( + coord.get_offset() / PackedSize, linear_offset / PackedSize, is_valid_element); } // X is vector of DataType. // "coord" is coordinate of DataType, not X. "coord" should be aligned to X diff --git a/include/ck_tile/core/tensor/tile_window.hpp b/include/ck_tile/core/tensor/tile_window.hpp index ba7eeb1936..2f2fe12f42 100644 --- a/include/ck_tile/core/tensor/tile_window.hpp +++ b/include/ck_tile/core/tensor/tile_window.hpp @@ -736,7 +736,7 @@ struct tile_window_with_static_distribution .template get_transpose_vectorized_elements( bottom_tensor_thread_coord, offset); // write into distributed tensor - static_for<0, Traits::ScalarPerVector, 1>{}([&](auto j) { + static_for<0, Traits::ScalarPerVector, Traits::PackedSize>{}([&](auto j) { constexpr auto orig_idx_ys = generate_tuple( [&](auto jj) { return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j) @@ -747,10 +747,12 @@ struct tile_window_with_static_distribution constexpr auto grouped_idx_ys = group_func(orig_idx_ys); constexpr index_t linear_distributed_index = - tile_dstr.get_ys_to_d_descriptor().calculate_offset(grouped_idx_ys); + tile_dstr.get_ys_to_d_descriptor().calculate_offset(grouped_idx_ys) / + Traits::PackedSize; dst_tensor.get_thread_buffer().template at() = - vec_value.template get_as()[j]; + vec_value + .template get_as()[j / Traits::PackedSize]; }); // move thread coordinate if constexpr(iCoordAccess != (NumAccessPerCoord - 1)) diff --git a/include/ck_tile/host/reference/reference_gemm.hpp b/include/ck_tile/host/reference/reference_gemm.hpp index da6b074b98..b6d7fbf521 100644 --- a/include/ck_tile/host/reference/reference_gemm.hpp +++ b/include/ck_tile/host/reference/reference_gemm.hpp @@ -388,49 +388,56 @@ template -CK_TILE_HOST void reference_mxfp4gemm_quant(const HostTensor& a_m_k, - const HostTensor& q, - const HostTensor& b_k_n, - HostTensor& c_m_n, - const AElementOp& a_element_op = {}, - const BElementOp& b_element_op = {}, - const ACCElementOp& acc_element_op = {}) +CK_TILE_HOST void reference_mx_gemm_bquant(const HostTensor& a_m_k, + const HostTensor& q, + const HostTensor& b_k_n, + HostTensor& c_m_n, + const AElementOp& a_element_op = {}, + const BElementOp& b_element_op = {}, + const ACCElementOp& acc_element_op = {}) { const std::size_t M = a_m_k.get_length(0); const std::size_t N = b_k_n.get_length(1); const std::size_t K = a_m_k.get_length(1); auto f_mn = [&](auto m, auto n) { - AccDataType v_acc = 0; - AccDataType pasual = 0; - for(std::size_t k = 0; k < (K / 2); k++) - { - using ComputeType = float; - auto b_scale = type_convert(q((2 * k) / QuantGroupSize::kK, n)) - 127; - ComputeType v_a_0, v_a_1; - ComputeType v_b_0, v_b_1; + AccDataType v_acc = 0; + using ComputeType = float; + ComputeType v_a; + ComputeType v_b; - v_a_0 = ck_tile::type_convert((a_element_op(a_m_k(m, 2 * k)))); - v_a_1 = ck_tile::type_convert((a_element_op(a_m_k(m, 2 * k + 1)))); - - if constexpr(std::is_same_v) + auto load_b = [&](std::size_t k) -> AccDataType { + if constexpr(std::is_same_v) { - auto b_pack = type_convert(b_element_op(b_k_n(k, n))); - auto b_scale_fp4 = type_convert(std::pow(2.0f, b_scale)); - - auto b_f4_lo = type_convert(b_pack.unpack(number<0>{})); - auto b_f4_hi = type_convert(b_pack.unpack(number<1>{})); - - v_b_0 = type_convert(b_f4_lo) * b_scale_fp4; - v_b_1 = type_convert(b_f4_hi) * b_scale_fp4; + const auto b_pack = type_convert(b_element_op(b_k_n(k, n))); + if constexpr(std::is_same_v) + { + return (n & 1) ? type_convert(b_pack.unpack(number<1>{})) + : type_convert(b_pack.unpack(number<0>{})); + } + else + { + return (k & 1) ? type_convert(b_pack.unpack(number<1>{})) + : type_convert(b_pack.unpack(number<0>{})); + } } + else + { + return ck_tile::type_convert(b_element_op(b_k_n(k, n))); + } + }; - pasual = v_a_0 * v_b_0 + v_a_1 * v_b_1; - v_acc += pasual; + for(std::size_t k = 0; k < K; k++) + { + const auto b_scale = type_convert(q(k / QuantGroupSize::kK, n)); + v_a = ck_tile::type_convert(a_element_op(a_m_k(m, k))); + v_b = load_b(k) * b_scale; + v_acc += v_a * v_b; } c_m_n(m, n) = ck_tile::type_convert(acc_element_op(v_acc)); }; diff --git a/include/ck_tile/ops/common/utils.hpp b/include/ck_tile/ops/common/utils.hpp index 4a30e3af16..6c1287486f 100644 --- a/include/ck_tile/ops/common/utils.hpp +++ b/include/ck_tile/ops/common/utils.hpp @@ -24,6 +24,7 @@ template <> struct DataTypeTraits { static constexpr const char * nam template <> struct DataTypeTraits { static constexpr const char * name = "pk_fp4"; }; template <> struct DataTypeTraits { static constexpr const char * name = "pk_fp6x16"; }; template <> struct DataTypeTraits { static constexpr const char * name = "pk_fp4_raw"; }; +template <> struct DataTypeTraits { static constexpr const char * name = "e8m0"; }; template struct memOpToStr; template <> struct memOpToStr { static constexpr const char * name = "set"; }; diff --git a/include/ck_tile/ops/elementwise/unary_element_wise_operation.hpp b/include/ck_tile/ops/elementwise/unary_element_wise_operation.hpp index 3f58eceb33..4ad699629c 100644 --- a/include/ck_tile/ops/elementwise/unary_element_wise_operation.hpp +++ b/include/ck_tile/ops/elementwise/unary_element_wise_operation.hpp @@ -359,6 +359,260 @@ CK_TILE_DEVICE bf8x4_t i4_to_bf8x4(int q) } #endif +CK_TILE_HOST_DEVICE bf16x8_t bf8x8_to_bf16x8_scale(const bf8x8_t& src, const float& scale) +{ + bf16x8_t y; +#if defined(__gfx950__) + constexpr index_t USE_BOTTOM = 0; + constexpr index_t USE_TOP = 1; + + auto convert_quartet = [&](index_t src_offset, index_t dst_offset) { + union + { + uint32_t packed; + bf8_t elements[4]; + } input; + + union + { + bf16x2_t vec; + bf16_t elements[2]; + } output; + + input.elements[0] = src[src_offset]; + input.elements[1] = src[src_offset + 1]; + input.elements[2] = src[src_offset + 2]; + input.elements[3] = src[src_offset + 3]; + + output.vec = __builtin_amdgcn_cvt_scalef32_pk_bf16_bf8(input.packed, scale, USE_BOTTOM); + y[dst_offset] = output.elements[0]; + y[dst_offset + 1] = output.elements[1]; + + output.vec = __builtin_amdgcn_cvt_scalef32_pk_bf16_bf8(input.packed, scale, USE_TOP); + y[dst_offset + 2] = output.elements[0]; + y[dst_offset + 3] = output.elements[1]; + }; + + convert_quartet(0, 0); + convert_quartet(4, 4); +#else + static_for<0, 8, 1>{}([&](auto i) { + y[i.value] = type_convert(type_convert(src[i.value]) * scale); + }); +#endif + return y; +} + +CK_TILE_HOST_DEVICE bf16x8_t fp8x8_to_bf16x8_scale(const fp8x8_t& src, const float& scale) +{ + bf16x8_t y; +#if defined(__gfx950__) + constexpr index_t USE_BOTTOM = 0; + constexpr index_t USE_TOP = 1; + + auto convert_quartet = [&](index_t src_offset, index_t dst_offset) { + union + { + uint32_t packed; + fp8_t elements[4]; + } input; + + union + { + bf16x2_t vec; + bf16_t elements[2]; + } output; + + input.elements[0] = src[src_offset]; + input.elements[1] = src[src_offset + 1]; + input.elements[2] = src[src_offset + 2]; + input.elements[3] = src[src_offset + 3]; + + output.vec = __builtin_amdgcn_cvt_scalef32_pk_bf16_fp8(input.packed, scale, USE_BOTTOM); + y[dst_offset] = output.elements[0]; + y[dst_offset + 1] = output.elements[1]; + + output.vec = __builtin_amdgcn_cvt_scalef32_pk_bf16_fp8(input.packed, scale, USE_TOP); + y[dst_offset + 2] = output.elements[0]; + y[dst_offset + 3] = output.elements[1]; + }; + + convert_quartet(0, 0); + convert_quartet(4, 4); +#else + static_for<0, 8, 1>{}([&](auto i) { + y[i.value] = type_convert(type_convert(src[i.value]) * scale); + }); +#endif + return y; +} + +CK_TILE_HOST_DEVICE fp16x8_t fp8x8_to_fp16x8_scale(const fp8x8_t& src, const float& scale) +{ + fp16x8_t y; +#if defined(__gfx950__) + constexpr index_t USE_BOTTOM = 0; + constexpr index_t USE_TOP = 1; + + auto convert_quartet = [&](index_t src_offset, index_t dst_offset) { + union + { + uint32_t packed; + fp8_t elements[4]; + } input; + + union + { + fp16x2_t vec; + fp16_t elements[2]; + } output; + + input.elements[0] = src[src_offset]; + input.elements[1] = src[src_offset + 1]; + input.elements[2] = src[src_offset + 2]; + input.elements[3] = src[src_offset + 3]; + + output.vec = __builtin_amdgcn_cvt_scalef32_pk_f16_fp8(input.packed, scale, USE_BOTTOM); + y[dst_offset] = output.elements[0]; + y[dst_offset + 1] = output.elements[1]; + + output.vec = __builtin_amdgcn_cvt_scalef32_pk_f16_fp8(input.packed, scale, USE_TOP); + y[dst_offset + 2] = output.elements[0]; + y[dst_offset + 3] = output.elements[1]; + }; + + convert_quartet(0, 0); + convert_quartet(4, 4); +#else + static_for<0, 8, 1>{}([&](auto i) { + y[i.value] = type_convert(type_convert(src[i.value]) * scale); + }); +#endif + return y; +} + +CK_TILE_HOST_DEVICE fp16x8_t bf8x8_to_fp16x8_scale(const bf8x8_t& src, const float& scale) +{ + fp16x8_t y; +#if defined(__gfx950__) + constexpr index_t USE_BOTTOM = 0; + constexpr index_t USE_TOP = 1; + + auto convert_quartet = [&](index_t src_offset, index_t dst_offset) { + union + { + uint32_t packed; + bf8_t elements[4]; + } input; + + union + { + fp16x2_t vec; + fp16_t elements[2]; + } output; + + input.elements[0] = src[src_offset]; + input.elements[1] = src[src_offset + 1]; + input.elements[2] = src[src_offset + 2]; + input.elements[3] = src[src_offset + 3]; + + output.vec = __builtin_amdgcn_cvt_scalef32_pk_f16_bf8(input.packed, scale, USE_BOTTOM); + y[dst_offset] = output.elements[0]; + y[dst_offset + 1] = output.elements[1]; + + output.vec = __builtin_amdgcn_cvt_scalef32_pk_f16_bf8(input.packed, scale, USE_TOP); + y[dst_offset + 2] = output.elements[0]; + y[dst_offset + 3] = output.elements[1]; + }; + + convert_quartet(0, 0); + convert_quartet(4, 4); +#else + static_for<0, 8, 1>{}([&](auto i) { + y[i.value] = type_convert(type_convert(src[i.value]) * scale); + }); +#endif + return y; +} + +CK_TILE_HOST_DEVICE bf16x8_t fp4x4_to_bf16x8_scale(const pk_fp4x4_t& src, const float& scale) +{ + bf16x8_t y; +#if defined(__gfx950__) + union + { + uint32_t u32; + pk_fp4x4_t pf4; + } cvt; + + constexpr index_t USE_BYTE_0 = 0; + constexpr index_t USE_BYTE_1 = 1; + constexpr index_t USE_BYTE_2 = 2; + constexpr index_t USE_BYTE_3 = 3; + + cvt.pf4 = src; + bf16x2_t y0 = __builtin_amdgcn_cvt_scalef32_pk_bf16_fp4(cvt.u32, scale, USE_BYTE_0); + bf16x2_t y1 = __builtin_amdgcn_cvt_scalef32_pk_bf16_fp4(cvt.u32, scale, USE_BYTE_1); + bf16x2_t y2 = __builtin_amdgcn_cvt_scalef32_pk_bf16_fp4(cvt.u32, scale, USE_BYTE_2); + bf16x2_t y3 = __builtin_amdgcn_cvt_scalef32_pk_bf16_fp4(cvt.u32, scale, USE_BYTE_3); + + y[0] = y0[0]; + y[1] = y0[1]; + y[2] = y1[0]; + y[3] = y1[1]; + y[4] = y2[0]; + y[5] = y2[1]; + y[6] = y3[0]; + y[7] = y3[1]; +#else + static_for<0, 4, 1>{}([&](auto i) { + auto yi = pk_fp4_to_bf16x2(src[i.value], scale); + y[2 * i.value] = yi[0]; + y[2 * i.value + 1] = yi[1]; + }); +#endif + return y; +} + +CK_TILE_HOST_DEVICE fp16x8_t fp4x4_to_fp16x8_scale(const pk_fp4x4_t& src, const float& scale) +{ + fp16x8_t y; +#if defined(__gfx950__) + union + { + uint32_t u32; + pk_fp4x4_t pf4; + } cvt; + + constexpr index_t USE_BYTE_0 = 0; + constexpr index_t USE_BYTE_1 = 1; + constexpr index_t USE_BYTE_2 = 2; + constexpr index_t USE_BYTE_3 = 3; + + cvt.pf4 = src; + fp16x2_t y0 = __builtin_amdgcn_cvt_scalef32_pk_f16_fp4(cvt.u32, scale, USE_BYTE_0); + fp16x2_t y1 = __builtin_amdgcn_cvt_scalef32_pk_f16_fp4(cvt.u32, scale, USE_BYTE_1); + fp16x2_t y2 = __builtin_amdgcn_cvt_scalef32_pk_f16_fp4(cvt.u32, scale, USE_BYTE_2); + fp16x2_t y3 = __builtin_amdgcn_cvt_scalef32_pk_f16_fp4(cvt.u32, scale, USE_BYTE_3); + + y[0] = y0[0]; + y[1] = y0[1]; + y[2] = y1[0]; + y[3] = y1[1]; + y[4] = y2[0]; + y[5] = y2[1]; + y[6] = y3[0]; + y[7] = y3[1]; +#else + static_for<0, 4, 1>{}([&](auto i) { + auto yi = pk_fp4_to_fp16x2(src[i.value], scale); + y[2 * i.value] = yi[0]; + y[2 * i.value + 1] = yi[1]; + }); +#endif + return y; +} + struct PassThroughPack8 { static constexpr const char* name = "PassThroughPack8"; @@ -437,6 +691,50 @@ struct DequantPack8 y.hi = i4_to_half4_scale(bit_cast(x) >> 8, z); } + CK_TILE_HOST_DEVICE constexpr void + operator()(bf16x8_t& y, const pk_fp4x4_t& x, const float& z) const + { + y = fp4x4_to_bf16x8_scale(x, z); + } + + CK_TILE_HOST_DEVICE constexpr void + operator()(fp16x8_t& y, const pk_fp4x4_t& x, const float& z) const + { + y = fp4x4_to_fp16x8_scale(x, z); + } + + CK_TILE_HOST_DEVICE constexpr void + operator()(bf16x8_t& y, const bf8x8_t& x, const float& z) const + { + y = bf8x8_to_bf16x8_scale(x, z); + } + + CK_TILE_HOST_DEVICE constexpr void + operator()(bf16x8_t& y, const fp8x8_t& x, const float& z) const + { + y = fp8x8_to_bf16x8_scale(x, z); + } + + CK_TILE_HOST_DEVICE constexpr void + operator()(fp16x8_t& y, const fp8x8_t& x, const float& z) const + { + y = fp8x8_to_fp16x8_scale(x, z); + } + + CK_TILE_HOST_DEVICE constexpr void + operator()(fp16x8_t& y, const bf8x8_t& x, const float& z) const + { + y = bf8x8_to_fp16x8_scale(x, z); + } + + CK_TILE_HOST_DEVICE constexpr void + operator()(bf16x8_t& y, const bf16x8_t& x, const float& z) const + { + static_for<0, 8, 1>{}([&](auto i) { + y[i.value] = type_convert(type_convert(x[i.value]) * z); + }); + } + constexpr const static bool is_pack8_invocable = true; }; diff --git a/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp b/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp index b31f8ba02a..7ebfa412f7 100644 --- a/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp +++ b/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp @@ -99,7 +99,7 @@ struct CShuffleEpilogue // Used for weight-only quantization kernel, B would be dequantized to the same data type as A using BTypeToUse = std::conditional_t || std::is_same_v || - std::is_same_v, + sizeof(BDataType) < sizeof(ADataType), ADataType, BDataType>; diff --git a/include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp b/include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp index 79030fcd51..7f34ae24bb 100644 --- a/include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp +++ b/include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp @@ -97,7 +97,8 @@ struct BlockUniversalGemmAsBsCr using ATypeToUse = std::conditional_t, BDataType, ADataType>; using BTypeToUse = std::conditional_t || - std::is_same_v, + std::is_same_v || + sizeof(BDataType) < sizeof(ADataType), ADataType, BDataType>; diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp index 4973d9c941..7cc14ecc39 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp @@ -20,8 +20,23 @@ struct GemmPipelineAgBgCrImplBase using ADataType = remove_cvref_t{}, AsDataType>>; using ALayout = remove_cvref_t{}, AsLayout>>; using BInDataType = remove_cvref_t{}, BsDataType>>; - using BDataType = - std::conditional_t, ADataType, BInDataType>; + + template + using has_bcastpolicy_type = decltype(T::BCastPolicy); + + static constexpr bool IsBCastPolicyBeforeLDSWrite = [] { + if constexpr(is_detected{}) + { + return Problem::BCastPolicy == CastPolicy::BeforeLDSWrite; + } + else + { + return false; + } + }(); + + using BDataType = std::conditional_t; + using BLayout = remove_cvref_t{}, BsLayout>>; static constexpr index_t MPerBlock = BlockGemmShape::kM; @@ -226,6 +241,12 @@ struct GemmPipelineAgBgCrImplBase CK_TILE_DEVICE constexpr auto MakeALdsWindows(const ALdsTensorView& a_lds_block_view, const ALdsLoadTileDistr&) const { + // with pk_int4_t load transpose the LDS type is always BDataType + using ADataTypeLDS = + std::conditional_t, + typename Problem::BDataType, + typename Problem::ADataType>; + auto a_lds_shape = []() { if constexpr(is_a_load_tr) return make_tuple(number{}, number{}); @@ -238,9 +259,8 @@ struct GemmPipelineAgBgCrImplBase auto a_lds_load_tile_distr = []() { if constexpr(is_a_load_tr) return make_static_tile_distribution( - typename InputTileDistributionTraits< - typename ALdsLoadTileDistr::DstrEncode, - typename Problem::ADataType>::TransposedDstrEncode{}); + typename InputTileDistributionTraits::TransposedDstrEncode{}); else return ALdsLoadTileDistr{}; }(); @@ -313,10 +333,9 @@ struct GemmPipelineAgBgCrImplBase auto b_copy_lds_window = make_tile_window(b_lds_block_view, b_lds_shape, {0, 0}); - using BLdsDataType = - std::conditional_t, - typename Problem::ADataType, - typename Problem::BDataType>; + using BLdsDataType = std::conditional_t; auto b_lds_load_tile_distr = []() { if constexpr(is_b_load_tr) diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp index 987704e433..f9d82f8eb4 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp @@ -10,6 +10,12 @@ namespace ck_tile { +enum struct CastPolicy +{ + BeforeLDSWrite, + AfterLDSRead, +}; + enum struct GemmPipelineScheduler { Default, diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp index 8074994fdd..cb112a11a7 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp @@ -80,6 +80,21 @@ struct UniversalGemmBasePolicy static constexpr bool is_b_load_tr = false; #endif + template + using has_bcastpolicy_type = decltype(T::BCastPolicy); + + template + static constexpr bool IsBCastPolicyBeforeLDSWrite_v = [] { + if constexpr(is_detected{}) + { + return Problem::BCastPolicy == CastPolicy::BeforeLDSWrite; + } + else + { + return false; + } + }(); + static constexpr auto I0 = number<0>{}; static constexpr auto I1 = number<1>{}; static constexpr auto I2 = number<2>{}; @@ -305,11 +320,11 @@ struct UniversalGemmBasePolicy template CK_TILE_DEVICE static constexpr auto MakeBLdsBlockDescriptor() { - using BLayout = remove_cvref_t; - using BDataType = - std::conditional_t, - typename Problem::ADataType, - typename Problem::BDataType>; + using BLayout = remove_cvref_t; + constexpr bool IsBCastPolicyBeforeLDSWrite = IsBCastPolicyBeforeLDSWrite_v; + using BDataType = std::conditional_t; constexpr index_t NPerBlock = Problem::BlockGemmShape::kN; constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; @@ -589,15 +604,14 @@ struct UniversalGemmBasePolicy CK_TILE_HOST_DEVICE static constexpr index_t GetVectorSizeB() { using BsLayout = remove_cvref_t; - using BsDataType = remove_cvref_t; constexpr index_t NPerBlock = Problem::BlockGemmShape::kN; constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; using BLayout = remove_cvref_t{}, BsLayout>>; - using BInDataType = remove_cvref_t{}, BsDataType>>; - using BDataType = std::conditional_t, - typename Problem::ADataType, - typename Problem::BDataType>; + constexpr bool IsBCastPolicyBeforeLDSWrite = IsBCastPolicyBeforeLDSWrite_v; + using BDataType = std::conditional_t; if constexpr(Problem::FixedVectorSize) { @@ -739,13 +753,13 @@ struct UniversalGemmBasePolicy { constexpr index_t BlockSize = Problem::kBlockSize; constexpr index_t NPerBlock = Problem::BlockGemmShape::kN; - using BDataType = remove_cvref_t; - constexpr index_t KPerBlock = std::is_same_v - ? Problem::BlockGemmShape::kK / 2 - : Problem::BlockGemmShape::kK; + constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; + // If we cast before writing to LDS, the vectorsize is defined by the A type + // since the assumption is that A type is going to be the B LDS type + constexpr bool IsBCastPolicyBeforeLDSWrite = IsBCastPolicyBeforeLDSWrite_v; constexpr index_t VecLoadSize = - std::is_same_v - ? 4 + IsBCastPolicyBeforeLDSWrite + ? (Problem::FixedVectorSize ? Problem::VectorSizeA : GetVectorSizeA()) : (Problem::FixedVectorSize ? Problem::VectorSizeB : GetVectorSizeB()); constexpr index_t NumWaveGroups = Problem::NumWaveGroups; using BLayout = remove_cvref_t< @@ -855,10 +869,10 @@ struct UniversalGemmBasePolicy template CK_TILE_DEVICE static constexpr index_t GetSmemSizeB() { - using BDataType = - std::conditional_t, - typename Problem::ADataType, - typename Problem::BDataType>; + constexpr bool IsBCastPolicyBeforeLDSWrite = IsBCastPolicyBeforeLDSWrite_v; + using BDataType = std::conditional_t; constexpr auto b_lds_block_desc = Derived::template MakeBLdsBlockDescriptor(); constexpr index_t smem_size_b = integer_least_multiple( b_lds_block_desc.get_element_space_size() * sizeof(BDataType), 16); @@ -900,7 +914,8 @@ struct UniversalGemmPipelineAgBgCrPolicy using ATypeToUse = std::conditional_t, BDataType, ADataType>; using BTypeToUse = std::conditional_t || - std::is_same_v, + std::is_same_v || + sizeof(BDataType) < sizeof(ADataType), ADataType, BDataType>; diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm.hpp index 0051242475..f3fa99304c 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm.hpp @@ -185,16 +185,35 @@ using WarpGemmMfmaBf16Bf16F32M32N32K16 = WarpGemmImpl +template using WarpGemmMfmaBf16Bf16F32M16N16K32 = WarpGemmImpl< WarpGemmAttributeMfma, - AttrNumAccess>>; + AttrNumAccessA, + AttrNumAccessB>>; + +template +using WarpGemmMfmaBf16Bf16F32M16N16K64 = WarpGemmImpl, + 2, + AttrNumAccessA, + AttrNumAccessB>>; #else -template +template using WarpGemmMfmaBf16Bf16F32M16N16K32 = WarpGemmImpl, 2, - AttrNumAccess>>; + AttrNumAccessA>>; + +template +using WarpGemmMfmaBf16Bf16F32M16N16K64 = WarpGemmImpl, + 4, + AttrNumAccessA, + AttrNumAccessB>>; #endif using WarpGemmMfmaBf16Bf16F32M32N32K8SwizzleA = WarpGemmImpl +struct get_wgattr_num_access +{ + private: + static constexpr index_t getAccesses() + { + if constexpr(AttrNumAccess == WGAttrNumAccessEnum::Single) + { + return 1; + } + else if constexpr(AttrNumAccess == WGAttrNumAccessEnum::Double) + { + return 2; + } + else if constexpr(AttrNumAccess == WGAttrNumAccessEnum::Quad) + { + return 4; + } + else + { + static_assert(false, "unsupported AttrNumAccess"); + return 0; + } + } + + public: + static constexpr auto value = getAccesses(); +}; + template + WGAttrNumAccessEnum AttrNumAccessA_ = WGAttrNumAccessEnum::Single, + WGAttrNumAccessEnum AttrNumAccessB_ = AttrNumAccessA_> struct WarpGemmAttributeMfma { - using Impl = remove_cvref_t; - static constexpr auto AttrNumAccess = AttrNumAccess_; - static constexpr auto AttrNumAccessV = static_cast(AttrNumAccess); + using Impl = remove_cvref_t; + static constexpr auto AttrNumAccessA = AttrNumAccessA_; + static constexpr auto AttrNumAccessAV = get_wgattr_num_access::value; + static constexpr auto AttrNumAccessB = AttrNumAccessB_; + static constexpr auto AttrNumAccessBV = get_wgattr_num_access::value; + + static constexpr bool UsePackNumAccess = AttrNumAccessA != AttrNumAccessB; using ADataType = typename Impl::ADataType; using BDataType = typename Impl::BDataType; @@ -44,12 +78,13 @@ struct WarpGemmAttributeMfma static_assert(Impl::kAMBlock == 1 && Impl::kBNBlock == 1, "Multi-block WarpGemmAttributeMfmaImpl is not supported"); - template + template static constexpr auto get_warp_dstr_encoding() { - static_assert(kKPerThread % AttrNumAccessV == 0, + static_assert(kKPerThread % AttrNumAccessV_ == 0, "kKPerThread must be divisible by NumAccess"); - if constexpr(AttrNumAccessV == 1) + if constexpr(AttrNumAccessV_ == 1) + { return tile_distribution_encoding< sequence<>, tuple, sequence>, @@ -57,18 +92,48 @@ struct WarpGemmAttributeMfma tuple>, sequence<2>, sequence<1>>{}; + } else - return tile_distribution_encoding< - sequence<>, - tuple, - sequence>, - tuple>, - tuple>, - sequence<2, 2>, - sequence<0, 2>>{}; + { + // AttrNumAccess splits the kABKPerLane + // We can split them but still have them contiguous (packed) or have them interleaved. + // The reason to split the dimension but still have it packed is to match load transpose + // encoding when A and B use different AttrNumAccess (they have different types in LDS) + // Example + // A: 16bit, B: 8bit + // Load transpose B: lane0 -> K=0..7 (only 1 instruction) + // Load transpose A: lane0 -> K=0..3 first instruction, K=4..7 second instruction + // In this way the data in register are consistent between A and B + if constexpr(UsePackNumAccess) + { + return tile_distribution_encoding< + sequence<>, + tuple, + sequence>, + tuple>, + tuple>, + sequence<2, 2>, + sequence<1, 2>>{}; + } + else + { + return tile_distribution_encoding< + sequence<>, + tuple, + sequence>, + tuple>, + tuple>, + sequence<2, 2>, + sequence<0, 2>>{}; + } + } } - using AWarpDstrEncoding = decltype(get_warp_dstr_encoding()); - using BWarpDstrEncoding = decltype(get_warp_dstr_encoding()); + using AWarpDstrEncoding = decltype(get_warp_dstr_encoding()); + using BWarpDstrEncoding = decltype(get_warp_dstr_encoding()); using CWarpDstrEncoding = tile_distribution_encoding< sequence<>, @@ -121,14 +186,19 @@ struct WarpGemmAttributeMfma template + WGAttrNumAccessEnum AttrNumAccessA_ = WGAttrNumAccessEnum::Single, + WGAttrNumAccessEnum AttrNumAccessB_ = AttrNumAccessA_> struct WarpGemmAttributeMfmaIterateK { static_assert(kKIter > 0, "wrong!"); - using Impl = remove_cvref_t; - static constexpr auto AttrNumAccess = AttrNumAccess_; - static constexpr auto AttrNumAccessV = static_cast(AttrNumAccess); + using Impl = remove_cvref_t; + static constexpr auto AttrNumAccessA = AttrNumAccessA_; + static constexpr auto AttrNumAccessAV = get_wgattr_num_access::value; + static constexpr auto AttrNumAccessB = AttrNumAccessB_; + static constexpr auto AttrNumAccessBV = get_wgattr_num_access::value; + + static constexpr bool UsePackNumAccess = AttrNumAccessA != AttrNumAccessB; using ADataType = typename Impl::ADataType; using BDataType = typename Impl::BDataType; @@ -151,14 +221,15 @@ struct WarpGemmAttributeMfmaIterateK static_assert(Impl::kAMBlock == 1 || Impl::kBNBlock == 1, "Multi-block on both M & N directions is not supported"); - template + template CK_TILE_DEVICE static constexpr auto get_warp_dstr_encoding() { if constexpr(kMNBlock == 1 && kNMBlock == 1) { - static_assert(kKPerThread % AttrNumAccessV == 0, + static_assert(kKPerThread % AttrNumAccessV_ == 0, "kKPerThread must be divisible by NumAccess"); - if constexpr(AttrNumAccessV == 1) + if constexpr(AttrNumAccessV_ == 1) + { return tile_distribution_encoding< sequence<>, tuple, sequence>, @@ -166,21 +237,40 @@ struct WarpGemmAttributeMfmaIterateK tuple>, sequence<2>, sequence<1>>{}; + } else - return tile_distribution_encoding< - sequence<>, - tuple, - sequence>, - tuple>, - tuple>, - sequence<2, 2>, - sequence<0, 2>>{}; + { + if constexpr(UsePackNumAccess) + { + return tile_distribution_encoding< + sequence<>, + tuple, + sequence>, + tuple>, + tuple>, + sequence<2, 2>, + sequence<1, 2>>{}; + } + else + { + return tile_distribution_encoding< + sequence<>, + tuple, + sequence>, + tuple>, + tuple>, + sequence<2, 2>, + sequence<0, 2>>{}; + } + } } else if constexpr(kMNBlock == 1 && 1 < kNMBlock) { - static_assert(AttrNumAccessV == 1, + static_assert(AttrNumAccessV_ == 1, "Multiple access is not supported when using multi-block"); // each M/N blocks share the same data return tile_distribution_encoding< @@ -193,7 +283,7 @@ struct WarpGemmAttributeMfmaIterateK } else if constexpr(1 < kMNBlock && kNMBlock == 1) { - static_assert(AttrNumAccessV == 1, + static_assert(AttrNumAccessV_ == 1, "Multiple access is not supported when using multi-block"); // single block to multi-block thread mapping return tile_distribution_encoding< @@ -245,10 +335,14 @@ struct WarpGemmAttributeMfmaIterateK } } - using AWarpDstrEncoding = - decltype(get_warp_dstr_encoding()); - using BWarpDstrEncoding = - decltype(get_warp_dstr_encoding()); + using AWarpDstrEncoding = decltype(get_warp_dstr_encoding()); + using BWarpDstrEncoding = decltype(get_warp_dstr_encoding()); using CWarpDstrEncoding = decltype(get_cwarp_dstr_encoding()); // c_vec += a_vec * b_vec diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp index f9a988a923..21360874fb 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp @@ -24,9 +24,10 @@ template + bool SwizzleA = false, + bool UseStructuredSparsity = false, + WGAttrNumAccessEnum AttrNumAccessA = ESingle, + WGAttrNumAccessEnum AttrNumAccessB = AttrNumAccessA> struct Dispatcher; // clang-format off @@ -78,6 +79,10 @@ template<> struct Dispatcher { using template<> struct Dispatcher { using Type = WarpGemmMfmaBf16Bf16F32M32N32K16; }; template<> struct Dispatcher { using Type = WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution; }; template<> struct Dispatcher { using Type = WarpGemmMfmaBf16Bf16F32M16N16K32<>; }; +template<> struct Dispatcher { using Type = WarpGemmMfmaBf16Bf16F32M16N16K32; }; +template<> struct Dispatcher { using Type = WarpGemmMfmaBf16Bf16F32M16N16K64; }; +template<> struct Dispatcher { using Type = WarpGemmMfmaBf16Bf16F32M16N16K64; }; +template<> struct Dispatcher { using Type = WarpGemmMfmaBf16Bf16F32M16N16K64<>; }; template<> struct Dispatcher { using Type = WarpGemmMfmaBf16Bf16F32M16N16K32TransposedCDistribution<>; }; template<> struct Dispatcher { using Type = WarpGemmMfmaBf16Bf16F32M16N16K32; }; template<> struct Dispatcher { using Type = WarpGemmMfmaBf16Bf16F32M16N16K32TransposedCDistribution; }; @@ -166,9 +171,10 @@ template + bool SwizzleA = false, + bool UseStructuredSparsity = false, + WGAttrNumAccessEnum AttrNumAccessA = WGAttrNumAccessEnum::Single, + WGAttrNumAccessEnum AttrNumAccessB = AttrNumAccessA> using WarpGemmDispatcher = typename impl::warp_gemm_dispatcher::Dispatcher< // AType, BType, @@ -179,6 +185,7 @@ using WarpGemmDispatcher = typename impl::warp_gemm_dispatcher::Dispatcher< // TransposeC, SwizzleA, UseStructuredSparsity, - AttrNumAccess>::Type; + AttrNumAccessA, + AttrNumAccessB>::Type; } // namespace ck_tile diff --git a/include/ck_tile/ops/gemm_quant.hpp b/include/ck_tile/ops/gemm_quant.hpp index 91a9521c4f..c2fe66ea5d 100644 --- a/include/ck_tile/ops/gemm_quant.hpp +++ b/include/ck_tile/ops/gemm_quant.hpp @@ -24,9 +24,9 @@ #include "ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_policy.hpp" #include "ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_v3.hpp" #include "ck_tile/ops/gemm_quant/pipeline/gemm_group_quant_utils.hpp" -#include "ck_tile/ops/gemm_quant/pipeline/gemm_mxfp4_pipeline_ag_bg_cr_base.hpp" -#include "ck_tile/ops/gemm_quant/pipeline/gemm_mxfp4_pipeline_ag_bg_cr_policy.hpp" -#include "ck_tile/ops/gemm_quant/pipeline/gemm_mxfp4_pipeline_ag_bg_cr_v3.hpp" +#include "ck_tile/ops/gemm_quant/pipeline/gemm_microscale_pipeline_ag_bg_cr_base.hpp" +#include "ck_tile/ops/gemm_quant/pipeline/gemm_microscale_pipeline_ag_bg_cr_policy.hpp" +#include "ck_tile/ops/gemm_quant/pipeline/gemm_microscale_pipeline_ag_bg_cr_v3.hpp" #include "ck_tile/ops/gemm_quant/pipeline/gemm_quant_pipeline_problem.hpp" #include "ck_tile/ops/gemm_quant/pipeline/gemm_wp_abquant_pipeline_ag_bg_cr_base_policy.hpp" #include "ck_tile/ops/gemm_quant/pipeline/gemm_wp_abquant_pipeline_ag_bg_cr_v2.hpp" diff --git a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp index 9d711c4862..3af7177365 100644 --- a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp +++ b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp @@ -5,6 +5,7 @@ #include "ck_tile/core.hpp" #include "ck_tile/core/arch/arch.hpp" +#include "ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp" #include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_default_policy.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp" #include "ck_tile/ops/elementwise.hpp" @@ -101,20 +102,33 @@ struct BQuantBlockUniversalGemmAsBsCr // 2. bf8, bf8, fp32 -> f32 // 3. i4, fp8, (fp8/fp32) -> f32 // 4. i4, bf8, (fp8/fp32) -> f32 - static_assert((std::is_same_v || std::is_same_v) && - (std::is_same_v || std::is_same_v || - std::is_same_v) && - (std::is_same_v || - std::is_same_v || - std::is_same_v) && - (std::is_same_v || - std::is_same_v) && - std::is_same_v); + // 5. bf16, (bf16/bf8/fp8/fp4), e8m0 -> f32 + // 6. fp16, (fp16/fp8/bf8/fp4), e8m0 -> f32 + static_assert( + is_any_of::value && + is_any_of::value && + is_any_of::value && + is_any_of::value && + std::is_same_v); static constexpr index_t InterWaveSchedulingMacClusters = 1; static constexpr index_t KPack = WarpGemm::kKPerThread; static constexpr index_t KPerThread = KIterPerWarp * WarpGemm::kKPerThread; + + template + using has_bcastpolicy_type = decltype(T::BCastPolicy); + + static constexpr bool IsBCastPolicyBeforeLDSWrite = [] { + if constexpr(is_detected{}) + { + return Problem::BCastPolicy == CastPolicy::BeforeLDSWrite; + } + else + { + return false; + } + }(); }; public: @@ -127,9 +141,12 @@ struct BQuantBlockUniversalGemmAsBsCr using CDataType = remove_cvref_t; // BDataType gets converted from PkInt4 during loading + // OverrideBDataType is only used when BCastPolicy is CastBeforeLDSWrite for microscale. + // In that case we use ADataType using OverrideBDataType = std::conditional_t< - std::is_same_v && - std::is_same_v, + (std::is_same_v && + std::is_same_v) || + Traits::IsBCastPolicyBeforeLDSWrite, ADataType, BDataType>; @@ -176,57 +193,17 @@ struct BQuantBlockUniversalGemmAsBsCr using I0 = number<0>; using I1 = number<1>; + // Use gemm universal block distribution encoding instead of duplicating it + using BlockGemmBase = BlockUniversalGemmAsBsCr; + CK_TILE_DEVICE static constexpr auto MakeABlockDistributionEncode() { - constexpr index_t KPerThread = Traits::KPerThread; - constexpr index_t NumMacClusters = Traits::InterWaveSchedulingMacClusters; - - constexpr index_t KPerInnerLoop = - ck_tile::max(KPerThread / NumMacClusters, WarpGemm::kKPerThread); - - constexpr index_t KIterInterwave = KPerInnerLoop / WarpGemm::kKPerThread; - - using KIterSeq = std::conditional_t, - sequence>; - - constexpr auto a_block_outer_dstr_encoding = - tile_distribution_encoding, - tuple, KIterSeq>, - tuple>, - tuple>, - sequence<1, 2>, - sequence<0, 0>>{}; - constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding( - a_block_outer_dstr_encoding, typename WarpGemm::AWarpDstrEncoding{}); - - return a_block_dstr_encode; + return BlockGemmBase::MakeABlockDistributionEncode(); } CK_TILE_DEVICE static constexpr auto MakeBBlockDistributionEncode() { - constexpr index_t KPerThread = Traits::KPerThread; - constexpr index_t NumMacClusters = Traits::InterWaveSchedulingMacClusters; - constexpr index_t KPerInnerLoop = - ck_tile::max(KPerThread / NumMacClusters, WarpGemm::kKPerThread); - constexpr index_t KIterInterwave = KPerInnerLoop / WarpGemm::kKPerThread; - - using KIterSeq = std::conditional_t, - sequence>; - - constexpr auto b_block_outer_dstr_encoding = - tile_distribution_encoding, - tuple, KIterSeq>, - tuple>, - tuple>, - sequence<1, 2>, - sequence<0, 0>>{}; - - constexpr auto b_block_dstr_encode = detail::make_embed_tile_distribution_encoding( - b_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{}); - - return b_block_dstr_encode; + return BlockGemmBase::MakeBBlockDistributionEncode(); } private: @@ -235,20 +212,24 @@ struct BQuantBlockUniversalGemmAsBsCr { }; + using BlockGemmImplBase = typename BlockUniversalGemmAsBsCr:: + template BlockGemmImpl; + template - struct BlockGemmImpl + struct BlockGemmImpl : public BlockGemmImplBase { - static constexpr auto ALdsTileDistr = - decltype(make_static_tile_distribution(MakeABlockDistributionEncode())){}; - static constexpr auto BLdsTileDistr = - decltype(make_static_tile_distribution(MakeBBlockDistributionEncode())){}; + using BlockGemmImplBase::a_warp_tile_; + using BlockGemmImplBase::b_warp_tile_; + using BlockGemmImplBase::BLdsTileDistr; + // If we apply scale while reading from LDS, then we can use the operator() from + // BlockUniversalGemmAsBsCr + using BlockGemmImplBase::operator(); - using ALdsTile = decltype(make_static_distributed_tensor(ALdsTileDistr)); - using BLdsTile = decltype(make_static_distributed_tensor(BLdsTileDistr)); - - ALdsTile a_warp_tile_; - BLdsTile b_warp_tile_; + // static distributed tensor with LDS type + using BTypeTile = decltype(make_static_distributed_tensor(BLdsTileDistr)); + BTypeTile b_warp_tile_lds_; + // Load from LDS (assumption is that the scale will be applied in the block gemm) template + CK_TILE_DEVICE void LocalPrefetch(const ASmemBlockWindow& a_block_window, + const BSmemBlockWindow& b_block_window, + const BQRegBlockTile& bq_block_tensor, + bool_constant = {}, + bool_constant = {}) + { + // Load tile from LDS + + // Do not use load_int4_tile here because it will have support to cast from fp4 to + // compute type, while here we want to only load from LDS and then apply the scale + // and cast later + if constexpr(ALoadTranspose) + { + a_warp_tile_ = load_tile_transpose(a_block_window); + } + else + { + load_tile(a_warp_tile_, a_block_window); + } + + if constexpr(BLoadTranspose) + { + b_warp_tile_lds_ = load_tile_transpose(b_block_window); + } + else + { + load_tile(b_warp_tile_lds_, b_block_window); + } + + // Apply scale and cast + using BDataTypeRaw = + std::conditional_t, pk_fp4_t::type, BDataType>; + + constexpr index_t warp_size = get_warp_size(); + constexpr index_t nelements = WarpGemm::kK * WarpGemm::kN / warp_size; + constexpr index_t thread_buffer_size = nelements / UnaryOpSize_; + const element_wise::DequantPack8 elementwise_op{}; + using SrcVectorRawType = ext_vector_t; + using DstVectorType = ext_vector_t; + + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + static_for<0, Traits::QScalesPerBlockRow, 1>{}([&](auto kQScale) { + // B scale register offset + constexpr index_t reg_offset = [&]() { + if constexpr(GemmTraits::BQuantGroupSize::kN >= (NWarp * WarpGemm::kN)) + return ((nIter * NWarp * WarpGemm::kN) / + GemmTraits::BQuantGroupSize::kN) * + Traits::KQPerBlock + + kQScale; + else + { + return nIter * Traits::KQPerBlock + kQScale; + } + }(); + + // Get B scale from thread buffer + auto& scale_reg = bq_block_tensor.get_thread_buffer()[reg_offset]; + float b_scale_f = float(scale_reg); + + static_for<0, Traits::KIterPerQScale, 1>{}([&](auto kIterInQScale) { + constexpr auto kIter = kQScale * Traits::KIterPerQScale + kIterInQScale; + // Thread buffers + using BWarpThreadBuffer = decltype(b_warp_tile_.get_y_sliced_thread_data( + merge_sequences(sequence{}, b_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, b_warp_y_lengths))); + using BLDSThreadBuffer = decltype(b_warp_tile_lds_.get_y_sliced_thread_data( + merge_sequences(sequence{}, b_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, b_warp_y_lengths))); + + BWarpThreadBuffer b_warp_thread_buffer; + BLDSThreadBuffer b_lds_thread_buffer; + + // Load thread buffer from tile (LDS type) + b_lds_thread_buffer = b_warp_tile_lds_.get_y_sliced_thread_data( + merge_sequences(sequence{}, b_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, b_warp_y_lengths)); + + // Apply scale to B thread buffer and cast + static_for<0, thread_buffer_size, 1>{}([&](auto i) { + elementwise_op( + b_warp_thread_buffer.template get_as()(i), + b_lds_thread_buffer.template get_as()[i], + b_scale_f); + }); + + // Store B thread buffer to tile (MMA type) + b_warp_tile_.set_y_sliced_thread_data( + merge_sequences(sequence{}, b_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, b_warp_y_lengths), + b_warp_thread_buffer); + }); + }); + }); + } + // C += A * B template + CK_TILE_DEVICE void LocalPrefetch(const ASmemBlockWindow& a_block_window, + const BSmemBlockWindow& b_block_window, + BQRegBlockTile bq_block_tile, + bool_constant a_load_tr = {}, + bool_constant b_load_tr = {}) + { + block_gemm_impl_.LocalPrefetch( + a_block_window, b_block_window, bq_block_tile, a_load_tr, b_load_tr); + } + // C += A * B + // Apply scale after MMA template + CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor, + const ASmemBlockWindow& a_block_window, + const BSmemBlockWindow& b_block_window) + { + block_gemm_impl_(c_block_tensor, a_block_window, b_block_window); + } + private: BlockGemmImpl block_gemm_impl_{}; }; diff --git a/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp b/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp index 05e8aa62a9..62ac2115cc 100644 --- a/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp +++ b/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp @@ -787,20 +787,12 @@ struct QuantGemmKernel } else { - if constexpr(std::is_same_v) - return make_naive_tensor_view( - b_ptr, - make_tuple(kargs.N, k_size / 2), - make_tuple(kargs.stride_B, 1), - number{}, - number<1>{}); - else - return make_naive_tensor_view( - b_ptr, - make_tuple(kargs.N, k_size), - make_tuple(kargs.stride_B, 1), - number{}, - number<1>{}); + return make_naive_tensor_view( + b_ptr, + make_tuple(kargs.N, k_size), + make_tuple(kargs.stride_B, 1), + number{}, + number<1>{}); } } } @@ -814,16 +806,10 @@ struct QuantGemmKernel } else if constexpr(std::is_same_v) { - if constexpr(std::is_same_v) - return pad_tensor_view(b_tensor_view, - make_tuple(number{}, - number{}), - sequence{}); - else - return pad_tensor_view(b_tensor_view, - make_tuple(number{}, - number{}), - sequence{}); + return pad_tensor_view(b_tensor_view, + make_tuple(number{}, + number{}), + sequence{}); } else { @@ -848,17 +834,10 @@ struct QuantGemmKernel { if constexpr(std::is_same_v) { - if constexpr(std::is_same_v) - return make_tile_window( - b_pad_view, - make_tuple(number{}, - number{}), - {i_n, 0}); - else - return make_tile_window(b_pad_view, - make_tuple(number{}, - number{}), - {i_n, 0}); + return make_tile_window(b_pad_view, + make_tuple(number{}, + number{}), + {i_n, 0}); } else { diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_mxfp4_pipeline_ag_bg_cr_base.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_microscale_pipeline_ag_bg_cr_base.hpp similarity index 80% rename from include/ck_tile/ops/gemm_quant/pipeline/gemm_mxfp4_pipeline_ag_bg_cr_base.hpp rename to include/ck_tile/ops/gemm_quant/pipeline/gemm_microscale_pipeline_ag_bg_cr_base.hpp index facec252a3..06ca9854b9 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_mxfp4_pipeline_ag_bg_cr_base.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_microscale_pipeline_ag_bg_cr_base.hpp @@ -10,7 +10,7 @@ namespace ck_tile { template -struct GemmMxFp4PipelineAgBgCrImplBase : public GemmPipelineAgBgCrImplBase +struct GemmMicroscalePipelineAgBgCrImplBase : public GemmPipelineAgBgCrImplBase { using Base = GemmPipelineAgBgCrImplBase; using ADataType = typename Base::ADataType; @@ -42,10 +42,14 @@ struct GemmMxFp4PipelineAgBgCrImplBase : public GemmPipelineAgBgCrImplBase); - - using YPerTile = number; - using XPerTile = number; + using YPerTile = + std::conditional_t, + number, + number>; + using XPerTile = + std::conditional_t, + number, + number>; auto bq_copy_dram_window = make_tile_window(bq_dram_block_window_tmp.get_bottom_tensor_view(), diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_microscale_pipeline_ag_bg_cr_policy.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_microscale_pipeline_ag_bg_cr_policy.hpp new file mode 100644 index 0000000000..a026694769 --- /dev/null +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_microscale_pipeline_ag_bg_cr_policy.hpp @@ -0,0 +1,296 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp" +#include "gemm_group_quant_utils.hpp" + +namespace ck_tile { + +struct GemmMicroscalePipelineAgBgCrPolicy : public UniversalGemmPipelineAgBgCrPolicy +{ + using Base = UniversalGemmPipelineAgBgCrPolicy; + using Base::I0; + using Base::I1; + using Base::I2; + + template + CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeBQ() + { + using BQLayout = remove_cvref_t; + using BQDataType = remove_cvref_t; + constexpr index_t NPerBlock = Problem::BlockGemmShape::kN; + constexpr index_t NPerBlockBQ = NPerBlock / Problem::BQuantGroupSize::kN; + constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; + constexpr index_t KPerBlockBQ = KPerBlock / Problem::BQuantGroupSize::kK; + + // Support both RowMajor and ColumnMajor layouts for BQ + if constexpr(std::is_same_v) + { + return GetABQGlobalVectorLoadSize(); + } + else + { + return GetABQGlobalVectorLoadSize(); + } + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeBRegTileDistribution() + { + using BLayout = remove_cvref_t; + + constexpr index_t BlockSize = Problem::kBlockSize; + constexpr index_t NPerBlock = Problem::BlockGemmShape::kN; + constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; + constexpr index_t VecLoadSize = + Problem::FixedVectorSize ? Problem::VectorSizeB : GetVectorSizeB(); + constexpr index_t NumWaveGroups = Problem::NumWaveGroups; + // Tile: KPerBlock X NPerBlock + if constexpr(std::is_same_v) + { + using TileEncodingPattern = + tile_distribution_encoding_pattern_2d; + return TileEncodingPattern::make_2d_static_tile_distribution(); + } + // Tile: NPerBlock X KPerBlock + else + { + using TileEncodingPattern = + tile_distribution_encoding_pattern_2d; + return TileEncodingPattern::make_2d_static_tile_distribution(); + } + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeBQDramTileDistribution() + { + using BQLayout = remove_cvref_t; + using BLayout = remove_cvref_t; + // If we apply scale before writing to LDS, we need a tile distribution for + // BQuant consistent with global memory reading of matrix B, while + // if we apply scale after reading from LDS, we need a tile distribution for + // BQuant consistent with the MMA instructions layout + if constexpr(Problem::BCastPolicy == CastPolicy::AfterLDSRead) + { + using BlockGemmShape = typename Problem::BlockGemmShape; + + constexpr index_t BlockSize = Problem::kBlockSize; + constexpr index_t NPerBlock = Problem::BlockGemmShape::kN; + constexpr index_t NPerBlockBQ = NPerBlock / Problem::BQuantGroupSize::kN; + constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; + constexpr index_t KPerBlockBQ = KPerBlock / Problem::BQuantGroupSize::kK; + + using WarpTile = typename Problem::BlockGemmShape::WarpTile; + using WarpGemm = WarpGemmDispatcher; + + using TileEncodingPattern = + tile_distribution_encoding_pattern_bq; + + return TileEncodingPattern::make_2d_static_tile_distribution(); + } + else + { + constexpr index_t BlockSize = Problem::kBlockSize; + constexpr index_t NPerBlock = Problem::BlockGemmShape::kN; + constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; + + constexpr index_t VecLoadSize = + Problem::FixedVectorSize ? Problem::VectorSizeB : GetVectorSizeB(); + constexpr index_t NumWaveGroups = Problem::NumWaveGroups; + + constexpr index_t warp_size = get_warp_size(); + constexpr index_t num_warps = BlockSize / get_warp_size(); + constexpr index_t LargestVec = (KPerBlock * NPerBlock) / (num_warps * warp_size); + constexpr index_t b_vec = VecLoadSize > LargestVec ? LargestVec : VecLoadSize; + + constexpr index_t KScale = KPerBlock / Problem::BQuantGroupSize::kK; + + // For each BQ layout we need different encodings whether B has the same layout or not + // TODO: generalize encodings for different BQuantGroupSize granularity + if constexpr(std::is_same_v) + { + if constexpr(std::is_same_v) + { + constexpr index_t K0 = KPerBlock / b_vec; + constexpr index_t K1 = K0 / KScale; + constexpr index_t K3 = KScale; + constexpr index_t K2 = 1; + + constexpr index_t N0 = num_warps / NumWaveGroups; + constexpr index_t N1 = warp_size / K0; + constexpr index_t N2 = NPerBlock / (N0 * N1); + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<1, 2, 0>>, + tuple, sequence<1, 0, 0>>, + sequence<1, 2>, + sequence<2, 1>>{}); + } + else + { + constexpr index_t N1 = NPerBlock / b_vec; + constexpr index_t N2 = b_vec; + + constexpr index_t KRepeatInWave = warp_size / N1; + constexpr index_t KRepeatAcrossWave = num_warps / KScale; + + constexpr index_t K2 = num_warps / KRepeatAcrossWave; + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<0, 1, 2>>, + tuple, sequence<1, 1, 1>>, + sequence<1, 2>, + sequence<2, 2>>{}); + } + } + else + { + if constexpr(std::is_same_v) + { + constexpr index_t NScale = NPerBlock / Problem::BQuantGroupSize::kN; + constexpr index_t N0 = NScale / b_vec; + constexpr index_t N1 = b_vec; + + constexpr index_t KLanes = warp_size / N0; + constexpr index_t KVec = KPerBlock / KLanes / num_warps; + constexpr index_t KRepeat = KPerBlock / KScale / KVec; + + constexpr index_t KRepeatInWave = KRepeat > KLanes ? KLanes : 1; + constexpr index_t KRepeatAcrossWave = KRepeat > KLanes ? KRepeat / KLanes : 1; + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<1, 0, 2>>, + tuple, sequence<1, 1, 0>>, + sequence<1, 2>, + sequence<2, 1>>{}); + } + else + { + constexpr index_t KRepeatInWave = Problem::BQuantGroupSize::kK / b_vec; + constexpr index_t K1 = KScale; + + constexpr index_t N0 = num_warps / NumWaveGroups; + constexpr index_t N1 = warp_size / (KRepeatInWave * K1); + + // Number of contiguous elements in N dimension when reading B matrix + // becomes the vector size of BQ + constexpr index_t N2 = NPerBlock / (BlockSize / (KPerBlock / b_vec)); + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<2, 0, 1, 0>>, + tuple, sequence<1, 1, 1, 2>>, + sequence<1, 2>, + sequence<2, 2>>{}); + } + } + } + } + + // Return AttrNumAccess for a given warp tile (defined by ThreadElements) and data type + template + static constexpr auto GetAttrNumAccess(bool_constant, number) + { + constexpr index_t PackedSize = numeric_traits>::PackedSize; + constexpr index_t vector_size = DS_READ_TR_SIZE() / sizeof(DataType) * PackedSize; + + return !UseLoadTranspose ? WGAttrNumAccessEnum::Single + : vector_size == ThreadElements ? WGAttrNumAccessEnum::Single + : vector_size * 2 == ThreadElements ? WGAttrNumAccessEnum::Double + : vector_size * 4 == ThreadElements ? WGAttrNumAccessEnum::Quad + : WGAttrNumAccessEnum::Invalid; + }; + + template + CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm() + { + using BlockWarps = typename Problem::BlockGemmShape::BlockWarps; + using WarpTile = typename Problem::BlockGemmShape::WarpTile; + using ComputeDataType = typename Problem::ComputeDataType; + using LDSADataType = typename Problem::ADataType; + using LDSBDataType = std::conditional_t; + + static_assert(Problem::BQuantGroupSize::kK % WarpTile::at(I2) == 0, + "KPerWarpGemm must be a multiple of QuantGroupSize!"); + + constexpr auto thread_elements = + number{}; + + constexpr auto is_a_load_tr_v = bool_constant>{}; + constexpr auto is_b_load_tr_v = bool_constant>{}; + constexpr auto is_any_load_tr = is_a_load_tr_v || is_b_load_tr_v; + + constexpr auto wg_attr_num_access_compute = + GetAttrNumAccess(is_any_load_tr, thread_elements); + constexpr auto wg_attr_num_accessA = + std::is_same_v + ? wg_attr_num_access_compute + : GetAttrNumAccess(is_a_load_tr_v, thread_elements); + constexpr auto wg_attr_num_accessB = + std::is_same_v + ? wg_attr_num_access_compute + : GetAttrNumAccess(is_b_load_tr_v, thread_elements); + + using WarpGemm = WarpGemmDispatcher; + static_assert(is_any_of::value); + static_assert(std::is_same_v); + + using BlockGemmPolicy = BlockGemmASmemBSmemCRegV1CustomPolicy< + typename Problem::ADataType, + std::conditional_t, + typename Problem::ADataType, + typename Problem::BDataType>, + typename Problem::CDataType, + BlockWarps, + WarpGemm>; + + return BQuantBlockUniversalGemmAsBsCr{}; + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_mxfp4_pipeline_ag_bg_cr_v3.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_microscale_pipeline_ag_bg_cr_v3.hpp similarity index 60% rename from include/ck_tile/ops/gemm_quant/pipeline/gemm_mxfp4_pipeline_ag_bg_cr_v3.hpp rename to include/ck_tile/ops/gemm_quant/pipeline/gemm_microscale_pipeline_ag_bg_cr_v3.hpp index 7c448599ed..5a03057c64 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_mxfp4_pipeline_ag_bg_cr_v3.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_microscale_pipeline_ag_bg_cr_v3.hpp @@ -9,7 +9,7 @@ #include "ck_tile/core.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp" -#include "ck_tile/ops/gemm_quant/pipeline/gemm_mxfp4_pipeline_ag_bg_cr_base.hpp" +#include "ck_tile/ops/gemm_quant/pipeline/gemm_microscale_pipeline_ag_bg_cr_base.hpp" #include "ck_tile/host/concat.hpp" namespace ck_tile { @@ -18,15 +18,21 @@ namespace ck_tile { // B Tile Window: global memory // C Distributed tensor: register -template -struct MxFp4GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3 +template +struct MicroscaleGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3 { using Base = BaseGemmPipelineAgBgCrCompV3; - using PipelineImplBase = GemmMxFp4PipelineAgBgCrImplBase; + using PipelineImplBase = GemmMicroscalePipelineAgBgCrImplBase; + + using ADataType = remove_cvref_t; + using BDataType = remove_cvref_t; + + using BDqDataType = remove_cvref_t; + + static constexpr bool IsCastBeforeLDS = Problem::BCastPolicy == CastPolicy::BeforeLDSWrite; + + using BLDSType = std::conditional_t; - using ADataType = remove_cvref_t; - using BDataType = remove_cvref_t; - using BDqDataType = remove_cvref_t; using BQDataType = remove_cvref_t; using CDataType = remove_cvref_t; using BlockGemmShape = remove_cvref_t; @@ -40,12 +46,16 @@ struct MxFp4GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3>::PackedSize; + static constexpr index_t BPackedSize = - ck_tile::numeric_traits>::PackedSize; + ck_tile::numeric_traits>::PackedSize; static constexpr index_t BQPackedSize = ck_tile::numeric_traits>::PackedSize; + static constexpr index_t BLDSPackedSize = + ck_tile::numeric_traits>::PackedSize; + using ALayout = remove_cvref_t; using BQLayout = remove_cvref_t; using BLayout = remove_cvref_t; @@ -82,6 +92,9 @@ struct MxFp4GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3{}; + static constexpr auto is_b_load_tr_v = bool_constant{}; + using Base::PrefetchStages; [[nodiscard]] CK_TILE_HOST static const std::string GetName() @@ -165,6 +178,11 @@ struct MxFp4GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3; + static constexpr bool is_b_row_major = + std::is_same_v; + CK_TILE_DEVICE static constexpr auto HotLoopScheduler() { constexpr index_t MPerXDL = BlockGemm::WarpGemm::kM; @@ -207,7 +225,7 @@ struct MxFp4GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3 + CK_TILE_DEVICE static void ScaleTile(const TileType& block_tile, + CastTileType& block_tile_cast, + const ScaleTileType& scale_tile) + { + if constexpr(IsCastBeforeLDS) + { + constexpr auto b_block = TileType::get_distributed_spans(); + + // Internally this is using V_CVT_SCALEF32_PK_BF16_FP4 or V_CVT_SCALEF32_PK_FP16_FP4 + // on gfx950 + auto pk_mxfp4_to_compute_v2 = [](auto pk_mxfp4, float fscale) { + if constexpr(std::is_same_v) + { + return pk_fp4_to_fp16x2(pk_mxfp4, fscale); + } + else if constexpr(std::is_same_v) + { + return pk_fp4_to_bf16x2(pk_mxfp4, fscale); + } + else + { + static_assert(false, "unsupported compute type"); + } + }; + + constexpr index_t BQuantGroupSizeIdx0 = + std::is_same_v + ? BQuantGroupSize::kN + : BQuantGroupSize::kK; + constexpr index_t BQuantGroupSizeIdx1 = + std::is_same_v + ? BQuantGroupSize::kK + : BQuantGroupSize::kN; + + // The input indices are with respect to B block tile. If B and Bq have different + // layouts, the indices must be swapped + auto make_bq_index = [](auto idx0, auto idx1) { + if constexpr(std::is_same_v) + { + return make_tuple( + tile_distributed_index{}, + tile_distributed_index{}); + } + else + { + return make_tuple( + tile_distributed_index{}, + tile_distributed_index{}); + } + }; + + sweep_tile_span(b_block[number<0>{}], [&](auto idx0) { + sweep_tile_span(b_block[number<1>{}], [&](auto idx1) { + if constexpr(std::is_same_v) + { + if constexpr(idx1.impl_.at(0) % BPackedSize == 0) + { + constexpr auto idx1_lo = tile_distributed_index{}; + constexpr auto idx1_hi = + tile_distributed_index{}; + + constexpr auto i_j_idx_lo = make_tuple(idx0, idx1_lo); + constexpr auto i_j_idx_hi = make_tuple(idx0, idx1_hi); + + constexpr auto i_j_idx = make_tuple(idx0, idx1); + auto b_pack = block_tile[i_j_idx]; + + constexpr auto i_j_idx_scale_lo = make_bq_index(idx0, idx1_lo); + constexpr auto i_j_idx_scale_hi = make_bq_index(idx0, idx1_hi); + + // If the scale is the same for packed values, use pk cvt scale + // instructions, otherwise scale and cast element by element + if constexpr(i_j_idx_scale_lo[I0{}].impl_.at(0) == + i_j_idx_scale_hi[I0{}].impl_.at(0) && + i_j_idx_scale_lo[I1{}].impl_.at(0) == + i_j_idx_scale_hi[I1{}].impl_.at(0)) + { + float scale = float(scale_tile[i_j_idx_scale_lo]); + auto cvt = pk_mxfp4_to_compute_v2(b_pack, scale); + + block_tile_cast(i_j_idx_lo) = cvt.x; + block_tile_cast(i_j_idx_hi) = cvt.y; + } + else + { + float scale_lo = float(scale_tile[i_j_idx_scale_lo]); + auto b_f4_lo = + type_convert(b_pack.unpack(number<0>{})); + block_tile_cast(i_j_idx_lo) = type_convert( + type_convert(b_f4_lo) * scale_lo); + + float scale_hi = float(scale_tile[i_j_idx_scale_hi]); + auto b_f4_hi = + type_convert(b_pack.unpack(number<1>{})); + block_tile_cast(i_j_idx_hi) = type_convert( + type_convert(b_f4_hi) * scale_hi); + } + } + } + else + { + constexpr auto i_j_idx = make_tuple(idx0, idx1); + constexpr auto i_j_idx_scale = make_bq_index(idx0, idx1); + float scale = float(scale_tile[i_j_idx_scale]); + + auto b_pack = block_tile[i_j_idx]; + block_tile_cast(i_j_idx) = + type_convert(type_convert(b_pack) * scale); + } + }); + }); + } + } + + template + CK_TILE_DEVICE void ALocalPrefill(WindowType& lds_window, + const TileType& block_tile, + const ElementwiseFunc& element_func) const + { + if constexpr(is_a_col_major && !is_a_load_tr_v()) + { + auto a_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledARegTileDistribution()); + transpose_tile2d(a_shuffle_tmp, block_tile); + Base::LocalPrefill(lds_window, a_shuffle_tmp, element_func); + } + else + { + Base::LocalPrefill(lds_window, block_tile, element_func); + } + } + + template + CK_TILE_DEVICE void BLocalPrefill(WindowType& lds_window, + const TileType& block_tile, + const TileTypeCast& block_tile_cast, + const ElementwiseFunc& element_func) const + { + // Fill LDS and apply the scale if IsCastBeforeLDS + auto get_b_block_tile = [](auto& b_block_tile_orig, auto& b_block_tile_cast) { + if constexpr(IsCastBeforeLDS) + { + return b_block_tile_cast; + } + else + { + return b_block_tile_orig; + } + }; + + if constexpr(is_b_row_major && !is_b_load_tr_v()) + { + auto b_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledBRegTileDistribution()); + transpose_tile2d(b_shuffle_tmp, get_b_block_tile(block_tile, block_tile_cast)); + Base::LocalPrefill(lds_window, b_shuffle_tmp, element_func); + } + else + { + Base::LocalPrefill( + lds_window, get_b_block_tile(block_tile, block_tile_cast), element_func); + } + } + + template + CK_TILE_DEVICE void LocalPrefetch(BlockGemmType& block_gemm, + const AWindowType& a_lds_window, + const BWindowType& b_lds_window, + const QTileType& q_block_tile) const + { + // Load from LDS + // It can apply the scale and cast if we scale after reading from LDS + if constexpr(IsCastBeforeLDS) + { + block_gemm.LocalPrefetch( + a_lds_window, b_lds_window, is_a_load_tr_v, is_b_load_tr_v); + } + else + { + block_gemm.LocalPrefetch( + a_lds_window, b_lds_window, q_block_tile, is_a_load_tr_v, is_b_load_tr_v); + } + } + template > && std::is_same_v; constexpr bool is_bq_col_major = std::is_same_v; - constexpr bool is_b_row_major = std::is_same_v; - static_assert(is_bq_col_major, "Bq must be col major (row major not supported yet)"); - static_assert(NPerBlockBQ == BQDramBlockWindowTmp{}.get_window_lengths()[I0{}] && - KPerBlockBQ == BQDramBlockWindowTmp{}.get_window_lengths()[I1{}], + static_assert(is_bq_col_major + ? (NPerBlockBQ == BQDramBlockWindowTmp{}.get_window_lengths()[I0{}] && + KPerBlockBQ == BQDramBlockWindowTmp{}.get_window_lengths()[I1{}]) + : (KPerBlockBQ == BQDramBlockWindowTmp{}.get_window_lengths()[I0{}] && + NPerBlockBQ == BQDramBlockWindowTmp{}.get_window_lengths()[I1{}]), "Bq block window has incorrect lengths for defined BqLayout!"); static_assert(is_a_col_major @@ -347,13 +557,12 @@ struct MxFp4GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3()); auto bq_copy_dram_window = Base::GetBQDramLoadWindow(bq_dram_block_window_tmp); auto bq_block_tile = decltype(load_tile(bq_copy_dram_window)){}; + // This defines the scaled and casted block tile for B matrix. + // Effectively, it is used only if we scale and cast before writing to LDS. + auto bdq_block_tile = make_static_distributed_tensor( + Policy::template MakeBRegTileDistribution()); + // Block GEMM auto block_gemm = BlockGemm(); auto c_block_tile = block_gemm.MakeCBlockTile(); using ABlockTileDistr = decltype(a_copy_dram_window.get_tile_distribution()); - // using BBlockTileDistr = decltype(b_copy_dram_window.get_tile_distribution()); using BBlockTileDistr = decltype(b_copy_dram_window.get_tile_distribution()); using ABlockTile = @@ -402,114 +610,61 @@ struct MxFp4GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3(BBlockTileDistr{})); ABlockTile a_block_tile; - BBlockTile b_fp4_block_tile; + BBlockTile b_block_tile; - using ADramTileWindowStep = typename ADramBlockWindowTmp::BottomTensorIndex; - using BDramTileWindowStep = typename BDramBlockWindowTmp::BottomTensorIndex; + using ADramTileWindowStep = typename ADramBlockWindowTmp::BottomTensorIndex; + using BDramTileWindowStep = typename BDramBlockWindowTmp::BottomTensorIndex; + using BQDramTileWindowStep = typename BQDramBlockWindowTmp::BottomTensorIndex; constexpr ADramTileWindowStep a_dram_tile_window_step = is_a_col_major ? make_array(KPerBlock, 0) : make_array(0, KPerBlock); constexpr BDramTileWindowStep b_dram_tile_window_step = - is_b_row_major ? make_array(KPerBlock / 2, 0) : make_array(0, KPerBlock / 2); + is_b_row_major ? make_array(KPerBlock, 0) : make_array(0, KPerBlock); - constexpr index_t b_scale_dram_tile_window_step = KPerBlock / BQuantGroupSize::kK; + constexpr BQDramTileWindowStep b_scale_dram_tile_window_step = + std::is_same_v + ? make_array(0, KPerBlock / BQuantGroupSize::kK) + : make_array(KPerBlock / BQuantGroupSize::kK, 0); // ----------------------------------------------------------------------------------------- // Gemm pipeline start - // prefetch - // global read 0 - // auto a_scale_block_tile = decltype(load_tile(a_scale_copy_dram_window)){}; + // prefetch stages + + // Vmem -> Vgpr 0 Base::GlobalPrefetch(a_block_tile, a_copy_dram_window, a_dram_tile_window_step); - Base::GlobalPrefetch(b_fp4_block_tile, b_copy_dram_window, b_dram_tile_window_step); - // BDataType - auto b_block_tile = make_static_distributed_tensor( - Policy::template MakeBRegTileDistribution()); + Base::GlobalPrefetch(b_block_tile, b_copy_dram_window, b_dram_tile_window_step); + // Vmem -> Vgpr 0 (Q matrix) + // Scale and cast tile before writing to LDS (if IsCastBeforeLDS) bq_block_tile = load_tile(bq_copy_dram_window); - move_tile_window(bq_copy_dram_window, {0, b_scale_dram_tile_window_step}); + move_tile_window(bq_copy_dram_window, b_scale_dram_tile_window_step); + ScaleTile(b_block_tile, bdq_block_tile, bq_block_tile); - constexpr auto idx1_js = tile_distributed_index<0>{}; - constexpr auto b_block = decltype(b_fp4_block_tile)::get_distributed_spans(); - sweep_tile_span(b_block[number<0>{}], [&](auto idx0) { - sweep_tile_span(b_block[number<1>{}], [&](auto idx1) { - constexpr auto i_j_idx = make_tuple(idx0, idx1); - constexpr auto i_j_idx_scale = make_tuple(idx0, idx1_js); - auto b_scale_uint = type_convert(bq_block_tile(i_j_idx_scale)) - 127; - auto b_scale = type_convert(std::pow(2.0f, b_scale_uint)); - constexpr auto idx1_lo = tile_distributed_index{}; - constexpr auto idx1_hi = tile_distributed_index{}; - constexpr auto i_j_idx_lo = make_tuple(idx0, idx1_lo); - constexpr auto i_j_idx_hi = make_tuple(idx0, idx1_hi); - - auto b_pack = type_convert(b_fp4_block_tile(i_j_idx)); - auto b_f4_lo = type_convert(b_pack.unpack(number<0>{})); - auto b_f4_hi = type_convert(b_pack.unpack(number<1>{})); - b_block_tile(i_j_idx_lo) = - type_convert(type_convert(b_f4_lo) * b_scale); - b_block_tile(i_j_idx_hi) = - type_convert(type_convert(b_f4_hi) * b_scale); - }); - }); - - // initialize C + // initialize C tile to zero tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile); block_sync_lds(); - // LDS write 0 - if constexpr(is_a_col_major) - { - auto a_shuffle_tmp = make_static_distributed_tensor( - Policy::template MakeShuffledARegTileDistribution()); - transpose_tile2d(a_shuffle_tmp, a_block_tile); - Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func); - } - else - { - Base::LocalPrefill(a_copy_lds_window, a_block_tile, a_element_func); - } - - if constexpr(is_b_row_major) - { - auto b_shuffle_tmp = make_static_distributed_tensor( - Policy::template MakeShuffledBRegTileDistribution()); - transpose_tile2d(b_shuffle_tmp, b_block_tile); - Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func); - } - else - { - Base::LocalPrefill(b_copy_lds_window, b_block_tile, b_element_func); - } + // Vgpr -> LDS 0 + ALocalPrefill(a_copy_lds_window, a_block_tile, a_element_func); + BLocalPrefill(b_copy_lds_window, b_block_tile, bdq_block_tile, b_element_func); + // Vmem -> Vgpr 1 Base::GlobalPrefetch(a_block_tile, a_copy_dram_window, a_dram_tile_window_step); - Base::GlobalPrefetch(b_fp4_block_tile, b_copy_dram_window, b_dram_tile_window_step); + Base::GlobalPrefetch(b_block_tile, b_copy_dram_window, b_dram_tile_window_step); - bq_block_tile = load_tile(bq_copy_dram_window); - move_tile_window(bq_copy_dram_window, {0, b_scale_dram_tile_window_step}); - - sweep_tile_span(b_block[number<0>{}], [&](auto idx0) { - sweep_tile_span(b_block[number<1>{}], [&](auto idx1) { - constexpr auto i_j_idx = make_tuple(idx0, idx1); - constexpr auto i_j_idx_scale = make_tuple(idx0, idx1_js); - - auto b_scale_uint = type_convert(bq_block_tile(i_j_idx_scale)) - 127; - auto b_scale = type_convert(std::pow(2.0f, b_scale_uint)); - constexpr auto idx1_lo = tile_distributed_index{}; - constexpr auto idx1_hi = tile_distributed_index{}; - constexpr auto i_j_idx_lo = make_tuple(idx0, idx1_lo); - constexpr auto i_j_idx_hi = make_tuple(idx0, idx1_hi); - - auto b_pack = type_convert(b_fp4_block_tile(i_j_idx)); - auto b_f4_lo = type_convert(b_pack.unpack(number<0>{})); - auto b_f4_hi = type_convert(b_pack.unpack(number<1>{})); - b_block_tile(i_j_idx_lo) = - type_convert(type_convert(b_f4_lo) * b_scale); - b_block_tile(i_j_idx_hi) = - type_convert(type_convert(b_f4_hi) * b_scale); - }); - }); + // If we scale and cast before writing to LDS, + // we need to read another tile of Q matrix from Vmem, then scale and cast tile + if constexpr(IsCastBeforeLDS) + { + bq_block_tile = load_tile(bq_copy_dram_window); + move_tile_window(bq_copy_dram_window, b_scale_dram_tile_window_step); + } + ScaleTile(b_block_tile, bdq_block_tile, bq_block_tile); block_sync_lds(); - block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window); + + // LDS -> Vgpr 0 + LocalPrefetch(block_gemm, a_lds_gemm_window, b_lds_gemm_window, bq_block_tile); __builtin_amdgcn_sched_barrier(0); @@ -521,72 +676,34 @@ struct MxFp4GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3( - Policy::template MakeShuffledARegTileDistribution()); - transpose_tile2d(a_shuffle_tmp, a_block_tile); - Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func); - } - else - { - Base::LocalPrefill(a_copy_lds_window, a_block_tile, a_element_func); - } - if constexpr(is_b_row_major) - { - auto b_shuffle_tmp = make_static_distributed_tensor( - Policy::template MakeShuffledBRegTileDistribution()); - transpose_tile2d(b_shuffle_tmp, b_block_tile); - Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func); - } - else - { - Base::LocalPrefill(b_copy_lds_window, b_block_tile, b_element_func); - } + // Vgpr -> LDS + ALocalPrefill(a_copy_lds_window, a_block_tile, a_element_func); + BLocalPrefill(b_copy_lds_window, b_block_tile, bdq_block_tile, b_element_func); + // Vmem -> Vgpr Base::GlobalPrefetch(a_block_tile, a_copy_dram_window, a_dram_tile_window_step); - Base::GlobalPrefetch( - b_fp4_block_tile, b_copy_dram_window, b_dram_tile_window_step); + Base::GlobalPrefetch(b_block_tile, b_copy_dram_window, b_dram_tile_window_step); + // Vmem -> Vgpr (Q matrix) + // Scale and cast tile before writing to LDS (if IsCastBeforeLDS) bq_block_tile = load_tile(bq_copy_dram_window); - move_tile_window(bq_copy_dram_window, {0, b_scale_dram_tile_window_step}); - - sweep_tile_span(b_block[number<0>{}], [&](auto idx0) { - sweep_tile_span(b_block[number<1>{}], [&](auto idx1) { - constexpr auto i_j_idx = make_tuple(idx0, idx1); - constexpr auto i_j_idx_scale = make_tuple(idx0, idx1_js); - - auto b_scale_uint = - type_convert(bq_block_tile(i_j_idx_scale)) - 127; - auto b_scale = type_convert(std::pow(2.0f, b_scale_uint)); - constexpr auto idx1_lo = tile_distributed_index{}; - constexpr auto idx1_hi = - tile_distributed_index{}; - constexpr auto i_j_idx_lo = make_tuple(idx0, idx1_lo); - constexpr auto i_j_idx_hi = make_tuple(idx0, idx1_hi); - - auto b_pack = type_convert(b_fp4_block_tile(i_j_idx)); - auto b_f4_lo = type_convert(b_pack.unpack(number<0>{})); - auto b_f4_hi = type_convert(b_pack.unpack(number<1>{})); - b_block_tile(i_j_idx_lo) = - type_convert(type_convert(b_f4_lo) * b_scale); - b_block_tile(i_j_idx_hi) = - type_convert(type_convert(b_f4_hi) * b_scale); - }); - }); + move_tile_window(bq_copy_dram_window, b_scale_dram_tile_window_step); + ScaleTile(b_block_tile, bdq_block_tile, bq_block_tile); + // Consume tile block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window); block_sync_lds(); - block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window); + // LDS -> Vgpr + LocalPrefetch(block_gemm, a_lds_gemm_window, b_lds_gemm_window, bq_block_tile); + HotLoopScheduler(); __builtin_amdgcn_sched_barrier(0); i += 1; - // b_block_stride +=1; } while(i < (num_loop - 1)); } - // tile_elementwise_inout([](auto& c) { c = 0; }, acc_block_tile); + // tail if constexpr((TailNum == TailNumber::Full) || (TailNum == TailNumber::Odd)) { @@ -596,35 +713,31 @@ struct MxFp4GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3( - Policy::template MakeShuffledARegTileDistribution()); - transpose_tile2d(a_shuffle_tmp, a_block_tile); - Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func); - } - else - { - Base::LocalPrefill(a_copy_lds_window, a_block_tile, a_element_func); - } - if constexpr(is_b_row_major) - { - auto b_shuffle_tmp = make_static_distributed_tensor( - Policy::template MakeShuffledBRegTileDistribution()); - transpose_tile2d(b_shuffle_tmp, b_block_tile); - Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func); - } - else - { - Base::LocalPrefill(b_copy_lds_window, b_block_tile, b_element_func); - } + // Vgpr -> LDS last tile + ALocalPrefill(a_copy_lds_window, a_block_tile, a_element_func); + BLocalPrefill(b_copy_lds_window, b_block_tile, bdq_block_tile, b_element_func); block_sync_lds(); - block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window); + // LDS -> Vgpr last tile + LocalPrefetch(block_gemm, a_lds_gemm_window, b_lds_gemm_window, bq_block_tile); + + // Consume last tile block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window); block_sync_lds(); } @@ -653,9 +766,9 @@ struct MxFp4GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3{}.template operator()( a_dram_block_window_tmp, - [](const ADataType& a) { return a; }, + identity{}, b_dram_block_window_tmp, - [](const BDqDataType& b) { return b; }, + identity{}, bq_dram_block_window_tmp, num_loop, p_smem); diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_mxfp4_pipeline_ag_bg_cr_policy.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_mxfp4_pipeline_ag_bg_cr_policy.hpp deleted file mode 100644 index 6cf9e22f41..0000000000 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_mxfp4_pipeline_ag_bg_cr_policy.hpp +++ /dev/null @@ -1,140 +0,0 @@ -// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT - -#pragma once - -#include "ck_tile/core.hpp" -#include "ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp" -#include "gemm_group_quant_utils.hpp" - -namespace ck_tile { - -struct GemmMxFp4PipelineAgBgCrPolicy : public UniversalGemmPipelineAgBgCrPolicy -{ - using Base = UniversalGemmPipelineAgBgCrPolicy; - using Base::I0; - using Base::I1; - using Base::I2; - - template - CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeBQ() - { - using BQLayout = remove_cvref_t; - using BQDataType = remove_cvref_t; - constexpr index_t NPerBlock = Problem::BlockGemmShape::kN; - constexpr index_t NPerBlockBQ = NPerBlock / Problem::BQuantGroupSize::kN; - constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; - constexpr index_t KPerBlockBQ = KPerBlock / Problem::BQuantGroupSize::kK; - - static_assert(std::is_same_v); - return GetABQGlobalVectorLoadSize(); - } - - template - CK_TILE_HOST_DEVICE static constexpr auto MakeBRegTileDistribution() - { - using BLayout = remove_cvref_t; - - constexpr index_t BlockSize = Problem::kBlockSize; - constexpr index_t NPerBlock = Problem::BlockGemmShape::kN; - constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; - constexpr index_t VecLoadSize = - Problem::FixedVectorSize ? Problem::VectorSizeB : GetVectorSizeB(); - constexpr index_t NumWaveGroups = Problem::NumWaveGroups; - // Tile: KPerBlock X NPerBlock - if constexpr(std::is_same_v) - { - using TileEncodingPattern = - tile_distribution_encoding_pattern_2d; - return TileEncodingPattern::make_2d_static_tile_distribution(); - } - // Tile: NPerBlock X KPerBlock - else - { - using TileEncodingPattern = - tile_distribution_encoding_pattern_2d; - return TileEncodingPattern::make_2d_static_tile_distribution(); - } - } - - template - CK_TILE_HOST_DEVICE static constexpr auto MakeBQDramTileDistribution() - { - // using BLayout = remove_cvref_t; - - constexpr index_t BlockSize = Problem::kBlockSize; - constexpr index_t NPerBlock = Problem::BlockGemmShape::kN; - constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; - - constexpr index_t KScale = KPerBlock / Problem::BQuantGroupSize::kK; // k_scale num //2 - constexpr index_t VecLoadSize = - Problem::FixedVectorSize ? Problem::VectorSizeB : GetVectorSizeB(); - constexpr index_t NumWaveGroups = Problem::NumWaveGroups; - - constexpr index_t warp_size = get_warp_size(); - constexpr index_t num_warps = BlockSize / get_warp_size(); - constexpr index_t LargestVec = (KPerBlock * NPerBlock) / (num_warps * warp_size); - constexpr index_t b_vec = VecLoadSize > LargestVec ? LargestVec : VecLoadSize; - constexpr index_t K0 = KPerBlock / b_vec; - constexpr index_t K1 = K0 / KScale; - constexpr index_t K3 = K0 / K1; - constexpr index_t K2 = 1; - - constexpr index_t N0 = num_warps / NumWaveGroups; - constexpr index_t N1 = warp_size / K0; - constexpr index_t N2 = NPerBlock / (N0 * N1); - - return make_static_tile_distribution( - tile_distribution_encoding, - tuple, sequence>, - tuple, sequence<1, 2, 0>>, - tuple, sequence<1, 0, 0>>, - sequence<1, 2>, - sequence<2, 1>>{}); - } - - template - CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm() - { - using BlockWarps = typename Problem::BlockGemmShape::BlockWarps; - using WarpTile = typename Problem::BlockGemmShape::WarpTile; - - static_assert(Problem::BQuantGroupSize::kK % WarpTile::at(I2) == 0, - "KPerWarpGemm must be a multiple of QuantGroupSize!"); - - using WarpGemm = WarpGemmDispatcher; - static_assert(std::is_same_v || - std::is_same_v || - std::is_same_v); - static_assert(std::is_same_v); - - using BlockGemmPolicy = BlockGemmASmemBSmemCRegV1CustomPolicy< - typename Problem::ADataType, - std::conditional_t, - typename Problem::ADataType, - typename Problem::BDataType>, - typename Problem::CDataType, - BlockWarps, - WarpGemm>; - - return BlockUniversalGemmAsBsCr{}; - } -}; - -} // namespace ck_tile diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_quant_pipeline_problem.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_quant_pipeline_problem.hpp index 9b02585e69..fdaebe8010 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_quant_pipeline_problem.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_quant_pipeline_problem.hpp @@ -24,7 +24,8 @@ template + TailNumber TailNum_ = TailNumber::Full, + CastPolicy BCastPolicy_ = CastPolicy::AfterLDSRead> struct GemmQuantPipelineProblemBase : public GemmPipelineProblemBase< ADataType_, @@ -82,6 +83,20 @@ struct GemmQuantPipelineProblemBase static constexpr auto HasHotLoop = HasHotLoop_; static constexpr auto TailNum = TailNum_; + // gfx950 supports load with transpose for 4bit types, so we can transpose + // pk_fp4_t from LDS in registers. But without this instruction, + // the transpose is done in register between Vmem read and LDS write and + // the implementation does not support 4 bit types +#ifdef __gfx950__ + static constexpr auto BCastPolicy = BCastPolicy_; +#else + static constexpr auto BCastPolicy = + std::is_same_v && + std::is_same_v + ? CastPolicy::BeforeLDSWrite + : BCastPolicy_; +#endif + static_assert(BlockGemmShape::kM % AQuantGroupSize::kM == 0); static_assert(BlockGemmShape::kK % AQuantGroupSize::kK == 0); static_assert(BlockGemmShape::kM % BQuantGroupSize::kM == 0); @@ -155,7 +170,8 @@ template + TailNumber TailNum_ = TailNumber::Full, + CastPolicy BCastPolicy_ = CastPolicy::AfterLDSRead> using GemmBQuantPipelineProblem = GemmQuantPipelineProblemBase; + TailNum_, + BCastPolicy_>; template , + std::is_same_v, ADataType_, std::conditional_t>; // Calculate thresholds diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_1d_128.cpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_1d_128.cpp index d491d89ef4..0e6e40b788 100644 --- a/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_1d_128.cpp +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_1d_128.cpp @@ -25,9 +25,9 @@ using GroupSize = ck_tile::QuantGroupShape>; // clang-format off using BQuant1D128Types = ::testing::Types< // 1d cases with grouping only on k axis - std::tuple, - std::tuple, - std::tuple + std::tuple, + std::tuple, + std::tuple >; // clang-format on diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_microscale_ccr_1d_128.cpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_microscale_ccr_1d_128.cpp new file mode 100644 index 0000000000..94572a80dc --- /dev/null +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_microscale_ccr_1d_128.cpp @@ -0,0 +1,41 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck_tile/host.hpp" +#include "ck_tile/ops/gemm.hpp" + +#include +#include + +#include "test_gemm_quant_fixtures.hpp" + +// Type aliases for readability +using RowMajor = ck_tile::tensor_layout::gemm::RowMajor; +using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor; +using BF8 = ck_tile::bf8_t; +using BF16 = ck_tile::bf16_t; +using PkFP4 = ck_tile::pk_fp4_t; +using E8M0 = ck_tile::e8m0_t; +using BQuantGrouped = std::integral_constant; +using GroupSize128 = ck_tile::QuantGroupShape>; + +// Type combinations for BQuant tests - 1D GroupSize 128 +// Tuple format: +// clang-format off +using BQuant1D128Types = ::testing::Types< + // CCR BQ: C + std::tuple, + std::tuple, + std::tuple +>; +// clang-format on + +// Test suite for BQuant 1D 128 +TYPED_TEST_SUITE(TestCkTileGemmBQuant, BQuant1D128Types); + +// BQuant tests +TYPED_TEST(TestCkTileGemmBQuant, BQuantGroupedTest) +{ + this->run_test_with_validation(1024, 1024, 1024); +} diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_microscale_ccr_1d_64.cpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_microscale_ccr_1d_64.cpp new file mode 100644 index 0000000000..c6d1f0c341 --- /dev/null +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_microscale_ccr_1d_64.cpp @@ -0,0 +1,45 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck_tile/host.hpp" +#include "ck_tile/ops/gemm.hpp" + +#include +#include + +#include "test_gemm_quant_fixtures.hpp" + +// Type aliases for readability +using RowMajor = ck_tile::tensor_layout::gemm::RowMajor; +using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor; +using FP8 = ck_tile::fp8_t; +using BF8 = ck_tile::bf8_t; +using FP16 = ck_tile::fp16_t; +using BF16 = ck_tile::bf16_t; +using Half = ck_tile::half_t; +using PkInt4 = ck_tile::pk_int4_t; +using PkFP4 = ck_tile::pk_fp4_t; +using E8M0 = ck_tile::e8m0_t; +using BQuantGrouped = std::integral_constant; +using GroupSize64 = ck_tile::QuantGroupShape>; + +// Type combinations for BQuant tests - 1D GroupSize 64 +// Tuple format: +// clang-format off +using BQuant1D64Types = ::testing::Types< + // CCR BQ: C + std::tuple, + std::tuple, + std::tuple +>; +// clang-format on + +// Test suite for BQuant 1D 64 +TYPED_TEST_SUITE(TestCkTileGemmBQuant, BQuant1D64Types); + +// BQuant tests +TYPED_TEST(TestCkTileGemmBQuant, BQuantGroupedTest) +{ + this->run_test_with_validation(1024, 1024, 1024); +} diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_microscale_crr_1d_128.cpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_microscale_crr_1d_128.cpp new file mode 100644 index 0000000000..e8744eb35a --- /dev/null +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_microscale_crr_1d_128.cpp @@ -0,0 +1,42 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck_tile/host.hpp" +#include "ck_tile/ops/gemm.hpp" + +#include +#include + +#include "test_gemm_quant_fixtures.hpp" + +// Type aliases for readability +using RowMajor = ck_tile::tensor_layout::gemm::RowMajor; +using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor; +using BF8 = ck_tile::bf8_t; +using BF16 = ck_tile::bf16_t; +using PkFP4 = ck_tile::pk_fp4_t; +using E8M0 = ck_tile::e8m0_t; +using BQuantGrouped = std::integral_constant; +using GroupSize128 = ck_tile::QuantGroupShape>; + +// Type combinations for BQuant tests - 1D GroupSize 128 +// Tuple format: +// clang-format off +using BQuant1D128Types = ::testing::Types< + // CRR BQ: C + std::tuple, + // CRR BQ: R + std::tuple, + std::tuple +>; +// clang-format on + +// Test suite for BQuant 1D 128 +TYPED_TEST_SUITE(TestCkTileGemmBQuant, BQuant1D128Types); + +// BQuant tests +TYPED_TEST(TestCkTileGemmBQuant, BQuantGroupedTest) +{ + this->run_test_with_validation(1024, 1024, 1024); +} diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_microscale_crr_1d_64.cpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_microscale_crr_1d_64.cpp new file mode 100644 index 0000000000..dbc1ae7f2a --- /dev/null +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_microscale_crr_1d_64.cpp @@ -0,0 +1,42 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck_tile/host.hpp" +#include "ck_tile/ops/gemm.hpp" + +#include +#include + +#include "test_gemm_quant_fixtures.hpp" + +// Type aliases for readability +using RowMajor = ck_tile::tensor_layout::gemm::RowMajor; +using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor; +using BF8 = ck_tile::bf8_t; +using BF16 = ck_tile::bf16_t; +using PkFP4 = ck_tile::pk_fp4_t; +using E8M0 = ck_tile::e8m0_t; +using BQuantGrouped = std::integral_constant; +using GroupSize64 = ck_tile::QuantGroupShape>; + +// Type combinations for BQuant tests - 1D GroupSize 64 +// Tuple format: +// clang-format off +using BQuant1D64Types = ::testing::Types< + // CRR BQ: C + std::tuple, + // CRR BQ: R + std::tuple, + std::tuple +>; +// clang-format on + +// Test suite for BQuant 1D 64 +TYPED_TEST_SUITE(TestCkTileGemmBQuant, BQuant1D64Types); + +// BQuant tests +TYPED_TEST(TestCkTileGemmBQuant, BQuantGroupedTest) +{ + this->run_test_with_validation(1024, 1024, 1024); +} diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_microscale_rcr_1d_128.cpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_microscale_rcr_1d_128.cpp new file mode 100644 index 0000000000..7637b8a12a --- /dev/null +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_microscale_rcr_1d_128.cpp @@ -0,0 +1,51 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck_tile/host.hpp" +#include "ck_tile/ops/gemm.hpp" + +#include +#include + +#include "test_gemm_quant_fixtures.hpp" + +// Type aliases for readability +using RowMajor = ck_tile::tensor_layout::gemm::RowMajor; +using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor; +using FP8 = ck_tile::fp8_t; +using BF8 = ck_tile::bf8_t; +using FP16 = ck_tile::fp16_t; +using BF16 = ck_tile::bf16_t; +using PkFP4 = ck_tile::pk_fp4_t; +using E8M0 = ck_tile::e8m0_t; +using BQuantGrouped = std::integral_constant; +using GroupSize128 = ck_tile::QuantGroupShape>; + +// Type combinations for BQuant tests - 1D GroupSize 128 +// Tuple format: +// clang-format off +using BQuant1D128Types = ::testing::Types< + // RCR BQ: C + std::tuple< RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF16, PkFP4, E8M0, BF16, BQuantGrouped, GemmConfigMx, GroupSize128>, + std::tuple< RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF16, BF8, E8M0, BF16, BQuantGrouped, GemmConfigMx, GroupSize128>, + std::tuple< RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF16, BF16, E8M0, BF16, BQuantGrouped, GemmConfigMx, GroupSize128>, + std::tuple< RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP16, FP16, E8M0, FP16, BQuantGrouped, GemmConfigMx, GroupSize128>, + std::tuple< RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF16, FP8, E8M0, BF16, BQuantGrouped, GemmConfigMx, GroupSize128>, + std::tuple< RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP16, FP8, E8M0, FP16, BQuantGrouped, GemmConfigMx, GroupSize128>, + std::tuple< RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP16, BF8, E8M0, FP16, BQuantGrouped, GemmConfigMx, GroupSize128>, + std::tuple< RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP16, PkFP4, E8M0, FP16, BQuantGrouped, GemmConfigMx, GroupSize128>, + // RCR BQ: R + std::tuple< RowMajor, ColumnMajor, RowMajor, RowMajor, BF16, BF8, E8M0, BF16, BQuantGrouped, GemmConfigMx, GroupSize128>, + std::tuple< RowMajor, ColumnMajor, RowMajor, RowMajor, BF16, BF16, E8M0, BF16, BQuantGrouped, GemmConfigMx, GroupSize128> +>; +// clang-format on + +// Test suite for BQuant 1D 128 +TYPED_TEST_SUITE(TestCkTileGemmBQuant, BQuant1D128Types); + +// BQuant tests +TYPED_TEST(TestCkTileGemmBQuant, BQuantGroupedTest) +{ + this->run_test_with_validation(1024, 1024, 1024); +} diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_microscale_rcr_1d_64.cpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_microscale_rcr_1d_64.cpp new file mode 100644 index 0000000000..aa960ca16e --- /dev/null +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_microscale_rcr_1d_64.cpp @@ -0,0 +1,51 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck_tile/host.hpp" +#include "ck_tile/ops/gemm.hpp" + +#include +#include + +#include "test_gemm_quant_fixtures.hpp" + +// Type aliases for readability +using RowMajor = ck_tile::tensor_layout::gemm::RowMajor; +using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor; +using FP8 = ck_tile::fp8_t; +using BF8 = ck_tile::bf8_t; +using FP16 = ck_tile::fp16_t; +using BF16 = ck_tile::bf16_t; +using PkFP4 = ck_tile::pk_fp4_t; +using E8M0 = ck_tile::e8m0_t; +using BQuantGrouped = std::integral_constant; +using GroupSize64 = ck_tile::QuantGroupShape>; + +// Type combinations for BQuant tests - 1D GroupSize 64 +// Tuple format: +// clang-format off +using BQuant1D64Types = ::testing::Types< + // RCR BQ: C + std::tuple< RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF16, PkFP4, E8M0, BF16, BQuantGrouped, GemmConfigMx, GroupSize64>, + std::tuple< RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF16, BF8, E8M0, BF16, BQuantGrouped, GemmConfigMx, GroupSize64>, + std::tuple< RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF16, BF16, E8M0, BF16, BQuantGrouped, GemmConfigMx, GroupSize64>, + std::tuple< RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP16, FP16, E8M0, FP16, BQuantGrouped, GemmConfigMx, GroupSize64>, + std::tuple< RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF16, FP8, E8M0, BF16, BQuantGrouped, GemmConfigMx, GroupSize64>, + std::tuple< RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP16, FP8, E8M0, FP16, BQuantGrouped, GemmConfigMx, GroupSize64>, + std::tuple< RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP16, BF8, E8M0, FP16, BQuantGrouped, GemmConfigMx, GroupSize64>, + std::tuple< RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP16, PkFP4, E8M0, FP16, BQuantGrouped, GemmConfigMx, GroupSize64>, + // RCR BQ: R + std::tuple< RowMajor, ColumnMajor, RowMajor, RowMajor, BF16, BF8, E8M0, BF16, BQuantGrouped, GemmConfigMx, GroupSize64>, + std::tuple< RowMajor, ColumnMajor, RowMajor, RowMajor, BF16, BF16, E8M0, BF16, BQuantGrouped, GemmConfigMx, GroupSize64> +>; +// clang-format on + +// Test suite for BQuant 1D 64 +TYPED_TEST_SUITE(TestCkTileGemmBQuant, BQuant1D64Types); + +// BQuant tests +TYPED_TEST(TestCkTileGemmBQuant, BQuantGroupedTest) +{ + this->run_test_with_validation(1024, 1024, 1024); +} diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_microscale_rrr_1d_128.cpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_microscale_rrr_1d_128.cpp new file mode 100644 index 0000000000..f181b432d4 --- /dev/null +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_microscale_rrr_1d_128.cpp @@ -0,0 +1,43 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck_tile/host.hpp" +#include "ck_tile/ops/gemm.hpp" + +#include +#include + +#include "test_gemm_quant_fixtures.hpp" + +// Type aliases for readability +using RowMajor = ck_tile::tensor_layout::gemm::RowMajor; +using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor; +using BF8 = ck_tile::bf8_t; +using BF16 = ck_tile::bf16_t; +using PkFP4 = ck_tile::pk_fp4_t; +using E8M0 = ck_tile::e8m0_t; +using BQuantGrouped = std::integral_constant; +using GroupSize128 = ck_tile::QuantGroupShape>; + +// Type combinations for BQuant tests - 1D GroupSize 128 +// Tuple format: +// clang-format off +using BQuant1D128Types = ::testing::Types< + // RRR BQ: C + std::tuple< RowMajor, RowMajor, RowMajor, ColumnMajor, BF16, BF8, E8M0, BF16, BQuantGrouped, GemmConfigMx, GroupSize128>, + std::tuple< RowMajor, RowMajor, RowMajor, ColumnMajor, BF16, PkFP4, E8M0, BF16, BQuantGrouped, GemmConfigMxFP4, GroupSize128>, + std::tuple< RowMajor, RowMajor, RowMajor, ColumnMajor, BF16, BF16, E8M0, BF16, BQuantGrouped, GemmConfigMx, GroupSize128>, + // RRR BQ: R + std::tuple< RowMajor, RowMajor, RowMajor, RowMajor, BF16, BF16, E8M0, BF16, BQuantGrouped, GemmConfigMx, GroupSize128> +>; +// clang-format on + +// Test suite for BQuant 1D 128 +TYPED_TEST_SUITE(TestCkTileGemmBQuant, BQuant1D128Types); + +// BQuant tests +TYPED_TEST(TestCkTileGemmBQuant, BQuantGroupedTest) +{ + this->run_test_with_validation(1024, 1024, 1024); +} diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_microscale_rrr_1d_64.cpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_microscale_rrr_1d_64.cpp new file mode 100644 index 0000000000..a02136b7db --- /dev/null +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_microscale_rrr_1d_64.cpp @@ -0,0 +1,43 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck_tile/host.hpp" +#include "ck_tile/ops/gemm.hpp" + +#include +#include + +#include "test_gemm_quant_fixtures.hpp" + +// Type aliases for readability +using RowMajor = ck_tile::tensor_layout::gemm::RowMajor; +using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor; +using BF8 = ck_tile::bf8_t; +using BF16 = ck_tile::bf16_t; +using PkFP4 = ck_tile::pk_fp4_t; +using E8M0 = ck_tile::e8m0_t; +using BQuantGrouped = std::integral_constant; +using GroupSize64 = ck_tile::QuantGroupShape>; + +// Type combinations for BQuant tests - 1D GroupSize 64 +// Tuple format: +// clang-format off +using BQuant1D64Types = ::testing::Types< + // RRR BQ: C + std::tuple< RowMajor, RowMajor, RowMajor, ColumnMajor, BF16, BF8, E8M0, BF16, BQuantGrouped, GemmConfigMx, GroupSize64>, + std::tuple< RowMajor, RowMajor, RowMajor, ColumnMajor, BF16, PkFP4, E8M0, BF16, BQuantGrouped, GemmConfigMxFP4, GroupSize64>, + std::tuple< RowMajor, RowMajor, RowMajor, ColumnMajor, BF16, BF16, E8M0, BF16, BQuantGrouped, GemmConfigMx, GroupSize64>, + // RRR BQ: R + std::tuple< RowMajor, RowMajor, RowMajor, RowMajor, BF16, BF16, E8M0, BF16, BQuantGrouped, GemmConfigMx, GroupSize64> +>; +// clang-format on + +// Test suite for BQuant 1D 64 +TYPED_TEST_SUITE(TestCkTileGemmBQuant, BQuant1D64Types); + +// BQuant tests +TYPED_TEST(TestCkTileGemmBQuant, BQuantGroupedTest) +{ + this->run_test_with_validation(1024, 1024, 1024); +} diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp index 11fa6e038a..5a26034182 100644 --- a/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp @@ -102,13 +102,24 @@ struct GemmConfigDecodeInterwave : public GemmConfigBase static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Interwave; }; -struct GemmConfigMxFp4 : public GemmConfigBase +struct GemmConfigMx : public GemmConfigBase { static constexpr ck_tile::index_t M_Tile = 128; static constexpr ck_tile::index_t N_Tile = 128; static constexpr ck_tile::index_t K_Tile = 128; }; +// This configuration uses K_Warp_Tile = 64 on CDNA. In this way, on gfx950 we can use +// LDS load transpose on matrix B (FP4) because the instruction requires each +// lane to load 16 4bits elements +struct GemmConfigMxFP4 : public GemmConfigBase +{ + static constexpr ck_tile::index_t M_Tile = 128; + static constexpr ck_tile::index_t N_Tile = 128; + static constexpr ck_tile::index_t K_Tile = 128; + static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); +}; + struct GemmConfigPreshuffleQuant : public GemmConfigBase { static constexpr bool APreshuffleQuant = true; @@ -666,8 +677,7 @@ class TestCkTileGemmBQuant : public TestCkTileGemmQuantBase ? (K / 2) : K; + const ck_tile::index_t stride_B = K; const ck_tile::index_t stride_C = N; // BQuant uses block/grouped quantization for B matrix @@ -678,24 +688,36 @@ class TestCkTileGemmBQuant : public TestCkTileGemmQuantBase a_m_k( ck_tile::host_tensor_descriptor(M, K, stride_A, this->is_row_major(ALayout{}))); - ck_tile::HostTensor b_k_n(ck_tile::host_tensor_descriptor( - std::is_same_v ? K / 2 : K, - N, - stride_B, - this->is_row_major(BLayout{}))); + ck_tile::HostTensor b_k_n( + ck_tile::host_tensor_descriptor(K, N, stride_B, this->is_row_major(BLayout{}))); ck_tile::HostTensor bq_bqk_bqn( ck_tile::host_tensor_descriptor(BQK, BQN, stride_BQ, this->is_row_major(BQLayout{}))); // Initialize data with random values ck_tile::FillUniformDistribution{-0.5f, 0.5f}(a_m_k); - if constexpr(std::is_same_v) + if constexpr(std::is_same_v) { ck_tile::FillUniformDistribution{-5.0f, 5.0f}(b_k_n); - ck_tile::FillUniformDistribution{125.f, 130.f}(bq_bqk_bqn); } else { ck_tile::FillUniformDistribution{0.f, 1.f}(b_k_n); + } + + if constexpr(std::is_same_v) + { + auto gen_scales = [&](auto& scales, float range_min, float range_max) { + // e8m0_t is basically an exponent of float32 + ck_tile::HostTensor pow2(scales.get_lengths()); + ck_tile::FillUniformDistributionIntegerValue{range_min, range_max}(pow2); + scales.ForEach([&](auto& self, const auto& i) { + self(i) = static_cast(std::exp2(pow2(i))); + }); + }; + gen_scales(bq_bqk_bqn, -2, 2); + } + else + { ck_tile::FillUniformDistribution{-1.0f, 1.0f}(bq_bqk_bqn); } @@ -780,14 +802,15 @@ class TestCkTileGemmBQuant : public TestCkTileGemmQuantBase) - ck_tile::reference_mxfp4gemm_quant(a_m_k, bq_bqk_bqn, b_k_n, c_m_n_host_ref); + if constexpr(std::is_same_v) + ck_tile::reference_mx_gemm_bquant(a_m_k, bq_bqk_bqn, b_k_n, c_m_n_host_ref); else ck_tile::reference_gemm_quant + ? ck_tile::CastPolicy::BeforeLDSWrite + : ck_tile::CastPolicy::AfterLDSRead; using PipelineProblem = ck_tile::GemmBQuantPipelineProblem; + tail_number_v, + b_cast_policy_v>; using GemmPipeline = std::conditional_t< PreshuffleB == false, - std::conditional_t, - ck_tile::MxFp4GemmPipelineAgBgCrCompV3, + std::conditional_t, + ck_tile::MicroscaleGemmPipelineAgBgCrCompV3, ck_tile::BQuantGemmPipelineAgBgCrCompV3>, ck_tile::WPQuantBPipelineAgBgCrV2>; using GemmEpilogue = ck_tile::CShuffleEpilogue, + std::conditional_t, ADataType, BDataType>, ck_tile::tuple<>,