mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-11 16:59:10 +00:00
[rocm-libraries] ROCm/rocm-libraries#5552 (commit 369c7a2)
[CK Tile] Eight Waves pipeline for MX GEMM (#5552) ## Motivation Integrate Eight Waves pipeline in MX GEMM ## Technical Details - EightWaves pipeline: - Add pipeline, policy and block gemm (internally using existing implementation used by GEMM and ABQuant) - Extend support of EightWaves policy for FP4 (packed types) - Async pipeline: - Fix pipeline with packed scales (requires MRepeat and NRepeat to be contiguous) - block gemm specific for MX GEMM is defined because distribution encodings have changed - CShuffle: - Add new functionality to support MRepeat and NRepeat contiguous (defined by `TilesPacked`) - Examples: - Refactor examples to easily switch different configurations (similar to GEMM universal) - Scales values generated consistently with other microscale implementations in CK Tile - Add configuration for EightWaves pipeline - Tests: - Unify existing FP8 and FP4 tests - Add tests for EightWaves pipeline - Scales values generated consistently with other microscale implementations in CK Tile Note: FP6 support for MX GEMM was added later and the support for the Eight Waves pipeline will be done in following PR ## Test Plan Add new pipeline to tests: `test_ck_tile_mx_gemm_async` for both FP4 and FP8 ## Submission Checklist - [x] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
@@ -14,6 +14,8 @@ endforeach()
|
||||
if(has_supported_gpu)
|
||||
add_executable(tile_example_mx_gemm mx_gemm.cpp)
|
||||
set(EXAMPLE_MX_GEMM_COMPILE_OPTIONS -Wno-undefined-func-template)
|
||||
list(APPEND EXAMPLE_MX_GEMM_COMPILE_OPTIONS "SHELL: -mllvm -enable-noalias-to-md-conversion=1")
|
||||
list(APPEND EXAMPLE_MX_GEMM_COMPILE_OPTIONS "SHELL: -mllvm -greedy-reverse-local-assignment=1")
|
||||
if(CK_USE_OCP_FP8)
|
||||
list(APPEND EXAMPLE_MX_GEMM_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8)
|
||||
endif()
|
||||
|
||||
@@ -102,9 +102,9 @@ float invoke_mx_gemm(ck_tile::DeviceMem& a_dev_buf,
|
||||
auto create_args(int argc, char* argv[])
|
||||
{
|
||||
ck_tile::ArgParser arg_parser;
|
||||
arg_parser.insert("m", "4096", "m dimension")
|
||||
.insert("n", "4096", "n dimension")
|
||||
.insert("k", "4096", "k dimension")
|
||||
arg_parser.insert("m", "1024", "m dimension")
|
||||
.insert("n", "1024", "n dimension")
|
||||
.insert("k", "2048", "k dimension")
|
||||
.insert("a_layout", "R", "A tensor data layout - Row by default")
|
||||
.insert("b_layout", "C", "B tensor data layout - Row by default")
|
||||
.insert("c_layout", "R", "C tensor data layout - Row by default")
|
||||
@@ -125,4 +125,4 @@ auto create_args(int argc, char* argv[])
|
||||
|
||||
#include "run_mx_gemm.inc"
|
||||
|
||||
int main(int argc, char* argv[]) { return run_mx_gemm_example(argc, argv); }
|
||||
int main(int argc, char* argv[]) { return run_mx_gemm_example<MX_GemmConfig16>(argc, argv); }
|
||||
|
||||
@@ -9,6 +9,7 @@
|
||||
#include "ck_tile/host/kernel_launch.hpp"
|
||||
#include "ck_tile/ops/epilogue.hpp"
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
#include "ck_tile/ops/gemm_mx.hpp"
|
||||
#include "ck_tile/ops/gemm_mx/kernel/scale_pointer.hpp"
|
||||
|
||||
template <typename ScaleM, typename ScaleN>
|
||||
@@ -83,17 +84,23 @@ struct MxGemmConfig
|
||||
static constexpr int N_Repeat = N_Tile / N_Warp_Tile / N_Warp;
|
||||
static constexpr bool TiledMMAPermuteN = false;
|
||||
};
|
||||
struct MXfp4_GemmConfig16 : MxGemmConfig
|
||||
|
||||
struct MX_GemmConfigEightWaves : MxGemmConfig
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Tile = 64;
|
||||
static constexpr ck_tile::index_t N_Tile = 64;
|
||||
static constexpr ck_tile::index_t K_Tile = 256;
|
||||
static constexpr ck_tile::index_t M_Warp = 4;
|
||||
static constexpr ck_tile::index_t N_Warp = 2; // NWarps == 2 for ping-pong!
|
||||
static constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
static constexpr ck_tile::index_t M_Tile = 128;
|
||||
static constexpr ck_tile::index_t N_Tile = 128 * N_Warp;
|
||||
static constexpr ck_tile::index_t K_Tile = 128 * K_Warp;
|
||||
|
||||
static constexpr int kBlockPerCu = 2;
|
||||
};
|
||||
|
||||
// GEMM config with 16x16 warp tile
|
||||
struct MXfp8_GemmConfig16 : MxGemmConfig
|
||||
struct MX_GemmConfig16 : MxGemmConfig
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Tile = 64;
|
||||
static constexpr ck_tile::index_t N_Tile = 64;
|
||||
static constexpr ck_tile::index_t N_Tile = 128;
|
||||
static constexpr ck_tile::index_t K_Tile = 256;
|
||||
};
|
||||
|
||||
@@ -57,7 +57,12 @@ float mx_gemm_calc(const MXGemmHostArgs<ScaleM, ScaleN>& args, const ck_tile::st
|
||||
GemmConfig::Scheduler>;
|
||||
|
||||
// Use the new MX comp_async pipeline with MX scaling support
|
||||
using MXGemmPipeline = ck_tile::MXGemmPipelineAgBgCrCompAsync<MXPipelineProblem>;
|
||||
constexpr bool IsEightWave =
|
||||
(GemmConfig::M_Warp * GemmConfig::N_Warp * GemmConfig::K_Warp) == 8;
|
||||
using MXGemmPipeline =
|
||||
std::conditional_t<IsEightWave,
|
||||
ck_tile::MXGemmPipelineAgBgCrCompAsyncEightWaves<MXPipelineProblem>,
|
||||
ck_tile::MXGemmPipelineAgBgCrCompAsync<MXPipelineProblem>>;
|
||||
|
||||
using TilePartitioner =
|
||||
ck_tile::GemmSpatiallyLocalTilePartitioner<GemmShape,
|
||||
@@ -80,7 +85,15 @@ float mx_gemm_calc(const MXGemmHostArgs<ScaleM, ScaleN>& args, const ck_tile::st
|
||||
GemmConfig::M_Warp_Tile,
|
||||
GemmConfig::N_Warp_Tile,
|
||||
GemmConfig::K_Warp_Tile,
|
||||
MXPipelineProblem::TransposeC>>;
|
||||
MXPipelineProblem::TransposeC,
|
||||
1, // kNumWaveGroups_ (Default)
|
||||
false, // FixedVectorSize_ (Default)
|
||||
1, // VectorSizeC_ (Default)
|
||||
1, // BlockedXDLN_PerWarp_ (Default)
|
||||
false, // DoubleSmemBuffer_ (Default)
|
||||
ComputeDataType, // AComputeDataType
|
||||
ComputeDataType, // BComputeDataType
|
||||
true>>; // TilesPacked_ (because of packed scales)
|
||||
|
||||
using Kernel = ck_tile::MXGemmKernel<TilePartitioner, MXGemmPipeline, GemmEpilogue>;
|
||||
|
||||
|
||||
@@ -124,29 +124,42 @@ int run_mx_gemm_with_layouts(int argc, char* argv[], ALayout, BLayout, CLayout)
|
||||
ck_tile::host_tensor_descriptor(M, scale_k_size, stride_scale_a, is_row_major(ALayout{})));
|
||||
ck_tile::HostTensor<ScaleType> scale_b_host(
|
||||
ck_tile::host_tensor_descriptor(scale_k_size, N, stride_scale_b, is_row_major(BLayout{})));
|
||||
int seed = 1234;
|
||||
|
||||
std::mt19937 gen(42);
|
||||
std::uniform_int_distribution<std::uint32_t> fill_seed(0, 500);
|
||||
|
||||
auto gen_scales = [&](auto& scales, float range_min, float range_max) {
|
||||
// e8m0_t is basically an exponent of float32
|
||||
ck_tile::HostTensor<float> pow2(scales.get_lengths());
|
||||
ck_tile::FillUniformDistributionIntegerValue<float>{range_min, range_max, fill_seed(gen)}(
|
||||
pow2);
|
||||
scales.ForEach([&](auto& self, const auto& i) {
|
||||
self(i) = static_cast<ScaleType>(std::exp2(pow2(i)));
|
||||
});
|
||||
};
|
||||
|
||||
switch(init_method)
|
||||
{
|
||||
case 0:
|
||||
// Initialize A, B, and scales to random values
|
||||
ck_tile::FillUniformDistribution<ADataType>{-2.f, 2.f, seed++}(a_host);
|
||||
ck_tile::FillUniformDistribution<BDataType>{-2.f, 2.f, seed++}(b_host);
|
||||
ck_tile::FillUniformDistribution<ScaleType>{0.001f, 10.f, seed++}(scale_a_host);
|
||||
ck_tile::FillUniformDistribution<ScaleType>{0.001f, 10.f, seed++}(scale_b_host);
|
||||
ck_tile::FillUniformDistribution<ADataType>{-2.f, 2.f, fill_seed(gen)}(a_host);
|
||||
ck_tile::FillUniformDistribution<BDataType>{-2.f, 2.f, fill_seed(gen)}(b_host);
|
||||
gen_scales(scale_a_host, -2, 2);
|
||||
gen_scales(scale_b_host, -2, 2);
|
||||
break;
|
||||
case 1:
|
||||
// Initialize A, B, and scales to 1.0
|
||||
ck_tile::FillConstant<ADataType>{ADataType(1.f)}(a_host);
|
||||
ck_tile::FillConstant<BDataType>{BDataType(1.f)}(b_host);
|
||||
ck_tile::FillConstant<ScaleType>{ScaleType(1.f)}(scale_a_host);
|
||||
ck_tile::FillConstant<ScaleType>{ScaleType(1.f)}(scale_b_host);
|
||||
gen_scales(scale_a_host, 0, 0);
|
||||
gen_scales(scale_b_host, 0, 0);
|
||||
break;
|
||||
case 2:
|
||||
// Initialize A and B with random values but with constant 1.0 scales
|
||||
ck_tile::FillUniformDistribution<ADataType>{-2.f, 2.f, seed++}(a_host);
|
||||
ck_tile::FillUniformDistribution<BDataType>{-2.f, 2.f, seed++}(b_host);
|
||||
ck_tile::FillConstant<ScaleType>{ScaleType(0.1f)}(scale_a_host);
|
||||
ck_tile::FillConstant<ScaleType>{ScaleType(0.1f)}(scale_b_host);
|
||||
ck_tile::FillUniformDistribution<ADataType>{-2.f, 2.f, fill_seed(gen)}(a_host);
|
||||
ck_tile::FillUniformDistribution<BDataType>{-2.f, 2.f, fill_seed(gen)}(b_host);
|
||||
gen_scales(scale_a_host, 0, 0);
|
||||
gen_scales(scale_b_host, 0, 0);
|
||||
break;
|
||||
}
|
||||
|
||||
@@ -248,6 +261,7 @@ int run_mx_gemm_with_layouts(int argc, char* argv[], ALayout, BLayout, CLayout)
|
||||
return pass ? 0 : -1;
|
||||
}
|
||||
|
||||
template <typename GemmConfig>
|
||||
int run_mx_gemm_example(int argc, char* argv[])
|
||||
{
|
||||
auto [result, arg_parser] = create_args(argc, argv);
|
||||
@@ -268,7 +282,7 @@ int run_mx_gemm_example(int argc, char* argv[])
|
||||
return run_mx_gemm_with_layouts<ck_tile::pk_fp4_t,
|
||||
ck_tile::pk_fp4_t,
|
||||
float,
|
||||
MXfp4_GemmConfig16,
|
||||
GemmConfig,
|
||||
true>(argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
else if(mx_prec == "fp8" || mx_prec == "fp8xfp8")
|
||||
@@ -276,7 +290,7 @@ int run_mx_gemm_example(int argc, char* argv[])
|
||||
return run_mx_gemm_with_layouts<ck_tile::fp8_t,
|
||||
ck_tile::fp8_t,
|
||||
float,
|
||||
MXfp8_GemmConfig16,
|
||||
GemmConfig,
|
||||
true>(argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
else
|
||||
|
||||
@@ -2793,13 +2793,15 @@ template <typename T,
|
||||
CK_TILE_DEVICE void amd_async_buffer_load_with_oob(CK_TILE_LDS_ADDR T* smem,
|
||||
const __amdgpu_buffer_rsrc_t rsrc,
|
||||
index_t src_thread_element_offset,
|
||||
index_t src_wave_addr_offset,
|
||||
index_t src_wave_element_offset,
|
||||
linear_offset_t,
|
||||
bool is_valid_element,
|
||||
bool_constant<oob_conditional_check> = {})
|
||||
{
|
||||
index_t src_thread_addr_offset = src_thread_element_offset * sizeof(T);
|
||||
constexpr index_t src_linear_addr_offset = static_cast<index_t>(linear_offset_t{}) * sizeof(T);
|
||||
constexpr index_t PackedSize = numeric_traits<T>::PackedSize;
|
||||
index_t src_wave_addr_offset = src_wave_element_offset * sizeof(T) / PackedSize;
|
||||
|
||||
amd_async_buffer_load<T, N, coherence>(smem,
|
||||
rsrc,
|
||||
|
||||
@@ -36,7 +36,8 @@ template <typename AsDataType_,
|
||||
index_t BlockedXDLN_PerWarp_ = 1, // The number of continuous xdl_output per warp
|
||||
bool DoubleSmemBuffer_ = false,
|
||||
typename AComputeDataType_ = void,
|
||||
typename BComputeDataType_ = void>
|
||||
typename BComputeDataType_ = void,
|
||||
bool TilesPacked_ = false>
|
||||
struct CShuffleEpilogueProblem
|
||||
{
|
||||
using AsDataType = remove_cvref_t<AsDataType_>;
|
||||
@@ -64,7 +65,7 @@ struct CShuffleEpilogueProblem
|
||||
static constexpr bool DoubleSmemBuffer = DoubleSmemBuffer_;
|
||||
static constexpr index_t kNumWaveGroups = kNumWaveGroups_;
|
||||
static constexpr index_t NumDTensor = DsDataType::size();
|
||||
|
||||
static constexpr bool TilesPacked = TilesPacked_;
|
||||
static_assert(NumDTensor == DsLayout::size(),
|
||||
"The size of DsDataType and DsLayout should be the same");
|
||||
};
|
||||
@@ -140,15 +141,19 @@ struct CShuffleEpilogue
|
||||
static constexpr bool EightWave = false;
|
||||
#endif
|
||||
|
||||
// If the wave tiles computed by a single wave are packed
|
||||
// This implies that in the block gemm MRepeat and NRepeat are contiguous
|
||||
static constexpr bool TilesPacked = Problem::TilesPacked;
|
||||
static constexpr index_t BlockedXDLN_PerWarp =
|
||||
EightWave ? kNPerBlock / NWave / NPerXdl : Problem::BlockedXDLN_PerWarp;
|
||||
static constexpr bool DoubleSmemBuffer = Problem::DoubleSmemBuffer;
|
||||
static constexpr index_t VectorSizeC = Problem::VectorSizeC;
|
||||
static constexpr index_t MPerIteration = MPerXdl * MWave;
|
||||
static constexpr index_t NPerIteration = NPerXdl * NWave;
|
||||
static constexpr index_t NumDTensor = Problem::NumDTensor;
|
||||
static constexpr index_t MRepeat = kMPerBlock / (MPerXdl * MWave);
|
||||
static constexpr index_t NRepeat = kNPerBlock / (NPerXdl * NWave);
|
||||
(EightWave || TilesPacked) ? kNPerBlock / NWave / NPerXdl : Problem::BlockedXDLN_PerWarp;
|
||||
static constexpr index_t BlockedXDLM_PerWarp = (TilesPacked) ? kMPerBlock / MWave / MPerXdl : 1;
|
||||
static constexpr bool DoubleSmemBuffer = Problem::DoubleSmemBuffer;
|
||||
static constexpr index_t VectorSizeC = Problem::VectorSizeC;
|
||||
static constexpr index_t MPerIteration = MPerXdl * MWave;
|
||||
static constexpr index_t NPerIteration = NPerXdl * NWave;
|
||||
static constexpr index_t NumDTensor = Problem::NumDTensor;
|
||||
static constexpr index_t MRepeat = kMPerBlock / (MPerXdl * MWave);
|
||||
static constexpr index_t NRepeat = kNPerBlock / (NPerXdl * NWave);
|
||||
|
||||
CDElementwise elfunc_;
|
||||
|
||||
@@ -288,7 +293,8 @@ struct CShuffleEpilogue
|
||||
}
|
||||
}
|
||||
}();
|
||||
static constexpr index_t NumMXdlPerWavePerShuffle = std::get<0>(shuffle_tile_tuple);
|
||||
static constexpr index_t NumMXdlPerWavePerShuffle =
|
||||
max(BlockedXDLM_PerWarp, std::get<0>(shuffle_tile_tuple));
|
||||
static constexpr index_t NumNXdlPerWavePerShuffle =
|
||||
max(BlockedXDLN_PerWarp, std::get<1>(shuffle_tile_tuple));
|
||||
|
||||
@@ -447,64 +453,96 @@ struct CShuffleEpilogue
|
||||
CK_TILE_DEVICE static constexpr auto MakeLdsDistributionEncode()
|
||||
{
|
||||
constexpr auto block_outer_dstr_encoding = [] {
|
||||
if constexpr(BlockedXDLN_PerWarp == 1)
|
||||
if constexpr(TilesPacked)
|
||||
{
|
||||
return tile_distribution_encoding<sequence<>,
|
||||
tuple<sequence<NumMXdlPerWavePerShuffle, MWave>,
|
||||
sequence<NumNXdlPerWavePerShuffle, NWave>>,
|
||||
tuple<sequence<1, 2>>,
|
||||
tuple<sequence<1, 1>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
if constexpr(EightWave)
|
||||
{
|
||||
constexpr int RakedXDLN_PerWarp =
|
||||
NumNXdlPerWavePerShuffle / BlockedXDLN_PerWarp;
|
||||
return tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<MWave, NumMXdlPerWavePerShuffle>,
|
||||
sequence<RakedXDLN_PerWarp, NWave, BlockedXDLN_PerWarp>>,
|
||||
tuple<sequence<2, 1>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
sequence<1, 2, 2>,
|
||||
sequence<1, 0, 2>>{};
|
||||
}
|
||||
else
|
||||
{
|
||||
return tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<MWave, NumMXdlPerWavePerShuffle>,
|
||||
sequence<NWave, NumNXdlPerWavePerShuffle>>,
|
||||
tuple<sequence<1, 2>>,
|
||||
tuple<sequence<0, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<1, 1>>{};
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
#if defined(__gfx950__) || defined(__gfx12__)
|
||||
constexpr auto UseBlockedLayout = true;
|
||||
#else
|
||||
constexpr auto UseBlockedLayout = false;
|
||||
#endif
|
||||
constexpr int RakedXDLN_PerWarp = NumNXdlPerWavePerShuffle / BlockedXDLN_PerWarp;
|
||||
// BlockedLayout
|
||||
// this branch is for original a16w4
|
||||
if constexpr(UseBlockedLayout ||
|
||||
is_any_of<ADataTypeBuf, pk_int4_t, pk_fp4_t>::value ||
|
||||
is_any_of<BDataTypeBuf, pk_int4_t, pk_fp4_t>::value)
|
||||
if constexpr(BlockedXDLN_PerWarp == 1)
|
||||
{
|
||||
if constexpr(EightWave)
|
||||
return tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<NumMXdlPerWavePerShuffle, MWave>,
|
||||
sequence<NumNXdlPerWavePerShuffle, NWave>>,
|
||||
tuple<sequence<1, 2>>,
|
||||
tuple<sequence<1, 1>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
}
|
||||
else
|
||||
{
|
||||
#if defined(__gfx950__) || defined(__gfx12__)
|
||||
constexpr auto UseBlockedLayout = true;
|
||||
#else
|
||||
constexpr auto UseBlockedLayout = false;
|
||||
#endif
|
||||
constexpr int RakedXDLN_PerWarp =
|
||||
NumNXdlPerWavePerShuffle / BlockedXDLN_PerWarp;
|
||||
// BlockedLayout
|
||||
// this branch is for original a16w4
|
||||
if constexpr(UseBlockedLayout ||
|
||||
is_any_of<ADataTypeBuf, pk_int4_t, pk_fp4_t>::value ||
|
||||
is_any_of<BDataTypeBuf, pk_int4_t, pk_fp4_t>::value)
|
||||
{
|
||||
return tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<NumMXdlPerWavePerShuffle, MWave>,
|
||||
sequence<RakedXDLN_PerWarp, NWave, BlockedXDLN_PerWarp>>,
|
||||
tuple<sequence<2, 1>>,
|
||||
tuple<sequence<1, 1>>,
|
||||
sequence<1, 2, 2>,
|
||||
sequence<0, 0, 2>>{};
|
||||
if constexpr(EightWave)
|
||||
{
|
||||
return tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<NumMXdlPerWavePerShuffle, MWave>,
|
||||
sequence<RakedXDLN_PerWarp, NWave, BlockedXDLN_PerWarp>>,
|
||||
tuple<sequence<2, 1>>,
|
||||
tuple<sequence<1, 1>>,
|
||||
sequence<1, 2, 2>,
|
||||
sequence<0, 0, 2>>{};
|
||||
}
|
||||
else
|
||||
{
|
||||
return tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<NumMXdlPerWavePerShuffle, MWave>,
|
||||
sequence<RakedXDLN_PerWarp, NWave, BlockedXDLN_PerWarp>>,
|
||||
tuple<sequence<1, 2>>,
|
||||
tuple<sequence<1, 1>>,
|
||||
sequence<1, 2, 2>,
|
||||
sequence<0, 0, 2>>{};
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
return tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<NumMXdlPerWavePerShuffle, MWave>,
|
||||
sequence<RakedXDLN_PerWarp, NWave, BlockedXDLN_PerWarp>>,
|
||||
sequence<RakedXDLN_PerWarp, BlockedXDLN_PerWarp, NWave>>,
|
||||
tuple<sequence<1, 2>>,
|
||||
tuple<sequence<1, 2>>,
|
||||
tuple<sequence<1, 1>>,
|
||||
sequence<1, 2, 2>,
|
||||
sequence<0, 0, 2>>{};
|
||||
sequence<0, 0, 1>>{};
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
return tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<NumMXdlPerWavePerShuffle, MWave>,
|
||||
sequence<RakedXDLN_PerWarp, BlockedXDLN_PerWarp, NWave>>,
|
||||
tuple<sequence<1, 2>>,
|
||||
tuple<sequence<1, 2>>,
|
||||
sequence<1, 2, 2>,
|
||||
sequence<0, 0, 1>>{};
|
||||
}
|
||||
}
|
||||
}();
|
||||
constexpr auto block_dstr_encoding = detail::make_embed_tile_distribution_encoding(
|
||||
|
||||
@@ -388,150 +388,6 @@ struct BlockGemmARegBRegCRegV1
|
||||
});
|
||||
}
|
||||
|
||||
// C += A * B with MX scaling and packed-in-two (XdlPack) optimization
|
||||
// Scale tensors contain pre-packed int32_t: each int32_t holds MXdlPack * KXdlPack e8m0_t
|
||||
// values (for A) or NXdlPack * KXdlPack (for B), packed on the host.
|
||||
// Uses OpSel (0-3) to select which byte within the packed int32_t for each MFMA call.
|
||||
// XdlPack template parameters default to 2; fall back to 1 when iteration count is too small.
|
||||
template <typename CBlockTensor,
|
||||
typename ABlockTensor,
|
||||
typename BBlockTensor,
|
||||
typename ScaleATensor,
|
||||
typename ScaleBTensor,
|
||||
index_t MXdlPack_ = 2,
|
||||
index_t NXdlPack_ = 2,
|
||||
index_t KXdlPack_ = 2>
|
||||
CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor,
|
||||
const ABlockTensor& a_block_tensor,
|
||||
const BBlockTensor& b_block_tensor,
|
||||
const ScaleATensor& scale_a_tensor,
|
||||
const ScaleBTensor& scale_b_tensor) const
|
||||
{
|
||||
static_assert(std::is_same_v<ADataType, remove_cv_t<typename ABlockTensor::DataType>> &&
|
||||
std::is_same_v<BDataType, remove_cv_t<typename BBlockTensor::DataType>> &&
|
||||
std::is_same_v<CDataType, remove_cv_t<typename CBlockTensor::DataType>>,
|
||||
"wrong!");
|
||||
|
||||
// check ABC-block-distribution
|
||||
static_assert(
|
||||
std::is_same_v<remove_cvref_t<decltype(MakeABlockDistributionEncode())>,
|
||||
remove_cvref_t<decltype(ABlockTensor::get_tile_distribution()
|
||||
.get_static_tile_distribution_encoding())>>,
|
||||
"A distribution is wrong!");
|
||||
static_assert(
|
||||
std::is_same_v<remove_cvref_t<decltype(MakeBBlockDistributionEncode())>,
|
||||
remove_cvref_t<decltype(BBlockTensor::get_tile_distribution()
|
||||
.get_static_tile_distribution_encoding())>>,
|
||||
"B distribution is wrong!");
|
||||
static_assert(
|
||||
std::is_same_v<remove_cvref_t<decltype(MakeCBlockDistributionEncode())>,
|
||||
remove_cvref_t<decltype(CBlockTensor::get_tile_distribution()
|
||||
.get_static_tile_distribution_encoding())>>,
|
||||
"C distribution is wrong!");
|
||||
|
||||
using AWarpDstr = typename WarpGemm::AWarpDstr;
|
||||
using BWarpDstr = typename WarpGemm::BWarpDstr;
|
||||
using CWarpDstr = typename WarpGemm::CWarpDstr;
|
||||
|
||||
using AWarpTensor = typename WarpGemm::AWarpTensor;
|
||||
using BWarpTensor = typename WarpGemm::BWarpTensor;
|
||||
using CWarpTensor = typename WarpGemm::CWarpTensor;
|
||||
|
||||
constexpr auto a_warp_y_lengths =
|
||||
to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
|
||||
constexpr auto b_warp_y_lengths =
|
||||
to_sequence(BWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
|
||||
constexpr auto c_warp_y_lengths =
|
||||
to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
|
||||
|
||||
constexpr auto a_warp_y_index_zeros = uniform_sequence_gen_t<AWarpDstr::NDimY, 0>{};
|
||||
constexpr auto b_warp_y_index_zeros = uniform_sequence_gen_t<BWarpDstr::NDimY, 0>{};
|
||||
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
|
||||
|
||||
// Effective XdlPack: fall back to 1 when iteration count is insufficient
|
||||
constexpr index_t MXdlPack =
|
||||
(MIterPerWarp >= MXdlPack_ && MIterPerWarp % MXdlPack_ == 0) ? MXdlPack_ : 1;
|
||||
constexpr index_t NXdlPack =
|
||||
(NIterPerWarp >= NXdlPack_ && NIterPerWarp % NXdlPack_ == 0) ? NXdlPack_ : 1;
|
||||
constexpr index_t KXdlPack =
|
||||
(KIterPerWarp >= KXdlPack_ && KIterPerWarp % KXdlPack_ == 0) ? KXdlPack_ : 1;
|
||||
|
||||
constexpr index_t MPackIterPerWarp = MIterPerWarp / MXdlPack;
|
||||
constexpr index_t NPackIterPerWarp = NIterPerWarp / NXdlPack;
|
||||
constexpr index_t KPackIterPerWarp = KIterPerWarp / KXdlPack;
|
||||
|
||||
// hot loop with MX scaling and pre-packed int32_t scales:
|
||||
// Outer loops iterate over pack groups (scale tile indices)
|
||||
static_ford<sequence<KPackIterPerWarp, MPackIterPerWarp>>{}([&](auto ii) {
|
||||
constexpr auto ikpack = number<ii[number<0>{}]>{};
|
||||
constexpr auto impack = number<ii[number<1>{}]>{};
|
||||
// Get pre-packed int32_t A scale (already contains MXdlPack*KXdlPack e8m0_t)
|
||||
auto scale_a_slice = scale_a_tensor.get_y_sliced_thread_data(
|
||||
sequence<ikpack, impack, 0>{}, sequence<1, 1, 1>{});
|
||||
const int32_t a_scale_packed = bit_cast<int32_t>(scale_a_slice[number<0>{}]);
|
||||
|
||||
static_for<0, NPackIterPerWarp, 1>{}([&](auto inpack) {
|
||||
// Get pre-packed int32_t B scale
|
||||
auto scale_b_slice = scale_b_tensor.get_y_sliced_thread_data(
|
||||
sequence<ikpack, inpack, 0>{}, sequence<1, 1, 1>{});
|
||||
const int32_t b_scale_packed = bit_cast<int32_t>(scale_b_slice[number<0>{}]);
|
||||
|
||||
// Inner loops: issue MFMAs within the pack group using OpSel
|
||||
static_ford<sequence<KXdlPack, MXdlPack>>{}([&](auto jj) {
|
||||
constexpr auto ikxdl = number<jj[number<0>{}]>{};
|
||||
constexpr auto imxdl = number<jj[number<1>{}]>{};
|
||||
constexpr auto kIter = ikpack * KXdlPack + ikxdl;
|
||||
constexpr auto mIter = impack * MXdlPack + imxdl;
|
||||
|
||||
// read A warp tensor from A block tensor
|
||||
AWarpTensor a_warp_tensor;
|
||||
a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, kIter>{}, a_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
|
||||
|
||||
// OpSel for A: selects byte within packed int32_t
|
||||
constexpr index_t kOpSelA = ikxdl * MXdlPack + imxdl;
|
||||
|
||||
static_for<0, NXdlPack, 1>{}([&](auto inxdl) {
|
||||
constexpr auto nIter = inpack * NXdlPack + inxdl;
|
||||
|
||||
// read B warp tensor from B block tensor
|
||||
BWarpTensor b_warp_tensor;
|
||||
b_warp_tensor.get_thread_buffer() = b_block_tensor.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<nIter, kIter>{}, b_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, b_warp_y_lengths));
|
||||
|
||||
// OpSel for B: selects byte within packed int32_t
|
||||
constexpr index_t kOpSelB = ikxdl * NXdlPack + inxdl;
|
||||
|
||||
// read C warp tensor from C block tensor
|
||||
using c_iter_idx = std::conditional_t<TransposeC,
|
||||
sequence<nIter, mIter>,
|
||||
sequence<mIter, nIter>>;
|
||||
CWarpTensor c_warp_tensor;
|
||||
c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data(
|
||||
merge_sequences(c_iter_idx{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
|
||||
// warp GEMM with MX scaling using pre-packed scale and OpSel
|
||||
WarpGemm{}.template operator()<OpSelA<kOpSelA>, OpSelB<kOpSelB>>(
|
||||
c_warp_tensor,
|
||||
a_warp_tensor,
|
||||
b_warp_tensor,
|
||||
a_scale_packed,
|
||||
b_scale_packed);
|
||||
|
||||
// write C warp tensor into C block tensor
|
||||
c_block_tensor.set_y_sliced_thread_data(
|
||||
merge_sequences(c_iter_idx{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
|
||||
c_warp_tensor.get_thread_buffer());
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE static constexpr auto MakeCBlockTile()
|
||||
{
|
||||
using c_distr_ys_major = std::conditional_t<TransposeC, sequence<2, 1>, sequence<1, 2>>;
|
||||
|
||||
@@ -118,12 +118,7 @@ struct GemmPipelineAgBgCrCompAsyncEightWaves : public BaseGemmPipelineAgBgCrComp
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
|
||||
{
|
||||
// We are not storing the original packed type in LDS, so we need to multiply the smem size
|
||||
// by the packed size.
|
||||
constexpr index_t smem_size_a = Policy::template GetSmemSizeA<Problem>() * APackedSize;
|
||||
constexpr index_t smem_size_b = Policy::template GetSmemSizeB<Problem>() * BPackedSize;
|
||||
|
||||
return 2 * (smem_size_a + smem_size_b);
|
||||
return Policy::template GetSmemSize<Problem>();
|
||||
}
|
||||
|
||||
static constexpr index_t MFMA_INST = MIterPerWarp * NIterPerWarp * KIterPerWarp;
|
||||
|
||||
@@ -17,10 +17,9 @@ namespace detail {
|
||||
template <typename Problem>
|
||||
struct GemmPipelineAgBgCrCompAsyncEightWavesPolicy
|
||||
{
|
||||
static constexpr auto I0 = number<0>{};
|
||||
static constexpr auto I1 = number<1>{};
|
||||
static constexpr auto I2 = number<2>{};
|
||||
static constexpr auto WGAccessDouble = WGAttrNumAccessEnum::Double;
|
||||
static constexpr auto I0 = number<0>{};
|
||||
static constexpr auto I1 = number<1>{};
|
||||
static constexpr auto I2 = number<2>{};
|
||||
|
||||
using ALayout = remove_cvref_t<typename Problem::ALayout>;
|
||||
using BLayout = remove_cvref_t<typename Problem::BLayout>;
|
||||
@@ -29,14 +28,21 @@ struct GemmPipelineAgBgCrCompAsyncEightWavesPolicy
|
||||
using CDataType = remove_cvref_t<typename Problem::CDataType>;
|
||||
using AComputeDataType = remove_cvref_t<typename Problem::AComputeDataType>;
|
||||
using BComputeDataType = remove_cvref_t<typename Problem::BComputeDataType>;
|
||||
static_assert(std::is_same_v<ALayout, ck_tile::tensor_layout::gemm::RowMajor>, "Wrong!");
|
||||
static_assert(std::is_same_v<BLayout, ck_tile::tensor_layout::gemm::ColumnMajor>, "Wrong!");
|
||||
static_assert(std::is_same_v<AComputeDataType, fp8_t> ||
|
||||
std::is_same_v<AComputeDataType, bf8_t>);
|
||||
static_assert(std::is_same_v<BComputeDataType, fp8_t> ||
|
||||
std::is_same_v<BComputeDataType, bf8_t>);
|
||||
using ComputeDataType = AComputeDataType;
|
||||
static_assert(std::is_same_v<ALayout, ck_tile::tensor_layout::gemm::RowMajor>,
|
||||
"ALayout must be RowMajor!");
|
||||
static_assert(std::is_same_v<BLayout, ck_tile::tensor_layout::gemm::ColumnMajor>,
|
||||
"BLayout must be ColumnMajor!");
|
||||
static_assert(is_any_of<AComputeDataType, fp8_t, bf8_t, pk_fp4_t>::value);
|
||||
static_assert(is_any_of<BComputeDataType, fp8_t, bf8_t, pk_fp4_t>::value);
|
||||
static_assert(std::is_same_v<AComputeDataType, BComputeDataType>);
|
||||
static_assert(std::is_same_v<CDataType, float>);
|
||||
|
||||
static constexpr auto WGAccess = std::is_same_v<ComputeDataType, fp8_t>
|
||||
? WGAttrNumAccessEnum::Double
|
||||
: WGAttrNumAccessEnum::Single;
|
||||
static constexpr auto PackedSize = numeric_traits<ComputeDataType>::PackedSize;
|
||||
|
||||
using BlockGemmShape = typename Problem::BlockGemmShape;
|
||||
using BlockWarps = typename BlockGemmShape::BlockWarps;
|
||||
using WarpTile = typename BlockGemmShape::WarpTile;
|
||||
@@ -88,7 +94,7 @@ struct GemmPipelineAgBgCrCompAsyncEightWavesPolicy
|
||||
static constexpr index_t NIterPerWarp = NWarpTiles / NWarps;
|
||||
static constexpr index_t KPerWarp = KPerBlock / KWarps;
|
||||
static constexpr index_t NPerWarp = NPerBlock / NWarps;
|
||||
static_assert(NWarps == 2, "KWarps == 2 for ping-pong!");
|
||||
static_assert(NWarps == 2, "NWarps == 2 for ping-pong!");
|
||||
static_assert(KWarpTiles == KWarps, "Wrong!");
|
||||
|
||||
static constexpr index_t warp_size = get_warp_size();
|
||||
@@ -98,8 +104,8 @@ struct GemmPipelineAgBgCrCompAsyncEightWavesPolicy
|
||||
|
||||
static_assert(sizeof(ADataType) == sizeof(BDataType), "Wrong!");
|
||||
static constexpr index_t ElementSize = sizeof(ADataType);
|
||||
static constexpr index_t K2 = Problem::VectorLoadSize / ElementSize; // 16
|
||||
static constexpr index_t K1 = WarpTile::at(I2) / K2; // 8
|
||||
static constexpr index_t K2 = Problem::VectorLoadSize / ElementSize * PackedSize; // 16
|
||||
static constexpr index_t K1 = WarpTile::at(I2) / K2; // 8
|
||||
static constexpr index_t K0 = KPerWarp / (K1 * K2);
|
||||
static_assert(K0 * K1 * K2 == KPerWarp, "Wrong!");
|
||||
static_assert(K0 == 1, "Wrong!");
|
||||
@@ -176,7 +182,7 @@ struct GemmPipelineAgBgCrCompAsyncEightWavesPolicy
|
||||
const index_t k_tiles = cols / (KWarps * K1 * K2);
|
||||
const auto col_lens = make_tuple(k_tiles, number<KWarps>{}, number<K1>{}, number<K2>{});
|
||||
|
||||
constexpr index_t M1 = warp_size / static_cast<index_t>(WGAccessDouble) / K1; // 4
|
||||
constexpr index_t M1 = warp_size / static_cast<index_t>(WGAccess) / K1; // 4
|
||||
const index_t M0 = integer_divide_ceil(rows, M1);
|
||||
const auto row_lens = make_tuple(M0, number<M1>{});
|
||||
|
||||
@@ -227,9 +233,9 @@ struct GemmPipelineAgBgCrCompAsyncEightWavesPolicy
|
||||
template <index_t MNPerBlock, index_t warp_groups_>
|
||||
CK_TILE_DEVICE static constexpr auto MakeABLdsBlockDescriptor_()
|
||||
{
|
||||
constexpr index_t M4 = warp_size / static_cast<index_t>(WGAccessDouble) / K1; // 4
|
||||
constexpr index_t M3 = static_cast<index_t>(WGAccessDouble); // 2
|
||||
constexpr index_t M2 = WarpTileM / M4 / M3; // 2
|
||||
constexpr index_t M4 = warp_size / static_cast<index_t>(WGAccess) / K1; // 4
|
||||
constexpr index_t M3 = static_cast<index_t>(WGAccess); // 2
|
||||
constexpr index_t M2 = WarpTileM / M4 / M3; // 2
|
||||
constexpr index_t M1 = (warp_num / warp_groups_) / M2;
|
||||
constexpr index_t M0 = MNPerBlock / M1 / M2 / M3 / M4;
|
||||
|
||||
@@ -337,12 +343,14 @@ struct GemmPipelineAgBgCrCompAsyncEightWavesPolicy
|
||||
CK_TILE_DEVICE static constexpr index_t GetSmemSizeA()
|
||||
{
|
||||
constexpr index_t desc_size = MakeALdsBlockDescriptor().get_element_space_size();
|
||||
return integer_least_multiple(sizeof(typename Problem::ADataType) * desc_size, 16);
|
||||
return integer_least_multiple(sizeof(typename Problem::ADataType) * desc_size / PackedSize,
|
||||
16);
|
||||
}
|
||||
CK_TILE_DEVICE static constexpr index_t GetSmemSizeB()
|
||||
{
|
||||
constexpr index_t desc_size = MakeBLdsBlockDescriptor().get_element_space_size();
|
||||
return integer_least_multiple(sizeof(typename Problem::BDataType) * desc_size, 16);
|
||||
return integer_least_multiple(sizeof(typename Problem::BDataType) * desc_size / PackedSize,
|
||||
16);
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE static constexpr index_t GetSmemSize()
|
||||
@@ -361,7 +369,7 @@ struct GemmPipelineAgBgCrCompAsyncEightWavesPolicy
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm()
|
||||
{
|
||||
// TODO: Fix for transpose
|
||||
constexpr auto wg_attr_num_access = WGAttrNumAccessEnum::Double;
|
||||
constexpr auto wg_attr_num_access = WGAccess;
|
||||
|
||||
using WarpGemm = WarpGemmDispatcher<typename Problem::ADataType,
|
||||
typename Problem::BDataType,
|
||||
|
||||
@@ -199,6 +199,22 @@ struct GemmPipelineAgBgCrEightWavesImplBase : public GemmPipelineAgBgCrImplBase<
|
||||
return 0;
|
||||
}
|
||||
|
||||
template <typename AQDramBlockWindowTmp,
|
||||
typename std::enable_if_t<std::is_same_v<AQDramBlockWindowTmp, NullTileWindowType>,
|
||||
bool>* = nullptr>
|
||||
CK_TILE_DEVICE static constexpr auto GetInstCountAQ(const AQDramBlockWindowTmp&)
|
||||
{
|
||||
return 0;
|
||||
}
|
||||
|
||||
template <typename BQDramBlockWindowTmp,
|
||||
typename std::enable_if_t<std::is_same_v<BQDramBlockWindowTmp, NullTileWindowType>,
|
||||
bool>* = nullptr>
|
||||
CK_TILE_DEVICE static constexpr auto GetInstCountBQ(const BQDramBlockWindowTmp&)
|
||||
{
|
||||
return 0;
|
||||
}
|
||||
|
||||
// A/B Quant
|
||||
template <typename AQDramBlockWindowTmp,
|
||||
typename std::enable_if_t<!std::is_same_v<AQDramBlockWindowTmp, NullTileWindowType>,
|
||||
@@ -234,6 +250,22 @@ struct GemmPipelineAgBgCrEightWavesImplBase : public GemmPipelineAgBgCrImplBase<
|
||||
return Policy::template GetKStepBQ<Problem>();
|
||||
}
|
||||
|
||||
template <typename AQDramBlockWindowTmp,
|
||||
typename std::enable_if_t<!std::is_same_v<AQDramBlockWindowTmp, NullTileWindowType>,
|
||||
bool>* = nullptr>
|
||||
CK_TILE_DEVICE static constexpr auto GetInstCountAQ(const AQDramBlockWindowTmp&)
|
||||
{
|
||||
return Policy::template GetInstCountAQ<Problem>();
|
||||
}
|
||||
|
||||
template <typename BQDramBlockWindowTmp,
|
||||
typename std::enable_if_t<!std::is_same_v<BQDramBlockWindowTmp, NullTileWindowType>,
|
||||
bool>* = nullptr>
|
||||
CK_TILE_DEVICE static constexpr auto GetInstCountBQ(const BQDramBlockWindowTmp&)
|
||||
{
|
||||
return Policy::template GetInstCountBQ<Problem>();
|
||||
}
|
||||
|
||||
template <bool HasHotLoop,
|
||||
TailNumber TailNum,
|
||||
typename ADramBlockWindowTmp,
|
||||
@@ -258,14 +290,6 @@ struct GemmPipelineAgBgCrEightWavesImplBase : public GemmPipelineAgBgCrImplBase<
|
||||
: 0;
|
||||
static_assert(N_LOOP >= 1, "wrong!");
|
||||
|
||||
// Instructions Count
|
||||
constexpr index_t VectorSizeB = Policy::template GetVectorSizeB<Problem>();
|
||||
constexpr index_t B_LOAD_INST = NPerBlock * KPerBlock / BlockSize / VectorSizeB;
|
||||
constexpr index_t AQ_LOAD_INST =
|
||||
std::is_same_v<AQDramBlockWindowTmp, NullTileWindowType> ? 0 : MIterPerWarp;
|
||||
constexpr index_t BQ_LOAD_INST =
|
||||
std::is_same_v<BQDramBlockWindowTmp, NullTileWindowType> ? 0 : 1;
|
||||
|
||||
// -----
|
||||
// Setup
|
||||
// -----
|
||||
@@ -314,6 +338,12 @@ struct GemmPipelineAgBgCrEightWavesImplBase : public GemmPipelineAgBgCrImplBase<
|
||||
constexpr AQDramTileWindowStep aq_move_step = {0, GetKStepAQ(aq_copy_dram_window)};
|
||||
constexpr BQDramTileWindowStep bq_move_step = {0, GetKStepBQ(bq_copy_dram_window)};
|
||||
|
||||
// Instructions Count
|
||||
constexpr index_t VectorSizeB = Policy::template GetVectorSizeB<Problem>();
|
||||
constexpr index_t B_LOAD_INST = NPerBlock * KPerBlock / BlockSize / VectorSizeB;
|
||||
constexpr index_t AQ_LOAD_INST = GetInstCountAQ(aq_copy_dram_window);
|
||||
constexpr index_t BQ_LOAD_INST = GetInstCountBQ(bq_copy_dram_window);
|
||||
|
||||
// -------
|
||||
// Lambdas
|
||||
// -------
|
||||
|
||||
@@ -2,10 +2,14 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/ops/gemm_mx/block/block_mx_gemm_areg_breg_creg_eight_waves_v1.hpp"
|
||||
#include "ck_tile/ops/gemm_mx/block/block_mx_gemm_areg_breg_creg_v1.hpp"
|
||||
#include "ck_tile/ops/gemm_mx/kernel/gemm_mx_kernel.hpp"
|
||||
#include "ck_tile/ops/gemm_mx/kernel/scale_pointer.hpp"
|
||||
#include "ck_tile/ops/gemm_mx/pipeline/gemm_pipeline_ag_bg_cr_comp_async.hpp"
|
||||
#include "ck_tile/ops/gemm_mx/pipeline/gemm_pipeline_ag_bg_cr_comp_async_default_policy.hpp"
|
||||
#include "ck_tile/ops/gemm_mx/pipeline/gemm_pipeline_ag_bg_cr_comp_async_eight_waves.hpp"
|
||||
#include "ck_tile/ops/gemm_mx/pipeline/gemm_pipeline_ag_bg_cr_comp_async_eight_waves_policy.hpp"
|
||||
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
|
||||
#include "ck_tile/ops/common/load_and_convert_tile.hpp"
|
||||
#include "ck_tile/ops/common/streamk_common.hpp"
|
||||
|
||||
@@ -0,0 +1,310 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1_default_policy.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// A is block distributed tensor
|
||||
// B is block distributed tensor
|
||||
// C is block distributed tensor
|
||||
template <typename Problem_, typename Policy_ = BlockGemmARegBRegCRegV1DefaultPolicy>
|
||||
struct BlockMXGemmARegBRegCRegEightWavesV1
|
||||
{
|
||||
private:
|
||||
template <typename PipelineProblem_, typename GemmPolicy_>
|
||||
struct GemmTraits_
|
||||
{
|
||||
using Problem = remove_cvref_t<PipelineProblem_>;
|
||||
using Policy = remove_cvref_t<GemmPolicy_>;
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
using CDataType = remove_cvref_t<typename Problem::CDataType>;
|
||||
using AComputeDataType = remove_cvref_t<typename Problem::AComputeDataType>;
|
||||
using BComputeDataType = remove_cvref_t<typename Problem::BComputeDataType>;
|
||||
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
|
||||
|
||||
static constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
static constexpr auto Scheduler = Problem::Scheduler;
|
||||
|
||||
static constexpr index_t MPerBlock = BlockGemmShape::kM;
|
||||
static constexpr index_t NPerBlock = BlockGemmShape::kN;
|
||||
static constexpr index_t KPerBlock = BlockGemmShape::kK;
|
||||
|
||||
static constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>();
|
||||
using WarpGemm = remove_cvref_t<decltype(config.template at<0>())>;
|
||||
|
||||
static constexpr index_t MWarp = config.template at<1>();
|
||||
static constexpr index_t NWarp = config.template at<2>();
|
||||
static constexpr index_t KWarp = Problem::BlockGemmShape::BlockWarps::at(number<2>{});
|
||||
|
||||
using I0 = number<0>;
|
||||
using I1 = number<1>;
|
||||
|
||||
static_assert(MWarp == BlockGemmShape::BlockWarps::at(I0{}),
|
||||
"Error! WarpGemm's MWarp is not consistent with BlockGemmShape!");
|
||||
static_assert(NWarp == BlockGemmShape::BlockWarps::at(I1{}),
|
||||
"Error! WarpGemm's NWarp is not consistent with BlockGemmShape!");
|
||||
static_assert(WarpGemm::kM == BlockGemmShape::WarpTile::at(I0{}),
|
||||
"Error! WarpGemm's M is not consistent with BlockGemmShape!");
|
||||
static_assert(WarpGemm::kN == BlockGemmShape::WarpTile::at(I1{}),
|
||||
"Error! WarpGemm's N is not consistent with BlockGemmShape!");
|
||||
|
||||
static constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WarpGemm::kM);
|
||||
static constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WarpGemm::kN);
|
||||
static constexpr index_t KIterPerWarp = KPerBlock / (KWarp * WarpGemm::kK);
|
||||
|
||||
// Controls how many MAC clusters (MFMA blocks) we have per wave
|
||||
// If InterWaveSchedulingMacClusters = 1;
|
||||
// Then we group all WarpGemms into single MAC cluster.
|
||||
// But if InterWaveSchedulingMacClusters = 2, then we
|
||||
// split the warp gemms into two groups.
|
||||
static constexpr index_t InterWaveSchedulingMacClusters = 1;
|
||||
|
||||
static constexpr index_t KPackA = WarpGemm::kAKPack;
|
||||
static constexpr index_t KPackB = WarpGemm::kBKPack;
|
||||
static constexpr index_t KPerThread = KIterPerWarp * WarpGemm::kKPerThread;
|
||||
static constexpr bool TransposeC = Problem::TransposeC;
|
||||
};
|
||||
|
||||
public:
|
||||
using Problem = remove_cvref_t<Problem_>;
|
||||
using Policy = remove_cvref_t<Policy_>;
|
||||
using Traits = GemmTraits_<Problem, Policy>;
|
||||
|
||||
using WarpGemm = typename Traits::WarpGemm;
|
||||
using BlockGemmShape = typename Traits::BlockGemmShape;
|
||||
|
||||
using ADataType = remove_cvref_t<typename Traits::ADataType>;
|
||||
using BDataType = remove_cvref_t<typename Traits::BDataType>;
|
||||
using CDataType = remove_cvref_t<typename Traits::CDataType>;
|
||||
using AComputeDataType = remove_cvref_t<typename Traits::AComputeDataType>;
|
||||
using BComputeDataType = remove_cvref_t<typename Traits::BComputeDataType>;
|
||||
|
||||
static constexpr index_t KIterPerWarp = Traits::KIterPerWarp;
|
||||
static constexpr index_t MIterPerWarp = Traits::MIterPerWarp;
|
||||
static constexpr index_t NIterPerWarp = Traits::NIterPerWarp;
|
||||
|
||||
static constexpr index_t MWarp = Traits::MWarp;
|
||||
static constexpr index_t NWarp = Traits::NWarp;
|
||||
static constexpr index_t KWarp = Traits::KWarp;
|
||||
|
||||
static constexpr auto Scheduler = Traits::Scheduler;
|
||||
static constexpr bool TransposeC = Traits::TransposeC;
|
||||
|
||||
using AWarpDstr = typename WarpGemm::AWarpDstr;
|
||||
using BWarpDstr = typename WarpGemm::BWarpDstr;
|
||||
using CWarpDstr = typename WarpGemm::CWarpDstr;
|
||||
|
||||
using AWarpTensor = typename WarpGemm::AWarpTensor;
|
||||
using BWarpTensor = typename WarpGemm::BWarpTensor;
|
||||
using CWarpTensor = typename WarpGemm::CWarpTensor;
|
||||
|
||||
static_assert(std::is_same_v<typename WarpGemm::CDataType, float>);
|
||||
|
||||
static constexpr auto a_warp_y_lengths =
|
||||
to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
|
||||
static constexpr auto b_warp_y_lengths =
|
||||
to_sequence(BWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
|
||||
static constexpr auto c_warp_y_lengths =
|
||||
to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
|
||||
|
||||
static constexpr auto a_warp_y_index_zeros = uniform_sequence_gen_t<AWarpDstr::NDimY, 0>{};
|
||||
static constexpr auto b_warp_y_index_zeros = uniform_sequence_gen_t<BWarpDstr::NDimY, 0>{};
|
||||
static constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
|
||||
|
||||
static constexpr index_t APackedSize =
|
||||
ck_tile::numeric_traits<remove_cvref_t<ADataType>>::PackedSize;
|
||||
static constexpr index_t BPackedSize =
|
||||
ck_tile::numeric_traits<remove_cvref_t<BDataType>>::PackedSize;
|
||||
|
||||
using I0 = number<0>;
|
||||
using I1 = number<1>;
|
||||
|
||||
// Note: distribution encodings have MIterPerWarp and NIterPerWarp contiguous because of scale
|
||||
// packing.
|
||||
|
||||
CK_TILE_DEVICE static constexpr auto MakeABlockDistributionEncode()
|
||||
{
|
||||
constexpr index_t KPerThread = Traits::KPerThread;
|
||||
constexpr index_t NumMacClusters = Traits::InterWaveSchedulingMacClusters;
|
||||
|
||||
constexpr index_t KPerInnerLoop =
|
||||
ck_tile::max(KPerThread / NumMacClusters, WarpGemm::kKPerThread);
|
||||
|
||||
constexpr index_t KIterInterwave = KPerInnerLoop / WarpGemm::kKPerThread;
|
||||
|
||||
using KIterSeq = std::conditional_t<Scheduler == GemmPipelineScheduler::Interwave,
|
||||
sequence<KWarp, KIterInterwave>,
|
||||
sequence<KWarp, KIterPerWarp>>;
|
||||
|
||||
constexpr auto a_block_outer_dstr_encoding =
|
||||
tile_distribution_encoding<sequence<2, NWarp / 2>,
|
||||
tuple<sequence<MWarp, MIterPerWarp>, KIterSeq>,
|
||||
tuple<sequence<0, 2, 1, 0>>,
|
||||
tuple<sequence<0, 0, 0, 1>>,
|
||||
sequence<1, 2>,
|
||||
sequence<1, 1>>{};
|
||||
constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
a_block_outer_dstr_encoding, typename WarpGemm::AWarpDstrEncoding{});
|
||||
|
||||
return a_block_dstr_encode;
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE static constexpr auto MakeBBlockDistributionEncode()
|
||||
{
|
||||
constexpr index_t KPerThread = Traits::KPerThread;
|
||||
constexpr index_t NumMacClusters = Traits::InterWaveSchedulingMacClusters;
|
||||
constexpr index_t KPerInnerLoop =
|
||||
ck_tile::max(KPerThread / NumMacClusters, WarpGemm::kKPerThread);
|
||||
constexpr index_t KIterInterwave = KPerInnerLoop / WarpGemm::kKPerThread;
|
||||
|
||||
using KIterSeq = std::conditional_t<Scheduler == GemmPipelineScheduler::Interwave,
|
||||
sequence<KWarp, KIterInterwave>,
|
||||
sequence<KWarp, KIterPerWarp>>;
|
||||
|
||||
constexpr auto b_block_outer_dstr_encoding =
|
||||
tile_distribution_encoding<sequence<MWarp>,
|
||||
tuple<sequence<2, NIterPerWarp, NWarp / 2>, KIterSeq>,
|
||||
tuple<sequence<2, 1, 0, 1>>,
|
||||
tuple<sequence<0, 0, 0, 2>>,
|
||||
sequence<>,
|
||||
sequence<>>{};
|
||||
|
||||
constexpr auto b_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
b_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{});
|
||||
|
||||
return b_block_dstr_encode;
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE static constexpr auto MakeCBlockDistributionEncode()
|
||||
{
|
||||
constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding<
|
||||
sequence<KWarp>,
|
||||
tuple<sequence<MWarp, MIterPerWarp>, sequence<2, NIterPerWarp, NWarp / 2>>,
|
||||
tuple<sequence<2, 0, 1, 2>>,
|
||||
tuple<sequence<0, 0, 0, 2>>,
|
||||
sequence<1, 2>,
|
||||
sequence<1, 1>>{};
|
||||
constexpr auto c_block_dstr_encoding = detail::make_embed_tile_distribution_encoding(
|
||||
c_block_outer_dstr_encoding, typename WarpGemm::CWarpDstrEncoding{});
|
||||
return c_block_dstr_encoding;
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE static constexpr auto MakeCBlockTile()
|
||||
{
|
||||
return make_static_distributed_tensor<CDataType>(
|
||||
make_static_tile_distribution(MakeCBlockDistributionEncode()));
|
||||
}
|
||||
|
||||
using ALdsTile = decltype(make_static_distributed_tensor<AComputeDataType>(
|
||||
make_static_tile_distribution(MakeABlockDistributionEncode())));
|
||||
using BLdsTiles = statically_indexed_array<
|
||||
statically_indexed_array<decltype(make_static_distributed_tensor<BComputeDataType>(
|
||||
make_static_tile_distribution(
|
||||
MakeBBlockDistributionEncode()))),
|
||||
KIterPerWarp>,
|
||||
NIterPerWarp>;
|
||||
|
||||
// C += A * B
|
||||
template <typename CBlockTensor,
|
||||
typename ScaleATensor,
|
||||
typename ScaleBTensor,
|
||||
index_t MXdlPack_ = 2,
|
||||
index_t NXdlPack_ = 2,
|
||||
index_t KXdlPack_ = 2>
|
||||
CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor,
|
||||
const ALdsTile& a_warp_tile_,
|
||||
const BLdsTiles& b_warp_tiles_,
|
||||
const ScaleATensor& scale_a_tensor,
|
||||
const ScaleBTensor& scale_b_tensor) const
|
||||
{
|
||||
// checks
|
||||
static_assert(std::is_same_v<CDataType, remove_cv_t<typename CBlockTensor::DataType>>,
|
||||
"CDataType must be same as CBlockTensor::DataType!");
|
||||
static_assert(
|
||||
std::is_same_v<remove_cvref_t<decltype(MakeCBlockDistributionEncode())>,
|
||||
remove_cvref_t<decltype(CBlockTensor::get_tile_distribution()
|
||||
.get_static_tile_distribution_encoding())>>,
|
||||
"C distribution is wrong!");
|
||||
|
||||
// Effective XdlPack: fall back to 1 when iteration count is insufficient
|
||||
constexpr index_t MXdlPack =
|
||||
(MIterPerWarp >= MXdlPack_ && MIterPerWarp % MXdlPack_ == 0) ? MXdlPack_ : 1;
|
||||
constexpr index_t NXdlPack =
|
||||
(NIterPerWarp >= NXdlPack_ && NIterPerWarp % NXdlPack_ == 0) ? NXdlPack_ : 1;
|
||||
constexpr index_t KXdlPack =
|
||||
(KIterPerWarp >= KXdlPack_ && KIterPerWarp % KXdlPack_ == 0) ? KXdlPack_ : 1;
|
||||
|
||||
constexpr index_t MPackIterPerWarp = MIterPerWarp / MXdlPack;
|
||||
constexpr index_t NPackIterPerWarp = NIterPerWarp / NXdlPack;
|
||||
constexpr index_t KPackIterPerWarp = KIterPerWarp / KXdlPack;
|
||||
|
||||
// hot loop:
|
||||
static_for_product<number<KPackIterPerWarp>,
|
||||
number<NPackIterPerWarp>,
|
||||
number<MPackIterPerWarp>>{}([&](auto ikpack, auto inpack, auto impack) {
|
||||
// get A scale for this M-K tile using get_y_sliced_thread_data
|
||||
auto scale_a_slice = scale_a_tensor.get_y_sliced_thread_data(
|
||||
sequence<ikpack, impack, 0>{}, sequence<1, 1, 1>{});
|
||||
const int32_t a_scale_packed = bit_cast<int32_t>(scale_a_slice[number<0>{}]);
|
||||
|
||||
// get B scale for this N-K tile using get_y_sliced_thread_data
|
||||
auto scale_b_slice = scale_b_tensor.get_y_sliced_thread_data(
|
||||
sequence<ikpack, inpack, 0>{}, sequence<1, 1, 1>{});
|
||||
const int32_t b_scale_packed = bit_cast<int32_t>(scale_b_slice[number<0>{}]);
|
||||
|
||||
// Inner loops: issue MFMAs within the pack group using OpSel
|
||||
static_for_product<number<KXdlPack>, number<NXdlPack>, number<MXdlPack>>{}(
|
||||
[&](auto ikxdl, auto inxdl, auto imxdl) {
|
||||
constexpr auto kIter = ikpack * KXdlPack + ikxdl;
|
||||
constexpr auto mIter = impack * MXdlPack + imxdl;
|
||||
constexpr auto nIter = inpack * NXdlPack + inxdl;
|
||||
|
||||
// OpSel for A: selects byte within packed int32_t
|
||||
constexpr index_t kOpSelA = ikxdl * MXdlPack + imxdl;
|
||||
|
||||
// OpSel for B: selects byte within packed int32_t
|
||||
constexpr index_t kOpSelB = ikxdl * NXdlPack + inxdl;
|
||||
|
||||
// read A warp tensor from A Block window
|
||||
AWarpTensor a_warp_tensor;
|
||||
a_warp_tensor.get_thread_buffer() = a_warp_tile_.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, kIter>{}, a_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
|
||||
|
||||
// read B warp tensor from B block tensor
|
||||
BWarpTensor b_warp_tensor;
|
||||
b_warp_tensor.get_thread_buffer() =
|
||||
b_warp_tiles_[number<nIter>{}][number<kIter>{}].get_thread_buffer();
|
||||
|
||||
// read C warp tensor from C block tensor
|
||||
using c_iter_idx = sequence<mIter, nIter>;
|
||||
CWarpTensor c_warp_tensor;
|
||||
c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data(
|
||||
merge_sequences(c_iter_idx{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
|
||||
// warp GEMM with MX scaling
|
||||
WarpGemm{}.template operator()<OpSelA<kOpSelA>, OpSelB<kOpSelB>>(
|
||||
c_warp_tensor,
|
||||
a_warp_tensor,
|
||||
b_warp_tensor,
|
||||
a_scale_packed,
|
||||
b_scale_packed);
|
||||
|
||||
// write C warp tensor into C block tensor
|
||||
c_block_tensor.set_y_sliced_thread_data(
|
||||
merge_sequences(c_iter_idx{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
|
||||
c_warp_tensor.get_thread_buffer());
|
||||
});
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,324 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1_default_policy.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// A is block distributed tensor
|
||||
// B is block distributed tensor
|
||||
// C is block distributed tensor
|
||||
template <typename Problem_,
|
||||
typename Policy_ = BlockGemmARegBRegCRegV1DefaultPolicy,
|
||||
bool TransposeC_ = false>
|
||||
struct BlockMXGemmARegBRegCRegV1
|
||||
{
|
||||
private:
|
||||
template <typename PipelineProblem_, typename GemmPolicy_>
|
||||
struct GemmTraits_
|
||||
{
|
||||
using Problem = remove_cvref_t<PipelineProblem_>;
|
||||
using Policy = remove_cvref_t<GemmPolicy_>;
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
using CDataType = remove_cvref_t<typename Problem::CDataType>;
|
||||
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
|
||||
|
||||
static constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
|
||||
static constexpr index_t MPerBlock = BlockGemmShape::kM;
|
||||
static constexpr index_t NPerBlock = BlockGemmShape::kN;
|
||||
static constexpr index_t KPerBlock = BlockGemmShape::kK;
|
||||
|
||||
static constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>();
|
||||
using WarpGemm = remove_cvref_t<decltype(config.template at<0>())>;
|
||||
|
||||
static constexpr index_t MWarp = config.template at<1>();
|
||||
static constexpr index_t NWarp = config.template at<2>();
|
||||
static constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WarpGemm::kM);
|
||||
static constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WarpGemm::kN);
|
||||
static constexpr index_t KIterPerWarp = KPerBlock / WarpGemm::kK;
|
||||
|
||||
static constexpr index_t KPackA = WarpGemm::kAKPack;
|
||||
static constexpr index_t KPackB = WarpGemm::kBKPack;
|
||||
};
|
||||
|
||||
public:
|
||||
using Problem = remove_cvref_t<Problem_>;
|
||||
using Policy = remove_cvref_t<Policy_>;
|
||||
static constexpr bool TransposeC = TransposeC_;
|
||||
|
||||
using Traits = GemmTraits_<Problem, Policy>;
|
||||
|
||||
using WarpGemm = typename Traits::WarpGemm;
|
||||
using BlockGemmShape = typename Traits::BlockGemmShape;
|
||||
|
||||
using ADataType = remove_cvref_t<typename Traits::ADataType>;
|
||||
using BDataType = remove_cvref_t<typename Traits::BDataType>;
|
||||
using CDataType = remove_cvref_t<typename Traits::CDataType>;
|
||||
|
||||
static constexpr index_t KIterPerWarp = Traits::KIterPerWarp;
|
||||
static constexpr index_t MIterPerWarp = Traits::MIterPerWarp;
|
||||
static constexpr index_t NIterPerWarp = Traits::NIterPerWarp;
|
||||
|
||||
static constexpr index_t MWarp = Traits::MWarp;
|
||||
static constexpr index_t NWarp = Traits::NWarp;
|
||||
static constexpr bool UseDefaultScheduler = (Problem::NumWaveGroups != 1);
|
||||
|
||||
// Note: distribution encodings have MIterPerWarp and NIterPerWarp contiguous because of scale
|
||||
// packing.
|
||||
|
||||
CK_TILE_DEVICE static constexpr auto MakeABlockDistributionEncode()
|
||||
{
|
||||
if constexpr(UseDefaultScheduler)
|
||||
{
|
||||
constexpr auto a_block_outer_dstr_encoding =
|
||||
tile_distribution_encoding<sequence<NWarp>,
|
||||
tuple<sequence<MIterPerWarp>, sequence<KIterPerWarp>>,
|
||||
tuple<>,
|
||||
tuple<>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
|
||||
constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
a_block_outer_dstr_encoding, typename WarpGemm::AWarpDstrEncoding{});
|
||||
|
||||
return a_block_dstr_encode;
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr auto a_block_outer_dstr_encoding = tile_distribution_encoding<
|
||||
sequence<NWarp>,
|
||||
tuple<sequence<MWarp, MIterPerWarp>, sequence<KIterPerWarp>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
tuple<sequence<0, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<1, 0>>{};
|
||||
constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
a_block_outer_dstr_encoding, typename WarpGemm::AWarpDstrEncoding{});
|
||||
|
||||
return a_block_dstr_encode;
|
||||
}
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE static constexpr auto MakeBBlockDistributionEncode()
|
||||
{
|
||||
if constexpr(UseDefaultScheduler)
|
||||
{
|
||||
constexpr auto b_block_outer_dstr_encoding =
|
||||
tile_distribution_encoding<sequence<MWarp>,
|
||||
tuple<sequence<NIterPerWarp>, sequence<KIterPerWarp>>,
|
||||
tuple<>,
|
||||
tuple<>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
constexpr auto b_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
b_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{});
|
||||
|
||||
return b_block_dstr_encode;
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr auto b_block_outer_dstr_encoding = tile_distribution_encoding<
|
||||
sequence<MWarp>,
|
||||
tuple<sequence<NWarp, NIterPerWarp>, sequence<KIterPerWarp>>,
|
||||
tuple<sequence<0, 1>>,
|
||||
tuple<sequence<0, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<1, 0>>{};
|
||||
constexpr auto b_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
b_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{});
|
||||
|
||||
return b_block_dstr_encode;
|
||||
}
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE static constexpr auto MakeCBlockDistributionEncode()
|
||||
{
|
||||
using c_distr_ys_major = std::conditional_t<TransposeC, sequence<2, 1>, sequence<1, 2>>;
|
||||
if constexpr(UseDefaultScheduler)
|
||||
{
|
||||
using c_distr_ys_minor = std::conditional_t<TransposeC, sequence<1, 0>, sequence<0, 1>>;
|
||||
constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding<
|
||||
sequence<MWarp>,
|
||||
tuple<sequence<MIterPerWarp>, sequence<NWarp, NIterPerWarp>>,
|
||||
tuple<sequence<1, 2>>,
|
||||
tuple<sequence<0, 0>>,
|
||||
c_distr_ys_major,
|
||||
c_distr_ys_minor>{};
|
||||
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
c_block_outer_dstr_encoding, typename WarpGemm::CWarpDstrEncoding{});
|
||||
|
||||
return c_block_dstr_encode;
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<MWarp, MIterPerWarp>, sequence<NWarp, NIterPerWarp>>,
|
||||
tuple<sequence<1, 2>>,
|
||||
tuple<sequence<0, 0>>,
|
||||
c_distr_ys_major,
|
||||
sequence<1, 1>>{};
|
||||
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
c_block_outer_dstr_encoding, typename WarpGemm::CWarpDstrEncoding{});
|
||||
|
||||
return c_block_dstr_encode;
|
||||
}
|
||||
}
|
||||
|
||||
// C += A * B with MX scaling and packed-in-two (XdlPack) optimization
|
||||
// Scale tensors contain pre-packed int32_t: each int32_t holds MXdlPack * KXdlPack e8m0_t
|
||||
// values (for A) or NXdlPack * KXdlPack (for B), packed on the host.
|
||||
// Uses OpSel (0-3) to select which byte within the packed int32_t for each MFMA call.
|
||||
// XdlPack template parameters default to 2; fall back to 1 when iteration count is too small.
|
||||
template <typename CBlockTensor,
|
||||
typename ABlockTensor,
|
||||
typename BBlockTensor,
|
||||
typename ScaleATensor,
|
||||
typename ScaleBTensor,
|
||||
index_t MXdlPack_ = 2,
|
||||
index_t NXdlPack_ = 2,
|
||||
index_t KXdlPack_ = 2>
|
||||
CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor,
|
||||
const ABlockTensor& a_block_tensor,
|
||||
const BBlockTensor& b_block_tensor,
|
||||
const ScaleATensor& scale_a_tensor,
|
||||
const ScaleBTensor& scale_b_tensor) const
|
||||
{
|
||||
static_assert(std::is_same_v<ADataType, remove_cv_t<typename ABlockTensor::DataType>> &&
|
||||
std::is_same_v<BDataType, remove_cv_t<typename BBlockTensor::DataType>> &&
|
||||
std::is_same_v<CDataType, remove_cv_t<typename CBlockTensor::DataType>>,
|
||||
"Datatypes do not match BlockTensor datatypes!");
|
||||
|
||||
// check ABC-block-distribution
|
||||
static_assert(
|
||||
std::is_same_v<remove_cvref_t<decltype(MakeABlockDistributionEncode())>,
|
||||
remove_cvref_t<decltype(ABlockTensor::get_tile_distribution()
|
||||
.get_static_tile_distribution_encoding())>>,
|
||||
"A distribution is wrong!");
|
||||
static_assert(
|
||||
std::is_same_v<remove_cvref_t<decltype(MakeBBlockDistributionEncode())>,
|
||||
remove_cvref_t<decltype(BBlockTensor::get_tile_distribution()
|
||||
.get_static_tile_distribution_encoding())>>,
|
||||
"B distribution is wrong!");
|
||||
static_assert(
|
||||
std::is_same_v<remove_cvref_t<decltype(MakeCBlockDistributionEncode())>,
|
||||
remove_cvref_t<decltype(CBlockTensor::get_tile_distribution()
|
||||
.get_static_tile_distribution_encoding())>>,
|
||||
"C distribution is wrong!");
|
||||
|
||||
using AWarpDstr = typename WarpGemm::AWarpDstr;
|
||||
using BWarpDstr = typename WarpGemm::BWarpDstr;
|
||||
using CWarpDstr = typename WarpGemm::CWarpDstr;
|
||||
|
||||
using AWarpTensor = typename WarpGemm::AWarpTensor;
|
||||
using BWarpTensor = typename WarpGemm::BWarpTensor;
|
||||
using CWarpTensor = typename WarpGemm::CWarpTensor;
|
||||
|
||||
constexpr auto a_warp_y_lengths =
|
||||
to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
|
||||
constexpr auto b_warp_y_lengths =
|
||||
to_sequence(BWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
|
||||
constexpr auto c_warp_y_lengths =
|
||||
to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
|
||||
|
||||
constexpr auto a_warp_y_index_zeros = uniform_sequence_gen_t<AWarpDstr::NDimY, 0>{};
|
||||
constexpr auto b_warp_y_index_zeros = uniform_sequence_gen_t<BWarpDstr::NDimY, 0>{};
|
||||
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
|
||||
|
||||
// Effective XdlPack: fall back to 1 when iteration count is insufficient
|
||||
constexpr index_t MXdlPack =
|
||||
(MIterPerWarp >= MXdlPack_ && MIterPerWarp % MXdlPack_ == 0) ? MXdlPack_ : 1;
|
||||
constexpr index_t NXdlPack =
|
||||
(NIterPerWarp >= NXdlPack_ && NIterPerWarp % NXdlPack_ == 0) ? NXdlPack_ : 1;
|
||||
constexpr index_t KXdlPack =
|
||||
(KIterPerWarp >= KXdlPack_ && KIterPerWarp % KXdlPack_ == 0) ? KXdlPack_ : 1;
|
||||
|
||||
constexpr index_t MPackIterPerWarp = MIterPerWarp / MXdlPack;
|
||||
constexpr index_t NPackIterPerWarp = NIterPerWarp / NXdlPack;
|
||||
constexpr index_t KPackIterPerWarp = KIterPerWarp / KXdlPack;
|
||||
|
||||
// hot loop with MX scaling and pre-packed int32_t scales:
|
||||
// Outer loops iterate over pack groups (scale tile indices)
|
||||
static_ford<sequence<KPackIterPerWarp, MPackIterPerWarp>>{}([&](auto ii) {
|
||||
constexpr auto ikpack = number<ii[number<0>{}]>{};
|
||||
constexpr auto impack = number<ii[number<1>{}]>{};
|
||||
// Get pre-packed int32_t A scale (already contains MXdlPack*KXdlPack e8m0_t)
|
||||
auto scale_a_slice = scale_a_tensor.get_y_sliced_thread_data(
|
||||
sequence<ikpack, impack, 0>{}, sequence<1, 1, 1>{});
|
||||
const int32_t a_scale_packed = bit_cast<int32_t>(scale_a_slice[number<0>{}]);
|
||||
|
||||
static_for<0, NPackIterPerWarp, 1>{}([&](auto inpack) {
|
||||
// Get pre-packed int32_t B scale
|
||||
auto scale_b_slice = scale_b_tensor.get_y_sliced_thread_data(
|
||||
sequence<ikpack, inpack, 0>{}, sequence<1, 1, 1>{});
|
||||
const int32_t b_scale_packed = bit_cast<int32_t>(scale_b_slice[number<0>{}]);
|
||||
|
||||
// Inner loops: issue MFMAs within the pack group using OpSel
|
||||
static_ford<sequence<KXdlPack, MXdlPack>>{}([&](auto jj) {
|
||||
constexpr auto ikxdl = number<jj[number<0>{}]>{};
|
||||
constexpr auto imxdl = number<jj[number<1>{}]>{};
|
||||
constexpr auto kIter = ikpack * KXdlPack + ikxdl;
|
||||
constexpr auto mIter = impack * MXdlPack + imxdl;
|
||||
|
||||
// read A warp tensor from A block tensor
|
||||
AWarpTensor a_warp_tensor;
|
||||
a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, kIter>{}, a_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
|
||||
|
||||
// OpSel for A: selects byte within packed int32_t
|
||||
constexpr index_t kOpSelA = ikxdl * MXdlPack + imxdl;
|
||||
|
||||
static_for<0, NXdlPack, 1>{}([&](auto inxdl) {
|
||||
constexpr auto nIter = inpack * NXdlPack + inxdl;
|
||||
|
||||
// read B warp tensor from B block tensor
|
||||
BWarpTensor b_warp_tensor;
|
||||
b_warp_tensor.get_thread_buffer() = b_block_tensor.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<nIter, kIter>{}, b_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, b_warp_y_lengths));
|
||||
|
||||
// OpSel for B: selects byte within packed int32_t
|
||||
constexpr index_t kOpSelB = ikxdl * NXdlPack + inxdl;
|
||||
|
||||
// read C warp tensor from C block tensor
|
||||
using c_iter_idx = std::conditional_t<TransposeC,
|
||||
sequence<nIter, mIter>,
|
||||
sequence<mIter, nIter>>;
|
||||
CWarpTensor c_warp_tensor;
|
||||
c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data(
|
||||
merge_sequences(c_iter_idx{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
|
||||
// warp GEMM with MX scaling using pre-packed scale and OpSel
|
||||
WarpGemm{}.template operator()<OpSelA<kOpSelA>, OpSelB<kOpSelB>>(
|
||||
c_warp_tensor,
|
||||
a_warp_tensor,
|
||||
b_warp_tensor,
|
||||
a_scale_packed,
|
||||
b_scale_packed);
|
||||
|
||||
// write C warp tensor into C block tensor
|
||||
c_block_tensor.set_y_sliced_thread_data(
|
||||
merge_sequences(c_iter_idx{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
|
||||
c_warp_tensor.get_thread_buffer());
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE static constexpr auto MakeCBlockTile()
|
||||
{
|
||||
return make_static_distributed_tensor<CDataType>(
|
||||
make_static_tile_distribution(MakeCBlockDistributionEncode()));
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -332,8 +332,7 @@ struct MXGemmKernel : UniversalGemmKernel<TilePartitioner_, MXGemmPipeline_, Epi
|
||||
const std::array<const BDataType*, NumBTensor>& bs_ptr,
|
||||
const std::array<const void*, NumDTensor>& ds_ptr,
|
||||
EDataType* e_ptr,
|
||||
void* smem_ptr_ping,
|
||||
void* smem_ptr_pong,
|
||||
void* smem_ptr,
|
||||
const KernelArgs<ScaleM, ScaleN>& kargs,
|
||||
const SplitKBatchOffset& splitk_batch_offset,
|
||||
const index_t i_m,
|
||||
@@ -363,24 +362,18 @@ struct MXGemmKernel : UniversalGemmKernel<TilePartitioner_, MXGemmPipeline_, Epi
|
||||
scale_a_block_window,
|
||||
scale_b_block_window,
|
||||
num_loop,
|
||||
smem_ptr_ping,
|
||||
smem_ptr_pong);
|
||||
smem_ptr);
|
||||
|
||||
// Run Epilogue Pipeline - create C block window directly
|
||||
auto c_block_window = MakeCBlockWindows(e_ptr, kargs, i_m, i_n);
|
||||
EpiloguePipeline{}(c_block_window, c_block_tile, d_block_window, smem_ptr_ping);
|
||||
EpiloguePipeline{}(c_block_window, c_block_tile, d_block_window, smem_ptr);
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemPingSize()
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
|
||||
{
|
||||
return max(MXGemmPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize());
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemPongSize()
|
||||
{
|
||||
return MXGemmPipeline::GetSmemSize();
|
||||
}
|
||||
|
||||
template <class ScaleM, class ScaleN>
|
||||
CK_TILE_DEVICE void operator()(KernelArgs<ScaleM, ScaleN> kargs,
|
||||
int partition_idx = get_block_id()) const
|
||||
@@ -389,8 +382,7 @@ struct MXGemmKernel : UniversalGemmKernel<TilePartitioner_, MXGemmPipeline_, Epi
|
||||
amd_wave_read_first_lane(TilePartitioner::GridSize(kargs.M, kargs.N));
|
||||
|
||||
// Allocate shared memory for ping pong buffers
|
||||
__shared__ char smem_ptr_ping[GetSmemPingSize()];
|
||||
__shared__ char smem_ptr_pong[GetSmemPongSize()];
|
||||
__shared__ char smem_ptr[GetSmemSize()];
|
||||
|
||||
// Support both persistent and non-persistent modes
|
||||
do
|
||||
@@ -423,8 +415,7 @@ struct MXGemmKernel : UniversalGemmKernel<TilePartitioner_, MXGemmPipeline_, Epi
|
||||
bs_ptr,
|
||||
kargs.ds_ptr,
|
||||
e_ptr,
|
||||
smem_ptr_ping,
|
||||
smem_ptr_pong,
|
||||
smem_ptr,
|
||||
kargs,
|
||||
splitk_batch_offset,
|
||||
i_m,
|
||||
|
||||
@@ -181,7 +181,8 @@ struct MXGemmPipelineAgBgCrCompAsync : public BaseMXGemmPipelineAgBgCrCompAsync<
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
|
||||
{
|
||||
return Policy::template GetSmemSize<Problem>();
|
||||
constexpr index_t smem_size = Policy::template GetSmemSize<Problem>();
|
||||
return 2 * smem_size;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto IsTransposeC()
|
||||
@@ -688,9 +689,11 @@ struct MXGemmPipelineAgBgCrCompAsync : public BaseMXGemmPipelineAgBgCrCompAsync<
|
||||
const ScaleADramBlockWindowTmp& scale_a_window,
|
||||
const ScaleBDramBlockWindowTmp& scale_b_window,
|
||||
index_t num_loop,
|
||||
void* __restrict__ p_smem_0,
|
||||
void* __restrict__ p_smem_1) const
|
||||
void* __restrict__ p_smem) const
|
||||
{
|
||||
constexpr index_t smem_size = Policy::template GetSmemSize<Problem>();
|
||||
const auto smem = reinterpret_cast<uint8_t*>(p_smem);
|
||||
|
||||
const bool has_hot_loop = Base::BlockHasHotloop(num_loop);
|
||||
const auto tail_number = Base::GetBlockLoopTailNum(num_loop);
|
||||
|
||||
@@ -703,8 +706,8 @@ struct MXGemmPipelineAgBgCrCompAsync : public BaseMXGemmPipelineAgBgCrCompAsync<
|
||||
scale_a_window,
|
||||
scale_b_window,
|
||||
num_loop,
|
||||
p_smem_0,
|
||||
p_smem_1);
|
||||
smem,
|
||||
smem + smem_size);
|
||||
};
|
||||
|
||||
return Base::TailHandler(RunPipeline, has_hot_loop, tail_number);
|
||||
@@ -720,9 +723,11 @@ struct MXGemmPipelineAgBgCrCompAsync : public BaseMXGemmPipelineAgBgCrCompAsync<
|
||||
const ScaleADramBlockWindowTmp& scale_a_window,
|
||||
const ScaleBDramBlockWindowTmp& scale_b_window,
|
||||
const index_t num_loop,
|
||||
void* __restrict__ p_smem_0,
|
||||
void* __restrict__ p_smem_1) const
|
||||
void* __restrict__ p_smem) const
|
||||
{
|
||||
constexpr index_t smem_size = Policy::template GetSmemSize<Problem>();
|
||||
const auto smem = reinterpret_cast<uint8_t*>(p_smem);
|
||||
|
||||
const bool has_hot_loop = Base::BlockHasHotloop(num_loop);
|
||||
const auto tail_number = Base::GetBlockLoopTailNum(num_loop);
|
||||
|
||||
@@ -735,8 +740,8 @@ struct MXGemmPipelineAgBgCrCompAsync : public BaseMXGemmPipelineAgBgCrCompAsync<
|
||||
scale_a_window,
|
||||
scale_b_window,
|
||||
num_loop,
|
||||
p_smem_0,
|
||||
p_smem_1);
|
||||
smem,
|
||||
smem + smem_size);
|
||||
};
|
||||
|
||||
return Base::TailHandler(RunPipeline, has_hot_loop, tail_number);
|
||||
|
||||
@@ -8,6 +8,7 @@
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
|
||||
#include "ck_tile/ops/common/tensor_layout.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp"
|
||||
#include "ck_tile/ops/gemm_mx/block/block_mx_gemm_areg_breg_creg_v1.hpp"
|
||||
#include <type_traits>
|
||||
|
||||
namespace ck_tile {
|
||||
@@ -128,7 +129,7 @@ struct MXGemmPipelineAgBgCrCompAsyncDefaultPolicy
|
||||
BlockWarps,
|
||||
WarpGemm>;
|
||||
|
||||
return BlockGemmARegBRegCRegV1<Problem, BlockGemmPolicy>{};
|
||||
return BlockMXGemmARegBRegCRegV1<Problem, BlockGemmPolicy>{};
|
||||
}
|
||||
|
||||
// XdlPack: how many e8m0_t scale values are packed into one int32_t per dimension
|
||||
@@ -170,12 +171,12 @@ struct MXGemmPipelineAgBgCrCompAsyncDefaultPolicy
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<NWarp>,
|
||||
tuple<sequence<MIterPerWarp_packed, MWarp, MPerXdl>,
|
||||
tuple<sequence<MWarp, MIterPerWarp_packed, MPerXdl>,
|
||||
sequence<KIterPerWarp_packed, K_Lane, KPerLane>>,
|
||||
tuple<sequence<0, 1>, sequence<2, 1>>,
|
||||
tuple<sequence<0, 1>, sequence<1, 2>>,
|
||||
tuple<sequence<0, 0>, sequence<1, 2>>,
|
||||
sequence<2, 1, 2>,
|
||||
sequence<0, 0, 2>>{});
|
||||
sequence<0, 1, 2>>{});
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
@@ -208,12 +209,12 @@ struct MXGemmPipelineAgBgCrCompAsyncDefaultPolicy
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<MWarp>,
|
||||
tuple<sequence<NIterPerWarp_packed, NWarp, NPerXdl>,
|
||||
tuple<sequence<NWarp, NIterPerWarp_packed, NPerXdl>,
|
||||
sequence<KIterPerWarp_packed, K_Lane, KPerLane>>,
|
||||
tuple<sequence<0, 1>, sequence<2, 1>>,
|
||||
tuple<sequence<0, 1>, sequence<1, 2>>,
|
||||
tuple<sequence<0, 0>, sequence<1, 2>>,
|
||||
sequence<2, 1, 2>,
|
||||
sequence<0, 0, 2>>{});
|
||||
sequence<0, 1, 2>>{});
|
||||
}
|
||||
};
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -0,0 +1,282 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_eight_waves_base.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp"
|
||||
#include "ck_tile/ops/gemm_mx/pipeline/gemm_pipeline_ag_bg_cr_comp_async_eight_waves_policy.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
/**
|
||||
* @brief Compute optimized pipeline version async for 8 waves
|
||||
*
|
||||
* This pipeline introduces asynchronous load from global memory to LDS,
|
||||
* skipping the intermediate loading into pipeline registers.
|
||||
*/
|
||||
template <typename Problem, typename Policy = MXGemmPipelineAgBgCrCompAsyncEightWavesPolicy>
|
||||
struct MXGemmPipelineAgBgCrCompAsyncEightWaves : public BaseGemmPipelineAgBgCrCompV3<Problem>
|
||||
{
|
||||
using Base = BaseGemmPipelineAgBgCrCompV3<Problem>;
|
||||
using PipelineImplBase = GemmPipelineAgBgCrEightWavesImplBase<Problem, Policy>;
|
||||
|
||||
using AsDataType = remove_cvref_t<typename Problem::AsDataTypeTuple>;
|
||||
using BsDataType = remove_cvref_t<typename Problem::BsDataTypeTuple>;
|
||||
using CDataType = remove_cvref_t<typename Problem::CDataType>;
|
||||
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
|
||||
|
||||
using AsLayout = remove_cvref_t<typename Problem::AsLayoutTuple>;
|
||||
using BsLayout = remove_cvref_t<typename Problem::BsLayoutTuple>;
|
||||
using CLayout = remove_cvref_t<typename Problem::CLayout>;
|
||||
|
||||
using AElementWise = remove_cvref_t<typename Problem::AElementWise>;
|
||||
using BElementWise = remove_cvref_t<typename Problem::BElementWise>;
|
||||
|
||||
using ALayout = remove_cvref_t<std::tuple_element_t<0, AsLayout>>;
|
||||
using BLayout = remove_cvref_t<std::tuple_element_t<0, BsLayout>>;
|
||||
|
||||
using ADataType = remove_cvref_t<std::tuple_element_t<0, AsDataType>>;
|
||||
using BDataType = remove_cvref_t<std::tuple_element_t<0, BsDataType>>;
|
||||
|
||||
static_assert(!std::is_same_v<BDataType, pk_int4_t>, "Not implemented");
|
||||
|
||||
static constexpr index_t APackedSize = ck_tile::numeric_traits<ADataType>::PackedSize;
|
||||
static constexpr index_t BPackedSize = ck_tile::numeric_traits<BDataType>::PackedSize;
|
||||
|
||||
using BlockGemm = remove_cvref_t<decltype(Policy::template GetBlockGemm<Problem>())>;
|
||||
using WarpGemm = typename BlockGemm::WarpGemm;
|
||||
|
||||
static constexpr auto I0 = number<0>{};
|
||||
static constexpr auto I1 = number<1>{};
|
||||
static constexpr auto I2 = number<2>{};
|
||||
|
||||
static constexpr index_t BlockSize = Problem::kBlockSize;
|
||||
|
||||
static constexpr index_t MPerBlock = BlockGemmShape::kM;
|
||||
static constexpr index_t NPerBlock = BlockGemmShape::kN;
|
||||
static constexpr index_t KPerBlock = BlockGemmShape::kK;
|
||||
|
||||
static constexpr index_t MWarps = BlockGemmShape::BlockWarps::at(I0);
|
||||
static constexpr index_t NWarps = BlockGemmShape::BlockWarps::at(I1);
|
||||
static constexpr index_t KWarps = BlockGemmShape::BlockWarps::at(I2);
|
||||
|
||||
static constexpr index_t kflatKPerBlock = BlockGemmShape::flatKPerBlock;
|
||||
|
||||
static constexpr index_t MIterPerWarp = MPerBlock / (MWarps * WarpGemm::kM);
|
||||
static constexpr index_t NIterPerWarp = NPerBlock / (NWarps * WarpGemm::kN);
|
||||
static constexpr index_t KIterPerWarp = KPerBlock / (KWarps * WarpGemm::kK);
|
||||
|
||||
static constexpr bool Async = true;
|
||||
|
||||
template <bool IsWave32Host = false>
|
||||
static constexpr index_t GetVectorSizeA()
|
||||
{
|
||||
return Policy::template GetVectorSizeA<Problem>();
|
||||
}
|
||||
template <bool IsWave32Host = false>
|
||||
static constexpr index_t GetVectorSizeB()
|
||||
{
|
||||
return Policy::template GetVectorSizeB<Problem>();
|
||||
}
|
||||
|
||||
static constexpr index_t GetSmemPackA() { return Policy::template GetSmemPackA<Problem>(); }
|
||||
static constexpr index_t GetSmemPackB() { return Policy::template GetSmemPackB<Problem>(); }
|
||||
|
||||
static constexpr index_t NumWaveGroups = Problem::NumWaveGroups;
|
||||
static constexpr index_t Preshuffle = Problem::Preshuffle;
|
||||
|
||||
static constexpr bool kPadM = Problem::kPadM;
|
||||
static constexpr bool kPadN = Problem::kPadN;
|
||||
static constexpr bool kPadK = Problem::kPadK;
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = Problem::DoubleSmemBuffer;
|
||||
|
||||
static constexpr auto Scheduler = Problem::Scheduler;
|
||||
|
||||
[[nodiscard]] CK_TILE_HOST static const std::string GetPipelineName()
|
||||
{
|
||||
// clang-format off
|
||||
return "COMPUTE_ASYNC_EIGHT_WAVES";
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
|
||||
{
|
||||
// clang-format off
|
||||
constexpr index_t WaveNumM = BlockGemmShape::BlockWarps::at(I0);
|
||||
constexpr index_t WaveNumN = BlockGemmShape::BlockWarps::at(I1);
|
||||
return concat('_', "pipeline_AgBgCrCompAsyncEightWaves",
|
||||
concat('x', MPerBlock, NPerBlock, KPerBlock), BlockSize,
|
||||
concat('x', GetVectorSizeA(), GetVectorSizeB()),
|
||||
concat('x', WaveNumM, WaveNumN),
|
||||
concat('x', kPadM, kPadN, kPadK),
|
||||
Problem::GetName());
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
|
||||
{
|
||||
return Policy::template GetSmemSize<Problem>();
|
||||
}
|
||||
|
||||
static constexpr index_t MFMA_INST = MIterPerWarp * NIterPerWarp * KIterPerWarp;
|
||||
|
||||
// Scales are packed so odd numbers of iterations greater than 1 are not supported
|
||||
static_assert((MIterPerWarp == 1) || (MIterPerWarp % 2 == 0));
|
||||
static_assert((NIterPerWarp == 1) || (NIterPerWarp % 2 == 0));
|
||||
static_assert((KIterPerWarp == 1) || (KIterPerWarp % 2 == 0));
|
||||
|
||||
template <GemmPipelineScheduler Scheduler>
|
||||
struct PipelineImpl : public PipelineImplBase
|
||||
{
|
||||
};
|
||||
|
||||
template <>
|
||||
struct PipelineImpl<GemmPipelineScheduler::Intrawave> : public PipelineImplBase
|
||||
{
|
||||
using Base = PipelineImplBase;
|
||||
|
||||
template <bool HasHotLoop,
|
||||
TailNumber TailNum,
|
||||
typename AsDramBlockWindowTmp,
|
||||
typename BsDramBlockWindowTmp,
|
||||
typename AElementFunction,
|
||||
typename BElementFunction,
|
||||
typename ScaleADramBlockWindowTmp,
|
||||
typename ScaleBDramBlockWindowTmp,
|
||||
typename std::enable_if_t<!is_detected<is_tuple, AsDramBlockWindowTmp>::value &&
|
||||
!is_detected<is_tuple, BsDramBlockWindowTmp>::value,
|
||||
bool>* = nullptr>
|
||||
CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const AElementFunction& a_element_func,
|
||||
const BsDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
const BElementFunction& b_element_func,
|
||||
const ScaleADramBlockWindowTmp& scale_a_window,
|
||||
const ScaleBDramBlockWindowTmp& scale_b_window,
|
||||
index_t num_loop,
|
||||
void* __restrict__ p_smem) const
|
||||
{
|
||||
// TODO: A/B elementwise functions currently not supported
|
||||
ignore = a_element_func;
|
||||
ignore = b_element_func;
|
||||
|
||||
// ------
|
||||
// Checks
|
||||
// ------
|
||||
static_assert(
|
||||
std::is_same_v<ADataType,
|
||||
remove_cvref_t<typename AsDramBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<BDataType,
|
||||
remove_cvref_t<typename BsDramBlockWindowTmp::DataType>>,
|
||||
"A/B Dram block window should have the same data type as appropriate "
|
||||
"([A|B]DataType) defined in Problem definition!");
|
||||
|
||||
static_assert(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>, "Wrong!");
|
||||
static_assert(std::is_same_v<BLayout, tensor_layout::gemm::ColumnMajor>, "Wrong!");
|
||||
|
||||
static_assert((MPerBlock == AsDramBlockWindowTmp{}.get_window_lengths()[I0] &&
|
||||
KPerBlock == AsDramBlockWindowTmp{}.get_window_lengths()[I1]),
|
||||
"A block window has incorrect lengths for defined ALayout!");
|
||||
static_assert(Preshuffle //
|
||||
? (NWarps == BsDramBlockWindowTmp{}.get_window_lengths()[I0] &&
|
||||
kflatKPerBlock == BsDramBlockWindowTmp{}.get_window_lengths()[I1])
|
||||
: (NPerBlock == BsDramBlockWindowTmp{}.get_window_lengths()[I0] &&
|
||||
KPerBlock == BsDramBlockWindowTmp{}.get_window_lengths()[I1]),
|
||||
"B block window has incorrect lengths for defined BLayout!");
|
||||
|
||||
// ------------------
|
||||
// Hot loop scheduler
|
||||
// ------------------
|
||||
auto hot_loop_scheduler = [&]() {
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, MIterPerWarp, 0); // MFMA
|
||||
s_waitcnt_lgkm<4>();
|
||||
__builtin_amdgcn_sched_group_barrier(0x004, 1, 0); // lgkmcnt / SALU
|
||||
static_for<0, MFMA_INST - MIterPerWarp, 1>{}([&](auto) {
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
});
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
};
|
||||
|
||||
// -------
|
||||
// Compute
|
||||
// -------
|
||||
return Base::template Run_<HasHotLoop, TailNum>(p_smem,
|
||||
num_loop,
|
||||
a_dram_block_window_tmp,
|
||||
b_dram_block_window_tmp,
|
||||
scale_a_window,
|
||||
scale_b_window,
|
||||
hot_loop_scheduler);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename AsDramBlockWindowTmp,
|
||||
typename BsDramBlockWindowTmp,
|
||||
typename AElementFunction,
|
||||
typename BElementFunction,
|
||||
typename ScaleADramBlockWindowTmp,
|
||||
typename ScaleBDramBlockWindowTmp,
|
||||
typename std::enable_if_t<!is_detected<is_tuple, AsDramBlockWindowTmp>::value &&
|
||||
!is_detected<is_tuple, BsDramBlockWindowTmp>::value,
|
||||
bool>* = nullptr>
|
||||
CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const AElementFunction& a_element_func,
|
||||
const BsDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
const BElementFunction& b_element_func,
|
||||
const ScaleADramBlockWindowTmp& scale_a_window,
|
||||
const ScaleBDramBlockWindowTmp& scale_b_window,
|
||||
index_t num_loop,
|
||||
void* p_smem) const
|
||||
{
|
||||
const bool has_hot_loop = Base::BlockHasHotloop(num_loop);
|
||||
const auto tail_number = Base::GetBlockLoopTailNum(num_loop);
|
||||
const auto RunPipeline = [&](auto hot_loop_, auto tail_num_) {
|
||||
return PipelineImpl<Scheduler>{}.template operator()<hot_loop_.value, tail_num_.value>(
|
||||
a_dram_block_window_tmp,
|
||||
a_element_func,
|
||||
b_dram_block_window_tmp,
|
||||
b_element_func,
|
||||
scale_a_window,
|
||||
scale_b_window,
|
||||
num_loop,
|
||||
p_smem);
|
||||
};
|
||||
|
||||
return Base::TailHandler(RunPipeline, has_hot_loop, tail_number);
|
||||
}
|
||||
|
||||
template <typename AsDramBlockWindowTmp,
|
||||
typename BsDramBlockWindowTmp,
|
||||
typename ScaleADramBlockWindowTmp,
|
||||
typename ScaleBDramBlockWindowTmp,
|
||||
typename std::enable_if_t<!is_detected<is_tuple, AsDramBlockWindowTmp>::value &&
|
||||
!is_detected<is_tuple, BsDramBlockWindowTmp>::value,
|
||||
bool>* = nullptr>
|
||||
CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const BsDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
const ScaleADramBlockWindowTmp& scale_a_window,
|
||||
const ScaleBDramBlockWindowTmp& scale_b_window,
|
||||
index_t num_loop,
|
||||
void* p_smem) const
|
||||
{
|
||||
const bool has_hot_loop = Base::BlockHasHotloop(num_loop);
|
||||
const auto tail_number = Base::GetBlockLoopTailNum(num_loop);
|
||||
const auto RunPipeline = [&](auto hot_loop_, auto tail_num_) {
|
||||
return PipelineImpl<Scheduler>{}.template operator()<hot_loop_.value, tail_num_.value>(
|
||||
a_dram_block_window_tmp,
|
||||
identity{},
|
||||
b_dram_block_window_tmp,
|
||||
identity{},
|
||||
scale_a_window,
|
||||
scale_b_window,
|
||||
num_loop,
|
||||
p_smem);
|
||||
};
|
||||
|
||||
return Base::TailHandler(RunPipeline, has_hot_loop, tail_number);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,203 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_async_eight_waves_policy.hpp"
|
||||
#include "ck_tile/ops/gemm_mx/block/block_mx_gemm_areg_breg_creg_eight_waves_v1.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
namespace detail {
|
||||
|
||||
template <typename Problem>
|
||||
struct MXGemmPipelineAgBgCrCompAsyncEightWavesPolicy
|
||||
{
|
||||
static constexpr auto I0 = number<0>{};
|
||||
static constexpr auto I1 = number<1>{};
|
||||
static constexpr auto I2 = number<2>{};
|
||||
|
||||
// MX scaling configuration: each e8m0 scale covers 32 elements in K
|
||||
static constexpr int BlockScaleSize = 32;
|
||||
|
||||
using ALayout = remove_cvref_t<typename Problem::ALayout>;
|
||||
using BLayout = remove_cvref_t<typename Problem::BLayout>;
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
using CDataType = remove_cvref_t<typename Problem::CDataType>;
|
||||
using AComputeDataType = remove_cvref_t<typename Problem::AComputeDataType>;
|
||||
using BComputeDataType = remove_cvref_t<typename Problem::BComputeDataType>;
|
||||
using ComputeDataType = AComputeDataType;
|
||||
static_assert(std::is_same_v<ALayout, ck_tile::tensor_layout::gemm::RowMajor>, "Wrong!");
|
||||
static_assert(std::is_same_v<BLayout, ck_tile::tensor_layout::gemm::ColumnMajor>, "Wrong!");
|
||||
static_assert(is_any_of<AComputeDataType, fp8_t, bf8_t, pk_fp4_t>::value);
|
||||
static_assert(is_any_of<BComputeDataType, fp8_t, bf8_t, pk_fp4_t>::value);
|
||||
static_assert(std::is_same_v<AComputeDataType, BComputeDataType>);
|
||||
static_assert(std::is_same_v<CDataType, float>);
|
||||
|
||||
using BlockGemmShape = typename Problem::BlockGemmShape;
|
||||
using BlockWarps = typename BlockGemmShape::BlockWarps;
|
||||
using WarpTile = typename BlockGemmShape::WarpTile;
|
||||
|
||||
static constexpr index_t BlockSize = Problem::kBlockSize;
|
||||
static constexpr index_t MPerBlock = BlockGemmShape::kM;
|
||||
static constexpr index_t NPerBlock = BlockGemmShape::kN;
|
||||
static constexpr index_t KPerBlock = BlockGemmShape::kK;
|
||||
static constexpr index_t MWarps = BlockGemmShape::BlockWarps::at(I0);
|
||||
static constexpr index_t NWarps = BlockGemmShape::BlockWarps::at(I1);
|
||||
static constexpr index_t KWarps = BlockGemmShape::BlockWarps::at(I2);
|
||||
static constexpr index_t WarpTileM = WarpTile::at(I0);
|
||||
static constexpr index_t WarpTileN = WarpTile::at(I1);
|
||||
static constexpr index_t WarpTileK = WarpTile::at(I2);
|
||||
static constexpr index_t MWarpTiles = MPerBlock / WarpTileM;
|
||||
static constexpr index_t NWarpTiles = NPerBlock / WarpTileN;
|
||||
static constexpr index_t KWarpTiles = KPerBlock / WarpTileK;
|
||||
|
||||
// XdlPack: how many e8m0_t scale values are packed into one int32_t per dimension
|
||||
// Host packs MXdlPack * KXdlPack e8m0_t into one int32_t for A scales
|
||||
// Host packs NXdlPack * KXdlPack e8m0_t into one int32_t for B scales
|
||||
static constexpr int MXdlPack = 2;
|
||||
static constexpr int NXdlPack = 2;
|
||||
static constexpr int KXdlPack = 2;
|
||||
|
||||
// Compute effective XdlPack sizes (fall back to 1 when iter count < pack)
|
||||
static constexpr index_t MPerXdl = WarpTile::at(I0);
|
||||
static constexpr index_t NPerXdl = WarpTile::at(I1);
|
||||
static constexpr index_t KPerXdl = WarpTile::at(I2);
|
||||
static constexpr index_t MIterPerWarp = MPerBlock / (MWarps * MPerXdl);
|
||||
static constexpr index_t NIterPerWarp = NPerBlock / (NWarps * NPerXdl);
|
||||
static constexpr index_t KIterPerWarp = KPerBlock / KPerXdl;
|
||||
|
||||
static constexpr index_t MXdlPackEff =
|
||||
(MIterPerWarp >= MXdlPack && MIterPerWarp % MXdlPack == 0) ? MXdlPack : 1;
|
||||
static constexpr index_t NXdlPackEff =
|
||||
(NIterPerWarp >= NXdlPack && NIterPerWarp % NXdlPack == 0) ? NXdlPack : 1;
|
||||
static constexpr index_t KXdlPackEff =
|
||||
(KIterPerWarp >= KXdlPack && KIterPerWarp % KXdlPack == 0) ? KXdlPack : 1;
|
||||
|
||||
static constexpr index_t KPerBlockScale = KPerBlock / BlockScaleSize / KXdlPackEff;
|
||||
|
||||
static constexpr index_t KPerWarp = KPerBlock / KWarps;
|
||||
static constexpr index_t NPerWarp = NPerBlock / NWarps;
|
||||
static_assert(NWarps == 2, "NWarps == 2 for ping-pong!");
|
||||
static_assert(KWarpTiles == KWarps, "Wrong!");
|
||||
|
||||
static constexpr index_t warp_size = get_warp_size();
|
||||
static constexpr index_t warp_num = BlockSize / warp_size;
|
||||
static_assert(warp_size == 64, "Wrong!");
|
||||
static_assert(warp_num * warp_size == BlockSize, "Wrong!");
|
||||
|
||||
static_assert(sizeof(ADataType) == sizeof(BDataType), "Wrong!");
|
||||
static constexpr index_t ElementSize = sizeof(ADataType);
|
||||
static constexpr index_t K2 = Problem::VectorLoadSize / ElementSize; // 16
|
||||
static constexpr index_t K1 = WarpTile::at(I2) / K2; // 8
|
||||
static constexpr index_t K0 = KPerWarp / (K1 * K2);
|
||||
static_assert(K0 * K1 * K2 == KPerWarp, "Wrong!");
|
||||
static_assert(K0 == 1, "Wrong!");
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetKStepAQ() { return KPerBlockScale; }
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetKStepBQ() { return KPerBlockScale; }
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetInstCountAQ()
|
||||
{
|
||||
return (MIterPerWarp / MXdlPackEff) * (KIterPerWarp / KXdlPackEff);
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetInstCountBQ()
|
||||
{
|
||||
return (NIterPerWarp / NXdlPackEff) * (KIterPerWarp / KXdlPackEff);
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeAQBlockDistribution()
|
||||
{
|
||||
constexpr index_t K_Lane = get_warp_size() / WarpTileM;
|
||||
|
||||
constexpr index_t KPerLane = WarpTileK / BlockScaleSize / K_Lane;
|
||||
|
||||
constexpr index_t MIterPerWarp_packed = MIterPerWarp / MXdlPackEff;
|
||||
constexpr index_t KIterPerWarp_packed = KIterPerWarp / KXdlPackEff;
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<
|
||||
sequence<NWarps>, // repeat over MWarps
|
||||
tuple<sequence<MWarps, MIterPerWarp_packed, WarpTileM>, // M dimension (first)
|
||||
sequence<KIterPerWarp_packed, K_Lane, KPerLane>>, // K dimension (second)
|
||||
tuple<sequence<0, 1>, sequence<2, 1>>, // <MWarps, NWarps>, <K_Lane, WarpTileM>
|
||||
tuple<sequence<0, 0>, sequence<1, 2>>,
|
||||
sequence<2, 1, 2>, // <KIterPerWarp, MIterPerWarp, KPerLane>
|
||||
sequence<0, 1, 2>>{});
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeBQBlockDistribution()
|
||||
{
|
||||
constexpr index_t K_Lane = get_warp_size() / WarpTileN;
|
||||
|
||||
constexpr index_t KPerLane = WarpTileK / BlockScaleSize / K_Lane;
|
||||
|
||||
constexpr index_t NIterPerWarp_packed = NIterPerWarp / NXdlPackEff;
|
||||
constexpr index_t KIterPerWarp_packed = KIterPerWarp / KXdlPackEff;
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<
|
||||
sequence<MWarps>, // repeat over MWarps
|
||||
tuple<sequence<2, NIterPerWarp_packed, NWarps / 2, WarpTileN>, // N dimension
|
||||
// (first)
|
||||
sequence<KIterPerWarp_packed, K_Lane, KPerLane>>, // K dimension (second)
|
||||
tuple<sequence<1, 0, 1>, sequence<2, 1>>, // <MWarps, NWarps>, <K_Lane, MPerXdl>
|
||||
tuple<sequence<0, 0, 2>, sequence<1, 3>>,
|
||||
sequence<2, 1, 2>, // <KIterPerWarp, NIterPerWarp, KPerLane>
|
||||
sequence<0, 1, 2>>{});
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm()
|
||||
{
|
||||
constexpr auto wg_attr_num_access =
|
||||
(std::is_same_v<ADataType, fp8_t> || std::is_same_v<BDataType, fp8_t>)
|
||||
? WGAttrNumAccessEnum::Double
|
||||
: WGAttrNumAccessEnum::Single;
|
||||
|
||||
using WarpGemm = WarpGemmDispatcher<ADataType,
|
||||
BDataType,
|
||||
CDataType, // AccDataType
|
||||
WarpTile::at(I0),
|
||||
WarpTile::at(I1),
|
||||
WarpTile::at(I2),
|
||||
Problem::TransposeC,
|
||||
false,
|
||||
false,
|
||||
wg_attr_num_access>;
|
||||
|
||||
using BlockGemmPolicy = BlockGemmARegBRegCRegV1CustomPolicy<ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
BlockWarps,
|
||||
WarpGemm>;
|
||||
|
||||
return BlockMXGemmARegBRegCRegEightWavesV1<Problem, BlockGemmPolicy>{};
|
||||
}
|
||||
};
|
||||
} // namespace detail
|
||||
|
||||
struct MXGemmPipelineAgBgCrCompAsyncEightWavesPolicy
|
||||
: public GemmPipelineAgBgCrCompAsyncEightWavesPolicy
|
||||
{
|
||||
|
||||
#define FORWARD_METHOD_(method) \
|
||||
template <typename Problem, typename... Args> \
|
||||
CK_TILE_HOST_DEVICE static constexpr auto method(Args&&... args) \
|
||||
{ \
|
||||
return detail::MXGemmPipelineAgBgCrCompAsyncEightWavesPolicy<Problem>::method( \
|
||||
std::forward<Args>(args)...); \
|
||||
}
|
||||
|
||||
FORWARD_METHOD_(MakeAQBlockDistribution);
|
||||
FORWARD_METHOD_(MakeBQBlockDistribution);
|
||||
FORWARD_METHOD_(GetBlockGemm);
|
||||
FORWARD_METHOD_(GetKStepAQ);
|
||||
FORWARD_METHOD_(GetKStepBQ);
|
||||
FORWARD_METHOD_(GetInstCountAQ);
|
||||
FORWARD_METHOD_(GetInstCountBQ);
|
||||
|
||||
#undef FORWARD_METHOD_
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -134,12 +134,7 @@ struct ABQuantGemmPipelineAgBgCrEightWaves : public BaseGemmPipelineAgBgCrCompV3
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
|
||||
{
|
||||
// We are not storing the original packed type in LDS, so we need to multiply the smem size
|
||||
// by the packed size.
|
||||
constexpr index_t smem_size_a = Policy::template GetSmemSizeA<Problem>() * APackedSize;
|
||||
constexpr index_t smem_size_b = Policy::template GetSmemSizeB<Problem>() * BPackedSize;
|
||||
|
||||
return 2 * (smem_size_a + smem_size_b);
|
||||
return Policy::template GetSmemSize<Problem>();
|
||||
}
|
||||
|
||||
CK_TILE_HOST static std::string Print() { return "ABQuantGemmPipelineAgBgCrEightWaves\n"; }
|
||||
|
||||
@@ -61,7 +61,7 @@ struct GemmABQuantPipelineAgBgCrAsyncPolicy
|
||||
static constexpr index_t NIterPerWarp = NWarpTiles / NWarps;
|
||||
static constexpr index_t KPerWarp = KPerBlock / KWarps;
|
||||
static constexpr index_t NPerWarp = NPerBlock / NWarps;
|
||||
static_assert(NWarps == 2, "KWarps == 2 for ping-pong!");
|
||||
static_assert(NWarps == 2, "NWarps == 2 for ping-pong!");
|
||||
static_assert(KWarpTiles == KWarps, "Wrong!");
|
||||
|
||||
static constexpr index_t KPerWarpAQ = KPerWarp / Problem::AQuantGroupSize::kK;
|
||||
@@ -87,6 +87,11 @@ struct GemmABQuantPipelineAgBgCrAsyncPolicy
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetKStepAQ() { return KPerBlockAQ; }
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetKStepBQ() { return KPerBlockBQ; }
|
||||
|
||||
// TODO: generalize instruction count calculation
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetInstCountAQ() { return MIterPerWarp; }
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetInstCountBQ() { return 1; }
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeAQBlockDistribution()
|
||||
{
|
||||
return make_static_tile_distribution(
|
||||
@@ -156,6 +161,8 @@ struct GemmABQuantPipelineAgBgCrAsyncPolicy : public GemmPipelineAgBgCrCompAsync
|
||||
FORWARD_METHOD_(GetBlockGemm);
|
||||
FORWARD_METHOD_(GetKStepAQ);
|
||||
FORWARD_METHOD_(GetKStepBQ);
|
||||
FORWARD_METHOD_(GetInstCountAQ);
|
||||
FORWARD_METHOD_(GetInstCountBQ);
|
||||
|
||||
#undef FORWARD_METHOD_
|
||||
};
|
||||
|
||||
@@ -7,11 +7,8 @@ if(CK_USE_OCP_FP8)
|
||||
endif()
|
||||
|
||||
if(GPU_TARGETS MATCHES "gfx95")
|
||||
add_gtest_executable(test_ck_tile_mx_gemm_fp4 test_mx_gemm_fp4.cpp)
|
||||
target_compile_options(test_ck_tile_mx_gemm_fp4 PRIVATE ${TEST_MX_GEMM_COMPILE_OPTIONS})
|
||||
|
||||
add_gtest_executable(test_ck_tile_mx_gemm_fp8 test_mx_gemm_fp8.cpp)
|
||||
target_compile_options(test_ck_tile_mx_gemm_fp8 PRIVATE ${TEST_MX_GEMM_COMPILE_OPTIONS})
|
||||
add_gtest_executable(test_ck_tile_mx_gemm_async test_mx_gemm_async.cpp)
|
||||
target_compile_options(test_ck_tile_mx_gemm_async PRIVATE ${TEST_MX_GEMM_COMPILE_OPTIONS})
|
||||
else()
|
||||
message(DEBUG "Skipping ck_tile MX GEMM tests for current target")
|
||||
endif()
|
||||
|
||||
33
test/ck_tile/gemm_mx/test_mx_gemm_async.cpp
Normal file
33
test/ck_tile/gemm_mx/test_mx_gemm_async.cpp
Normal file
@@ -0,0 +1,33 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "test_mx_gemm_config.hpp"
|
||||
#include "test_mx_gemm_util.hpp"
|
||||
|
||||
using Row = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
using F4 = ck_tile::pk_fp4_t;
|
||||
using F8 = ck_tile::fp8_t;
|
||||
using F6 = ck_tile::pk_fp6x16_t;
|
||||
|
||||
// clang-format off
|
||||
using MxTypes = ::testing::Types<std::tuple<F4, F4, MX_GemmConfig16, Row, Col, Row>,
|
||||
std::tuple<F4, F4, MX_GemmConfigEightWaves, Row, Col, Row>,
|
||||
std::tuple<F8, F8, MX_GemmConfig16, Row, Col, Row>,
|
||||
std::tuple<F8, F8, MX_GemmConfigEightWaves, Row, Col, Row>>;
|
||||
// clang-format on
|
||||
|
||||
template <typename TypeParam>
|
||||
class TestMxGemm : public TestMxGemmUtil<TypeParam>
|
||||
{
|
||||
};
|
||||
|
||||
TYPED_TEST_SUITE(TestMxGemm, MxTypes);
|
||||
|
||||
TYPED_TEST(TestMxGemm, Default)
|
||||
{
|
||||
// No M/N/K padding so we use 128x256x256 as smallest dimensions
|
||||
this->Run(128, 256, 256);
|
||||
this->Run(256, 256, 512);
|
||||
this->Run(1024, 1024, 1024);
|
||||
}
|
||||
@@ -80,16 +80,22 @@ struct MxGemmConfig
|
||||
static constexpr bool TiledMMAPermuteN = false;
|
||||
};
|
||||
|
||||
struct MXfp4_GemmConfig16 : MxGemmConfig
|
||||
struct MX_GemmConfig16 : MxGemmConfig
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Tile = 64;
|
||||
static constexpr ck_tile::index_t N_Tile = 64;
|
||||
static constexpr ck_tile::index_t N_Tile = 128;
|
||||
static constexpr ck_tile::index_t K_Tile = 256;
|
||||
};
|
||||
|
||||
struct MXfp8_GemmConfig16 : MxGemmConfig
|
||||
struct MX_GemmConfigEightWaves : MxGemmConfig
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Tile = 64;
|
||||
static constexpr ck_tile::index_t N_Tile = 64;
|
||||
static constexpr ck_tile::index_t K_Tile = 256;
|
||||
static constexpr ck_tile::index_t M_Warp = 4;
|
||||
static constexpr ck_tile::index_t N_Warp = 2; // NWarps == 2 for ping-pong!
|
||||
static constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
static constexpr ck_tile::index_t M_Tile = 128;
|
||||
static constexpr ck_tile::index_t N_Tile = 128 * N_Warp;
|
||||
static constexpr ck_tile::index_t K_Tile = 128 * K_Warp;
|
||||
|
||||
static constexpr int kBlockPerCu = 2;
|
||||
};
|
||||
|
||||
@@ -1,30 +0,0 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "test_mx_gemm_config.hpp"
|
||||
#include "test_mx_gemm_util.hpp"
|
||||
|
||||
using Row = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
using MxFp4Types = ::testing::Types<
|
||||
std::tuple<ck_tile::pk_fp4_t, ck_tile::pk_fp4_t, MXfp4_GemmConfig16, Row, Col, Row>>;
|
||||
|
||||
template <typename TypeParam>
|
||||
class TestMxGemmFp4 : public TestMxGemmUtil<std::tuple_element_t<0, TypeParam>,
|
||||
std::tuple_element_t<1, TypeParam>,
|
||||
std::tuple_element_t<2, TypeParam>,
|
||||
std::tuple_element_t<3, TypeParam>,
|
||||
std::tuple_element_t<4, TypeParam>,
|
||||
std::tuple_element_t<5, TypeParam>>
|
||||
{
|
||||
};
|
||||
|
||||
TYPED_TEST_SUITE(TestMxGemmFp4, MxFp4Types);
|
||||
|
||||
TYPED_TEST(TestMxGemmFp4, BasicSizes)
|
||||
{
|
||||
this->Run(64, 64, 256);
|
||||
this->Run(128, 128, 256);
|
||||
this->Run(64, 128, 512);
|
||||
}
|
||||
@@ -1,30 +0,0 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "test_mx_gemm_config.hpp"
|
||||
#include "test_mx_gemm_util.hpp"
|
||||
|
||||
using Row = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
using MxFp8Types =
|
||||
::testing::Types<std::tuple<ck_tile::fp8_t, ck_tile::fp8_t, MXfp8_GemmConfig16, Row, Col, Row>>;
|
||||
|
||||
template <typename TypeParam>
|
||||
class TestMxGemmFp8 : public TestMxGemmUtil<std::tuple_element_t<0, TypeParam>,
|
||||
std::tuple_element_t<1, TypeParam>,
|
||||
std::tuple_element_t<2, TypeParam>,
|
||||
std::tuple_element_t<3, TypeParam>,
|
||||
std::tuple_element_t<4, TypeParam>,
|
||||
std::tuple_element_t<5, TypeParam>>
|
||||
{
|
||||
};
|
||||
|
||||
TYPED_TEST_SUITE(TestMxGemmFp8, MxFp8Types);
|
||||
|
||||
TYPED_TEST(TestMxGemmFp8, BasicSizes)
|
||||
{
|
||||
this->Run(64, 64, 256);
|
||||
this->Run(128, 128, 256);
|
||||
this->Run(64, 128, 512);
|
||||
}
|
||||
@@ -4,8 +4,7 @@
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "ck_tile/ops/gemm_mx/pipeline/gemm_pipeline_ag_bg_cr_comp_async.hpp"
|
||||
#include "ck_tile/ops/gemm_mx/kernel/gemm_mx_kernel.hpp"
|
||||
#include "ck_tile/ops/gemm_mx.hpp"
|
||||
#include "test_mx_gemm_config.hpp"
|
||||
|
||||
template <typename GemmConfig,
|
||||
@@ -48,7 +47,12 @@ float mx_gemm_calc(const MXGemmHostArgs<ScaleM, ScaleN>& args, const ck_tile::st
|
||||
MXGemmTraits,
|
||||
GemmConfig::Scheduler>;
|
||||
|
||||
using MXGemmPipeline = ck_tile::MXGemmPipelineAgBgCrCompAsync<MXPipelineProblem>;
|
||||
constexpr bool IsEightWave =
|
||||
(GemmConfig::M_Warp * GemmConfig::N_Warp * GemmConfig::K_Warp) == 8;
|
||||
using MXGemmPipeline =
|
||||
std::conditional_t<IsEightWave,
|
||||
ck_tile::MXGemmPipelineAgBgCrCompAsyncEightWaves<MXPipelineProblem>,
|
||||
ck_tile::MXGemmPipelineAgBgCrCompAsync<MXPipelineProblem>>;
|
||||
|
||||
using TilePartitioner =
|
||||
ck_tile::GemmSpatiallyLocalTilePartitioner<GemmShape,
|
||||
@@ -71,7 +75,15 @@ float mx_gemm_calc(const MXGemmHostArgs<ScaleM, ScaleN>& args, const ck_tile::st
|
||||
GemmConfig::M_Warp_Tile,
|
||||
GemmConfig::N_Warp_Tile,
|
||||
GemmConfig::K_Warp_Tile,
|
||||
MXPipelineProblem::TransposeC>>;
|
||||
MXPipelineProblem::TransposeC,
|
||||
1, // kNumWaveGroups_ (Default)
|
||||
false, // FixedVectorSize_ (Default)
|
||||
1, // VectorSizeC_ (Default)
|
||||
1, // BlockedXDLN_PerWarp_ (Default)
|
||||
false, // DoubleSmemBuffer_ (Default)
|
||||
ADataType, // AComputeDataType
|
||||
BDataType, // BComputeDataType
|
||||
true>>; // TilesPacked_ (because of packed scales)
|
||||
|
||||
using Kernel = ck_tile::MXGemmKernel<TilePartitioner, MXGemmPipeline, GemmEpilogue>;
|
||||
|
||||
|
||||
@@ -30,15 +30,17 @@ auto calculate_rtol_atol_mx(ck_tile::index_t K, float max_accumulated_value)
|
||||
return ck_tile::make_tuple(rtol, atol);
|
||||
}
|
||||
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename GemmConfig,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename CLayout>
|
||||
template <typename Tuple>
|
||||
class TestMxGemmUtil : public ::testing::Test
|
||||
{
|
||||
protected:
|
||||
using ADataType = std::tuple_element_t<0, Tuple>;
|
||||
using BDataType = std::tuple_element_t<1, Tuple>;
|
||||
using GemmConfig = std::tuple_element_t<2, Tuple>;
|
||||
using ALayout = std::tuple_element_t<3, Tuple>;
|
||||
using BLayout = std::tuple_element_t<4, Tuple>;
|
||||
using CLayout = std::tuple_element_t<5, Tuple>;
|
||||
|
||||
using AccDataType = float;
|
||||
using CDataType = ck_tile::fp16_t;
|
||||
using ScaleType = ck_tile::e8m0_t;
|
||||
@@ -94,7 +96,7 @@ class TestMxGemmUtil : public ::testing::Test
|
||||
return packed;
|
||||
}
|
||||
|
||||
void Run(ck_tile::index_t M, ck_tile::index_t N, ck_tile::index_t K, int seed = 1234)
|
||||
void Run(ck_tile::index_t M, ck_tile::index_t N, ck_tile::index_t K)
|
||||
{
|
||||
const ck_tile::index_t scale_k_size = K / 32;
|
||||
const ck_tile::index_t stride_A =
|
||||
@@ -119,10 +121,23 @@ class TestMxGemmUtil : public ::testing::Test
|
||||
ck_tile::HostTensor<ScaleType> scale_b_host(ck_tile::host_tensor_descriptor(
|
||||
scale_k_size, N, stride_scale_b, is_row_major(BLayout{})));
|
||||
|
||||
ck_tile::FillUniformDistribution<ADataType>{-2.f, 2.f, seed++}(a_host);
|
||||
ck_tile::FillUniformDistribution<BDataType>{-2.f, 2.f, seed++}(b_host);
|
||||
ck_tile::FillUniformDistribution<ScaleType>{0.001f, 10.f, seed++}(scale_a_host);
|
||||
ck_tile::FillUniformDistribution<ScaleType>{0.001f, 10.f, seed++}(scale_b_host);
|
||||
std::mt19937 gen(42);
|
||||
std::uniform_int_distribution<std::uint32_t> fill_seed(0, 500);
|
||||
|
||||
auto gen_scales = [&](auto& scales, float range_min, float range_max) {
|
||||
// e8m0_t is basically an exponent of float32
|
||||
ck_tile::HostTensor<float> pow2(scales.get_lengths());
|
||||
ck_tile::FillUniformDistributionIntegerValue<float>{
|
||||
range_min, range_max, fill_seed(gen)}(pow2);
|
||||
scales.ForEach([&](auto& self, const auto& i) {
|
||||
self(i) = static_cast<ScaleType>(std::exp2(pow2(i)));
|
||||
});
|
||||
};
|
||||
|
||||
ck_tile::FillUniformDistribution<ADataType>{-2.f, 2.f, fill_seed(gen)}(a_host);
|
||||
ck_tile::FillUniformDistribution<BDataType>{-2.f, 2.f, fill_seed(gen)}(b_host);
|
||||
gen_scales(scale_a_host, -2, 2);
|
||||
gen_scales(scale_b_host, -2, 2);
|
||||
|
||||
// Compute effective XdlPack sizes based on GemmConfig tile dimensions
|
||||
constexpr ck_tile::index_t MPerXdl = GemmConfig::M_Warp_Tile;
|
||||
|
||||
@@ -20,4 +20,3 @@ TYPED_TEST(TestCkTileMxGroupedGemm, Basic)
|
||||
|
||||
this->Run(Ms, Ns, Ks, kbatch, group_count);
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user