[rocm-libraries] ROCm/rocm-libraries#5552 (commit 369c7a2)

[CK Tile] Eight Waves pipeline for MX GEMM (#5552)

## Motivation

Integrate Eight Waves pipeline in MX GEMM

## Technical Details

 - EightWaves pipeline:
- Add pipeline, policy and block gemm (internally using existing
implementation used by GEMM and ABQuant)
   - Extend support of EightWaves policy for FP4 (packed types)
 - Async pipeline:
- Fix pipeline with packed scales (requires MRepeat and NRepeat to be
contiguous)
- block gemm specific for MX GEMM is defined because distribution
encodings have changed
 - CShuffle:
- Add new functionality to support MRepeat and NRepeat contiguous
(defined by `TilesPacked`)
 - Examples:
- Refactor examples to easily switch different configurations (similar
to GEMM universal)
- Scales values generated consistently with other microscale
implementations in CK Tile
   - Add configuration for EightWaves pipeline
 - Tests:
   - Unify existing FP8 and FP4 tests
   - Add tests for EightWaves pipeline
- Scales values generated consistently with other microscale
implementations in CK Tile

Note: FP6 support for MX GEMM was added later and the support for the
Eight Waves pipeline will be done in following PR

## Test Plan

Add new pipeline to tests: `test_ck_tile_mx_gemm_async` for both FP4 and
FP8

## Submission Checklist

- [x] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
Enrico Degregori
2026-05-19 20:53:19 +02:00
committed by GitHub
parent f01a8cb28d
commit 9565ca21ec
29 changed files with 1472 additions and 383 deletions

View File

@@ -14,6 +14,8 @@ endforeach()
if(has_supported_gpu)
add_executable(tile_example_mx_gemm mx_gemm.cpp)
set(EXAMPLE_MX_GEMM_COMPILE_OPTIONS -Wno-undefined-func-template)
list(APPEND EXAMPLE_MX_GEMM_COMPILE_OPTIONS "SHELL: -mllvm -enable-noalias-to-md-conversion=1")
list(APPEND EXAMPLE_MX_GEMM_COMPILE_OPTIONS "SHELL: -mllvm -greedy-reverse-local-assignment=1")
if(CK_USE_OCP_FP8)
list(APPEND EXAMPLE_MX_GEMM_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8)
endif()

View File

@@ -102,9 +102,9 @@ float invoke_mx_gemm(ck_tile::DeviceMem& a_dev_buf,
auto create_args(int argc, char* argv[])
{
ck_tile::ArgParser arg_parser;
arg_parser.insert("m", "4096", "m dimension")
.insert("n", "4096", "n dimension")
.insert("k", "4096", "k dimension")
arg_parser.insert("m", "1024", "m dimension")
.insert("n", "1024", "n dimension")
.insert("k", "2048", "k dimension")
.insert("a_layout", "R", "A tensor data layout - Row by default")
.insert("b_layout", "C", "B tensor data layout - Row by default")
.insert("c_layout", "R", "C tensor data layout - Row by default")
@@ -125,4 +125,4 @@ auto create_args(int argc, char* argv[])
#include "run_mx_gemm.inc"
int main(int argc, char* argv[]) { return run_mx_gemm_example(argc, argv); }
int main(int argc, char* argv[]) { return run_mx_gemm_example<MX_GemmConfig16>(argc, argv); }

View File

@@ -9,6 +9,7 @@
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/ops/epilogue.hpp"
#include "ck_tile/ops/gemm.hpp"
#include "ck_tile/ops/gemm_mx.hpp"
#include "ck_tile/ops/gemm_mx/kernel/scale_pointer.hpp"
template <typename ScaleM, typename ScaleN>
@@ -83,17 +84,23 @@ struct MxGemmConfig
static constexpr int N_Repeat = N_Tile / N_Warp_Tile / N_Warp;
static constexpr bool TiledMMAPermuteN = false;
};
struct MXfp4_GemmConfig16 : MxGemmConfig
struct MX_GemmConfigEightWaves : MxGemmConfig
{
static constexpr ck_tile::index_t M_Tile = 64;
static constexpr ck_tile::index_t N_Tile = 64;
static constexpr ck_tile::index_t K_Tile = 256;
static constexpr ck_tile::index_t M_Warp = 4;
static constexpr ck_tile::index_t N_Warp = 2; // NWarps == 2 for ping-pong!
static constexpr ck_tile::index_t K_Warp = 1;
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 128 * N_Warp;
static constexpr ck_tile::index_t K_Tile = 128 * K_Warp;
static constexpr int kBlockPerCu = 2;
};
// GEMM config with 16x16 warp tile
struct MXfp8_GemmConfig16 : MxGemmConfig
struct MX_GemmConfig16 : MxGemmConfig
{
static constexpr ck_tile::index_t M_Tile = 64;
static constexpr ck_tile::index_t N_Tile = 64;
static constexpr ck_tile::index_t N_Tile = 128;
static constexpr ck_tile::index_t K_Tile = 256;
};

View File

@@ -57,7 +57,12 @@ float mx_gemm_calc(const MXGemmHostArgs<ScaleM, ScaleN>& args, const ck_tile::st
GemmConfig::Scheduler>;
// Use the new MX comp_async pipeline with MX scaling support
using MXGemmPipeline = ck_tile::MXGemmPipelineAgBgCrCompAsync<MXPipelineProblem>;
constexpr bool IsEightWave =
(GemmConfig::M_Warp * GemmConfig::N_Warp * GemmConfig::K_Warp) == 8;
using MXGemmPipeline =
std::conditional_t<IsEightWave,
ck_tile::MXGemmPipelineAgBgCrCompAsyncEightWaves<MXPipelineProblem>,
ck_tile::MXGemmPipelineAgBgCrCompAsync<MXPipelineProblem>>;
using TilePartitioner =
ck_tile::GemmSpatiallyLocalTilePartitioner<GemmShape,
@@ -80,7 +85,15 @@ float mx_gemm_calc(const MXGemmHostArgs<ScaleM, ScaleN>& args, const ck_tile::st
GemmConfig::M_Warp_Tile,
GemmConfig::N_Warp_Tile,
GemmConfig::K_Warp_Tile,
MXPipelineProblem::TransposeC>>;
MXPipelineProblem::TransposeC,
1, // kNumWaveGroups_ (Default)
false, // FixedVectorSize_ (Default)
1, // VectorSizeC_ (Default)
1, // BlockedXDLN_PerWarp_ (Default)
false, // DoubleSmemBuffer_ (Default)
ComputeDataType, // AComputeDataType
ComputeDataType, // BComputeDataType
true>>; // TilesPacked_ (because of packed scales)
using Kernel = ck_tile::MXGemmKernel<TilePartitioner, MXGemmPipeline, GemmEpilogue>;

View File

@@ -124,29 +124,42 @@ int run_mx_gemm_with_layouts(int argc, char* argv[], ALayout, BLayout, CLayout)
ck_tile::host_tensor_descriptor(M, scale_k_size, stride_scale_a, is_row_major(ALayout{})));
ck_tile::HostTensor<ScaleType> scale_b_host(
ck_tile::host_tensor_descriptor(scale_k_size, N, stride_scale_b, is_row_major(BLayout{})));
int seed = 1234;
std::mt19937 gen(42);
std::uniform_int_distribution<std::uint32_t> fill_seed(0, 500);
auto gen_scales = [&](auto& scales, float range_min, float range_max) {
// e8m0_t is basically an exponent of float32
ck_tile::HostTensor<float> pow2(scales.get_lengths());
ck_tile::FillUniformDistributionIntegerValue<float>{range_min, range_max, fill_seed(gen)}(
pow2);
scales.ForEach([&](auto& self, const auto& i) {
self(i) = static_cast<ScaleType>(std::exp2(pow2(i)));
});
};
switch(init_method)
{
case 0:
// Initialize A, B, and scales to random values
ck_tile::FillUniformDistribution<ADataType>{-2.f, 2.f, seed++}(a_host);
ck_tile::FillUniformDistribution<BDataType>{-2.f, 2.f, seed++}(b_host);
ck_tile::FillUniformDistribution<ScaleType>{0.001f, 10.f, seed++}(scale_a_host);
ck_tile::FillUniformDistribution<ScaleType>{0.001f, 10.f, seed++}(scale_b_host);
ck_tile::FillUniformDistribution<ADataType>{-2.f, 2.f, fill_seed(gen)}(a_host);
ck_tile::FillUniformDistribution<BDataType>{-2.f, 2.f, fill_seed(gen)}(b_host);
gen_scales(scale_a_host, -2, 2);
gen_scales(scale_b_host, -2, 2);
break;
case 1:
// Initialize A, B, and scales to 1.0
ck_tile::FillConstant<ADataType>{ADataType(1.f)}(a_host);
ck_tile::FillConstant<BDataType>{BDataType(1.f)}(b_host);
ck_tile::FillConstant<ScaleType>{ScaleType(1.f)}(scale_a_host);
ck_tile::FillConstant<ScaleType>{ScaleType(1.f)}(scale_b_host);
gen_scales(scale_a_host, 0, 0);
gen_scales(scale_b_host, 0, 0);
break;
case 2:
// Initialize A and B with random values but with constant 1.0 scales
ck_tile::FillUniformDistribution<ADataType>{-2.f, 2.f, seed++}(a_host);
ck_tile::FillUniformDistribution<BDataType>{-2.f, 2.f, seed++}(b_host);
ck_tile::FillConstant<ScaleType>{ScaleType(0.1f)}(scale_a_host);
ck_tile::FillConstant<ScaleType>{ScaleType(0.1f)}(scale_b_host);
ck_tile::FillUniformDistribution<ADataType>{-2.f, 2.f, fill_seed(gen)}(a_host);
ck_tile::FillUniformDistribution<BDataType>{-2.f, 2.f, fill_seed(gen)}(b_host);
gen_scales(scale_a_host, 0, 0);
gen_scales(scale_b_host, 0, 0);
break;
}
@@ -248,6 +261,7 @@ int run_mx_gemm_with_layouts(int argc, char* argv[], ALayout, BLayout, CLayout)
return pass ? 0 : -1;
}
template <typename GemmConfig>
int run_mx_gemm_example(int argc, char* argv[])
{
auto [result, arg_parser] = create_args(argc, argv);
@@ -268,7 +282,7 @@ int run_mx_gemm_example(int argc, char* argv[])
return run_mx_gemm_with_layouts<ck_tile::pk_fp4_t,
ck_tile::pk_fp4_t,
float,
MXfp4_GemmConfig16,
GemmConfig,
true>(argc, argv, Row{}, Col{}, Row{});
}
else if(mx_prec == "fp8" || mx_prec == "fp8xfp8")
@@ -276,7 +290,7 @@ int run_mx_gemm_example(int argc, char* argv[])
return run_mx_gemm_with_layouts<ck_tile::fp8_t,
ck_tile::fp8_t,
float,
MXfp8_GemmConfig16,
GemmConfig,
true>(argc, argv, Row{}, Col{}, Row{});
}
else

View File

@@ -2793,13 +2793,15 @@ template <typename T,
CK_TILE_DEVICE void amd_async_buffer_load_with_oob(CK_TILE_LDS_ADDR T* smem,
const __amdgpu_buffer_rsrc_t rsrc,
index_t src_thread_element_offset,
index_t src_wave_addr_offset,
index_t src_wave_element_offset,
linear_offset_t,
bool is_valid_element,
bool_constant<oob_conditional_check> = {})
{
index_t src_thread_addr_offset = src_thread_element_offset * sizeof(T);
constexpr index_t src_linear_addr_offset = static_cast<index_t>(linear_offset_t{}) * sizeof(T);
constexpr index_t PackedSize = numeric_traits<T>::PackedSize;
index_t src_wave_addr_offset = src_wave_element_offset * sizeof(T) / PackedSize;
amd_async_buffer_load<T, N, coherence>(smem,
rsrc,

View File

@@ -36,7 +36,8 @@ template <typename AsDataType_,
index_t BlockedXDLN_PerWarp_ = 1, // The number of continuous xdl_output per warp
bool DoubleSmemBuffer_ = false,
typename AComputeDataType_ = void,
typename BComputeDataType_ = void>
typename BComputeDataType_ = void,
bool TilesPacked_ = false>
struct CShuffleEpilogueProblem
{
using AsDataType = remove_cvref_t<AsDataType_>;
@@ -64,7 +65,7 @@ struct CShuffleEpilogueProblem
static constexpr bool DoubleSmemBuffer = DoubleSmemBuffer_;
static constexpr index_t kNumWaveGroups = kNumWaveGroups_;
static constexpr index_t NumDTensor = DsDataType::size();
static constexpr bool TilesPacked = TilesPacked_;
static_assert(NumDTensor == DsLayout::size(),
"The size of DsDataType and DsLayout should be the same");
};
@@ -140,15 +141,19 @@ struct CShuffleEpilogue
static constexpr bool EightWave = false;
#endif
// If the wave tiles computed by a single wave are packed
// This implies that in the block gemm MRepeat and NRepeat are contiguous
static constexpr bool TilesPacked = Problem::TilesPacked;
static constexpr index_t BlockedXDLN_PerWarp =
EightWave ? kNPerBlock / NWave / NPerXdl : Problem::BlockedXDLN_PerWarp;
static constexpr bool DoubleSmemBuffer = Problem::DoubleSmemBuffer;
static constexpr index_t VectorSizeC = Problem::VectorSizeC;
static constexpr index_t MPerIteration = MPerXdl * MWave;
static constexpr index_t NPerIteration = NPerXdl * NWave;
static constexpr index_t NumDTensor = Problem::NumDTensor;
static constexpr index_t MRepeat = kMPerBlock / (MPerXdl * MWave);
static constexpr index_t NRepeat = kNPerBlock / (NPerXdl * NWave);
(EightWave || TilesPacked) ? kNPerBlock / NWave / NPerXdl : Problem::BlockedXDLN_PerWarp;
static constexpr index_t BlockedXDLM_PerWarp = (TilesPacked) ? kMPerBlock / MWave / MPerXdl : 1;
static constexpr bool DoubleSmemBuffer = Problem::DoubleSmemBuffer;
static constexpr index_t VectorSizeC = Problem::VectorSizeC;
static constexpr index_t MPerIteration = MPerXdl * MWave;
static constexpr index_t NPerIteration = NPerXdl * NWave;
static constexpr index_t NumDTensor = Problem::NumDTensor;
static constexpr index_t MRepeat = kMPerBlock / (MPerXdl * MWave);
static constexpr index_t NRepeat = kNPerBlock / (NPerXdl * NWave);
CDElementwise elfunc_;
@@ -288,7 +293,8 @@ struct CShuffleEpilogue
}
}
}();
static constexpr index_t NumMXdlPerWavePerShuffle = std::get<0>(shuffle_tile_tuple);
static constexpr index_t NumMXdlPerWavePerShuffle =
max(BlockedXDLM_PerWarp, std::get<0>(shuffle_tile_tuple));
static constexpr index_t NumNXdlPerWavePerShuffle =
max(BlockedXDLN_PerWarp, std::get<1>(shuffle_tile_tuple));
@@ -447,64 +453,96 @@ struct CShuffleEpilogue
CK_TILE_DEVICE static constexpr auto MakeLdsDistributionEncode()
{
constexpr auto block_outer_dstr_encoding = [] {
if constexpr(BlockedXDLN_PerWarp == 1)
if constexpr(TilesPacked)
{
return tile_distribution_encoding<sequence<>,
tuple<sequence<NumMXdlPerWavePerShuffle, MWave>,
sequence<NumNXdlPerWavePerShuffle, NWave>>,
tuple<sequence<1, 2>>,
tuple<sequence<1, 1>>,
sequence<1, 2>,
sequence<0, 0>>{};
if constexpr(EightWave)
{
constexpr int RakedXDLN_PerWarp =
NumNXdlPerWavePerShuffle / BlockedXDLN_PerWarp;
return tile_distribution_encoding<
sequence<>,
tuple<sequence<MWave, NumMXdlPerWavePerShuffle>,
sequence<RakedXDLN_PerWarp, NWave, BlockedXDLN_PerWarp>>,
tuple<sequence<2, 1>>,
tuple<sequence<1, 0>>,
sequence<1, 2, 2>,
sequence<1, 0, 2>>{};
}
else
{
return tile_distribution_encoding<
sequence<>,
tuple<sequence<MWave, NumMXdlPerWavePerShuffle>,
sequence<NWave, NumNXdlPerWavePerShuffle>>,
tuple<sequence<1, 2>>,
tuple<sequence<0, 0>>,
sequence<1, 2>,
sequence<1, 1>>{};
}
}
else
{
#if defined(__gfx950__) || defined(__gfx12__)
constexpr auto UseBlockedLayout = true;
#else
constexpr auto UseBlockedLayout = false;
#endif
constexpr int RakedXDLN_PerWarp = NumNXdlPerWavePerShuffle / BlockedXDLN_PerWarp;
// BlockedLayout
// this branch is for original a16w4
if constexpr(UseBlockedLayout ||
is_any_of<ADataTypeBuf, pk_int4_t, pk_fp4_t>::value ||
is_any_of<BDataTypeBuf, pk_int4_t, pk_fp4_t>::value)
if constexpr(BlockedXDLN_PerWarp == 1)
{
if constexpr(EightWave)
return tile_distribution_encoding<
sequence<>,
tuple<sequence<NumMXdlPerWavePerShuffle, MWave>,
sequence<NumNXdlPerWavePerShuffle, NWave>>,
tuple<sequence<1, 2>>,
tuple<sequence<1, 1>>,
sequence<1, 2>,
sequence<0, 0>>{};
}
else
{
#if defined(__gfx950__) || defined(__gfx12__)
constexpr auto UseBlockedLayout = true;
#else
constexpr auto UseBlockedLayout = false;
#endif
constexpr int RakedXDLN_PerWarp =
NumNXdlPerWavePerShuffle / BlockedXDLN_PerWarp;
// BlockedLayout
// this branch is for original a16w4
if constexpr(UseBlockedLayout ||
is_any_of<ADataTypeBuf, pk_int4_t, pk_fp4_t>::value ||
is_any_of<BDataTypeBuf, pk_int4_t, pk_fp4_t>::value)
{
return tile_distribution_encoding<
sequence<>,
tuple<sequence<NumMXdlPerWavePerShuffle, MWave>,
sequence<RakedXDLN_PerWarp, NWave, BlockedXDLN_PerWarp>>,
tuple<sequence<2, 1>>,
tuple<sequence<1, 1>>,
sequence<1, 2, 2>,
sequence<0, 0, 2>>{};
if constexpr(EightWave)
{
return tile_distribution_encoding<
sequence<>,
tuple<sequence<NumMXdlPerWavePerShuffle, MWave>,
sequence<RakedXDLN_PerWarp, NWave, BlockedXDLN_PerWarp>>,
tuple<sequence<2, 1>>,
tuple<sequence<1, 1>>,
sequence<1, 2, 2>,
sequence<0, 0, 2>>{};
}
else
{
return tile_distribution_encoding<
sequence<>,
tuple<sequence<NumMXdlPerWavePerShuffle, MWave>,
sequence<RakedXDLN_PerWarp, NWave, BlockedXDLN_PerWarp>>,
tuple<sequence<1, 2>>,
tuple<sequence<1, 1>>,
sequence<1, 2, 2>,
sequence<0, 0, 2>>{};
}
}
else
{
return tile_distribution_encoding<
sequence<>,
tuple<sequence<NumMXdlPerWavePerShuffle, MWave>,
sequence<RakedXDLN_PerWarp, NWave, BlockedXDLN_PerWarp>>,
sequence<RakedXDLN_PerWarp, BlockedXDLN_PerWarp, NWave>>,
tuple<sequence<1, 2>>,
tuple<sequence<1, 2>>,
tuple<sequence<1, 1>>,
sequence<1, 2, 2>,
sequence<0, 0, 2>>{};
sequence<0, 0, 1>>{};
}
}
else
{
return tile_distribution_encoding<
sequence<>,
tuple<sequence<NumMXdlPerWavePerShuffle, MWave>,
sequence<RakedXDLN_PerWarp, BlockedXDLN_PerWarp, NWave>>,
tuple<sequence<1, 2>>,
tuple<sequence<1, 2>>,
sequence<1, 2, 2>,
sequence<0, 0, 1>>{};
}
}
}();
constexpr auto block_dstr_encoding = detail::make_embed_tile_distribution_encoding(

View File

@@ -388,150 +388,6 @@ struct BlockGemmARegBRegCRegV1
});
}
// C += A * B with MX scaling and packed-in-two (XdlPack) optimization
// Scale tensors contain pre-packed int32_t: each int32_t holds MXdlPack * KXdlPack e8m0_t
// values (for A) or NXdlPack * KXdlPack (for B), packed on the host.
// Uses OpSel (0-3) to select which byte within the packed int32_t for each MFMA call.
// XdlPack template parameters default to 2; fall back to 1 when iteration count is too small.
template <typename CBlockTensor,
typename ABlockTensor,
typename BBlockTensor,
typename ScaleATensor,
typename ScaleBTensor,
index_t MXdlPack_ = 2,
index_t NXdlPack_ = 2,
index_t KXdlPack_ = 2>
CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor,
const ABlockTensor& a_block_tensor,
const BBlockTensor& b_block_tensor,
const ScaleATensor& scale_a_tensor,
const ScaleBTensor& scale_b_tensor) const
{
static_assert(std::is_same_v<ADataType, remove_cv_t<typename ABlockTensor::DataType>> &&
std::is_same_v<BDataType, remove_cv_t<typename BBlockTensor::DataType>> &&
std::is_same_v<CDataType, remove_cv_t<typename CBlockTensor::DataType>>,
"wrong!");
// check ABC-block-distribution
static_assert(
std::is_same_v<remove_cvref_t<decltype(MakeABlockDistributionEncode())>,
remove_cvref_t<decltype(ABlockTensor::get_tile_distribution()
.get_static_tile_distribution_encoding())>>,
"A distribution is wrong!");
static_assert(
std::is_same_v<remove_cvref_t<decltype(MakeBBlockDistributionEncode())>,
remove_cvref_t<decltype(BBlockTensor::get_tile_distribution()
.get_static_tile_distribution_encoding())>>,
"B distribution is wrong!");
static_assert(
std::is_same_v<remove_cvref_t<decltype(MakeCBlockDistributionEncode())>,
remove_cvref_t<decltype(CBlockTensor::get_tile_distribution()
.get_static_tile_distribution_encoding())>>,
"C distribution is wrong!");
using AWarpDstr = typename WarpGemm::AWarpDstr;
using BWarpDstr = typename WarpGemm::BWarpDstr;
using CWarpDstr = typename WarpGemm::CWarpDstr;
using AWarpTensor = typename WarpGemm::AWarpTensor;
using BWarpTensor = typename WarpGemm::BWarpTensor;
using CWarpTensor = typename WarpGemm::CWarpTensor;
constexpr auto a_warp_y_lengths =
to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
constexpr auto b_warp_y_lengths =
to_sequence(BWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
constexpr auto c_warp_y_lengths =
to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
constexpr auto a_warp_y_index_zeros = uniform_sequence_gen_t<AWarpDstr::NDimY, 0>{};
constexpr auto b_warp_y_index_zeros = uniform_sequence_gen_t<BWarpDstr::NDimY, 0>{};
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
// Effective XdlPack: fall back to 1 when iteration count is insufficient
constexpr index_t MXdlPack =
(MIterPerWarp >= MXdlPack_ && MIterPerWarp % MXdlPack_ == 0) ? MXdlPack_ : 1;
constexpr index_t NXdlPack =
(NIterPerWarp >= NXdlPack_ && NIterPerWarp % NXdlPack_ == 0) ? NXdlPack_ : 1;
constexpr index_t KXdlPack =
(KIterPerWarp >= KXdlPack_ && KIterPerWarp % KXdlPack_ == 0) ? KXdlPack_ : 1;
constexpr index_t MPackIterPerWarp = MIterPerWarp / MXdlPack;
constexpr index_t NPackIterPerWarp = NIterPerWarp / NXdlPack;
constexpr index_t KPackIterPerWarp = KIterPerWarp / KXdlPack;
// hot loop with MX scaling and pre-packed int32_t scales:
// Outer loops iterate over pack groups (scale tile indices)
static_ford<sequence<KPackIterPerWarp, MPackIterPerWarp>>{}([&](auto ii) {
constexpr auto ikpack = number<ii[number<0>{}]>{};
constexpr auto impack = number<ii[number<1>{}]>{};
// Get pre-packed int32_t A scale (already contains MXdlPack*KXdlPack e8m0_t)
auto scale_a_slice = scale_a_tensor.get_y_sliced_thread_data(
sequence<ikpack, impack, 0>{}, sequence<1, 1, 1>{});
const int32_t a_scale_packed = bit_cast<int32_t>(scale_a_slice[number<0>{}]);
static_for<0, NPackIterPerWarp, 1>{}([&](auto inpack) {
// Get pre-packed int32_t B scale
auto scale_b_slice = scale_b_tensor.get_y_sliced_thread_data(
sequence<ikpack, inpack, 0>{}, sequence<1, 1, 1>{});
const int32_t b_scale_packed = bit_cast<int32_t>(scale_b_slice[number<0>{}]);
// Inner loops: issue MFMAs within the pack group using OpSel
static_ford<sequence<KXdlPack, MXdlPack>>{}([&](auto jj) {
constexpr auto ikxdl = number<jj[number<0>{}]>{};
constexpr auto imxdl = number<jj[number<1>{}]>{};
constexpr auto kIter = ikpack * KXdlPack + ikxdl;
constexpr auto mIter = impack * MXdlPack + imxdl;
// read A warp tensor from A block tensor
AWarpTensor a_warp_tensor;
a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, kIter>{}, a_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
// OpSel for A: selects byte within packed int32_t
constexpr index_t kOpSelA = ikxdl * MXdlPack + imxdl;
static_for<0, NXdlPack, 1>{}([&](auto inxdl) {
constexpr auto nIter = inpack * NXdlPack + inxdl;
// read B warp tensor from B block tensor
BWarpTensor b_warp_tensor;
b_warp_tensor.get_thread_buffer() = b_block_tensor.get_y_sliced_thread_data(
merge_sequences(sequence<nIter, kIter>{}, b_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, b_warp_y_lengths));
// OpSel for B: selects byte within packed int32_t
constexpr index_t kOpSelB = ikxdl * NXdlPack + inxdl;
// read C warp tensor from C block tensor
using c_iter_idx = std::conditional_t<TransposeC,
sequence<nIter, mIter>,
sequence<mIter, nIter>>;
CWarpTensor c_warp_tensor;
c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data(
merge_sequences(c_iter_idx{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
// warp GEMM with MX scaling using pre-packed scale and OpSel
WarpGemm{}.template operator()<OpSelA<kOpSelA>, OpSelB<kOpSelB>>(
c_warp_tensor,
a_warp_tensor,
b_warp_tensor,
a_scale_packed,
b_scale_packed);
// write C warp tensor into C block tensor
c_block_tensor.set_y_sliced_thread_data(
merge_sequences(c_iter_idx{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
c_warp_tensor.get_thread_buffer());
});
});
});
});
}
CK_TILE_DEVICE static constexpr auto MakeCBlockTile()
{
using c_distr_ys_major = std::conditional_t<TransposeC, sequence<2, 1>, sequence<1, 2>>;

View File

@@ -118,12 +118,7 @@ struct GemmPipelineAgBgCrCompAsyncEightWaves : public BaseGemmPipelineAgBgCrComp
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
{
// We are not storing the original packed type in LDS, so we need to multiply the smem size
// by the packed size.
constexpr index_t smem_size_a = Policy::template GetSmemSizeA<Problem>() * APackedSize;
constexpr index_t smem_size_b = Policy::template GetSmemSizeB<Problem>() * BPackedSize;
return 2 * (smem_size_a + smem_size_b);
return Policy::template GetSmemSize<Problem>();
}
static constexpr index_t MFMA_INST = MIterPerWarp * NIterPerWarp * KIterPerWarp;

View File

@@ -17,10 +17,9 @@ namespace detail {
template <typename Problem>
struct GemmPipelineAgBgCrCompAsyncEightWavesPolicy
{
static constexpr auto I0 = number<0>{};
static constexpr auto I1 = number<1>{};
static constexpr auto I2 = number<2>{};
static constexpr auto WGAccessDouble = WGAttrNumAccessEnum::Double;
static constexpr auto I0 = number<0>{};
static constexpr auto I1 = number<1>{};
static constexpr auto I2 = number<2>{};
using ALayout = remove_cvref_t<typename Problem::ALayout>;
using BLayout = remove_cvref_t<typename Problem::BLayout>;
@@ -29,14 +28,21 @@ struct GemmPipelineAgBgCrCompAsyncEightWavesPolicy
using CDataType = remove_cvref_t<typename Problem::CDataType>;
using AComputeDataType = remove_cvref_t<typename Problem::AComputeDataType>;
using BComputeDataType = remove_cvref_t<typename Problem::BComputeDataType>;
static_assert(std::is_same_v<ALayout, ck_tile::tensor_layout::gemm::RowMajor>, "Wrong!");
static_assert(std::is_same_v<BLayout, ck_tile::tensor_layout::gemm::ColumnMajor>, "Wrong!");
static_assert(std::is_same_v<AComputeDataType, fp8_t> ||
std::is_same_v<AComputeDataType, bf8_t>);
static_assert(std::is_same_v<BComputeDataType, fp8_t> ||
std::is_same_v<BComputeDataType, bf8_t>);
using ComputeDataType = AComputeDataType;
static_assert(std::is_same_v<ALayout, ck_tile::tensor_layout::gemm::RowMajor>,
"ALayout must be RowMajor!");
static_assert(std::is_same_v<BLayout, ck_tile::tensor_layout::gemm::ColumnMajor>,
"BLayout must be ColumnMajor!");
static_assert(is_any_of<AComputeDataType, fp8_t, bf8_t, pk_fp4_t>::value);
static_assert(is_any_of<BComputeDataType, fp8_t, bf8_t, pk_fp4_t>::value);
static_assert(std::is_same_v<AComputeDataType, BComputeDataType>);
static_assert(std::is_same_v<CDataType, float>);
static constexpr auto WGAccess = std::is_same_v<ComputeDataType, fp8_t>
? WGAttrNumAccessEnum::Double
: WGAttrNumAccessEnum::Single;
static constexpr auto PackedSize = numeric_traits<ComputeDataType>::PackedSize;
using BlockGemmShape = typename Problem::BlockGemmShape;
using BlockWarps = typename BlockGemmShape::BlockWarps;
using WarpTile = typename BlockGemmShape::WarpTile;
@@ -88,7 +94,7 @@ struct GemmPipelineAgBgCrCompAsyncEightWavesPolicy
static constexpr index_t NIterPerWarp = NWarpTiles / NWarps;
static constexpr index_t KPerWarp = KPerBlock / KWarps;
static constexpr index_t NPerWarp = NPerBlock / NWarps;
static_assert(NWarps == 2, "KWarps == 2 for ping-pong!");
static_assert(NWarps == 2, "NWarps == 2 for ping-pong!");
static_assert(KWarpTiles == KWarps, "Wrong!");
static constexpr index_t warp_size = get_warp_size();
@@ -98,8 +104,8 @@ struct GemmPipelineAgBgCrCompAsyncEightWavesPolicy
static_assert(sizeof(ADataType) == sizeof(BDataType), "Wrong!");
static constexpr index_t ElementSize = sizeof(ADataType);
static constexpr index_t K2 = Problem::VectorLoadSize / ElementSize; // 16
static constexpr index_t K1 = WarpTile::at(I2) / K2; // 8
static constexpr index_t K2 = Problem::VectorLoadSize / ElementSize * PackedSize; // 16
static constexpr index_t K1 = WarpTile::at(I2) / K2; // 8
static constexpr index_t K0 = KPerWarp / (K1 * K2);
static_assert(K0 * K1 * K2 == KPerWarp, "Wrong!");
static_assert(K0 == 1, "Wrong!");
@@ -176,7 +182,7 @@ struct GemmPipelineAgBgCrCompAsyncEightWavesPolicy
const index_t k_tiles = cols / (KWarps * K1 * K2);
const auto col_lens = make_tuple(k_tiles, number<KWarps>{}, number<K1>{}, number<K2>{});
constexpr index_t M1 = warp_size / static_cast<index_t>(WGAccessDouble) / K1; // 4
constexpr index_t M1 = warp_size / static_cast<index_t>(WGAccess) / K1; // 4
const index_t M0 = integer_divide_ceil(rows, M1);
const auto row_lens = make_tuple(M0, number<M1>{});
@@ -227,9 +233,9 @@ struct GemmPipelineAgBgCrCompAsyncEightWavesPolicy
template <index_t MNPerBlock, index_t warp_groups_>
CK_TILE_DEVICE static constexpr auto MakeABLdsBlockDescriptor_()
{
constexpr index_t M4 = warp_size / static_cast<index_t>(WGAccessDouble) / K1; // 4
constexpr index_t M3 = static_cast<index_t>(WGAccessDouble); // 2
constexpr index_t M2 = WarpTileM / M4 / M3; // 2
constexpr index_t M4 = warp_size / static_cast<index_t>(WGAccess) / K1; // 4
constexpr index_t M3 = static_cast<index_t>(WGAccess); // 2
constexpr index_t M2 = WarpTileM / M4 / M3; // 2
constexpr index_t M1 = (warp_num / warp_groups_) / M2;
constexpr index_t M0 = MNPerBlock / M1 / M2 / M3 / M4;
@@ -337,12 +343,14 @@ struct GemmPipelineAgBgCrCompAsyncEightWavesPolicy
CK_TILE_DEVICE static constexpr index_t GetSmemSizeA()
{
constexpr index_t desc_size = MakeALdsBlockDescriptor().get_element_space_size();
return integer_least_multiple(sizeof(typename Problem::ADataType) * desc_size, 16);
return integer_least_multiple(sizeof(typename Problem::ADataType) * desc_size / PackedSize,
16);
}
CK_TILE_DEVICE static constexpr index_t GetSmemSizeB()
{
constexpr index_t desc_size = MakeBLdsBlockDescriptor().get_element_space_size();
return integer_least_multiple(sizeof(typename Problem::BDataType) * desc_size, 16);
return integer_least_multiple(sizeof(typename Problem::BDataType) * desc_size / PackedSize,
16);
}
CK_TILE_DEVICE static constexpr index_t GetSmemSize()
@@ -361,7 +369,7 @@ struct GemmPipelineAgBgCrCompAsyncEightWavesPolicy
CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm()
{
// TODO: Fix for transpose
constexpr auto wg_attr_num_access = WGAttrNumAccessEnum::Double;
constexpr auto wg_attr_num_access = WGAccess;
using WarpGemm = WarpGemmDispatcher<typename Problem::ADataType,
typename Problem::BDataType,

View File

@@ -199,6 +199,22 @@ struct GemmPipelineAgBgCrEightWavesImplBase : public GemmPipelineAgBgCrImplBase<
return 0;
}
template <typename AQDramBlockWindowTmp,
typename std::enable_if_t<std::is_same_v<AQDramBlockWindowTmp, NullTileWindowType>,
bool>* = nullptr>
CK_TILE_DEVICE static constexpr auto GetInstCountAQ(const AQDramBlockWindowTmp&)
{
return 0;
}
template <typename BQDramBlockWindowTmp,
typename std::enable_if_t<std::is_same_v<BQDramBlockWindowTmp, NullTileWindowType>,
bool>* = nullptr>
CK_TILE_DEVICE static constexpr auto GetInstCountBQ(const BQDramBlockWindowTmp&)
{
return 0;
}
// A/B Quant
template <typename AQDramBlockWindowTmp,
typename std::enable_if_t<!std::is_same_v<AQDramBlockWindowTmp, NullTileWindowType>,
@@ -234,6 +250,22 @@ struct GemmPipelineAgBgCrEightWavesImplBase : public GemmPipelineAgBgCrImplBase<
return Policy::template GetKStepBQ<Problem>();
}
template <typename AQDramBlockWindowTmp,
typename std::enable_if_t<!std::is_same_v<AQDramBlockWindowTmp, NullTileWindowType>,
bool>* = nullptr>
CK_TILE_DEVICE static constexpr auto GetInstCountAQ(const AQDramBlockWindowTmp&)
{
return Policy::template GetInstCountAQ<Problem>();
}
template <typename BQDramBlockWindowTmp,
typename std::enable_if_t<!std::is_same_v<BQDramBlockWindowTmp, NullTileWindowType>,
bool>* = nullptr>
CK_TILE_DEVICE static constexpr auto GetInstCountBQ(const BQDramBlockWindowTmp&)
{
return Policy::template GetInstCountBQ<Problem>();
}
template <bool HasHotLoop,
TailNumber TailNum,
typename ADramBlockWindowTmp,
@@ -258,14 +290,6 @@ struct GemmPipelineAgBgCrEightWavesImplBase : public GemmPipelineAgBgCrImplBase<
: 0;
static_assert(N_LOOP >= 1, "wrong!");
// Instructions Count
constexpr index_t VectorSizeB = Policy::template GetVectorSizeB<Problem>();
constexpr index_t B_LOAD_INST = NPerBlock * KPerBlock / BlockSize / VectorSizeB;
constexpr index_t AQ_LOAD_INST =
std::is_same_v<AQDramBlockWindowTmp, NullTileWindowType> ? 0 : MIterPerWarp;
constexpr index_t BQ_LOAD_INST =
std::is_same_v<BQDramBlockWindowTmp, NullTileWindowType> ? 0 : 1;
// -----
// Setup
// -----
@@ -314,6 +338,12 @@ struct GemmPipelineAgBgCrEightWavesImplBase : public GemmPipelineAgBgCrImplBase<
constexpr AQDramTileWindowStep aq_move_step = {0, GetKStepAQ(aq_copy_dram_window)};
constexpr BQDramTileWindowStep bq_move_step = {0, GetKStepBQ(bq_copy_dram_window)};
// Instructions Count
constexpr index_t VectorSizeB = Policy::template GetVectorSizeB<Problem>();
constexpr index_t B_LOAD_INST = NPerBlock * KPerBlock / BlockSize / VectorSizeB;
constexpr index_t AQ_LOAD_INST = GetInstCountAQ(aq_copy_dram_window);
constexpr index_t BQ_LOAD_INST = GetInstCountBQ(bq_copy_dram_window);
// -------
// Lambdas
// -------

View File

@@ -2,10 +2,14 @@
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/ops/gemm_mx/block/block_mx_gemm_areg_breg_creg_eight_waves_v1.hpp"
#include "ck_tile/ops/gemm_mx/block/block_mx_gemm_areg_breg_creg_v1.hpp"
#include "ck_tile/ops/gemm_mx/kernel/gemm_mx_kernel.hpp"
#include "ck_tile/ops/gemm_mx/kernel/scale_pointer.hpp"
#include "ck_tile/ops/gemm_mx/pipeline/gemm_pipeline_ag_bg_cr_comp_async.hpp"
#include "ck_tile/ops/gemm_mx/pipeline/gemm_pipeline_ag_bg_cr_comp_async_default_policy.hpp"
#include "ck_tile/ops/gemm_mx/pipeline/gemm_pipeline_ag_bg_cr_comp_async_eight_waves.hpp"
#include "ck_tile/ops/gemm_mx/pipeline/gemm_pipeline_ag_bg_cr_comp_async_eight_waves_policy.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/load_and_convert_tile.hpp"
#include "ck_tile/ops/common/streamk_common.hpp"

View File

@@ -0,0 +1,310 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1_default_policy.hpp"
namespace ck_tile {
// A is block distributed tensor
// B is block distributed tensor
// C is block distributed tensor
template <typename Problem_, typename Policy_ = BlockGemmARegBRegCRegV1DefaultPolicy>
struct BlockMXGemmARegBRegCRegEightWavesV1
{
private:
template <typename PipelineProblem_, typename GemmPolicy_>
struct GemmTraits_
{
using Problem = remove_cvref_t<PipelineProblem_>;
using Policy = remove_cvref_t<GemmPolicy_>;
using ADataType = remove_cvref_t<typename Problem::ADataType>;
using BDataType = remove_cvref_t<typename Problem::BDataType>;
using CDataType = remove_cvref_t<typename Problem::CDataType>;
using AComputeDataType = remove_cvref_t<typename Problem::AComputeDataType>;
using BComputeDataType = remove_cvref_t<typename Problem::BComputeDataType>;
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
static constexpr index_t kBlockSize = Problem::kBlockSize;
static constexpr auto Scheduler = Problem::Scheduler;
static constexpr index_t MPerBlock = BlockGemmShape::kM;
static constexpr index_t NPerBlock = BlockGemmShape::kN;
static constexpr index_t KPerBlock = BlockGemmShape::kK;
static constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>();
using WarpGemm = remove_cvref_t<decltype(config.template at<0>())>;
static constexpr index_t MWarp = config.template at<1>();
static constexpr index_t NWarp = config.template at<2>();
static constexpr index_t KWarp = Problem::BlockGemmShape::BlockWarps::at(number<2>{});
using I0 = number<0>;
using I1 = number<1>;
static_assert(MWarp == BlockGemmShape::BlockWarps::at(I0{}),
"Error! WarpGemm's MWarp is not consistent with BlockGemmShape!");
static_assert(NWarp == BlockGemmShape::BlockWarps::at(I1{}),
"Error! WarpGemm's NWarp is not consistent with BlockGemmShape!");
static_assert(WarpGemm::kM == BlockGemmShape::WarpTile::at(I0{}),
"Error! WarpGemm's M is not consistent with BlockGemmShape!");
static_assert(WarpGemm::kN == BlockGemmShape::WarpTile::at(I1{}),
"Error! WarpGemm's N is not consistent with BlockGemmShape!");
static constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WarpGemm::kM);
static constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WarpGemm::kN);
static constexpr index_t KIterPerWarp = KPerBlock / (KWarp * WarpGemm::kK);
// Controls how many MAC clusters (MFMA blocks) we have per wave
// If InterWaveSchedulingMacClusters = 1;
// Then we group all WarpGemms into single MAC cluster.
// But if InterWaveSchedulingMacClusters = 2, then we
// split the warp gemms into two groups.
static constexpr index_t InterWaveSchedulingMacClusters = 1;
static constexpr index_t KPackA = WarpGemm::kAKPack;
static constexpr index_t KPackB = WarpGemm::kBKPack;
static constexpr index_t KPerThread = KIterPerWarp * WarpGemm::kKPerThread;
static constexpr bool TransposeC = Problem::TransposeC;
};
public:
using Problem = remove_cvref_t<Problem_>;
using Policy = remove_cvref_t<Policy_>;
using Traits = GemmTraits_<Problem, Policy>;
using WarpGemm = typename Traits::WarpGemm;
using BlockGemmShape = typename Traits::BlockGemmShape;
using ADataType = remove_cvref_t<typename Traits::ADataType>;
using BDataType = remove_cvref_t<typename Traits::BDataType>;
using CDataType = remove_cvref_t<typename Traits::CDataType>;
using AComputeDataType = remove_cvref_t<typename Traits::AComputeDataType>;
using BComputeDataType = remove_cvref_t<typename Traits::BComputeDataType>;
static constexpr index_t KIterPerWarp = Traits::KIterPerWarp;
static constexpr index_t MIterPerWarp = Traits::MIterPerWarp;
static constexpr index_t NIterPerWarp = Traits::NIterPerWarp;
static constexpr index_t MWarp = Traits::MWarp;
static constexpr index_t NWarp = Traits::NWarp;
static constexpr index_t KWarp = Traits::KWarp;
static constexpr auto Scheduler = Traits::Scheduler;
static constexpr bool TransposeC = Traits::TransposeC;
using AWarpDstr = typename WarpGemm::AWarpDstr;
using BWarpDstr = typename WarpGemm::BWarpDstr;
using CWarpDstr = typename WarpGemm::CWarpDstr;
using AWarpTensor = typename WarpGemm::AWarpTensor;
using BWarpTensor = typename WarpGemm::BWarpTensor;
using CWarpTensor = typename WarpGemm::CWarpTensor;
static_assert(std::is_same_v<typename WarpGemm::CDataType, float>);
static constexpr auto a_warp_y_lengths =
to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
static constexpr auto b_warp_y_lengths =
to_sequence(BWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
static constexpr auto c_warp_y_lengths =
to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
static constexpr auto a_warp_y_index_zeros = uniform_sequence_gen_t<AWarpDstr::NDimY, 0>{};
static constexpr auto b_warp_y_index_zeros = uniform_sequence_gen_t<BWarpDstr::NDimY, 0>{};
static constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
static constexpr index_t APackedSize =
ck_tile::numeric_traits<remove_cvref_t<ADataType>>::PackedSize;
static constexpr index_t BPackedSize =
ck_tile::numeric_traits<remove_cvref_t<BDataType>>::PackedSize;
using I0 = number<0>;
using I1 = number<1>;
// Note: distribution encodings have MIterPerWarp and NIterPerWarp contiguous because of scale
// packing.
CK_TILE_DEVICE static constexpr auto MakeABlockDistributionEncode()
{
constexpr index_t KPerThread = Traits::KPerThread;
constexpr index_t NumMacClusters = Traits::InterWaveSchedulingMacClusters;
constexpr index_t KPerInnerLoop =
ck_tile::max(KPerThread / NumMacClusters, WarpGemm::kKPerThread);
constexpr index_t KIterInterwave = KPerInnerLoop / WarpGemm::kKPerThread;
using KIterSeq = std::conditional_t<Scheduler == GemmPipelineScheduler::Interwave,
sequence<KWarp, KIterInterwave>,
sequence<KWarp, KIterPerWarp>>;
constexpr auto a_block_outer_dstr_encoding =
tile_distribution_encoding<sequence<2, NWarp / 2>,
tuple<sequence<MWarp, MIterPerWarp>, KIterSeq>,
tuple<sequence<0, 2, 1, 0>>,
tuple<sequence<0, 0, 0, 1>>,
sequence<1, 2>,
sequence<1, 1>>{};
constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
a_block_outer_dstr_encoding, typename WarpGemm::AWarpDstrEncoding{});
return a_block_dstr_encode;
}
CK_TILE_DEVICE static constexpr auto MakeBBlockDistributionEncode()
{
constexpr index_t KPerThread = Traits::KPerThread;
constexpr index_t NumMacClusters = Traits::InterWaveSchedulingMacClusters;
constexpr index_t KPerInnerLoop =
ck_tile::max(KPerThread / NumMacClusters, WarpGemm::kKPerThread);
constexpr index_t KIterInterwave = KPerInnerLoop / WarpGemm::kKPerThread;
using KIterSeq = std::conditional_t<Scheduler == GemmPipelineScheduler::Interwave,
sequence<KWarp, KIterInterwave>,
sequence<KWarp, KIterPerWarp>>;
constexpr auto b_block_outer_dstr_encoding =
tile_distribution_encoding<sequence<MWarp>,
tuple<sequence<2, NIterPerWarp, NWarp / 2>, KIterSeq>,
tuple<sequence<2, 1, 0, 1>>,
tuple<sequence<0, 0, 0, 2>>,
sequence<>,
sequence<>>{};
constexpr auto b_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
b_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{});
return b_block_dstr_encode;
}
CK_TILE_DEVICE static constexpr auto MakeCBlockDistributionEncode()
{
constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding<
sequence<KWarp>,
tuple<sequence<MWarp, MIterPerWarp>, sequence<2, NIterPerWarp, NWarp / 2>>,
tuple<sequence<2, 0, 1, 2>>,
tuple<sequence<0, 0, 0, 2>>,
sequence<1, 2>,
sequence<1, 1>>{};
constexpr auto c_block_dstr_encoding = detail::make_embed_tile_distribution_encoding(
c_block_outer_dstr_encoding, typename WarpGemm::CWarpDstrEncoding{});
return c_block_dstr_encoding;
}
CK_TILE_DEVICE static constexpr auto MakeCBlockTile()
{
return make_static_distributed_tensor<CDataType>(
make_static_tile_distribution(MakeCBlockDistributionEncode()));
}
using ALdsTile = decltype(make_static_distributed_tensor<AComputeDataType>(
make_static_tile_distribution(MakeABlockDistributionEncode())));
using BLdsTiles = statically_indexed_array<
statically_indexed_array<decltype(make_static_distributed_tensor<BComputeDataType>(
make_static_tile_distribution(
MakeBBlockDistributionEncode()))),
KIterPerWarp>,
NIterPerWarp>;
// C += A * B
template <typename CBlockTensor,
typename ScaleATensor,
typename ScaleBTensor,
index_t MXdlPack_ = 2,
index_t NXdlPack_ = 2,
index_t KXdlPack_ = 2>
CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor,
const ALdsTile& a_warp_tile_,
const BLdsTiles& b_warp_tiles_,
const ScaleATensor& scale_a_tensor,
const ScaleBTensor& scale_b_tensor) const
{
// checks
static_assert(std::is_same_v<CDataType, remove_cv_t<typename CBlockTensor::DataType>>,
"CDataType must be same as CBlockTensor::DataType!");
static_assert(
std::is_same_v<remove_cvref_t<decltype(MakeCBlockDistributionEncode())>,
remove_cvref_t<decltype(CBlockTensor::get_tile_distribution()
.get_static_tile_distribution_encoding())>>,
"C distribution is wrong!");
// Effective XdlPack: fall back to 1 when iteration count is insufficient
constexpr index_t MXdlPack =
(MIterPerWarp >= MXdlPack_ && MIterPerWarp % MXdlPack_ == 0) ? MXdlPack_ : 1;
constexpr index_t NXdlPack =
(NIterPerWarp >= NXdlPack_ && NIterPerWarp % NXdlPack_ == 0) ? NXdlPack_ : 1;
constexpr index_t KXdlPack =
(KIterPerWarp >= KXdlPack_ && KIterPerWarp % KXdlPack_ == 0) ? KXdlPack_ : 1;
constexpr index_t MPackIterPerWarp = MIterPerWarp / MXdlPack;
constexpr index_t NPackIterPerWarp = NIterPerWarp / NXdlPack;
constexpr index_t KPackIterPerWarp = KIterPerWarp / KXdlPack;
// hot loop:
static_for_product<number<KPackIterPerWarp>,
number<NPackIterPerWarp>,
number<MPackIterPerWarp>>{}([&](auto ikpack, auto inpack, auto impack) {
// get A scale for this M-K tile using get_y_sliced_thread_data
auto scale_a_slice = scale_a_tensor.get_y_sliced_thread_data(
sequence<ikpack, impack, 0>{}, sequence<1, 1, 1>{});
const int32_t a_scale_packed = bit_cast<int32_t>(scale_a_slice[number<0>{}]);
// get B scale for this N-K tile using get_y_sliced_thread_data
auto scale_b_slice = scale_b_tensor.get_y_sliced_thread_data(
sequence<ikpack, inpack, 0>{}, sequence<1, 1, 1>{});
const int32_t b_scale_packed = bit_cast<int32_t>(scale_b_slice[number<0>{}]);
// Inner loops: issue MFMAs within the pack group using OpSel
static_for_product<number<KXdlPack>, number<NXdlPack>, number<MXdlPack>>{}(
[&](auto ikxdl, auto inxdl, auto imxdl) {
constexpr auto kIter = ikpack * KXdlPack + ikxdl;
constexpr auto mIter = impack * MXdlPack + imxdl;
constexpr auto nIter = inpack * NXdlPack + inxdl;
// OpSel for A: selects byte within packed int32_t
constexpr index_t kOpSelA = ikxdl * MXdlPack + imxdl;
// OpSel for B: selects byte within packed int32_t
constexpr index_t kOpSelB = ikxdl * NXdlPack + inxdl;
// read A warp tensor from A Block window
AWarpTensor a_warp_tensor;
a_warp_tensor.get_thread_buffer() = a_warp_tile_.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, kIter>{}, a_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
// read B warp tensor from B block tensor
BWarpTensor b_warp_tensor;
b_warp_tensor.get_thread_buffer() =
b_warp_tiles_[number<nIter>{}][number<kIter>{}].get_thread_buffer();
// read C warp tensor from C block tensor
using c_iter_idx = sequence<mIter, nIter>;
CWarpTensor c_warp_tensor;
c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data(
merge_sequences(c_iter_idx{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
// warp GEMM with MX scaling
WarpGemm{}.template operator()<OpSelA<kOpSelA>, OpSelB<kOpSelB>>(
c_warp_tensor,
a_warp_tensor,
b_warp_tensor,
a_scale_packed,
b_scale_packed);
// write C warp tensor into C block tensor
c_block_tensor.set_y_sliced_thread_data(
merge_sequences(c_iter_idx{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
c_warp_tensor.get_thread_buffer());
});
});
}
};
} // namespace ck_tile

View File

@@ -0,0 +1,324 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1_default_policy.hpp"
namespace ck_tile {
// A is block distributed tensor
// B is block distributed tensor
// C is block distributed tensor
template <typename Problem_,
typename Policy_ = BlockGemmARegBRegCRegV1DefaultPolicy,
bool TransposeC_ = false>
struct BlockMXGemmARegBRegCRegV1
{
private:
template <typename PipelineProblem_, typename GemmPolicy_>
struct GemmTraits_
{
using Problem = remove_cvref_t<PipelineProblem_>;
using Policy = remove_cvref_t<GemmPolicy_>;
using ADataType = remove_cvref_t<typename Problem::ADataType>;
using BDataType = remove_cvref_t<typename Problem::BDataType>;
using CDataType = remove_cvref_t<typename Problem::CDataType>;
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
static constexpr index_t kBlockSize = Problem::kBlockSize;
static constexpr index_t MPerBlock = BlockGemmShape::kM;
static constexpr index_t NPerBlock = BlockGemmShape::kN;
static constexpr index_t KPerBlock = BlockGemmShape::kK;
static constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>();
using WarpGemm = remove_cvref_t<decltype(config.template at<0>())>;
static constexpr index_t MWarp = config.template at<1>();
static constexpr index_t NWarp = config.template at<2>();
static constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WarpGemm::kM);
static constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WarpGemm::kN);
static constexpr index_t KIterPerWarp = KPerBlock / WarpGemm::kK;
static constexpr index_t KPackA = WarpGemm::kAKPack;
static constexpr index_t KPackB = WarpGemm::kBKPack;
};
public:
using Problem = remove_cvref_t<Problem_>;
using Policy = remove_cvref_t<Policy_>;
static constexpr bool TransposeC = TransposeC_;
using Traits = GemmTraits_<Problem, Policy>;
using WarpGemm = typename Traits::WarpGemm;
using BlockGemmShape = typename Traits::BlockGemmShape;
using ADataType = remove_cvref_t<typename Traits::ADataType>;
using BDataType = remove_cvref_t<typename Traits::BDataType>;
using CDataType = remove_cvref_t<typename Traits::CDataType>;
static constexpr index_t KIterPerWarp = Traits::KIterPerWarp;
static constexpr index_t MIterPerWarp = Traits::MIterPerWarp;
static constexpr index_t NIterPerWarp = Traits::NIterPerWarp;
static constexpr index_t MWarp = Traits::MWarp;
static constexpr index_t NWarp = Traits::NWarp;
static constexpr bool UseDefaultScheduler = (Problem::NumWaveGroups != 1);
// Note: distribution encodings have MIterPerWarp and NIterPerWarp contiguous because of scale
// packing.
CK_TILE_DEVICE static constexpr auto MakeABlockDistributionEncode()
{
if constexpr(UseDefaultScheduler)
{
constexpr auto a_block_outer_dstr_encoding =
tile_distribution_encoding<sequence<NWarp>,
tuple<sequence<MIterPerWarp>, sequence<KIterPerWarp>>,
tuple<>,
tuple<>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
a_block_outer_dstr_encoding, typename WarpGemm::AWarpDstrEncoding{});
return a_block_dstr_encode;
}
else
{
constexpr auto a_block_outer_dstr_encoding = tile_distribution_encoding<
sequence<NWarp>,
tuple<sequence<MWarp, MIterPerWarp>, sequence<KIterPerWarp>>,
tuple<sequence<1, 0>>,
tuple<sequence<0, 0>>,
sequence<1, 2>,
sequence<1, 0>>{};
constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
a_block_outer_dstr_encoding, typename WarpGemm::AWarpDstrEncoding{});
return a_block_dstr_encode;
}
}
CK_TILE_DEVICE static constexpr auto MakeBBlockDistributionEncode()
{
if constexpr(UseDefaultScheduler)
{
constexpr auto b_block_outer_dstr_encoding =
tile_distribution_encoding<sequence<MWarp>,
tuple<sequence<NIterPerWarp>, sequence<KIterPerWarp>>,
tuple<>,
tuple<>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto b_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
b_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{});
return b_block_dstr_encode;
}
else
{
constexpr auto b_block_outer_dstr_encoding = tile_distribution_encoding<
sequence<MWarp>,
tuple<sequence<NWarp, NIterPerWarp>, sequence<KIterPerWarp>>,
tuple<sequence<0, 1>>,
tuple<sequence<0, 0>>,
sequence<1, 2>,
sequence<1, 0>>{};
constexpr auto b_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
b_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{});
return b_block_dstr_encode;
}
}
CK_TILE_DEVICE static constexpr auto MakeCBlockDistributionEncode()
{
using c_distr_ys_major = std::conditional_t<TransposeC, sequence<2, 1>, sequence<1, 2>>;
if constexpr(UseDefaultScheduler)
{
using c_distr_ys_minor = std::conditional_t<TransposeC, sequence<1, 0>, sequence<0, 1>>;
constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding<
sequence<MWarp>,
tuple<sequence<MIterPerWarp>, sequence<NWarp, NIterPerWarp>>,
tuple<sequence<1, 2>>,
tuple<sequence<0, 0>>,
c_distr_ys_major,
c_distr_ys_minor>{};
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
c_block_outer_dstr_encoding, typename WarpGemm::CWarpDstrEncoding{});
return c_block_dstr_encode;
}
else
{
constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding<
sequence<>,
tuple<sequence<MWarp, MIterPerWarp>, sequence<NWarp, NIterPerWarp>>,
tuple<sequence<1, 2>>,
tuple<sequence<0, 0>>,
c_distr_ys_major,
sequence<1, 1>>{};
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
c_block_outer_dstr_encoding, typename WarpGemm::CWarpDstrEncoding{});
return c_block_dstr_encode;
}
}
// C += A * B with MX scaling and packed-in-two (XdlPack) optimization
// Scale tensors contain pre-packed int32_t: each int32_t holds MXdlPack * KXdlPack e8m0_t
// values (for A) or NXdlPack * KXdlPack (for B), packed on the host.
// Uses OpSel (0-3) to select which byte within the packed int32_t for each MFMA call.
// XdlPack template parameters default to 2; fall back to 1 when iteration count is too small.
template <typename CBlockTensor,
typename ABlockTensor,
typename BBlockTensor,
typename ScaleATensor,
typename ScaleBTensor,
index_t MXdlPack_ = 2,
index_t NXdlPack_ = 2,
index_t KXdlPack_ = 2>
CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor,
const ABlockTensor& a_block_tensor,
const BBlockTensor& b_block_tensor,
const ScaleATensor& scale_a_tensor,
const ScaleBTensor& scale_b_tensor) const
{
static_assert(std::is_same_v<ADataType, remove_cv_t<typename ABlockTensor::DataType>> &&
std::is_same_v<BDataType, remove_cv_t<typename BBlockTensor::DataType>> &&
std::is_same_v<CDataType, remove_cv_t<typename CBlockTensor::DataType>>,
"Datatypes do not match BlockTensor datatypes!");
// check ABC-block-distribution
static_assert(
std::is_same_v<remove_cvref_t<decltype(MakeABlockDistributionEncode())>,
remove_cvref_t<decltype(ABlockTensor::get_tile_distribution()
.get_static_tile_distribution_encoding())>>,
"A distribution is wrong!");
static_assert(
std::is_same_v<remove_cvref_t<decltype(MakeBBlockDistributionEncode())>,
remove_cvref_t<decltype(BBlockTensor::get_tile_distribution()
.get_static_tile_distribution_encoding())>>,
"B distribution is wrong!");
static_assert(
std::is_same_v<remove_cvref_t<decltype(MakeCBlockDistributionEncode())>,
remove_cvref_t<decltype(CBlockTensor::get_tile_distribution()
.get_static_tile_distribution_encoding())>>,
"C distribution is wrong!");
using AWarpDstr = typename WarpGemm::AWarpDstr;
using BWarpDstr = typename WarpGemm::BWarpDstr;
using CWarpDstr = typename WarpGemm::CWarpDstr;
using AWarpTensor = typename WarpGemm::AWarpTensor;
using BWarpTensor = typename WarpGemm::BWarpTensor;
using CWarpTensor = typename WarpGemm::CWarpTensor;
constexpr auto a_warp_y_lengths =
to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
constexpr auto b_warp_y_lengths =
to_sequence(BWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
constexpr auto c_warp_y_lengths =
to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
constexpr auto a_warp_y_index_zeros = uniform_sequence_gen_t<AWarpDstr::NDimY, 0>{};
constexpr auto b_warp_y_index_zeros = uniform_sequence_gen_t<BWarpDstr::NDimY, 0>{};
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
// Effective XdlPack: fall back to 1 when iteration count is insufficient
constexpr index_t MXdlPack =
(MIterPerWarp >= MXdlPack_ && MIterPerWarp % MXdlPack_ == 0) ? MXdlPack_ : 1;
constexpr index_t NXdlPack =
(NIterPerWarp >= NXdlPack_ && NIterPerWarp % NXdlPack_ == 0) ? NXdlPack_ : 1;
constexpr index_t KXdlPack =
(KIterPerWarp >= KXdlPack_ && KIterPerWarp % KXdlPack_ == 0) ? KXdlPack_ : 1;
constexpr index_t MPackIterPerWarp = MIterPerWarp / MXdlPack;
constexpr index_t NPackIterPerWarp = NIterPerWarp / NXdlPack;
constexpr index_t KPackIterPerWarp = KIterPerWarp / KXdlPack;
// hot loop with MX scaling and pre-packed int32_t scales:
// Outer loops iterate over pack groups (scale tile indices)
static_ford<sequence<KPackIterPerWarp, MPackIterPerWarp>>{}([&](auto ii) {
constexpr auto ikpack = number<ii[number<0>{}]>{};
constexpr auto impack = number<ii[number<1>{}]>{};
// Get pre-packed int32_t A scale (already contains MXdlPack*KXdlPack e8m0_t)
auto scale_a_slice = scale_a_tensor.get_y_sliced_thread_data(
sequence<ikpack, impack, 0>{}, sequence<1, 1, 1>{});
const int32_t a_scale_packed = bit_cast<int32_t>(scale_a_slice[number<0>{}]);
static_for<0, NPackIterPerWarp, 1>{}([&](auto inpack) {
// Get pre-packed int32_t B scale
auto scale_b_slice = scale_b_tensor.get_y_sliced_thread_data(
sequence<ikpack, inpack, 0>{}, sequence<1, 1, 1>{});
const int32_t b_scale_packed = bit_cast<int32_t>(scale_b_slice[number<0>{}]);
// Inner loops: issue MFMAs within the pack group using OpSel
static_ford<sequence<KXdlPack, MXdlPack>>{}([&](auto jj) {
constexpr auto ikxdl = number<jj[number<0>{}]>{};
constexpr auto imxdl = number<jj[number<1>{}]>{};
constexpr auto kIter = ikpack * KXdlPack + ikxdl;
constexpr auto mIter = impack * MXdlPack + imxdl;
// read A warp tensor from A block tensor
AWarpTensor a_warp_tensor;
a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, kIter>{}, a_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
// OpSel for A: selects byte within packed int32_t
constexpr index_t kOpSelA = ikxdl * MXdlPack + imxdl;
static_for<0, NXdlPack, 1>{}([&](auto inxdl) {
constexpr auto nIter = inpack * NXdlPack + inxdl;
// read B warp tensor from B block tensor
BWarpTensor b_warp_tensor;
b_warp_tensor.get_thread_buffer() = b_block_tensor.get_y_sliced_thread_data(
merge_sequences(sequence<nIter, kIter>{}, b_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, b_warp_y_lengths));
// OpSel for B: selects byte within packed int32_t
constexpr index_t kOpSelB = ikxdl * NXdlPack + inxdl;
// read C warp tensor from C block tensor
using c_iter_idx = std::conditional_t<TransposeC,
sequence<nIter, mIter>,
sequence<mIter, nIter>>;
CWarpTensor c_warp_tensor;
c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data(
merge_sequences(c_iter_idx{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
// warp GEMM with MX scaling using pre-packed scale and OpSel
WarpGemm{}.template operator()<OpSelA<kOpSelA>, OpSelB<kOpSelB>>(
c_warp_tensor,
a_warp_tensor,
b_warp_tensor,
a_scale_packed,
b_scale_packed);
// write C warp tensor into C block tensor
c_block_tensor.set_y_sliced_thread_data(
merge_sequences(c_iter_idx{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
c_warp_tensor.get_thread_buffer());
});
});
});
});
}
CK_TILE_DEVICE static constexpr auto MakeCBlockTile()
{
return make_static_distributed_tensor<CDataType>(
make_static_tile_distribution(MakeCBlockDistributionEncode()));
}
};
} // namespace ck_tile

View File

@@ -332,8 +332,7 @@ struct MXGemmKernel : UniversalGemmKernel<TilePartitioner_, MXGemmPipeline_, Epi
const std::array<const BDataType*, NumBTensor>& bs_ptr,
const std::array<const void*, NumDTensor>& ds_ptr,
EDataType* e_ptr,
void* smem_ptr_ping,
void* smem_ptr_pong,
void* smem_ptr,
const KernelArgs<ScaleM, ScaleN>& kargs,
const SplitKBatchOffset& splitk_batch_offset,
const index_t i_m,
@@ -363,24 +362,18 @@ struct MXGemmKernel : UniversalGemmKernel<TilePartitioner_, MXGemmPipeline_, Epi
scale_a_block_window,
scale_b_block_window,
num_loop,
smem_ptr_ping,
smem_ptr_pong);
smem_ptr);
// Run Epilogue Pipeline - create C block window directly
auto c_block_window = MakeCBlockWindows(e_ptr, kargs, i_m, i_n);
EpiloguePipeline{}(c_block_window, c_block_tile, d_block_window, smem_ptr_ping);
EpiloguePipeline{}(c_block_window, c_block_tile, d_block_window, smem_ptr);
}
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemPingSize()
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
{
return max(MXGemmPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize());
}
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemPongSize()
{
return MXGemmPipeline::GetSmemSize();
}
template <class ScaleM, class ScaleN>
CK_TILE_DEVICE void operator()(KernelArgs<ScaleM, ScaleN> kargs,
int partition_idx = get_block_id()) const
@@ -389,8 +382,7 @@ struct MXGemmKernel : UniversalGemmKernel<TilePartitioner_, MXGemmPipeline_, Epi
amd_wave_read_first_lane(TilePartitioner::GridSize(kargs.M, kargs.N));
// Allocate shared memory for ping pong buffers
__shared__ char smem_ptr_ping[GetSmemPingSize()];
__shared__ char smem_ptr_pong[GetSmemPongSize()];
__shared__ char smem_ptr[GetSmemSize()];
// Support both persistent and non-persistent modes
do
@@ -423,8 +415,7 @@ struct MXGemmKernel : UniversalGemmKernel<TilePartitioner_, MXGemmPipeline_, Epi
bs_ptr,
kargs.ds_ptr,
e_ptr,
smem_ptr_ping,
smem_ptr_pong,
smem_ptr,
kargs,
splitk_batch_offset,
i_m,

View File

@@ -181,7 +181,8 @@ struct MXGemmPipelineAgBgCrCompAsync : public BaseMXGemmPipelineAgBgCrCompAsync<
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
{
return Policy::template GetSmemSize<Problem>();
constexpr index_t smem_size = Policy::template GetSmemSize<Problem>();
return 2 * smem_size;
}
CK_TILE_HOST_DEVICE static constexpr auto IsTransposeC()
@@ -688,9 +689,11 @@ struct MXGemmPipelineAgBgCrCompAsync : public BaseMXGemmPipelineAgBgCrCompAsync<
const ScaleADramBlockWindowTmp& scale_a_window,
const ScaleBDramBlockWindowTmp& scale_b_window,
index_t num_loop,
void* __restrict__ p_smem_0,
void* __restrict__ p_smem_1) const
void* __restrict__ p_smem) const
{
constexpr index_t smem_size = Policy::template GetSmemSize<Problem>();
const auto smem = reinterpret_cast<uint8_t*>(p_smem);
const bool has_hot_loop = Base::BlockHasHotloop(num_loop);
const auto tail_number = Base::GetBlockLoopTailNum(num_loop);
@@ -703,8 +706,8 @@ struct MXGemmPipelineAgBgCrCompAsync : public BaseMXGemmPipelineAgBgCrCompAsync<
scale_a_window,
scale_b_window,
num_loop,
p_smem_0,
p_smem_1);
smem,
smem + smem_size);
};
return Base::TailHandler(RunPipeline, has_hot_loop, tail_number);
@@ -720,9 +723,11 @@ struct MXGemmPipelineAgBgCrCompAsync : public BaseMXGemmPipelineAgBgCrCompAsync<
const ScaleADramBlockWindowTmp& scale_a_window,
const ScaleBDramBlockWindowTmp& scale_b_window,
const index_t num_loop,
void* __restrict__ p_smem_0,
void* __restrict__ p_smem_1) const
void* __restrict__ p_smem) const
{
constexpr index_t smem_size = Policy::template GetSmemSize<Problem>();
const auto smem = reinterpret_cast<uint8_t*>(p_smem);
const bool has_hot_loop = Base::BlockHasHotloop(num_loop);
const auto tail_number = Base::GetBlockLoopTailNum(num_loop);
@@ -735,8 +740,8 @@ struct MXGemmPipelineAgBgCrCompAsync : public BaseMXGemmPipelineAgBgCrCompAsync<
scale_a_window,
scale_b_window,
num_loop,
p_smem_0,
p_smem_1);
smem,
smem + smem_size);
};
return Base::TailHandler(RunPipeline, has_hot_loop, tail_number);

