mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-12 01:10:17 +00:00
WIP
This commit is contained in:
@@ -180,92 +180,115 @@ struct MXGemmKernel : UniversalGemmKernel<TilePartitioner_, MXGemmPipeline_, Epi
|
||||
|
||||
using SplitKBatchOffset = typename Underlying::SplitKBatchOffset;
|
||||
|
||||
// Create C block window following UniversalGemmKernel pattern
|
||||
template <memory_operation_enum DstInMemOp = memory_operation_enum::set, typename ScaleM, typename ScaleN>
|
||||
CK_TILE_DEVICE static auto
|
||||
MakeGemmTensorViews(const std::array<const ADataType*, NumATensor>& as_ptr,
|
||||
const std::array<const BDataType*, NumBTensor>& bs_ptr,
|
||||
const std::array<const void*, NumDTensor>& ds_ptr,
|
||||
EDataType* e_ptr,
|
||||
const KernelArgs<ScaleM, ScaleN>& kargs,
|
||||
const SplitKBatchOffset& splitk_batch_offset)
|
||||
CK_TILE_DEVICE static auto MakeCBlockWindows(EDataType* e_ptr,
|
||||
const KernelArgs<ScaleM, ScaleN>& kargs,
|
||||
const index_t i_m,
|
||||
const index_t i_n)
|
||||
{
|
||||
// Get tensor views from the UniversalGemmKernel
|
||||
const auto& gemm_tensor_views_tuple =
|
||||
Underlying::template MakeGemmTensorViews<DstInMemOp>(
|
||||
as_ptr, bs_ptr, ds_ptr, e_ptr, kargs, splitk_batch_offset.splitted_k);
|
||||
// Create tensor view for E/C tensor
|
||||
constexpr index_t vector_size = EpiloguePipeline::GetVectorSizeC();
|
||||
const auto& e_tensor_view = [&]() -> auto {
|
||||
if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return make_naive_tensor_view<address_space_enum::global, DstInMemOp>(
|
||||
e_ptr,
|
||||
make_tuple(kargs.M, kargs.N),
|
||||
make_tuple(kargs.stride_E, 1),
|
||||
number<vector_size>{},
|
||||
number<1>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_naive_tensor_view<address_space_enum::global, DstInMemOp>(
|
||||
e_ptr,
|
||||
make_tuple(kargs.M, kargs.N),
|
||||
make_tuple(1, kargs.stride_E),
|
||||
number<1>{},
|
||||
number<vector_size>{});
|
||||
}
|
||||
}();
|
||||
|
||||
// Create padded view
|
||||
const auto& e_pad_view = [&]() {
|
||||
if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return pad_tensor_view(e_tensor_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
number<TilePartitioner::NPerBlock>{}),
|
||||
sequence<false, false>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return pad_tensor_view(e_tensor_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
number<TilePartitioner::NPerBlock>{}),
|
||||
sequence<false, false>{});
|
||||
}
|
||||
}();
|
||||
|
||||
// Create block window
|
||||
auto c_block_window = make_tile_window(
|
||||
e_pad_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::NPerBlock>{}),
|
||||
{i_m, i_n});
|
||||
|
||||
return c_block_window;
|
||||
}
|
||||
|
||||
// Create scale A block windows following the pattern of MakeABlockWindows
|
||||
template <typename ScaleM, typename ScaleN>
|
||||
CK_TILE_DEVICE static auto
|
||||
MakeScaleABlockWindows(const KernelArgs<ScaleM, ScaleN>& kargs, const index_t block_idx_m)
|
||||
{
|
||||
auto scale_a = kargs.scale_m_ptr;
|
||||
auto scale_b = kargs.scale_n_ptr;
|
||||
|
||||
static_assert(ScaleM::GranularityK == ScaleN::GranularityK, "M and N scales must have same K granularity!");
|
||||
static constexpr int BlockScaleSize = ScaleM::GranularityK;
|
||||
|
||||
// With 1D K-only packing: each int32 contains KXdlPack consecutive e8m0 values
|
||||
// Scale A layout: [M, K/BlockScaleSize/KXdlPack] where each element is int32
|
||||
// Scale B layout: [N, K/BlockScaleSize/KXdlPack] where each element is int32
|
||||
const auto&& scale_k_packed = kargs.K / BlockScaleSize / KXdlPack;
|
||||
static constexpr int BlockScaleSize = ScaleM::GranularityK;
|
||||
const auto scale_k_packed = kargs.K / BlockScaleSize / KXdlPack;
|
||||
|
||||
// A scale tensor view - simple 2D layout [M, K/32/4]
|
||||
const auto& scale_a_tensor_view = [&]() {
|
||||
const auto scale_a_desc = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(kargs.M, scale_k_packed));
|
||||
// A scale tensor view - simple 2D layout [M, K/BlockScaleSize/KXdlPack]
|
||||
const auto scale_a_desc = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(kargs.M, scale_k_packed));
|
||||
|
||||
return make_tensor_view<address_space_enum::global>(
|
||||
reinterpret_cast<const int32_t*>(scale_a.ptr), scale_a_desc);
|
||||
}();
|
||||
const auto scale_a_tensor_view = make_tensor_view<address_space_enum::global>(
|
||||
reinterpret_cast<const int32_t*>(scale_a.ptr), scale_a_desc);
|
||||
|
||||
// B scale tensor view - layout [K/32/4, N] to match reference
|
||||
// Reference provides scale_b(k/32, n), so it's [K/32, N] in e8m0
|
||||
// With KXdlPack=4, we pack 4 e8m0 into 1 int32, so it's [K/32/4, N]
|
||||
const auto& scale_b_tensor_view = [&]() {
|
||||
const auto scale_b_desc = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(scale_k_packed, kargs.N));
|
||||
|
||||
return make_tensor_view<address_space_enum::global>(
|
||||
reinterpret_cast<const int32_t*>(scale_b.ptr), scale_b_desc);
|
||||
}();
|
||||
|
||||
return concat_tuple(gemm_tensor_views_tuple, make_tuple(scale_a_tensor_view, scale_b_tensor_view));
|
||||
}
|
||||
|
||||
template <typename TensorView>
|
||||
CK_TILE_DEVICE static auto MakeGemmPadViews(const TensorView& views)
|
||||
{
|
||||
const auto& padded_views = Underlying::template MakeGemmPadViews<TensorView>(views);
|
||||
|
||||
return make_tuple(
|
||||
padded_views.at(I0), padded_views.at(I1), padded_views.at(I2), padded_views.at(I3), views.at(I4), views.at(I5));
|
||||
}
|
||||
|
||||
template <typename PadView>
|
||||
CK_TILE_DEVICE static auto
|
||||
MakeGemmTileWindows(const PadView& views, const index_t i_m, const index_t i_n)
|
||||
{
|
||||
const auto& tile_windows = Underlying::template MakeGemmTileWindows<PadView>(views, i_m, i_n);
|
||||
|
||||
static constexpr int BlockScaleSize = 32;
|
||||
|
||||
// With 1D K-only packing: MXdlPack=1, NXdlPack=1, KXdlPack=4
|
||||
// Each int32 contains KXdlPack consecutive e8m0 scales
|
||||
// Create block window for scale A
|
||||
auto scale_a_block_window = make_tile_window(
|
||||
views.at(I4),
|
||||
scale_a_tensor_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
number<TilePartitioner::KPerBlock / BlockScaleSize / KXdlPack>{}),
|
||||
{i_m, 0});
|
||||
{block_idx_m, 0});
|
||||
|
||||
// Scale B window matches [K/32/4, N] layout from reference
|
||||
return scale_a_block_window;
|
||||
}
|
||||
|
||||
// Create scale B block windows following the pattern of MakeBBlockWindows
|
||||
template <typename ScaleM, typename ScaleN>
|
||||
CK_TILE_DEVICE static auto
|
||||
MakeScaleBBlockWindows(const KernelArgs<ScaleM, ScaleN>& kargs, const index_t block_idx_n)
|
||||
{
|
||||
auto scale_b = kargs.scale_n_ptr;
|
||||
|
||||
static constexpr int BlockScaleSize = ScaleN::GranularityK;
|
||||
const auto scale_k_packed = kargs.K / BlockScaleSize / KXdlPack;
|
||||
|
||||
// B scale tensor view - layout [K/BlockScaleSize/KXdlPack, N]
|
||||
const auto scale_b_desc = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(scale_k_packed, kargs.N));
|
||||
|
||||
const auto scale_b_tensor_view = make_tensor_view<address_space_enum::global>(
|
||||
reinterpret_cast<const int32_t*>(scale_b.ptr), scale_b_desc);
|
||||
|
||||
// Create block window for scale B
|
||||
auto scale_b_block_window = make_tile_window(
|
||||
views.at(I5),
|
||||
scale_b_tensor_view,
|
||||
make_tuple(number<TilePartitioner::KPerBlock / BlockScaleSize / KXdlPack>{},
|
||||
number<TilePartitioner::NPerBlock>{}),
|
||||
{0, i_n});
|
||||
{0, block_idx_n});
|
||||
|
||||
return make_tuple(tile_windows.at(I0),
|
||||
tile_windows.at(I1),
|
||||
tile_windows.at(I2),
|
||||
tile_windows.at(I3),
|
||||
scale_a_block_window,
|
||||
scale_b_block_window);
|
||||
return scale_b_block_window;
|
||||
}
|
||||
|
||||
template <class ScaleM, class ScaleN, bool UseDefaultScheduler = true>
|
||||
@@ -281,22 +304,19 @@ struct MXGemmKernel : UniversalGemmKernel<TilePartitioner_, MXGemmPipeline_, Epi
|
||||
const index_t block_idx_m,
|
||||
const index_t block_idx_n)
|
||||
{
|
||||
// Create Gemm tensor views, pad views and tile windows
|
||||
const auto& gemm_tensor_views_tuple =
|
||||
MakeGemmTensorViews<EpiloguePipeline::MemoryOperation>(
|
||||
as_ptr, bs_ptr, ds_ptr, e_ptr, kargs, splitk_batch_offset);
|
||||
const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple);
|
||||
auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
|
||||
// Create block windows directly, following the new pattern from UniversalGemmKernel
|
||||
const auto& a_block_window =
|
||||
Underlying::MakeABlockWindows(as_ptr, kargs, splitk_batch_offset.splitted_k, block_idx_m);
|
||||
const auto& b_block_window =
|
||||
Underlying::MakeBBlockWindows(bs_ptr, kargs, splitk_batch_offset.splitted_k, block_idx_n);
|
||||
const auto& d_block_window = Underlying::MakeDBlockWindows(ds_ptr, kargs, block_idx_m, block_idx_n);
|
||||
|
||||
// Create scale block windows using our new functions
|
||||
const auto& scale_a_block_window = MakeScaleABlockWindows(kargs, block_idx_m);
|
||||
const auto& scale_b_block_window = MakeScaleBBlockWindows(kargs, block_idx_n);
|
||||
|
||||
const index_t num_loop = TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k);
|
||||
|
||||
// Run GEMM cooperatively by whole workgroup.
|
||||
const auto& a_block_window = gemm_tile_windows.at(I0);
|
||||
const auto& b_block_window = gemm_tile_windows.at(I1);
|
||||
const auto& d_block_window = gemm_tile_windows.at(I2);
|
||||
const auto& scale_a_block_window = gemm_tile_windows.at(I4);
|
||||
const auto& scale_b_block_window = gemm_tile_windows.at(I5);
|
||||
|
||||
static_assert(ScaleM::GranularityK == ScaleN::GranularityK // have the same granK
|
||||
|| ScaleM::GranularityMN == -1 // or ScaleA is disable
|
||||
|| ScaleN::GranularityMN == -1, // or ScaleB is disable
|
||||
@@ -310,8 +330,9 @@ struct MXGemmKernel : UniversalGemmKernel<TilePartitioner_, MXGemmPipeline_, Epi
|
||||
smem_ptr_ping,
|
||||
smem_ptr_pong);
|
||||
|
||||
// Run Epilogue Pipeline
|
||||
auto& c_block_window = gemm_tile_windows.at(I3);
|
||||
// Run Epilogue Pipeline - create C block window directly
|
||||
auto c_block_window =
|
||||
MakeCBlockWindows<EpiloguePipeline::MemoryOperation>(e_ptr, kargs, block_idx_m, block_idx_n);
|
||||
EpiloguePipeline{}(c_block_window, c_block_tile, d_block_window, smem_ptr_ping);
|
||||
}
|
||||
|
||||
@@ -338,7 +359,8 @@ struct MXGemmKernel : UniversalGemmKernel<TilePartitioner_, MXGemmPipeline_, Epi
|
||||
const index_t i_m = amd_wave_read_first_lane(iM * TilePartitioner::MPerBlock);
|
||||
const index_t i_n = amd_wave_read_first_lane(iN * TilePartitioner::NPerBlock);
|
||||
|
||||
const SplitKBatchOffset splitk_batch_offset(kargs);
|
||||
// Cast to base class for SplitKBatchOffset construction
|
||||
const SplitKBatchOffset splitk_batch_offset(static_cast<const typename Underlying::KernelArgs&>(kargs));
|
||||
// options
|
||||
EDataType* e_ptr = static_cast<EDataType*>(kargs.e_ptr);
|
||||
|
||||
|
||||
@@ -5,19 +5,22 @@
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_async_default_policy.hpp"
|
||||
#include "ck_tile/ops/gemm_mx/pipeline/gemm_pipeline_ag_bg_cr_comp_async_default_policy.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// A Tile Window: global memory
|
||||
// B Tile Window: global memory
|
||||
// C Distributed tensor: register
|
||||
// MX scaling support with OpSel
|
||||
template <typename Problem>
|
||||
struct BaseGemmPipelineAgBgCrCompAsync
|
||||
struct BaseMXGemmPipelineAgBgCrCompAsync
|
||||
{
|
||||
static constexpr index_t PrefetchStages = 2;
|
||||
static constexpr index_t PrefillStages = 1;
|
||||
static constexpr index_t GlobalBufferNum = 1;
|
||||
|
||||
static constexpr bool UsePersistentKernel = Problem::Traits::UsePersistentKernel;
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr bool BlockHasHotloop(index_t num_loop)
|
||||
{
|
||||
@@ -87,15 +90,16 @@ struct BaseGemmPipelineAgBgCrCompAsync
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief Compute optimized pipeline version async; which is based on V4.
|
||||
* @brief MX GEMM compute optimized pipeline version async; which is based on V4.
|
||||
*
|
||||
* This pipeline introduces asynchronous load from global memory to LDS,
|
||||
* skipping the intermediate loading into pipeline registers.
|
||||
* Supports MX scaling with e8m0 packed values and OpSel.
|
||||
*/
|
||||
template <typename Problem, typename Policy = GemmPipelineAgBgCrCompAsyncDefaultPolicy>
|
||||
struct GemmPipelineAgBgCrCompAsync : public BaseGemmPipelineAgBgCrCompAsync<Problem>
|
||||
template <typename Problem, typename Policy = MXGemmPipelineAgBgCrCompAsyncDefaultPolicy>
|
||||
struct MXGemmPipelineAgBgCrCompAsync : public BaseMXGemmPipelineAgBgCrCompAsync<Problem>
|
||||
{
|
||||
using Base = BaseGemmPipelineAgBgCrCompAsync<Problem>;
|
||||
using Base = BaseMXGemmPipelineAgBgCrCompAsync<Problem>;
|
||||
using PipelineImplBase = GemmPipelineAgBgCrImplBase<Problem, Policy>;
|
||||
|
||||
using AsDataType = remove_cvref_t<typename Problem::AsDataTypeTuple>;
|
||||
@@ -117,6 +121,11 @@ struct GemmPipelineAgBgCrCompAsync : public BaseGemmPipelineAgBgCrCompAsync<Prob
|
||||
using BDataType = remove_cvref_t<std::tuple_element_t<0, BsDataType>>;
|
||||
|
||||
static_assert(!std::is_same_v<BDataType, pk_int4_t>, "Not implemented");
|
||||
|
||||
// MX scaling packing constants
|
||||
static constexpr int MXdlPack = Policy::MXdlPack;
|
||||
static constexpr int NXdlPack = Policy::NXdlPack;
|
||||
static constexpr int KXdlPack = Policy::KXdlPack;
|
||||
|
||||
static constexpr index_t APackedSize =
|
||||
ck_tile::numeric_traits<remove_cvref_t<ADataType>>::PackedSize;
|
||||
@@ -317,7 +326,6 @@ struct GemmPipelineAgBgCrCompAsync : public BaseGemmPipelineAgBgCrCompAsync<Prob
|
||||
constexpr index_t KPerXdl = WarpTile::at(I2{});
|
||||
|
||||
constexpr index_t ScaleBlockSize = 32;
|
||||
constexpr index_t KXdlPack = Policy::KXdlPack;
|
||||
|
||||
// Scale A DRAM Window: [MWarp * MPerXdl, kKPerBlock / 32 / KXdlPack]
|
||||
auto scale_a_dram_window = make_tile_window(
|
||||
@@ -520,19 +528,28 @@ struct GemmPipelineAgBgCrCompAsync : public BaseGemmPipelineAgBgCrCompAsync<Prob
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto n_iter) {
|
||||
constexpr auto OpSelB = kScaleInPack;
|
||||
|
||||
// Extract A/B values for this iteration
|
||||
auto a_val = a_block_tile.get_y_sliced_thread_data(
|
||||
// Extract A/B values for this iteration - create warp tensors
|
||||
typename WarpGemm::AWarpTensor a_warp_tensor{};
|
||||
const auto a_thread_data = a_block_tile.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<m_iter, k_iter>{}, a_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
|
||||
auto b_val = b_block_tile.get_y_sliced_thread_data(
|
||||
static_for<0, a_warp_tensor.get_thread_buffer_size(), 1>{}([&](auto i) {
|
||||
a_warp_tensor.get_thread_buffer()(i) = a_thread_data[i];
|
||||
});
|
||||
|
||||
typename WarpGemm::BWarpTensor b_warp_tensor{};
|
||||
const auto b_thread_data = b_block_tile.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<n_iter, k_iter>{}, b_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, b_warp_y_lengths));
|
||||
static_for<0, b_warp_tensor.get_thread_buffer_size(), 1>{}([&](auto i) {
|
||||
b_warp_tensor.get_thread_buffer()(i) = b_thread_data[i];
|
||||
});
|
||||
|
||||
WarpGemm{}.template operator()<OpSelA, OpSelB>(
|
||||
c_warp_tensors(m_iter)(n_iter),
|
||||
bit_cast<typename WarpGemm::AWarpTensor>(a_val),
|
||||
a_warp_tensor,
|
||||
b_warp_tensor,
|
||||
scale_a(m_iter)(number<kScalePacked>{}).get_thread_buffer()[0],
|
||||
bit_cast<typename WarpGemm::BWarpTensor>(b_val),
|
||||
scale_b(n_iter)(number<kScalePacked>{}).get_thread_buffer()[0]);
|
||||
});
|
||||
});
|
||||
@@ -742,9 +759,9 @@ struct GemmPipelineAgBgCrCompAsync : public BaseGemmPipelineAgBgCrCompAsync<Prob
|
||||
const auto RunPipeline = [&](auto hot_loop_, auto tail_num_) {
|
||||
return PipelineImpl<Scheduler>{}.template operator()<hot_loop_.value, tail_num_.value>(
|
||||
a_dram_block_window_tmp,
|
||||
[](const ADataType& a) { return a; },
|
||||
element_wise::PassThrough{},
|
||||
b_dram_block_window_tmp,
|
||||
[](const BDataType& b) { return b; },
|
||||
element_wise::PassThrough{},
|
||||
scale_a_window,
|
||||
scale_b_window,
|
||||
num_loop,
|
||||
|
||||
@@ -9,11 +9,12 @@
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
// Default policy for GemmPipelineAgBgCrCompAsync
|
||||
// Customized methods: MakeALdsBlockDescriptor, MakeBLdsBlockDescriptor
|
||||
// Default policy for MXGemmPipelineAgBgCrCompAsync
|
||||
// Customized methods: MakeALdsBlockDescriptor, MakeBLdsBlockDescriptor
|
||||
// GetBlockGemm implementation is copied from GemmPipelineAgBgCrCompV4DefaultPolicy
|
||||
struct GemmPipelineAgBgCrCompAsyncDefaultPolicy
|
||||
: public UniversalGemmBasePolicy<GemmPipelineAgBgCrCompAsyncDefaultPolicy>
|
||||
// Adds MX scale tile distributions
|
||||
struct MXGemmPipelineAgBgCrCompAsyncDefaultPolicy
|
||||
: public UniversalGemmBasePolicy<MXGemmPipelineAgBgCrCompAsyncDefaultPolicy>
|
||||
{
|
||||
static constexpr auto ATileAccessPattern = tile_distribution_pattern::warp_raked;
|
||||
static constexpr auto BTileAccessPattern = tile_distribution_pattern::warp_raked;
|
||||
@@ -134,7 +135,8 @@ struct GemmPipelineAgBgCrCompAsyncDefaultPolicy
|
||||
return BlockGemmARegBRegCRegV1<Problem, BlockGemmPolicy>{};
|
||||
}
|
||||
|
||||
// MX Scale tile distributions for loading from global memory
|
||||
// MX Scale tile distributions for loading from global memory
|
||||
// Using the proven "Flat" patterns from v1 policy
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeMX_ScaleA_DramTileDistribution()
|
||||
{
|
||||
@@ -147,16 +149,17 @@ struct GemmPipelineAgBgCrCompAsyncDefaultPolicy
|
||||
constexpr index_t MPerXdl = WarpTile::at(number<0>{});
|
||||
constexpr index_t K_Lane = get_warp_size() / MPerXdl; // 4 for 16x16 mfma
|
||||
|
||||
// Scale A: [MWarp * MPerXdl, K/32/KXdlPack] for warp-level tile
|
||||
// Scale A: [MWarp * MPerXdl, K/32/KXdlPack] for warp-level tile
|
||||
// Distribution: simple 2D for loading int32 packed scales
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<NWarp>, // repeat over NWarp
|
||||
tile_distribution_encoding<sequence<NWarp>, // repeat over NWarps
|
||||
tuple<sequence<MWarp, MPerXdl>, // M dimension
|
||||
sequence<K_Lane, 1>>, // K dimension (int32 vec load)
|
||||
sequence<K_Lane, 1>>, // K dimension (int32 vec load)
|
||||
tuple<sequence<1, 0>, sequence<2, 1>>, // which direction
|
||||
tuple<sequence<0, 0>, sequence<0, 1>>, // which index
|
||||
sequence<2>, // repeat
|
||||
sequence<1>>{}); // vec_load
|
||||
// <repeat, vec_load>
|
||||
sequence<2>,
|
||||
sequence<1>>{});
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
@@ -170,17 +173,22 @@ struct GemmPipelineAgBgCrCompAsyncDefaultPolicy
|
||||
constexpr index_t NWarp = BlockWarps::at(number<1>{});
|
||||
constexpr index_t NPerXdl = WarpTile::at(number<1>{});
|
||||
constexpr index_t K_Lane = get_warp_size() / NPerXdl; // 4 for 16x16 mfma
|
||||
|
||||
static_assert(K_Lane == 4, "K_Lane must be 4 for 16x16 mfma");
|
||||
static_assert(NPerXdl == 16, "NPerXdl must be 16 for 16x16 mfma");
|
||||
static_assert(MWarp == 1, "MWarp must be 1 for 16x16 mfma");
|
||||
|
||||
// Scale B: [K/32/KXdlPack, NWarp * NPerXdl] for warp-level tile
|
||||
// Layout is [K, N] where K is packed int32
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<MWarp>, // repeat over MWarp
|
||||
tuple<sequence<K_Lane, 1>, // K dimension (int32 vec load)
|
||||
tile_distribution_encoding<sequence<MWarp>, // repeat over MWarps
|
||||
tuple<sequence<K_Lane, 1>, // K dimension (int32 vec load)
|
||||
sequence<NWarp, NPerXdl>>, // N dimension
|
||||
tuple<sequence<2, 1>, sequence<0, 1>>, // which direction
|
||||
tuple<sequence<0, 1>, sequence<0, 0>>, // which index
|
||||
sequence<1>, // repeat
|
||||
sequence<2>>{}); // vec_load
|
||||
// <repeat, vec_load>
|
||||
sequence<1>,
|
||||
sequence<1>>{});
|
||||
}
|
||||
};
|
||||
} // namespace ck_tile
|
||||
|
||||
Reference in New Issue
Block a user