[rocm-libraries] ROCm/rocm-libraries#4594 (commit 1fce4cb)

[CK_TILE] MX GEMM non-preshuffled RCR layout

## Motivation

Implements a GEMM with MX scaling for fp4 and fp8 in non-preshuffled
layouts using async pipeline.

## Technical Details

<!-- Explain the changes along with any relevant GitHub links. -->

## Test Plan

<!-- Explain any relevant testing done to verify this PR. -->

## Test Result

<!-- Briefly summarize test outcomes. -->

## Submission Checklist

- [ ] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
Sami Remes
2026-03-10 20:12:43 +00:00
committed by assistant-librarian[bot]
parent b8def2c724
commit 8f27f65d44
40 changed files with 2729 additions and 43 deletions

View File

View File

@@ -0,0 +1,413 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include <iostream>
#include <string>
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp"
#include "ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp"
#include "ck_tile/ops/gemm_mx/kernel/scale_pointer.hpp"
namespace ck_tile {
template <typename ScaleM = MXScalePointer<e8m0_t, -1>,
typename ScaleN = MXScalePointer<e8m0_t, -1>,
index_t NumATensor = 1,
index_t NumBTensor = 1,
index_t NumDTensor = 0>
struct MXGemmKernelArgs : UniversalGemmKernelArgs<NumATensor, NumBTensor, NumDTensor>
{
using Base = UniversalGemmKernelArgs<NumATensor, NumBTensor, NumDTensor>;
CK_TILE_HOST MXGemmKernelArgs(const std::array<const void*, NumATensor>& as_ptr_,
const std::array<const void*, NumBTensor>& bs_ptr_,
const std::array<const void*, NumDTensor>& ds_ptr_,
void* e_ptr_,
index_t k_batch_,
index_t M_,
index_t N_,
index_t K_,
const std::array<index_t, NumATensor>& stride_As_,
const std::array<index_t, NumBTensor>& stride_Bs_,
const std::array<index_t, NumDTensor>& stride_Ds_,
index_t stride_E_,
ScaleM scale_m_ptr_,
ScaleN scale_n_ptr_)
: Base{as_ptr_,
bs_ptr_,
ds_ptr_,
e_ptr_,
M_,
N_,
K_,
stride_As_,
stride_Bs_,
stride_Ds_,
stride_E_,
k_batch_},
scale_m_ptr(scale_m_ptr_),
scale_n_ptr(scale_n_ptr_)
{
}
ScaleM scale_m_ptr;
ScaleN scale_n_ptr;
};
template <typename TilePartitioner_, typename MXGemmPipeline_, typename EpiloguePipeline_>
struct MXGemmKernel : UniversalGemmKernel<TilePartitioner_, MXGemmPipeline_, EpiloguePipeline_>
{
using Underlying = UniversalGemmKernel<TilePartitioner_, MXGemmPipeline_, EpiloguePipeline_>;
using TilePartitioner = remove_cvref_t<TilePartitioner_>;
using MXGemmPipeline = remove_cvref_t<MXGemmPipeline_>;
using BlockGemmShape = remove_cvref_t<typename MXGemmPipeline::BlockGemmShape>;
using EpiloguePipeline = remove_cvref_t<EpiloguePipeline_>;
using ALayout = remove_cvref_t<typename MXGemmPipeline::ALayout>;
using BLayout = remove_cvref_t<typename MXGemmPipeline::BLayout>;
using ELayout = remove_cvref_t<typename MXGemmPipeline::CLayout>;
using DsLayout = remove_cvref_t<typename EpiloguePipeline::DsLayout>;
using DsDataType = remove_cvref_t<typename EpiloguePipeline::DsDataType>;
static constexpr index_t KernelBlockSize = MXGemmPipeline::BlockSize;
static constexpr bool UsePersistentKernel = MXGemmPipeline::UsePersistentKernel;
// Below type is actually accumulation data type - the output of block GEMM.
using EDataType = remove_cvref_t<typename EpiloguePipeline::ODataType>;
static constexpr auto I0 = number<0>();
static constexpr auto I1 = number<1>();
static constexpr auto I2 = number<2>();
static constexpr auto I3 = number<3>();
static constexpr auto I4 = number<4>();
static constexpr auto I5 = number<5>();
static constexpr index_t NumATensor = Underlying::AsDataType::size();
static constexpr index_t NumBTensor = Underlying::BsDataType::size();
static constexpr index_t NumDTensor = Underlying::DsDataType::size();
using ADataType = remove_cvref_t<std::tuple_element_t<I0, typename Underlying::AsDataType>>;
using BDataType = remove_cvref_t<std::tuple_element_t<I0, typename Underlying::BsDataType>>;
static constexpr auto MThreadPerXdl = BlockGemmShape::WarpTile::at(number<0>{});
static constexpr auto NThreadPerXdl = BlockGemmShape::WarpTile::at(number<1>{});
static constexpr auto KThreadPerXdl = 64 / MThreadPerXdl;
static constexpr auto APackedSize = numeric_traits<ADataType>::PackedSize;
static constexpr auto BPackedSize = numeric_traits<BDataType>::PackedSize;
static constexpr int kBlockPerCu = 1;
static_assert(DsLayout::size() == DsDataType::size(),
"The size of DsLayout and DsDataType should be the same");
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
{
// clang-format off
return concat('_', "mx_gemm", gemm_prec_str<ADataType, BDataType>, MXGemmPipeline::GetName());
// clang-format on
}
template <typename ScaleM, typename ScaleN>
using KernelArgs = MXGemmKernelArgs<ScaleM, ScaleN, NumATensor, NumBTensor, NumDTensor>;
template <typename ScaleM, typename ScaleN>
CK_TILE_HOST static auto MakeKernelArgs(const std::array<const void*, NumATensor>& as_ptr,
const std::array<const void*, NumBTensor>& bs_ptr,
const std::array<const void*, NumDTensor>& ds_ptr,
void* e_ptr,
index_t k_batch,
index_t M,
index_t N,
index_t K,
const std::array<index_t, NumATensor>& stride_As,
const std::array<index_t, NumBTensor>& stride_Bs,
const std::array<index_t, NumDTensor>& stride_Ds,
index_t stride_E,
ScaleM scale_m_ptr,
ScaleN scale_n_ptr)
{
return KernelArgs<ScaleM, ScaleN>(as_ptr,
bs_ptr,
ds_ptr,
e_ptr,
k_batch,
M,
N,
K,
stride_As,
stride_Bs,
stride_Ds,
stride_E,
scale_m_ptr,
scale_n_ptr);
}
template <class ScaleM, class ScaleN>
CK_TILE_HOST static constexpr auto GridSize(const KernelArgs<ScaleM, ScaleN>& kargs)
{
const int total_work_tile_cnt = TilePartitioner::GridSize(kargs.M, kargs.N);
if constexpr(UsePersistentKernel)
{
hipDeviceProp_t prop;
int deviceId = 0; // default device
int dync_smem_size = 0;
int maxActiveBlocksPerCU = 0;
if(hipGetDeviceProperties(&prop, deviceId) != hipSuccess)
throw std::runtime_error(std::string("hipGetDeviceProperties failed: ") +
hipGetErrorName(hipGetLastError()));
if(hipOccupancyMaxActiveBlocksPerMultiprocessor(
&maxActiveBlocksPerCU,
reinterpret_cast<void*>(
kentry<1, MXGemmKernel, remove_cvref_t<decltype(kargs)>>),
KernelBlockSize,
dync_smem_size) != hipSuccess)
throw std::runtime_error(
std::string("hipOccupancyMaxActiveBlocksPerMultiprocessor failed: ") +
hipGetErrorName(hipGetLastError()));
const int persistent_block_size = prop.multiProcessorCount * maxActiveBlocksPerCU;
const int actual_grid_size = min(persistent_block_size, total_work_tile_cnt);
return dim3(actual_grid_size, 1, 1);
}
else
{
// Non-persistent: use full grid size based on number of tiles
return dim3(total_work_tile_cnt, 1, 1);
}
}
using SplitKBatchOffset = typename Underlying::SplitKBatchOffset;
// Create C block window following UniversalGemmKernel pattern
template <memory_operation_enum DstInMemOp = memory_operation_enum::set,
typename ScaleM,
typename ScaleN>
CK_TILE_DEVICE static auto MakeCBlockWindows(EDataType* e_ptr,
const KernelArgs<ScaleM, ScaleN>& kargs,
const index_t i_m,
const index_t i_n)
{
// Create tensor view for E/C tensor
constexpr index_t vector_size = EpiloguePipeline::GetVectorSizeC();
const auto& e_tensor_view = [&]() -> auto {
if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
{
return make_naive_tensor_view<address_space_enum::global, DstInMemOp>(
e_ptr,
make_tuple(kargs.M, kargs.N),
make_tuple(kargs.stride_E, 1),
number<vector_size>{},
number<1>{});
}
else
{
return make_naive_tensor_view<address_space_enum::global, DstInMemOp>(
e_ptr,
make_tuple(kargs.M, kargs.N),
make_tuple(1, kargs.stride_E),
number<1>{},
number<vector_size>{});
}
}();
// Create padded view
const auto& e_pad_view = [&]() {
if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
{
return pad_tensor_view(e_tensor_view,
make_tuple(number<TilePartitioner::MPerBlock>{},
number<TilePartitioner::NPerBlock>{}),
sequence<false, false>{});
}
else
{
return pad_tensor_view(e_tensor_view,
make_tuple(number<TilePartitioner::MPerBlock>{},
number<TilePartitioner::NPerBlock>{}),
sequence<false, false>{});
}
}();
// Create block window
auto c_block_window = make_tile_window(
e_pad_view,
make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::NPerBlock>{}),
{i_m, i_n});
return c_block_window;
}
// Create scale A block windows following the pattern of MakeABlockWindows
template <typename ScaleM, typename ScaleN>
CK_TILE_DEVICE static auto MakeScaleABlockWindows(const KernelArgs<ScaleM, ScaleN>& kargs,
const index_t i_m)
{
auto scale_a = kargs.scale_m_ptr;
static constexpr int BlockScaleSize = ScaleM::GranularityK;
const auto scale_k_size = kargs.K / BlockScaleSize;
// A scale tensor view - layout [M, scale_k_size] with e8m0_t elements
// Use e8m0_t directly without packing
const auto scale_a_tensor_view = make_naive_tensor_view<address_space_enum::global>(
reinterpret_cast<const e8m0_t*>(scale_a.ptr),
make_tuple(kargs.M, scale_k_size),
make_tuple(scale_k_size, 1));
// Create block window for scale A
// K dimension: scale_k_size e8m0_t elements
// i_m is element offset (iM * MPerBlock), not tile index
auto scale_a_block_window =
make_tile_window(scale_a_tensor_view,
make_tuple(number<TilePartitioner::MPerBlock>{},
number<TilePartitioner::KPerBlock / BlockScaleSize>{}),
{i_m, 0});
return scale_a_block_window;
}
// Create scale B block windows following the pattern of MakeBBlockWindows
template <typename ScaleM, typename ScaleN>
CK_TILE_DEVICE static auto MakeScaleBBlockWindows(const KernelArgs<ScaleM, ScaleN>& kargs,
const index_t i_n)
{
auto scale_b = kargs.scale_n_ptr;
static constexpr int BlockScaleSize = ScaleN::GranularityK;
const auto scale_k_size = kargs.K / BlockScaleSize;
// B scale tensor view
// Host stores as [K/32, N] col-major = [N, K/32] row-major from access perspective
const auto scale_b_tensor_view = make_naive_tensor_view<address_space_enum::global>(
reinterpret_cast<const e8m0_t*>(scale_b.ptr),
make_tuple(kargs.N, scale_k_size), // [N, K/32] for access
make_tuple(scale_k_size, 1)); // stride to match col-major storage
// Create block window for scale B
// Tile window shape matches access pattern: [NPerBlock, KPerBlock/32]
// i_n is element offset (iN * NPerBlock)
auto scale_b_block_window =
make_tile_window(scale_b_tensor_view,
make_tuple(number<TilePartitioner::NPerBlock>{},
number<TilePartitioner::KPerBlock / BlockScaleSize>{}),
{i_n, 0});
return scale_b_block_window;
}
template <class ScaleM, class ScaleN>
CK_TILE_DEVICE static void RunMxGemm(const std::array<const ADataType*, NumATensor>& as_ptr,
const std::array<const BDataType*, NumBTensor>& bs_ptr,
const std::array<const void*, NumDTensor>& ds_ptr,
EDataType* e_ptr,
void* smem_ptr_ping,
void* smem_ptr_pong,
const KernelArgs<ScaleM, ScaleN>& kargs,
const SplitKBatchOffset& splitk_batch_offset,
const index_t i_m,
const index_t i_n)
{
// Create block windows directly, following the new pattern from UniversalGemmKernel
// i_m and i_n are element offsets (iM * MPerBlock, iN * NPerBlock), not tile indices
const auto& a_block_window =
Underlying::MakeABlockWindows(as_ptr, kargs, splitk_batch_offset.splitted_k, i_m);
const auto& b_block_window =
Underlying::MakeBBlockWindows(bs_ptr, kargs, splitk_batch_offset.splitted_k, i_n);
const auto& d_block_window = Underlying::MakeDBlockWindows(ds_ptr, kargs, i_m, i_n);
// Create scale block windows using our new functions
const auto& scale_a_block_window = MakeScaleABlockWindows(kargs, i_m);
const auto& scale_b_block_window = MakeScaleBBlockWindows(kargs, i_n);
const index_t num_loop = TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k);
static_assert(ScaleM::GranularityK == ScaleN::GranularityK // have the same granK
|| ScaleM::GranularityMN == -1 // or ScaleA is disable
|| ScaleN::GranularityMN == -1, // or ScaleB is disable
"ScaleM and ScaleN should have the same GranularityK");
const auto& c_block_tile = MXGemmPipeline{}(a_block_window[number<0>{}],
b_block_window[number<0>{}],
scale_a_block_window,
scale_b_block_window,
num_loop,
smem_ptr_ping,
smem_ptr_pong);
// Run Epilogue Pipeline - create C block window directly
auto c_block_window = MakeCBlockWindows(e_ptr, kargs, i_m, i_n);
EpiloguePipeline{}(c_block_window, c_block_tile, d_block_window, smem_ptr_ping);
}
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemPingSize()
{
return max(MXGemmPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize());
}
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemPongSize()
{
return MXGemmPipeline::GetSmemSize();
}
template <class ScaleM, class ScaleN>
CK_TILE_DEVICE void operator()(KernelArgs<ScaleM, ScaleN> kargs,
int partition_idx = get_block_id()) const
{
const int total_work_tile_cnt =
amd_wave_read_first_lane(TilePartitioner::GridSize(kargs.M, kargs.N));
// Allocate shared memory for ping pong buffers
__shared__ char smem_ptr_ping[GetSmemPingSize()];
__shared__ char smem_ptr_pong[GetSmemPongSize()];
// Support both persistent and non-persistent modes
do
{
const auto [iM, iN] =
TilePartitioner{kargs.M, kargs.N}.GetOutputTileIndex(partition_idx);
const index_t i_m = amd_wave_read_first_lane(iM * TilePartitioner::MPerBlock);
const index_t i_n = amd_wave_read_first_lane(iN * TilePartitioner::NPerBlock);
// Cast to base class for SplitKBatchOffset construction
const SplitKBatchOffset splitk_batch_offset(
static_cast<const typename Underlying::KernelArgs&>(kargs));
// options
EDataType* e_ptr = static_cast<EDataType*>(kargs.e_ptr);
// options
std::array<const ADataType*, NumATensor> as_ptr;
static_for<0, NumATensor, 1>{}([&](auto i) {
as_ptr[i] = static_cast<const ADataType*>(kargs.as_ptr[i]) +
splitk_batch_offset.as_k_split_offset[i] / APackedSize;
});
std::array<const BDataType*, NumBTensor> bs_ptr;
static_for<0, NumBTensor, 1>{}([&](auto i) {
bs_ptr[i] = static_cast<const BDataType*>(kargs.bs_ptr[i]) +
splitk_batch_offset.bs_k_split_offset[i] / BPackedSize;
});
RunMxGemm<ScaleM, ScaleN>(as_ptr,
bs_ptr,
kargs.ds_ptr,
e_ptr,
smem_ptr_ping,
smem_ptr_pong,
kargs,
splitk_batch_offset,
i_m,
i_n);
partition_idx += gridDim.x;
} while(UsePersistentKernel && partition_idx < total_work_tile_cnt);
}
};
} // namespace ck_tile