View File

@@ -8,6 +8,7 @@
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp"
#include "ck_tile/ops/gemm_mx/block/block_mx_gemm_areg_breg_creg_v1.hpp"
#include <type_traits>
namespace ck_tile {
@@ -128,7 +129,7 @@ struct MXGemmPipelineAgBgCrCompAsyncDefaultPolicy
BlockWarps,
WarpGemm>;
return BlockGemmARegBRegCRegV1<Problem, BlockGemmPolicy>{};
return BlockMXGemmARegBRegCRegV1<Problem, BlockGemmPolicy>{};
}
// XdlPack: how many e8m0_t scale values are packed into one int32_t per dimension
@@ -170,12 +171,12 @@ struct MXGemmPipelineAgBgCrCompAsyncDefaultPolicy
return make_static_tile_distribution(
tile_distribution_encoding<sequence<NWarp>,
tuple<sequence<MIterPerWarp_packed, MWarp, MPerXdl>,
tuple<sequence<MWarp, MIterPerWarp_packed, MPerXdl>,
sequence<KIterPerWarp_packed, K_Lane, KPerLane>>,
tuple<sequence<0, 1>, sequence<2, 1>>,
tuple<sequence<0, 1>, sequence<1, 2>>,
tuple<sequence<0, 0>, sequence<1, 2>>,
sequence<2, 1, 2>,
sequence<0, 0, 2>>{});
sequence<0, 1, 2>>{});
}
template <typename Problem>
@@ -208,12 +209,12 @@ struct MXGemmPipelineAgBgCrCompAsyncDefaultPolicy
return make_static_tile_distribution(
tile_distribution_encoding<sequence<MWarp>,
tuple<sequence<NIterPerWarp_packed, NWarp, NPerXdl>,
tuple<sequence<NWarp, NIterPerWarp_packed, NPerXdl>,
sequence<KIterPerWarp_packed, K_Lane, KPerLane>>,
tuple<sequence<0, 1>, sequence<2, 1>>,
tuple<sequence<0, 1>, sequence<1, 2>>,
tuple<sequence<0, 0>, sequence<1, 2>>,
sequence<2, 1, 2>,
sequence<0, 0, 2>>{});
sequence<0, 1, 2>>{});
}
};
} // namespace ck_tile

