[CK_Tile] Support for a4w4 (fp4) in block scale gemm AB quant (#3603)

* chore: split block scale example instances in more separate files to speed up compile times

* wip: fp4 scaffolding for abquant

* feat: add fp4 decoding-while-loading to abquant pipeline

* feat: add support for fp4 CPU verification in abquant

* chore: add time tracking to reference calculation

* feat: add a4w4 test for blockscale gemm

* feat: optimize reference calculation by preconverting values to AccType

* feat: add fp4 to fp8 look-up table

* fix: reference to wrong ComputeDataType field in QuantProblem

* feat: type utilities for determining MFMA compute types

* feat: packed fp4 for abquant weight preshuffle

* feat: add separate tests for a4w4 base case, padding and preshuffleB

* fix: fp4 conversion on gfx950 attempting to use non-supported method

* fix: test case was using quant group sizes which don't work on gfx950 due to larger mfma tile size

* chore: add fp4 preshuffleb mode to block scale example

* chore: sanity check for packed types being 1 byte

* chore: clarify tensor dimension indices with constants

* chore: replace traits check with specialized check for packed types

* style: some minor refactoring and cleanup

* fix: correct conversion table for FNUZ fp8

* chore: add fp4 instances to main abquant instances again

* chore: use same initialization branch for int4 and fp4

* chore: add missing initialization for fp4 in block scale gemm example

---------

Co-authored-by: Thomas Ning <Thomas.Ning@amd.com>

[ROCm/composable_kernel commit: 6a6177a246]
This commit is contained in:
Erwin Terpstra
2026-01-30 12:40:50 +01:00
committed by GitHub
parent 1b1dd65b83
commit a5824466fb
28 changed files with 642 additions and 175 deletions

View File

@@ -164,5 +164,35 @@ static auto _ = []() {
BQuantGroupSize,
ck_tile::QuantType::ABQuantGrouped>(arg_parser);
};
lut[hash_multiple_strings(
{"fp4", "abquant", "non-preshuffleb", "non-preshufflequant", "1x128x128"})] =
[](const ck_tile::ArgParser& arg_parser) {
using AQuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
using BQuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 128, 128>>;
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::pk_fp4_t,
ck_tile::pk_fp4_t,
ck_tile::half_t,
float>{});
return run_gemm_example_prec_type<GemmConfig<ck_tile::pk_fp4_raw_t>,
TypeConfig,
AQuantGroupSize,
BQuantGroupSize,
ck_tile::QuantType::ABQuantGrouped>(arg_parser);
};
lut[hash_multiple_strings(
{"fp4", "abquant", "preshuffleb", "non-preshufflequant", "1x128x128"})] =
[](const ck_tile::ArgParser& arg_parser) {
using AQuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
using BQuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 128, 128>>;
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::pk_fp4_t,
ck_tile::pk_fp4_t,
ck_tile::half_t,
float>{});
return run_gemm_example_prec_type<GemmConfigPreshuffleB<ck_tile::pk_fp4_raw_t>,
TypeConfig,
AQuantGroupSize,
BQuantGroupSize,
ck_tile::QuantType::ABQuantGrouped>(arg_parser);
};
return 0;
}();

View File

@@ -32,7 +32,7 @@ auto create_args(int argc, char* argv[])
.insert("prec",
"fp8",
"Data type. For AQuant: fp8, bf8, i4fp8, or i4bf8; for Bquant: fp8, bf8, fp8i4, "
"or bf8i4; for ABQuant: fp8, bf8")
"or bf8i4; for ABQuant: fp8, bf8, fp4")
.insert("warmup", "50", "Number of iterations before benchmarking the kernel")
.insert("repeat", "1000", "Number of iterations to benchmark the kernel")
.insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer")

View File