View File

@@ -0,0 +1,113 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/core.hpp"
namespace ck_tile {
template <typename ScaleType, int SharedGranularityMN, int SharedGranularityK = 0>
struct MXScalePointer
{
static constexpr int GranularityMN = SharedGranularityMN;
static constexpr int GranularityK = SharedGranularityK;
static_assert(GranularityK != 0,
"GranularityK cannot be zero in primary template; "
"use the partial specialization for GranularityK == 0");
const ScaleType* ptr;
CK_TILE_HOST_DEVICE MXScalePointer() = default;
CK_TILE_HOST_DEVICE MXScalePointer(const ScaleType* ptr_) : ptr(ptr_) {}
CK_TILE_HOST_DEVICE MXScalePointer(const ScaleType* ptr_, [[maybe_unused]] index_t length_)
: ptr(ptr_)
{
}
CK_TILE_HOST_DEVICE MXScalePointer operator+(index_t offset) const
{
MXScalePointer ret;
if constexpr(GranularityMN == 0)
{
ret.ptr = ptr + offset / GranularityK;
}
else
{
ret.ptr = ptr + offset / GranularityMN / GranularityK;
}
return ret;
}
CK_TILE_HOST_DEVICE ScaleType operator[](index_t i) const = delete;
};
template <typename ScaleType, int SharedGranularityMN>
struct MXScalePointer<ScaleType, SharedGranularityMN, 0>
{
static constexpr int GranularityMN = SharedGranularityMN;
static constexpr int GranularityK = 0;
static_assert(GranularityMN != 0);
const ScaleType* ptr;
index_t length;
CK_TILE_HOST_DEVICE MXScalePointer() = default;
CK_TILE_HOST_DEVICE MXScalePointer(const ScaleType* ptr_) : ptr(ptr_), length(1) {}
CK_TILE_HOST_DEVICE MXScalePointer(const ScaleType* ptr_, index_t length_)
: ptr(ptr_), length(length_)
{
}
CK_TILE_HOST_DEVICE MXScalePointer operator+(index_t offset) const
{
MXScalePointer ret;
if constexpr(GranularityMN == 1)
{
ret.ptr = ptr + offset;
ret.length = length - offset;
}
else
{
ret.ptr = ptr + offset / GranularityMN;
ret.length = length - offset / GranularityMN;
}
return ret;
}
CK_TILE_HOST_DEVICE ScaleType operator[](index_t i) const
{
// with additional oob check
if constexpr(GranularityMN == 1)
return i < length ? ptr[i] : 0;
else
return i / GranularityMN < length ? ptr[i / GranularityMN] : 0;
}
};
// shared granularityMN = -1 means no scale
template <typename ScaleType>
struct MXScalePointer<ScaleType, -1, 0>
{
static constexpr int GranularityMN = -1;
static constexpr int GranularityK = 0;
const ScaleType* ptr = nullptr;
CK_TILE_HOST_DEVICE constexpr MXScalePointer() = default;
CK_TILE_HOST_DEVICE constexpr MXScalePointer(const ScaleType*) {}
CK_TILE_HOST_DEVICE constexpr MXScalePointer(const ScaleType*, index_t) {}
CK_TILE_HOST_DEVICE constexpr MXScalePointer operator+(index_t) const
{
return MXScalePointer{};
}
CK_TILE_HOST_DEVICE constexpr ScaleType operator[](index_t) const
{
return 1; // alway return 1, it doesn't change the result
}
};
} // namespace ck_tile