View File

@@ -0,0 +1,282 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_eight_waves_base.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp"
#include "ck_tile/ops/gemm_mx/pipeline/gemm_pipeline_ag_bg_cr_comp_async_eight_waves_policy.hpp"
namespace ck_tile {
/**
* @brief Compute optimized pipeline version async for 8 waves
*
* This pipeline introduces asynchronous load from global memory to LDS,
* skipping the intermediate loading into pipeline registers.
*/
template <typename Problem, typename Policy = MXGemmPipelineAgBgCrCompAsyncEightWavesPolicy>
struct MXGemmPipelineAgBgCrCompAsyncEightWaves : public BaseGemmPipelineAgBgCrCompV3<Problem>
{
using Base = BaseGemmPipelineAgBgCrCompV3<Problem>;
using PipelineImplBase = GemmPipelineAgBgCrEightWavesImplBase<Problem, Policy>;
using AsDataType = remove_cvref_t<typename Problem::AsDataTypeTuple>;
using BsDataType = remove_cvref_t<typename Problem::BsDataTypeTuple>;
using CDataType = remove_cvref_t<typename Problem::CDataType>;
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
using AsLayout = remove_cvref_t<typename Problem::AsLayoutTuple>;
using BsLayout = remove_cvref_t<typename Problem::BsLayoutTuple>;
using CLayout = remove_cvref_t<typename Problem::CLayout>;
using AElementWise = remove_cvref_t<typename Problem::AElementWise>;
using BElementWise = remove_cvref_t<typename Problem::BElementWise>;
using ALayout = remove_cvref_t<std::tuple_element_t<0, AsLayout>>;
using BLayout = remove_cvref_t<std::tuple_element_t<0, BsLayout>>;
using ADataType = remove_cvref_t<std::tuple_element_t<0, AsDataType>>;
using BDataType = remove_cvref_t<std::tuple_element_t<0, BsDataType>>;
static_assert(!std::is_same_v<BDataType, pk_int4_t>, "Not implemented");
static constexpr index_t APackedSize = ck_tile::numeric_traits<ADataType>::PackedSize;
static constexpr index_t BPackedSize = ck_tile::numeric_traits<BDataType>::PackedSize;
using BlockGemm = remove_cvref_t<decltype(Policy::template GetBlockGemm<Problem>())>;
using WarpGemm = typename BlockGemm::WarpGemm;
static constexpr auto I0 = number<0>{};
static constexpr auto I1 = number<1>{};
static constexpr auto I2 = number<2>{};
static constexpr index_t BlockSize = Problem::kBlockSize;
static constexpr index_t MPerBlock = BlockGemmShape::kM;
static constexpr index_t NPerBlock = BlockGemmShape::kN;
static constexpr index_t KPerBlock = BlockGemmShape::kK;
static constexpr index_t MWarps = BlockGemmShape::BlockWarps::at(I0);
static constexpr index_t NWarps = BlockGemmShape::BlockWarps::at(I1);
static constexpr index_t KWarps = BlockGemmShape::BlockWarps::at(I2);
static constexpr index_t kflatKPerBlock = BlockGemmShape::flatKPerBlock;
static constexpr index_t MIterPerWarp = MPerBlock / (MWarps * WarpGemm::kM);
static constexpr index_t NIterPerWarp = NPerBlock / (NWarps * WarpGemm::kN);
static constexpr index_t KIterPerWarp = KPerBlock / (KWarps * WarpGemm::kK);
static constexpr bool Async = true;
template <bool IsWave32Host = false>
static constexpr index_t GetVectorSizeA()
{
return Policy::template GetVectorSizeA<Problem>();
}
template <bool IsWave32Host = false>
static constexpr index_t GetVectorSizeB()
{
return Policy::template GetVectorSizeB<Problem>();
}
static constexpr index_t GetSmemPackA() { return Policy::template GetSmemPackA<Problem>(); }
static constexpr index_t GetSmemPackB() { return Policy::template GetSmemPackB<Problem>(); }
static constexpr index_t NumWaveGroups = Problem::NumWaveGroups;
static constexpr index_t Preshuffle = Problem::Preshuffle;
static constexpr bool kPadM = Problem::kPadM;
static constexpr bool kPadN = Problem::kPadN;
static constexpr bool kPadK = Problem::kPadK;
static constexpr bool DoubleSmemBuffer = Problem::DoubleSmemBuffer;
static constexpr auto Scheduler = Problem::Scheduler;
[[nodiscard]] CK_TILE_HOST static const std::string GetPipelineName()
{
// clang-format off
return "COMPUTE_ASYNC_EIGHT_WAVES";
// clang-format on
}
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
{
// clang-format off
constexpr index_t WaveNumM = BlockGemmShape::BlockWarps::at(I0);
constexpr index_t WaveNumN = BlockGemmShape::BlockWarps::at(I1);
return concat('_', "pipeline_AgBgCrCompAsyncEightWaves",
concat('x', MPerBlock, NPerBlock, KPerBlock), BlockSize,
concat('x', GetVectorSizeA(), GetVectorSizeB()),
concat('x', WaveNumM, WaveNumN),
concat('x', kPadM, kPadN, kPadK),
Problem::GetName());
// clang-format on
}
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
{
return Policy::template GetSmemSize<Problem>();
}
static constexpr index_t MFMA_INST = MIterPerWarp * NIterPerWarp * KIterPerWarp;
// Scales are packed so odd numbers of iterations greater than 1 are not supported
static_assert((MIterPerWarp == 1) || (MIterPerWarp % 2 == 0));
static_assert((NIterPerWarp == 1) || (NIterPerWarp % 2 == 0));
static_assert((KIterPerWarp == 1) || (KIterPerWarp % 2 == 0));
template <GemmPipelineScheduler Scheduler>
struct PipelineImpl : public PipelineImplBase
{
};
template <>
struct PipelineImpl<GemmPipelineScheduler::Intrawave> : public PipelineImplBase
{
using Base = PipelineImplBase;
template <bool HasHotLoop,
TailNumber TailNum,
typename AsDramBlockWindowTmp,
typename BsDramBlockWindowTmp,
typename AElementFunction,
typename BElementFunction,
typename ScaleADramBlockWindowTmp,
typename ScaleBDramBlockWindowTmp,
typename std::enable_if_t<!is_detected<is_tuple, AsDramBlockWindowTmp>::value &&
!is_detected<is_tuple, BsDramBlockWindowTmp>::value,
bool>* = nullptr>
CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp,
const AElementFunction& a_element_func,
const BsDramBlockWindowTmp& b_dram_block_window_tmp,
const BElementFunction& b_element_func,
const ScaleADramBlockWindowTmp& scale_a_window,
const ScaleBDramBlockWindowTmp& scale_b_window,
index_t num_loop,
void* __restrict__ p_smem) const
{
// TODO: A/B elementwise functions currently not supported
ignore = a_element_func;
ignore = b_element_func;
// ------
// Checks
// ------
static_assert(
std::is_same_v<ADataType,
remove_cvref_t<typename AsDramBlockWindowTmp::DataType>> &&
std::is_same_v<BDataType,
remove_cvref_t<typename BsDramBlockWindowTmp::DataType>>,
"A/B Dram block window should have the same data type as appropriate "
"([A|B]DataType) defined in Problem definition!");
static_assert(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>, "Wrong!");
static_assert(std::is_same_v<BLayout, tensor_layout::gemm::ColumnMajor>, "Wrong!");
static_assert((MPerBlock == AsDramBlockWindowTmp{}.get_window_lengths()[I0] &&
KPerBlock == AsDramBlockWindowTmp{}.get_window_lengths()[I1]),
"A block window has incorrect lengths for defined ALayout!");
static_assert(Preshuffle //
? (NWarps == BsDramBlockWindowTmp{}.get_window_lengths()[I0] &&
kflatKPerBlock == BsDramBlockWindowTmp{}.get_window_lengths()[I1])
: (NPerBlock == BsDramBlockWindowTmp{}.get_window_lengths()[I0] &&
KPerBlock == BsDramBlockWindowTmp{}.get_window_lengths()[I1]),
"B block window has incorrect lengths for defined BLayout!");
// ------------------
// Hot loop scheduler
// ------------------
auto hot_loop_scheduler = [&]() {
__builtin_amdgcn_sched_group_barrier(0x008, MIterPerWarp, 0); // MFMA
s_waitcnt_lgkm<4>();
__builtin_amdgcn_sched_group_barrier(0x004, 1, 0); // lgkmcnt / SALU
static_for<0, MFMA_INST - MIterPerWarp, 1>{}([&](auto) {
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
});
__builtin_amdgcn_sched_barrier(0);
};
// -------
// Compute
// -------
return Base::template Run_<HasHotLoop, TailNum>(p_smem,
num_loop,
a_dram_block_window_tmp,
b_dram_block_window_tmp,
scale_a_window,
scale_b_window,
hot_loop_scheduler);
}
};
template <typename AsDramBlockWindowTmp,
typename BsDramBlockWindowTmp,
typename AElementFunction,
typename BElementFunction,
typename ScaleADramBlockWindowTmp,
typename ScaleBDramBlockWindowTmp,
typename std::enable_if_t<!is_detected<is_tuple, AsDramBlockWindowTmp>::value &&
!is_detected<is_tuple, BsDramBlockWindowTmp>::value,
bool>* = nullptr>
CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp,
const AElementFunction& a_element_func,
const BsDramBlockWindowTmp& b_dram_block_window_tmp,
const BElementFunction& b_element_func,
const ScaleADramBlockWindowTmp& scale_a_window,
const ScaleBDramBlockWindowTmp& scale_b_window,
index_t num_loop,
void* p_smem) const
{
const bool has_hot_loop = Base::BlockHasHotloop(num_loop);
const auto tail_number = Base::GetBlockLoopTailNum(num_loop);
const auto RunPipeline = [&](auto hot_loop_, auto tail_num_) {
return PipelineImpl<Scheduler>{}.template operator()<hot_loop_.value, tail_num_.value>(
a_dram_block_window_tmp,
a_element_func,
b_dram_block_window_tmp,
b_element_func,
scale_a_window,
scale_b_window,
num_loop,
p_smem);
};
return Base::TailHandler(RunPipeline, has_hot_loop, tail_number);
}
template <typename AsDramBlockWindowTmp,
typename BsDramBlockWindowTmp,
typename ScaleADramBlockWindowTmp,
typename ScaleBDramBlockWindowTmp,
typename std::enable_if_t<!is_detected<is_tuple, AsDramBlockWindowTmp>::value &&
!is_detected<is_tuple, BsDramBlockWindowTmp>::value,
bool>* = nullptr>
CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp,
const BsDramBlockWindowTmp& b_dram_block_window_tmp,
const ScaleADramBlockWindowTmp& scale_a_window,
const ScaleBDramBlockWindowTmp& scale_b_window,
index_t num_loop,
void* p_smem) const
{
const bool has_hot_loop = Base::BlockHasHotloop(num_loop);
const auto tail_number = Base::GetBlockLoopTailNum(num_loop);
const auto RunPipeline = [&](auto hot_loop_, auto tail_num_) {
return PipelineImpl<Scheduler>{}.template operator()<hot_loop_.value, tail_num_.value>(
a_dram_block_window_tmp,
identity{},
b_dram_block_window_tmp,
identity{},
scale_a_window,
scale_b_window,
num_loop,
p_smem);
};
return Base::TailHandler(RunPipeline, has_hot_loop, tail_number);
}
};
} // namespace ck_tile

