[CK TILE] Fix bugs in AQuant preshuffle (#2700)

* [CK TILE] Fix bugs in AQuant preshuffle

- Make Aquant works with block Mx64x256. `M` could be 16, 32, 64
- Make Aquant works with warp 16x16x32 and 32x32x16.

* [CK TILE] Rename Preshuffle to PreshuffleQuant

The new name, PreshuffleQuant, explicitly states the function's purpose:
to preshuffle the quantization matrix.

* [CK TILE Block Scale] Use GemmConfig to save tile properties

- Remove specialization of GemmQuantTypeConfig
- Pass GemmConfig around which contains tile properties. Stop using hard
  coded tile properties in `gemm_calc_aquant()`

* [CK TILE Block Scale] Rename GemmConfig used in block scale

    - Remove unused GemmConfig
    - Rename GemmConfig used in block scale

---------

Co-authored-by: ThomasNing <thomas.ning@amd.com>
This commit is contained in:
Cong Ma
2025-08-27 01:05:54 -06:00
committed by GitHub
parent 95e4a4efcb
commit 245467f359
12 changed files with 266 additions and 1069 deletions

View File

@@ -8,11 +8,10 @@
#include <string>
#include <tuple>
#include "ck_tile/core/config.hpp"
#include "ck_tile/host.hpp"
#include "gemm_utils.hpp"
template <typename ADataType,
template <typename GemmConfig,
typename ADataType,
typename AQDataType,
typename BDataType,
typename AccDataType,
@@ -21,8 +20,7 @@ template <typename ADataType,
typename ALayout,
typename BLayout,
typename CLayout,
uint32_t QuantGroupSize,
bool Preshuffle = false>
uint32_t QuantGroupSize>
float gemm_calc_aquant(const ck_tile::AQuantGemmHostArgs& args, const ck_tile::stream_config& s)
{
constexpr bool kPadM = false;
@@ -33,17 +31,17 @@ float gemm_calc_aquant(const ck_tile::AQuantGemmHostArgs& args, const ck_tile::s
static_assert(std::is_same_v<CLayout, ck_tile::tensor_layout::gemm::RowMajor>);
constexpr ck_tile::index_t M_Tile = 16;
constexpr ck_tile::index_t N_Tile = 64;
constexpr ck_tile::index_t K_Tile = 256;
constexpr ck_tile::index_t M_Tile = GemmConfig::M_Tile;
constexpr ck_tile::index_t N_Tile = GemmConfig::N_Tile;
constexpr ck_tile::index_t K_Tile = GemmConfig::K_Tile;
constexpr ck_tile::index_t M_Warp = 1;
constexpr ck_tile::index_t N_Warp = 4;
constexpr ck_tile::index_t K_Warp = 1;
constexpr ck_tile::index_t M_Warp = GemmConfig::M_Warp;
constexpr ck_tile::index_t N_Warp = GemmConfig::N_Warp;
constexpr ck_tile::index_t K_Warp = GemmConfig::K_Warp;
constexpr ck_tile::index_t M_Warp_Tile = 16;
constexpr ck_tile::index_t N_Warp_Tile = 16;
constexpr ck_tile::index_t K_Warp_Tile = 32;
constexpr ck_tile::index_t M_Warp_Tile = GemmConfig::M_Warp_Tile;
constexpr ck_tile::index_t N_Warp_Tile = GemmConfig::N_Warp_Tile;
constexpr ck_tile::index_t K_Warp_Tile = GemmConfig::K_Warp_Tile;
using CodegenGemmShape =
ck_tile::TileGemmShape<ck_tile::sequence<M_Tile, N_Tile, K_Tile>,
@@ -52,8 +50,13 @@ float gemm_calc_aquant(const ck_tile::AQuantGemmHostArgs& args, const ck_tile::s
using TilePartitioner = ck_tile::GemmTile1DPartitioner<CodegenGemmShape>;
using CodegenGemmTraits =
ck_tile::TileGemmAQuantTraits<kPadM, kPadN, kPadK, Preshuffle, ALayout, BLayout, CLayout>;
using CodegenGemmTraits = ck_tile::TileGemmAQuantTraits<kPadM,
kPadN,
kPadK,
GemmConfig::PreshuffleQuant,
ALayout,
BLayout,
CLayout>;
using GemmPipelineProblem = ck_tile::GemmPipelineProblemBase<ADataType,
BDataType,
@@ -186,13 +189,14 @@ int run_gemm_example(int argc, char* argv[])
if(data_type == "fp8")
{
using TypeConfig =
decltype(GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t>{});
decltype(GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, float>{});
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>, TypeConfig, 128>(
a_layout, b_layout, argc, argv);
}
else if(data_type == "bf8")
{
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::bf8_t, ck_tile::bf8_t, float>{});
using TypeConfig =
decltype(GemmQuantTypeConfig<ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, float>{});
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>, TypeConfig, 128>(
a_layout, b_layout, argc, argv);
}
@@ -200,7 +204,7 @@ int run_gemm_example(int argc, char* argv[])
{
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::pk_int4_t,
ck_tile::fp8_t,
float,
ck_tile::half_t,
ck_tile::fp8_t>{});
return run_gemm_example_prec_type<GemmConfig<ck_tile::pk_int4_t>, TypeConfig, 128>(
a_layout, b_layout, argc, argv);
@@ -209,29 +213,15 @@ int run_gemm_example(int argc, char* argv[])
{
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::pk_int4_t,
ck_tile::bf8_t,
float,
ck_tile::half_t,
ck_tile::bf8_t>{});
return run_gemm_example_prec_type<GemmConfig<ck_tile::pk_int4_t>, TypeConfig, 128>(
a_layout, b_layout, argc, argv);
}
else if(data_type == "i4f32fp8")
{
using TypeConfig =
decltype(GemmQuantTypeConfig<ck_tile::pk_int4_t, ck_tile::fp8_t, float, float>{});
return run_gemm_example_prec_type<GemmConfig<ck_tile::pk_int4_t>, TypeConfig, 128>(
a_layout, b_layout, argc, argv);
}
else if(data_type == "i4f32bf8")
{
using TypeConfig =
decltype(GemmQuantTypeConfig<ck_tile::pk_int4_t, ck_tile::bf8_t, float, float>{});
return run_gemm_example_prec_type<GemmConfig<ck_tile::pk_int4_t>, TypeConfig, 128>(
a_layout, b_layout, argc, argv);
}
else
{
throw std::runtime_error("Unsupported data type for this operation !!!");
}
}
int main(int argc, char* argv[]) { return !run_gemm_example<GemmConfigComputeV3>(argc, argv); }
int main(int argc, char* argv[]) { return !run_gemm_example<GemmConfigDecode>(argc, argv); }

View File

@@ -8,11 +8,10 @@
#include <string>
#include <tuple>
#include "ck_tile/core/config.hpp"
#include "ck_tile/host.hpp"
#include "gemm_utils.hpp"
template <typename ADataType,
template <typename GemmConfig,
typename ADataType,
typename AQDataType,
typename BDataType,
typename AccDataType,
@@ -21,8 +20,7 @@ template <typename ADataType,
typename ALayout,
typename BLayout,
typename CLayout,
uint32_t QuantGroupSize,
bool Preshuffle = false>
uint32_t QuantGroupSize>
float gemm_calc_aquant(const ck_tile::AQuantGemmHostArgs& args, const ck_tile::stream_config& s)
{
constexpr bool kPadM = false;
@@ -33,17 +31,17 @@ float gemm_calc_aquant(const ck_tile::AQuantGemmHostArgs& args, const ck_tile::s
static_assert(std::is_same_v<CLayout, ck_tile::tensor_layout::gemm::RowMajor>);
constexpr ck_tile::index_t M_Tile = 16;
constexpr ck_tile::index_t N_Tile = 64;
constexpr ck_tile::index_t K_Tile = 256;
constexpr ck_tile::index_t M_Tile = GemmConfig::M_Tile;
constexpr ck_tile::index_t N_Tile = GemmConfig::N_Tile;
constexpr ck_tile::index_t K_Tile = GemmConfig::K_Tile;
constexpr ck_tile::index_t M_Warp = 1;
constexpr ck_tile::index_t N_Warp = 4;
constexpr ck_tile::index_t K_Warp = 1;
constexpr ck_tile::index_t M_Warp = GemmConfig::M_Warp;
constexpr ck_tile::index_t N_Warp = GemmConfig::N_Warp;
constexpr ck_tile::index_t K_Warp = GemmConfig::K_Warp;
constexpr ck_tile::index_t M_Warp_Tile = 16;
constexpr ck_tile::index_t N_Warp_Tile = 16;
constexpr ck_tile::index_t K_Warp_Tile = 32;
constexpr ck_tile::index_t M_Warp_Tile = GemmConfig::M_Warp_Tile;
constexpr ck_tile::index_t N_Warp_Tile = GemmConfig::N_Warp_Tile;
constexpr ck_tile::index_t K_Warp_Tile = GemmConfig::K_Warp_Tile;
using CodegenGemmShape =
ck_tile::TileGemmShape<ck_tile::sequence<M_Tile, N_Tile, K_Tile>,
@@ -52,8 +50,13 @@ float gemm_calc_aquant(const ck_tile::AQuantGemmHostArgs& args, const ck_tile::s
using TilePartitioner = ck_tile::GemmTile1DPartitioner<CodegenGemmShape>;
using CodegenGemmTraits =
ck_tile::TileGemmAQuantTraits<kPadM, kPadN, kPadK, Preshuffle, ALayout, BLayout, CLayout>;
using CodegenGemmTraits = ck_tile::TileGemmAQuantTraits<kPadM,
kPadN,
kPadK,
GemmConfig::PreshuffleQuant,
ALayout,
BLayout,
CLayout>;
using GemmPipelineProblem = ck_tile::GemmPipelineProblemBase<ADataType,
BDataType,
@@ -186,13 +189,14 @@ int run_gemm_example(int argc, char* argv[])
if(data_type == "fp8")
{
using TypeConfig =
decltype(GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t>{});
decltype(GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, float>{});
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>, TypeConfig, 128>(
a_layout, b_layout, argc, argv);
}
else if(data_type == "bf8")
{
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::bf8_t, ck_tile::bf8_t, float>{});
using TypeConfig =
decltype(GemmQuantTypeConfig<ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, float>{});
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>, TypeConfig, 128>(
a_layout, b_layout, argc, argv);
}
@@ -200,7 +204,7 @@ int run_gemm_example(int argc, char* argv[])
{
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::pk_int4_t,
ck_tile::fp8_t,
float,
ck_tile::half_t,
ck_tile::fp8_t>{});
return run_gemm_example_prec_type<GemmConfig<ck_tile::pk_int4_t>, TypeConfig, 128>(
a_layout, b_layout, argc, argv);
@@ -209,29 +213,18 @@ int run_gemm_example(int argc, char* argv[])
{
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::pk_int4_t,
ck_tile::bf8_t,
float,
ck_tile::half_t,
ck_tile::bf8_t>{});
return run_gemm_example_prec_type<GemmConfig<ck_tile::pk_int4_t>, TypeConfig, 128>(
a_layout, b_layout, argc, argv);
}
else if(data_type == "i4f32fp8")
{
using TypeConfig =
decltype(GemmQuantTypeConfig<ck_tile::pk_int4_t, ck_tile::fp8_t, float, float>{});
return run_gemm_example_prec_type<GemmConfig<ck_tile::pk_int4_t>, TypeConfig, 128>(
a_layout, b_layout, argc, argv);
}
else if(data_type == "i4f32bf8")
{
using TypeConfig =
decltype(GemmQuantTypeConfig<ck_tile::pk_int4_t, ck_tile::bf8_t, float, float>{});
return run_gemm_example_prec_type<GemmConfig<ck_tile::pk_int4_t>, TypeConfig, 128>(
a_layout, b_layout, argc, argv);
}
else
{
throw std::runtime_error("Unsupported data type for this operation !!!");
}
}
int main(int argc, char* argv[]) { return !run_gemm_example<GemmConfigPreshufle_AQ>(argc, argv); }
int main(int argc, char* argv[])
{
return !run_gemm_example<GemmConfigPreshuffleQuant>(argc, argv);
}

View File

@@ -11,11 +11,9 @@
#include "ck_tile/ops/gemm.hpp"
#include "ck_tile/ops/gemm_group_quant.hpp"
#define CK_TILE_PIPELINE_COMPUTE_V3 1
#define CK_TILE_PIPELINE_MEMORY 2
#define CK_TILE_PIPELINE_COMPUTE_V4 3
#define CK_TILE_PIPELINE_COMPUTE_V5 4
#define CK_TILE_PIPELINE_PRESHUFFLE 5
#define CK_TILE_PIPELINE_PREFILL 1
#define CK_TILE_PIPELINE_DECODE 2
#define CK_TILE_PIPELINE_PRESHUFFLEQUANT 3
template <typename PrecType, ck_tile::index_t M_Warp_Tile>
constexpr ck_tile::index_t get_k_warp_tile()
@@ -87,196 +85,32 @@ struct GemmConfigBase
static constexpr ck_tile::index_t TileParitionerGroupNum = 8;
static constexpr ck_tile::index_t TileParitionerM01 = 4;
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3;
static constexpr ck_tile::index_t NumWaveGroups = 1;
static constexpr bool Preshuffle = false;
static constexpr bool PreshuffleQuant = false;
static constexpr bool DoubleSmemBuffer = false;
};
template <typename PrecType>
struct GemmConfigMemoryInterwave : public GemmConfigBase
struct GemmConfigDecode : public GemmConfigBase
{
// Memory friendly for Interwave scheduler
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 32;
static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType);
static constexpr ck_tile::index_t M_Warp = 4;
static constexpr ck_tile::index_t N_Warp = 1;
static constexpr ck_tile::index_t K_Warp = 1;
static constexpr ck_tile::index_t M_Warp_Tile = 32;
static constexpr ck_tile::index_t N_Warp_Tile = 32;
static constexpr ck_tile::index_t K_Warp_Tile = sizeof(PrecType) == 2 ? 8 : 16;
static constexpr bool DoubleSmemBuffer = false;
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_MEMORY;
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Interwave;
};
template <typename PrecType>
struct GemmConfigMemoryIntrawave : public GemmConfigBase
{
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 32;
static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType);
static constexpr ck_tile::index_t M_Warp = 4;
static constexpr ck_tile::index_t N_Warp = 1;
static constexpr ck_tile::index_t K_Warp = 1;
static constexpr ck_tile::index_t M_Warp_Tile = 32;
static constexpr ck_tile::index_t N_Warp_Tile = 32;
static constexpr ck_tile::index_t K_Warp_Tile = sizeof(PrecType) == 2 ? 8 : 16;
static constexpr bool DoubleSmemBuffer = false;
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_MEMORY;
};
template <typename PrecType>
struct GemmConfigComputeV3 : public GemmConfigBase
{
// Compute V3 only support Intrawave scheduler
static constexpr ck_tile::index_t M_Tile = 32;
static constexpr ck_tile::index_t N_Tile = 128;
static constexpr ck_tile::index_t M_Tile = 16;
static constexpr ck_tile::index_t N_Tile = 64;
static constexpr ck_tile::index_t K_Tile = 256 / sizeof(PrecType);
static constexpr ck_tile::index_t M_Warp = 1;
static constexpr ck_tile::index_t N_Warp = 4;
static constexpr ck_tile::index_t K_Warp = 1;
static constexpr ck_tile::index_t M_Warp_Tile = 32;
static constexpr ck_tile::index_t N_Warp_Tile = 32;
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
static constexpr bool DoubleSmemBuffer = false;
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3;
};
template <typename PrecType>
struct GemmConfigComputeV3_1 : public GemmConfigBase
{
static constexpr ck_tile::index_t M_Tile = 256;
static constexpr ck_tile::index_t N_Tile = 256;
static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType);
static constexpr ck_tile::index_t M_Warp = 2;
static constexpr ck_tile::index_t N_Warp = 2;
static constexpr ck_tile::index_t K_Warp = 1;
static constexpr ck_tile::index_t M_Warp_Tile = 32;
static constexpr ck_tile::index_t N_Warp_Tile = 32;
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
static constexpr bool DoubleSmemBuffer = false;
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3;
};
template <typename PrecType>
struct GemmConfigComputeV3_2 : public GemmConfigBase
{
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 128;
static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType);
static constexpr ck_tile::index_t M_Warp = 2;
static constexpr ck_tile::index_t N_Warp = 2;
static constexpr ck_tile::index_t K_Warp = 1;
static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
static constexpr bool DoubleSmemBuffer = false;
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3;
static constexpr int kBlockPerCu = 2;
};
template <typename PrecType>
struct GemmConfigComputeV4 : public GemmConfigBase
{
// Compute V4 only support Intrawave scheduler
// Using the ping pong reader in the lds level
static constexpr ck_tile::index_t M_Tile = 256;
static constexpr ck_tile::index_t N_Tile = 256;
static constexpr ck_tile::index_t K_Tile = 64 / sizeof(PrecType);
static constexpr ck_tile::index_t M_Warp = 2;
static constexpr ck_tile::index_t N_Warp = 2;
static constexpr ck_tile::index_t K_Warp = 1;
static constexpr ck_tile::index_t M_Warp_Tile = 32;
static constexpr ck_tile::index_t N_Warp_Tile = 32;
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
static constexpr bool DoubleSmemBuffer = true;
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V4;
};
template <typename PrecType>
struct GemmConfigComputeV4_1 : public GemmConfigBase
{
static constexpr ck_tile::index_t M_Tile = 256;
static constexpr ck_tile::index_t N_Tile = 256;
static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType);
static constexpr ck_tile::index_t M_Warp = 2;
static constexpr ck_tile::index_t N_Warp = 2;
static constexpr ck_tile::index_t K_Warp = 1;
static constexpr ck_tile::index_t M_Warp_Tile = 32;
static constexpr ck_tile::index_t N_Warp_Tile = 32;
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
static constexpr bool DoubleSmemBuffer = true;
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V4;
};
template <typename PrecType>
struct GemmConfigComputeV5 : public GemmConfigBase
{
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 128;
static constexpr ck_tile::index_t K_Tile = 64 / sizeof(PrecType);
static constexpr ck_tile::index_t M_Warp = 1;
static constexpr ck_tile::index_t N_Warp = 1;
static constexpr ck_tile::index_t K_Warp = 2;
static constexpr ck_tile::index_t M_Warp_Tile = 32;
static constexpr ck_tile::index_t N_Warp_Tile = 32;
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
static constexpr bool DoubleSmemBuffer = false;
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V5;
static constexpr ck_tile::index_t NumWaNumWaveGroups = 2;
};
template <typename PrecType>
struct GemmConfigPreshufle_1 : public GemmConfigBase
{
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 128;
static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType);
static constexpr ck_tile::index_t M_Warp = 1;
static constexpr ck_tile::index_t N_Warp = 4;
static constexpr ck_tile::index_t K_Warp = 1;
static constexpr ck_tile::index_t M_Warp_Tile = 32;
static constexpr ck_tile::index_t N_Warp_Tile = 32;
static constexpr ck_tile::index_t K_Warp_Tile =
get_k_from_preshuffled_warp_tile<PrecType, M_Warp_Tile>();
static constexpr int kBlockPerCu = 2;
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_PRESHUFFLE;
static constexpr bool Preshuffle = true;
static constexpr bool DoubleSmemBuffer = false;
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_DECODE;
};
template <typename PrecType>
struct GemmConfigPreshufle_2 : public GemmConfigBase
struct GemmConfigPrefill : public GemmConfigBase
{
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 128;
@@ -288,18 +122,15 @@ struct GemmConfigPreshufle_2 : public GemmConfigBase
static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile =
get_k_from_preshuffled_warp_tile<PrecType, M_Warp_Tile>();
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
static constexpr int kBlockPerCu = 2;
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_PRESHUFFLE;
static constexpr bool Preshuffle = true;
static constexpr bool DoubleSmemBuffer = false;
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_PREFILL;
};
template <typename PrecType>
struct GemmConfigPreshufle_AQ : public GemmConfigBase
struct GemmConfigPreshuffleQuant : public GemmConfigBase
{
static constexpr ck_tile::index_t M_Tile = 16;
static constexpr ck_tile::index_t N_Tile = 64;
@@ -314,9 +145,9 @@ struct GemmConfigPreshufle_AQ : public GemmConfigBase
static constexpr ck_tile::index_t K_Warp_Tile =
get_k_from_preshuffled_warp_tile<PrecType, M_Warp_Tile>();
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_PRESHUFFLE;
static constexpr bool Preshuffle = true;
static constexpr bool DoubleSmemBuffer = false;
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_PRESHUFFLEQUANT;
static constexpr bool PreshuffleQuant = true;
};
template <typename ADataType_,
@@ -332,176 +163,6 @@ struct GemmQuantTypeConfig
using CDataType = CDataType_;
};
template <>
struct GemmQuantTypeConfig<ck_tile::half_t>
{
using ADataType = ck_tile::half_t;
using QDataType = float;
using BDataType = ck_tile::half_t;
using AccDataType = float;
using CDataType = ck_tile::half_t;
};
template <>
struct GemmQuantTypeConfig<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t>
{
using ADataType = ck_tile::bf16_t;
using QDataType = float;
using BDataType = ck_tile::bf16_t;
using AccDataType = float;
using CDataType = ck_tile::bf16_t;
};
template <>
struct GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t>
{
using ADataType = ck_tile::fp8_t;
using QDataType = float;
using BDataType = ck_tile::fp8_t;
using AccDataType = float;
using CDataType = ck_tile::half_t;
};
template <>
struct GemmQuantTypeConfig<ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t>
{
using ADataType = ck_tile::bf8_t;
using QDataType = float;
using BDataType = ck_tile::bf8_t;
using AccDataType = float;
using CDataType = ck_tile::half_t;
};
template <>
struct GemmQuantTypeConfig<ck_tile::half_t, ck_tile::pk_int4_t, ck_tile::half_t>
{
using ADataType = ck_tile::half_t;
using QDataType = float;
using BDataType = ck_tile::pk_int4_t;
using AccDataType = float;
using CDataType = ck_tile::half_t;
};
template <>
struct GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, float>
{
using ADataType = ck_tile::fp8_t;
using QDataType = float;
using BDataType = ck_tile::fp8_t;
using AccDataType = float;
using CDataType = ck_tile::half_t;
};
template <>
struct GemmQuantTypeConfig<ck_tile::bf8_t, ck_tile::bf8_t, float>
{
using ADataType = ck_tile::bf8_t;
using QDataType = float;
using BDataType = ck_tile::bf8_t;
using AccDataType = float;
using CDataType = ck_tile::half_t;
};
template <>
struct GemmQuantTypeConfig<ck_tile::pk_int4_t, ck_tile::fp8_t, float, ck_tile::fp8_t>
{
using ADataType = ck_tile::pk_int4_t;
using QDataType = ck_tile::fp8_t;
using BDataType = ck_tile::fp8_t;
using AccDataType = float;
using CDataType = ck_tile::half_t;
};
template <>
struct GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, float, ck_tile::fp8_t>
{
using ADataType = ck_tile::fp8_t;
using QDataType = ck_tile::fp8_t;
using BDataType = ck_tile::fp8_t;
using AccDataType = float;
using CDataType = ck_tile::half_t;
};
template <>
struct GemmQuantTypeConfig<ck_tile::bf8_t, ck_tile::bf8_t, float, ck_tile::bf8_t>
{
using ADataType = ck_tile::bf8_t;
using QDataType = ck_tile::bf8_t;
using BDataType = ck_tile::bf8_t;
using AccDataType = float;
using CDataType = ck_tile::half_t;
};
template <>
struct GemmQuantTypeConfig<ck_tile::pk_int4_t, ck_tile::bf8_t, float, ck_tile::bf8_t>
{
using ADataType = ck_tile::pk_int4_t;
using QDataType = ck_tile::bf8_t;
using BDataType = ck_tile::bf8_t;
using AccDataType = float;
using CDataType = ck_tile::half_t;
};
template <>
struct GemmQuantTypeConfig<ck_tile::pk_int4_t, ck_tile::fp8_t, float, float>
{
using ADataType = ck_tile::pk_int4_t;
using QDataType = float;
using BDataType = ck_tile::fp8_t;
using AccDataType = float;
using CDataType = ck_tile::half_t;
};
template <>
struct GemmQuantTypeConfig<ck_tile::pk_int4_t, ck_tile::bf8_t, float, float>
{
using ADataType = ck_tile::pk_int4_t;
using QDataType = float;
using BDataType = ck_tile::bf8_t;
using AccDataType = float;
using CDataType = ck_tile::half_t;
};
template <>
struct GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::pk_int4_t, float, ck_tile::fp8_t>
{
using ADataType = ck_tile::fp8_t;
using QDataType = ck_tile::fp8_t;
using BDataType = ck_tile::pk_int4_t;
using AccDataType = float;
using CDataType = ck_tile::half_t;
};
template <>
struct GemmQuantTypeConfig<ck_tile::bf8_t, ck_tile::pk_int4_t, float, ck_tile::bf8_t>
{
using ADataType = ck_tile::bf8_t;
using QDataType = ck_tile::bf8_t;
using BDataType = ck_tile::pk_int4_t;
using AccDataType = float;
using CDataType = ck_tile::half_t;
};
template <>
struct GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::pk_int4_t, float, float>
{
using ADataType = ck_tile::fp8_t;
using QDataType = float;
using BDataType = ck_tile::pk_int4_t;
using AccDataType = float;
using CDataType = ck_tile::half_t;
};
template <>
struct GemmQuantTypeConfig<ck_tile::bf8_t, ck_tile::pk_int4_t, float, float>
{
using ADataType = ck_tile::bf8_t;
using QDataType = float;
using BDataType = ck_tile::pk_int4_t;
using AccDataType = float;
using CDataType = ck_tile::half_t;
};
template <typename T>
struct DataTypeTraits;
@@ -559,55 +220,6 @@ struct DataTypeTraits<ck_tile::int8_t>
static constexpr const char* name = "int8";
};
template <ck_tile::index_t PipelineId>
struct PipelineTypeTraits;
template <>
struct PipelineTypeTraits<CK_TILE_PIPELINE_MEMORY>
{
template <typename PipelineProblem>
using GemmPipeline = ck_tile::GemmPipelineAgBgCrMem<PipelineProblem>;
template <typename PipelineProblem>
using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrMem<PipelineProblem>;
};
template <>
struct PipelineTypeTraits<CK_TILE_PIPELINE_COMPUTE_V3>
{
template <typename PipelineProblem>
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3<PipelineProblem>;
template <typename PipelineProblem>
using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV3<PipelineProblem>;
};
template <>
struct PipelineTypeTraits<CK_TILE_PIPELINE_COMPUTE_V4>
{
template <typename PipelineProblem>
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV4<PipelineProblem>;
template <typename PipelineProblem>
using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV4<PipelineProblem>;
};
template <>
struct PipelineTypeTraits<CK_TILE_PIPELINE_COMPUTE_V5>
{
template <typename PipelineProblem>
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV5<PipelineProblem>;
template <typename PipelineProblem>
using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV5<PipelineProblem>;
};
template <>
struct PipelineTypeTraits<CK_TILE_PIPELINE_PRESHUFFLE>
{
template <typename PipelineProblem>
using GemmPipeline = ck_tile::WeightPreshufflePipelineAGmemBGmemCRegV1<PipelineProblem>;
template <typename PipelineProblem>
using UniversalGemmPipeline =
ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV1<PipelineProblem>;
};
auto create_args(int argc, char* argv[])
{
ck_tile::ArgParser arg_parser;

View File

@@ -31,7 +31,8 @@ auto shuffle_aq(const ck_tile::HostTensor<T>& t, int block_aq_k)
return ck_tile::reference_permute(t_view, {1, 0, 2});
}
template <typename ADataType,
template <typename GemmConfig,
typename ADataType,
typename AQDataType,
typename BDataType,
typename AccDataType,
@@ -40,8 +41,7 @@ template <typename ADataType,
typename AQLayout,
typename BLayout,
typename CLayout,
uint32_t QuantGroupSize,
bool Preshuffle = false>
uint32_t QuantGroupSize>
float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
ck_tile::DeviceMem& aq_m_aqk_dev_buf,
ck_tile::DeviceMem& b_k_n_dev_buf,
@@ -73,7 +73,8 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
args.stride_C = stride_C;
args.stride_AQ = stride_AQ;
float ave_time = gemm_calc_aquant<ADataType,
float ave_time = gemm_calc_aquant<GemmConfig,
ADataType,
AQDataType,
BDataType,
AccDataType,
@@ -82,8 +83,7 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
ALayout,
BLayout,
CLayout,
QuantGroupSize,
Preshuffle>(
QuantGroupSize>(
args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat});
std::size_t flop = std::size_t(2) * M * N * K;
@@ -206,7 +206,7 @@ int run_gemm_example_with_layouts(int argc,
ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size_in_bytes());
ck_tile::DeviceMem c_m_n_dev_buf(c_m_n_dev_result.get_element_space_size_in_bytes());
if constexpr(GemmConfig::Preshuffle)
if constexpr(GemmConfig::PreshuffleQuant)
{
ck_tile::HostTensor<AQDataType> aq_shuffle_host =
shuffle_aq(aq_m_aqk, GemmConfig::K_Tile / QuantGroupSize);
@@ -222,7 +222,8 @@ int run_gemm_example_with_layouts(int argc,
c_m_n_dev_buf.SetZero();
c_m_n_dev_result.SetZero();
invoke_gemm<ADataType,
invoke_gemm<GemmConfig,
ADataType,
AQDataType,
BDataType,
AccDataType,
@@ -231,22 +232,21 @@ int run_gemm_example_with_layouts(int argc,
AQLayout,
BLayout,
CLayout,
QuantGroupSize,
GemmConfig::Preshuffle>(a_m_k_dev_buf,
aq_m_aqk_dev_buf,
b_k_n_dev_buf,
c_m_n_dev_buf,
M,
N,
K,
AQK,
stride_A,
stride_AQ,
stride_B,
stride_C,
kbatch,
n_warmup,
n_repeat);
QuantGroupSize>(a_m_k_dev_buf,
aq_m_aqk_dev_buf,
b_k_n_dev_buf,
c_m_n_dev_buf,
M,
N,
K,
AQK,
stride_A,
stride_AQ,
stride_B,
stride_C,
kbatch,
n_warmup,
n_repeat);
c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data());
bool pass = true;

View File

@@ -157,7 +157,7 @@ struct AQuantBlockUniversalGemmAsBsCr : public BlockGemmQuantBase<Problem_>
static constexpr index_t KPack = WarpGemm::kKPerThread;
static constexpr index_t KPerThread = KIterPerWarp * WarpGemm::kKPerThread;
static constexpr bool Preshuffle = Problem::Traits::Preshuffle;
static constexpr bool PreshuffleQuant = Problem::Traits::PreshuffleQuant;
};
public:
@@ -357,7 +357,7 @@ struct AQuantBlockUniversalGemmAsBsCr : public BlockGemmQuantBase<Problem_>
}
});
if constexpr(Traits::Preshuffle)
if constexpr(Traits::PreshuffleQuant)
{
// A view is created on top of the preshuffled AQ, where each row of the
// view is composed of a row from a warp tile within an AQ block tile.
@@ -392,12 +392,27 @@ struct AQuantBlockUniversalGemmAsBsCr : public BlockGemmQuantBase<Problem_>
// Thread 0 can read AQ_tile[0, 0] from itself, AQ_tile[1, 0]
// from thread 1, ..., and AQ_tile[3, 0] from thread 3.
auto pull_from_lane =
((threadIdx.x & (warp_size - 1)) / Traits::WarpGemm::kN *
kTileRowsOfCPerThread +
c_row) *
Traits::QScalesPerBlockRow +
kQScale;
decltype(threadIdx.x) pull_from_lane = 0;
if constexpr(WarpGemm::kM == 16)
{
pull_from_lane = (__lane_id() / Traits::WarpGemm::kN *
kTileRowsOfCPerThread +
c_row) *
Traits::QScalesPerBlockRow +
kQScale;
}
else if constexpr(WarpGemm::kM == 32)
{
pull_from_lane = (__lane_id() / Traits::WarpGemm::kN *
kTileRowsOfCPerThread +
((c_row >> 2) << 3) + (c_row & 0b11)) *
Traits::QScalesPerBlockRow +
kQScale;
}
else
{
static_assert(false, "WarpGemm::kM is not 16 nor 32.");
}
auto& scale_reg = aq_block_tensor.get_thread_buffer()[mIter];
// cross lane ops

View File

@@ -99,15 +99,15 @@ struct AQuantGemmKernelArgs
template <typename TilePartitioner_, typename GemmPipeline_, typename EpiloguePipeline_>
struct AQuantGemmKernel
{
using TilePartitioner = remove_cvref_t<TilePartitioner_>;
using GemmPipeline = remove_cvref_t<GemmPipeline_>;
using EpiloguePipeline = remove_cvref_t<EpiloguePipeline_>;
using ALayout = remove_cvref_t<typename GemmPipeline::ALayout>;
using AQLayout = remove_cvref_t<typename GemmPipeline::AQLayout>;
using BLayout = remove_cvref_t<typename GemmPipeline::BLayout>;
using CLayout = remove_cvref_t<typename GemmPipeline::CLayout>;
static constexpr index_t kBlockSize = GemmPipeline::BlockSize;
static constexpr bool Preshuffle = GemmPipeline::Preshuffle;
using TilePartitioner = remove_cvref_t<TilePartitioner_>;
using GemmPipeline = remove_cvref_t<GemmPipeline_>;
using EpiloguePipeline = remove_cvref_t<EpiloguePipeline_>;
using ALayout = remove_cvref_t<typename GemmPipeline::ALayout>;
using AQLayout = remove_cvref_t<typename GemmPipeline::AQLayout>;
using BLayout = remove_cvref_t<typename GemmPipeline::BLayout>;
using CLayout = remove_cvref_t<typename GemmPipeline::CLayout>;
static constexpr index_t KernelBlockSize = GemmPipeline::BlockSize;
static constexpr bool PreshuffleQuant = GemmPipeline::PreshuffleQuant;
using ADataType = remove_cvref_t<typename GemmPipeline::ADataType>;
using AQDataType = remove_cvref_t<typename GemmPipeline::AQDataType>;
@@ -422,9 +422,9 @@ struct AQuantGemmKernel
ck_tile::integer_least_multiple(wave_tile_size, get_warp_size());
const auto aq_merge_pad1_desc = transform_tensor_descriptor(
aq_pad1_desc,
make_tuple(make_merge_transform(make_tuple(wave_tile_count_x, aq_y)),
make_tuple(make_merge_transform(make_tuple(aq_y, wave_tile_count_x)),
make_pass_through_transform(pad_wave_size)),
make_tuple(sequence<1, 0>{}, sequence<2>{}),
make_tuple(sequence<0, 1>{}, sequence<2>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
return make_tensor_view<address_space_enum::global>(aq_ptr, aq_merge_pad1_desc);
@@ -432,7 +432,7 @@ struct AQuantGemmKernel
const auto& aq_tensor_view = [&]() {
static_assert(std::is_same_v<AQLayout, tensor_layout::gemm::RowMajor>);
if constexpr(Preshuffle)
if constexpr(PreshuffleQuant)
{
return make_preshuffled_aq_tensor_view();
}
@@ -599,10 +599,8 @@ struct AQuantGemmKernel
}
template <typename PadView>
CK_TILE_DEVICE static auto MakeGemmTileWindows(const PadView& views,
const AQuantGemmKernelArgs& kargs,
const index_t i_m,
const index_t i_n)
CK_TILE_DEVICE static auto
MakeGemmTileWindows(const PadView& views, const index_t i_m, const index_t i_n)
{
const auto& a_pad_view = views.at(I0);
const auto& aq_pad_view = views.at(I1);
@@ -628,24 +626,27 @@ struct AQuantGemmKernel
const auto& aq_block_window = [&]() {
static_assert(std::is_same_v<AQLayout, tensor_layout::gemm::RowMajor>);
if constexpr(Preshuffle)
constexpr auto block_m = TilePartitioner::MPerBlock;
constexpr auto block_k = TilePartitioner::KPerBlock;
constexpr auto warp_m = TilePartitioner::BlockGemmShape::WarpTile::at(I0);
constexpr auto aqk_per_block =
TilePartitioner::KPerBlock / GemmPipeline::QuantGroupSize;
if constexpr(PreshuffleQuant)
{
constexpr auto tile_window_width = get_warp_size();
constexpr auto tile_window_height =
TilePartitioner::MPerBlock / TilePartitioner::BlockGemmShape::WarpTile::at(I0);
auto block_m_idx = i_m / TilePartitioner::MPerBlock;
constexpr auto tile_window_width =
ck_tile::integer_least_multiple(warp_m * aqk_per_block, get_warp_size());
constexpr auto tile_window_height = block_m / warp_m;
auto block_m_idx = i_m / block_m;
return make_tile_window(
aq_pad_view,
make_tuple(number<tile_window_height>{}, number<tile_window_width>{}),
{block_m_idx * kargs.K / TilePartitioner::BlockGemmShape::BlockTile::at(I2),
0});
{block_m_idx * tile_window_height, 0});
}
else
{
return make_tile_window(
aq_pad_view,
make_tuple(number<TilePartitioner::MPerBlock>{},
number<TilePartitioner::KPerBlock / GemmPipeline::QuantGroupSize>{}),
make_tuple(number<block_m>{}, number<block_k / GemmPipeline::QuantGroupSize>{}),
{i_m, 0});
}
}();
@@ -706,8 +707,7 @@ struct AQuantGemmKernel
a_ptr, b_ptr, aq_ptr, c_ptr, kargs, splitk_batch_offset);
const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple);
auto gemm_tile_windows =
MakeGemmTileWindows(gemm_pad_views, kargs, block_idx_m, block_idx_n);
auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
const index_t num_loop = __builtin_amdgcn_readfirstlane(
TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k));
@@ -718,7 +718,7 @@ struct AQuantGemmKernel
const auto& b_block_window = gemm_tile_windows.at(I2);
const auto& c_block_tile = GemmPipeline{}.template operator()(
a_block_window, b_block_window, aq_block_window, num_loop, smem_ptr_0);
a_block_window, b_block_window, aq_block_window, kargs.M, num_loop, smem_ptr_0);
// Run Epilogue Pipeline
auto& c_block_window = gemm_tile_windows.at(I3);

View File

@@ -37,23 +37,23 @@ struct GemmAQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgC
using AQLayout = remove_cvref_t<typename Problem::AQLayout>;
using BlockGemmShape = typename Problem::BlockGemmShape;
constexpr index_t BlockSize = Problem::kBlockSize;
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t KPerBlockAQ = KPerBlock / Problem::kQuantGroupSize;
constexpr index_t VecLoadSize = GetVectorSizeAQ<Problem>();
constexpr bool Preshuffle = Problem::Traits::Preshuffle;
using WarpTile = typename Problem::BlockGemmShape::WarpTile;
using WarpGemm = WarpGemmDispatcher<typename Problem::ComputeDataType,
typename Problem::ComputeDataType,
typename Problem::CDataType,
WarpTile::at(I0),
WarpTile::at(I1),
WarpTile::at(I2),
false>;
constexpr index_t BlockSize = Problem::kBlockSize;
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t KPerBlockAQ = KPerBlock / Problem::kQuantGroupSize;
constexpr index_t VecLoadSize = GetVectorSizeAQ<Problem>();
constexpr bool PreshuffleQuant = Problem::Traits::PreshuffleQuant;
using WarpTile = typename Problem::BlockGemmShape::WarpTile;
using WarpGemm = WarpGemmDispatcher<typename Problem::ComputeDataType,
typename Problem::ComputeDataType,
typename Problem::CDataType,
WarpTile::at(I0),
WarpTile::at(I1),
WarpTile::at(I2),
false>;
static_assert(std::is_same_v<AQLayout, tensor_layout::gemm::RowMajor>);
if constexpr(Preshuffle)
if constexpr(PreshuffleQuant)
{
using TileEncodingPattern =
TileDistributionEncodingPatternAQ<BlockGemmShape,
@@ -64,7 +64,7 @@ struct GemmAQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgC
WarpGemm::kM * KPerBlockAQ, get_warp_size()),
KPerBlockAQ,
VecLoadSize,
Preshuffle>;
PreshuffleQuant>;
return TileEncodingPattern::Make2DStaticTileDistribution();
}
@@ -77,7 +77,7 @@ struct GemmAQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgC
KPerBlockAQ,
KPerBlockAQ,
VecLoadSize,
Preshuffle>;
PreshuffleQuant>;
return TileEncodingPattern::Make2DStaticTileDistribution();
}