View File

@@ -0,0 +1,723 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/core/arch/arch.hpp"
#include "ck_tile/core/tensor/load_tile.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp"
#include "ck_tile/ops/gemm_mx/pipeline/gemm_pipeline_ag_bg_cr_comp_async_default_policy.hpp"
namespace ck_tile {
// A Tile Window: global memory
// B Tile Window: global memory
// C Distributed tensor: register
// MX scaling support with OpSel
template <typename Problem>
struct BaseMXGemmPipelineAgBgCrCompAsync
{
static constexpr index_t PrefetchStages = 2;
static constexpr index_t PrefillStages = 1;
static constexpr index_t GlobalBufferNum = 1;
static constexpr bool UsePersistentKernel = Problem::Traits::UsePersistentKernel;
CK_TILE_HOST_DEVICE static constexpr bool BlockHasHotloop(index_t num_loop)
{
return num_loop > PrefetchStages;
}
CK_TILE_HOST_DEVICE static constexpr TailNumber GetBlockLoopTailNum(index_t num_loop)
{
if(num_loop == 1)
{
return TailNumber::One;
}
if(num_loop % PrefetchStages == 1)
{
return TailNumber::Three;
}
else
{
return TailNumber::Two;
}
}
template <typename RunFunction>
CK_TILE_HOST_DEVICE static auto
TailHandler(const RunFunction& run_func, bool has_hot_loop, TailNumber tail_number)
{
// Handle all the valid cases.
if(has_hot_loop)
{
if(tail_number == TailNumber::Three)
{
return run_func(bool_constant<true>{},
integral_constant<TailNumber, TailNumber::Three>{});
}
else if(tail_number == TailNumber::Two)
{
return run_func(bool_constant<true>{},
integral_constant<TailNumber, TailNumber::Two>{});
}
}
else
{
if(tail_number == TailNumber::Three)
{
return run_func(bool_constant<false>{},
integral_constant<TailNumber, TailNumber::Three>{});
}
else if(tail_number == TailNumber::Two)
{
return run_func(bool_constant<false>{},
integral_constant<TailNumber, TailNumber::Two>{});
}
else
{
return (run_func(bool_constant<false>{},
integral_constant<TailNumber, TailNumber::One>{}));
}
}
// If execution reaches here, it's an invalid tail_number because it wasn't handled above.
#if defined(__HIP_DEVICE_COMPILE__)
__builtin_unreachable();
#else
throw std::logic_error(
"Invalid TailNumber: Only TailNumber::Three and TailNumber::Two are supported");
#endif
}
};
/**
* @brief MX GEMM compute optimized pipeline version async; which is based on V4.
*
* This pipeline introduces asynchronous load from global memory to LDS,
* skipping the intermediate loading into pipeline registers.
* Supports MX scaling with e8m0 packed values and OpSel.
*/
template <typename Problem, typename Policy = MXGemmPipelineAgBgCrCompAsyncDefaultPolicy>
struct MXGemmPipelineAgBgCrCompAsync : public BaseMXGemmPipelineAgBgCrCompAsync<Problem>
{
using Base = BaseMXGemmPipelineAgBgCrCompAsync<Problem>;
using PipelineImplBase = GemmPipelineAgBgCrImplBase<Problem, Policy>;
using AsDataType = remove_cvref_t<typename Problem::AsDataTypeTuple>;
using BsDataType = remove_cvref_t<typename Problem::BsDataTypeTuple>;
using CDataType = remove_cvref_t<typename Problem::CDataType>;
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
using AsLayout = remove_cvref_t<typename Problem::AsLayoutTuple>;
using BsLayout = remove_cvref_t<typename Problem::BsLayoutTuple>;
using CLayout = remove_cvref_t<typename Problem::CLayout>;
using AElementWise = remove_cvref_t<typename Problem::AElementWise>;
using BElementWise = remove_cvref_t<typename Problem::BElementWise>;
using ALayout = remove_cvref_t<std::tuple_element_t<0, AsLayout>>;
using BLayout = remove_cvref_t<std::tuple_element_t<0, BsLayout>>;
using ADataType = remove_cvref_t<std::tuple_element_t<0, AsDataType>>;
using BDataType = remove_cvref_t<std::tuple_element_t<0, BsDataType>>;
static_assert(!std::is_same_v<BDataType, pk_int4_t>, "Not implemented");
// Each scale covers 32 K elements
static constexpr index_t ScaleBlockSize = 32;
static constexpr index_t APackedSize =
ck_tile::numeric_traits<remove_cvref_t<ADataType>>::PackedSize;
static constexpr index_t BPackedSize =
ck_tile::numeric_traits<remove_cvref_t<BDataType>>::PackedSize;
using BlockGemm = remove_cvref_t<decltype(Policy::template GetBlockGemm<Problem>())>;
using I0 = number<0>;
using I1 = number<1>;
using I2 = number<2>;
static constexpr index_t BlockSize = Problem::kBlockSize;
static constexpr index_t MPerBlock = BlockGemmShape::kM;
static constexpr index_t NPerBlock = BlockGemmShape::kN;
static constexpr index_t KPerBlock = BlockGemmShape::kK;
template <bool IsWave32Host = false>
static constexpr index_t GetVectorSizeA()
{
return Policy::template GetVectorSizeA<Problem, IsWave32Host>();
}
template <bool IsWave32Host = false>
static constexpr index_t GetVectorSizeB()
{
return Policy::template GetVectorSizeB<Problem, IsWave32Host>();
}
static constexpr index_t GetVectorSizeC() { return Policy::template GetVectorSizeC<Problem>(); }
static constexpr index_t GetSmemPackA() { return Policy::template GetSmemPackA<Problem>(); }
static constexpr index_t GetSmemPackB() { return Policy::template GetSmemPackB<Problem>(); }
static constexpr index_t NumWaveGroups = Problem::NumWaveGroups;
static constexpr index_t Preshuffle = Problem::Preshuffle;
static constexpr bool kPadM = Problem::kPadM;
static constexpr bool kPadN = Problem::kPadN;
static constexpr bool kPadK = Problem::kPadK;
static constexpr bool DoubleSmemBuffer = Problem::DoubleSmemBuffer;
static constexpr auto Scheduler = Problem::Scheduler;
static constexpr auto is_a_load_tr_v = bool_constant<PipelineImplBase::is_a_load_tr>{};
static constexpr auto is_b_load_tr_v = bool_constant<PipelineImplBase::is_b_load_tr>{};
[[nodiscard]] CK_TILE_HOST static const std::string GetPipelineName()
{
// clang-format off
return "COMPUTE_ASYNC";
// clang-format on
}
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
{
return Policy::template GetSmemSize<Problem>();
}
CK_TILE_HOST_DEVICE static constexpr auto IsTransposeC()
{
return Policy::template IsTransposeC<Problem>();
}
template <GemmPipelineScheduler Scheduler>
struct PipelineImpl : public PipelineImplBase
{
};
template <>
struct PipelineImpl<GemmPipelineScheduler::Intrawave> : public PipelineImplBase
{
using Base = PipelineImplBase;
CK_TILE_DEVICE static constexpr auto HotLoopScheduler()
{
constexpr index_t MPerXDL = BlockGemmShape::WarpTile::at(I0{});
constexpr index_t NPerXDL = BlockGemmShape::WarpTile::at(I1{});
constexpr index_t KPerXDL = BlockGemmShape::WarpTile::at(I2{});
constexpr index_t WaveSize = get_warp_size();
constexpr index_t A_Buffer_Load_Inst_Num =
MPerBlock * KPerBlock / (BlockSize * GetVectorSizeA());
constexpr index_t B_Buffer_Load_Inst_Num =
NPerBlock * KPerBlock / (BlockSize * GetVectorSizeB());
constexpr index_t C_MFMA_Inst_Num = MPerBlock * NPerBlock * KPerBlock /
(BlockSize / WaveSize) /
(MPerXDL * NPerXDL * KPerXDL);
constexpr auto num_buffer_load_inst = A_Buffer_Load_Inst_Num + B_Buffer_Load_Inst_Num;
constexpr auto num_issue = num_buffer_load_inst;
static_for<0, num_buffer_load_inst, 1>{}([&](auto i) {
// TODO: this will likely need to be redesigned after (1) changes to reading from
// LDS and (2) re-profiling
ignore = i;
__builtin_amdgcn_sched_group_barrier(LLVMSchedGroupMask::MFMA, 1, 0); // MFMA : 1
__builtin_amdgcn_sched_group_barrier(
LLVMSchedGroupMask::DS_READ, 1, 0); // DS read : 1
__builtin_amdgcn_sched_group_barrier(LLVMSchedGroupMask::MFMA, 1, 0); // MFMA: 1
__builtin_amdgcn_sched_group_barrier(
LLVMSchedGroupMask::VMEM_READ, 1, 0); // VMEM read :1
__builtin_amdgcn_sched_group_barrier(
LLVMSchedGroupMask::MFMA, C_MFMA_Inst_Num / num_issue - 2, 0); // MFMA : 6
});
__builtin_amdgcn_sched_barrier(0);
}
template <bool HasHotLoop,
TailNumber TailNum,
typename AsDramBlockWindowTmp,
typename BsDramBlockWindowTmp,
typename ScaleADramBlockWindowTmp,
typename ScaleBDramBlockWindowTmp,
typename AElementFunction,
typename BElementFunction,
typename std::enable_if_t<is_detected<is_tuple, AsDramBlockWindowTmp>::value &&
is_detected<is_tuple, BsDramBlockWindowTmp>::value,
bool>* = nullptr>
CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp,
const AElementFunction& a_element_func,
const BsDramBlockWindowTmp& b_dram_block_window_tmp,
const BElementFunction& b_element_func,
const ScaleADramBlockWindowTmp& scale_a_window,
const ScaleBDramBlockWindowTmp& scale_b_window,
index_t num_loop,
void* __restrict__ p_smem_0,
void* __restrict__ p_smem_1) const
{
// TODO support multi-ABD
static_assert(1 == std::tuple_size_v<AsDramBlockWindowTmp>);
static_assert(1 == std::tuple_size_v<BsDramBlockWindowTmp>);
using ADramBlockWindowTmp =
remove_cvref_t<std::tuple_element_t<number<0>{}, AsDramBlockWindowTmp>>;
using BDramBlockWindowTmp =
remove_cvref_t<std::tuple_element_t<number<0>{}, BsDramBlockWindowTmp>>;
// TODO currently fused elementwise are not supported
ignore = a_element_func;
ignore = b_element_func;
static_assert(std::is_same_v<remove_cvref_t<decltype(a_element_func)>,
element_wise::PassThrough>);
static_assert(std::is_same_v<remove_cvref_t<decltype(b_element_func)>,
element_wise::PassThrough>);
static_assert(
std::is_same_v<ADataType, remove_cvref_t<typename ADramBlockWindowTmp::DataType>> &&
std::is_same_v<BDataType,
remove_cvref_t<typename BDramBlockWindowTmp::DataType>>,
"Data Type conflict on A and B matrix input data type.");
constexpr bool is_a_col_major =
std::is_same_v<ALayout, tensor_layout::gemm::ColumnMajor>;
constexpr bool is_b_row_major = std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>;
static_assert(is_a_col_major
? (KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
MPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1{}])
: (MPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1{}]),
"A block window has incorrect lengths for defined ALayout!");
static_assert(is_b_row_major
? (KPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
NPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I1{}])
: (NPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
KPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I1{}]),
"B block window has incorrect lengths for defined BLayout!");
////////////// global window & register /////////////////
// A DRAM tile window(s) for load
auto a_tile_windows = generate_tuple(
[&](auto idx) {
return make_tile_window(
a_dram_block_window_tmp[number<idx>{}].get_bottom_tensor_view(),
make_tuple(number<MPerBlock>{}, number<KPerBlock>{}),
a_dram_block_window_tmp[number<idx>{}].get_window_origin(),
Policy::template MakeADramTileDistribution<Problem>());
},
number<AsLayout::size()>{});
// B DRAM window(s) for load
auto b_tile_windows = generate_tuple(
[&](auto idx) {
return make_tile_window(
b_dram_block_window_tmp[number<idx>{}].get_bottom_tensor_view(),
make_tuple(number<NPerBlock>{}, number<KPerBlock>{}),
b_dram_block_window_tmp[number<idx>{}].get_window_origin(),
Policy::template MakeBDramTileDistribution<Problem>());
},
number<BsLayout::size()>{});
////////////// MX Scale windows /////////////////
// Get WarpGemm configuration
using BlockWarps = typename BlockGemmShape::BlockWarps;
constexpr index_t MWarp = BlockWarps::at(I0{});
constexpr index_t NWarp = BlockWarps::at(I1{});
// Calculate scale dimensions: KPerBlock elements need KPerBlock/32 e8m0_t scales
constexpr index_t ScaleKDimPerBlock = KPerBlock / ScaleBlockSize;
// Scale tensor views and base origins for creating tile windows per iteration
const auto& scale_a_tensor_view = scale_a_window.get_bottom_tensor_view();
const auto& scale_b_tensor_view = scale_b_window.get_bottom_tensor_view();
auto scale_a_base_origin = scale_a_window.get_window_origin();
auto scale_b_base_origin = scale_b_window.get_window_origin();
// Create sample scale windows to determine tile types
auto scale_a_dram_window =
make_tile_window(scale_a_tensor_view,
make_tuple(number<MPerBlock>{}, number<ScaleKDimPerBlock>{}),
scale_a_base_origin,
Policy::template MakeMX_ScaleA_DramTileDistribution<Problem>());
auto scale_b_dram_window =
make_tile_window(scale_b_tensor_view,
make_tuple(number<NPerBlock>{}, number<ScaleKDimPerBlock>{}),
scale_b_base_origin,
Policy::template MakeMX_ScaleB_DramTileDistribution<Problem>());
// this pipeline has a pair of LDS buffers per logical tile
auto&& [a_lds_block0, b_lds_block0] = Base::GetABLdsTensorViews(p_smem_0);
auto&& [a_lds_block1, b_lds_block1] = Base::GetABLdsTensorViews(p_smem_1);
constexpr auto a_lds_shape = []() {
if constexpr(is_a_load_tr_v)
return make_tuple(number<KPerBlock>{}, number<MPerBlock>{});
else
return make_tuple(number<MPerBlock>{}, number<KPerBlock>{});
}();
constexpr auto b_lds_shape = []() {
if constexpr(is_b_load_tr_v)
return make_tuple(number<KPerBlock>{}, number<NPerBlock>{});
else
return make_tuple(number<NPerBlock>{}, number<KPerBlock>{});
}();
// LDS tile windows for storing, one per LDS buffer
auto a_copy_lds_window0 = make_tile_window(a_lds_block0, a_lds_shape, {0, 0});
auto a_copy_lds_window1 = make_tile_window(a_lds_block1, a_lds_shape, {0, 0});
auto b_copy_lds_window0 = make_tile_window(b_lds_block0, b_lds_shape, {0, 0});
auto b_copy_lds_window1 = make_tile_window(b_lds_block1, b_lds_shape, {0, 0});
// initialize DRAM window steps, used to advance the DRAM windows
using ADramTileWindowStep = typename ADramBlockWindowTmp::BottomTensorIndex;
using BDramTileWindowStep = typename BDramBlockWindowTmp::BottomTensorIndex;
constexpr ADramTileWindowStep a_dram_tile_window_step =
is_a_col_major ? make_array(KPerBlock, 0) : make_array(0, KPerBlock);
constexpr BDramTileWindowStep b_dram_tile_window_step =
is_b_row_major ? make_array(KPerBlock, 0) : make_array(0, KPerBlock);
// read A(0), B(0) from DRAM to LDS window(0)
// and advance the DRAM windows
Base::GlobalPrefetchAsync(
a_copy_lds_window0, a_tile_windows[number<0>{}], a_dram_tile_window_step);
Base::GlobalPrefetchAsync(
b_copy_lds_window0, b_tile_windows[number<0>{}], b_dram_tile_window_step);
// Initialize block gemm and C block tile
auto block_gemm = BlockGemm();
auto c_block_tile = block_gemm.MakeCBlockTile();
clear_tile(c_block_tile);
// read A(1), B(1) from DRAM to LDS window(1)
// and advance the DRAM windows
Base::GlobalPrefetchAsync(
a_copy_lds_window1, a_tile_windows[number<0>{}], a_dram_tile_window_step);
Base::GlobalPrefetchAsync(
b_copy_lds_window1, b_tile_windows[number<0>{}], b_dram_tile_window_step);
// tile distribution for the register tiles
constexpr auto ALdsTileDistr =
make_static_tile_distribution(BlockGemm::MakeABlockDistributionEncode());
constexpr auto BLdsTileDistr =
make_static_tile_distribution(BlockGemm::MakeBBlockDistributionEncode());
using ALdsTile = decltype(make_static_distributed_tensor<ADataType>(ALdsTileDistr));
using BLdsTile = decltype(make_static_distributed_tensor<BDataType>(BLdsTileDistr));
// register tiles; double buffering -> a register tile corresponds to a LDS tile window
ALdsTile a_block_tile0, a_block_tile1;
BLdsTile b_block_tile0, b_block_tile1;
// Some sanity checks on the LDS tile sizes
static_assert(sizeof(ALdsTile) == MPerBlock *
(KPerBlock * sizeof(ADataType) / APackedSize) *
NWarp / BlockSize,
"ALdsTile size is wrong!");
static_assert(sizeof(BLdsTile) == NPerBlock *
(KPerBlock * sizeof(BDataType) / BPackedSize) *
MWarp / BlockSize,
"BLdsTile size is wrong!");
static_assert(Policy::template GetSmemSizeA<Problem>() ==
MPerBlock * (KPerBlock * sizeof(ADataType) / APackedSize),
"SmemSizeA size is wrong!");
static_assert(Policy::template GetSmemSizeB<Problem>() ==
(KPerBlock * sizeof(BDataType) / BPackedSize) * NPerBlock,
"SmemSizeB size is wrong!");
////////////// MX Scale register tiles (ping-pong buffers) /////////////////
// No packing needed - each thread gets e8m0_t elements directly
// Each thread will cast e8m0_t to int32_t for WarpGemm with OpSel=0
using ScaleATileType = decltype(load_tile(scale_a_dram_window));
using ScaleBTileType = decltype(load_tile(scale_b_dram_window));
ScaleATileType scale_a_tile_ping, scale_a_tile_pong;
ScaleBTileType scale_b_tile_ping, scale_b_tile_pong;
// initialize Scale DRAM window steps, used to advance the Scale DRAM windows
using ScaleADramTileWindowStep = typename ScaleADramBlockWindowTmp::BottomTensorIndex;
using ScaleBDramTileWindowStep = typename ScaleBDramBlockWindowTmp::BottomTensorIndex;
constexpr ScaleADramTileWindowStep scale_a_dram_tile_window_step =
make_array(0, ScaleKDimPerBlock);
constexpr ScaleBDramTileWindowStep scale_b_dram_tile_window_step =
make_array(0, ScaleKDimPerBlock);
// Helper function to load scales
auto load_scales_from_dram = [&](auto& scale_a, auto& scale_b) {
scale_a = load_tile(scale_a_dram_window);
scale_b = load_tile(scale_b_dram_window);
move_tile_window(scale_a_dram_window, scale_a_dram_tile_window_step);
move_tile_window(scale_b_dram_window, scale_b_dram_tile_window_step);
};
/// TODO: enable transpose
// constexpr auto a_lds_input_tile_distr = [ALdsTileDistr]() {
// if constexpr(is_a_load_tr_v)
// return make_static_tile_distribution(
// typename InputTileDistributionTraits<
// typename decltype(ALdsTileDistr)::DstrEncode,
// typename Problem::ADataType>::TransposedDstrEncode{});
// else
// return ALdsTileDistr;
// }();
// constexpr auto b_lds_input_tile_distr = [BLdsTileDistr]() {
// if constexpr(is_b_load_tr_v)
// return make_static_tile_distribution(
// typename InputTileDistributionTraits<
// typename decltype(BLdsTileDistr)::DstrEncode,
// typename Problem::BDataType>::TransposedDstrEncode{});
// else
// return BLdsTileDistr;
// }();
// LDS tile windows for reading;
// they share the data pointer with the LDS windows for storing
// but also associate with a distribution to produce a register tile when reading
auto a_lds_ld_window0 =
make_tile_window(a_lds_block0, a_lds_shape, {0, 0}, ALdsTileDistr);
auto a_lds_ld_window1 =
make_tile_window(a_lds_block1, a_lds_shape, {0, 0}, ALdsTileDistr);
auto b_lds_ld_window0 =
make_tile_window(b_lds_block0, b_lds_shape, {0, 0}, BLdsTileDistr);
auto b_lds_ld_window1 =
make_tile_window(b_lds_block1, b_lds_shape, {0, 0}, BLdsTileDistr);
static_assert(!(is_tile_window_linear_v<decltype(a_lds_ld_window0)>) &&
!(is_tile_window_linear_v<decltype(a_lds_ld_window1)>) &&
!(is_tile_window_linear_v<decltype(b_lds_ld_window0)>) &&
!(is_tile_window_linear_v<decltype(b_lds_ld_window1)>),
"LDS windows must not be linear");
// write to LDS window(0) must complete before the local prefetch
block_sync_lds_direct_load();
// read A(0), B(0) from LDS window(0) to pipeline registers(0)
Base::LocalPrefetch(a_block_tile0, a_lds_ld_window0, is_a_load_tr_v);
Base::LocalPrefetch(b_block_tile0, b_lds_ld_window0, is_b_load_tr_v);
// LDS window(0) contents are overwritten below by global prefetch, need to sync
block_sync_lds();
// read A(2), B(2) from DRAM to LDS window(0)
// and advance the DRAM windows
Base::GlobalPrefetchAsync(
a_copy_lds_window0, a_tile_windows[number<0>{}], a_dram_tile_window_step);
Base::GlobalPrefetchAsync(
b_copy_lds_window0, b_tile_windows[number<0>{}], b_dram_tile_window_step);
// Load scales for iteration 0 (ping)
load_scales_from_dram(scale_a_tile_ping, scale_b_tile_ping);
// Load scales for iteration 1 (pong) if needed
if(num_loop > 1)
{
load_scales_from_dram(scale_a_tile_pong, scale_b_tile_pong);
}
if(HasHotLoop)
{
// we have had 3 global prefetches so far, indexed (0, 1, 2).
index_t i_global_read = amd_wave_read_first_lane(3);
// alternate ping: (read to register tile(1), use register tile(0) as gemm input)
// pong: (read to register tile(0), use register tile(1) as gemm input)
do
{
// ping
{
// read A(i-1), B(i-1) from LDS window(1) to pipeline registers(1)
Base::LocalPrefetch(a_block_tile1, a_lds_ld_window1, is_a_load_tr_v);
Base::LocalPrefetch(b_block_tile1, b_lds_ld_window1, is_b_load_tr_v);
// LDS window(1) contents are overwritten by global prefetch, need to sync
block_sync_lds();
// read A(i), B(i) from DRAM to LDS window(1)
// and advance the DRAM windows
Base::GlobalPrefetchAsync(a_copy_lds_window1,
a_tile_windows[number<0>{}],
a_dram_tile_window_step);
Base::GlobalPrefetchAsync(b_copy_lds_window1,
b_tile_windows[number<0>{}],
b_dram_tile_window_step);
// C(i-3) = A(i-3) @ B(i-3) with MX scaling
block_gemm(c_block_tile,
a_block_tile0,
b_block_tile0,
scale_a_tile_ping,
scale_b_tile_ping);
HotLoopScheduler();
// Load next scales after using current scales above
load_scales_from_dram(scale_a_tile_ping, scale_b_tile_ping);
}
// pong
{
// write to LDS window(0) must complete before the local prefetch
block_sync_lds_direct_load();
// read A(i), B(i) from LDS window(0) to pipeline registers(0)
Base::LocalPrefetch(a_block_tile0, a_lds_ld_window0, is_a_load_tr_v);
Base::LocalPrefetch(b_block_tile0, b_lds_ld_window0, is_b_load_tr_v);
// LDS window(0) contents are overwritten by global prefetch, need to sync
block_sync_lds();
// read A(i+1), B(i+1) from DRAM to LDS window(0)
// and advance the DRAM windows
Base::GlobalPrefetchAsync(a_copy_lds_window0,
a_tile_windows[number<0>{}],
a_dram_tile_window_step);
Base::GlobalPrefetchAsync(b_copy_lds_window0,
b_tile_windows[number<0>{}],
b_dram_tile_window_step);
// C(i-2) = A(i-2) @ B(i-2) with MX scaling
block_gemm(c_block_tile,
a_block_tile1,
b_block_tile1,
scale_a_tile_pong,
scale_b_tile_pong);
HotLoopScheduler();
// Load next scales after using current scales above
load_scales_from_dram(scale_a_tile_pong, scale_b_tile_pong);
}
i_global_read += 2;
} while(i_global_read < num_loop);
}
// 3 block gemms remaining
if constexpr(TailNum == TailNumber::Three)
{
{
// read A(num_loop-1), B(num_loop-1) from LDS window(1) to pipeline registers(1)
Base::LocalPrefetch(a_block_tile1, a_lds_ld_window1, is_a_load_tr_v);
Base::LocalPrefetch(b_block_tile1, b_lds_ld_window1, is_b_load_tr_v);
// C(num_loop-2) = A(num_loop-2) @ B(num_loop-2) with MX scaling
block_gemm(c_block_tile,
a_block_tile0,
b_block_tile0,
scale_a_tile_ping,
scale_b_tile_ping);
// load last scales to ping for the last iteration to ping buffers
load_scales_from_dram(scale_a_tile_ping, scale_b_tile_ping);
}
{
// write to LDS window(0) must complete before the local prefetch
block_sync_lds_direct_load();
// read A(num_loop), B(num_loop) from LDS window(0) to pipeline registers(0)
Base::LocalPrefetch(a_block_tile0, a_lds_ld_window0, is_a_load_tr_v);
Base::LocalPrefetch(b_block_tile0, b_lds_ld_window0, is_b_load_tr_v);
// C(num_loop-1) = A(num_loop-1) @ B(num_loop-1) with MX scaling
block_gemm(c_block_tile,
a_block_tile1,
b_block_tile1,
scale_a_tile_pong,
scale_b_tile_pong);
}
{
// C(num_loop) = A(num_loop) @ B(num_loop) with MX scaling
block_gemm(c_block_tile,
a_block_tile0,
b_block_tile0,
scale_a_tile_ping,
scale_b_tile_ping);
}
}
else if(TailNum == TailNumber::Two)
// 2 block gemms remaining
{
{
// read A(num_loop), B(num_loop) from LDS window(1) to pipeline registers(1)
Base::LocalPrefetch(a_block_tile1, a_lds_ld_window1, is_a_load_tr_v);
Base::LocalPrefetch(b_block_tile1, b_lds_ld_window1, is_b_load_tr_v);
block_gemm(c_block_tile,
a_block_tile0,
b_block_tile0,
scale_a_tile_ping,
scale_b_tile_ping);
}
{
// C(num_loop) = A(num_loop) @ B(num_loop) with MX scaling
block_gemm(c_block_tile,
a_block_tile1,
b_block_tile1,
scale_a_tile_pong,
scale_b_tile_pong);
}
}
else if(TailNum == TailNumber::One)
{
block_sync_lds();
// C(num_loop) = A(num_loop) @ B(num_loop) with MX scaling
block_gemm(c_block_tile,
a_block_tile0,
b_block_tile0,
scale_a_tile_ping,
scale_b_tile_ping);
__builtin_amdgcn_sched_barrier(0);
}
return c_block_tile;
}
};
template <typename ADramBlockWindowTmp,
typename BDramBlockWindowTmp,
typename ScaleADramBlockWindowTmp,
typename ScaleBDramBlockWindowTmp,
typename AElementFunction,
typename BElementFunction>
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
const AElementFunction& a_element_func,
const BDramBlockWindowTmp& b_dram_block_window_tmp,
const BElementFunction& b_element_func,
const ScaleADramBlockWindowTmp& scale_a_window,
const ScaleBDramBlockWindowTmp& scale_b_window,
index_t num_loop,
void* __restrict__ p_smem_0,
void* __restrict__ p_smem_1) const
{
const bool has_hot_loop = Base::BlockHasHotloop(num_loop);
const auto tail_number = Base::GetBlockLoopTailNum(num_loop);
const auto RunPipeline = [&](auto hot_loop_, auto tail_num_) {
return PipelineImpl<Scheduler>{}.template operator()<hot_loop_.value, tail_num_.value>(
a_dram_block_window_tmp,
a_element_func,
b_dram_block_window_tmp,
b_element_func,
scale_a_window,
scale_b_window,
num_loop,
p_smem_0,
p_smem_1);
};
return Base::TailHandler(RunPipeline, has_hot_loop, tail_number);
}
public:
template <typename ADramBlockWindowTmp,
typename BDramBlockWindowTmp,
typename ScaleADramBlockWindowTmp,
typename ScaleBDramBlockWindowTmp>
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
const BDramBlockWindowTmp& b_dram_block_window_tmp,
const ScaleADramBlockWindowTmp& scale_a_window,
const ScaleBDramBlockWindowTmp& scale_b_window,
const index_t num_loop,
void* __restrict__ p_smem_0,
void* __restrict__ p_smem_1) const
{
const bool has_hot_loop = Base::BlockHasHotloop(num_loop);
const auto tail_number = Base::GetBlockLoopTailNum(num_loop);
const auto RunPipeline = [&](auto hot_loop_, auto tail_num_) {
return PipelineImpl<Scheduler>{}.template operator()<hot_loop_.value, tail_num_.value>(
make_tuple(a_dram_block_window_tmp),
element_wise::PassThrough{},
make_tuple(b_dram_block_window_tmp),
element_wise::PassThrough{},
scale_a_window,
scale_b_window,
num_loop,
p_smem_0,
p_smem_1);
};
return Base::TailHandler(RunPipeline, has_hot_loop, tail_number);
}
};
} // namespace ck_tile