View File

@@ -0,0 +1,203 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_async_eight_waves_policy.hpp"
#include "ck_tile/ops/gemm_mx/block/block_mx_gemm_areg_breg_creg_eight_waves_v1.hpp"
namespace ck_tile {
namespace detail {
template <typename Problem>
struct MXGemmPipelineAgBgCrCompAsyncEightWavesPolicy
{
static constexpr auto I0 = number<0>{};
static constexpr auto I1 = number<1>{};
static constexpr auto I2 = number<2>{};
// MX scaling configuration: each e8m0 scale covers 32 elements in K
static constexpr int BlockScaleSize = 32;
using ALayout = remove_cvref_t<typename Problem::ALayout>;
using BLayout = remove_cvref_t<typename Problem::BLayout>;
using ADataType = remove_cvref_t<typename Problem::ADataType>;
using BDataType = remove_cvref_t<typename Problem::BDataType>;
using CDataType = remove_cvref_t<typename Problem::CDataType>;
using AComputeDataType = remove_cvref_t<typename Problem::AComputeDataType>;
using BComputeDataType = remove_cvref_t<typename Problem::BComputeDataType>;
using ComputeDataType = AComputeDataType;
static_assert(std::is_same_v<ALayout, ck_tile::tensor_layout::gemm::RowMajor>, "Wrong!");
static_assert(std::is_same_v<BLayout, ck_tile::tensor_layout::gemm::ColumnMajor>, "Wrong!");
static_assert(is_any_of<AComputeDataType, fp8_t, bf8_t, pk_fp4_t>::value);
static_assert(is_any_of<BComputeDataType, fp8_t, bf8_t, pk_fp4_t>::value);
static_assert(std::is_same_v<AComputeDataType, BComputeDataType>);
static_assert(std::is_same_v<CDataType, float>);
using BlockGemmShape = typename Problem::BlockGemmShape;
using BlockWarps = typename BlockGemmShape::BlockWarps;
using WarpTile = typename BlockGemmShape::WarpTile;
static constexpr index_t BlockSize = Problem::kBlockSize;
static constexpr index_t MPerBlock = BlockGemmShape::kM;
static constexpr index_t NPerBlock = BlockGemmShape::kN;
static constexpr index_t KPerBlock = BlockGemmShape::kK;
static constexpr index_t MWarps = BlockGemmShape::BlockWarps::at(I0);
static constexpr index_t NWarps = BlockGemmShape::BlockWarps::at(I1);
static constexpr index_t KWarps = BlockGemmShape::BlockWarps::at(I2);
static constexpr index_t WarpTileM = WarpTile::at(I0);
static constexpr index_t WarpTileN = WarpTile::at(I1);
static constexpr index_t WarpTileK = WarpTile::at(I2);
static constexpr index_t MWarpTiles = MPerBlock / WarpTileM;
static constexpr index_t NWarpTiles = NPerBlock / WarpTileN;
static constexpr index_t KWarpTiles = KPerBlock / WarpTileK;
// XdlPack: how many e8m0_t scale values are packed into one int32_t per dimension
// Host packs MXdlPack * KXdlPack e8m0_t into one int32_t for A scales
// Host packs NXdlPack * KXdlPack e8m0_t into one int32_t for B scales
static constexpr int MXdlPack = 2;
static constexpr int NXdlPack = 2;
static constexpr int KXdlPack = 2;
// Compute effective XdlPack sizes (fall back to 1 when iter count < pack)
static constexpr index_t MPerXdl = WarpTile::at(I0);
static constexpr index_t NPerXdl = WarpTile::at(I1);
static constexpr index_t KPerXdl = WarpTile::at(I2);
static constexpr index_t MIterPerWarp = MPerBlock / (MWarps * MPerXdl);
static constexpr index_t NIterPerWarp = NPerBlock / (NWarps * NPerXdl);
static constexpr index_t KIterPerWarp = KPerBlock / KPerXdl;
static constexpr index_t MXdlPackEff =
(MIterPerWarp >= MXdlPack && MIterPerWarp % MXdlPack == 0) ? MXdlPack : 1;
static constexpr index_t NXdlPackEff =
(NIterPerWarp >= NXdlPack && NIterPerWarp % NXdlPack == 0) ? NXdlPack : 1;
static constexpr index_t KXdlPackEff =
(KIterPerWarp >= KXdlPack && KIterPerWarp % KXdlPack == 0) ? KXdlPack : 1;
static constexpr index_t KPerBlockScale = KPerBlock / BlockScaleSize / KXdlPackEff;
static constexpr index_t KPerWarp = KPerBlock / KWarps;
static constexpr index_t NPerWarp = NPerBlock / NWarps;
static_assert(NWarps == 2, "NWarps == 2 for ping-pong!");
static_assert(KWarpTiles == KWarps, "Wrong!");
static constexpr index_t warp_size = get_warp_size();
static constexpr index_t warp_num = BlockSize / warp_size;
static_assert(warp_size == 64, "Wrong!");
static_assert(warp_num * warp_size == BlockSize, "Wrong!");
static_assert(sizeof(ADataType) == sizeof(BDataType), "Wrong!");
static constexpr index_t ElementSize = sizeof(ADataType);
static constexpr index_t K2 = Problem::VectorLoadSize / ElementSize; // 16
static constexpr index_t K1 = WarpTile::at(I2) / K2; // 8
static constexpr index_t K0 = KPerWarp / (K1 * K2);
static_assert(K0 * K1 * K2 == KPerWarp, "Wrong!");
static_assert(K0 == 1, "Wrong!");
CK_TILE_HOST_DEVICE static constexpr auto GetKStepAQ() { return KPerBlockScale; }
CK_TILE_HOST_DEVICE static constexpr auto GetKStepBQ() { return KPerBlockScale; }
CK_TILE_HOST_DEVICE static constexpr auto GetInstCountAQ()
{
return (MIterPerWarp / MXdlPackEff) * (KIterPerWarp / KXdlPackEff);
}
CK_TILE_HOST_DEVICE static constexpr auto GetInstCountBQ()
{
return (NIterPerWarp / NXdlPackEff) * (KIterPerWarp / KXdlPackEff);
}
CK_TILE_HOST_DEVICE static constexpr auto MakeAQBlockDistribution()
{
constexpr index_t K_Lane = get_warp_size() / WarpTileM;
constexpr index_t KPerLane = WarpTileK / BlockScaleSize / K_Lane;
constexpr index_t MIterPerWarp_packed = MIterPerWarp / MXdlPackEff;
constexpr index_t KIterPerWarp_packed = KIterPerWarp / KXdlPackEff;
return make_static_tile_distribution(
tile_distribution_encoding<
sequence<NWarps>, // repeat over MWarps
tuple<sequence<MWarps, MIterPerWarp_packed, WarpTileM>, // M dimension (first)
sequence<KIterPerWarp_packed, K_Lane, KPerLane>>, // K dimension (second)
tuple<sequence<0, 1>, sequence<2, 1>>, // <MWarps, NWarps>, <K_Lane, WarpTileM>
tuple<sequence<0, 0>, sequence<1, 2>>,
sequence<2, 1, 2>, // <KIterPerWarp, MIterPerWarp, KPerLane>
sequence<0, 1, 2>>{});
}
CK_TILE_HOST_DEVICE static constexpr auto MakeBQBlockDistribution()
{
constexpr index_t K_Lane = get_warp_size() / WarpTileN;
constexpr index_t KPerLane = WarpTileK / BlockScaleSize / K_Lane;
constexpr index_t NIterPerWarp_packed = NIterPerWarp / NXdlPackEff;
constexpr index_t KIterPerWarp_packed = KIterPerWarp / KXdlPackEff;
return make_static_tile_distribution(
tile_distribution_encoding<
sequence<MWarps>, // repeat over MWarps
tuple<sequence<2, NIterPerWarp_packed, NWarps / 2, WarpTileN>, // N dimension
// (first)
sequence<KIterPerWarp_packed, K_Lane, KPerLane>>, // K dimension (second)
tuple<sequence<1, 0, 1>, sequence<2, 1>>, // <MWarps, NWarps>, <K_Lane, MPerXdl>
tuple<sequence<0, 0, 2>, sequence<1, 3>>,
sequence<2, 1, 2>, // <KIterPerWarp, NIterPerWarp, KPerLane>
sequence<0, 1, 2>>{});
}
CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm()
{
constexpr auto wg_attr_num_access =
(std::is_same_v<ADataType, fp8_t> || std::is_same_v<BDataType, fp8_t>)
? WGAttrNumAccessEnum::Double
: WGAttrNumAccessEnum::Single;
using WarpGemm = WarpGemmDispatcher<ADataType,
BDataType,
CDataType, // AccDataType
WarpTile::at(I0),
WarpTile::at(I1),
WarpTile::at(I2),
Problem::TransposeC,
false,
false,
wg_attr_num_access>;
using BlockGemmPolicy = BlockGemmARegBRegCRegV1CustomPolicy<ADataType,
BDataType,
CDataType,
BlockWarps,
WarpGemm>;
return BlockMXGemmARegBRegCRegEightWavesV1<Problem, BlockGemmPolicy>{};
}
};
} // namespace detail
struct MXGemmPipelineAgBgCrCompAsyncEightWavesPolicy
: public GemmPipelineAgBgCrCompAsyncEightWavesPolicy
{
#define FORWARD_METHOD_(method) \
template <typename Problem, typename... Args> \
CK_TILE_HOST_DEVICE static constexpr auto method(Args&&... args) \
{ \
return detail::MXGemmPipelineAgBgCrCompAsyncEightWavesPolicy<Problem>::method( \
std::forward<Args>(args)...); \
}
FORWARD_METHOD_(MakeAQBlockDistribution);
FORWARD_METHOD_(MakeBQBlockDistribution);
FORWARD_METHOD_(GetBlockGemm);
FORWARD_METHOD_(GetKStepAQ);
FORWARD_METHOD_(GetKStepBQ);
FORWARD_METHOD_(GetInstCountAQ);
FORWARD_METHOD_(GetInstCountBQ);
#undef FORWARD_METHOD_
};
} // namespace ck_tile

