mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-21 05:19:20 +00:00
try to enable scale loading in kernel and pipeline
This commit is contained in:
@@ -152,30 +152,40 @@ struct MXGemmKernel : UniversalGemmKernel<TilePartitioner_, MXGemmPipeline_, Epi
|
||||
CK_TILE_HOST static constexpr auto
|
||||
GridSize(const KernelArgs<ScaleM, ScaleN>& kargs)
|
||||
{
|
||||
hipDeviceProp_t prop;
|
||||
int deviceId = 0; // default device
|
||||
const int total_work_tile_cnt = TilePartitioner::GridSize(kargs.M, kargs.N);
|
||||
|
||||
if constexpr(UsePersistentKernel)
|
||||
{
|
||||
hipDeviceProp_t prop;
|
||||
int deviceId = 0; // default device
|
||||
|
||||
int dync_smem_size = 0;
|
||||
int maxActiveBlocksPerCU = 0;
|
||||
int dync_smem_size = 0;
|
||||
int maxActiveBlocksPerCU = 0;
|
||||
|
||||
if(hipGetDeviceProperties(&prop, deviceId) != hipSuccess)
|
||||
throw std::runtime_error(std::string("hipGetDeviceProperties failed: ") +
|
||||
hipGetErrorName(hipGetLastError()));
|
||||
if(hipGetDeviceProperties(&prop, deviceId) != hipSuccess)
|
||||
throw std::runtime_error(std::string("hipGetDeviceProperties failed: ") +
|
||||
hipGetErrorName(hipGetLastError()));
|
||||
|
||||
if(hipOccupancyMaxActiveBlocksPerMultiprocessor(
|
||||
&maxActiveBlocksPerCU,
|
||||
reinterpret_cast<void*>(
|
||||
kentry<1, MXGemmKernel, remove_cvref_t<decltype(kargs)>>),
|
||||
KernelBlockSize,
|
||||
dync_smem_size) != hipSuccess)
|
||||
throw std::runtime_error(
|
||||
std::string("hipOccupancyMaxActiveBlocksPerMultiprocessor failed: ") +
|
||||
hipGetErrorName(hipGetLastError()));
|
||||
if(hipOccupancyMaxActiveBlocksPerMultiprocessor(
|
||||
&maxActiveBlocksPerCU,
|
||||
reinterpret_cast<void*>(
|
||||
kentry<1, MXGemmKernel, remove_cvref_t<decltype(kargs)>>),
|
||||
KernelBlockSize,
|
||||
dync_smem_size) != hipSuccess)
|
||||
throw std::runtime_error(
|
||||
std::string("hipOccupancyMaxActiveBlocksPerMultiprocessor failed: ") +
|
||||
hipGetErrorName(hipGetLastError()));
|
||||
|
||||
const int persistent_block_size = prop.multiProcessorCount * maxActiveBlocksPerCU;
|
||||
const int total_work_tile_cnt = TilePartitioner::GridSize(kargs.M, kargs.N);
|
||||
const int persistent_block_size = prop.multiProcessorCount * maxActiveBlocksPerCU;
|
||||
const int actual_grid_size = min(persistent_block_size, total_work_tile_cnt);
|
||||
|
||||
return dim3(min(persistent_block_size, total_work_tile_cnt), 1, 1);
|
||||
return dim3(actual_grid_size, 1, 1);
|
||||
}
|
||||
else
|
||||
{
|
||||
// Non-persistent: use full grid size based on number of tiles
|
||||
return dim3(total_work_tile_cnt, 1, 1);
|
||||
}
|
||||
}
|
||||
|
||||
using SplitKBatchOffset = typename Underlying::SplitKBatchOffset;
|
||||
@@ -240,26 +250,36 @@ struct MXGemmKernel : UniversalGemmKernel<TilePartitioner_, MXGemmPipeline_, Epi
|
||||
// Create scale A block windows following the pattern of MakeABlockWindows
|
||||
template <typename ScaleM, typename ScaleN>
|
||||
CK_TILE_DEVICE static auto
|
||||
MakeScaleABlockWindows(const KernelArgs<ScaleM, ScaleN>& kargs, const index_t block_idx_m)
|
||||
MakeScaleABlockWindows(const KernelArgs<ScaleM, ScaleN>& kargs, const index_t i_m)
|
||||
{
|
||||
auto scale_a = kargs.scale_m_ptr;
|
||||
|
||||
static constexpr int BlockScaleSize = ScaleM::GranularityK;
|
||||
const auto scale_k_packed = kargs.K / BlockScaleSize / KXdlPack;
|
||||
const auto scale_k_size = kargs.K / BlockScaleSize;
|
||||
const auto scale_k_size_packed = scale_k_size / KXdlPack;
|
||||
|
||||
// A scale tensor view - simple 2D layout [M, K/BlockScaleSize/KXdlPack]
|
||||
const auto scale_a_desc = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(kargs.M, scale_k_packed));
|
||||
// A scale tensor view - layout [M, scale_k_size_packed] with packed int32_t
|
||||
// Host packs 4 consecutive e8m0_t scales into one int32_t
|
||||
// const auto scale_a_desc = make_naive_tensor_descriptor(
|
||||
// make_tuple(kargs.M, scale_k_size_packed),
|
||||
// make_tuple(scale_k_size_packed, 1));
|
||||
|
||||
const auto scale_a_tensor_view = make_tensor_view<address_space_enum::global>(
|
||||
reinterpret_cast<const int32_t*>(scale_a.ptr), scale_a_desc);
|
||||
// const auto scale_a_tensor_view = make_tensor_view<address_space_enum::global>(
|
||||
// reinterpret_cast<const int32_t*>(scale_a.ptr), scale_a_desc);
|
||||
|
||||
const auto scale_a_tensor_view = make_naive_tensor_view<address_space_enum::global>(
|
||||
reinterpret_cast<const int32_t*>(scale_a.ptr),
|
||||
make_tuple(kargs.M, scale_k_size_packed),
|
||||
make_tuple(scale_k_size_packed, 1));
|
||||
|
||||
// Create block window for scale A
|
||||
// K dimension: KIterPerWarp int32s, each int32 contains 4 scales for K_Lane threads
|
||||
// i_m is element offset (iM * MPerBlock), not tile index
|
||||
auto scale_a_block_window = make_tile_window(
|
||||
scale_a_tensor_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
number<TilePartitioner::KPerBlock / BlockScaleSize / KXdlPack>{}),
|
||||
{block_idx_m, 0});
|
||||
{i_m, 0});
|
||||
|
||||
return scale_a_block_window;
|
||||
}
|
||||
@@ -267,26 +287,35 @@ struct MXGemmKernel : UniversalGemmKernel<TilePartitioner_, MXGemmPipeline_, Epi
|
||||
// Create scale B block windows following the pattern of MakeBBlockWindows
|
||||
template <typename ScaleM, typename ScaleN>
|
||||
CK_TILE_DEVICE static auto
|
||||
MakeScaleBBlockWindows(const KernelArgs<ScaleM, ScaleN>& kargs, const index_t block_idx_n)
|
||||
MakeScaleBBlockWindows(const KernelArgs<ScaleM, ScaleN>& kargs, const index_t i_n)
|
||||
{
|
||||
auto scale_b = kargs.scale_n_ptr;
|
||||
|
||||
static constexpr int BlockScaleSize = ScaleN::GranularityK;
|
||||
const auto scale_k_packed = kargs.K / BlockScaleSize / KXdlPack;
|
||||
const auto scale_k_size = kargs.K / BlockScaleSize;
|
||||
const auto scale_k_size_packed = scale_k_size / KXdlPack;
|
||||
|
||||
// B scale tensor view - layout [K/BlockScaleSize/KXdlPack, N]
|
||||
const auto scale_b_desc = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(scale_k_packed, kargs.N));
|
||||
// B scale tensor view - layout [scale_k_size_packed, N] with packed int32_t
|
||||
// Host packs 4 consecutive e8m0_t scales into one int32_t
|
||||
// const auto scale_b_desc = make_naive_tensor_descriptor(
|
||||
// make_tuple(kargs.N, scale_k_size_packed),
|
||||
// make_tuple(scale_k_size_packed, 1));
|
||||
|
||||
const auto scale_b_tensor_view = make_tensor_view<address_space_enum::global>(
|
||||
reinterpret_cast<const int32_t*>(scale_b.ptr), scale_b_desc);
|
||||
// const auto scale_b_tensor_view = make_tensor_view<address_space_enum::global>(
|
||||
// reinterpret_cast<const int32_t*>(scale_b.ptr), scale_b_desc);
|
||||
|
||||
const auto scale_b_tensor_view = make_naive_tensor_view<address_space_enum::global>(
|
||||
reinterpret_cast<const int32_t*>(scale_b.ptr),
|
||||
make_tuple(kargs.N, scale_k_size_packed),
|
||||
make_tuple(scale_k_size_packed, 1));
|
||||
|
||||
// Create block window for scale B
|
||||
// i_n is element offset (iN * NPerBlock), not tile index
|
||||
auto scale_b_block_window = make_tile_window(
|
||||
scale_b_tensor_view,
|
||||
make_tuple(number<TilePartitioner::KPerBlock / BlockScaleSize / KXdlPack>{},
|
||||
number<TilePartitioner::NPerBlock>{}),
|
||||
{0, block_idx_n});
|
||||
make_tuple(number<TilePartitioner::NPerBlock>{},
|
||||
number<TilePartitioner::KPerBlock / BlockScaleSize / KXdlPack>{}),
|
||||
{i_n, 0});
|
||||
|
||||
return scale_b_block_window;
|
||||
}
|
||||
@@ -301,19 +330,20 @@ struct MXGemmKernel : UniversalGemmKernel<TilePartitioner_, MXGemmPipeline_, Epi
|
||||
void* smem_ptr_pong,
|
||||
const KernelArgs<ScaleM, ScaleN>& kargs,
|
||||
const SplitKBatchOffset& splitk_batch_offset,
|
||||
const index_t block_idx_m,
|
||||
const index_t block_idx_n)
|
||||
const index_t i_m,
|
||||
const index_t i_n)
|
||||
{
|
||||
// Create block windows directly, following the new pattern from UniversalGemmKernel
|
||||
// i_m and i_n are element offsets (iM * MPerBlock, iN * NPerBlock), not tile indices
|
||||
const auto& a_block_window =
|
||||
Underlying::MakeABlockWindows(as_ptr, kargs, splitk_batch_offset.splitted_k, block_idx_m);
|
||||
Underlying::MakeABlockWindows(as_ptr, kargs, splitk_batch_offset.splitted_k, i_m);
|
||||
const auto& b_block_window =
|
||||
Underlying::MakeBBlockWindows(bs_ptr, kargs, splitk_batch_offset.splitted_k, block_idx_n);
|
||||
const auto& d_block_window = Underlying::MakeDBlockWindows(ds_ptr, kargs, block_idx_m, block_idx_n);
|
||||
Underlying::MakeBBlockWindows(bs_ptr, kargs, splitk_batch_offset.splitted_k, i_n);
|
||||
const auto& d_block_window = Underlying::MakeDBlockWindows(ds_ptr, kargs, i_m, i_n);
|
||||
|
||||
// Create scale block windows using our new functions
|
||||
const auto& scale_a_block_window = MakeScaleABlockWindows(kargs, block_idx_m);
|
||||
const auto& scale_b_block_window = MakeScaleBBlockWindows(kargs, block_idx_n);
|
||||
const auto& scale_a_block_window = MakeScaleABlockWindows(kargs, i_m);
|
||||
const auto& scale_b_block_window = MakeScaleBBlockWindows(kargs, i_n);
|
||||
|
||||
const index_t num_loop = TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k);
|
||||
|
||||
@@ -322,6 +352,7 @@ struct MXGemmKernel : UniversalGemmKernel<TilePartitioner_, MXGemmPipeline_, Epi
|
||||
|| ScaleN::GranularityMN == -1, // or ScaleB is disable
|
||||
"ScaleM and ScaleN should have the same GranularityK");
|
||||
|
||||
|
||||
const auto& c_block_tile = MXGemmPipeline{}(a_block_window[number<0>{}],
|
||||
b_block_window[number<0>{}],
|
||||
scale_a_block_window,
|
||||
@@ -332,7 +363,7 @@ struct MXGemmKernel : UniversalGemmKernel<TilePartitioner_, MXGemmPipeline_, Epi
|
||||
|
||||
// Run Epilogue Pipeline - create C block window directly
|
||||
auto c_block_window =
|
||||
MakeCBlockWindows<EpiloguePipeline::MemoryOperation>(e_ptr, kargs, block_idx_m, block_idx_n);
|
||||
MakeCBlockWindows<EpiloguePipeline::MemoryOperation>(e_ptr, kargs, i_m, i_n);
|
||||
EpiloguePipeline{}(c_block_window, c_block_tile, d_block_window, smem_ptr_ping);
|
||||
}
|
||||
|
||||
@@ -352,6 +383,11 @@ struct MXGemmKernel : UniversalGemmKernel<TilePartitioner_, MXGemmPipeline_, Epi
|
||||
{
|
||||
const int total_work_tile_cnt = amd_wave_read_first_lane(TilePartitioner::GridSize(kargs.M, kargs.N));
|
||||
|
||||
// Allocate shared memory OUTSIDE the loop - __shared__ variables must be at function scope
|
||||
__shared__ char smem_ptr_ping[GetSmemPingSize()];
|
||||
__shared__ char smem_ptr_pong[GetSmemPongSize()];
|
||||
|
||||
// Support both persistent and non-persistent modes
|
||||
do
|
||||
{
|
||||
const auto [iM, iN] =
|
||||
@@ -377,10 +413,6 @@ struct MXGemmKernel : UniversalGemmKernel<TilePartitioner_, MXGemmPipeline_, Epi
|
||||
splitk_batch_offset.bs_k_split_offset[i] / BPackedSize;
|
||||
});
|
||||
|
||||
// allocate LDS
|
||||
__shared__ char smem_ptr_ping[GetSmemPingSize()];
|
||||
__shared__ char smem_ptr_pong[GetSmemPongSize()];
|
||||
|
||||
if constexpr(!(EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add &&
|
||||
EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
|
||||
is_any_of<EDataType, fp16_t, bf16_t>::value))
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
|
||||
#pragma once
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/core/tensor/load_tile.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp"
|
||||
#include "ck_tile/ops/gemm_mx/pipeline/gemm_pipeline_ag_bg_cr_comp_async_default_policy.hpp"
|
||||
@@ -294,7 +295,7 @@ struct MXGemmPipelineAgBgCrCompAsync : public BaseMXGemmPipelineAgBgCrCompAsync<
|
||||
"B block window has incorrect lengths for defined BLayout!");
|
||||
|
||||
////////////// global window & register /////////////////
|
||||
// A DRAM tile window(s) for load
|
||||
// A DRAM tile window(s) for load
|
||||
|
||||
auto a_tile_windows = generate_tuple(
|
||||
[&](auto idx) {
|
||||
@@ -410,33 +411,35 @@ struct MXGemmPipelineAgBgCrCompAsync : public BaseMXGemmPipelineAgBgCrCompAsync<
|
||||
constexpr index_t NWarp = BlockWarps::at(I1{});
|
||||
constexpr index_t MPerXdl = WarpTile::at(I0{});
|
||||
constexpr index_t NPerXdl = WarpTile::at(I1{});
|
||||
constexpr index_t KPerXdl = WarpTile::at(I2{});
|
||||
|
||||
constexpr index_t ScaleBlockSize = 32;
|
||||
constexpr index_t ScaleBlockSize = 32; // Each scale covers 32 K elements
|
||||
|
||||
// Scale A DRAM Window: [MWarp * MPerXdl, kKPerBlock / 32 / KXdlPack]
|
||||
// Calculate scale dimensions: KPerBlock elements need KPerBlock/32 scales
|
||||
// Each int32 packs KXdlPack=4 scales, so we need KPerBlock/32/4 int32s per block
|
||||
constexpr index_t ScaleKDimPerBlock = KPerBlock / ScaleBlockSize / KXdlPack; // Packed int32s per block
|
||||
static_assert(ScaleBlockSize == 32, "Scale block size must be 32 for MX format");
|
||||
|
||||
// Scale A DRAM Window: [MWarp * MPerXdl, ScaleKDimPerBlock]
|
||||
// With strided packing: KXdlPack kIters share each int32 via OpSel
|
||||
auto scale_a_dram_window = make_tile_window(
|
||||
scale_a_window.get_bottom_tensor_view(),
|
||||
make_tuple(number<MWarp * MPerXdl>{}, number<KPerBlock / ScaleBlockSize / KXdlPack>{}),
|
||||
make_tuple(number<MWarp * MPerXdl>{}, number<ScaleKDimPerBlock>{}),
|
||||
scale_a_window.get_window_origin(),
|
||||
Policy::template MakeMX_ScaleA_DramTileDistribution<Problem>());
|
||||
|
||||
const auto scale_a_dram_step_m = amd_wave_read_first_lane(
|
||||
scale_a_dram_window.get_load_offset(tuple<number<MWarp * MPerXdl>, number<0>>{}));
|
||||
const auto scale_a_dram_step_k = amd_wave_read_first_lane(
|
||||
scale_a_dram_window.get_load_offset(tuple<number<0>, number<KPerBlock / ScaleBlockSize / KXdlPack>>{}));
|
||||
|
||||
// Scale B DRAM Window: [kKPerBlock / 32 / KXdlPack, NWarp * NPerXdl]
|
||||
// Scale B DRAM Window: [ScaleKDimPerBlock, NWarp * NPerXdl]
|
||||
// With strided packing: KXdlPack kIters share each int32 via OpSel
|
||||
auto scale_b_dram_window = make_tile_window(
|
||||
scale_b_window.get_bottom_tensor_view(),
|
||||
make_tuple(number<KPerBlock / ScaleBlockSize / KXdlPack>{}, number<NWarp * NPerXdl>{}),
|
||||
make_tuple(number<ScaleKDimPerBlock>{}, number<NWarp * NPerXdl>{}),
|
||||
scale_b_window.get_window_origin(),
|
||||
Policy::template MakeMX_ScaleB_DramTileDistribution<Problem>());
|
||||
|
||||
const auto scale_b_dram_step_k = amd_wave_read_first_lane(
|
||||
scale_b_dram_window.get_load_offset(tuple<number<KPerBlock / ScaleBlockSize / KXdlPack>, number<0>>{}));
|
||||
const auto scale_b_dram_step_n = amd_wave_read_first_lane(
|
||||
scale_b_dram_window.get_load_offset(tuple<number<0>, number<NWarp * NPerXdl>>{}));
|
||||
scale_b_dram_window.get_load_offset(tuple<number<NWarp * NPerXdl>, number<0>>{}));
|
||||
|
||||
// this pipeline has a pair of LDS buffers per logical tile
|
||||
// TODO: check for packed size - are these blocks too big?
|
||||
@@ -447,6 +450,7 @@ struct MXGemmPipelineAgBgCrCompAsync : public BaseMXGemmPipelineAgBgCrCompAsync<
|
||||
auto&& [a_lds_block0, b_lds_block0] = Base::GetABLdsTensorViews(p_smem_0);
|
||||
auto&& [a_lds_block1, b_lds_block1] = Base::GetABLdsTensorViews(p_smem_1);
|
||||
|
||||
|
||||
// set up LDS tile shapes - always use STORAGE dimensions for K
|
||||
/// NOTE: flatmm style byte tensor approach:
|
||||
// constexpr auto a_lds_shape = []() {
|
||||
@@ -544,23 +548,29 @@ struct MXGemmPipelineAgBgCrCompAsync : public BaseMXGemmPipelineAgBgCrCompAsync<
|
||||
static_assert(Policy::template GetSmemSizeB<Problem>() == (KPerBlock * sizeof(BDataType) / BPackedSize) * NPerBlock, "SmemSizeB size is wrong!");
|
||||
|
||||
////////////// MX Scale register tiles (ping-pong buffers) /////////////////
|
||||
// Calculate scale iterations: each scale covers 32 elements in K
|
||||
// Each K iteration processes KPerXdl elements
|
||||
// Each packed int32 contains KXdlPack scales
|
||||
// Calculate scale iterations for M/N dimensions
|
||||
constexpr index_t KPerXdl = WarpTile::at(I2{});
|
||||
constexpr index_t KIterPerWarp = KPerBlock / KPerXdl;
|
||||
constexpr index_t MIterPerWarp = MPerBlock / (MWarp * MPerXdl);
|
||||
constexpr index_t NIterPerWarp = NPerBlock / (NWarp * NPerXdl);
|
||||
constexpr index_t ScaleKPackedPerIter = (KIterPerWarp * KPerXdl) / (ScaleBlockSize * KXdlPack);
|
||||
static_assert(ScaleKPackedPerIter > 0, "ScaleKPackedPerIter is wrong!");
|
||||
|
||||
// Load a sample scale tile to get the type
|
||||
// ScaleKPackedPerIter: number of int32s needed to cover all KIterPerWarp iterations
|
||||
// Each int32 packs 4 scales (via strided packing), OpSel selects byte for kIter
|
||||
// KXdlPack kIters share one int32, so we need KIterPerWarp/KXdlPack int32s total
|
||||
constexpr index_t ScaleKPackedPerIter = KIterPerWarp / KXdlPack;
|
||||
static_assert(ScaleKPackedPerIter > 0, "ScaleKPackedPerIter must be positive!");
|
||||
|
||||
// Load a sample scale tile to get the type after distribution
|
||||
auto scale_a_sample = load_tile_with_offset(scale_a_dram_window, tuple<number<0>, number<0>>{});
|
||||
auto scale_b_sample = load_tile_with_offset(scale_b_dram_window, tuple<number<0>, number<0>>{});
|
||||
|
||||
using ScaleTileElementA = remove_cvref_t<decltype(scale_a_sample)>;
|
||||
using ScaleTileElementB = remove_cvref_t<decltype(scale_b_sample)>;
|
||||
using ScaleATileType = statically_indexed_array<statically_indexed_array<ScaleTileElementA, ScaleKPackedPerIter>, MIterPerWarp>;
|
||||
using ScaleBTileType = statically_indexed_array<statically_indexed_array<ScaleTileElementB, ScaleKPackedPerIter>, NIterPerWarp>;
|
||||
|
||||
// ScaleATileType: array of distributed tensors, one per M/N iteration
|
||||
// Each distributed tensor holds ScaleKPackedPerIter int32 elements across threads
|
||||
using ScaleATileType = statically_indexed_array<ScaleTileElementA, MIterPerWarp>;
|
||||
using ScaleBTileType = statically_indexed_array<ScaleTileElementB, NIterPerWarp>;
|
||||
|
||||
ScaleATileType scale_a_tile_ping, scale_a_tile_pong;
|
||||
ScaleBTileType scale_b_tile_ping, scale_b_tile_pong;
|
||||
@@ -569,20 +579,22 @@ struct MXGemmPipelineAgBgCrCompAsync : public BaseMXGemmPipelineAgBgCrCompAsync<
|
||||
auto load_scales_ = [&](auto& scale_a, auto& scale_b) {
|
||||
// Load scales for each M/N iteration
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
static_for<0, ScaleKPackedPerIter, 1>{}([&](auto kPacked) {
|
||||
scale_a(mIter)(kPacked) = load_tile_with_offset(
|
||||
scale_a_dram_window, mIter * scale_a_dram_step_m + kPacked * scale_a_dram_step_k);
|
||||
});
|
||||
// static_for<0, ScaleKPackedPerIter, 1>{}([&](auto kPacked) {
|
||||
// scale_a(mIter)(kPacked) = load_tile_with_offset(
|
||||
// scale_a_dram_window, mIter * scale_a_dram_step_m + kPacked * scale_a_dram_step_k);
|
||||
// });
|
||||
scale_a(mIter) = load_tile_with_offset(scale_a_dram_window, make_tuple(mIter * scale_a_dram_step_m, number<0>{}));
|
||||
});
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
static_for<0, ScaleKPackedPerIter, 1>{}([&](auto kPacked) {
|
||||
// Scale B is [K/32/KXdlPack, N], so K is first dimension
|
||||
scale_b(nIter)(kPacked) = load_tile_with_offset(
|
||||
scale_b_dram_window, kPacked * scale_b_dram_step_k + nIter * scale_b_dram_step_n);
|
||||
});
|
||||
// static_for<0, ScaleKPackedPerIter, 1>{}([&](auto kPacked) {
|
||||
// // Scale B is [K/32/KXdlPack, N], so K is first dimension
|
||||
// scale_b(nIter)(kPacked) = load_tile_with_offset(
|
||||
// scale_b_dram_window, kPacked * scale_b_dram_step_k + nIter * scale_b_dram_step_n);
|
||||
// });
|
||||
scale_b(nIter) = load_tile_with_offset(scale_b_dram_window, make_tuple(nIter * scale_b_dram_step_n, number<0>{}));
|
||||
});
|
||||
move_tile_window(scale_a_dram_window, {0, KPerBlock / ScaleBlockSize / KXdlPack});
|
||||
move_tile_window(scale_b_dram_window, {KPerBlock / ScaleBlockSize / KXdlPack, 0});
|
||||
move_tile_window(scale_b_dram_window, {0, KPerBlock / ScaleBlockSize / KXdlPack});
|
||||
};
|
||||
|
||||
// constexpr auto a_lds_input_tile_distr = [ALdsTileDistr]() {
|
||||
@@ -734,7 +746,7 @@ struct MXGemmPipelineAgBgCrCompAsync : public BaseMXGemmPipelineAgBgCrCompAsync<
|
||||
b_tile_windows[number<0>{}],
|
||||
b_dram_tile_window_step);
|
||||
// C(i-3) = A(i-3) @ B(i-3) with MX scaling
|
||||
// warp_gemm_loop(a_block_tile0, b_block_tile0, scale_a_tile_ping, scale_b_tile_ping);
|
||||
// block_gemm(c_block_tile, a_block_tile0, b_block_tile0, scale_a_tile_ping, scale_b_tile_ping);
|
||||
block_gemm(c_block_tile, a_block_tile0, b_block_tile0);
|
||||
/// TODO: remove these after creating a block gemm with scales
|
||||
ignore = scale_a_tile_ping;
|
||||
@@ -763,11 +775,11 @@ struct MXGemmPipelineAgBgCrCompAsync : public BaseMXGemmPipelineAgBgCrCompAsync<
|
||||
b_tile_windows[number<0>{}],
|
||||
b_dram_tile_window_step);
|
||||
// C(i-2) = A(i-2) @ B(i-2) with MX scaling
|
||||
// warp_gemm_loop(a_block_tile1, b_block_tile1, scale_a_tile_pong, scale_b_tile_pong);
|
||||
block_gemm(c_block_tile, a_block_tile1, b_block_tile1);
|
||||
/// TODO: remove these after creating a block gemm with scales
|
||||
ignore = scale_a_tile_pong;
|
||||
ignore = scale_b_tile_pong;
|
||||
block_gemm(c_block_tile, a_block_tile1, b_block_tile1, scale_a_tile_pong, scale_b_tile_pong);
|
||||
// block_gemm(c_block_tile, a_block_tile1, b_block_tile1);
|
||||
// /// TODO: remove these after creating a block gemm with scales
|
||||
// ignore = scale_a_tile_pong;
|
||||
// ignore = scale_b_tile_pong;
|
||||
HotLoopScheduler();
|
||||
// Load scales for iteration i+2 (pong)
|
||||
/// TODO: check condition
|
||||
@@ -787,11 +799,11 @@ struct MXGemmPipelineAgBgCrCompAsync : public BaseMXGemmPipelineAgBgCrCompAsync<
|
||||
Base::LocalPrefetch(a_block_tile1, a_lds_ld_window1, is_a_load_tr_v);
|
||||
Base::LocalPrefetch(b_block_tile1, b_lds_ld_window1, is_b_load_tr_v);
|
||||
// C(num_loop-2) = A(num_loop-2) @ B(num_loop-2) with MX scaling
|
||||
// warp_gemm_loop(a_block_tile0, b_block_tile0, scale_a_tile_ping, scale_b_tile_ping);
|
||||
block_gemm(c_block_tile, a_block_tile0, b_block_tile0);
|
||||
/// TODO: remove these after creating a block gemm with scales
|
||||
ignore = scale_a_tile_ping;
|
||||
ignore = scale_b_tile_ping;
|
||||
block_gemm(c_block_tile, a_block_tile0, b_block_tile0, scale_a_tile_ping, scale_b_tile_ping);
|
||||
// block_gemm(c_block_tile, a_block_tile0, b_block_tile0);
|
||||
// /// TODO: remove these after creating a block gemm with scales
|
||||
// ignore = scale_a_tile_ping;
|
||||
// ignore = scale_b_tile_ping;
|
||||
/// TODO: load next scales to ping for the last iteration
|
||||
}
|
||||
{
|
||||
@@ -801,19 +813,19 @@ struct MXGemmPipelineAgBgCrCompAsync : public BaseMXGemmPipelineAgBgCrCompAsync<
|
||||
Base::LocalPrefetch(a_block_tile0, a_lds_ld_window0, is_a_load_tr_v);
|
||||
Base::LocalPrefetch(b_block_tile0, b_lds_ld_window0, is_b_load_tr_v);
|
||||
// C(num_loop-1) = A(num_loop-1) @ B(num_loop-1) with MX scaling
|
||||
// warp_gemm_loop(a_block_tile1, b_block_tile1, scale_a_tile_pong, scale_b_tile_pong);
|
||||
block_gemm(c_block_tile, a_block_tile1, b_block_tile1);
|
||||
/// TODO: remove these after creating a block gemm with scales
|
||||
ignore = scale_a_tile_pong;
|
||||
ignore = scale_b_tile_pong;
|
||||
block_gemm(c_block_tile, a_block_tile1, b_block_tile1, scale_a_tile_pong, scale_b_tile_pong);
|
||||
// block_gemm(c_block_tile, a_block_tile1, b_block_tile1);
|
||||
// /// TODO: remove these after creating a block gemm with scales
|
||||
// ignore = scale_a_tile_pong;
|
||||
// ignore = scale_b_tile_pong;
|
||||
}
|
||||
{
|
||||
// C(num_loop) = A(num_loop) @ B(num_loop) with MX scaling
|
||||
// warp_gemm_loop(a_block_tile0, b_block_tile0, scale_a_tile_ping, scale_b_tile_ping);
|
||||
block_gemm(c_block_tile, a_block_tile0, b_block_tile0);
|
||||
/// TODO: remove these after creating a block gemm with scales
|
||||
ignore = scale_a_tile_ping;
|
||||
ignore = scale_b_tile_ping;
|
||||
block_gemm(c_block_tile, a_block_tile0, b_block_tile0, scale_a_tile_ping, scale_b_tile_ping);
|
||||
// block_gemm(c_block_tile, a_block_tile0, b_block_tile0);
|
||||
// /// TODO: remove these after creating a block gemm with scales
|
||||
// ignore = scale_a_tile_ping;
|
||||
// ignore = scale_b_tile_ping;
|
||||
}
|
||||
}
|
||||
else if(TailNum == TailNumber::Two)
|
||||
@@ -824,30 +836,30 @@ struct MXGemmPipelineAgBgCrCompAsync : public BaseMXGemmPipelineAgBgCrCompAsync<
|
||||
Base::LocalPrefetch(a_block_tile1, a_lds_ld_window1, is_a_load_tr_v);
|
||||
Base::LocalPrefetch(b_block_tile1, b_lds_ld_window1, is_b_load_tr_v);
|
||||
// C(num_loop-1) = A(num_loop-1) @ B(num_loop-1) with MX scaling
|
||||
// warp_gemm_loop(a_block_tile0, b_block_tile0, scale_a_tile_ping, scale_b_tile_ping);
|
||||
block_gemm(c_block_tile, a_block_tile0, b_block_tile0);
|
||||
/// TODO: remove these after creating a block gemm with scales
|
||||
ignore = scale_a_tile_ping;
|
||||
ignore = scale_b_tile_ping;
|
||||
block_gemm(c_block_tile, a_block_tile0, b_block_tile0, scale_a_tile_ping, scale_b_tile_ping);
|
||||
// block_gemm(c_block_tile, a_block_tile0, b_block_tile0);
|
||||
// /// TODO: remove these after creating a block gemm with scales
|
||||
// ignore = scale_a_tile_ping;
|
||||
// ignore = scale_b_tile_ping;
|
||||
}
|
||||
{
|
||||
// C(num_loop) = A(num_loop) @ B(num_loop) with MX scaling
|
||||
// warp_gemm_loop(a_block_tile1, b_block_tile1, scale_a_tile_pong, scale_b_tile_pong);
|
||||
block_gemm(c_block_tile, a_block_tile1, b_block_tile1);
|
||||
/// TODO: remove these after creating a block gemm with scales
|
||||
ignore = scale_a_tile_pong;
|
||||
ignore = scale_b_tile_pong;
|
||||
block_gemm(c_block_tile, a_block_tile1, b_block_tile1, scale_a_tile_pong, scale_b_tile_pong);
|
||||
// block_gemm(c_block_tile, a_block_tile1, b_block_tile1);
|
||||
// /// TODO: remove these after creating a block gemm with scales
|
||||
// ignore = scale_a_tile_pong;
|
||||
// ignore = scale_b_tile_pong;
|
||||
}
|
||||
}
|
||||
else if(TailNum == TailNumber::One)
|
||||
{
|
||||
block_sync_lds();
|
||||
// C(num_loop) = A(num_loop) @ B(num_loop) with MX scaling
|
||||
// warp_gemm_loop(a_block_tile0, b_block_tile0, scale_a_tile_ping, scale_b_tile_ping);
|
||||
block_gemm(c_block_tile, a_block_tile0, b_block_tile0);
|
||||
/// TODO: remove these after creating a block gemm with scales
|
||||
ignore = scale_a_tile_ping;
|
||||
ignore = scale_b_tile_ping;
|
||||
block_gemm(c_block_tile, a_block_tile0, b_block_tile0, scale_a_tile_ping, scale_b_tile_ping);
|
||||
// block_gemm(c_block_tile, a_block_tile0, b_block_tile0);
|
||||
// /// TODO: remove these after creating a block gemm with scales
|
||||
// ignore = scale_a_tile_ping;
|
||||
// ignore = scale_b_tile_ping;
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
}
|
||||
|
||||
|
||||
@@ -25,6 +25,7 @@ struct MXGemmPipelineAgBgCrCompAsyncDefaultPolicy
|
||||
static constexpr int MXdlPack = 1; // No M packing
|
||||
static constexpr int NXdlPack = 1; // No N packing
|
||||
static constexpr int KXdlPack = 4; // Pack 4 consecutive e8m0 scales in K = 4 bytes = 1 int32
|
||||
static constexpr int BlockScaleSize = 32; // Each e8m0 scale covers 32 elements in K
|
||||
|
||||
// Override vector size methods to ensure compatibility with async buffer operations
|
||||
// Valid sizes for amd_async_buffer_load are 4, 12, or 16 bytes
|
||||
@@ -72,222 +73,6 @@ struct MXGemmPipelineAgBgCrCompAsyncDefaultPolicy
|
||||
return vector_size;
|
||||
}
|
||||
|
||||
// // DRAM tile distributions use STORAGE dimensions (for the storage tensor view)
|
||||
// template <typename Problem>
|
||||
// CK_TILE_HOST_DEVICE static constexpr auto MakeADramTileDistribution()
|
||||
// {
|
||||
// // using AsDataType = remove_cvref_t<typename Problem::AsDataTypeTuple>;
|
||||
// // using ADataType = remove_cvref_t<std::tuple_element_t<number<0>{}, AsDataType>>;
|
||||
|
||||
// // constexpr index_t BlockSize = Problem::kBlockSize;
|
||||
// // constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
|
||||
// // constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
// // constexpr index_t APackedSize = numeric_traits<remove_cvref_t<ADataType>>::PackedSize;
|
||||
|
||||
// // constexpr index_t K2 = 16; // 16 bytes
|
||||
// // constexpr index_t K1 = 128 / K2; // 8
|
||||
// // constexpr index_t K0 = KPerBlock / (K1 * K2 * APackedSize); // KPerBlock/256/packsize
|
||||
|
||||
// // constexpr index_t M2 = get_warp_size() / K1; // 8
|
||||
// // constexpr index_t M1 = BlockSize / get_warp_size(); // 4
|
||||
// // constexpr index_t M0 = MPerBlock / (M2 * M1);
|
||||
|
||||
// // static_assert(M0 * M1 * M2 == MPerBlock, "M0, M1, M2 must cover whole MPerBlock!");
|
||||
// // static_assert(K0 * K1 * K2 * APackedSize == KPerBlock,
|
||||
// // "K0, K1, K2 must cover whole KPerBlock!");
|
||||
|
||||
// // return make_static_tile_distribution(
|
||||
// // tile_distribution_encoding< //
|
||||
// // sequence<1>,
|
||||
// // tuple<sequence<M0, M1, M2>, sequence<K0, K1, K2>>, // ?,4,8 1,8,32 or 2,8,16
|
||||
// // tuple<sequence<1>, sequence<1, 2>>, // M1 M2,K1
|
||||
// // tuple<sequence<1>, sequence<2, 1>>,
|
||||
// // sequence<1, 2, 2>, // M0,K0,K2
|
||||
// // sequence<0, 0, 2>>{});
|
||||
// constexpr index_t BlockSize = Problem::kBlockSize;
|
||||
// constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
|
||||
// /// NOTE: for flatmm style byte tensor, divide KPerBlock by APackedSize to get STORAGE dimensions
|
||||
// // using AsDataType = remove_cvref_t<typename Problem::AsDataTypeTuple>;
|
||||
// // using ADataType = remove_cvref_t<std::tuple_element_t<number<0>{}, AsDataType>>;
|
||||
// // constexpr index_t APackedSize = numeric_traits<remove_cvref_t<ADataType>>::PackedSize;
|
||||
// // constexpr index_t KPerBlock = Problem::BlockGemmShape::kK / APackedSize; // Use STORAGE dimensions
|
||||
// /// NOTE: use original KPerBlock
|
||||
// constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
// constexpr index_t VecLoadSize = GetVectorSizeA<Problem>();
|
||||
// constexpr index_t NumWaveGroups = Problem::NumWaveGroups;
|
||||
|
||||
// using ALayout = remove_cvref_t<
|
||||
// std::tuple_element_t<number<0>{}, remove_cvref_t<typename Problem::AsLayoutTuple>>>;
|
||||
|
||||
|
||||
// if constexpr(std::is_same_v<ALayout, ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
// {
|
||||
// using TileEncodingPattern =
|
||||
// tile_distribution_encoding_pattern_2d<BlockSize,
|
||||
// MPerBlock,
|
||||
// KPerBlock, // Use storage dimensions
|
||||
// VecLoadSize,
|
||||
// getATileAccessPattern(),
|
||||
// NumWaveGroups>;
|
||||
// return TileEncodingPattern::make_2d_static_tile_distribution();
|
||||
// }
|
||||
// else
|
||||
// {
|
||||
// static_assert(false, "Not implemented");
|
||||
// // using TileEncodingPattern =
|
||||
// // tile_distribution_encoding_pattern_2d<BlockSize,
|
||||
// // KPerBlock,
|
||||
// // MPerBlock,
|
||||
// // VecLoadSize,
|
||||
// // getATileAccessPattern(),
|
||||
// // NumWaveGroups>;
|
||||
// // return TileEncodingPattern::make_2d_static_tile_distribution();
|
||||
// }
|
||||
// }
|
||||
|
||||
// template <typename Problem>
|
||||
// CK_TILE_HOST_DEVICE static constexpr auto MakeBDramTileDistribution()
|
||||
// {
|
||||
// /// NOTE: flatmm style dstr
|
||||
// // using BsDataType = remove_cvref_t<typename Problem::BsDataTypeTuple>;
|
||||
// // using BDataType = remove_cvref_t<std::tuple_element_t<number<0>{}, BsDataType>>;
|
||||
|
||||
// // constexpr index_t BlockSize = Problem::kBlockSize;
|
||||
// // constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
|
||||
// // constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
// // constexpr index_t BPackedSize = numeric_traits<remove_cvref_t<BDataType>>::PackedSize;
|
||||
|
||||
// // constexpr index_t K2 = 16; // 16 bytes
|
||||
// // constexpr index_t K1 = 128 / K2; // 8
|
||||
// // constexpr index_t K0 = KPerBlock / (K1 * K2 * BPackedSize); // KPerBlock/256/packsize
|
||||
|
||||
// // constexpr index_t N2 = get_warp_size() / K1; // 8
|
||||
// // constexpr index_t N1 = BlockSize / get_warp_size(); // 4
|
||||
// // constexpr index_t N0 = NPerBlock / (N2 * N1);
|
||||
|
||||
// // static_assert(N0 * N1 * N2 == NPerBlock, "N0, N1, N2 must cover whole NPerBlock!");
|
||||
// // static_assert(K0 * K1 * K2 * BPackedSize == KPerBlock,
|
||||
// // "K0, K1, K2 must cover whole KPerBlock!");
|
||||
|
||||
// // return make_static_tile_distribution(
|
||||
// // tile_distribution_encoding< //
|
||||
// // sequence<1>,
|
||||
// // tuple<sequence<N0, N1, N2>, sequence<K0, K1, K2>>, // ?,4,8 1,8,32 or 2,8,16
|
||||
// // tuple<sequence<1>, sequence<1, 2>>, // M1 M2,K1
|
||||
// // tuple<sequence<1>, sequence<2, 1>>,
|
||||
// // sequence<1, 2, 2>, // N0,K0,K2
|
||||
// // sequence<0, 0, 2>>{});
|
||||
// constexpr index_t BlockSize = Problem::kBlockSize;
|
||||
// constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
|
||||
// /// NOTE: for flatmm style byte tensor, divide KPerBlock by BPackedSize to get STORAGE dimensions
|
||||
// // using BsDataType = remove_cvref_t<typename Problem::BsDataTypeTuple>;
|
||||
// // using BDataType = remove_cvref_t<std::tuple_element_t<number<0>{}, BsDataType>>;
|
||||
// // constexpr index_t BPackedSize = numeric_traits<remove_cvref_t<BDataType>>::PackedSize;
|
||||
// // constexpr index_t KPerBlock = Problem::BlockGemmShape::kK / BPackedSize; // Use STORAGE dimensions
|
||||
// /// NOTE: use original KPerBlock
|
||||
// constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
// constexpr index_t VecLoadSize = GetVectorSizeB<Problem>();
|
||||
// constexpr index_t NumWaveGroups = Problem::NumWaveGroups;
|
||||
|
||||
// using BLayout = remove_cvref_t<
|
||||
// std::tuple_element_t<number<0>{}, remove_cvref_t<typename Problem::BsLayoutTuple>>>;
|
||||
|
||||
|
||||
// if constexpr(std::is_same_v<BLayout, ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
// {
|
||||
// static_assert(false, "Not implemented");
|
||||
// }
|
||||
// else
|
||||
// {
|
||||
// using TileEncodingPattern =
|
||||
// tile_distribution_encoding_pattern_2d<BlockSize,
|
||||
// NPerBlock,
|
||||
// KPerBlock, // Use storage dimensions
|
||||
// VecLoadSize,
|
||||
// getBTileAccessPattern(),
|
||||
// NumWaveGroups>;
|
||||
// return TileEncodingPattern::make_2d_static_tile_distribution();
|
||||
// }
|
||||
// }
|
||||
|
||||
// template <typename Problem>
|
||||
// CK_TILE_HOST_DEVICE static constexpr auto MakeMX_ALDSBytes_TileDistribution()
|
||||
// {
|
||||
// // static_assert(BlockWarps::at(I0) == 1, "requires Wave_M == 1");
|
||||
// using AsDataType = remove_cvref_t<typename Problem::AsDataTypeTuple>;
|
||||
// using ADataType = remove_cvref_t<std::tuple_element_t<number<0>{}, AsDataType>>;
|
||||
// constexpr index_t APackedSize = numeric_traits<remove_cvref_t<ADataType>>::PackedSize;
|
||||
// using BlockWarps = typename Problem::BlockGemmShape::BlockWarps;
|
||||
// constexpr index_t MWarps = BlockWarps::at(number<0>{});
|
||||
// constexpr index_t NWarps = BlockWarps::at(number<1>{});
|
||||
// constexpr index_t MPerXdl = Problem::BlockGemmShape::WarpTile::at(I0);
|
||||
// // constexpr index_t NPerXdl = Problem::BlockGemmShape::WarpTile::at(I1);
|
||||
// constexpr index_t KPerXdl = Problem::BlockGemmShape::WarpTile::at(I2);
|
||||
// constexpr index_t K_Lane = get_warp_size() / 16; // 4
|
||||
// constexpr index_t K_Thread = KPerXdl / K_Lane; // 32
|
||||
// constexpr index_t DWORDx4 = 16;
|
||||
// constexpr index_t AK1 = DWORDx4 * APackedSize;
|
||||
|
||||
// if constexpr(K_Thread == AK1)
|
||||
// return make_static_tile_distribution(
|
||||
// tile_distribution_encoding< //
|
||||
// sequence<NWarps>,
|
||||
// tuple<sequence<MWarps, 1, MPerXdl>, sequence<K_Lane, AK1 / APackedSize>>,
|
||||
// tuple<sequence<1, 0>, sequence<2, 1>>,
|
||||
// tuple<sequence<0, 0>, sequence<0, 2>>,
|
||||
// sequence<2>,
|
||||
// sequence<1>>{});
|
||||
// else
|
||||
// return make_static_tile_distribution(
|
||||
// tile_distribution_encoding< //
|
||||
// sequence<NWarps>,
|
||||
// tuple<sequence<MWarps, 1, MPerXdl>,
|
||||
// sequence<K_Thread / AK1, K_Lane, AK1 / APackedSize>>,
|
||||
// tuple<sequence<1, 0>, sequence<2, 1>>,
|
||||
// tuple<sequence<0, 0>, sequence<1, 2>>,
|
||||
// sequence<2, 2>,
|
||||
// sequence<0, 2>>{});
|
||||
// }
|
||||
|
||||
// template <typename Problem>
|
||||
// CK_TILE_HOST_DEVICE static constexpr auto MakeMX_BLDSBytes_TileDistribution()
|
||||
// {
|
||||
// // static_assert(BlockWarps::at(I0) == 1, "requires Wave_M == 1");
|
||||
// using BsDataType = remove_cvref_t<typename Problem::BsDataTypeTuple>;
|
||||
// using BDataType = remove_cvref_t<std::tuple_element_t<number<0>{}, BsDataType>>;
|
||||
// constexpr index_t BPackedSize = numeric_traits<remove_cvref_t<BDataType>>::PackedSize;
|
||||
// using BlockWarps = typename Problem::BlockGemmShape::BlockWarps;
|
||||
// constexpr index_t MWarps = BlockWarps::at(number<0>{});
|
||||
// constexpr index_t NWarps = BlockWarps::at(number<1>{});
|
||||
// // constexpr index_t MPerXdl = Problem::BlockGemmShape::WarpTile::at(I0);
|
||||
// constexpr index_t NPerXdl = Problem::BlockGemmShape::WarpTile::at(I1);
|
||||
// constexpr index_t KPerXdl = Problem::BlockGemmShape::WarpTile::at(I2);
|
||||
// constexpr index_t K_Lane = get_warp_size() / 16; // 4
|
||||
// constexpr index_t K_Thread = KPerXdl / K_Lane; // 32
|
||||
// constexpr index_t DWORDx4 = 16;
|
||||
// constexpr index_t BK1 = DWORDx4 * BPackedSize;
|
||||
|
||||
// if constexpr(K_Thread == BK1)
|
||||
// return make_static_tile_distribution(
|
||||
// tile_distribution_encoding< //
|
||||
// sequence<MWarps>,
|
||||
// tuple<sequence<NWarps, 1, NPerXdl>, sequence<K_Lane, BK1 / BPackedSize>>,
|
||||
// tuple<sequence<1, 0>, sequence<2, 1>>,
|
||||
// tuple<sequence<0, 0>, sequence<0, 2>>,
|
||||
// sequence<2>,
|
||||
// sequence<1>>{});
|
||||
// else
|
||||
// return make_static_tile_distribution(
|
||||
// tile_distribution_encoding< //
|
||||
// sequence<MWarps>,
|
||||
// tuple<sequence<NWarps, 1, NPerXdl>,
|
||||
// sequence<K_Thread / BK1, K_Lane, BK1 / BPackedSize>>,
|
||||
// tuple<sequence<1, 0>, sequence<2, 1>>,
|
||||
// tuple<sequence<0, 0>, sequence<1, 2>>,
|
||||
// sequence<2, 2>,
|
||||
// sequence<0, 2>>{});
|
||||
// }
|
||||
|
||||
template <typename Problem,
|
||||
typename OverrideADataType = remove_cvref_t<typename Problem::ADataType>>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeALdsBlockDescriptor()
|
||||
@@ -413,8 +198,7 @@ struct MXGemmPipelineAgBgCrCompAsyncDefaultPolicy
|
||||
return BlockGemmARegBRegCRegV1<Problem, BlockGemmPolicy>{};
|
||||
}
|
||||
|
||||
// MX Scale tile distributions for loading from global memory
|
||||
// Using the proven "Flat" patterns from v1 policy
|
||||
// MX Scale tile distributions for loading from global memory
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeMX_ScaleA_DramTileDistribution()
|
||||
{
|
||||
@@ -425,20 +209,23 @@ struct MXGemmPipelineAgBgCrCompAsyncDefaultPolicy
|
||||
constexpr index_t MWarp = BlockWarps::at(number<0>{});
|
||||
constexpr index_t NWarp = BlockWarps::at(number<1>{});
|
||||
constexpr index_t MPerXdl = WarpTile::at(number<0>{});
|
||||
constexpr index_t K_Lane = get_warp_size() / MPerXdl; // 4 for 16x16 mfma
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
constexpr index_t ScaleKDimPerBlock = KPerBlock / BlockScaleSize / KXdlPack; // int32s per block
|
||||
constexpr index_t K_Lane = get_warp_size() / MPerXdl; // 64/16 = 4 threads in K dimension
|
||||
// constexpr index_t KPackedElementsPerThread = ScaleKDimPerBlock / K_Lane; // 4/4 = 1 for K=512
|
||||
|
||||
// Scale A: [MWarp * MPerXdl, K/32/KXdlPack] for warp-level tile
|
||||
// Distribution: simple 2D for loading int32 packed scales
|
||||
// TODO: check which layout to actually use (could use KxN)
|
||||
// Scale A: [MWarp * MPerXdl, ScaleKDimPerBlock] warp-level tile
|
||||
// For K=512: [16, 4], distribute 4 int32s across 4 K_Lane threads (1 each)
|
||||
// Strided packing: thread at K_lane=k gets one int32 with scales for all kIters at K position k
|
||||
// Distribution: Replicate in M dimension, distribute in K dimension (no vectorization - scalar loads)
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<NWarp>, // repeat over NWarps
|
||||
tuple<sequence<MWarp, MPerXdl>, // M dimension
|
||||
sequence<K_Lane, 1>>, // K dimension (int32 vec load)
|
||||
tuple<sequence<1, 0>, sequence<2, 1>>, // which direction
|
||||
tuple<sequence<0, 0>, sequence<0, 1>>, // which index
|
||||
// <repeat, vec_load>
|
||||
sequence<2>,
|
||||
sequence<1>>{});
|
||||
tile_distribution_encoding<sequence<NWarp>, // repeat over NWarps
|
||||
tuple<sequence<MWarp, MPerXdl>, // M dimension
|
||||
sequence<ScaleKDimPerBlock, K_Lane>>, // K dimension
|
||||
tuple<sequence<1, 0>, sequence<2, 1>>, // <MWarp, NWarp>, <K_Lane, MPerXdl>
|
||||
tuple<sequence<0, 0>, sequence<1, 1>>,
|
||||
sequence<2>, // ScaleKDimPerBlock, all int32 needed to cover KPerBlock
|
||||
sequence<0>>{});
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
@@ -451,20 +238,23 @@ struct MXGemmPipelineAgBgCrCompAsyncDefaultPolicy
|
||||
constexpr index_t MWarp = BlockWarps::at(number<0>{});
|
||||
constexpr index_t NWarp = BlockWarps::at(number<1>{});
|
||||
constexpr index_t NPerXdl = WarpTile::at(number<1>{});
|
||||
constexpr index_t K_Lane = get_warp_size() / NPerXdl; // 4 for 16x16 mfma
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
constexpr index_t ScaleKDimPerBlock = KPerBlock / BlockScaleSize / KXdlPack; // int32s per block
|
||||
constexpr index_t K_Lane = get_warp_size() / NPerXdl; // 64/16 = 4 threads in K dimension
|
||||
// constexpr index_t KPackedElementsPerThread = ScaleKDimPerBlock / K_Lane; // 4/4 = 1 for K=512
|
||||
|
||||
// Scale B: [K/32/KXdlPack, NWarp * NPerXdl] for warp-level tile
|
||||
// Layout is [K, N] where K is packed int32
|
||||
// TODO: check which layout to actually use (could use KxN)
|
||||
// Scale B: [ScaleKDimPerBlock, NWarp * NPerXdl] warp-level tile
|
||||
// For K=512: [4, 64], distribute 4 int32s across 4 K_Lane threads (1 each)
|
||||
// Strided packing: thread at K_lane=k gets one int32 with scales for all kIters at K position k
|
||||
// Distribution: Distribute in K dimension (no vectorization - scalar loads), replicate in N dimension
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<MWarp>, // repeat over MWarps
|
||||
tuple<sequence<K_Lane, 1>, // K dimension (int32 vec load)
|
||||
sequence<NWarp, NPerXdl>>, // N dimension
|
||||
tuple<sequence<2, 1>, sequence<0, 1>>, // which direction
|
||||
tuple<sequence<0, 1>, sequence<0, 0>>, // which index
|
||||
// <repeat, vec_load>
|
||||
sequence<1>,
|
||||
sequence<1>>{});
|
||||
tile_distribution_encoding<sequence<MWarp>, // repeat over MWarps
|
||||
tuple<sequence<NWarp, NPerXdl>, // N dimension
|
||||
sequence<ScaleKDimPerBlock, K_Lane>>, // K dimension
|
||||
tuple<sequence<0, 1>, sequence<2, 1>>, // <MWarp, NWarp>, <K_Lane, NPerXdl>
|
||||
tuple<sequence<0, 0>, sequence<1, 1>>,
|
||||
sequence<2>, // ScaleKDimPerBlock, all int32 needed to cover KPerBlock
|
||||
sequence<0>>{});
|
||||
}
|
||||
};
|
||||
} // namespace ck_tile
|
||||
|
||||
Reference in New Issue
Block a user