@@ -9,6 +9,7 @@
#include <stdexcept>
#include <string>
#include <tuple>
#include <type_traits>
#include "ck_tile/core/config.hpp"
#include "ck_tile/ops/common/utils.hpp"
@@ -35,10 +36,9 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str
static_assert(std::is_same_v<CLayout, ck_tile::tensor_layout::gemm::RowMajor>);
constexpr bool transpose_c =
GemmConfig::TransposeC; // QuantMode == ck_tile::QuantType::ABQuantGrouped;
using ComputeDataType = std::conditional_t<QuantMode == ck_tile::QuantType::AQuantGrouped ||
QuantMode == ck_tile::QuantType::RowColQuant,
typename TypeConfig::BDataType,
typename TypeConfig::ADataType>;
// Use automatically determined compute type from
using ComputeDataType = void;
using GemmShape = ck_tile::TileGemmShape<
ck_tile::sequence<GemmConfig::M_Tile, GemmConfig::N_Tile, GemmConfig::K_Tile>,
@@ -80,7 +80,10 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str
std::conditional_t<
QuantMode == ck_tile::QuantType::AQuantGrouped,
ck_tile::BaseGemmPipelineAgBgCrMem<GemmPipelineProblem>,
ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV2<GemmPipelineProblem>>>>;
std::conditional_t<
QuantMode == ck_tile::QuantType::ABQuantGrouped,
ck_tile::BaseGemmPipelineAgBgCrMem<GemmPipelineProblem>,
ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV2<GemmPipelineProblem>>>>>;
const ck_tile::index_t K_split = ck_tile::integer_least_multiple(args.K, GemmConfig::K_Tile);
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split);
@@ -182,30 +185,28 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str
printf(
"TiledPermuteN: %d (QuantGroupSize::kN=%d)\n", TiledPermuteN, BQuantGroupSize::kN);
}
using GemmEpilogue = ck_tile::CShuffleEpilogue<ck_tile::CShuffleEpilogueProblem<
typename TypeConfig::ADataType,
std::conditional_t<
std::is_same_v<typename TypeConfig::BDataType, ck_tile::pk_fp4_raw_t>,
typename TypeConfig::ADataType,
typename TypeConfig::BDataType>,
ck_tile::tuple<>,
typename TypeConfig::AccDataType,
typename TypeConfig::CDataType,
ck_tile::tuple<>,
CLayout,
CDEElementWise,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
GemmConfig::M_Warp,
GemmConfig::N_Warp,
GemmConfig::M_Warp_Tile,
GemmConfig::N_Warp_Tile,
GemmConfig::K_Warp_Tile,
transpose_c,
1,
false,
1,
TiledPermuteN>>;
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<typename PipelineProblem::ComputeDataType,
typename PipelineProblem::ComputeDataType,
ck_tile::tuple<>,
typename TypeConfig::AccDataType,
typename TypeConfig::CDataType,
ck_tile::tuple<>,
CLayout,
CDEElementWise,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
GemmConfig::M_Warp,
GemmConfig::N_Warp,
GemmConfig::M_Warp_Tile,
GemmConfig::N_Warp_Tile,
GemmConfig::K_Warp_Tile,
transpose_c,
1,
false,
1,
TiledPermuteN>>;
using Kernel =
ck_tile::QuantGemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue, QuantMode>;
@@ -557,8 +558,7 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser,
{
if constexpr(std::is_same_v<BDataType, ck_tile::pk_int4_t>)
{
ck_tile::FillUniformDistribution<ck_tile::pk_int4_t>{-5.0f, 5.0f, fill_seed(gen)}(
b_k_n);
ck_tile::FillUniformDistribution<BDataType>{-5.0f, 5.0f, fill_seed(gen)}(b_k_n);
ck_tile::FillUniformDistribution<BQDataType>{-2.0f, 2.0f, fill_seed(gen)}(
*bq_tensor_ptr);
}
@@ -594,18 +594,26 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser,
}
else if constexpr(QuantMode == ck_tile::QuantType::ABQuantGrouped)
{
if constexpr(std::is_same_v<ADataType, ck_tile::pk_int4_t>)
if constexpr(std::is_same_v<ADataType, ck_tile::pk_int4_t> ||
std::is_same_v<ADataType, ck_tile::pk_fp4_t>)
{
ck_tile::FillUniformDistribution<ck_tile::pk_int4_t>{-5.0f, 5.0f, fill_seed(gen)}(
a_m_k);
ck_tile::FillUniformDistribution<ck_tile::pk_int4_t>{-5.0f, 5.0f, fill_seed(gen)}(
b_k_n);
ck_tile::FillUniformDistribution<ADataType>{-5.0f, 5.0f, fill_seed(gen)}(a_m_k);
}
else
{
ck_tile::FillUniformDistribution<ADataType>{-2.0f, 3.0f, fill_seed(gen)}(a_m_k);
}
if constexpr(std::is_same_v<BDataType, ck_tile::pk_int4_t> ||
std::is_same_v<BDataType, ck_tile::pk_fp4_t>)
{
ck_tile::FillUniformDistribution<BDataType>{-5.0f, 5.0f, fill_seed(gen)}(b_k_n);
}
else
{
ck_tile::FillUniformDistribution<BDataType>{-2.0f, 3.0f, fill_seed(gen)}(b_k_n);
}
ck_tile::FillUniformDistribution<AQDataType>{-2.0f, 2.0f, fill_seed(gen)}(
*aq_tensor_ptr);
ck_tile::FillUniformDistribution<BQDataType>{-2.0f, 2.0f, fill_seed(gen)}(
@@ -723,12 +731,11 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser,
}
else if constexpr(QuantMode == ck_tile::QuantType::ABQuantGrouped)
{
if constexpr(std::is_same_v<ADataType, ck_tile::pk_int4_t>)
if constexpr(std::is_same_v<ADataType, ck_tile::pk_int4_t> ||
std::is_same_v<ADataType, ck_tile::pk_fp4_t>)
{
ck_tile::FillUniformDistribution<ck_tile::pk_int4_t>{-5.0f, 5.0f, fill_seed(gen)}(
a_m_k);
ck_tile::FillUniformDistribution<ck_tile::pk_int4_t>{-5.0f, 5.0f, fill_seed(gen)}(
b_k_n);
ck_tile::FillUniformDistribution<ADataType>{-5.0f, 5.0f, fill_seed(gen)}(a_m_k);
ck_tile::FillUniformDistribution<BDataType>{-5.0f, 5.0f, fill_seed(gen)}(b_k_n);
}
else
{
@@ -804,12 +811,11 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser,
}
else if constexpr(QuantMode == ck_tile::QuantType::ABQuantGrouped)
{
if constexpr(std::is_same_v<ADataType, ck_tile::pk_int4_t>)
if constexpr(std::is_same_v<ADataType, ck_tile::pk_int4_t> ||
std::is_same_v<ADataType, ck_tile::pk_fp4_t>)
{
ck_tile::FillUniformDistribution<ck_tile::pk_int4_t>{-5.0f, 5.0f, fill_seed(gen)}(
a_m_k);
ck_tile::FillUniformDistribution<ck_tile::pk_int4_t>{-5.0f, 5.0f, fill_seed(gen)}(
b_k_n);
ck_tile::FillUniformDistribution<ADataType>{-5.0f, 5.0f, fill_seed(gen)}(a_m_k);
ck_tile::FillUniformDistribution<BDataType>{-5.0f, 5.0f, fill_seed(gen)}(b_k_n);
}
else
{
@@ -984,10 +990,14 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser,
if(arg_parser.get_int("v") == 1)
{
std::cout << "Performing CPU verification..." << std::endl;
ck_tile::HostTensor<CDataType> c_m_n_host_ref(
ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{})));
c_m_n_host_ref.SetZero();
// Track start time for reference operation
auto start_reference_tick = std::chrono::high_resolution_clock::now();
if constexpr(QuantMode == ck_tile::QuantType::AQuantGrouped)
{
ck_tile::reference_gemm_quant<ADataType,
@@ -1051,6 +1061,9 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser,
a_m_k, *aq_tensor_ptr, b_k_n, *bq_tensor_ptr, c_m_n_host_ref);
}
// Track where we stop reference calculation, and start verification
auto start_verification_tick = std::chrono::high_resolution_clock::now();
const float max_accumulated_value =
*std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end());
const auto rtol_atol = calculate_rtol_atol<ADataType, BDataType, AccDataType, CDataType>(
@@ -1061,6 +1074,9 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser,
rtol_atol.at(ck_tile::number<0>{}),
rtol_atol.at(ck_tile::number<1>{}));
// "Stop" our timer
auto verification_finished_tick = std::chrono::high_resolution_clock::now();
if(!pass)
{
std::cout << "Relative error threshold: " << rtol_atol.at(ck_tile::number<0>{})
@@ -1068,6 +1084,21 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser,
<< std::endl;
}
std::cout << "The CPU verification result is:" << (pass ? "correct" : "fail") << std::endl;
// Calculate and display reference timing
using DurationType = std::chrono::duration<double>;
double reference_sec = std::chrono::duration_cast<DurationType>(verification_finished_tick -
start_reference_tick)
.count();
double verification_sec = std::chrono::duration_cast<DurationType>(
verification_finished_tick - start_verification_tick)
.count();
float reference_msec = static_cast<float>(reference_sec * 1e3);
float verification_msec = static_cast<float>(verification_sec * 1e3);
std::cout << std::fixed << std::setprecision(1) << "CPU reference GEMM took "
<< reference_msec << "ms, verification took " << verification_msec << "ms."
<< std::endl;
}
else if(arg_parser.get_int("v") == 2)
{
@@ -1098,6 +1129,7 @@ int run_gemm_example_prec_type(const ck_tile::ArgParser& arg_parser)
}
if constexpr(std::is_same_v<typename TypeConfig::ADataType, ck_tile::pk_int4_t> ||
std::is_same_v<typename TypeConfig::ADataType, ck_tile::pk_fp4_t> ||
std::is_same_v<typename TypeConfig::ADataType, ck_tile::fp8_t> ||
std::is_same_v<typename TypeConfig::ADataType, ck_tile::bf8_t> ||
std::is_same_v<typename TypeConfig::ADataType, ck_tile::bf16_t>)

View File

@@ -91,6 +91,7 @@
#include "ck_tile/core/utility/ignore.hpp"
#include "ck_tile/core/utility/literals.hpp"
#include "ck_tile/core/utility/magic_div.hpp"
#include "ck_tile/core/utility/mixed_prec_compute_type.hpp"
#include "ck_tile/core/utility/persistent_async_input_scheduler.hpp"
#include "ck_tile/core/utility/philox_rand.hpp"
#include "ck_tile/core/utility/print.hpp"

View File

@@ -1544,7 +1544,8 @@ CK_TILE_DEVICE thread_buffer<T, N> amd_buffer_load_impl(int32x4_t src_wave_buffe
(N == 1 || N == 2 || N == 4 || N == 8 || N == 16 || N == 32)) ||
(std::is_same<T, pk_fp4_raw_t>::value &&
(N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(std::is_same<T, pk_fp4_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)),
(std::is_same<T, pk_fp4_t>::value &&
(N == 1 || N == 2 || N == 4 || N == 8 || N == 16 || N == 32)),
"wrong! not implemented");
using rtn_type = thread_buffer<T, N>;

View File

@@ -1414,7 +1414,7 @@ CK_TILE_DEVICE thread_buffer<T, N> amd_buffer_load_impl(int32x4_t src_wave_buffe
(std::is_same<T, pk_int4_t>::value &&
(N == 1 || N == 2 || N == 4 || N == 8 || N == 16 || N == 32) ||
(std::is_same<T, pk_fp4_t>::value &&
(N == 1 || N == 2 || N == 4 || N == 8 || N == 16))),
(N == 1 || N == 2 || N == 4 || N == 8 || N == 16 || N == 32))),
"wrong! not implemented");
using rtn_type = thread_buffer<T, N>;

View File