View File

@@ -134,12 +134,7 @@ struct ABQuantGemmPipelineAgBgCrEightWaves : public BaseGemmPipelineAgBgCrCompV3
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
{
// We are not storing the original packed type in LDS, so we need to multiply the smem size
// by the packed size.
constexpr index_t smem_size_a = Policy::template GetSmemSizeA<Problem>() * APackedSize;
constexpr index_t smem_size_b = Policy::template GetSmemSizeB<Problem>() * BPackedSize;
return 2 * (smem_size_a + smem_size_b);
return Policy::template GetSmemSize<Problem>();
}
CK_TILE_HOST static std::string Print() { return "ABQuantGemmPipelineAgBgCrEightWaves\n"; }

View File

@@ -61,7 +61,7 @@ struct GemmABQuantPipelineAgBgCrAsyncPolicy
static constexpr index_t NIterPerWarp = NWarpTiles / NWarps;
static constexpr index_t KPerWarp = KPerBlock / KWarps;
static constexpr index_t NPerWarp = NPerBlock / NWarps;
static_assert(NWarps == 2, "KWarps == 2 for ping-pong!");
static_assert(NWarps == 2, "NWarps == 2 for ping-pong!");
static_assert(KWarpTiles == KWarps, "Wrong!");
static constexpr index_t KPerWarpAQ = KPerWarp / Problem::AQuantGroupSize::kK;
@@ -87,6 +87,11 @@ struct GemmABQuantPipelineAgBgCrAsyncPolicy
CK_TILE_HOST_DEVICE static constexpr auto GetKStepAQ() { return KPerBlockAQ; }
CK_TILE_HOST_DEVICE static constexpr auto GetKStepBQ() { return KPerBlockBQ; }
// TODO: generalize instruction count calculation
CK_TILE_HOST_DEVICE static constexpr auto GetInstCountAQ() { return MIterPerWarp; }
CK_TILE_HOST_DEVICE static constexpr auto GetInstCountBQ() { return 1; }
CK_TILE_HOST_DEVICE static constexpr auto MakeAQBlockDistribution()
{
return make_static_tile_distribution(
@@ -156,6 +161,8 @@ struct GemmABQuantPipelineAgBgCrAsyncPolicy : public GemmPipelineAgBgCrCompAsync
FORWARD_METHOD_(GetBlockGemm);
FORWARD_METHOD_(GetKStepAQ);
FORWARD_METHOD_(GetKStepBQ);
FORWARD_METHOD_(GetInstCountAQ);
FORWARD_METHOD_(GetInstCountBQ);
#undef FORWARD_METHOD_
};

View File

@@ -7,11 +7,8 @@ if(CK_USE_OCP_FP8)
endif()
if(GPU_TARGETS MATCHES "gfx95")
add_gtest_executable(test_ck_tile_mx_gemm_fp4 test_mx_gemm_fp4.cpp)
target_compile_options(test_ck_tile_mx_gemm_fp4 PRIVATE ${TEST_MX_GEMM_COMPILE_OPTIONS})
add_gtest_executable(test_ck_tile_mx_gemm_fp8 test_mx_gemm_fp8.cpp)
target_compile_options(test_ck_tile_mx_gemm_fp8 PRIVATE ${TEST_MX_GEMM_COMPILE_OPTIONS})
add_gtest_executable(test_ck_tile_mx_gemm_async test_mx_gemm_async.cpp)
target_compile_options(test_ck_tile_mx_gemm_async PRIVATE ${TEST_MX_GEMM_COMPILE_OPTIONS})
else()
message(DEBUG "Skipping ck_tile MX GEMM tests for current target")
endif()

View File

@@ -0,0 +1,33 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "test_mx_gemm_config.hpp"
#include "test_mx_gemm_util.hpp"
using Row = ck_tile::tensor_layout::gemm::RowMajor;
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
using F4 = ck_tile::pk_fp4_t;
using F8 = ck_tile::fp8_t;
using F6 = ck_tile::pk_fp6x16_t;
// clang-format off
using MxTypes = ::testing::Types<std::tuple<F4, F4, MX_GemmConfig16, Row, Col, Row>,
std::tuple<F4, F4, MX_GemmConfigEightWaves, Row, Col, Row>,
std::tuple<F8, F8, MX_GemmConfig16, Row, Col, Row>,
std::tuple<F8, F8, MX_GemmConfigEightWaves, Row, Col, Row>>;
// clang-format on
template <typename TypeParam>
class TestMxGemm : public TestMxGemmUtil<TypeParam>
{
};
TYPED_TEST_SUITE(TestMxGemm, MxTypes);
TYPED_TEST(TestMxGemm, Default)
{
// No M/N/K padding so we use 128x256x256 as smallest dimensions
this->Run(128, 256, 256);
this->Run(256, 256, 512);
this->Run(1024, 1024, 1024);
}

View File

@@ -80,16 +80,22 @@ struct MxGemmConfig
static constexpr bool TiledMMAPermuteN = false;
};
struct MXfp4_GemmConfig16 : MxGemmConfig
struct MX_GemmConfig16 : MxGemmConfig
{
static constexpr ck_tile::index_t M_Tile = 64;
static constexpr ck_tile::index_t N_Tile = 64;
static constexpr ck_tile::index_t N_Tile = 128;
static constexpr ck_tile::index_t K_Tile = 256;
};
struct MXfp8_GemmConfig16 : MxGemmConfig
struct MX_GemmConfigEightWaves : MxGemmConfig
{
static constexpr ck_tile::index_t M_Tile = 64;
static constexpr ck_tile::index_t N_Tile = 64;
static constexpr ck_tile::index_t K_Tile = 256;
static constexpr ck_tile::index_t M_Warp = 4;
static constexpr ck_tile::index_t N_Warp = 2; // NWarps == 2 for ping-pong!
static constexpr ck_tile::index_t K_Warp = 1;
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 128 * N_Warp;
static constexpr ck_tile::index_t K_Tile = 128 * K_Warp;
static constexpr int kBlockPerCu = 2;
};