View File

@@ -0,0 +1,195 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/core/arch/arch.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp"
#include <type_traits>
namespace ck_tile {
// Default policy for MXGemmPipelineAgBgCrCompAsync
// Customized methods: MakeALdsBlockDescriptor, MakeBLdsBlockDescriptor
// GetBlockGemm implementation is copied from GemmPipelineAgBgCrCompV4DefaultPolicy
// Adds MX scale tile distributions
struct MXGemmPipelineAgBgCrCompAsyncDefaultPolicy
: public UniversalGemmBasePolicy<MXGemmPipelineAgBgCrCompAsyncDefaultPolicy>
{
static constexpr auto ATileAccessPattern = tile_distribution_pattern::warp_raked;
static constexpr auto BTileAccessPattern = tile_distribution_pattern::warp_raked;
// MX scaling configuration: each e8m0 scale covers 32 elements in K
static constexpr int BlockScaleSize = 32;
template <typename Problem,
typename OverrideADataType = remove_cvref_t<typename Problem::ADataType>>
CK_TILE_HOST_DEVICE static constexpr auto MakeALdsBlockDescriptor()
{
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
if constexpr(is_a_load_tr<Problem>)
{
// TODO: better LDS descriptor for performance
constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor( //
make_tuple(number<KPerBlock>{}, number<MPerBlock>{}),
make_tuple(number<MPerBlock>{}, number<1>{}),
number<MPerBlock>{},
number<1>{});
return a_lds_block_desc_0;
}
else
{
constexpr index_t KPack = GetSmemPackA<Problem>();
constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<KPerBlock / KPack>{}, number<MPerBlock>{}, number<KPack>{}),
make_tuple(number<KPack>{}, number<KPerBlock>{}, number<1>{}),
number<KPack>{},
number<1>{});
return transform_tensor_descriptor(
a_lds_block_desc_0,
make_tuple(
make_pass_through_transform(number<MPerBlock>{}),
make_merge_transform(make_tuple(number<KPerBlock / KPack>{}, number<KPack>{}))),
make_tuple(sequence<1>{}, sequence<0, 2>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
}
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeBLdsBlockDescriptor()
{
constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
if constexpr(is_b_load_tr<Problem>)
{
// TODO: better LDS descriptor for performance
constexpr auto b_lds_block_desc_0 =
make_naive_tensor_descriptor(make_tuple(number<KPerBlock>{}, number<NPerBlock>{}),
make_tuple(number<NPerBlock>{}, number<1>{}),
number<NPerBlock>{},
number<1>{});
return b_lds_block_desc_0;
}
else
{
constexpr index_t KPack = GetSmemPackB<Problem>();
constexpr auto b_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<KPerBlock / KPack>{}, number<NPerBlock>{}, number<KPack>{}),
make_tuple(number<KPack>{}, number<KPerBlock>{}, number<1>{}),
number<KPack>{},
number<1>{});
return transform_tensor_descriptor(
b_lds_block_desc_0,
make_tuple(
make_pass_through_transform(number<NPerBlock>{}),
make_merge_transform(make_tuple(number<KPerBlock / KPack>{}, number<KPack>{}))),
make_tuple(sequence<1>{}, sequence<0, 2>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
}
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm()
{
using BlockWarps = typename Problem::BlockGemmShape::BlockWarps;
using WarpTile = typename Problem::BlockGemmShape::WarpTile;
using ADataType = typename Problem::ADataType;
using BDataType = typename Problem::BDataType;
using CDataType = typename Problem::CDataType;
// FP4 and FP8 require different layouts for the scaled mfma instructions
constexpr auto wg_attr_num_access =
(std::is_same_v<ADataType, fp8_t> || std::is_same_v<BDataType, fp8_t>)
? WGAttrNumAccessEnum::Double
: WGAttrNumAccessEnum::Single;
using WarpGemm = WarpGemmDispatcher<ADataType,
BDataType,
CDataType, // AccDataType
WarpTile::at(I0),
WarpTile::at(I1),
WarpTile::at(I2),
Problem::TransposeC,
false,
false,
wg_attr_num_access>;
using BlockGemmPolicy = BlockGemmARegBRegCRegV1CustomPolicy<ADataType,
BDataType,
CDataType,
BlockWarps,
WarpGemm>;
return BlockGemmARegBRegCRegV1<Problem, BlockGemmPolicy>{};
}
// MX Scale tile distributions for loading from global memory
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeMX_ScaleA_DramTileDistribution()
{
using BlockGemmShape = typename Problem::BlockGemmShape;
using BlockWarps = typename BlockGemmShape::BlockWarps;
using WarpTile = typename BlockGemmShape::WarpTile;
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
constexpr index_t MWarp = BlockWarps::at(number<0>{});
constexpr index_t NWarp = BlockWarps::at(number<1>{});
constexpr index_t MPerXdl = WarpTile::at(number<0>{});
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t K_Lane = get_warp_size() / MPerXdl; // 64/16 = 4 threads in K dimension
constexpr index_t MIterPerWarp = MPerBlock / (MWarp * MPerXdl);
constexpr index_t KPerXdl = WarpTile::at(number<2>{});
constexpr index_t KIterPerWarp = KPerBlock / KPerXdl;
constexpr index_t KPerLane = KPerXdl / BlockScaleSize / K_Lane;
return make_static_tile_distribution(
tile_distribution_encoding<
sequence<NWarp>, // repeat over MWarps
tuple<sequence<MIterPerWarp, MWarp, MPerXdl>, // M dimension (first)
sequence<KIterPerWarp, K_Lane, KPerLane>>, // K dimension (second)
tuple<sequence<0, 1>, sequence<2, 1>>, // <MWarp, NWarp>, <K_Lane, MPerXdl>
tuple<sequence<0, 1>, sequence<1, 2>>,
sequence<2, 1, 2>, // <KIterPerWarp, MIterPerWarp, KPerLane>
sequence<0, 0, 2>>{});
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeMX_ScaleB_DramTileDistribution()
{
using BlockGemmShape = typename Problem::BlockGemmShape;
using BlockWarps = typename BlockGemmShape::BlockWarps;
using WarpTile = typename BlockGemmShape::WarpTile;
constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
constexpr index_t MWarp = BlockWarps::at(number<0>{});
constexpr index_t NWarp = BlockWarps::at(number<1>{});
constexpr index_t NPerXdl = WarpTile::at(number<1>{});
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t K_Lane = get_warp_size() / NPerXdl; // 64/16 = 4 threads in K dimension
constexpr index_t NIterPerWarp = NPerBlock / (NWarp * NPerXdl);
constexpr index_t KPerXdl = WarpTile::at(number<2>{});
constexpr index_t KIterPerWarp = KPerBlock / KPerXdl;
constexpr index_t KPerLane = KPerXdl / BlockScaleSize / K_Lane;
return make_static_tile_distribution(
tile_distribution_encoding<
sequence<MWarp>, // repeat over MWarps
tuple<sequence<NIterPerWarp, NWarp, NPerXdl>, // N dimension (first)
sequence<KIterPerWarp, K_Lane, KPerLane>>, // K dimension (second)
tuple<sequence<0, 1>, sequence<2, 1>>, // <MWarp, NWarp>, <K_Lane, MPerXdl>
tuple<sequence<0, 1>, sequence<1, 2>>,
sequence<2, 1, 2>, // <KIterPerWarp, NIterPerWarp, KPerLane>
sequence<0, 0, 2>>{});
}
};
} // namespace ck_tile