@@ -6,6 +6,7 @@
#include <cmath>
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/numeric/half.hpp"
#include "ck_tile/core/numeric/float8.hpp"
#include "ck_tile/core/numeric/mxfp_convert.hpp"
#if defined(__gfx950__)
@@ -23,6 +24,12 @@ using fp32x2_t = float __attribute__((ext_vector_type(2)));
using fp16x2_t = _Float16 __attribute__((ext_vector_type(2)));
using bf16x2_t = bfloat16_t __attribute__((ext_vector_type(2)));
#if CK_TILE_USE_CUSTOM_DATA_TYPE
using fp8x2_t = fp8_raw_t __attribute__((ext_vector_type(2)));
#else
using fp8x2_t = fp8_t __attribute__((ext_vector_type(2)));
#endif
// Helpers: constexpr-safe access to elements of ext_vector_type(2)
// Some compilers don't allow operator[] in constant expressions for vector types.
// We use bit_cast to a trivially copyable representation to extract lanes.
@@ -98,6 +105,8 @@ struct pk_float4_e2m1_t
CK_TILE_HOST_DEVICE constexpr fp16x2_t to_fp16x2(float scale = 1.f) const;
CK_TILE_HOST_DEVICE constexpr bf16_t to_bf16(float scale = 1.f) const;
CK_TILE_HOST_DEVICE constexpr bf16x2_t to_bf16x2(float scale = 1.f) const;
CK_TILE_HOST_DEVICE constexpr fp8_t to_fp8(float scale = 1.f) const;
CK_TILE_HOST_DEVICE constexpr fp8x2_t to_fp8x2(float scale = 1.f) const;
CK_TILE_HOST_DEVICE constexpr operator float() const { return to_float(); }
CK_TILE_HOST_DEVICE constexpr operator fp32x2_t() const { return to_fp32x2(); }
@@ -105,6 +114,8 @@ struct pk_float4_e2m1_t
CK_TILE_HOST_DEVICE constexpr operator fp16x2_t() const { return to_fp16x2(); }
CK_TILE_HOST_DEVICE constexpr operator bf16_t() const { return to_bf16(); }
CK_TILE_HOST_DEVICE constexpr operator bf16x2_t() const { return to_bf16x2(); }
CK_TILE_HOST_DEVICE constexpr operator fp8_t() const { return to_fp8(); }
CK_TILE_HOST_DEVICE constexpr operator fp8x2_t() const { return to_fp8x2(); }
template <index_t I>
CK_TILE_HOST_DEVICE constexpr pk_float4_e2m1_t unpack(number<I>) const
@@ -145,6 +156,49 @@ struct pk_float4_e2m1_t
bit_cast<fp16_t>(static_cast<uint16_t>(0xC400)), // -4
bit_cast<fp16_t>(static_cast<uint16_t>(0xC600)) // -6
};
#if CK_TILE_USE_OCP_FP8
// FP8 EM4E3 (OCP) representation
static constexpr fp8_t e2m1_to_fp8_table[16] = {
fp8_t(static_cast<uint8_t>(0x00)), // 0
fp8_t(static_cast<uint8_t>(0x30)), // 0.5
fp8_t(static_cast<uint8_t>(0x38)), // 1
fp8_t(static_cast<uint8_t>(0x3C)), // 1.5
fp8_t(static_cast<uint8_t>(0x40)), // 2
fp8_t(static_cast<uint8_t>(0x44)), // 3
fp8_t(static_cast<uint8_t>(0x48)), // 4
fp8_t(static_cast<uint8_t>(0x4C)), // 6
fp8_t(static_cast<uint8_t>(0x00)), // -0
fp8_t(static_cast<uint8_t>(0xB0)), // -0.5
fp8_t(static_cast<uint8_t>(0xB8)), // -1
fp8_t(static_cast<uint8_t>(0xBC)), // -1.5
fp8_t(static_cast<uint8_t>(0xC0)), // -2
fp8_t(static_cast<uint8_t>(0xC4)), // -3
fp8_t(static_cast<uint8_t>(0xC8)), // -4
fp8_t(static_cast<uint8_t>(0xCC)) // -6
};
#else // CK_TILE_USE_FNUZ_FP8
// FP8 E4M3 FNUZ
static constexpr fp8_t e2m1_to_fp8_table[16] = {
fp8_t(static_cast<uint8_t>(0x00)), // 0
fp8_t(static_cast<uint8_t>(0x38)), // 0.5
fp8_t(static_cast<uint8_t>(0x40)), // 1
fp8_t(static_cast<uint8_t>(0x44)), // 1.5
fp8_t(static_cast<uint8_t>(0x48)), // 2
fp8_t(static_cast<uint8_t>(0x4C)), // 3
fp8_t(static_cast<uint8_t>(0x50)), // 4
fp8_t(static_cast<uint8_t>(0x54)), // 6
fp8_t(static_cast<uint8_t>(0x00)), // -0
fp8_t(static_cast<uint8_t>(0xB8)), // -0.5
fp8_t(static_cast<uint8_t>(0xC0)), // -1
fp8_t(static_cast<uint8_t>(0xC4)), // -1.5
fp8_t(static_cast<uint8_t>(0xC4)), // -2
fp8_t(static_cast<uint8_t>(0xCC)), // -3
fp8_t(static_cast<uint8_t>(0xD0)), // -4
fp8_t(static_cast<uint8_t>(0xD4)) // -6
};
#endif
#endif
};
@@ -408,6 +462,27 @@ CK_TILE_HOST_DEVICE constexpr fp16x2_t pk_fp4_t::to_fp16x2(float scale) const
type_convert<fp16_t>(convert_to_float<pk_fp4_t>(_unpack(number<1>{}), scale))};
#endif
}
CK_TILE_HOST_DEVICE constexpr fp8_t pk_fp4_t::to_fp8(float scale) const
{
// NOTE: No specialized fp4 to fp8 instructions are available. Unsure whether fp4 to fp16 to fp8
// would be better than the naive implementation below
// #if CK_TILE_FP4_CVT_DEVICE
// return impl::_from_f4<fp8_t>(data, scale);
// #else
return fp8_t{type_convert<fp8_t>(convert_to_float<pk_fp4_t>(_unpack(number<0>{}), scale))};
// #endif
}
CK_TILE_HOST_DEVICE constexpr fp8x2_t pk_fp4_t::to_fp8x2(float scale) const
{
// NOTE: No specialized fp4 to fp8 instructions are available. Unsure whether fp4 to fp16 to fp8
// would be better than the naive implementation below
// #if CK_TILE_FP4_CVT_DEVICE
// return impl::_from_f4<fp8x2_t>(data, scale);
// #else
return fp8x2_t{type_convert<fp8_t>(convert_to_float<pk_fp4_t>(_unpack(number<0>{}), scale)),
type_convert<fp8_t>(convert_to_float<pk_fp4_t>(_unpack(number<1>{}), scale))};
// #endif
}
#else
CK_TILE_HOST_DEVICE constexpr float pk_fp4_t::to_float(float scale) const
{
@@ -415,7 +490,8 @@ CK_TILE_HOST_DEVICE constexpr float pk_fp4_t::to_float(float scale) const
}
CK_TILE_HOST_DEVICE constexpr fp32x2_t pk_fp4_t::to_fp32x2(float scale) const
{
return fp32x2_t{e2m1_to_fp32_table[_unpack(number<0>{})] * scale, e2m1_to_fp32_table[_unpack(number<1>{}] * scale};
return fp32x2_t{e2m1_to_fp32_table[_unpack(number<0>{})] * scale,
e2m1_to_fp32_table[_unpack(number<1>{})] * scale};
}
CK_TILE_HOST_DEVICE constexpr fp16_t pk_fp4_t::to_fp16(float scale) const
{
@@ -428,6 +504,16 @@ CK_TILE_HOST_DEVICE constexpr fp16x2_t pk_fp4_t::to_fp16x2(float scale) const
type_convert<fp16_t>(type_convert<float>(e2m1_to_fp16_table[_unpack(number<1>{})]) *
scale)};
}
CK_TILE_HOST_DEVICE constexpr fp8_t pk_fp4_t::to_fp8(float scale) const
{
return type_convert<float>(e2m1_to_fp8_table[_unpack(number<0>{})]) * scale;
}
CK_TILE_HOST_DEVICE constexpr fp8x2_t pk_fp4_t::to_fp8x2(float scale) const
{
return fp8x2_t{
type_convert<fp8_t>(type_convert<float>(e2m1_to_fp8_table[_unpack(number<0>{})]) * scale),
type_convert<fp8_t>(type_convert<float>(e2m1_to_fp8_table[_unpack(number<1>{})]) * scale)};
}
#endif
} // namespace ck_tile

View File

@@ -6,6 +6,7 @@
#include "ck_tile/core/numeric/integral_constant.hpp"
#include "ck_tile/core/numeric/math.hpp"
#include "ck_tile/core/numeric/numeric.hpp"
#include "ck_tile/core/numeric/pk_fp4.hpp"
#include "ck_tile/core/utility/bit_cast.hpp"
#include "ck_tile/core/utility/random.hpp"
#include <stdint.h>
@@ -23,6 +24,11 @@ struct pk_int4_t
type data;
CK_TILE_HOST_DEVICE constexpr pk_int4_t() : data{type{}} {}
CK_TILE_HOST_DEVICE constexpr pk_int4_t(type init) : data{init} {}
// NOTE: added for interface compatibility with pk_fp4_t
// Other data types could be added for greater similarity
CK_TILE_HOST_DEVICE constexpr fp32x2_t to_fp32x2() const;
CK_TILE_HOST_DEVICE constexpr operator fp32x2_t() const { return to_fp32x2(); }
};
// limits
@@ -186,4 +192,9 @@ CK_TILE_HOST_DEVICE int8x2_t pk_int4_t_to_int8x2_t(const pk_int4_t& x)
return res;
}
CK_TILE_HOST_DEVICE constexpr fp32x2_t pk_int4_t::to_fp32x2() const
{
return pk_int4_t_to_fp32x2_t(*this);
}
} // namespace ck_tile