View File

@@ -7,6 +7,7 @@
#include <sstream>
#include "ck_tile/core.hpp"
#include "ck_tile/core/numeric/math.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
#include "ck_tile/ops/gemm_group_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_base.hpp"
#include "ck_tile/host/concat.hpp"
@@ -133,7 +134,7 @@ struct AQuantGemmPipelineAgBgCrCompV3 : public BaseAQuantGemmPipelineAgBgCrCompV
static constexpr bool kPadK = Problem::kPadK;
static constexpr bool DoubleSmemBuffer = Problem::DoubleSmemBuffer;
static constexpr bool Preshuffle = Problem::Traits::Preshuffle;
static constexpr bool PreshuffleQuant = Problem::Traits::PreshuffleQuant;
static constexpr bool HasHotLoop = Problem::HasHotLoop;
static constexpr auto TailNum = Problem::TailNum;
@@ -235,6 +236,7 @@ struct AQuantGemmPipelineAgBgCrCompV3 : public BaseAQuantGemmPipelineAgBgCrCompV
const BDramBlockWindowTmp& b_dram_block_window_tmp,
const BElementFunction& b_element_func,
const AQDramBlockWindowTmp& aq_dram_block_window_tmp,
index_t m,
index_t num_loop,
void* p_smem) const
{
@@ -311,9 +313,11 @@ struct AQuantGemmPipelineAgBgCrCompV3 : public BaseAQuantGemmPipelineAgBgCrCompV
is_b_row_major ? make_array(KPerBlock, 0) : make_array(0, KPerBlock);
// only row_major for AQ
constexpr AQDramTileWindowStep aq_dram_tile_window_step =
Preshuffle ? make_array(MPerBlock / BlockGemm::WarpGemm::kM, 0)
: make_array(0, KPerBlockAQ);
const AQDramTileWindowStep aq_dram_tile_window_step =
PreshuffleQuant ? make_array(ck_tile::integer_least_multiple(m, MPerBlock) /
BlockGemm::WarpGemm::kM,
0)
: make_array(0, KPerBlockAQ);
// DRAM prefetch (global read 0)
Base::GlobalPrefetch(a_block_tile, a_copy_dram_window, a_dram_tile_window_step);
@@ -458,6 +462,7 @@ struct AQuantGemmPipelineAgBgCrCompV3 : public BaseAQuantGemmPipelineAgBgCrCompV
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
const BDramBlockWindowTmp& b_dram_block_window_tmp,
const AQDramBlockWindowTmp& aq_dram_block_window_tmp,
index_t m,
index_t num_loop,
void* p_smem) const
{
@@ -467,6 +472,7 @@ struct AQuantGemmPipelineAgBgCrCompV3 : public BaseAQuantGemmPipelineAgBgCrCompV
b_dram_block_window_tmp,
[](const BDataType& b) { return b; },
aq_dram_block_window_tmp,
m,
num_loop,
p_smem);
}

