diff --git a/CHANGELOG.md b/CHANGELOG.md index 2d88da364a..38669385f3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -28,6 +28,7 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj * Added support for elementwise kernel. * Added benchmarking support for tile engine GEMM Multi D. * Added block scaling support in CK_TILE GEMM, allowing flexible use of quantization matrices from either A or B operands. +* Added the row-wise column-wise quantization for CK_TILE GEMM & CK_TILE Grouped GEMM. ### Optimized diff --git a/example/ck_tile/03_gemm/gemm_utils.hpp b/example/ck_tile/03_gemm/gemm_utils.hpp index 0a6e65c345..9cf43a986e 100644 --- a/example/ck_tile/03_gemm/gemm_utils.hpp +++ b/example/ck_tile/03_gemm/gemm_utils.hpp @@ -276,6 +276,8 @@ struct GemmConfigPreshuffleDecode : public GemmConfigBase static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_PRESHUFFLE_V2; static constexpr bool Preshuffle = true; static constexpr bool DoubleSmemBuffer = true; + static constexpr int N_Repeat = N_Tile / N_Warp_Tile / N_Warp; + static constexpr bool TiledMMAPermuteN = N_Repeat % 2 == 0; }; template @@ -298,6 +300,8 @@ struct GemmConfigPreshufflePrefill : public GemmConfigBase static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_PRESHUFFLE_V2; static constexpr bool Preshuffle = true; static constexpr bool DoubleSmemBuffer = true; + static constexpr int N_Repeat = N_Tile / N_Warp_Tile / N_Warp; + static constexpr bool TiledMMAPermuteN = N_Repeat % 2 == 0; }; template diff --git a/example/ck_tile/03_gemm/run_gemm_example.inc b/example/ck_tile/03_gemm/run_gemm_example.inc index 7289d54742..49d9a34f17 100644 --- a/example/ck_tile/03_gemm/run_gemm_example.inc +++ b/example/ck_tile/03_gemm/run_gemm_example.inc @@ -241,6 +241,26 @@ auto shuffle_b(const ck_tile::HostTensor& t) return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4}); } +template +auto shuffle_b_permuteN(const ck_tile::HostTensor& t) +{ + assert(t.get_lengths().size() == 2); + + int n_ = t.get_lengths()[1]; + int k_ = t.get_lengths()[0]; + constexpr int divisor = GemmConfig::N_Warp_Tile == 32 ? 2 : 4; + constexpr int NRepeat = GemmConfig::N_Tile / GemmConfig::N_Warp_Tile / GemmConfig::N_Warp; + ck_tile::HostTensor t_view({n_ / GemmConfig::N_Tile, + GemmConfig::N_Warp, + GemmConfig::N_Warp_Tile, + NRepeat, + k_ / GemmConfig::K_Warp_Tile, + divisor, + GemmConfig::K_Warp_Tile / divisor}); + std::copy(t.begin(), t.end(), t_view.begin()); + return ck_tile::reference_permute(t_view, {0, 3, 1, 4, 5, 2, 6}); +} + template bool do_verify(const ck_tile::HostTensor& c_m_n_dev_result, const ck_tile::HostTensor& c_m_n_ref, @@ -346,7 +366,18 @@ int run_gemm_example_with_layouts(ck_tile::ArgParser& arg_parser, if constexpr(preshuffle) { - ck_tile::HostTensor b_shuffle_host = shuffle_b(b_k_n); + ck_tile::HostTensor b_shuffle_host = [&]() { + if constexpr(GemmConfig::TiledMMAPermuteN) + { + std::cout << "Run with PermuteN" << std::endl; + return shuffle_b_permuteN(b_k_n); + } + else + { + std::cout << "Run without PermuteN" << std::endl; + return shuffle_b(b_k_n); + } + }(); // shuffled buffer B for device implementation b_k_n_dev_buf.ToDevice(b_shuffle_host.data()); } diff --git a/include/ck_tile/core/container/sequence.hpp b/include/ck_tile/core/container/sequence.hpp index 905b32dd15..cfec2237f9 100644 --- a/include/ck_tile/core/container/sequence.hpp +++ b/include/ck_tile/core/container/sequence.hpp @@ -175,6 +175,9 @@ struct sequence return sequence{})...>{}; } + CK_TILE_HOST_DEVICE static constexpr auto sum() { return (Is + ... + 0); } + CK_TILE_HOST_DEVICE static constexpr auto product() { return (Is * ... * 1); } + // modify element at index "I" with value "X" template CK_TILE_HOST_DEVICE static constexpr auto modify(number, number) diff --git a/include/ck_tile/core/utility/debug.hpp b/include/ck_tile/core/utility/debug.hpp index 15f0718dc2..9f0f931bc8 100644 --- a/include/ck_tile/core/utility/debug.hpp +++ b/include/ck_tile/core/utility/debug.hpp @@ -153,4 +153,28 @@ struct CK_PRINTF_WARP0 : public CK_PRINTF base_t::operator()(buf); } }; + +/* + * RAII struct which inserts start/end markers into the generated assembly. + * + * Usage: + * - Create an `AsmScopeMarker` object at the beginning of a scope or code block. + * - Its constructor will emit a "CK_ASM_SCOPE_START" marker into the assembly. + * - When the object goes out of scope (end of block, return, exception, etc.), + * the destructor will emit a "CK_ASM_SCOPE_END" marker. + * + * Example: + * { + * [[maybe_unused]] AsmScopeMarker marker; // Emits CK_ASM_SCOPE_START + * // ... code you want to delimit in assembly ... + * } // marker goes out of scope → Emits CK_ASM_SCOPE_END + * + */ +struct AsmScopeMarker +{ + // in some future version of clang we might be able to use string_view to customize + CK_TILE_HOST_DEVICE AsmScopeMarker() { asm volatile(";;# CK_ASM_SCOPE_START"); } + CK_TILE_HOST_DEVICE ~AsmScopeMarker() { asm volatile(";;# CK_ASM_SCOPE_END"); } +}; + } // namespace ck_tile diff --git a/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp b/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp index c6cfe84ac5..ed73f7e9f4 100644 --- a/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp +++ b/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp @@ -31,7 +31,8 @@ template + index_t VectorSizeC_ = 1, + bool TiledMMAPermuteN_ = false> struct CShuffleEpilogueProblem { using ADataType = remove_cvref_t; @@ -54,6 +55,7 @@ struct CShuffleEpilogueProblem static constexpr memory_operation_enum MemoryOperation = MemoryOperation_; static constexpr bool FixedVectorSize = FixedVectorSize_; static constexpr index_t VectorSizeC = VectorSizeC_; + static constexpr bool TiledMMAPermuteN = TiledMMAPermuteN_; static constexpr index_t kNumWaveGroups = kNumWaveGroups_; static constexpr index_t NumDTensor = DsDataType::size(); @@ -89,10 +91,13 @@ struct CShuffleEpilogue static constexpr index_t KPerXdl = Problem::KPerXdl; static constexpr index_t isCTransposed = Problem::isCTransposed; static constexpr bool FixedVectorSize = Problem::FixedVectorSize; + static constexpr bool TiledMMAPermuteN = Problem::TiledMMAPermuteN; static constexpr index_t VectorSizeC = Problem::VectorSizeC; static constexpr index_t MPerIteration = MPerXdl * MWave; static constexpr index_t NPerIteration = NPerXdl * NWave; static constexpr index_t NumDTensor = Problem::NumDTensor; + static constexpr index_t MRepeat = kMPerBlock / (MPerXdl * MWave); + static constexpr index_t NRepeat = kNPerBlock / (NPerXdl * NWave); static_assert(NumDTensor == DsLayout::size(), "The size of DsDataType and DsLayout should be the same"); @@ -367,11 +372,152 @@ struct CShuffleEpilogue struct EmptyScale { }; + template + typename ScaleM = EmptyScale, + typename ScaleN = EmptyScale, + int EnablePermuateN_ = TiledMMAPermuteN, + std::enable_if_t = 0> + CK_TILE_DEVICE auto operator()(ODramWindow& out_dram_window, + const OAccTile& o_acc_tile, + const DsDramWindows& ds_dram_windows, + void* /*p_smem*/, + const ScaleM& scale_m = {}, + const ScaleN& scale_n = {}) + { + constexpr int kM0 = MWave; + constexpr int kM2 = 4; + constexpr int kM1 = MPerXdl / kM2; + + constexpr int kN0 = NWave; + constexpr int kN1 = NPerXdl; + constexpr int kN2 = NRepeat; + + using IntrThreadShuffleEncode = + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<1, 1>>, + sequence<1, 2>, + sequence<2, 2>>; + constexpr auto dram_tile_distribution = + make_static_tile_distribution(IntrThreadShuffleEncode{}); + + auto d_dram_windows = generate_tuple( + [&](auto idx) { + return make_tile_window(ds_dram_windows[idx], dram_tile_distribution); + }, + number{}); + + constexpr auto c_warp_y_lengths = + to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); + constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t{}; + + auto shuffle_acc = make_static_distributed_tensor(dram_tile_distribution); + auto c_out_tensor = make_static_distributed_tensor(dram_tile_distribution); + + // Optional scales (must share the same distribution to match per-thread indexing) + constexpr bool has_scales = + !std::is_same::value && !std::is_same::value; + + // Tiles to hold row/col scales when present + using SMType = + std::conditional_t, float>; + using SNType = + std::conditional_t, float>; + + auto sm_tile = make_static_distributed_tensor(dram_tile_distribution); + auto sn_tile = make_static_distributed_tensor(dram_tile_distribution); + + // Build windows only if scales are provided + auto scale_m_window = [&]() { + if constexpr(has_scales) + { + return make_tile_window(scale_m, dram_tile_distribution); + } + else + { + return EmptyScale{}; + } + }(); + auto scale_n_window = [&]() { + if constexpr(has_scales) + { + return make_tile_window(scale_n, dram_tile_distribution); + } + else + { + return EmptyScale{}; + } + }(); + + static_for<0, MRepeat, 1>{}([&](auto mIter) { + // Slice accumulators for this M repeat into the permuted layout + shuffle_acc.get_thread_buffer() = o_acc_tile.get_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, NRepeat>{}, c_warp_y_lengths)); + + // If scales provided, load them with identical distribution + if constexpr(has_scales) + { + sm_tile = load_tile(scale_m_window); // row scales in permuted layout + sn_tile = load_tile(scale_n_window); // col scales in permuted layout + } + + // Pack 4 “rows per lane” as you already do + static_for<0, NRepeat, 1>{}([&](auto n_idx) { + // source indices in shuffle_acc: (n_idx * product(Y) + row) + const index_t base = n_idx * c_warp_y_lengths.product(); + + // local lambda to fuse scale (if present) and convert + auto emit = [&](index_t out_idx, index_t src_row) { + AccDataType v = shuffle_acc.get_thread_buffer()[base + src_row]; + + if constexpr(has_scales) + { + // same linear index mapping on the permuted distribution + const auto s_m = static_cast(sm_tile.get_thread_buffer()[out_idx]); + const auto s_n = static_cast(sn_tile.get_thread_buffer()[out_idx]); + v = static_cast(v * s_m * s_n); + } + + c_out_tensor.get_thread_buffer()[out_idx] = type_convert(v); + }; + + // Your current packing pattern (rows 0..3, spaced by NRepeat) + emit(n_idx + 0 * NRepeat, 0); + emit(n_idx + 1 * NRepeat, 1); + emit(n_idx + 2 * NRepeat, 2); + emit(n_idx + 3 * NRepeat, 3); + }); + + // store/update + if constexpr(MemoryOperation == memory_operation_enum::set) + { + store_tile(out_dram_window, c_out_tensor); + } + else + { + update_tile(out_dram_window, c_out_tensor); + } + + // advance output (and any D-tensors) by one MPerXdl*MWave chunk + move_tile_window(out_dram_window, {number{}, number<0>{}}); + static_for<0, NumDTensor, 1>{}([&](auto idx) { + move_tile_window(d_dram_windows[idx], {number{}, number<0>{}}); + }); + }); + } + + template = 0> CK_TILE_DEVICE auto operator()(ODramWindow& out_dram_window, const OAccTile& o_acc_tile, const DsDramWindows& ds_dram_windows,