View File

@@ -11,6 +11,7 @@
#include "ck_tile/core/numeric/half.hpp"
#include "ck_tile/core/numeric/bfloat16.hpp"
#include "ck_tile/core/numeric/pk_int4.hpp"
#include "ck_tile/core/numeric/pk_fp4.hpp"
#include "ck_tile/core/numeric/e8m0.hpp"
#include "ck_tile/core/utility/type_traits.hpp"

View File

@@ -0,0 +1,54 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/numeric/numeric.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
#include <type_traits>
namespace ck_tile {
namespace detail {
// Helper method to automatically determine compute type
// Selects the largest type of the two. If both of them are packed data types, defaults to fp8.
template <typename ADataType, typename BDataType>
struct auto_compute_type
{
using LargestInputType = largest_type_t<ADataType, BDataType>;
// Sanity check: there are no packed types larger than 1 byte yet, but if we add them
// this logic should change
static_assert(!is_packed_type_v<LargestInputType> || sizeof(LargestInputType) == sizeof(fp8_t));
using type = std::conditional_t<is_packed_type_v<LargestInputType>, fp8_t, LargestInputType>;
};
// Helper method to determine compute type, defaulting an explicitly passed-in compute type
template <typename ComputeDataType, typename ADataType, typename BDataType>
struct mixed_prec_compute_type
{
using type = std::conditional_t<std::is_void_v<ComputeDataType>,
typename auto_compute_type<ADataType, BDataType>::type,
ComputeDataType>;
};
} // namespace detail
template <typename ComputeDataType, typename ADataType, typename BDataType>
using mixed_prec_compute_type_t =
typename detail::mixed_prec_compute_type<ComputeDataType, ADataType, BDataType>::type;
// Helper method to determine compute type, defaulting to input data type
// If "ThisDataType" is packed (4-bit), will default to "OtherDataType". If both are packed,
// ComputeDataType is used.
template <typename ThisDataType, typename OtherDataType, typename ComputeDataType>
using mixed_prec_compute_type_from_input_t = std::conditional_t<
is_packed_type_v<ThisDataType>,
std::conditional_t<is_packed_type_v<OtherDataType>, ComputeDataType, OtherDataType>,
ThisDataType>;
} // namespace ck_tile

View File

@@ -4,6 +4,8 @@
#pragma once
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/numeric/numeric.hpp"
#include <tuple>
#include <type_traits>
#include <stdint.h>
@@ -187,4 +189,19 @@ template <typename Tuple_, std::size_t Idx, typename DefaultType>
using tuple_element_or_default_t =
typename tuple_element_or_default<Tuple_, Idx, DefaultType>::type;
// Helper struct to determine if a type is packed (more than 1 element per byte)
template <typename T>
struct is_packed_type
{
static constexpr bool value = numeric_traits<T>::PackedSize > 1;
};
template <typename T>
static constexpr bool is_packed_type_v = is_packed_type<T>::value;
// Helper definition to take the largest sizes type
template <typename ADataType, typename BDataType>
using largest_type_t =
std::conditional_t<sizeof(ADataType) >= sizeof(BDataType), ADataType, BDataType>;
} // namespace ck_tile

View File