View File

@@ -52,7 +52,7 @@ template <typename BlockGemmShape,
index_t XPerTile,
index_t KPerBlockAQ,
index_t VecSize,
bool Preshuffle>
bool PreshuffleQuant>
struct TileDistributionEncodingPatternAQ : public TileDistributionEncodingPattern
{
static_assert(XPerTile % VecSize == 0, "XPerTile must be a multiple of VecSize!");
@@ -72,20 +72,20 @@ struct TileDistributionEncodingPatternAQ : public TileDistributionEncodingPatter
CK_TILE_HOST_DEVICE static constexpr auto Make2DStaticTileDistribution()
{
if constexpr(Preshuffle)
if constexpr(PreshuffleQuant)
{
// # of elements per thread
constexpr index_t X2 = KPerBlockAQ;
constexpr index_t X1 = warp_size / X2;
static_assert(XPerTile >= warp_size && XPerTile % warp_size == 0);
constexpr index_t X1 = warp_size;
constexpr index_t X0 = XPerTile / warp_size;
constexpr index_t Y1 = MWarps;
constexpr index_t Y0 = YPerTile / Y1;
return make_static_tile_distribution(
tile_distribution_encoding<sequence<NWarps>,
tuple<sequence<Y0, Y1>, sequence<X0, X1, X2>>,
tuple<sequence<1, 0>, sequence<2, 2>>,
tuple<sequence<1, 0>, sequence<1, 2>>,
tuple<sequence<Y0, Y1>, sequence<X0, X1>>,
tuple<sequence<1, 0>, sequence<2>>,
tuple<sequence<1, 0>, sequence<1>>,
sequence<1, 2>,
sequence<0, 0>>{});
}

View File

@@ -10,7 +10,7 @@ namespace ck_tile {
template <bool kPadM_,
bool kPadN_,
bool kPadK_,
bool Preshuffle_,
bool PreshuffleQuant_,
typename ALayout_,
typename BLayout_,
typename CLayout_,
@@ -30,7 +30,7 @@ struct TileGemmAQuantTraits
static constexpr bool UseStructuredSparsity = false;
static constexpr index_t NumWaveGroups = 1;
static constexpr bool Preshuffle = Preshuffle_;
static constexpr bool PreshuffleQuant = PreshuffleQuant_;
};
} // namespace ck_tile

View File

@@ -11,11 +11,9 @@
#include "ck_tile/ops/gemm.hpp"
#include "ck_tile/ops/gemm_group_quant.hpp"
#define CK_TILE_PIPELINE_COMPUTE_V3 1
#define CK_TILE_PIPELINE_MEMORY 2
#define CK_TILE_PIPELINE_COMPUTE_V4 3
#define CK_TILE_PIPELINE_COMPUTE_V5 4
#define CK_TILE_PIPELINE_PRESHUFFLE 5
#define CK_TILE_PIPELINE_PREFILL 1
#define CK_TILE_PIPELINE_DECODE 2
#define CK_TILE_PIPELINE_PRESHUFFLEQUANT 3
template <typename PrecType, ck_tile::index_t M_Warp_Tile>
constexpr ck_tile::index_t get_k_warp_tile()
@@ -34,21 +32,6 @@ constexpr ck_tile::index_t get_k_warp_tile()
return 32;
#endif
}
template <typename PrecType, ck_tile::index_t M_Warp_Tile>
constexpr ck_tile::index_t get_k_warp_tile_flatmm()
{
#if defined(__gfx950__)
if constexpr(M_Warp_Tile == 32)
return sizeof(PrecType) == 2 ? 16 : 64;
else
return sizeof(PrecType) == 2 ? 32 : 128;
#else
if constexpr(M_Warp_Tile == 32)
return sizeof(PrecType) == 2 ? 16 : 32;
else
return sizeof(PrecType) == 2 ? 32 : 64;
#endif
}
template <typename ADataType, typename BDataType, typename AccDataType, typename CDataType>
auto calculate_rtol_atol(const ck_tile::index_t K,
@@ -93,195 +76,32 @@ struct GemmConfigBase
static constexpr ck_tile::index_t TileParitionerGroupNum = 8;
static constexpr ck_tile::index_t TileParitionerM01 = 4;
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3;
static constexpr ck_tile::index_t NumWaveGroups = 1;
static constexpr bool Preshuffle = false;
static constexpr bool PreshuffleQuant = false;
static constexpr bool DoubleSmemBuffer = true;
};
template <typename PrecType>
struct GemmConfigMemoryInterwave : public GemmConfigBase
struct GemmConfigDecode : public GemmConfigBase
{
// Memory friendly for Interwave scheduler
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 32;
static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType);
static constexpr ck_tile::index_t M_Warp = 4;
static constexpr ck_tile::index_t N_Warp = 1;
static constexpr ck_tile::index_t K_Warp = 1;
static constexpr ck_tile::index_t M_Warp_Tile = 32;
static constexpr ck_tile::index_t N_Warp_Tile = 32;
static constexpr ck_tile::index_t K_Warp_Tile = sizeof(PrecType) == 2 ? 8 : 16;
static constexpr bool DoubleSmemBuffer = false;
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_MEMORY;
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Interwave;
};
template <typename PrecType>
struct GemmConfigMemoryIntrawave : public GemmConfigBase
{
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 32;
static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType);
static constexpr ck_tile::index_t M_Warp = 4;
static constexpr ck_tile::index_t N_Warp = 1;
static constexpr ck_tile::index_t K_Warp = 1;
static constexpr ck_tile::index_t M_Warp_Tile = 32;
static constexpr ck_tile::index_t N_Warp_Tile = 32;
static constexpr ck_tile::index_t K_Warp_Tile = sizeof(PrecType) == 2 ? 8 : 16;
static constexpr bool DoubleSmemBuffer = false;
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_MEMORY;
};
template <typename PrecType>
struct GemmConfigComputeV3 : public GemmConfigBase
{
// Compute V3 only support Intrawave scheduler
static constexpr ck_tile::index_t M_Tile = 32;
static constexpr ck_tile::index_t N_Tile = 128;
static constexpr ck_tile::index_t K_Tile = 256;
static constexpr ck_tile::index_t M_Tile = 16;
static constexpr ck_tile::index_t N_Tile = 64;
static constexpr ck_tile::index_t K_Tile = 256 / sizeof(PrecType);
static constexpr ck_tile::index_t M_Warp = 1;
static constexpr ck_tile::index_t N_Warp = 4;
static constexpr ck_tile::index_t K_Warp = 1;
static constexpr ck_tile::index_t M_Warp_Tile = 32;
static constexpr ck_tile::index_t N_Warp_Tile = 32;
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
static constexpr bool DoubleSmemBuffer = false;
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3;
};
template <typename PrecType>
struct GemmConfigComputeV3_1 : public GemmConfigBase
{
static constexpr ck_tile::index_t M_Tile = 256;
static constexpr ck_tile::index_t N_Tile = 256;
static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType);
static constexpr ck_tile::index_t M_Warp = 2;
static constexpr ck_tile::index_t N_Warp = 2;
static constexpr ck_tile::index_t K_Warp = 1;
static constexpr ck_tile::index_t M_Warp_Tile = 32;
static constexpr ck_tile::index_t N_Warp_Tile = 32;
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
static constexpr bool DoubleSmemBuffer = false;
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3;
};
template <typename PrecType>
struct GemmConfigComputeV3_2 : public GemmConfigBase
{
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 128;
static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType);
static constexpr ck_tile::index_t M_Warp = 2;
static constexpr ck_tile::index_t N_Warp = 2;
static constexpr ck_tile::index_t K_Warp = 1;
static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
static constexpr bool DoubleSmemBuffer = false;
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3;
static constexpr int kBlockPerCu = 2;
};
template <typename PrecType>
struct GemmConfigComputeV4 : public GemmConfigBase
{
// Compute V4 only support Intrawave scheduler
// Using the ping pong reader in the lds level
static constexpr ck_tile::index_t M_Tile = 256;
static constexpr ck_tile::index_t N_Tile = 256;
static constexpr ck_tile::index_t K_Tile = 64 / sizeof(PrecType);
static constexpr ck_tile::index_t M_Warp = 2;
static constexpr ck_tile::index_t N_Warp = 2;
static constexpr ck_tile::index_t K_Warp = 1;
static constexpr ck_tile::index_t M_Warp_Tile = 32;
static constexpr ck_tile::index_t N_Warp_Tile = 32;
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
static constexpr bool DoubleSmemBuffer = true;
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V4;
};
template <typename PrecType>
struct GemmConfigComputeV4_1 : public GemmConfigBase
{
static constexpr ck_tile::index_t M_Tile = 256;
static constexpr ck_tile::index_t N_Tile = 256;
static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType);
static constexpr ck_tile::index_t M_Warp = 2;
static constexpr ck_tile::index_t N_Warp = 2;
static constexpr ck_tile::index_t K_Warp = 1;
static constexpr ck_tile::index_t M_Warp_Tile = 32;
static constexpr ck_tile::index_t N_Warp_Tile = 32;
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
static constexpr bool DoubleSmemBuffer = true;
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V4;
};
template <typename PrecType>
struct GemmConfigComputeV5 : public GemmConfigBase
{
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 128;
static constexpr ck_tile::index_t K_Tile = 64 / sizeof(PrecType);
static constexpr ck_tile::index_t M_Warp = 1;
static constexpr ck_tile::index_t N_Warp = 1;
static constexpr ck_tile::index_t K_Warp = 2;
static constexpr ck_tile::index_t M_Warp_Tile = 32;
static constexpr ck_tile::index_t N_Warp_Tile = 32;
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
static constexpr bool DoubleSmemBuffer = false;
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V5;
static constexpr ck_tile::index_t NumWaNumWaveGroups = 2;
};
template <typename PrecType>
struct GemmConfigPreshufle_1 : public GemmConfigBase
{
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 128;
static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType);
static constexpr ck_tile::index_t M_Warp = 1;
static constexpr ck_tile::index_t N_Warp = 4;
static constexpr ck_tile::index_t K_Warp = 1;
static constexpr ck_tile::index_t M_Warp_Tile = 32;
static constexpr ck_tile::index_t N_Warp_Tile = 32;
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile_flatmm<PrecType, M_Warp_Tile>();
static constexpr int kBlockPerCu = 2;
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_PRESHUFFLE;
static constexpr bool Preshuffle = true;
static constexpr bool DoubleSmemBuffer = false;
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_DECODE;
};
template <typename PrecType>
struct GemmConfigPreshufle_2 : public GemmConfigBase
struct GemmConfigPrefill : public GemmConfigBase
{
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 128;
@@ -293,71 +113,32 @@ struct GemmConfigPreshufle_2 : public GemmConfigBase
static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile_flatmm<PrecType, M_Warp_Tile>();
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
static constexpr int kBlockPerCu = 2;
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_PRESHUFFLE;
static constexpr bool Preshuffle = true;
static constexpr bool DoubleSmemBuffer = false;
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_PREFILL;
};
template <typename ADataType, typename BDataType = ADataType, typename CDataType = ADataType>
struct GemmTypeConfig;
template <>
struct GemmTypeConfig<ck_tile::half_t>
template <typename PrecType>
struct GemmConfigPreshuffleQuant : public GemmConfigBase
{
using ADataType = ck_tile::half_t;
using BDataType = ck_tile::half_t;
using AccDataType = float;
using CDataType = ck_tile::half_t;
// ToDo: Add more bias config to support different categories of GEMM.
};
static constexpr ck_tile::index_t M_Tile = 16;
static constexpr ck_tile::index_t N_Tile = 64;
static constexpr ck_tile::index_t K_Tile = 256 / sizeof(PrecType);
template <>
struct GemmTypeConfig<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t>
{
using ADataType = ck_tile::bf16_t;
using BDataType = ck_tile::bf16_t;
using AccDataType = float;
using CDataType = ck_tile::bf16_t;
};
static constexpr ck_tile::index_t M_Warp = 1;
static constexpr ck_tile::index_t N_Warp = 4;
static constexpr ck_tile::index_t K_Warp = 1;
template <>
struct GemmTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t>
{
using ADataType = ck_tile::fp8_t;
using BDataType = ck_tile::fp8_t;
using AccDataType = float;
using CDataType = ck_tile::half_t;
};
static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile =
get_k_from_preshuffled_warp_tile<PrecType, M_Warp_Tile>();
template <>
struct GemmTypeConfig<ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t>
{
using ADataType = ck_tile::bf8_t;
using BDataType = ck_tile::bf8_t;
using AccDataType = float;
using CDataType = ck_tile::half_t;
};
template <>
struct GemmTypeConfig<ck_tile::half_t, ck_tile::pk_int4_t, ck_tile::half_t>
{
using ADataType = ck_tile::half_t;
using BDataType = ck_tile::pk_int4_t;
using AccDataType = float;
using CDataType = ck_tile::half_t;
};
template <>
struct GemmTypeConfig<ck_tile::int8_t, ck_tile::int8_t, int32_t>
{
using ADataType = ck_tile::int8_t;
using BDataType = ck_tile::int8_t;
using AccDataType = int32_t;
using CDataType = int32_t;
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_PRESHUFFLEQUANT;
static constexpr bool PreshuffleQuant = true;
};
template <typename ADataType_,
@@ -373,176 +154,6 @@ struct GemmQuantTypeConfig
using CDataType = CDataType_;
};
template <>
struct GemmQuantTypeConfig<ck_tile::half_t>
{
using ADataType = ck_tile::half_t;
using QDataType = float;
using BDataType = ck_tile::half_t;
using AccDataType = float;
using CDataType = ck_tile::half_t;
};
template <>
struct GemmQuantTypeConfig<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t>
{
using ADataType = ck_tile::bf16_t;
using QDataType = float;
using BDataType = ck_tile::bf16_t;
using AccDataType = float;
using CDataType = ck_tile::bf16_t;
};
template <>
struct GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t>
{
using ADataType = ck_tile::fp8_t;
using QDataType = float;
using BDataType = ck_tile::fp8_t;
using AccDataType = float;
using CDataType = ck_tile::half_t;
};
template <>
struct GemmQuantTypeConfig<ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t>
{
using ADataType = ck_tile::bf8_t;
using QDataType = float;
using BDataType = ck_tile::bf8_t;
using AccDataType = float;
using CDataType = ck_tile::half_t;
};
template <>
struct GemmQuantTypeConfig<ck_tile::half_t, ck_tile::pk_int4_t, ck_tile::half_t>
{
using ADataType = ck_tile::half_t;
using QDataType = float;
using BDataType = ck_tile::pk_int4_t;
using AccDataType = float;
using CDataType = ck_tile::half_t;
};
template <>
struct GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, float>
{
using ADataType = ck_tile::fp8_t;
using QDataType = float;
using BDataType = ck_tile::fp8_t;
using AccDataType = float;
using CDataType = float;
};
template <>
struct GemmQuantTypeConfig<ck_tile::bf8_t, ck_tile::bf8_t, float>
{
using ADataType = ck_tile::bf8_t;
using QDataType = float;
using BDataType = ck_tile::bf8_t;
using AccDataType = float;
using CDataType = float;
};
template <>
struct GemmQuantTypeConfig<ck_tile::pk_int4_t, ck_tile::fp8_t, float, ck_tile::fp8_t>
{
using ADataType = ck_tile::pk_int4_t;
using QDataType = ck_tile::fp8_t;
using BDataType = ck_tile::fp8_t;
using AccDataType = float;
using CDataType = float;
};
template <>
struct GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, float, ck_tile::fp8_t>
{
using ADataType = ck_tile::fp8_t;
using QDataType = ck_tile::fp8_t;
using BDataType = ck_tile::fp8_t;
using AccDataType = float;
using CDataType = float;
};
template <>
struct GemmQuantTypeConfig<ck_tile::bf8_t, ck_tile::bf8_t, float, ck_tile::bf8_t>
{
using ADataType = ck_tile::bf8_t;
using QDataType = ck_tile::bf8_t;
using BDataType = ck_tile::bf8_t;
using AccDataType = float;
using CDataType = float;
};
template <>
struct GemmQuantTypeConfig<ck_tile::pk_int4_t, ck_tile::bf8_t, float, ck_tile::bf8_t>
{
using ADataType = ck_tile::pk_int4_t;
using QDataType = ck_tile::bf8_t;
using BDataType = ck_tile::bf8_t;
using AccDataType = float;
using CDataType = float;
};
template <>
struct GemmQuantTypeConfig<ck_tile::pk_int4_t, ck_tile::fp8_t, float, float>
{
using ADataType = ck_tile::pk_int4_t;
using QDataType = float;
using BDataType = ck_tile::fp8_t;
using AccDataType = float;
using CDataType = float;
};
template <>
struct GemmQuantTypeConfig<ck_tile::pk_int4_t, ck_tile::bf8_t, float, float>
{
using ADataType = ck_tile::pk_int4_t;
using QDataType = float;
using BDataType = ck_tile::bf8_t;
using AccDataType = float;
using CDataType = float;
};
template <>
struct GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::pk_int4_t, float, ck_tile::fp8_t>
{
using ADataType = ck_tile::fp8_t;
using QDataType = ck_tile::fp8_t;
using BDataType = ck_tile::pk_int4_t;
using AccDataType = float;
using CDataType = float;
};
template <>
struct GemmQuantTypeConfig<ck_tile::bf8_t, ck_tile::pk_int4_t, float, ck_tile::bf8_t>
{
using ADataType = ck_tile::bf8_t;
using QDataType = ck_tile::bf8_t;
using BDataType = ck_tile::pk_int4_t;
using AccDataType = float;
using CDataType = float;
};
template <>
struct GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::pk_int4_t, float, float>
{
using ADataType = ck_tile::fp8_t;
using QDataType = float;
using BDataType = ck_tile::pk_int4_t;
using AccDataType = float;
using CDataType = float;
};
template <>
struct GemmQuantTypeConfig<ck_tile::bf8_t, ck_tile::pk_int4_t, float, float>
{
using ADataType = ck_tile::bf8_t;
using QDataType = float;
using BDataType = ck_tile::pk_int4_t;
using AccDataType = float;
using CDataType = float;
};
template <typename T>
struct DataTypeTraits;
@@ -600,55 +211,6 @@ struct DataTypeTraits<ck_tile::int8_t>
static constexpr const char* name = "int8";
};
template <ck_tile::index_t PipelineId>
struct PipelineTypeTraits;
template <>
struct PipelineTypeTraits<CK_TILE_PIPELINE_MEMORY>
{
template <typename PipelineProblem>
using GemmPipeline = ck_tile::GemmPipelineAgBgCrMem<PipelineProblem>;
template <typename PipelineProblem>
using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrMem<PipelineProblem>;
};
template <>
struct PipelineTypeTraits<CK_TILE_PIPELINE_COMPUTE_V3>
{
template <typename PipelineProblem>
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3<PipelineProblem>;
template <typename PipelineProblem>
using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV3<PipelineProblem>;
};
template <>
struct PipelineTypeTraits<CK_TILE_PIPELINE_COMPUTE_V4>
{
template <typename PipelineProblem>
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV4<PipelineProblem>;
template <typename PipelineProblem>
using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV4<PipelineProblem>;
};
template <>
struct PipelineTypeTraits<CK_TILE_PIPELINE_COMPUTE_V5>
{
template <typename PipelineProblem>
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV5<PipelineProblem>;
template <typename PipelineProblem>
using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV5<PipelineProblem>;
};
template <>
struct PipelineTypeTraits<CK_TILE_PIPELINE_PRESHUFFLE>
{
template <typename PipelineProblem>
using GemmPipeline = ck_tile::WeightPreshufflePipelineAGmemBGmemCRegV1<PipelineProblem>;
template <typename PipelineProblem>
using UniversalGemmPipeline =
ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV1<PipelineProblem>;
};
auto create_args(int argc, char* argv[])
{
ck_tile::ArgParser arg_parser;

View File

@@ -15,7 +15,8 @@
#include "ck_tile/host.hpp"
#include "test_gemm_aquant_utils.hpp"
template <typename ADataType,
template <typename GemmConfig,
typename ADataType,
typename AQDataType,
typename BDataType,
typename AccDataType,
@@ -24,8 +25,7 @@ template <typename ADataType,
typename ALayout,
typename BLayout,
typename CLayout,
uint32_t QuantGroupSize,
bool Preshuffle = false>
uint32_t QuantGroupSize>
float gemm_calc_aquant(const ck_tile::AQuantGemmHostArgs& args, const ck_tile::stream_config& s)
{
constexpr bool kPadM = false;
@@ -36,17 +36,17 @@ float gemm_calc_aquant(const ck_tile::AQuantGemmHostArgs& args, const ck_tile::s
static_assert(std::is_same_v<CLayout, ck_tile::tensor_layout::gemm::RowMajor>);
constexpr ck_tile::index_t M_Tile = 16;
constexpr ck_tile::index_t N_Tile = 64;
constexpr ck_tile::index_t K_Tile = 256;
constexpr ck_tile::index_t M_Tile = GemmConfig::M_Tile;
constexpr ck_tile::index_t N_Tile = GemmConfig::N_Tile;
constexpr ck_tile::index_t K_Tile = GemmConfig::K_Tile;
constexpr ck_tile::index_t M_Warp = 1;
constexpr ck_tile::index_t N_Warp = 4;
constexpr ck_tile::index_t K_Warp = 1;
constexpr ck_tile::index_t M_Warp = GemmConfig::M_Warp;
constexpr ck_tile::index_t N_Warp = GemmConfig::N_Warp;
constexpr ck_tile::index_t K_Warp = GemmConfig::K_Warp;
constexpr ck_tile::index_t M_Warp_Tile = 16;
constexpr ck_tile::index_t N_Warp_Tile = 16;
constexpr ck_tile::index_t K_Warp_Tile = 32;
constexpr ck_tile::index_t M_Warp_Tile = GemmConfig::M_Warp_Tile;
constexpr ck_tile::index_t N_Warp_Tile = GemmConfig::N_Warp_Tile;
constexpr ck_tile::index_t K_Warp_Tile = GemmConfig::K_Warp_Tile;
using CodegenGemmShape =
ck_tile::TileGemmShape<ck_tile::sequence<M_Tile, N_Tile, K_Tile>,
@@ -55,8 +55,13 @@ float gemm_calc_aquant(const ck_tile::AQuantGemmHostArgs& args, const ck_tile::s
using TilePartitioner = ck_tile::GemmTile1DPartitioner<CodegenGemmShape>;
using CodegenGemmTraits =
ck_tile::TileGemmAQuantTraits<kPadM, kPadN, kPadK, Preshuffle, ALayout, BLayout, CLayout>;
using CodegenGemmTraits = ck_tile::TileGemmAQuantTraits<kPadM,
kPadN,
kPadK,
GemmConfig::PreshuffleQuant,
ALayout,
BLayout,
CLayout>;
using GemmPipelineProblem = ck_tile::GemmPipelineProblemBase<ADataType,
BDataType,
@@ -152,7 +157,8 @@ static constexpr inline auto is_row_major(Layout layout_)
ck_tile::tensor_layout::gemm::RowMajor>>{};
}
template <typename ADataType,
template <typename GemmConfig,
typename ADataType,
typename AQDataType,
typename BDataType,
typename AccDataType,
@@ -161,8 +167,7 @@ template <typename ADataType,
typename AQLayout,
typename BLayout,
typename CLayout,
uint32_t QuantGroupSize,
bool Preshuffle = false>
uint32_t QuantGroupSize>
float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
ck_tile::DeviceMem& aq_m_aqk_dev_buf,
ck_tile::DeviceMem& b_k_n_dev_buf,
@@ -194,7 +199,8 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
args.stride_C = stride_C;
args.stride_AQ = stride_AQ;
float ave_time = gemm_calc_aquant<ADataType,
float ave_time = gemm_calc_aquant<GemmConfig,
ADataType,
AQDataType,
BDataType,
AccDataType,
@@ -203,8 +209,7 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
ALayout,
BLayout,
CLayout,
QuantGroupSize,
Preshuffle>(
QuantGroupSize>(
args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat});
std::size_t flop = std::size_t(2) * M * N * K;
@@ -227,7 +232,8 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
return ave_time;
}
template <typename TypeConfig,
template <typename GemmConfig,
typename TypeConfig,
uint32_t QuantGroupSize,
typename ALayout,
typename AQLayout,
@@ -332,7 +338,8 @@ bool run_gemm_test_with_layouts(int argc,
c_m_n_dev_buf.SetZero();
c_m_n_dev_result.SetZero();
invoke_gemm<ADataType,
invoke_gemm<GemmConfig,
ADataType,
AQDataType,
BDataType,
AccDataType,
@@ -400,7 +407,7 @@ bool run_gemm_test_with_layouts(int argc,
return pass;
}
template <typename TypeConfig, uint32_t QuantGroupSize>
template <typename GemmConfig, typename TypeConfig, uint32_t QuantGroupSize>
bool run_gemm_test_prec_type(std::string a_layout, std::string b_layout, int argc, char* argv[])
{
using Row = ck_tile::tensor_layout::gemm::RowMajor;
@@ -412,7 +419,7 @@ bool run_gemm_test_prec_type(std::string a_layout, std::string b_layout, int arg
{
if(a_layout == "R" && b_layout == "C")
{
return run_gemm_test_with_layouts<TypeConfig, QuantGroupSize>(
return run_gemm_test_with_layouts<GemmConfig, TypeConfig, QuantGroupSize>(
argc, argv, Row{}, Row{}, Col{}, Row{});
}
else
@@ -428,6 +435,7 @@ bool run_gemm_test_prec_type(std::string a_layout, std::string b_layout, int arg
return true;
}
template <template <typename PreType> typename GemmConfig>
bool run_gemm_test(int argc, char* argv[])
{
auto [result, arg_parser] = create_args(argc, argv);
@@ -441,41 +449,52 @@ bool run_gemm_test(int argc, char* argv[])
if(data_type == "fp8")
{
using TypeConfig =
decltype(GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t>{});
return run_gemm_test_prec_type<TypeConfig, 128>(a_layout, b_layout, argc, argv);
decltype(GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, float>{});
return run_gemm_test_prec_type<GemmConfig<ck_tile::fp8_t>, TypeConfig, 128>(
a_layout, b_layout, argc, argv);
}
else if(data_type == "bf8")
{
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::bf8_t, ck_tile::bf8_t, float>{});
return run_gemm_test_prec_type<TypeConfig, 128>(a_layout, b_layout, argc, argv);
using TypeConfig =
decltype(GemmQuantTypeConfig<ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, float>{});
return run_gemm_test_prec_type<GemmConfig<ck_tile::fp8_t>, TypeConfig, 128>(
a_layout, b_layout, argc, argv);
}
else if(data_type == "i4fp8")
{
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::pk_int4_t,
ck_tile::fp8_t,
float,
ck_tile::half_t,
ck_tile::fp8_t>{});
return run_gemm_test_prec_type<TypeConfig, 128>(a_layout, b_layout, argc, argv);
return run_gemm_test_prec_type<GemmConfig<ck_tile::pk_int4_t>, TypeConfig, 128>(
a_layout, b_layout, argc, argv);
}
else if(data_type == "i4bf8")
{
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::pk_int4_t,
ck_tile::bf8_t,
float,
ck_tile::half_t,
ck_tile::bf8_t>{});
return run_gemm_test_prec_type<TypeConfig, 128>(a_layout, b_layout, argc, argv);
return run_gemm_test_prec_type<GemmConfig<ck_tile::pk_int4_t>, TypeConfig, 128>(
a_layout, b_layout, argc, argv);
}
else if(data_type == "i4f32fp8")
{
using TypeConfig =
decltype(GemmQuantTypeConfig<ck_tile::pk_int4_t, ck_tile::fp8_t, float, float>{});
return run_gemm_test_prec_type<TypeConfig, 128>(a_layout, b_layout, argc, argv);
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::pk_int4_t,
ck_tile::fp8_t,
ck_tile::half_t,
float>{});
return run_gemm_test_prec_type<GemmConfig<ck_tile::pk_int4_t>, TypeConfig, 128>(
a_layout, b_layout, argc, argv);
}
else if(data_type == "i4f32bf8")
{
using TypeConfig =
decltype(GemmQuantTypeConfig<ck_tile::pk_int4_t, ck_tile::bf8_t, float, float>{});
return run_gemm_test_prec_type<TypeConfig, 128>(a_layout, b_layout, argc, argv);
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::pk_int4_t,
ck_tile::bf8_t,
ck_tile::half_t,
float>{});
return run_gemm_test_prec_type<GemmConfig<ck_tile::pk_int4_t>, TypeConfig, 128>(
a_layout, b_layout, argc, argv);
}
else
{
@@ -564,7 +583,7 @@ int run_gemm_combinations(std::string const& data_type)
// Call the function with the current configuration
try
{
is_success = run_gemm_test(ARG_COUNT, argv) && is_success;
is_success = run_gemm_test<GemmConfigDecode>(ARG_COUNT, argv) && is_success;
}
catch(const ArgumentsNotSupportedException& e)
{