View File

@@ -1,30 +0,0 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "test_mx_gemm_config.hpp"
#include "test_mx_gemm_util.hpp"
using Row = ck_tile::tensor_layout::gemm::RowMajor;
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
using MxFp4Types = ::testing::Types<
std::tuple<ck_tile::pk_fp4_t, ck_tile::pk_fp4_t, MXfp4_GemmConfig16, Row, Col, Row>>;
template <typename TypeParam>
class TestMxGemmFp4 : public TestMxGemmUtil<std::tuple_element_t<0, TypeParam>,
std::tuple_element_t<1, TypeParam>,
std::tuple_element_t<2, TypeParam>,
std::tuple_element_t<3, TypeParam>,
std::tuple_element_t<4, TypeParam>,
std::tuple_element_t<5, TypeParam>>
{
};
TYPED_TEST_SUITE(TestMxGemmFp4, MxFp4Types);
TYPED_TEST(TestMxGemmFp4, BasicSizes)
{
this->Run(64, 64, 256);
this->Run(128, 128, 256);
this->Run(64, 128, 512);
}

View File

@@ -1,30 +0,0 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "test_mx_gemm_config.hpp"
#include "test_mx_gemm_util.hpp"
using Row = ck_tile::tensor_layout::gemm::RowMajor;
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
using MxFp8Types =
::testing::Types<std::tuple<ck_tile::fp8_t, ck_tile::fp8_t, MXfp8_GemmConfig16, Row, Col, Row>>;
template <typename TypeParam>
class TestMxGemmFp8 : public TestMxGemmUtil<std::tuple_element_t<0, TypeParam>,
std::tuple_element_t<1, TypeParam>,
std::tuple_element_t<2, TypeParam>,
std::tuple_element_t<3, TypeParam>,
std::tuple_element_t<4, TypeParam>,
std::tuple_element_t<5, TypeParam>>
{
};
TYPED_TEST_SUITE(TestMxGemmFp8, MxFp8Types);
TYPED_TEST(TestMxGemmFp8, BasicSizes)
{
this->Run(64, 64, 256);
this->Run(128, 128, 256);
this->Run(64, 128, 512);
}