@@ -137,47 +137,55 @@ CK_TILE_HOST void reference_gemm_abquant(const HostTensor<ADataType>& a_m_k,
const BElementOp& b_element_op = {},
const ACCElementOp& acc_element_op = {})
{
const std::size_t M = a_m_k.get_length(0);
const std::size_t N = b_k_n.get_length(1);
const std::size_t K = a_m_k.get_length(1);
constexpr auto A_TENSOR_M_DIM = 0;
constexpr auto A_TENSOR_K_DIM = 1;
constexpr auto B_TENSOR_K_DIM = 0;
constexpr auto B_TENSOR_N_DIM = 1;
const std::size_t M = a_m_k.get_length(A_TENSOR_M_DIM);
const std::size_t N = b_k_n.get_length(B_TENSOR_N_DIM);
const std::size_t K = a_m_k.get_length(A_TENSOR_K_DIM);
// Pre-convert A/B tensors to AccData type
// This prevents doing slow reconversions for each row/column
HostTensor<AccDataType> a_acc(a_m_k.mDesc);
HostTensor<AccDataType> b_acc(b_k_n.mDesc);
a_acc.ForEach([&](auto& self, auto index) {
if constexpr(std::is_same_v<ADataType, pk_int4_t> || std::is_same_v<ADataType, pk_fp4_t>)
{
const ADataType pk_val = a_element_op(a_m_k(index));
const fp32x2_t fp32_val = pk_val.to_fp32x2();
self(index) = (index[A_TENSOR_K_DIM] & 1) ? fp32_val.hi : fp32_val.lo;
}
else
{
self(index) = ck_tile::type_convert<AccDataType>(a_element_op(a_m_k(index)));
}
});
b_acc.ForEach([&](auto& self, auto index) {
if constexpr(std::is_same_v<BDataType, pk_int4_t> || std::is_same_v<BDataType, pk_fp4_t>)
{
const BDataType pk_val = b_element_op(b_k_n(index));
const fp32x2_t fp32_val = pk_val.to_fp32x2();
self(index) = (index[B_TENSOR_K_DIM] & 1) ? fp32_val.hi : fp32_val.lo;
}
else if constexpr(std::is_same_v<BDataType, fp8_t>)
{
self(index) = fp8_to_float_raw(b_element_op(b_k_n(index)));
}
else
{
self(index) = ck_tile::type_convert<AccDataType>(b_element_op(b_k_n(index)));
}
});
auto f_mn = [&](auto m, auto n) {
AccDataType v_acc = 0;
constexpr std::size_t kGroupK = BQuantGroupSize::kK;
// ---- A loader: dequant A(m,k) into AccDataType ----
auto load_a = [&](std::size_t k) -> AccDataType {
if constexpr(std::is_same_v<ADataType, pk_int4_t>)
{
const pk_int4_t pk_val = a_element_op(a_m_k(m, k));
const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t(pk_val);
return (k & 1) ? fp32_val.hi : fp32_val.lo;
}
else
{
return ck_tile::type_convert<AccDataType>(a_element_op(a_m_k(m, k)));
}
};
// ---- B loader: dequant B(k,n) into AccDataType ----
auto load_b = [&](std::size_t k) -> AccDataType {
if constexpr(std::is_same_v<BDataType, pk_int4_t>)
{
const pk_int4_t pk_val = b_element_op(b_k_n(k, n));
const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t(pk_val);
return (k & 1) ? fp32_val.hi : fp32_val.lo;
}
else if constexpr(std::is_same_v<BDataType, fp8_t>)
{
return fp8_to_float_raw(b_element_op(b_k_n(k, n)));
}
else
{
return ck_tile::type_convert<AccDataType>(b_element_op(b_k_n(k, n)));
}
};
// ---- a scale loader for a given K-group index ----
auto load_scale_a = [&](ck_tile::index_t k_group) -> float {
const ck_tile::index_t outer_dim = m / AQuantGroupSize::kM;
@@ -224,8 +232,8 @@ CK_TILE_HOST void reference_gemm_abquant(const HostTensor<ADataType>& a_m_k,
// unscaled accumulation within this K-group
for(std::size_t k = k_begin; k < k_end; ++k)
{
const AccDataType v_a = load_a(k);
const AccDataType v_b = load_b(k);
const AccDataType v_a = a_acc(m, k);
const AccDataType v_b = b_acc(k, n);
v_block_acc += v_a * v_b;
}

View File

@@ -4,11 +4,12 @@
#pragma once
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp"
namespace ck_tile {
template <typename DstDataType, index_t UnaryOpSize>
template <typename SrcDataType, typename DstDataType, index_t UnaryOpSize>
struct InterleavedPKTypeLoader
{
template <typename WarpWindow, typename WarpTile>
@@ -21,10 +22,15 @@ struct InterleavedPKTypeLoader
constexpr index_t thread_buffer_size = WarpTile::get_thread_buffer_size() / UnaryOpSize;
const auto in_dstr_tensors = load_tile(warp_window);
using DstVectorType = DstDataType __attribute__((ext_vector_type(UnaryOpSize)));
// NOTE: we rely on types packing neatly here
using RawSrcType = typename SrcDataType::type;
constexpr auto PackedSize = numeric_traits<SrcDataType>::PackedSize;
using SrcVectorType = ext_vector_t<RawSrcType, UnaryOpSize / PackedSize>;
using DstVectorType = ext_vector_t<DstDataType, UnaryOpSize>;
static_for<0, thread_buffer_size, 1>{}([&](auto i) {
elementwise_op(warp_tile.get_thread_buffer().template get_as<DstVectorType>()(i),
in_dstr_tensors.get_thread_buffer().template get_as<pk_int4x4_t>()[i]);
in_dstr_tensors.get_thread_buffer().template get_as<SrcVectorType>()[i]);
});
}
};
@@ -37,10 +43,11 @@ template <typename SrcDataType,
typename WarpWindow>
CK_TILE_DEVICE void load_int4_tile(WarpTile& dst, const WarpWindow& src)
{
if constexpr(std::is_same_v<SrcDataType, pk_int4_t>)
if constexpr(is_packed_type_v<SrcDataType>)
{
static_assert(!LoadTranspose, "LoadTranspose not supported with pk_int4_t");
InterleavedPKTypeLoader<DstDataType, UnaryOpSize>::load_interleaved_pk_type(dst, src);
static_assert(!LoadTranspose, "LoadTranspose not supported with pk_int4_t or pk_fp4_t");
InterleavedPKTypeLoader<SrcDataType, DstDataType, UnaryOpSize>::load_interleaved_pk_type(
dst, src);
}
else if constexpr(LoadTranspose)
{

View File

@@ -397,6 +397,29 @@ struct PassThroughPack8
y.hi = i4_to_bf8x4(bit_cast<int>(x) >> 8);
#endif
}
CK_TILE_HOST_DEVICE constexpr void operator()(fp8x8_t& y, const pk_fp4x4_t& x) const
{
pk_fp4_t f0 = pk_fp4_t{x[0]};
pk_fp4_t f1 = pk_fp4_t{x[1]};
pk_fp4_t f2 = pk_fp4_t{x[2]};
pk_fp4_t f3 = pk_fp4_t{x[3]};
fp8x2_t x0 = f0.to_fp8x2();
fp8x2_t x1 = f1.to_fp8x2();
fp8x2_t x2 = f2.to_fp8x2();
fp8x2_t x3 = f3.to_fp8x2();
y[0] = x0[0];
y[1] = x0[1];
y[2] = x1[0];
y[3] = x1[1];
y[4] = x2[0];
y[5] = x2[1];
y[6] = x3[0];
y[7] = x3[1];
}
constexpr const static bool is_pack8_invocable = true;
};

View File

@@ -4,6 +4,7 @@
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/core/numeric/numeric.hpp"
#include "ck_tile/ops/gemm/block/block_wp_asmem_breg_creg.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
@@ -255,17 +256,26 @@ struct UniversalWeightPreshufflePipelineAgBgCrPolicy
{
using BlockWarps = typename Problem::BlockGemmShape::BlockWarps;
using WarpTile = typename Problem::BlockGemmShape::WarpTile;
using BTypeToUse =
std::conditional_t<std::is_same_v<typename Problem::BDataType, ck_tile::pk_int4_t>,
typename Problem::ADataType,
typename Problem::BDataType>;
// Determine compute types to use
// This logic defaults to A/B DataType, but if one of them is packed falls back to the other
// If both are packed, it falls back to the explicitly defined ComputeDataType in the
// problem It might be a good idea to use ComputeDataType anyway, but that would break how
// this behaviour used to work
using ATypeToUse = mixed_prec_compute_type_from_input_t<typename Problem::ADataType,
typename Problem::BDataType,
typename Problem::ComputeDataType>;
using BTypeToUse = mixed_prec_compute_type_from_input_t<typename Problem::BDataType,
typename Problem::ADataType,
typename Problem::ComputeDataType>;
constexpr index_t WaveSize = get_warp_size();
constexpr index_t KLane = WarpTile::at(I2) * WarpTile::at(I0) / WaveSize;
using BDataType = typename Problem::BDataType;
constexpr index_t KLaneBytes =
KLane / numeric_traits<BDataType>::PackedSize * sizeof(BDataType);
constexpr auto NumAccess = static_cast<WGAttrNumAccessEnum>(max(1, KLaneBytes / 16));
using WarpGemm = WarpGemmDispatcher<typename Problem::ADataType,
using WarpGemm = WarpGemmDispatcher<ATypeToUse,
BTypeToUse,
typename Problem::CDataType,
WarpTile::at(I0),

View File

@@ -101,9 +101,11 @@ struct BlockGemmWeightPreshuffleABQuantARegBRegCReg
// 4. i4, bf8, (fp8/fp32) -> f32
static_assert(
(std::is_same_v<ADataType, fp8_t> || std::is_same_v<ADataType, bf8_t> ||
std::is_same_v<ADataType, ck_tile::pk_int4_t>) &&
std::is_same_v<ADataType, ck_tile::pk_int4_t> ||
std::is_same_v<ADataType, ck_tile::pk_fp4_t>) &&
(std::is_same_v<BDataType, fp8_t> || std::is_same_v<BDataType, bf8_t> ||
std::is_same_v<BDataType, ck_tile::pk_int4_t>) &&
std::is_same_v<BDataType, ck_tile::pk_int4_t> ||
std::is_same_v<BDataType, ck_tile::pk_fp4_t>) &&
(std::is_same_v<AQDataType, float> || std::is_same_v<AQDataType, ck_tile::fp8_t> ||
std::is_same_v<AQDataType, ck_tile::bf8_t>) &&
(std::is_same_v<BQDataType, float> || std::is_same_v<BQDataType, ck_tile::fp8_t> ||
@@ -189,7 +191,8 @@ struct BlockGemmWeightPreshuffleABQuantARegBRegCReg
typename BFlatBlockTensor,
typename AQBlockTensor,
typename BQBlockTensor,
typename ABlockWindow>
typename ABlockWindow,
index_t UnaryOpSize = 8>
CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor,
ABlockTensor& a_warp_tensor,
BFlatBlockTensor& b_warp_tensor,
@@ -249,8 +252,10 @@ struct BlockGemmWeightPreshuffleABQuantARegBRegCReg
{
constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp;
constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp);
a_warp_tensor(number<AwarpIter>{}) =
load_tile(a_warp_windows(number<AmIter>{})(number<AkIter>{}));
load_int4_tile<ADataType, ComputeDataType, UnaryOpSize>(
a_warp_tensor(number<AwarpIter>{}),
a_warp_windows(number<AmIter>{})(number<AkIter>{}));
}
// barrier
// Could be deleted

View File

@@ -108,9 +108,11 @@ struct ABQuantBlockUniversalGemmAsBsCr : public BlockGemmQuantBase
// 4. i4, bf8, (fp8/fp32) -> f32
static_assert(
(std::is_same_v<ADataType, fp8_t> || std::is_same_v<ADataType, bf8_t> ||
std::is_same_v<ADataType, ck_tile::pk_int4_t>) &&
std::is_same_v<ADataType, ck_tile::pk_int4_t> ||
std::is_same_v<ADataType, ck_tile::pk_fp4_t>) &&
(std::is_same_v<BDataType, fp8_t> || std::is_same_v<BDataType, bf8_t> ||
std::is_same_v<BDataType, ck_tile::pk_int4_t>) &&
std::is_same_v<BDataType, ck_tile::pk_int4_t> ||
std::is_same_v<BDataType, ck_tile::pk_fp4_t>) &&
(std::is_same_v<AQDataType, float> || std::is_same_v<AQDataType, ck_tile::fp8_t> ||
std::is_same_v<AQDataType, ck_tile::bf8_t>) &&
(std::is_same_v<BQDataType, float> || std::is_same_v<BQDataType, ck_tile::fp8_t> ||
@@ -135,12 +137,9 @@ struct ABQuantBlockUniversalGemmAsBsCr : public BlockGemmQuantBase
using ComputeDataType = remove_cvref_t<typename Traits::ComputeDataType>;
using CDataType = remove_cvref_t<typename Traits::CDataType>;
// BDataType gets converted from PkInt4 during loading
using OverrideBDataType = std::conditional_t<
std::is_same_v<BDataType, pk_int4_t> &&
std::is_same_v<typename Traits::BLayout, tensor_layout::gemm::RowMajor>,
ADataType,
BDataType>;
// A/B DataType get converted from PkInt4/PkFp4 during loading
using OverrideADataType = ComputeDataType;
using OverrideBDataType = ComputeDataType;
using Base = BlockGemmQuantBase;
using WarpGemm = remove_cvref_t<typename Traits::WarpGemm>;
@@ -268,9 +267,9 @@ struct ABQuantBlockUniversalGemmAsBsCr : public BlockGemmQuantBase
bool_constant<ALoadTranspose> = {},
bool_constant<BLoadTranspose> = {})
{
load_int4_tile<ADataType, ComputeDataType, UnaryOpSize_, ALoadTranspose>(
// If A/B datatype were pkint4/pkfp4 it would be converted prior to storing in LDS
load_int4_tile<OverrideADataType, ComputeDataType, UnaryOpSize_, ALoadTranspose>(
a_warp_tile_, a_block_window);
// If B datatype were pkint4 it would be converted prior to storing in LDS
load_int4_tile<OverrideBDataType, ComputeDataType, UnaryOpSize_, BLoadTranspose>(
b_warp_tile_, b_block_window);
}

View File

@@ -10,9 +10,10 @@
namespace ck_tile {
struct GemmABQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgCrPolicy
struct GemmABQuantPipelineAgBgCrDefaultPolicy
: public UniversalGemmBasePolicy<GemmABQuantPipelineAgBgCrDefaultPolicy>
{
using Base = UniversalGemmPipelineAgBgCrPolicy;
using Base = UniversalGemmBasePolicy<GemmABQuantPipelineAgBgCrDefaultPolicy>;
using Base::I0;
using Base::I1;
using Base::I2;

View File

@@ -34,9 +34,6 @@ struct ABQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Pro
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
using AQuantGroupSize = remove_cvref_t<typename Problem::AQuantGroupSize>;
using BQuantGroupSize = remove_cvref_t<typename Problem::BQuantGroupSize>;
// BDataType gets converted from PkInt4 during loading
using OverrideBDataType =
std::conditional_t<std::is_same_v<BDataType, pk_int4_t>, ADataType, BDataType>;
static_assert(BQuantGroupSize::kM == 1, "only N/K blocks for BQuant kernel!");
static_assert(AQuantGroupSize::kN == 1, "only M/K blocks for AQuant kernel!");
@@ -67,6 +64,10 @@ struct ABQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Pro
using BlockGemm = remove_cvref_t<decltype(Policy::template GetBlockGemm<Problem>())>;
// A/B DataType gets converted from PkInt4/PkFp4 during loading
using OverrideADataType = BlockGemm::OverrideADataType;
using OverrideBDataType = BlockGemm::OverrideBDataType;
static constexpr index_t BlockSize = Problem::kBlockSize;
static constexpr index_t MPerBlock = BlockGemmShape::kM;
static constexpr index_t NPerBlock = BlockGemmShape::kN;
@@ -281,9 +282,9 @@ struct ABQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Pro
using AQDramTileWindowStep = typename AQDramBlockWindowTmp::BottomTensorIndex;
using BQDramTileWindowStep = typename BQDramBlockWindowTmp::BottomTensorIndex;
// Note: BDataType PkInt4 gets converted during loading, before going to LDS
// Note: A/B DataType PkInt4/PkFp4 gets converted during loading, before going to LDS
auto&& [a_lds_block, b_lds_block] =
Base::template GetABLdsTensorViews<ADataType, OverrideBDataType>(p_smem);
Base::template GetABLdsTensorViews<OverrideADataType, OverrideBDataType>(p_smem);
constexpr auto a_lds_load_tile_distr =
make_static_tile_distribution(BlockGemm::MakeABlockDistributionEncode());
@@ -303,9 +304,9 @@ struct ABQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Pro
using BQBlockTileDistr = decltype(bq_copy_dram_window.get_tile_distribution());
using ABlockTile =
decltype(make_static_distributed_tensor<ADataType>(ABlockTileDistr{}));
decltype(make_static_distributed_tensor<OverrideADataType>(ABlockTileDistr{}));
using BBlockTile =
decltype(make_static_distributed_tensor<BDataType>(BBlockTileDistr{}));
decltype(make_static_distributed_tensor<OverrideBDataType>(BBlockTileDistr{}));
using AQBlockTile =
decltype(make_static_distributed_tensor<AQDataType>(AQBlockTileDistr{}));
using BQBlockTile =
@@ -361,7 +362,7 @@ struct ABQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Pro
if constexpr(is_a_col_major && !is_a_load_tr_v())
{
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
auto a_shuffle_tmp = make_static_distributed_tensor<OverrideADataType>(
Policy::template MakeShuffledARegTileDistribution<Problem>());
transpose_tile2d(a_shuffle_tmp, a_block_tile);
Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func);
@@ -373,7 +374,7 @@ struct ABQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Pro
if constexpr(is_b_row_major && !is_b_load_tr_v())
{
auto b_shuffle_tmp = make_static_distributed_tensor<BDataType>(
auto b_shuffle_tmp = make_static_distributed_tensor<OverrideBDataType>(
Policy::template MakeShuffledBRegTileDistribution<Problem>());
transpose_tile2d(b_shuffle_tmp, b_block_tile);
Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func);
@@ -409,7 +410,8 @@ struct ABQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Pro
if constexpr(is_a_col_major && !is_a_load_tr_v())
{
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
// Note: ABDataType PkInt4/PkFp4 gets converted during loading earlier
auto a_shuffle_tmp = make_static_distributed_tensor<OverrideADataType>(
Policy::template MakeShuffledARegTileDistribution<Problem>());
transpose_tile2d(a_shuffle_tmp, a_block_tile);
Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func);
@@ -420,7 +422,7 @@ struct ABQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Pro
}
if constexpr(is_b_row_major && !is_b_load_tr_v())
{
// Note: BDataType PkInt4 gets converted during loading earlier
// Note: BDataType PkInt4/PkFp4 gets converted during loading earlier
auto b_shuffle_tmp = make_static_distributed_tensor<OverrideBDataType>(
Policy::template MakeShuffledBRegTileDistribution<Problem>());
transpose_tile2d(b_shuffle_tmp, b_block_tile);
@@ -493,7 +495,8 @@ struct ABQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Pro
if constexpr(is_a_col_major && !is_a_load_tr_v())
{
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
// Note: ADataType gets converted during loading from PkInt4/PkFp4
auto a_shuffle_tmp = make_static_distributed_tensor<OverrideADataType>(
Policy::template MakeShuffledARegTileDistribution<Problem>());
transpose_tile2d(a_shuffle_tmp, a_block_tile);
Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func);
@@ -543,9 +546,9 @@ struct ABQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Pro
return PipelineImpl<Scheduler>{}.template operator()<HasHotLoop, TailNum>(
a_dram_block_window_tmp,
[](const ADataType& a) { return a; },
[](const OverrideADataType& a) { return a; },
b_dram_block_window_tmp,
[](const BDataType& b) { return b; },
[](const OverrideBDataType& b) { return b; },
aq_dram_block_window_tmp,
bq_dram_block_window_tmp,
m,
@@ -593,9 +596,10 @@ struct ABQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Pro
return PipelineImpl<Scheduler>{}.template operator()<hot_loop, tail_num>(
a_dram_block_window_tmp,
[](const ADataType& a) { return a; },
// Note: ADataType PkInt4/PkFp4 gets converted during loading
[](const OverrideADataType& a) { return a; },
b_dram_block_window_tmp,
// Note: BDataType PkInt4 gets converted during loading
// Note: BDataType PkInt4/PkFp4 gets converted during loading
[](const OverrideBDataType& b) { return b; },
aq_dram_block_window_tmp,
bq_dram_block_window_tmp,

View File

@@ -21,23 +21,27 @@ template <typename ADataType_,
typename AQuantGroupSize_,
typename BQuantGroupSize_,
bool TransposeC_,
typename ComputeDataType_ = BDataType_,
typename ComputeDataType_ = void,
GemmPipelineScheduler Scheduler_ = GemmPipelineScheduler::Intrawave,
bool HasHotLoop_ = true,
TailNumber TailNum_ = TailNumber::Full>
struct GemmQuantPipelineProblemBase : public GemmPipelineProblemBase<ADataType_,
BDataType_,
CDataType_,
BlockGemmShape_,
Traits_,
ComputeDataType_>
struct GemmQuantPipelineProblemBase
: public GemmPipelineProblemBase<
ADataType_,
BDataType_,
CDataType_,
BlockGemmShape_,
Traits_,
mixed_prec_compute_type_t<ComputeDataType_, ADataType_, BDataType_>>
{
using Base = GemmPipelineProblemBase<ADataType_,
BDataType_,
CDataType_,
BlockGemmShape_,
Traits_,
ComputeDataType_>;
using Base = GemmPipelineProblemBase<
ADataType_,
BDataType_,
CDataType_,
BlockGemmShape_,
Traits_,
mixed_prec_compute_type_t<ComputeDataType_, ADataType_, BDataType_>>;
using Traits = typename Base::Traits;

View File

@@ -95,11 +95,6 @@ struct GemmWPABQuantPipelineAgBgCrPolicy : public UniversalWeightPreshufflePipel
using BlockWarps = typename Problem::BlockGemmShape::BlockWarps;
using WarpTile = typename Problem::BlockGemmShape::WarpTile;
using BTypeToUse =
std::conditional_t<std::is_same_v<typename Problem::BDataType, ck_tile::pk_int4_t>,
typename Problem::ADataType,
typename Problem::BDataType>;
constexpr index_t WaveSize = get_warp_size();
constexpr index_t KLane = WarpTile::at(I2) * WarpTile::at(I0) / WaveSize;
using BDataType = typename Problem::BDataType;
@@ -107,8 +102,8 @@ struct GemmWPABQuantPipelineAgBgCrPolicy : public UniversalWeightPreshufflePipel
KLane / numeric_traits<BDataType>::PackedSize * sizeof(BDataType);
constexpr auto NumAccess = static_cast<WGAttrNumAccessEnum>(max(1, KLaneBytes / 16));
using WarpGemm = WarpGemmDispatcher<typename Problem::ADataType,
BTypeToUse,
using WarpGemm = WarpGemmDispatcher<typename Problem::ComputeDataType,
typename Problem::ComputeDataType,
typename Problem::CDataType,
WarpTile::at(I0),
WarpTile::at(I1),

View File

@@ -7,6 +7,7 @@
#include <sstream>
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common/load_interleaved_pk_type.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
#include "ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_v2.hpp"
@@ -239,36 +240,42 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe
make_tensor_view<address_space_enum::lds>(p_a_lds_pong, a_lds_block_desc);
// A DRAM tile window for load
auto a_dram_tile_distribution =
PipelinePolicy::template MakeADramTileDistribution<Problem>();
auto a_copy_dram_window =
make_tile_window(a_dram_block_window_tmp.get_bottom_tensor_view(),
make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}),
a_dram_block_window_tmp.get_window_origin(),
PipelinePolicy::template MakeADramTileDistribution<Problem>());
a_dram_tile_distribution);
auto a_copy_lds_window_ping =
make_tile_window(a_lds_block_ping,
make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}),
{0, 0},
PipelinePolicy::template MakeADramTileDistribution<Problem>());
a_dram_tile_distribution);
auto a_copy_lds_window_pong =
make_tile_window(a_lds_block_pong,
make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}),
{0, 0},
PipelinePolicy::template MakeADramTileDistribution<Problem>());
a_dram_tile_distribution);
// ping-pong window for A LDS
auto a_warp_tile_distribution =
make_static_tile_distribution(typename WG::AWarpDstrEncoding{});
auto a_warp_window_ping_tmp =
make_tile_window(a_lds_block_ping,
make_tuple(number<WG::kM>{}, number<WG::kK>{}),
{iMWarp * WG::kM, 0},
make_static_tile_distribution(typename WG::AWarpDstrEncoding{}));
a_warp_tile_distribution);
auto a_warp_window_pong_tmp =
make_tile_window(a_lds_block_pong,
make_tuple(number<WG::kM>{}, number<WG::kK>{}),
{iMWarp * WG::kM, 0},
make_static_tile_distribution(typename WG::AWarpDstrEncoding{}));
a_warp_tile_distribution);
statically_indexed_array<
statically_indexed_array<decltype(a_warp_window_ping_tmp), KIterPerWarp>,
@@ -314,7 +321,7 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe
b_flat_distribution);
using BTypeToUse =
std::conditional_t<std::is_same_v<BDataType, pk_int4_t>, ADataType, BDataType>;
mixed_prec_compute_type_from_input_t<BDataType, ADataType, ComputeDataType>;
using BTileType = decltype(make_static_distributed_tensor<BTypeToUse>(b_flat_distribution));
// pingpong buffer for B
@@ -354,7 +361,7 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe
move_tile_window(b_flat_dram_windows(nIter)(kIter),
{nIter * flatNPerWarp, kIter * flatKPerWarp});
load_int4_tile<BDataType, ADataType, UnaryOpSize_>(
load_int4_tile<BDataType, BTypeToUse, UnaryOpSize_>(
b_warp_tensor_ping(nIter)(kIter), b_flat_dram_windows(nIter)(kIter));
});
});
@@ -393,15 +400,17 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe
block_sync_lds();
// preload A00,A10 from lds
statically_indexed_array<decltype(load_tile(a_warp_windows_ping(number<0>{})(number<0>{}))),
m_preload>
a_warp_tensor;
using ATypeToUse =
mixed_prec_compute_type_from_input_t<ADataType, BDataType, ComputeDataType>;
using ATileType =
decltype(make_static_distributed_tensor<BTypeToUse>(a_warp_tile_distribution));
statically_indexed_array<ATileType, m_preload> a_warp_tensor;
static_for<0, m_preload, 1>{}([&](auto loadIter) {
constexpr auto mIter = loadIter % MIterPerWarp;
constexpr auto kIter = loadIter / MIterPerWarp;
a_warp_tensor(loadIter) =
load_tile(a_warp_windows_ping(number<mIter>{})(number<kIter>{}));
load_int4_tile<ADataType, ATypeToUse, UnaryOpSize_>(
a_warp_tensor(loadIter), a_warp_windows_ping(number<mIter>{})(number<kIter>{}));
});
__builtin_amdgcn_sched_barrier(0);
@@ -434,7 +443,7 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe
move_tile_window(b_flat_dram_windows(nIter)(kIter),
{nIter * flatNPerWarp, kIter * flatKPerWarp});
load_int4_tile<BDataType, ADataType, UnaryOpSize_>(
load_int4_tile<BDataType, BTypeToUse, UnaryOpSize_>(
b_warp_tensor_pong(nIter)(kIter), b_flat_dram_windows(nIter)(kIter));
});
});
@@ -450,8 +459,8 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe
static_for<0, m_preload, 1>{}([&](auto loadIter) {
constexpr auto mIter = loadIter % MIterPerWarp;
constexpr auto kIter = loadIter / MIterPerWarp;
a_warp_tensor(loadIter) =
load_tile(a_warp_windows_pong(number<mIter>{})(number<kIter>{}));
load_int4_tile<ADataType, ATypeToUse, UnaryOpSize_>(
a_warp_tensor(loadIter), a_warp_windows_pong(number<mIter>{})(number<kIter>{}));
});
// Next K
@@ -463,7 +472,7 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe
move_tile_window(b_flat_dram_windows(nIter)(kIter),
{nIter * flatNPerWarp, kIter * flatKPerWarp});
load_int4_tile<BDataType, ADataType, UnaryOpSize_>(
load_int4_tile<BDataType, BTypeToUse, UnaryOpSize_>(
b_warp_tensor_ping(nIter)(kIter), b_flat_dram_windows(nIter)(kIter));
});
});
@@ -495,8 +504,8 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe
static_for<0, m_preload, 1>{}([&](auto loadIter) {
constexpr auto mIter = loadIter % MIterPerWarp;
constexpr auto kIter = loadIter / MIterPerWarp;
a_warp_tensor(loadIter) =
load_tile(a_warp_windows_ping(number<mIter>{})(number<kIter>{}));
load_int4_tile<ADataType, ATypeToUse, UnaryOpSize_>(
a_warp_tensor(loadIter), a_warp_windows_ping(number<mIter>{})(number<kIter>{}));
});
iCounter--;
HotLoopScheduler<loop_count>();
@@ -513,7 +522,7 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe
move_tile_window(b_flat_dram_windows(nIter)(kIter),
{nIter * flatNPerWarp, kIter * flatKPerWarp});
load_int4_tile<BDataType, ADataType, UnaryOpSize_>(
load_int4_tile<BDataType, BTypeToUse, UnaryOpSize_>(
b_warp_tensor_pong(nIter)(kIter), b_flat_dram_windows(nIter)(kIter));
});
});
@@ -535,8 +544,8 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe
static_for<0, m_preload, 1>{}([&](auto loadIter) {
constexpr auto mIter = loadIter % MIterPerWarp;
constexpr auto kIter = loadIter / MIterPerWarp;
a_warp_tensor(loadIter) =
load_tile(a_warp_windows_pong(number<mIter>{})(number<kIter>{}));
load_int4_tile<ADataType, ATypeToUse, UnaryOpSize_>(
a_warp_tensor(loadIter), a_warp_windows_pong(number<mIter>{})(number<kIter>{}));
});
// GEMM loopK

View File

@@ -76,6 +76,22 @@ if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx12")
)
target_compile_options(test_tile_gemm_quant_abquant_preshuffle PRIVATE ${TEST_GEMM_COMPILE_OPTIONS})
add_gtest_executable(test_tile_gemm_quant_abquant_a4w4_base
test_gemm_quant_abquant_a4w4_base.cpp
)
target_compile_options(test_tile_gemm_quant_abquant_a4w4_base PRIVATE ${TEST_GEMM_COMPILE_OPTIONS})
add_gtest_executable(test_tile_gemm_quant_abquant_a4w4_padding
test_gemm_quant_abquant_a4w4_padding.cpp
)
target_compile_options(test_tile_gemm_quant_abquant_a4w4_padding PRIVATE ${TEST_GEMM_COMPILE_OPTIONS})
add_gtest_executable(test_tile_gemm_quant_abquant_a4w4_preshuffle
test_gemm_quant_abquant_a4w4_preshuffle.cpp
)
target_compile_options(test_tile_gemm_quant_abquant_a4w4_preshuffle PRIVATE ${TEST_GEMM_COMPILE_OPTIONS})
add_gtest_executable(test_tile_gemm_quant_abquant_preshuffleQuant
test_gemm_quant_abquant_preshuffleQuant.cpp
)