View File

@@ -4,8 +4,7 @@
#pragma once
#include "ck_tile/host.hpp"
#include "ck_tile/ops/gemm_mx/pipeline/gemm_pipeline_ag_bg_cr_comp_async.hpp"
#include "ck_tile/ops/gemm_mx/kernel/gemm_mx_kernel.hpp"
#include "ck_tile/ops/gemm_mx.hpp"
#include "test_mx_gemm_config.hpp"
template <typename GemmConfig,
@@ -48,7 +47,12 @@ float mx_gemm_calc(const MXGemmHostArgs<ScaleM, ScaleN>& args, const ck_tile::st
MXGemmTraits,
GemmConfig::Scheduler>;
using MXGemmPipeline = ck_tile::MXGemmPipelineAgBgCrCompAsync<MXPipelineProblem>;
constexpr bool IsEightWave =
(GemmConfig::M_Warp * GemmConfig::N_Warp * GemmConfig::K_Warp) == 8;
using MXGemmPipeline =
std::conditional_t<IsEightWave,
ck_tile::MXGemmPipelineAgBgCrCompAsyncEightWaves<MXPipelineProblem>,
ck_tile::MXGemmPipelineAgBgCrCompAsync<MXPipelineProblem>>;
using TilePartitioner =
ck_tile::GemmSpatiallyLocalTilePartitioner<GemmShape,
@@ -71,7 +75,15 @@ float mx_gemm_calc(const MXGemmHostArgs<ScaleM, ScaleN>& args, const ck_tile::st
GemmConfig::M_Warp_Tile,
GemmConfig::N_Warp_Tile,
GemmConfig::K_Warp_Tile,
MXPipelineProblem::TransposeC>>;
MXPipelineProblem::TransposeC,
1, // kNumWaveGroups_ (Default)
false, // FixedVectorSize_ (Default)
1, // VectorSizeC_ (Default)
1, // BlockedXDLN_PerWarp_ (Default)
false, // DoubleSmemBuffer_ (Default)
ADataType, // AComputeDataType
BDataType, // BComputeDataType
true>>; // TilesPacked_ (because of packed scales)
using Kernel = ck_tile::MXGemmKernel<TilePartitioner, MXGemmPipeline, GemmEpilogue>;

View File

@@ -30,15 +30,17 @@ auto calculate_rtol_atol_mx(ck_tile::index_t K, float max_accumulated_value)
return ck_tile::make_tuple(rtol, atol);
}
template <typename ADataType,
typename BDataType,
typename GemmConfig,
typename ALayout,
typename BLayout,
typename CLayout>
template <typename Tuple>
class TestMxGemmUtil : public ::testing::Test
{
protected:
using ADataType = std::tuple_element_t<0, Tuple>;
using BDataType = std::tuple_element_t<1, Tuple>;
using GemmConfig = std::tuple_element_t<2, Tuple>;
using ALayout = std::tuple_element_t<3, Tuple>;
using BLayout = std::tuple_element_t<4, Tuple>;
using CLayout = std::tuple_element_t<5, Tuple>;
using AccDataType = float;
using CDataType = ck_tile::fp16_t;
using ScaleType = ck_tile::e8m0_t;
@@ -94,7 +96,7 @@ class TestMxGemmUtil : public ::testing::Test
return packed;
}
void Run(ck_tile::index_t M, ck_tile::index_t N, ck_tile::index_t K, int seed = 1234)
void Run(ck_tile::index_t M, ck_tile::index_t N, ck_tile::index_t K)
{
const ck_tile::index_t scale_k_size = K / 32;
const ck_tile::index_t stride_A =
@@ -119,10 +121,23 @@ class TestMxGemmUtil : public ::testing::Test
ck_tile::HostTensor<ScaleType> scale_b_host(ck_tile::host_tensor_descriptor(
scale_k_size, N, stride_scale_b, is_row_major(BLayout{})));
ck_tile::FillUniformDistribution<ADataType>{-2.f, 2.f, seed++}(a_host);
ck_tile::FillUniformDistribution<BDataType>{-2.f, 2.f, seed++}(b_host);
ck_tile::FillUniformDistribution<ScaleType>{0.001f, 10.f, seed++}(scale_a_host);
ck_tile::FillUniformDistribution<ScaleType>{0.001f, 10.f, seed++}(scale_b_host);
std::mt19937 gen(42);
std::uniform_int_distribution<std::uint32_t> fill_seed(0, 500);
auto gen_scales = [&](auto& scales, float range_min, float range_max) {
// e8m0_t is basically an exponent of float32
ck_tile::HostTensor<float> pow2(scales.get_lengths());
ck_tile::FillUniformDistributionIntegerValue<float>{
range_min, range_max, fill_seed(gen)}(pow2);
scales.ForEach([&](auto& self, const auto& i) {
self(i) = static_cast<ScaleType>(std::exp2(pow2(i)));
});
};
ck_tile::FillUniformDistribution<ADataType>{-2.f, 2.f, fill_seed(gen)}(a_host);
ck_tile::FillUniformDistribution<BDataType>{-2.f, 2.f, fill_seed(gen)}(b_host);
gen_scales(scale_a_host, -2, 2);
gen_scales(scale_b_host, -2, 2);
// Compute effective XdlPack sizes based on GemmConfig tile dimensions
constexpr ck_tile::index_t MPerXdl = GemmConfig::M_Warp_Tile;

View File

@@ -20,4 +20,3 @@ TYPED_TEST(TestCkTileMxGroupedGemm, Basic)
this->Run(Ms, Ns, Ks, kbatch, group_count);
}