View File

@@ -0,0 +1,44 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "ck_tile/host.hpp"
#include "ck_tile/ops/gemm.hpp"
#include <gtest/gtest.h>
#include <memory>
#include "test_gemm_quant_fixtures.hpp"
// Type aliases for readability
using RowMajor = ck_tile::tensor_layout::gemm::RowMajor;
using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor;
using Half = ck_tile::half_t;
using PkFP4 = ck_tile::pk_fp4_t;
using ABQuantGrouped =
std::integral_constant<ck_tile::QuantType, ck_tile::QuantType::ABQuantGrouped>;
// 1d block sizes for AQuant
using GroupSize1D = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
// 2d block sizes for BQuant
using GroupSize2D = ck_tile::QuantGroupShape<ck_tile::sequence<1, 128, 128>>;
// Type combinations for ABQuant tests
// Tuple format: <ALayout, BLayout, CLayout, AQLayout, ADataType, BDataType, QDataType, CDataType,
// QuantType, GemmConfig, AQuantGroupSize, BQuantGroupSize, BQLayout>
// clang-format off
using ABQuantTypes = ::testing::Types<
// PreshuffleQuant = false && TransposeC = false
// RCR layout with RowMajor AQ, ColumnMajor BQ
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, PkFP4, PkFP4, float, Half, ABQuantGrouped, GemmConfigBase, GroupSize1D, GroupSize2D, ColumnMajor>
>;
// clang-format on
// Test suite for ABQuant
TYPED_TEST_SUITE(TestCkTileGemmABQuant, ABQuantTypes);
// AQuant tests
TYPED_TEST(TestCkTileGemmABQuant, ABQuantGroupedTest)
{
this->run_test_with_validation(1024, 1024, 1024);
}

View File

@@ -0,0 +1,65 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "ck_tile/host.hpp"
#include "ck_tile/ops/gemm.hpp"
#include <gtest/gtest.h>
#include <memory>
#include "test_gemm_quant_fixtures.hpp"
// Type aliases for readability
using RowMajor = ck_tile::tensor_layout::gemm::RowMajor;
using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor;
using Half = ck_tile::half_t;
using PkFP4 = ck_tile::pk_fp4_t;
using ABQuantGrouped =
std::integral_constant<ck_tile::QuantType, ck_tile::QuantType::ABQuantGrouped>;
// 1d block sizes for AQuant
using GroupSize1D = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
// 2d block sizes for BQuant
using GroupSize2D = ck_tile::QuantGroupShape<ck_tile::sequence<1, 128, 128>>;
// Type combinations for ABQuant tests
// Tuple format: <ALayout, BLayout, CLayout, AQLayout, ADataType, BDataType, QDataType, CDataType,
// QuantType, GemmConfig, AQuantGroupSize, BQuantGroupSize, BQLayout>
// clang-format off
using ABQuantTypes = ::testing::Types<
// PreshuffleQuant = false && TransposeC = false
// RCR layout with RowMajor AQ, ColumnMajor BQ
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, PkFP4, PkFP4, float, Half, ABQuantGrouped, GemmConfigPadding, GroupSize1D, GroupSize2D, ColumnMajor>
>;
// clang-format on
// Test suite for ABQuant
TYPED_TEST_SUITE(TestCkTileGemmABQuant, ABQuantTypes);
// AQuant tests
TYPED_TEST(TestCkTileGemmABQuant, ABQuantGroupedTest_PadK)
{
this->run_test_with_validation(1024, 1024, 832);
}
TYPED_TEST(TestCkTileGemmABQuant, ABQuantGroupedTest_PadN)
{
this->run_test_with_validation(1024, 832, 1024);
}
TYPED_TEST(TestCkTileGemmABQuant, ABQuantGroupedTest_PadM)
{
this->run_test_with_validation(832, 1024, 1024);
}
TYPED_TEST(TestCkTileGemmABQuant, ABQuantGroupedTest_PadMNK)
{
this->run_test_with_validation(832, 832, 832);
}
TYPED_TEST(TestCkTileGemmABQuant, ABQuantGroupedTest_PadNK)
{
this->run_test_with_validation(1024, 832, 832);
}

View File

@@ -0,0 +1,44 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "ck_tile/host.hpp"
#include "ck_tile/ops/gemm.hpp"
#include <gtest/gtest.h>
#include <memory>
#include "test_gemm_quant_fixtures.hpp"
// Type aliases for readability
using RowMajor = ck_tile::tensor_layout::gemm::RowMajor;
using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor;
using Half = ck_tile::half_t;
using PkFP4 = ck_tile::pk_fp4_t;
using ABQuantGrouped =
std::integral_constant<ck_tile::QuantType, ck_tile::QuantType::ABQuantGrouped>;
// 1d block sizes for AQuant
using GroupSize1D = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
// 2d block sizes for BQuant
using GroupSize2D = ck_tile::QuantGroupShape<ck_tile::sequence<1, 128, 128>>;
// Type combinations for ABQuant tests
// Tuple format: <ALayout, BLayout, CLayout, AQLayout, ADataType, BDataType, QDataType, CDataType,
// QuantType, GemmConfig, AQuantGroupSize, BQuantGroupSize, BQLayout>
// clang-format off
using ABQuantTypes = ::testing::Types<
// RCR layout with RowMajor AQ, ColumnMajor BQ
// PreshuffleB = true && TransposeC = false
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, PkFP4, PkFP4, float, Half, ABQuantGrouped, GemmConfigPreshuffleBPrefill, GroupSize1D, GroupSize2D, ColumnMajor>
>;
// clang-format on
// Test suite for ABQuant
TYPED_TEST_SUITE(TestCkTileGemmABQuant, ABQuantTypes);
// AQuant tests
TYPED_TEST(TestCkTileGemmABQuant, ABQuantGroupedTest)
{
this->run_test_with_validation(1024, 1024, 1024);
}

View File

@@ -209,7 +209,7 @@ template <>
struct QuantTypeTraits<ck_tile::QuantType::ABQuantGrouped>
{
template <typename ADataType, typename BDataType>
using ComputeDataType = BDataType; // For AQuant, compute type is BDataType
using ComputeDataType = void; // Use automatically determined compute type
static constexpr const char* name = "abquant";
};

View File

@@ -1174,8 +1174,8 @@ class TestCkTileGemmABQuant : public TestCkTileGemmQuantBase<Tuple, TestCkTileGe
ck_tile::ABQuantGemmPipelineAgBgCrCompV3<PipelineProblem>>;
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<ADataType,
BDataType,
ck_tile::CShuffleEpilogueProblem<typename PipelineProblem::ComputeDataType,
typename PipelineProblem::ComputeDataType,
ck_tile::tuple<>,
AccDataType,
CDataType,