mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 14:29:05 +00:00
[CK TILE] Fix bugs in AQuant preshuffle (#2700)
* [CK TILE] Fix bugs in AQuant preshuffle
- Make Aquant works with block Mx64x256. `M` could be 16, 32, 64
- Make Aquant works with warp 16x16x32 and 32x32x16.
* [CK TILE] Rename Preshuffle to PreshuffleQuant
The new name, PreshuffleQuant, explicitly states the function's purpose:
to preshuffle the quantization matrix.
* [CK TILE Block Scale] Use GemmConfig to save tile properties
- Remove specialization of GemmQuantTypeConfig
- Pass GemmConfig around which contains tile properties. Stop using hard
coded tile properties in `gemm_calc_aquant()`
* [CK TILE Block Scale] Rename GemmConfig used in block scale
- Remove unused GemmConfig
- Rename GemmConfig used in block scale
---------
Co-authored-by: ThomasNing <thomas.ning@amd.com>
This commit is contained in:
@@ -8,11 +8,10 @@
|
|||||||
#include <string>
|
#include <string>
|
||||||
#include <tuple>
|
#include <tuple>
|
||||||
|
|
||||||
#include "ck_tile/core/config.hpp"
|
|
||||||
#include "ck_tile/host.hpp"
|
|
||||||
#include "gemm_utils.hpp"
|
#include "gemm_utils.hpp"
|
||||||
|
|
||||||
template <typename ADataType,
|
template <typename GemmConfig,
|
||||||
|
typename ADataType,
|
||||||
typename AQDataType,
|
typename AQDataType,
|
||||||
typename BDataType,
|
typename BDataType,
|
||||||
typename AccDataType,
|
typename AccDataType,
|
||||||
@@ -21,8 +20,7 @@ template <typename ADataType,
|
|||||||
typename ALayout,
|
typename ALayout,
|
||||||
typename BLayout,
|
typename BLayout,
|
||||||
typename CLayout,
|
typename CLayout,
|
||||||
uint32_t QuantGroupSize,
|
uint32_t QuantGroupSize>
|
||||||
bool Preshuffle = false>
|
|
||||||
float gemm_calc_aquant(const ck_tile::AQuantGemmHostArgs& args, const ck_tile::stream_config& s)
|
float gemm_calc_aquant(const ck_tile::AQuantGemmHostArgs& args, const ck_tile::stream_config& s)
|
||||||
{
|
{
|
||||||
constexpr bool kPadM = false;
|
constexpr bool kPadM = false;
|
||||||
@@ -33,17 +31,17 @@ float gemm_calc_aquant(const ck_tile::AQuantGemmHostArgs& args, const ck_tile::s
|
|||||||
|
|
||||||
static_assert(std::is_same_v<CLayout, ck_tile::tensor_layout::gemm::RowMajor>);
|
static_assert(std::is_same_v<CLayout, ck_tile::tensor_layout::gemm::RowMajor>);
|
||||||
|
|
||||||
constexpr ck_tile::index_t M_Tile = 16;
|
constexpr ck_tile::index_t M_Tile = GemmConfig::M_Tile;
|
||||||
constexpr ck_tile::index_t N_Tile = 64;
|
constexpr ck_tile::index_t N_Tile = GemmConfig::N_Tile;
|
||||||
constexpr ck_tile::index_t K_Tile = 256;
|
constexpr ck_tile::index_t K_Tile = GemmConfig::K_Tile;
|
||||||
|
|
||||||
constexpr ck_tile::index_t M_Warp = 1;
|
constexpr ck_tile::index_t M_Warp = GemmConfig::M_Warp;
|
||||||
constexpr ck_tile::index_t N_Warp = 4;
|
constexpr ck_tile::index_t N_Warp = GemmConfig::N_Warp;
|
||||||
constexpr ck_tile::index_t K_Warp = 1;
|
constexpr ck_tile::index_t K_Warp = GemmConfig::K_Warp;
|
||||||
|
|
||||||
constexpr ck_tile::index_t M_Warp_Tile = 16;
|
constexpr ck_tile::index_t M_Warp_Tile = GemmConfig::M_Warp_Tile;
|
||||||
constexpr ck_tile::index_t N_Warp_Tile = 16;
|
constexpr ck_tile::index_t N_Warp_Tile = GemmConfig::N_Warp_Tile;
|
||||||
constexpr ck_tile::index_t K_Warp_Tile = 32;
|
constexpr ck_tile::index_t K_Warp_Tile = GemmConfig::K_Warp_Tile;
|
||||||
|
|
||||||
using CodegenGemmShape =
|
using CodegenGemmShape =
|
||||||
ck_tile::TileGemmShape<ck_tile::sequence<M_Tile, N_Tile, K_Tile>,
|
ck_tile::TileGemmShape<ck_tile::sequence<M_Tile, N_Tile, K_Tile>,
|
||||||
@@ -52,8 +50,13 @@ float gemm_calc_aquant(const ck_tile::AQuantGemmHostArgs& args, const ck_tile::s
|
|||||||
|
|
||||||
using TilePartitioner = ck_tile::GemmTile1DPartitioner<CodegenGemmShape>;
|
using TilePartitioner = ck_tile::GemmTile1DPartitioner<CodegenGemmShape>;
|
||||||
|
|
||||||
using CodegenGemmTraits =
|
using CodegenGemmTraits = ck_tile::TileGemmAQuantTraits<kPadM,
|
||||||
ck_tile::TileGemmAQuantTraits<kPadM, kPadN, kPadK, Preshuffle, ALayout, BLayout, CLayout>;
|
kPadN,
|
||||||
|
kPadK,
|
||||||
|
GemmConfig::PreshuffleQuant,
|
||||||
|
ALayout,
|
||||||
|
BLayout,
|
||||||
|
CLayout>;
|
||||||
|
|
||||||
using GemmPipelineProblem = ck_tile::GemmPipelineProblemBase<ADataType,
|
using GemmPipelineProblem = ck_tile::GemmPipelineProblemBase<ADataType,
|
||||||
BDataType,
|
BDataType,
|
||||||
@@ -186,13 +189,14 @@ int run_gemm_example(int argc, char* argv[])
|
|||||||
if(data_type == "fp8")
|
if(data_type == "fp8")
|
||||||
{
|
{
|
||||||
using TypeConfig =
|
using TypeConfig =
|
||||||
decltype(GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t>{});
|
decltype(GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, float>{});
|
||||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>, TypeConfig, 128>(
|
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>, TypeConfig, 128>(
|
||||||
a_layout, b_layout, argc, argv);
|
a_layout, b_layout, argc, argv);
|
||||||
}
|
}
|
||||||
else if(data_type == "bf8")
|
else if(data_type == "bf8")
|
||||||
{
|
{
|
||||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::bf8_t, ck_tile::bf8_t, float>{});
|
using TypeConfig =
|
||||||
|
decltype(GemmQuantTypeConfig<ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, float>{});
|
||||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>, TypeConfig, 128>(
|
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>, TypeConfig, 128>(
|
||||||
a_layout, b_layout, argc, argv);
|
a_layout, b_layout, argc, argv);
|
||||||
}
|
}
|
||||||
@@ -200,7 +204,7 @@ int run_gemm_example(int argc, char* argv[])
|
|||||||
{
|
{
|
||||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::pk_int4_t,
|
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::pk_int4_t,
|
||||||
ck_tile::fp8_t,
|
ck_tile::fp8_t,
|
||||||
float,
|
ck_tile::half_t,
|
||||||
ck_tile::fp8_t>{});
|
ck_tile::fp8_t>{});
|
||||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::pk_int4_t>, TypeConfig, 128>(
|
return run_gemm_example_prec_type<GemmConfig<ck_tile::pk_int4_t>, TypeConfig, 128>(
|
||||||
a_layout, b_layout, argc, argv);
|
a_layout, b_layout, argc, argv);
|
||||||
@@ -209,29 +213,15 @@ int run_gemm_example(int argc, char* argv[])
|
|||||||
{
|
{
|
||||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::pk_int4_t,
|
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::pk_int4_t,
|
||||||
ck_tile::bf8_t,
|
ck_tile::bf8_t,
|
||||||
float,
|
ck_tile::half_t,
|
||||||
ck_tile::bf8_t>{});
|
ck_tile::bf8_t>{});
|
||||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::pk_int4_t>, TypeConfig, 128>(
|
return run_gemm_example_prec_type<GemmConfig<ck_tile::pk_int4_t>, TypeConfig, 128>(
|
||||||
a_layout, b_layout, argc, argv);
|
a_layout, b_layout, argc, argv);
|
||||||
}
|
}
|
||||||
else if(data_type == "i4f32fp8")
|
|
||||||
{
|
|
||||||
using TypeConfig =
|
|
||||||
decltype(GemmQuantTypeConfig<ck_tile::pk_int4_t, ck_tile::fp8_t, float, float>{});
|
|
||||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::pk_int4_t>, TypeConfig, 128>(
|
|
||||||
a_layout, b_layout, argc, argv);
|
|
||||||
}
|
|
||||||
else if(data_type == "i4f32bf8")
|
|
||||||
{
|
|
||||||
using TypeConfig =
|
|
||||||
decltype(GemmQuantTypeConfig<ck_tile::pk_int4_t, ck_tile::bf8_t, float, float>{});
|
|
||||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::pk_int4_t>, TypeConfig, 128>(
|
|
||||||
a_layout, b_layout, argc, argv);
|
|
||||||
}
|
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
throw std::runtime_error("Unsupported data type for this operation !!!");
|
throw std::runtime_error("Unsupported data type for this operation !!!");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
int main(int argc, char* argv[]) { return !run_gemm_example<GemmConfigComputeV3>(argc, argv); }
|
int main(int argc, char* argv[]) { return !run_gemm_example<GemmConfigDecode>(argc, argv); }
|
||||||
|
|||||||
@@ -8,11 +8,10 @@
|
|||||||
#include <string>
|
#include <string>
|
||||||
#include <tuple>
|
#include <tuple>
|
||||||
|
|
||||||
#include "ck_tile/core/config.hpp"
|
|
||||||
#include "ck_tile/host.hpp"
|
|
||||||
#include "gemm_utils.hpp"
|
#include "gemm_utils.hpp"
|
||||||
|
|
||||||
template <typename ADataType,
|
template <typename GemmConfig,
|
||||||
|
typename ADataType,
|
||||||
typename AQDataType,
|
typename AQDataType,
|
||||||
typename BDataType,
|
typename BDataType,
|
||||||
typename AccDataType,
|
typename AccDataType,
|
||||||
@@ -21,8 +20,7 @@ template <typename ADataType,
|
|||||||
typename ALayout,
|
typename ALayout,
|
||||||
typename BLayout,
|
typename BLayout,
|
||||||
typename CLayout,
|
typename CLayout,
|
||||||
uint32_t QuantGroupSize,
|
uint32_t QuantGroupSize>
|
||||||
bool Preshuffle = false>
|
|
||||||
float gemm_calc_aquant(const ck_tile::AQuantGemmHostArgs& args, const ck_tile::stream_config& s)
|
float gemm_calc_aquant(const ck_tile::AQuantGemmHostArgs& args, const ck_tile::stream_config& s)
|
||||||
{
|
{
|
||||||
constexpr bool kPadM = false;
|
constexpr bool kPadM = false;
|
||||||
@@ -33,17 +31,17 @@ float gemm_calc_aquant(const ck_tile::AQuantGemmHostArgs& args, const ck_tile::s
|
|||||||
|
|
||||||
static_assert(std::is_same_v<CLayout, ck_tile::tensor_layout::gemm::RowMajor>);
|
static_assert(std::is_same_v<CLayout, ck_tile::tensor_layout::gemm::RowMajor>);
|
||||||
|
|
||||||
constexpr ck_tile::index_t M_Tile = 16;
|
constexpr ck_tile::index_t M_Tile = GemmConfig::M_Tile;
|
||||||
constexpr ck_tile::index_t N_Tile = 64;
|
constexpr ck_tile::index_t N_Tile = GemmConfig::N_Tile;
|
||||||
constexpr ck_tile::index_t K_Tile = 256;
|
constexpr ck_tile::index_t K_Tile = GemmConfig::K_Tile;
|
||||||
|
|
||||||
constexpr ck_tile::index_t M_Warp = 1;
|
constexpr ck_tile::index_t M_Warp = GemmConfig::M_Warp;
|
||||||
constexpr ck_tile::index_t N_Warp = 4;
|
constexpr ck_tile::index_t N_Warp = GemmConfig::N_Warp;
|
||||||
constexpr ck_tile::index_t K_Warp = 1;
|
constexpr ck_tile::index_t K_Warp = GemmConfig::K_Warp;
|
||||||
|
|
||||||
constexpr ck_tile::index_t M_Warp_Tile = 16;
|
constexpr ck_tile::index_t M_Warp_Tile = GemmConfig::M_Warp_Tile;
|
||||||
constexpr ck_tile::index_t N_Warp_Tile = 16;
|
constexpr ck_tile::index_t N_Warp_Tile = GemmConfig::N_Warp_Tile;
|
||||||
constexpr ck_tile::index_t K_Warp_Tile = 32;
|
constexpr ck_tile::index_t K_Warp_Tile = GemmConfig::K_Warp_Tile;
|
||||||
|
|
||||||
using CodegenGemmShape =
|
using CodegenGemmShape =
|
||||||
ck_tile::TileGemmShape<ck_tile::sequence<M_Tile, N_Tile, K_Tile>,
|
ck_tile::TileGemmShape<ck_tile::sequence<M_Tile, N_Tile, K_Tile>,
|
||||||
@@ -52,8 +50,13 @@ float gemm_calc_aquant(const ck_tile::AQuantGemmHostArgs& args, const ck_tile::s
|
|||||||
|
|
||||||
using TilePartitioner = ck_tile::GemmTile1DPartitioner<CodegenGemmShape>;
|
using TilePartitioner = ck_tile::GemmTile1DPartitioner<CodegenGemmShape>;
|
||||||
|
|
||||||
using CodegenGemmTraits =
|
using CodegenGemmTraits = ck_tile::TileGemmAQuantTraits<kPadM,
|
||||||
ck_tile::TileGemmAQuantTraits<kPadM, kPadN, kPadK, Preshuffle, ALayout, BLayout, CLayout>;
|
kPadN,
|
||||||
|
kPadK,
|
||||||
|
GemmConfig::PreshuffleQuant,
|
||||||
|
ALayout,
|
||||||
|
BLayout,
|
||||||
|
CLayout>;
|
||||||
|
|
||||||
using GemmPipelineProblem = ck_tile::GemmPipelineProblemBase<ADataType,
|
using GemmPipelineProblem = ck_tile::GemmPipelineProblemBase<ADataType,
|
||||||
BDataType,
|
BDataType,
|
||||||
@@ -186,13 +189,14 @@ int run_gemm_example(int argc, char* argv[])
|
|||||||
if(data_type == "fp8")
|
if(data_type == "fp8")
|
||||||
{
|
{
|
||||||
using TypeConfig =
|
using TypeConfig =
|
||||||
decltype(GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t>{});
|
decltype(GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, float>{});
|
||||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>, TypeConfig, 128>(
|
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>, TypeConfig, 128>(
|
||||||
a_layout, b_layout, argc, argv);
|
a_layout, b_layout, argc, argv);
|
||||||
}
|
}
|
||||||
else if(data_type == "bf8")
|
else if(data_type == "bf8")
|
||||||
{
|
{
|
||||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::bf8_t, ck_tile::bf8_t, float>{});
|
using TypeConfig =
|
||||||
|
decltype(GemmQuantTypeConfig<ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, float>{});
|
||||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>, TypeConfig, 128>(
|
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>, TypeConfig, 128>(
|
||||||
a_layout, b_layout, argc, argv);
|
a_layout, b_layout, argc, argv);
|
||||||
}
|
}
|
||||||
@@ -200,7 +204,7 @@ int run_gemm_example(int argc, char* argv[])
|
|||||||
{
|
{
|
||||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::pk_int4_t,
|
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::pk_int4_t,
|
||||||
ck_tile::fp8_t,
|
ck_tile::fp8_t,
|
||||||
float,
|
ck_tile::half_t,
|
||||||
ck_tile::fp8_t>{});
|
ck_tile::fp8_t>{});
|
||||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::pk_int4_t>, TypeConfig, 128>(
|
return run_gemm_example_prec_type<GemmConfig<ck_tile::pk_int4_t>, TypeConfig, 128>(
|
||||||
a_layout, b_layout, argc, argv);
|
a_layout, b_layout, argc, argv);
|
||||||
@@ -209,29 +213,18 @@ int run_gemm_example(int argc, char* argv[])
|
|||||||
{
|
{
|
||||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::pk_int4_t,
|
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::pk_int4_t,
|
||||||
ck_tile::bf8_t,
|
ck_tile::bf8_t,
|
||||||
float,
|
ck_tile::half_t,
|
||||||
ck_tile::bf8_t>{});
|
ck_tile::bf8_t>{});
|
||||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::pk_int4_t>, TypeConfig, 128>(
|
return run_gemm_example_prec_type<GemmConfig<ck_tile::pk_int4_t>, TypeConfig, 128>(
|
||||||
a_layout, b_layout, argc, argv);
|
a_layout, b_layout, argc, argv);
|
||||||
}
|
}
|
||||||
else if(data_type == "i4f32fp8")
|
|
||||||
{
|
|
||||||
using TypeConfig =
|
|
||||||
decltype(GemmQuantTypeConfig<ck_tile::pk_int4_t, ck_tile::fp8_t, float, float>{});
|
|
||||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::pk_int4_t>, TypeConfig, 128>(
|
|
||||||
a_layout, b_layout, argc, argv);
|
|
||||||
}
|
|
||||||
else if(data_type == "i4f32bf8")
|
|
||||||
{
|
|
||||||
using TypeConfig =
|
|
||||||
decltype(GemmQuantTypeConfig<ck_tile::pk_int4_t, ck_tile::bf8_t, float, float>{});
|
|
||||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::pk_int4_t>, TypeConfig, 128>(
|
|
||||||
a_layout, b_layout, argc, argv);
|
|
||||||
}
|
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
throw std::runtime_error("Unsupported data type for this operation !!!");
|
throw std::runtime_error("Unsupported data type for this operation !!!");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
int main(int argc, char* argv[]) { return !run_gemm_example<GemmConfigPreshufle_AQ>(argc, argv); }
|
int main(int argc, char* argv[])
|
||||||
|
{
|
||||||
|
return !run_gemm_example<GemmConfigPreshuffleQuant>(argc, argv);
|
||||||
|
}
|
||||||
|
|||||||
@@ -11,11 +11,9 @@
|
|||||||
#include "ck_tile/ops/gemm.hpp"
|
#include "ck_tile/ops/gemm.hpp"
|
||||||
#include "ck_tile/ops/gemm_group_quant.hpp"
|
#include "ck_tile/ops/gemm_group_quant.hpp"
|
||||||
|
|
||||||
#define CK_TILE_PIPELINE_COMPUTE_V3 1
|
#define CK_TILE_PIPELINE_PREFILL 1
|
||||||
#define CK_TILE_PIPELINE_MEMORY 2
|
#define CK_TILE_PIPELINE_DECODE 2
|
||||||
#define CK_TILE_PIPELINE_COMPUTE_V4 3
|
#define CK_TILE_PIPELINE_PRESHUFFLEQUANT 3
|
||||||
#define CK_TILE_PIPELINE_COMPUTE_V5 4
|
|
||||||
#define CK_TILE_PIPELINE_PRESHUFFLE 5
|
|
||||||
|
|
||||||
template <typename PrecType, ck_tile::index_t M_Warp_Tile>
|
template <typename PrecType, ck_tile::index_t M_Warp_Tile>
|
||||||
constexpr ck_tile::index_t get_k_warp_tile()
|
constexpr ck_tile::index_t get_k_warp_tile()
|
||||||
@@ -87,196 +85,32 @@ struct GemmConfigBase
|
|||||||
static constexpr ck_tile::index_t TileParitionerGroupNum = 8;
|
static constexpr ck_tile::index_t TileParitionerGroupNum = 8;
|
||||||
static constexpr ck_tile::index_t TileParitionerM01 = 4;
|
static constexpr ck_tile::index_t TileParitionerM01 = 4;
|
||||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
|
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
|
||||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3;
|
|
||||||
static constexpr ck_tile::index_t NumWaveGroups = 1;
|
static constexpr ck_tile::index_t NumWaveGroups = 1;
|
||||||
static constexpr bool Preshuffle = false;
|
static constexpr bool PreshuffleQuant = false;
|
||||||
|
static constexpr bool DoubleSmemBuffer = false;
|
||||||
};
|
};
|
||||||
|
|
||||||
template <typename PrecType>
|
template <typename PrecType>
|
||||||
struct GemmConfigMemoryInterwave : public GemmConfigBase
|
struct GemmConfigDecode : public GemmConfigBase
|
||||||
{
|
{
|
||||||
// Memory friendly for Interwave scheduler
|
static constexpr ck_tile::index_t M_Tile = 16;
|
||||||
static constexpr ck_tile::index_t M_Tile = 128;
|
static constexpr ck_tile::index_t N_Tile = 64;
|
||||||
static constexpr ck_tile::index_t N_Tile = 32;
|
|
||||||
static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType);
|
|
||||||
|
|
||||||
static constexpr ck_tile::index_t M_Warp = 4;
|
|
||||||
static constexpr ck_tile::index_t N_Warp = 1;
|
|
||||||
static constexpr ck_tile::index_t K_Warp = 1;
|
|
||||||
|
|
||||||
static constexpr ck_tile::index_t M_Warp_Tile = 32;
|
|
||||||
static constexpr ck_tile::index_t N_Warp_Tile = 32;
|
|
||||||
static constexpr ck_tile::index_t K_Warp_Tile = sizeof(PrecType) == 2 ? 8 : 16;
|
|
||||||
|
|
||||||
static constexpr bool DoubleSmemBuffer = false;
|
|
||||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_MEMORY;
|
|
||||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Interwave;
|
|
||||||
};
|
|
||||||
|
|
||||||
template <typename PrecType>
|
|
||||||
struct GemmConfigMemoryIntrawave : public GemmConfigBase
|
|
||||||
{
|
|
||||||
static constexpr ck_tile::index_t M_Tile = 128;
|
|
||||||
static constexpr ck_tile::index_t N_Tile = 32;
|
|
||||||
static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType);
|
|
||||||
|
|
||||||
static constexpr ck_tile::index_t M_Warp = 4;
|
|
||||||
static constexpr ck_tile::index_t N_Warp = 1;
|
|
||||||
static constexpr ck_tile::index_t K_Warp = 1;
|
|
||||||
|
|
||||||
static constexpr ck_tile::index_t M_Warp_Tile = 32;
|
|
||||||
static constexpr ck_tile::index_t N_Warp_Tile = 32;
|
|
||||||
static constexpr ck_tile::index_t K_Warp_Tile = sizeof(PrecType) == 2 ? 8 : 16;
|
|
||||||
|
|
||||||
static constexpr bool DoubleSmemBuffer = false;
|
|
||||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_MEMORY;
|
|
||||||
};
|
|
||||||
|
|
||||||
template <typename PrecType>
|
|
||||||
struct GemmConfigComputeV3 : public GemmConfigBase
|
|
||||||
{
|
|
||||||
// Compute V3 only support Intrawave scheduler
|
|
||||||
static constexpr ck_tile::index_t M_Tile = 32;
|
|
||||||
static constexpr ck_tile::index_t N_Tile = 128;
|
|
||||||
static constexpr ck_tile::index_t K_Tile = 256 / sizeof(PrecType);
|
static constexpr ck_tile::index_t K_Tile = 256 / sizeof(PrecType);
|
||||||
|
|
||||||
static constexpr ck_tile::index_t M_Warp = 1;
|
static constexpr ck_tile::index_t M_Warp = 1;
|
||||||
static constexpr ck_tile::index_t N_Warp = 4;
|
static constexpr ck_tile::index_t N_Warp = 4;
|
||||||
static constexpr ck_tile::index_t K_Warp = 1;
|
static constexpr ck_tile::index_t K_Warp = 1;
|
||||||
|
|
||||||
static constexpr ck_tile::index_t M_Warp_Tile = 32;
|
|
||||||
static constexpr ck_tile::index_t N_Warp_Tile = 32;
|
|
||||||
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
|
|
||||||
|
|
||||||
static constexpr bool DoubleSmemBuffer = false;
|
|
||||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3;
|
|
||||||
};
|
|
||||||
|
|
||||||
template <typename PrecType>
|
|
||||||
struct GemmConfigComputeV3_1 : public GemmConfigBase
|
|
||||||
{
|
|
||||||
static constexpr ck_tile::index_t M_Tile = 256;
|
|
||||||
static constexpr ck_tile::index_t N_Tile = 256;
|
|
||||||
static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType);
|
|
||||||
|
|
||||||
static constexpr ck_tile::index_t M_Warp = 2;
|
|
||||||
static constexpr ck_tile::index_t N_Warp = 2;
|
|
||||||
static constexpr ck_tile::index_t K_Warp = 1;
|
|
||||||
|
|
||||||
static constexpr ck_tile::index_t M_Warp_Tile = 32;
|
|
||||||
static constexpr ck_tile::index_t N_Warp_Tile = 32;
|
|
||||||
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
|
|
||||||
|
|
||||||
static constexpr bool DoubleSmemBuffer = false;
|
|
||||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3;
|
|
||||||
};
|
|
||||||
|
|
||||||
template <typename PrecType>
|
|
||||||
struct GemmConfigComputeV3_2 : public GemmConfigBase
|
|
||||||
{
|
|
||||||
static constexpr ck_tile::index_t M_Tile = 128;
|
|
||||||
static constexpr ck_tile::index_t N_Tile = 128;
|
|
||||||
static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType);
|
|
||||||
|
|
||||||
static constexpr ck_tile::index_t M_Warp = 2;
|
|
||||||
static constexpr ck_tile::index_t N_Warp = 2;
|
|
||||||
static constexpr ck_tile::index_t K_Warp = 1;
|
|
||||||
|
|
||||||
static constexpr ck_tile::index_t M_Warp_Tile = 16;
|
static constexpr ck_tile::index_t M_Warp_Tile = 16;
|
||||||
static constexpr ck_tile::index_t N_Warp_Tile = 16;
|
static constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||||
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
|
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
|
||||||
|
|
||||||
static constexpr bool DoubleSmemBuffer = false;
|
|
||||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3;
|
|
||||||
|
|
||||||
static constexpr int kBlockPerCu = 2;
|
|
||||||
};
|
|
||||||
|
|
||||||
template <typename PrecType>
|
|
||||||
struct GemmConfigComputeV4 : public GemmConfigBase
|
|
||||||
{
|
|
||||||
// Compute V4 only support Intrawave scheduler
|
|
||||||
// Using the ping pong reader in the lds level
|
|
||||||
static constexpr ck_tile::index_t M_Tile = 256;
|
|
||||||
static constexpr ck_tile::index_t N_Tile = 256;
|
|
||||||
static constexpr ck_tile::index_t K_Tile = 64 / sizeof(PrecType);
|
|
||||||
|
|
||||||
static constexpr ck_tile::index_t M_Warp = 2;
|
|
||||||
static constexpr ck_tile::index_t N_Warp = 2;
|
|
||||||
static constexpr ck_tile::index_t K_Warp = 1;
|
|
||||||
|
|
||||||
static constexpr ck_tile::index_t M_Warp_Tile = 32;
|
|
||||||
static constexpr ck_tile::index_t N_Warp_Tile = 32;
|
|
||||||
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
|
|
||||||
|
|
||||||
static constexpr bool DoubleSmemBuffer = true;
|
|
||||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V4;
|
|
||||||
};
|
|
||||||
|
|
||||||
template <typename PrecType>
|
|
||||||
struct GemmConfigComputeV4_1 : public GemmConfigBase
|
|
||||||
{
|
|
||||||
static constexpr ck_tile::index_t M_Tile = 256;
|
|
||||||
static constexpr ck_tile::index_t N_Tile = 256;
|
|
||||||
static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType);
|
|
||||||
|
|
||||||
static constexpr ck_tile::index_t M_Warp = 2;
|
|
||||||
static constexpr ck_tile::index_t N_Warp = 2;
|
|
||||||
static constexpr ck_tile::index_t K_Warp = 1;
|
|
||||||
|
|
||||||
static constexpr ck_tile::index_t M_Warp_Tile = 32;
|
|
||||||
static constexpr ck_tile::index_t N_Warp_Tile = 32;
|
|
||||||
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
|
|
||||||
|
|
||||||
static constexpr bool DoubleSmemBuffer = true;
|
|
||||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V4;
|
|
||||||
};
|
|
||||||
|
|
||||||
template <typename PrecType>
|
|
||||||
struct GemmConfigComputeV5 : public GemmConfigBase
|
|
||||||
{
|
|
||||||
static constexpr ck_tile::index_t M_Tile = 128;
|
|
||||||
static constexpr ck_tile::index_t N_Tile = 128;
|
|
||||||
static constexpr ck_tile::index_t K_Tile = 64 / sizeof(PrecType);
|
|
||||||
|
|
||||||
static constexpr ck_tile::index_t M_Warp = 1;
|
|
||||||
static constexpr ck_tile::index_t N_Warp = 1;
|
|
||||||
static constexpr ck_tile::index_t K_Warp = 2;
|
|
||||||
|
|
||||||
static constexpr ck_tile::index_t M_Warp_Tile = 32;
|
|
||||||
static constexpr ck_tile::index_t N_Warp_Tile = 32;
|
|
||||||
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
|
|
||||||
|
|
||||||
static constexpr bool DoubleSmemBuffer = false;
|
|
||||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V5;
|
|
||||||
static constexpr ck_tile::index_t NumWaNumWaveGroups = 2;
|
|
||||||
};
|
|
||||||
|
|
||||||
template <typename PrecType>
|
|
||||||
struct GemmConfigPreshufle_1 : public GemmConfigBase
|
|
||||||
{
|
|
||||||
static constexpr ck_tile::index_t M_Tile = 128;
|
|
||||||
static constexpr ck_tile::index_t N_Tile = 128;
|
|
||||||
static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType);
|
|
||||||
|
|
||||||
static constexpr ck_tile::index_t M_Warp = 1;
|
|
||||||
static constexpr ck_tile::index_t N_Warp = 4;
|
|
||||||
static constexpr ck_tile::index_t K_Warp = 1;
|
|
||||||
|
|
||||||
static constexpr ck_tile::index_t M_Warp_Tile = 32;
|
|
||||||
static constexpr ck_tile::index_t N_Warp_Tile = 32;
|
|
||||||
static constexpr ck_tile::index_t K_Warp_Tile =
|
|
||||||
get_k_from_preshuffled_warp_tile<PrecType, M_Warp_Tile>();
|
|
||||||
|
|
||||||
static constexpr int kBlockPerCu = 2;
|
|
||||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
|
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
|
||||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_PRESHUFFLE;
|
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_DECODE;
|
||||||
static constexpr bool Preshuffle = true;
|
|
||||||
static constexpr bool DoubleSmemBuffer = false;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
template <typename PrecType>
|
template <typename PrecType>
|
||||||
struct GemmConfigPreshufle_2 : public GemmConfigBase
|
struct GemmConfigPrefill : public GemmConfigBase
|
||||||
{
|
{
|
||||||
static constexpr ck_tile::index_t M_Tile = 128;
|
static constexpr ck_tile::index_t M_Tile = 128;
|
||||||
static constexpr ck_tile::index_t N_Tile = 128;
|
static constexpr ck_tile::index_t N_Tile = 128;
|
||||||
@@ -288,18 +122,15 @@ struct GemmConfigPreshufle_2 : public GemmConfigBase
|
|||||||
|
|
||||||
static constexpr ck_tile::index_t M_Warp_Tile = 16;
|
static constexpr ck_tile::index_t M_Warp_Tile = 16;
|
||||||
static constexpr ck_tile::index_t N_Warp_Tile = 16;
|
static constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||||
static constexpr ck_tile::index_t K_Warp_Tile =
|
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
|
||||||
get_k_from_preshuffled_warp_tile<PrecType, M_Warp_Tile>();
|
|
||||||
|
|
||||||
static constexpr int kBlockPerCu = 2;
|
static constexpr int kBlockPerCu = 2;
|
||||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
|
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
|
||||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_PRESHUFFLE;
|
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_PREFILL;
|
||||||
static constexpr bool Preshuffle = true;
|
|
||||||
static constexpr bool DoubleSmemBuffer = false;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
template <typename PrecType>
|
template <typename PrecType>
|
||||||
struct GemmConfigPreshufle_AQ : public GemmConfigBase
|
struct GemmConfigPreshuffleQuant : public GemmConfigBase
|
||||||
{
|
{
|
||||||
static constexpr ck_tile::index_t M_Tile = 16;
|
static constexpr ck_tile::index_t M_Tile = 16;
|
||||||
static constexpr ck_tile::index_t N_Tile = 64;
|
static constexpr ck_tile::index_t N_Tile = 64;
|
||||||
@@ -314,9 +145,9 @@ struct GemmConfigPreshufle_AQ : public GemmConfigBase
|
|||||||
static constexpr ck_tile::index_t K_Warp_Tile =
|
static constexpr ck_tile::index_t K_Warp_Tile =
|
||||||
get_k_from_preshuffled_warp_tile<PrecType, M_Warp_Tile>();
|
get_k_from_preshuffled_warp_tile<PrecType, M_Warp_Tile>();
|
||||||
|
|
||||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_PRESHUFFLE;
|
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
|
||||||
static constexpr bool Preshuffle = true;
|
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_PRESHUFFLEQUANT;
|
||||||
static constexpr bool DoubleSmemBuffer = false;
|
static constexpr bool PreshuffleQuant = true;
|
||||||
};
|
};
|
||||||
|
|
||||||
template <typename ADataType_,
|
template <typename ADataType_,
|
||||||
@@ -332,176 +163,6 @@ struct GemmQuantTypeConfig
|
|||||||
using CDataType = CDataType_;
|
using CDataType = CDataType_;
|
||||||
};
|
};
|
||||||
|
|
||||||
template <>
|
|
||||||
struct GemmQuantTypeConfig<ck_tile::half_t>
|
|
||||||
{
|
|
||||||
using ADataType = ck_tile::half_t;
|
|
||||||
using QDataType = float;
|
|
||||||
using BDataType = ck_tile::half_t;
|
|
||||||
using AccDataType = float;
|
|
||||||
using CDataType = ck_tile::half_t;
|
|
||||||
};
|
|
||||||
|
|
||||||
template <>
|
|
||||||
struct GemmQuantTypeConfig<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t>
|
|
||||||
{
|
|
||||||
using ADataType = ck_tile::bf16_t;
|
|
||||||
using QDataType = float;
|
|
||||||
using BDataType = ck_tile::bf16_t;
|
|
||||||
using AccDataType = float;
|
|
||||||
using CDataType = ck_tile::bf16_t;
|
|
||||||
};
|
|
||||||
|
|
||||||
template <>
|
|
||||||
struct GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t>
|
|
||||||
{
|
|
||||||
using ADataType = ck_tile::fp8_t;
|
|
||||||
using QDataType = float;
|
|
||||||
using BDataType = ck_tile::fp8_t;
|
|
||||||
using AccDataType = float;
|
|
||||||
using CDataType = ck_tile::half_t;
|
|
||||||
};
|
|
||||||
|
|
||||||
template <>
|
|
||||||
struct GemmQuantTypeConfig<ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t>
|
|
||||||
{
|
|
||||||
using ADataType = ck_tile::bf8_t;
|
|
||||||
using QDataType = float;
|
|
||||||
using BDataType = ck_tile::bf8_t;
|
|
||||||
using AccDataType = float;
|
|
||||||
using CDataType = ck_tile::half_t;
|
|
||||||
};
|
|
||||||
|
|
||||||
template <>
|
|
||||||
struct GemmQuantTypeConfig<ck_tile::half_t, ck_tile::pk_int4_t, ck_tile::half_t>
|
|
||||||
{
|
|
||||||
using ADataType = ck_tile::half_t;
|
|
||||||
using QDataType = float;
|
|
||||||
using BDataType = ck_tile::pk_int4_t;
|
|
||||||
using AccDataType = float;
|
|
||||||
using CDataType = ck_tile::half_t;
|
|
||||||
};
|
|
||||||
|
|
||||||
template <>
|
|
||||||
struct GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, float>
|
|
||||||
{
|
|
||||||
using ADataType = ck_tile::fp8_t;
|
|
||||||
using QDataType = float;
|
|
||||||
using BDataType = ck_tile::fp8_t;
|
|
||||||
using AccDataType = float;
|
|
||||||
using CDataType = ck_tile::half_t;
|
|
||||||
};
|
|
||||||
|
|
||||||
template <>
|
|
||||||
struct GemmQuantTypeConfig<ck_tile::bf8_t, ck_tile::bf8_t, float>
|
|
||||||
{
|
|
||||||
using ADataType = ck_tile::bf8_t;
|
|
||||||
using QDataType = float;
|
|
||||||
using BDataType = ck_tile::bf8_t;
|
|
||||||
using AccDataType = float;
|
|
||||||
using CDataType = ck_tile::half_t;
|
|
||||||
};
|
|
||||||
|
|
||||||
template <>
|
|
||||||
struct GemmQuantTypeConfig<ck_tile::pk_int4_t, ck_tile::fp8_t, float, ck_tile::fp8_t>
|
|
||||||
{
|
|
||||||
using ADataType = ck_tile::pk_int4_t;
|
|
||||||
using QDataType = ck_tile::fp8_t;
|
|
||||||
using BDataType = ck_tile::fp8_t;
|
|
||||||
using AccDataType = float;
|
|
||||||
using CDataType = ck_tile::half_t;
|
|
||||||
};
|
|
||||||
|
|
||||||
template <>
|
|
||||||
struct GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, float, ck_tile::fp8_t>
|
|
||||||
{
|
|
||||||
using ADataType = ck_tile::fp8_t;
|
|
||||||
using QDataType = ck_tile::fp8_t;
|
|
||||||
using BDataType = ck_tile::fp8_t;
|
|
||||||
using AccDataType = float;
|
|
||||||
using CDataType = ck_tile::half_t;
|
|
||||||
};
|
|
||||||
|
|
||||||
template <>
|
|
||||||
struct GemmQuantTypeConfig<ck_tile::bf8_t, ck_tile::bf8_t, float, ck_tile::bf8_t>
|
|
||||||
{
|
|
||||||
using ADataType = ck_tile::bf8_t;
|
|
||||||
using QDataType = ck_tile::bf8_t;
|
|
||||||
using BDataType = ck_tile::bf8_t;
|
|
||||||
using AccDataType = float;
|
|
||||||
using CDataType = ck_tile::half_t;
|
|
||||||
};
|
|
||||||
|
|
||||||
template <>
|
|
||||||
struct GemmQuantTypeConfig<ck_tile::pk_int4_t, ck_tile::bf8_t, float, ck_tile::bf8_t>
|
|
||||||
{
|
|
||||||
using ADataType = ck_tile::pk_int4_t;
|
|
||||||
using QDataType = ck_tile::bf8_t;
|
|
||||||
using BDataType = ck_tile::bf8_t;
|
|
||||||
using AccDataType = float;
|
|
||||||
using CDataType = ck_tile::half_t;
|
|
||||||
};
|
|
||||||
|
|
||||||
template <>
|
|
||||||
struct GemmQuantTypeConfig<ck_tile::pk_int4_t, ck_tile::fp8_t, float, float>
|
|
||||||
{
|
|
||||||
using ADataType = ck_tile::pk_int4_t;
|
|
||||||
using QDataType = float;
|
|
||||||
using BDataType = ck_tile::fp8_t;
|
|
||||||
using AccDataType = float;
|
|
||||||
using CDataType = ck_tile::half_t;
|
|
||||||
};
|
|
||||||
|
|
||||||
template <>
|
|
||||||
struct GemmQuantTypeConfig<ck_tile::pk_int4_t, ck_tile::bf8_t, float, float>
|
|
||||||
{
|
|
||||||
using ADataType = ck_tile::pk_int4_t;
|
|
||||||
using QDataType = float;
|
|
||||||
using BDataType = ck_tile::bf8_t;
|
|
||||||
using AccDataType = float;
|
|
||||||
using CDataType = ck_tile::half_t;
|
|
||||||
};
|
|
||||||
|
|
||||||
template <>
|
|
||||||
struct GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::pk_int4_t, float, ck_tile::fp8_t>
|
|
||||||
{
|
|
||||||
using ADataType = ck_tile::fp8_t;
|
|
||||||
using QDataType = ck_tile::fp8_t;
|
|
||||||
using BDataType = ck_tile::pk_int4_t;
|
|
||||||
using AccDataType = float;
|
|
||||||
using CDataType = ck_tile::half_t;
|
|
||||||
};
|
|
||||||
|
|
||||||
template <>
|
|
||||||
struct GemmQuantTypeConfig<ck_tile::bf8_t, ck_tile::pk_int4_t, float, ck_tile::bf8_t>
|
|
||||||
{
|
|
||||||
using ADataType = ck_tile::bf8_t;
|
|
||||||
using QDataType = ck_tile::bf8_t;
|
|
||||||
using BDataType = ck_tile::pk_int4_t;
|
|
||||||
using AccDataType = float;
|
|
||||||
using CDataType = ck_tile::half_t;
|
|
||||||
};
|
|
||||||
|
|
||||||
template <>
|
|
||||||
struct GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::pk_int4_t, float, float>
|
|
||||||
{
|
|
||||||
using ADataType = ck_tile::fp8_t;
|
|
||||||
using QDataType = float;
|
|
||||||
using BDataType = ck_tile::pk_int4_t;
|
|
||||||
using AccDataType = float;
|
|
||||||
using CDataType = ck_tile::half_t;
|
|
||||||
};
|
|
||||||
|
|
||||||
template <>
|
|
||||||
struct GemmQuantTypeConfig<ck_tile::bf8_t, ck_tile::pk_int4_t, float, float>
|
|
||||||
{
|
|
||||||
using ADataType = ck_tile::bf8_t;
|
|
||||||
using QDataType = float;
|
|
||||||
using BDataType = ck_tile::pk_int4_t;
|
|
||||||
using AccDataType = float;
|
|
||||||
using CDataType = ck_tile::half_t;
|
|
||||||
};
|
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
struct DataTypeTraits;
|
struct DataTypeTraits;
|
||||||
|
|
||||||
@@ -559,55 +220,6 @@ struct DataTypeTraits<ck_tile::int8_t>
|
|||||||
static constexpr const char* name = "int8";
|
static constexpr const char* name = "int8";
|
||||||
};
|
};
|
||||||
|
|
||||||
template <ck_tile::index_t PipelineId>
|
|
||||||
struct PipelineTypeTraits;
|
|
||||||
|
|
||||||
template <>
|
|
||||||
struct PipelineTypeTraits<CK_TILE_PIPELINE_MEMORY>
|
|
||||||
{
|
|
||||||
template <typename PipelineProblem>
|
|
||||||
using GemmPipeline = ck_tile::GemmPipelineAgBgCrMem<PipelineProblem>;
|
|
||||||
template <typename PipelineProblem>
|
|
||||||
using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrMem<PipelineProblem>;
|
|
||||||
};
|
|
||||||
|
|
||||||
template <>
|
|
||||||
struct PipelineTypeTraits<CK_TILE_PIPELINE_COMPUTE_V3>
|
|
||||||
{
|
|
||||||
template <typename PipelineProblem>
|
|
||||||
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3<PipelineProblem>;
|
|
||||||
template <typename PipelineProblem>
|
|
||||||
using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV3<PipelineProblem>;
|
|
||||||
};
|
|
||||||
|
|
||||||
template <>
|
|
||||||
struct PipelineTypeTraits<CK_TILE_PIPELINE_COMPUTE_V4>
|
|
||||||
{
|
|
||||||
template <typename PipelineProblem>
|
|
||||||
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV4<PipelineProblem>;
|
|
||||||
template <typename PipelineProblem>
|
|
||||||
using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV4<PipelineProblem>;
|
|
||||||
};
|
|
||||||
|
|
||||||
template <>
|
|
||||||
struct PipelineTypeTraits<CK_TILE_PIPELINE_COMPUTE_V5>
|
|
||||||
{
|
|
||||||
template <typename PipelineProblem>
|
|
||||||
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV5<PipelineProblem>;
|
|
||||||
template <typename PipelineProblem>
|
|
||||||
using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV5<PipelineProblem>;
|
|
||||||
};
|
|
||||||
|
|
||||||
template <>
|
|
||||||
struct PipelineTypeTraits<CK_TILE_PIPELINE_PRESHUFFLE>
|
|
||||||
{
|
|
||||||
template <typename PipelineProblem>
|
|
||||||
using GemmPipeline = ck_tile::WeightPreshufflePipelineAGmemBGmemCRegV1<PipelineProblem>;
|
|
||||||
template <typename PipelineProblem>
|
|
||||||
using UniversalGemmPipeline =
|
|
||||||
ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV1<PipelineProblem>;
|
|
||||||
};
|
|
||||||
|
|
||||||
auto create_args(int argc, char* argv[])
|
auto create_args(int argc, char* argv[])
|
||||||
{
|
{
|
||||||
ck_tile::ArgParser arg_parser;
|
ck_tile::ArgParser arg_parser;
|
||||||
|
|||||||
@@ -31,7 +31,8 @@ auto shuffle_aq(const ck_tile::HostTensor<T>& t, int block_aq_k)
|
|||||||
return ck_tile::reference_permute(t_view, {1, 0, 2});
|
return ck_tile::reference_permute(t_view, {1, 0, 2});
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename ADataType,
|
template <typename GemmConfig,
|
||||||
|
typename ADataType,
|
||||||
typename AQDataType,
|
typename AQDataType,
|
||||||
typename BDataType,
|
typename BDataType,
|
||||||
typename AccDataType,
|
typename AccDataType,
|
||||||
@@ -40,8 +41,7 @@ template <typename ADataType,
|
|||||||
typename AQLayout,
|
typename AQLayout,
|
||||||
typename BLayout,
|
typename BLayout,
|
||||||
typename CLayout,
|
typename CLayout,
|
||||||
uint32_t QuantGroupSize,
|
uint32_t QuantGroupSize>
|
||||||
bool Preshuffle = false>
|
|
||||||
float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
|
float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
|
||||||
ck_tile::DeviceMem& aq_m_aqk_dev_buf,
|
ck_tile::DeviceMem& aq_m_aqk_dev_buf,
|
||||||
ck_tile::DeviceMem& b_k_n_dev_buf,
|
ck_tile::DeviceMem& b_k_n_dev_buf,
|
||||||
@@ -73,7 +73,8 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
|
|||||||
args.stride_C = stride_C;
|
args.stride_C = stride_C;
|
||||||
args.stride_AQ = stride_AQ;
|
args.stride_AQ = stride_AQ;
|
||||||
|
|
||||||
float ave_time = gemm_calc_aquant<ADataType,
|
float ave_time = gemm_calc_aquant<GemmConfig,
|
||||||
|
ADataType,
|
||||||
AQDataType,
|
AQDataType,
|
||||||
BDataType,
|
BDataType,
|
||||||
AccDataType,
|
AccDataType,
|
||||||
@@ -82,8 +83,7 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
|
|||||||
ALayout,
|
ALayout,
|
||||||
BLayout,
|
BLayout,
|
||||||
CLayout,
|
CLayout,
|
||||||
QuantGroupSize,
|
QuantGroupSize>(
|
||||||
Preshuffle>(
|
|
||||||
args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat});
|
args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat});
|
||||||
|
|
||||||
std::size_t flop = std::size_t(2) * M * N * K;
|
std::size_t flop = std::size_t(2) * M * N * K;
|
||||||
@@ -206,7 +206,7 @@ int run_gemm_example_with_layouts(int argc,
|
|||||||
ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size_in_bytes());
|
ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size_in_bytes());
|
||||||
ck_tile::DeviceMem c_m_n_dev_buf(c_m_n_dev_result.get_element_space_size_in_bytes());
|
ck_tile::DeviceMem c_m_n_dev_buf(c_m_n_dev_result.get_element_space_size_in_bytes());
|
||||||
|
|
||||||
if constexpr(GemmConfig::Preshuffle)
|
if constexpr(GemmConfig::PreshuffleQuant)
|
||||||
{
|
{
|
||||||
ck_tile::HostTensor<AQDataType> aq_shuffle_host =
|
ck_tile::HostTensor<AQDataType> aq_shuffle_host =
|
||||||
shuffle_aq(aq_m_aqk, GemmConfig::K_Tile / QuantGroupSize);
|
shuffle_aq(aq_m_aqk, GemmConfig::K_Tile / QuantGroupSize);
|
||||||
@@ -222,7 +222,8 @@ int run_gemm_example_with_layouts(int argc,
|
|||||||
c_m_n_dev_buf.SetZero();
|
c_m_n_dev_buf.SetZero();
|
||||||
c_m_n_dev_result.SetZero();
|
c_m_n_dev_result.SetZero();
|
||||||
|
|
||||||
invoke_gemm<ADataType,
|
invoke_gemm<GemmConfig,
|
||||||
|
ADataType,
|
||||||
AQDataType,
|
AQDataType,
|
||||||
BDataType,
|
BDataType,
|
||||||
AccDataType,
|
AccDataType,
|
||||||
@@ -231,22 +232,21 @@ int run_gemm_example_with_layouts(int argc,
|
|||||||
AQLayout,
|
AQLayout,
|
||||||
BLayout,
|
BLayout,
|
||||||
CLayout,
|
CLayout,
|
||||||
QuantGroupSize,
|
QuantGroupSize>(a_m_k_dev_buf,
|
||||||
GemmConfig::Preshuffle>(a_m_k_dev_buf,
|
aq_m_aqk_dev_buf,
|
||||||
aq_m_aqk_dev_buf,
|
b_k_n_dev_buf,
|
||||||
b_k_n_dev_buf,
|
c_m_n_dev_buf,
|
||||||
c_m_n_dev_buf,
|
M,
|
||||||
M,
|
N,
|
||||||
N,
|
K,
|
||||||
K,
|
AQK,
|
||||||
AQK,
|
stride_A,
|
||||||
stride_A,
|
stride_AQ,
|
||||||
stride_AQ,
|
stride_B,
|
||||||
stride_B,
|
stride_C,
|
||||||
stride_C,
|
kbatch,
|
||||||
kbatch,
|
n_warmup,
|
||||||
n_warmup,
|
n_repeat);
|
||||||
n_repeat);
|
|
||||||
|
|
||||||
c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data());
|
c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data());
|
||||||
bool pass = true;
|
bool pass = true;
|
||||||
|
|||||||
@@ -157,7 +157,7 @@ struct AQuantBlockUniversalGemmAsBsCr : public BlockGemmQuantBase<Problem_>
|
|||||||
static constexpr index_t KPack = WarpGemm::kKPerThread;
|
static constexpr index_t KPack = WarpGemm::kKPerThread;
|
||||||
static constexpr index_t KPerThread = KIterPerWarp * WarpGemm::kKPerThread;
|
static constexpr index_t KPerThread = KIterPerWarp * WarpGemm::kKPerThread;
|
||||||
|
|
||||||
static constexpr bool Preshuffle = Problem::Traits::Preshuffle;
|
static constexpr bool PreshuffleQuant = Problem::Traits::PreshuffleQuant;
|
||||||
};
|
};
|
||||||
|
|
||||||
public:
|
public:
|
||||||
@@ -357,7 +357,7 @@ struct AQuantBlockUniversalGemmAsBsCr : public BlockGemmQuantBase<Problem_>
|
|||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
if constexpr(Traits::Preshuffle)
|
if constexpr(Traits::PreshuffleQuant)
|
||||||
{
|
{
|
||||||
// A view is created on top of the preshuffled AQ, where each row of the
|
// A view is created on top of the preshuffled AQ, where each row of the
|
||||||
// view is composed of a row from a warp tile within an AQ block tile.
|
// view is composed of a row from a warp tile within an AQ block tile.
|
||||||
@@ -392,12 +392,27 @@ struct AQuantBlockUniversalGemmAsBsCr : public BlockGemmQuantBase<Problem_>
|
|||||||
|
|
||||||
// Thread 0 can read AQ_tile[0, 0] from itself, AQ_tile[1, 0]
|
// Thread 0 can read AQ_tile[0, 0] from itself, AQ_tile[1, 0]
|
||||||
// from thread 1, ..., and AQ_tile[3, 0] from thread 3.
|
// from thread 1, ..., and AQ_tile[3, 0] from thread 3.
|
||||||
auto pull_from_lane =
|
decltype(threadIdx.x) pull_from_lane = 0;
|
||||||
((threadIdx.x & (warp_size - 1)) / Traits::WarpGemm::kN *
|
if constexpr(WarpGemm::kM == 16)
|
||||||
kTileRowsOfCPerThread +
|
{
|
||||||
c_row) *
|
pull_from_lane = (__lane_id() / Traits::WarpGemm::kN *
|
||||||
Traits::QScalesPerBlockRow +
|
kTileRowsOfCPerThread +
|
||||||
kQScale;
|
c_row) *
|
||||||
|
Traits::QScalesPerBlockRow +
|
||||||
|
kQScale;
|
||||||
|
}
|
||||||
|
else if constexpr(WarpGemm::kM == 32)
|
||||||
|
{
|
||||||
|
pull_from_lane = (__lane_id() / Traits::WarpGemm::kN *
|
||||||
|
kTileRowsOfCPerThread +
|
||||||
|
((c_row >> 2) << 3) + (c_row & 0b11)) *
|
||||||
|
Traits::QScalesPerBlockRow +
|
||||||
|
kQScale;
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
static_assert(false, "WarpGemm::kM is not 16 nor 32.");
|
||||||
|
}
|
||||||
auto& scale_reg = aq_block_tensor.get_thread_buffer()[mIter];
|
auto& scale_reg = aq_block_tensor.get_thread_buffer()[mIter];
|
||||||
|
|
||||||
// cross lane ops
|
// cross lane ops
|
||||||
|
|||||||
@@ -99,15 +99,15 @@ struct AQuantGemmKernelArgs
|
|||||||
template <typename TilePartitioner_, typename GemmPipeline_, typename EpiloguePipeline_>
|
template <typename TilePartitioner_, typename GemmPipeline_, typename EpiloguePipeline_>
|
||||||
struct AQuantGemmKernel
|
struct AQuantGemmKernel
|
||||||
{
|
{
|
||||||
using TilePartitioner = remove_cvref_t<TilePartitioner_>;
|
using TilePartitioner = remove_cvref_t<TilePartitioner_>;
|
||||||
using GemmPipeline = remove_cvref_t<GemmPipeline_>;
|
using GemmPipeline = remove_cvref_t<GemmPipeline_>;
|
||||||
using EpiloguePipeline = remove_cvref_t<EpiloguePipeline_>;
|
using EpiloguePipeline = remove_cvref_t<EpiloguePipeline_>;
|
||||||
using ALayout = remove_cvref_t<typename GemmPipeline::ALayout>;
|
using ALayout = remove_cvref_t<typename GemmPipeline::ALayout>;
|
||||||
using AQLayout = remove_cvref_t<typename GemmPipeline::AQLayout>;
|
using AQLayout = remove_cvref_t<typename GemmPipeline::AQLayout>;
|
||||||
using BLayout = remove_cvref_t<typename GemmPipeline::BLayout>;
|
using BLayout = remove_cvref_t<typename GemmPipeline::BLayout>;
|
||||||
using CLayout = remove_cvref_t<typename GemmPipeline::CLayout>;
|
using CLayout = remove_cvref_t<typename GemmPipeline::CLayout>;
|
||||||
static constexpr index_t kBlockSize = GemmPipeline::BlockSize;
|
static constexpr index_t KernelBlockSize = GemmPipeline::BlockSize;
|
||||||
static constexpr bool Preshuffle = GemmPipeline::Preshuffle;
|
static constexpr bool PreshuffleQuant = GemmPipeline::PreshuffleQuant;
|
||||||
|
|
||||||
using ADataType = remove_cvref_t<typename GemmPipeline::ADataType>;
|
using ADataType = remove_cvref_t<typename GemmPipeline::ADataType>;
|
||||||
using AQDataType = remove_cvref_t<typename GemmPipeline::AQDataType>;
|
using AQDataType = remove_cvref_t<typename GemmPipeline::AQDataType>;
|
||||||
@@ -422,9 +422,9 @@ struct AQuantGemmKernel
|
|||||||
ck_tile::integer_least_multiple(wave_tile_size, get_warp_size());
|
ck_tile::integer_least_multiple(wave_tile_size, get_warp_size());
|
||||||
const auto aq_merge_pad1_desc = transform_tensor_descriptor(
|
const auto aq_merge_pad1_desc = transform_tensor_descriptor(
|
||||||
aq_pad1_desc,
|
aq_pad1_desc,
|
||||||
make_tuple(make_merge_transform(make_tuple(wave_tile_count_x, aq_y)),
|
make_tuple(make_merge_transform(make_tuple(aq_y, wave_tile_count_x)),
|
||||||
make_pass_through_transform(pad_wave_size)),
|
make_pass_through_transform(pad_wave_size)),
|
||||||
make_tuple(sequence<1, 0>{}, sequence<2>{}),
|
make_tuple(sequence<0, 1>{}, sequence<2>{}),
|
||||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||||
|
|
||||||
return make_tensor_view<address_space_enum::global>(aq_ptr, aq_merge_pad1_desc);
|
return make_tensor_view<address_space_enum::global>(aq_ptr, aq_merge_pad1_desc);
|
||||||
@@ -432,7 +432,7 @@ struct AQuantGemmKernel
|
|||||||
|
|
||||||
const auto& aq_tensor_view = [&]() {
|
const auto& aq_tensor_view = [&]() {
|
||||||
static_assert(std::is_same_v<AQLayout, tensor_layout::gemm::RowMajor>);
|
static_assert(std::is_same_v<AQLayout, tensor_layout::gemm::RowMajor>);
|
||||||
if constexpr(Preshuffle)
|
if constexpr(PreshuffleQuant)
|
||||||
{
|
{
|
||||||
return make_preshuffled_aq_tensor_view();
|
return make_preshuffled_aq_tensor_view();
|
||||||
}
|
}
|
||||||
@@ -599,10 +599,8 @@ struct AQuantGemmKernel
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <typename PadView>
|
template <typename PadView>
|
||||||
CK_TILE_DEVICE static auto MakeGemmTileWindows(const PadView& views,
|
CK_TILE_DEVICE static auto
|
||||||
const AQuantGemmKernelArgs& kargs,
|
MakeGemmTileWindows(const PadView& views, const index_t i_m, const index_t i_n)
|
||||||
const index_t i_m,
|
|
||||||
const index_t i_n)
|
|
||||||
{
|
{
|
||||||
const auto& a_pad_view = views.at(I0);
|
const auto& a_pad_view = views.at(I0);
|
||||||
const auto& aq_pad_view = views.at(I1);
|
const auto& aq_pad_view = views.at(I1);
|
||||||
@@ -628,24 +626,27 @@ struct AQuantGemmKernel
|
|||||||
|
|
||||||
const auto& aq_block_window = [&]() {
|
const auto& aq_block_window = [&]() {
|
||||||
static_assert(std::is_same_v<AQLayout, tensor_layout::gemm::RowMajor>);
|
static_assert(std::is_same_v<AQLayout, tensor_layout::gemm::RowMajor>);
|
||||||
if constexpr(Preshuffle)
|
constexpr auto block_m = TilePartitioner::MPerBlock;
|
||||||
|
constexpr auto block_k = TilePartitioner::KPerBlock;
|
||||||
|
constexpr auto warp_m = TilePartitioner::BlockGemmShape::WarpTile::at(I0);
|
||||||
|
constexpr auto aqk_per_block =
|
||||||
|
TilePartitioner::KPerBlock / GemmPipeline::QuantGroupSize;
|
||||||
|
if constexpr(PreshuffleQuant)
|
||||||
{
|
{
|
||||||
constexpr auto tile_window_width = get_warp_size();
|
constexpr auto tile_window_width =
|
||||||
constexpr auto tile_window_height =
|
ck_tile::integer_least_multiple(warp_m * aqk_per_block, get_warp_size());
|
||||||
TilePartitioner::MPerBlock / TilePartitioner::BlockGemmShape::WarpTile::at(I0);
|
constexpr auto tile_window_height = block_m / warp_m;
|
||||||
auto block_m_idx = i_m / TilePartitioner::MPerBlock;
|
auto block_m_idx = i_m / block_m;
|
||||||
return make_tile_window(
|
return make_tile_window(
|
||||||
aq_pad_view,
|
aq_pad_view,
|
||||||
make_tuple(number<tile_window_height>{}, number<tile_window_width>{}),
|
make_tuple(number<tile_window_height>{}, number<tile_window_width>{}),
|
||||||
{block_m_idx * kargs.K / TilePartitioner::BlockGemmShape::BlockTile::at(I2),
|
{block_m_idx * tile_window_height, 0});
|
||||||
0});
|
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
return make_tile_window(
|
return make_tile_window(
|
||||||
aq_pad_view,
|
aq_pad_view,
|
||||||
make_tuple(number<TilePartitioner::MPerBlock>{},
|
make_tuple(number<block_m>{}, number<block_k / GemmPipeline::QuantGroupSize>{}),
|
||||||
number<TilePartitioner::KPerBlock / GemmPipeline::QuantGroupSize>{}),
|
|
||||||
{i_m, 0});
|
{i_m, 0});
|
||||||
}
|
}
|
||||||
}();
|
}();
|
||||||
@@ -706,8 +707,7 @@ struct AQuantGemmKernel
|
|||||||
a_ptr, b_ptr, aq_ptr, c_ptr, kargs, splitk_batch_offset);
|
a_ptr, b_ptr, aq_ptr, c_ptr, kargs, splitk_batch_offset);
|
||||||
|
|
||||||
const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple);
|
const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple);
|
||||||
auto gemm_tile_windows =
|
auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
|
||||||
MakeGemmTileWindows(gemm_pad_views, kargs, block_idx_m, block_idx_n);
|
|
||||||
|
|
||||||
const index_t num_loop = __builtin_amdgcn_readfirstlane(
|
const index_t num_loop = __builtin_amdgcn_readfirstlane(
|
||||||
TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k));
|
TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k));
|
||||||
@@ -718,7 +718,7 @@ struct AQuantGemmKernel
|
|||||||
const auto& b_block_window = gemm_tile_windows.at(I2);
|
const auto& b_block_window = gemm_tile_windows.at(I2);
|
||||||
|
|
||||||
const auto& c_block_tile = GemmPipeline{}.template operator()(
|
const auto& c_block_tile = GemmPipeline{}.template operator()(
|
||||||
a_block_window, b_block_window, aq_block_window, num_loop, smem_ptr_0);
|
a_block_window, b_block_window, aq_block_window, kargs.M, num_loop, smem_ptr_0);
|
||||||
|
|
||||||
// Run Epilogue Pipeline
|
// Run Epilogue Pipeline
|
||||||
auto& c_block_window = gemm_tile_windows.at(I3);
|
auto& c_block_window = gemm_tile_windows.at(I3);
|
||||||
|
|||||||
@@ -37,23 +37,23 @@ struct GemmAQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgC
|
|||||||
using AQLayout = remove_cvref_t<typename Problem::AQLayout>;
|
using AQLayout = remove_cvref_t<typename Problem::AQLayout>;
|
||||||
using BlockGemmShape = typename Problem::BlockGemmShape;
|
using BlockGemmShape = typename Problem::BlockGemmShape;
|
||||||
|
|
||||||
constexpr index_t BlockSize = Problem::kBlockSize;
|
constexpr index_t BlockSize = Problem::kBlockSize;
|
||||||
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
|
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
|
||||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||||
constexpr index_t KPerBlockAQ = KPerBlock / Problem::kQuantGroupSize;
|
constexpr index_t KPerBlockAQ = KPerBlock / Problem::kQuantGroupSize;
|
||||||
constexpr index_t VecLoadSize = GetVectorSizeAQ<Problem>();
|
constexpr index_t VecLoadSize = GetVectorSizeAQ<Problem>();
|
||||||
constexpr bool Preshuffle = Problem::Traits::Preshuffle;
|
constexpr bool PreshuffleQuant = Problem::Traits::PreshuffleQuant;
|
||||||
using WarpTile = typename Problem::BlockGemmShape::WarpTile;
|
using WarpTile = typename Problem::BlockGemmShape::WarpTile;
|
||||||
using WarpGemm = WarpGemmDispatcher<typename Problem::ComputeDataType,
|
using WarpGemm = WarpGemmDispatcher<typename Problem::ComputeDataType,
|
||||||
typename Problem::ComputeDataType,
|
typename Problem::ComputeDataType,
|
||||||
typename Problem::CDataType,
|
typename Problem::CDataType,
|
||||||
WarpTile::at(I0),
|
WarpTile::at(I0),
|
||||||
WarpTile::at(I1),
|
WarpTile::at(I1),
|
||||||
WarpTile::at(I2),
|
WarpTile::at(I2),
|
||||||
false>;
|
false>;
|
||||||
|
|
||||||
static_assert(std::is_same_v<AQLayout, tensor_layout::gemm::RowMajor>);
|
static_assert(std::is_same_v<AQLayout, tensor_layout::gemm::RowMajor>);
|
||||||
if constexpr(Preshuffle)
|
if constexpr(PreshuffleQuant)
|
||||||
{
|
{
|
||||||
using TileEncodingPattern =
|
using TileEncodingPattern =
|
||||||
TileDistributionEncodingPatternAQ<BlockGemmShape,
|
TileDistributionEncodingPatternAQ<BlockGemmShape,
|
||||||
@@ -64,7 +64,7 @@ struct GemmAQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgC
|
|||||||
WarpGemm::kM * KPerBlockAQ, get_warp_size()),
|
WarpGemm::kM * KPerBlockAQ, get_warp_size()),
|
||||||
KPerBlockAQ,
|
KPerBlockAQ,
|
||||||
VecLoadSize,
|
VecLoadSize,
|
||||||
Preshuffle>;
|
PreshuffleQuant>;
|
||||||
|
|
||||||
return TileEncodingPattern::Make2DStaticTileDistribution();
|
return TileEncodingPattern::Make2DStaticTileDistribution();
|
||||||
}
|
}
|
||||||
@@ -77,7 +77,7 @@ struct GemmAQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgC
|
|||||||
KPerBlockAQ,
|
KPerBlockAQ,
|
||||||
KPerBlockAQ,
|
KPerBlockAQ,
|
||||||
VecLoadSize,
|
VecLoadSize,
|
||||||
Preshuffle>;
|
PreshuffleQuant>;
|
||||||
|
|
||||||
return TileEncodingPattern::Make2DStaticTileDistribution();
|
return TileEncodingPattern::Make2DStaticTileDistribution();
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -7,6 +7,7 @@
|
|||||||
#include <sstream>
|
#include <sstream>
|
||||||
|
|
||||||
#include "ck_tile/core.hpp"
|
#include "ck_tile/core.hpp"
|
||||||
|
#include "ck_tile/core/numeric/math.hpp"
|
||||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
|
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
|
||||||
#include "ck_tile/ops/gemm_group_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_base.hpp"
|
#include "ck_tile/ops/gemm_group_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_base.hpp"
|
||||||
#include "ck_tile/host/concat.hpp"
|
#include "ck_tile/host/concat.hpp"
|
||||||
@@ -133,7 +134,7 @@ struct AQuantGemmPipelineAgBgCrCompV3 : public BaseAQuantGemmPipelineAgBgCrCompV
|
|||||||
static constexpr bool kPadK = Problem::kPadK;
|
static constexpr bool kPadK = Problem::kPadK;
|
||||||
|
|
||||||
static constexpr bool DoubleSmemBuffer = Problem::DoubleSmemBuffer;
|
static constexpr bool DoubleSmemBuffer = Problem::DoubleSmemBuffer;
|
||||||
static constexpr bool Preshuffle = Problem::Traits::Preshuffle;
|
static constexpr bool PreshuffleQuant = Problem::Traits::PreshuffleQuant;
|
||||||
|
|
||||||
static constexpr bool HasHotLoop = Problem::HasHotLoop;
|
static constexpr bool HasHotLoop = Problem::HasHotLoop;
|
||||||
static constexpr auto TailNum = Problem::TailNum;
|
static constexpr auto TailNum = Problem::TailNum;
|
||||||
@@ -235,6 +236,7 @@ struct AQuantGemmPipelineAgBgCrCompV3 : public BaseAQuantGemmPipelineAgBgCrCompV
|
|||||||
const BDramBlockWindowTmp& b_dram_block_window_tmp,
|
const BDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||||
const BElementFunction& b_element_func,
|
const BElementFunction& b_element_func,
|
||||||
const AQDramBlockWindowTmp& aq_dram_block_window_tmp,
|
const AQDramBlockWindowTmp& aq_dram_block_window_tmp,
|
||||||
|
index_t m,
|
||||||
index_t num_loop,
|
index_t num_loop,
|
||||||
void* p_smem) const
|
void* p_smem) const
|
||||||
{
|
{
|
||||||
@@ -311,9 +313,11 @@ struct AQuantGemmPipelineAgBgCrCompV3 : public BaseAQuantGemmPipelineAgBgCrCompV
|
|||||||
is_b_row_major ? make_array(KPerBlock, 0) : make_array(0, KPerBlock);
|
is_b_row_major ? make_array(KPerBlock, 0) : make_array(0, KPerBlock);
|
||||||
|
|
||||||
// only row_major for AQ
|
// only row_major for AQ
|
||||||
constexpr AQDramTileWindowStep aq_dram_tile_window_step =
|
const AQDramTileWindowStep aq_dram_tile_window_step =
|
||||||
Preshuffle ? make_array(MPerBlock / BlockGemm::WarpGemm::kM, 0)
|
PreshuffleQuant ? make_array(ck_tile::integer_least_multiple(m, MPerBlock) /
|
||||||
: make_array(0, KPerBlockAQ);
|
BlockGemm::WarpGemm::kM,
|
||||||
|
0)
|
||||||
|
: make_array(0, KPerBlockAQ);
|
||||||
|
|
||||||
// DRAM prefetch (global read 0)
|
// DRAM prefetch (global read 0)
|
||||||
Base::GlobalPrefetch(a_block_tile, a_copy_dram_window, a_dram_tile_window_step);
|
Base::GlobalPrefetch(a_block_tile, a_copy_dram_window, a_dram_tile_window_step);
|
||||||
@@ -458,6 +462,7 @@ struct AQuantGemmPipelineAgBgCrCompV3 : public BaseAQuantGemmPipelineAgBgCrCompV
|
|||||||
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
|
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
|
||||||
const BDramBlockWindowTmp& b_dram_block_window_tmp,
|
const BDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||||
const AQDramBlockWindowTmp& aq_dram_block_window_tmp,
|
const AQDramBlockWindowTmp& aq_dram_block_window_tmp,
|
||||||
|
index_t m,
|
||||||
index_t num_loop,
|
index_t num_loop,
|
||||||
void* p_smem) const
|
void* p_smem) const
|
||||||
{
|
{
|
||||||
@@ -467,6 +472,7 @@ struct AQuantGemmPipelineAgBgCrCompV3 : public BaseAQuantGemmPipelineAgBgCrCompV
|
|||||||
b_dram_block_window_tmp,
|
b_dram_block_window_tmp,
|
||||||
[](const BDataType& b) { return b; },
|
[](const BDataType& b) { return b; },
|
||||||
aq_dram_block_window_tmp,
|
aq_dram_block_window_tmp,
|
||||||
|
m,
|
||||||
num_loop,
|
num_loop,
|
||||||
p_smem);
|
p_smem);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -52,7 +52,7 @@ template <typename BlockGemmShape,
|
|||||||
index_t XPerTile,
|
index_t XPerTile,
|
||||||
index_t KPerBlockAQ,
|
index_t KPerBlockAQ,
|
||||||
index_t VecSize,
|
index_t VecSize,
|
||||||
bool Preshuffle>
|
bool PreshuffleQuant>
|
||||||
struct TileDistributionEncodingPatternAQ : public TileDistributionEncodingPattern
|
struct TileDistributionEncodingPatternAQ : public TileDistributionEncodingPattern
|
||||||
{
|
{
|
||||||
static_assert(XPerTile % VecSize == 0, "XPerTile must be a multiple of VecSize!");
|
static_assert(XPerTile % VecSize == 0, "XPerTile must be a multiple of VecSize!");
|
||||||
@@ -72,20 +72,20 @@ struct TileDistributionEncodingPatternAQ : public TileDistributionEncodingPatter
|
|||||||
|
|
||||||
CK_TILE_HOST_DEVICE static constexpr auto Make2DStaticTileDistribution()
|
CK_TILE_HOST_DEVICE static constexpr auto Make2DStaticTileDistribution()
|
||||||
{
|
{
|
||||||
if constexpr(Preshuffle)
|
if constexpr(PreshuffleQuant)
|
||||||
{
|
{
|
||||||
// # of elements per thread
|
// # of elements per thread
|
||||||
constexpr index_t X2 = KPerBlockAQ;
|
static_assert(XPerTile >= warp_size && XPerTile % warp_size == 0);
|
||||||
constexpr index_t X1 = warp_size / X2;
|
constexpr index_t X1 = warp_size;
|
||||||
constexpr index_t X0 = XPerTile / warp_size;
|
constexpr index_t X0 = XPerTile / warp_size;
|
||||||
|
|
||||||
constexpr index_t Y1 = MWarps;
|
constexpr index_t Y1 = MWarps;
|
||||||
constexpr index_t Y0 = YPerTile / Y1;
|
constexpr index_t Y0 = YPerTile / Y1;
|
||||||
return make_static_tile_distribution(
|
return make_static_tile_distribution(
|
||||||
tile_distribution_encoding<sequence<NWarps>,
|
tile_distribution_encoding<sequence<NWarps>,
|
||||||
tuple<sequence<Y0, Y1>, sequence<X0, X1, X2>>,
|
tuple<sequence<Y0, Y1>, sequence<X0, X1>>,
|
||||||
tuple<sequence<1, 0>, sequence<2, 2>>,
|
tuple<sequence<1, 0>, sequence<2>>,
|
||||||
tuple<sequence<1, 0>, sequence<1, 2>>,
|
tuple<sequence<1, 0>, sequence<1>>,
|
||||||
sequence<1, 2>,
|
sequence<1, 2>,
|
||||||
sequence<0, 0>>{});
|
sequence<0, 0>>{});
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ namespace ck_tile {
|
|||||||
template <bool kPadM_,
|
template <bool kPadM_,
|
||||||
bool kPadN_,
|
bool kPadN_,
|
||||||
bool kPadK_,
|
bool kPadK_,
|
||||||
bool Preshuffle_,
|
bool PreshuffleQuant_,
|
||||||
typename ALayout_,
|
typename ALayout_,
|
||||||
typename BLayout_,
|
typename BLayout_,
|
||||||
typename CLayout_,
|
typename CLayout_,
|
||||||
@@ -30,7 +30,7 @@ struct TileGemmAQuantTraits
|
|||||||
|
|
||||||
static constexpr bool UseStructuredSparsity = false;
|
static constexpr bool UseStructuredSparsity = false;
|
||||||
static constexpr index_t NumWaveGroups = 1;
|
static constexpr index_t NumWaveGroups = 1;
|
||||||
static constexpr bool Preshuffle = Preshuffle_;
|
static constexpr bool PreshuffleQuant = PreshuffleQuant_;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace ck_tile
|
} // namespace ck_tile
|
||||||
|
|||||||
@@ -11,11 +11,9 @@
|
|||||||
#include "ck_tile/ops/gemm.hpp"
|
#include "ck_tile/ops/gemm.hpp"
|
||||||
#include "ck_tile/ops/gemm_group_quant.hpp"
|
#include "ck_tile/ops/gemm_group_quant.hpp"
|
||||||
|
|
||||||
#define CK_TILE_PIPELINE_COMPUTE_V3 1
|
#define CK_TILE_PIPELINE_PREFILL 1
|
||||||
#define CK_TILE_PIPELINE_MEMORY 2
|
#define CK_TILE_PIPELINE_DECODE 2
|
||||||
#define CK_TILE_PIPELINE_COMPUTE_V4 3
|
#define CK_TILE_PIPELINE_PRESHUFFLEQUANT 3
|
||||||
#define CK_TILE_PIPELINE_COMPUTE_V5 4
|
|
||||||
#define CK_TILE_PIPELINE_PRESHUFFLE 5
|
|
||||||
|
|
||||||
template <typename PrecType, ck_tile::index_t M_Warp_Tile>
|
template <typename PrecType, ck_tile::index_t M_Warp_Tile>
|
||||||
constexpr ck_tile::index_t get_k_warp_tile()
|
constexpr ck_tile::index_t get_k_warp_tile()
|
||||||
@@ -34,21 +32,6 @@ constexpr ck_tile::index_t get_k_warp_tile()
|
|||||||
return 32;
|
return 32;
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
template <typename PrecType, ck_tile::index_t M_Warp_Tile>
|
|
||||||
constexpr ck_tile::index_t get_k_warp_tile_flatmm()
|
|
||||||
{
|
|
||||||
#if defined(__gfx950__)
|
|
||||||
if constexpr(M_Warp_Tile == 32)
|
|
||||||
return sizeof(PrecType) == 2 ? 16 : 64;
|
|
||||||
else
|
|
||||||
return sizeof(PrecType) == 2 ? 32 : 128;
|
|
||||||
#else
|
|
||||||
if constexpr(M_Warp_Tile == 32)
|
|
||||||
return sizeof(PrecType) == 2 ? 16 : 32;
|
|
||||||
else
|
|
||||||
return sizeof(PrecType) == 2 ? 32 : 64;
|
|
||||||
#endif
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename ADataType, typename BDataType, typename AccDataType, typename CDataType>
|
template <typename ADataType, typename BDataType, typename AccDataType, typename CDataType>
|
||||||
auto calculate_rtol_atol(const ck_tile::index_t K,
|
auto calculate_rtol_atol(const ck_tile::index_t K,
|
||||||
@@ -93,195 +76,32 @@ struct GemmConfigBase
|
|||||||
static constexpr ck_tile::index_t TileParitionerGroupNum = 8;
|
static constexpr ck_tile::index_t TileParitionerGroupNum = 8;
|
||||||
static constexpr ck_tile::index_t TileParitionerM01 = 4;
|
static constexpr ck_tile::index_t TileParitionerM01 = 4;
|
||||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
|
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
|
||||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3;
|
|
||||||
static constexpr ck_tile::index_t NumWaveGroups = 1;
|
static constexpr ck_tile::index_t NumWaveGroups = 1;
|
||||||
static constexpr bool Preshuffle = false;
|
static constexpr bool PreshuffleQuant = false;
|
||||||
|
static constexpr bool DoubleSmemBuffer = true;
|
||||||
};
|
};
|
||||||
|
|
||||||
template <typename PrecType>
|
template <typename PrecType>
|
||||||
struct GemmConfigMemoryInterwave : public GemmConfigBase
|
struct GemmConfigDecode : public GemmConfigBase
|
||||||
{
|
{
|
||||||
// Memory friendly for Interwave scheduler
|
static constexpr ck_tile::index_t M_Tile = 16;
|
||||||
static constexpr ck_tile::index_t M_Tile = 128;
|
static constexpr ck_tile::index_t N_Tile = 64;
|
||||||
static constexpr ck_tile::index_t N_Tile = 32;
|
static constexpr ck_tile::index_t K_Tile = 256 / sizeof(PrecType);
|
||||||
static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType);
|
|
||||||
|
|
||||||
static constexpr ck_tile::index_t M_Warp = 4;
|
|
||||||
static constexpr ck_tile::index_t N_Warp = 1;
|
|
||||||
static constexpr ck_tile::index_t K_Warp = 1;
|
|
||||||
|
|
||||||
static constexpr ck_tile::index_t M_Warp_Tile = 32;
|
|
||||||
static constexpr ck_tile::index_t N_Warp_Tile = 32;
|
|
||||||
static constexpr ck_tile::index_t K_Warp_Tile = sizeof(PrecType) == 2 ? 8 : 16;
|
|
||||||
|
|
||||||
static constexpr bool DoubleSmemBuffer = false;
|
|
||||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_MEMORY;
|
|
||||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Interwave;
|
|
||||||
};
|
|
||||||
|
|
||||||
template <typename PrecType>
|
|
||||||
struct GemmConfigMemoryIntrawave : public GemmConfigBase
|
|
||||||
{
|
|
||||||
static constexpr ck_tile::index_t M_Tile = 128;
|
|
||||||
static constexpr ck_tile::index_t N_Tile = 32;
|
|
||||||
static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType);
|
|
||||||
|
|
||||||
static constexpr ck_tile::index_t M_Warp = 4;
|
|
||||||
static constexpr ck_tile::index_t N_Warp = 1;
|
|
||||||
static constexpr ck_tile::index_t K_Warp = 1;
|
|
||||||
|
|
||||||
static constexpr ck_tile::index_t M_Warp_Tile = 32;
|
|
||||||
static constexpr ck_tile::index_t N_Warp_Tile = 32;
|
|
||||||
static constexpr ck_tile::index_t K_Warp_Tile = sizeof(PrecType) == 2 ? 8 : 16;
|
|
||||||
|
|
||||||
static constexpr bool DoubleSmemBuffer = false;
|
|
||||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_MEMORY;
|
|
||||||
};
|
|
||||||
|
|
||||||
template <typename PrecType>
|
|
||||||
struct GemmConfigComputeV3 : public GemmConfigBase
|
|
||||||
{
|
|
||||||
// Compute V3 only support Intrawave scheduler
|
|
||||||
static constexpr ck_tile::index_t M_Tile = 32;
|
|
||||||
static constexpr ck_tile::index_t N_Tile = 128;
|
|
||||||
static constexpr ck_tile::index_t K_Tile = 256;
|
|
||||||
|
|
||||||
static constexpr ck_tile::index_t M_Warp = 1;
|
static constexpr ck_tile::index_t M_Warp = 1;
|
||||||
static constexpr ck_tile::index_t N_Warp = 4;
|
static constexpr ck_tile::index_t N_Warp = 4;
|
||||||
static constexpr ck_tile::index_t K_Warp = 1;
|
static constexpr ck_tile::index_t K_Warp = 1;
|
||||||
|
|
||||||
static constexpr ck_tile::index_t M_Warp_Tile = 32;
|
|
||||||
static constexpr ck_tile::index_t N_Warp_Tile = 32;
|
|
||||||
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
|
|
||||||
|
|
||||||
static constexpr bool DoubleSmemBuffer = false;
|
|
||||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3;
|
|
||||||
};
|
|
||||||
|
|
||||||
template <typename PrecType>
|
|
||||||
struct GemmConfigComputeV3_1 : public GemmConfigBase
|
|
||||||
{
|
|
||||||
static constexpr ck_tile::index_t M_Tile = 256;
|
|
||||||
static constexpr ck_tile::index_t N_Tile = 256;
|
|
||||||
static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType);
|
|
||||||
|
|
||||||
static constexpr ck_tile::index_t M_Warp = 2;
|
|
||||||
static constexpr ck_tile::index_t N_Warp = 2;
|
|
||||||
static constexpr ck_tile::index_t K_Warp = 1;
|
|
||||||
|
|
||||||
static constexpr ck_tile::index_t M_Warp_Tile = 32;
|
|
||||||
static constexpr ck_tile::index_t N_Warp_Tile = 32;
|
|
||||||
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
|
|
||||||
|
|
||||||
static constexpr bool DoubleSmemBuffer = false;
|
|
||||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3;
|
|
||||||
};
|
|
||||||
|
|
||||||
template <typename PrecType>
|
|
||||||
struct GemmConfigComputeV3_2 : public GemmConfigBase
|
|
||||||
{
|
|
||||||
static constexpr ck_tile::index_t M_Tile = 128;
|
|
||||||
static constexpr ck_tile::index_t N_Tile = 128;
|
|
||||||
static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType);
|
|
||||||
|
|
||||||
static constexpr ck_tile::index_t M_Warp = 2;
|
|
||||||
static constexpr ck_tile::index_t N_Warp = 2;
|
|
||||||
static constexpr ck_tile::index_t K_Warp = 1;
|
|
||||||
|
|
||||||
static constexpr ck_tile::index_t M_Warp_Tile = 16;
|
static constexpr ck_tile::index_t M_Warp_Tile = 16;
|
||||||
static constexpr ck_tile::index_t N_Warp_Tile = 16;
|
static constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||||
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
|
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
|
||||||
|
|
||||||
static constexpr bool DoubleSmemBuffer = false;
|
|
||||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3;
|
|
||||||
|
|
||||||
static constexpr int kBlockPerCu = 2;
|
|
||||||
};
|
|
||||||
|
|
||||||
template <typename PrecType>
|
|
||||||
struct GemmConfigComputeV4 : public GemmConfigBase
|
|
||||||
{
|
|
||||||
// Compute V4 only support Intrawave scheduler
|
|
||||||
// Using the ping pong reader in the lds level
|
|
||||||
static constexpr ck_tile::index_t M_Tile = 256;
|
|
||||||
static constexpr ck_tile::index_t N_Tile = 256;
|
|
||||||
static constexpr ck_tile::index_t K_Tile = 64 / sizeof(PrecType);
|
|
||||||
|
|
||||||
static constexpr ck_tile::index_t M_Warp = 2;
|
|
||||||
static constexpr ck_tile::index_t N_Warp = 2;
|
|
||||||
static constexpr ck_tile::index_t K_Warp = 1;
|
|
||||||
|
|
||||||
static constexpr ck_tile::index_t M_Warp_Tile = 32;
|
|
||||||
static constexpr ck_tile::index_t N_Warp_Tile = 32;
|
|
||||||
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
|
|
||||||
|
|
||||||
static constexpr bool DoubleSmemBuffer = true;
|
|
||||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V4;
|
|
||||||
};
|
|
||||||
|
|
||||||
template <typename PrecType>
|
|
||||||
struct GemmConfigComputeV4_1 : public GemmConfigBase
|
|
||||||
{
|
|
||||||
static constexpr ck_tile::index_t M_Tile = 256;
|
|
||||||
static constexpr ck_tile::index_t N_Tile = 256;
|
|
||||||
static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType);
|
|
||||||
|
|
||||||
static constexpr ck_tile::index_t M_Warp = 2;
|
|
||||||
static constexpr ck_tile::index_t N_Warp = 2;
|
|
||||||
static constexpr ck_tile::index_t K_Warp = 1;
|
|
||||||
|
|
||||||
static constexpr ck_tile::index_t M_Warp_Tile = 32;
|
|
||||||
static constexpr ck_tile::index_t N_Warp_Tile = 32;
|
|
||||||
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
|
|
||||||
|
|
||||||
static constexpr bool DoubleSmemBuffer = true;
|
|
||||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V4;
|
|
||||||
};
|
|
||||||
|
|
||||||
template <typename PrecType>
|
|
||||||
struct GemmConfigComputeV5 : public GemmConfigBase
|
|
||||||
{
|
|
||||||
static constexpr ck_tile::index_t M_Tile = 128;
|
|
||||||
static constexpr ck_tile::index_t N_Tile = 128;
|
|
||||||
static constexpr ck_tile::index_t K_Tile = 64 / sizeof(PrecType);
|
|
||||||
|
|
||||||
static constexpr ck_tile::index_t M_Warp = 1;
|
|
||||||
static constexpr ck_tile::index_t N_Warp = 1;
|
|
||||||
static constexpr ck_tile::index_t K_Warp = 2;
|
|
||||||
|
|
||||||
static constexpr ck_tile::index_t M_Warp_Tile = 32;
|
|
||||||
static constexpr ck_tile::index_t N_Warp_Tile = 32;
|
|
||||||
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
|
|
||||||
|
|
||||||
static constexpr bool DoubleSmemBuffer = false;
|
|
||||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V5;
|
|
||||||
static constexpr ck_tile::index_t NumWaNumWaveGroups = 2;
|
|
||||||
};
|
|
||||||
|
|
||||||
template <typename PrecType>
|
|
||||||
struct GemmConfigPreshufle_1 : public GemmConfigBase
|
|
||||||
{
|
|
||||||
static constexpr ck_tile::index_t M_Tile = 128;
|
|
||||||
static constexpr ck_tile::index_t N_Tile = 128;
|
|
||||||
static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType);
|
|
||||||
|
|
||||||
static constexpr ck_tile::index_t M_Warp = 1;
|
|
||||||
static constexpr ck_tile::index_t N_Warp = 4;
|
|
||||||
static constexpr ck_tile::index_t K_Warp = 1;
|
|
||||||
|
|
||||||
static constexpr ck_tile::index_t M_Warp_Tile = 32;
|
|
||||||
static constexpr ck_tile::index_t N_Warp_Tile = 32;
|
|
||||||
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile_flatmm<PrecType, M_Warp_Tile>();
|
|
||||||
|
|
||||||
static constexpr int kBlockPerCu = 2;
|
|
||||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
|
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
|
||||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_PRESHUFFLE;
|
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_DECODE;
|
||||||
static constexpr bool Preshuffle = true;
|
|
||||||
static constexpr bool DoubleSmemBuffer = false;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
template <typename PrecType>
|
template <typename PrecType>
|
||||||
struct GemmConfigPreshufle_2 : public GemmConfigBase
|
struct GemmConfigPrefill : public GemmConfigBase
|
||||||
{
|
{
|
||||||
static constexpr ck_tile::index_t M_Tile = 128;
|
static constexpr ck_tile::index_t M_Tile = 128;
|
||||||
static constexpr ck_tile::index_t N_Tile = 128;
|
static constexpr ck_tile::index_t N_Tile = 128;
|
||||||
@@ -293,71 +113,32 @@ struct GemmConfigPreshufle_2 : public GemmConfigBase
|
|||||||
|
|
||||||
static constexpr ck_tile::index_t M_Warp_Tile = 16;
|
static constexpr ck_tile::index_t M_Warp_Tile = 16;
|
||||||
static constexpr ck_tile::index_t N_Warp_Tile = 16;
|
static constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||||
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile_flatmm<PrecType, M_Warp_Tile>();
|
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
|
||||||
|
|
||||||
static constexpr int kBlockPerCu = 2;
|
static constexpr int kBlockPerCu = 2;
|
||||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
|
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
|
||||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_PRESHUFFLE;
|
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_PREFILL;
|
||||||
static constexpr bool Preshuffle = true;
|
|
||||||
static constexpr bool DoubleSmemBuffer = false;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
template <typename ADataType, typename BDataType = ADataType, typename CDataType = ADataType>
|
template <typename PrecType>
|
||||||
struct GemmTypeConfig;
|
struct GemmConfigPreshuffleQuant : public GemmConfigBase
|
||||||
|
|
||||||
template <>
|
|
||||||
struct GemmTypeConfig<ck_tile::half_t>
|
|
||||||
{
|
{
|
||||||
using ADataType = ck_tile::half_t;
|
static constexpr ck_tile::index_t M_Tile = 16;
|
||||||
using BDataType = ck_tile::half_t;
|
static constexpr ck_tile::index_t N_Tile = 64;
|
||||||
using AccDataType = float;
|
static constexpr ck_tile::index_t K_Tile = 256 / sizeof(PrecType);
|
||||||
using CDataType = ck_tile::half_t;
|
|
||||||
// ToDo: Add more bias config to support different categories of GEMM.
|
|
||||||
};
|
|
||||||
|
|
||||||
template <>
|
static constexpr ck_tile::index_t M_Warp = 1;
|
||||||
struct GemmTypeConfig<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t>
|
static constexpr ck_tile::index_t N_Warp = 4;
|
||||||
{
|
static constexpr ck_tile::index_t K_Warp = 1;
|
||||||
using ADataType = ck_tile::bf16_t;
|
|
||||||
using BDataType = ck_tile::bf16_t;
|
|
||||||
using AccDataType = float;
|
|
||||||
using CDataType = ck_tile::bf16_t;
|
|
||||||
};
|
|
||||||
|
|
||||||
template <>
|
static constexpr ck_tile::index_t M_Warp_Tile = 16;
|
||||||
struct GemmTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t>
|
static constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||||
{
|
static constexpr ck_tile::index_t K_Warp_Tile =
|
||||||
using ADataType = ck_tile::fp8_t;
|
get_k_from_preshuffled_warp_tile<PrecType, M_Warp_Tile>();
|
||||||
using BDataType = ck_tile::fp8_t;
|
|
||||||
using AccDataType = float;
|
|
||||||
using CDataType = ck_tile::half_t;
|
|
||||||
};
|
|
||||||
|
|
||||||
template <>
|
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
|
||||||
struct GemmTypeConfig<ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t>
|
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_PRESHUFFLEQUANT;
|
||||||
{
|
static constexpr bool PreshuffleQuant = true;
|
||||||
using ADataType = ck_tile::bf8_t;
|
|
||||||
using BDataType = ck_tile::bf8_t;
|
|
||||||
using AccDataType = float;
|
|
||||||
using CDataType = ck_tile::half_t;
|
|
||||||
};
|
|
||||||
|
|
||||||
template <>
|
|
||||||
struct GemmTypeConfig<ck_tile::half_t, ck_tile::pk_int4_t, ck_tile::half_t>
|
|
||||||
{
|
|
||||||
using ADataType = ck_tile::half_t;
|
|
||||||
using BDataType = ck_tile::pk_int4_t;
|
|
||||||
using AccDataType = float;
|
|
||||||
using CDataType = ck_tile::half_t;
|
|
||||||
};
|
|
||||||
|
|
||||||
template <>
|
|
||||||
struct GemmTypeConfig<ck_tile::int8_t, ck_tile::int8_t, int32_t>
|
|
||||||
{
|
|
||||||
using ADataType = ck_tile::int8_t;
|
|
||||||
using BDataType = ck_tile::int8_t;
|
|
||||||
using AccDataType = int32_t;
|
|
||||||
using CDataType = int32_t;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
template <typename ADataType_,
|
template <typename ADataType_,
|
||||||
@@ -373,176 +154,6 @@ struct GemmQuantTypeConfig
|
|||||||
using CDataType = CDataType_;
|
using CDataType = CDataType_;
|
||||||
};
|
};
|
||||||
|
|
||||||
template <>
|
|
||||||
struct GemmQuantTypeConfig<ck_tile::half_t>
|
|
||||||
{
|
|
||||||
using ADataType = ck_tile::half_t;
|
|
||||||
using QDataType = float;
|
|
||||||
using BDataType = ck_tile::half_t;
|
|
||||||
using AccDataType = float;
|
|
||||||
using CDataType = ck_tile::half_t;
|
|
||||||
};
|
|
||||||
|
|
||||||
template <>
|
|
||||||
struct GemmQuantTypeConfig<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t>
|
|
||||||
{
|
|
||||||
using ADataType = ck_tile::bf16_t;
|
|
||||||
using QDataType = float;
|
|
||||||
using BDataType = ck_tile::bf16_t;
|
|
||||||
using AccDataType = float;
|
|
||||||
using CDataType = ck_tile::bf16_t;
|
|
||||||
};
|
|
||||||
|
|
||||||
template <>
|
|
||||||
struct GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t>
|
|
||||||
{
|
|
||||||
using ADataType = ck_tile::fp8_t;
|
|
||||||
using QDataType = float;
|
|
||||||
using BDataType = ck_tile::fp8_t;
|
|
||||||
using AccDataType = float;
|
|
||||||
using CDataType = ck_tile::half_t;
|
|
||||||
};
|
|
||||||
|
|
||||||
template <>
|
|
||||||
struct GemmQuantTypeConfig<ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t>
|
|
||||||
{
|
|
||||||
using ADataType = ck_tile::bf8_t;
|
|
||||||
using QDataType = float;
|
|
||||||
using BDataType = ck_tile::bf8_t;
|
|
||||||
using AccDataType = float;
|
|
||||||
using CDataType = ck_tile::half_t;
|
|
||||||
};
|
|
||||||
|
|
||||||
template <>
|
|
||||||
struct GemmQuantTypeConfig<ck_tile::half_t, ck_tile::pk_int4_t, ck_tile::half_t>
|
|
||||||
{
|
|
||||||
using ADataType = ck_tile::half_t;
|
|
||||||
using QDataType = float;
|
|
||||||
using BDataType = ck_tile::pk_int4_t;
|
|
||||||
using AccDataType = float;
|
|
||||||
using CDataType = ck_tile::half_t;
|
|
||||||
};
|
|
||||||
|
|
||||||
template <>
|
|
||||||
struct GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, float>
|
|
||||||
{
|
|
||||||
using ADataType = ck_tile::fp8_t;
|
|
||||||
using QDataType = float;
|
|
||||||
using BDataType = ck_tile::fp8_t;
|
|
||||||
using AccDataType = float;
|
|
||||||
using CDataType = float;
|
|
||||||
};
|
|
||||||
|
|
||||||
template <>
|
|
||||||
struct GemmQuantTypeConfig<ck_tile::bf8_t, ck_tile::bf8_t, float>
|
|
||||||
{
|
|
||||||
using ADataType = ck_tile::bf8_t;
|
|
||||||
using QDataType = float;
|
|
||||||
using BDataType = ck_tile::bf8_t;
|
|
||||||
using AccDataType = float;
|
|
||||||
using CDataType = float;
|
|
||||||
};
|
|
||||||
|
|
||||||
template <>
|
|
||||||
struct GemmQuantTypeConfig<ck_tile::pk_int4_t, ck_tile::fp8_t, float, ck_tile::fp8_t>
|
|
||||||
{
|
|
||||||
using ADataType = ck_tile::pk_int4_t;
|
|
||||||
using QDataType = ck_tile::fp8_t;
|
|
||||||
using BDataType = ck_tile::fp8_t;
|
|
||||||
using AccDataType = float;
|
|
||||||
using CDataType = float;
|
|
||||||
};
|
|
||||||
|
|
||||||
template <>
|
|
||||||
struct GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, float, ck_tile::fp8_t>
|
|
||||||
{
|
|
||||||
using ADataType = ck_tile::fp8_t;
|
|
||||||
using QDataType = ck_tile::fp8_t;
|
|
||||||
using BDataType = ck_tile::fp8_t;
|
|
||||||
using AccDataType = float;
|
|
||||||
using CDataType = float;
|
|
||||||
};
|
|
||||||
|
|
||||||
template <>
|
|
||||||
struct GemmQuantTypeConfig<ck_tile::bf8_t, ck_tile::bf8_t, float, ck_tile::bf8_t>
|
|
||||||
{
|
|
||||||
using ADataType = ck_tile::bf8_t;
|
|
||||||
using QDataType = ck_tile::bf8_t;
|
|
||||||
using BDataType = ck_tile::bf8_t;
|
|
||||||
using AccDataType = float;
|
|
||||||
using CDataType = float;
|
|
||||||
};
|
|
||||||
|
|
||||||
template <>
|
|
||||||
struct GemmQuantTypeConfig<ck_tile::pk_int4_t, ck_tile::bf8_t, float, ck_tile::bf8_t>
|
|
||||||
{
|
|
||||||
using ADataType = ck_tile::pk_int4_t;
|
|
||||||
using QDataType = ck_tile::bf8_t;
|
|
||||||
using BDataType = ck_tile::bf8_t;
|
|
||||||
using AccDataType = float;
|
|
||||||
using CDataType = float;
|
|
||||||
};
|
|
||||||
|
|
||||||
template <>
|
|
||||||
struct GemmQuantTypeConfig<ck_tile::pk_int4_t, ck_tile::fp8_t, float, float>
|
|
||||||
{
|
|
||||||
using ADataType = ck_tile::pk_int4_t;
|
|
||||||
using QDataType = float;
|
|
||||||
using BDataType = ck_tile::fp8_t;
|
|
||||||
using AccDataType = float;
|
|
||||||
using CDataType = float;
|
|
||||||
};
|
|
||||||
|
|
||||||
template <>
|
|
||||||
struct GemmQuantTypeConfig<ck_tile::pk_int4_t, ck_tile::bf8_t, float, float>
|
|
||||||
{
|
|
||||||
using ADataType = ck_tile::pk_int4_t;
|
|
||||||
using QDataType = float;
|
|
||||||
using BDataType = ck_tile::bf8_t;
|
|
||||||
using AccDataType = float;
|
|
||||||
using CDataType = float;
|
|
||||||
};
|
|
||||||
|
|
||||||
template <>
|
|
||||||
struct GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::pk_int4_t, float, ck_tile::fp8_t>
|
|
||||||
{
|
|
||||||
using ADataType = ck_tile::fp8_t;
|
|
||||||
using QDataType = ck_tile::fp8_t;
|
|
||||||
using BDataType = ck_tile::pk_int4_t;
|
|
||||||
using AccDataType = float;
|
|
||||||
using CDataType = float;
|
|
||||||
};
|
|
||||||
|
|
||||||
template <>
|
|
||||||
struct GemmQuantTypeConfig<ck_tile::bf8_t, ck_tile::pk_int4_t, float, ck_tile::bf8_t>
|
|
||||||
{
|
|
||||||
using ADataType = ck_tile::bf8_t;
|
|
||||||
using QDataType = ck_tile::bf8_t;
|
|
||||||
using BDataType = ck_tile::pk_int4_t;
|
|
||||||
using AccDataType = float;
|
|
||||||
using CDataType = float;
|
|
||||||
};
|
|
||||||
|
|
||||||
template <>
|
|
||||||
struct GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::pk_int4_t, float, float>
|
|
||||||
{
|
|
||||||
using ADataType = ck_tile::fp8_t;
|
|
||||||
using QDataType = float;
|
|
||||||
using BDataType = ck_tile::pk_int4_t;
|
|
||||||
using AccDataType = float;
|
|
||||||
using CDataType = float;
|
|
||||||
};
|
|
||||||
|
|
||||||
template <>
|
|
||||||
struct GemmQuantTypeConfig<ck_tile::bf8_t, ck_tile::pk_int4_t, float, float>
|
|
||||||
{
|
|
||||||
using ADataType = ck_tile::bf8_t;
|
|
||||||
using QDataType = float;
|
|
||||||
using BDataType = ck_tile::pk_int4_t;
|
|
||||||
using AccDataType = float;
|
|
||||||
using CDataType = float;
|
|
||||||
};
|
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
struct DataTypeTraits;
|
struct DataTypeTraits;
|
||||||
|
|
||||||
@@ -600,55 +211,6 @@ struct DataTypeTraits<ck_tile::int8_t>
|
|||||||
static constexpr const char* name = "int8";
|
static constexpr const char* name = "int8";
|
||||||
};
|
};
|
||||||
|
|
||||||
template <ck_tile::index_t PipelineId>
|
|
||||||
struct PipelineTypeTraits;
|
|
||||||
|
|
||||||
template <>
|
|
||||||
struct PipelineTypeTraits<CK_TILE_PIPELINE_MEMORY>
|
|
||||||
{
|
|
||||||
template <typename PipelineProblem>
|
|
||||||
using GemmPipeline = ck_tile::GemmPipelineAgBgCrMem<PipelineProblem>;
|
|
||||||
template <typename PipelineProblem>
|
|
||||||
using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrMem<PipelineProblem>;
|
|
||||||
};
|
|
||||||
|
|
||||||
template <>
|
|
||||||
struct PipelineTypeTraits<CK_TILE_PIPELINE_COMPUTE_V3>
|
|
||||||
{
|
|
||||||
template <typename PipelineProblem>
|
|
||||||
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3<PipelineProblem>;
|
|
||||||
template <typename PipelineProblem>
|
|
||||||
using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV3<PipelineProblem>;
|
|
||||||
};
|
|
||||||
|
|
||||||
template <>
|
|
||||||
struct PipelineTypeTraits<CK_TILE_PIPELINE_COMPUTE_V4>
|
|
||||||
{
|
|
||||||
template <typename PipelineProblem>
|
|
||||||
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV4<PipelineProblem>;
|
|
||||||
template <typename PipelineProblem>
|
|
||||||
using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV4<PipelineProblem>;
|
|
||||||
};
|
|
||||||
|
|
||||||
template <>
|
|
||||||
struct PipelineTypeTraits<CK_TILE_PIPELINE_COMPUTE_V5>
|
|
||||||
{
|
|
||||||
template <typename PipelineProblem>
|
|
||||||
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV5<PipelineProblem>;
|
|
||||||
template <typename PipelineProblem>
|
|
||||||
using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV5<PipelineProblem>;
|
|
||||||
};
|
|
||||||
|
|
||||||
template <>
|
|
||||||
struct PipelineTypeTraits<CK_TILE_PIPELINE_PRESHUFFLE>
|
|
||||||
{
|
|
||||||
template <typename PipelineProblem>
|
|
||||||
using GemmPipeline = ck_tile::WeightPreshufflePipelineAGmemBGmemCRegV1<PipelineProblem>;
|
|
||||||
template <typename PipelineProblem>
|
|
||||||
using UniversalGemmPipeline =
|
|
||||||
ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV1<PipelineProblem>;
|
|
||||||
};
|
|
||||||
|
|
||||||
auto create_args(int argc, char* argv[])
|
auto create_args(int argc, char* argv[])
|
||||||
{
|
{
|
||||||
ck_tile::ArgParser arg_parser;
|
ck_tile::ArgParser arg_parser;
|
||||||
|
|||||||
@@ -15,7 +15,8 @@
|
|||||||
#include "ck_tile/host.hpp"
|
#include "ck_tile/host.hpp"
|
||||||
#include "test_gemm_aquant_utils.hpp"
|
#include "test_gemm_aquant_utils.hpp"
|
||||||
|
|
||||||
template <typename ADataType,
|
template <typename GemmConfig,
|
||||||
|
typename ADataType,
|
||||||
typename AQDataType,
|
typename AQDataType,
|
||||||
typename BDataType,
|
typename BDataType,
|
||||||
typename AccDataType,
|
typename AccDataType,
|
||||||
@@ -24,8 +25,7 @@ template <typename ADataType,
|
|||||||
typename ALayout,
|
typename ALayout,
|
||||||
typename BLayout,
|
typename BLayout,
|
||||||
typename CLayout,
|
typename CLayout,
|
||||||
uint32_t QuantGroupSize,
|
uint32_t QuantGroupSize>
|
||||||
bool Preshuffle = false>
|
|
||||||
float gemm_calc_aquant(const ck_tile::AQuantGemmHostArgs& args, const ck_tile::stream_config& s)
|
float gemm_calc_aquant(const ck_tile::AQuantGemmHostArgs& args, const ck_tile::stream_config& s)
|
||||||
{
|
{
|
||||||
constexpr bool kPadM = false;
|
constexpr bool kPadM = false;
|
||||||
@@ -36,17 +36,17 @@ float gemm_calc_aquant(const ck_tile::AQuantGemmHostArgs& args, const ck_tile::s
|
|||||||
|
|
||||||
static_assert(std::is_same_v<CLayout, ck_tile::tensor_layout::gemm::RowMajor>);
|
static_assert(std::is_same_v<CLayout, ck_tile::tensor_layout::gemm::RowMajor>);
|
||||||
|
|
||||||
constexpr ck_tile::index_t M_Tile = 16;
|
constexpr ck_tile::index_t M_Tile = GemmConfig::M_Tile;
|
||||||
constexpr ck_tile::index_t N_Tile = 64;
|
constexpr ck_tile::index_t N_Tile = GemmConfig::N_Tile;
|
||||||
constexpr ck_tile::index_t K_Tile = 256;
|
constexpr ck_tile::index_t K_Tile = GemmConfig::K_Tile;
|
||||||
|
|
||||||
constexpr ck_tile::index_t M_Warp = 1;
|
constexpr ck_tile::index_t M_Warp = GemmConfig::M_Warp;
|
||||||
constexpr ck_tile::index_t N_Warp = 4;
|
constexpr ck_tile::index_t N_Warp = GemmConfig::N_Warp;
|
||||||
constexpr ck_tile::index_t K_Warp = 1;
|
constexpr ck_tile::index_t K_Warp = GemmConfig::K_Warp;
|
||||||
|
|
||||||
constexpr ck_tile::index_t M_Warp_Tile = 16;
|
constexpr ck_tile::index_t M_Warp_Tile = GemmConfig::M_Warp_Tile;
|
||||||
constexpr ck_tile::index_t N_Warp_Tile = 16;
|
constexpr ck_tile::index_t N_Warp_Tile = GemmConfig::N_Warp_Tile;
|
||||||
constexpr ck_tile::index_t K_Warp_Tile = 32;
|
constexpr ck_tile::index_t K_Warp_Tile = GemmConfig::K_Warp_Tile;
|
||||||
|
|
||||||
using CodegenGemmShape =
|
using CodegenGemmShape =
|
||||||
ck_tile::TileGemmShape<ck_tile::sequence<M_Tile, N_Tile, K_Tile>,
|
ck_tile::TileGemmShape<ck_tile::sequence<M_Tile, N_Tile, K_Tile>,
|
||||||
@@ -55,8 +55,13 @@ float gemm_calc_aquant(const ck_tile::AQuantGemmHostArgs& args, const ck_tile::s
|
|||||||
|
|
||||||
using TilePartitioner = ck_tile::GemmTile1DPartitioner<CodegenGemmShape>;
|
using TilePartitioner = ck_tile::GemmTile1DPartitioner<CodegenGemmShape>;
|
||||||
|
|
||||||
using CodegenGemmTraits =
|
using CodegenGemmTraits = ck_tile::TileGemmAQuantTraits<kPadM,
|
||||||
ck_tile::TileGemmAQuantTraits<kPadM, kPadN, kPadK, Preshuffle, ALayout, BLayout, CLayout>;
|
kPadN,
|
||||||
|
kPadK,
|
||||||
|
GemmConfig::PreshuffleQuant,
|
||||||
|
ALayout,
|
||||||
|
BLayout,
|
||||||
|
CLayout>;
|
||||||
|
|
||||||
using GemmPipelineProblem = ck_tile::GemmPipelineProblemBase<ADataType,
|
using GemmPipelineProblem = ck_tile::GemmPipelineProblemBase<ADataType,
|
||||||
BDataType,
|
BDataType,
|
||||||
@@ -152,7 +157,8 @@ static constexpr inline auto is_row_major(Layout layout_)
|
|||||||
ck_tile::tensor_layout::gemm::RowMajor>>{};
|
ck_tile::tensor_layout::gemm::RowMajor>>{};
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename ADataType,
|
template <typename GemmConfig,
|
||||||
|
typename ADataType,
|
||||||
typename AQDataType,
|
typename AQDataType,
|
||||||
typename BDataType,
|
typename BDataType,
|
||||||
typename AccDataType,
|
typename AccDataType,
|
||||||
@@ -161,8 +167,7 @@ template <typename ADataType,
|
|||||||
typename AQLayout,
|
typename AQLayout,
|
||||||
typename BLayout,
|
typename BLayout,
|
||||||
typename CLayout,
|
typename CLayout,
|
||||||
uint32_t QuantGroupSize,
|
uint32_t QuantGroupSize>
|
||||||
bool Preshuffle = false>
|
|
||||||
float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
|
float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
|
||||||
ck_tile::DeviceMem& aq_m_aqk_dev_buf,
|
ck_tile::DeviceMem& aq_m_aqk_dev_buf,
|
||||||
ck_tile::DeviceMem& b_k_n_dev_buf,
|
ck_tile::DeviceMem& b_k_n_dev_buf,
|
||||||
@@ -194,7 +199,8 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
|
|||||||
args.stride_C = stride_C;
|
args.stride_C = stride_C;
|
||||||
args.stride_AQ = stride_AQ;
|
args.stride_AQ = stride_AQ;
|
||||||
|
|
||||||
float ave_time = gemm_calc_aquant<ADataType,
|
float ave_time = gemm_calc_aquant<GemmConfig,
|
||||||
|
ADataType,
|
||||||
AQDataType,
|
AQDataType,
|
||||||
BDataType,
|
BDataType,
|
||||||
AccDataType,
|
AccDataType,
|
||||||
@@ -203,8 +209,7 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
|
|||||||
ALayout,
|
ALayout,
|
||||||
BLayout,
|
BLayout,
|
||||||
CLayout,
|
CLayout,
|
||||||
QuantGroupSize,
|
QuantGroupSize>(
|
||||||
Preshuffle>(
|
|
||||||
args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat});
|
args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat});
|
||||||
|
|
||||||
std::size_t flop = std::size_t(2) * M * N * K;
|
std::size_t flop = std::size_t(2) * M * N * K;
|
||||||
@@ -227,7 +232,8 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
|
|||||||
return ave_time;
|
return ave_time;
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename TypeConfig,
|
template <typename GemmConfig,
|
||||||
|
typename TypeConfig,
|
||||||
uint32_t QuantGroupSize,
|
uint32_t QuantGroupSize,
|
||||||
typename ALayout,
|
typename ALayout,
|
||||||
typename AQLayout,
|
typename AQLayout,
|
||||||
@@ -332,7 +338,8 @@ bool run_gemm_test_with_layouts(int argc,
|
|||||||
c_m_n_dev_buf.SetZero();
|
c_m_n_dev_buf.SetZero();
|
||||||
c_m_n_dev_result.SetZero();
|
c_m_n_dev_result.SetZero();
|
||||||
|
|
||||||
invoke_gemm<ADataType,
|
invoke_gemm<GemmConfig,
|
||||||
|
ADataType,
|
||||||
AQDataType,
|
AQDataType,
|
||||||
BDataType,
|
BDataType,
|
||||||
AccDataType,
|
AccDataType,
|
||||||
@@ -400,7 +407,7 @@ bool run_gemm_test_with_layouts(int argc,
|
|||||||
return pass;
|
return pass;
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename TypeConfig, uint32_t QuantGroupSize>
|
template <typename GemmConfig, typename TypeConfig, uint32_t QuantGroupSize>
|
||||||
bool run_gemm_test_prec_type(std::string a_layout, std::string b_layout, int argc, char* argv[])
|
bool run_gemm_test_prec_type(std::string a_layout, std::string b_layout, int argc, char* argv[])
|
||||||
{
|
{
|
||||||
using Row = ck_tile::tensor_layout::gemm::RowMajor;
|
using Row = ck_tile::tensor_layout::gemm::RowMajor;
|
||||||
@@ -412,7 +419,7 @@ bool run_gemm_test_prec_type(std::string a_layout, std::string b_layout, int arg
|
|||||||
{
|
{
|
||||||
if(a_layout == "R" && b_layout == "C")
|
if(a_layout == "R" && b_layout == "C")
|
||||||
{
|
{
|
||||||
return run_gemm_test_with_layouts<TypeConfig, QuantGroupSize>(
|
return run_gemm_test_with_layouts<GemmConfig, TypeConfig, QuantGroupSize>(
|
||||||
argc, argv, Row{}, Row{}, Col{}, Row{});
|
argc, argv, Row{}, Row{}, Col{}, Row{});
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
@@ -428,6 +435,7 @@ bool run_gemm_test_prec_type(std::string a_layout, std::string b_layout, int arg
|
|||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <template <typename PreType> typename GemmConfig>
|
||||||
bool run_gemm_test(int argc, char* argv[])
|
bool run_gemm_test(int argc, char* argv[])
|
||||||
{
|
{
|
||||||
auto [result, arg_parser] = create_args(argc, argv);
|
auto [result, arg_parser] = create_args(argc, argv);
|
||||||
@@ -441,41 +449,52 @@ bool run_gemm_test(int argc, char* argv[])
|
|||||||
if(data_type == "fp8")
|
if(data_type == "fp8")
|
||||||
{
|
{
|
||||||
using TypeConfig =
|
using TypeConfig =
|
||||||
decltype(GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t>{});
|
decltype(GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, float>{});
|
||||||
return run_gemm_test_prec_type<TypeConfig, 128>(a_layout, b_layout, argc, argv);
|
return run_gemm_test_prec_type<GemmConfig<ck_tile::fp8_t>, TypeConfig, 128>(
|
||||||
|
a_layout, b_layout, argc, argv);
|
||||||
}
|
}
|
||||||
else if(data_type == "bf8")
|
else if(data_type == "bf8")
|
||||||
{
|
{
|
||||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::bf8_t, ck_tile::bf8_t, float>{});
|
using TypeConfig =
|
||||||
return run_gemm_test_prec_type<TypeConfig, 128>(a_layout, b_layout, argc, argv);
|
decltype(GemmQuantTypeConfig<ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, float>{});
|
||||||
|
return run_gemm_test_prec_type<GemmConfig<ck_tile::fp8_t>, TypeConfig, 128>(
|
||||||
|
a_layout, b_layout, argc, argv);
|
||||||
}
|
}
|
||||||
else if(data_type == "i4fp8")
|
else if(data_type == "i4fp8")
|
||||||
{
|
{
|
||||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::pk_int4_t,
|
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::pk_int4_t,
|
||||||
ck_tile::fp8_t,
|
ck_tile::fp8_t,
|
||||||
float,
|
ck_tile::half_t,
|
||||||
ck_tile::fp8_t>{});
|
ck_tile::fp8_t>{});
|
||||||
return run_gemm_test_prec_type<TypeConfig, 128>(a_layout, b_layout, argc, argv);
|
return run_gemm_test_prec_type<GemmConfig<ck_tile::pk_int4_t>, TypeConfig, 128>(
|
||||||
|
a_layout, b_layout, argc, argv);
|
||||||
}
|
}
|
||||||
else if(data_type == "i4bf8")
|
else if(data_type == "i4bf8")
|
||||||
{
|
{
|
||||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::pk_int4_t,
|
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::pk_int4_t,
|
||||||
ck_tile::bf8_t,
|
ck_tile::bf8_t,
|
||||||
float,
|
ck_tile::half_t,
|
||||||
ck_tile::bf8_t>{});
|
ck_tile::bf8_t>{});
|
||||||
return run_gemm_test_prec_type<TypeConfig, 128>(a_layout, b_layout, argc, argv);
|
return run_gemm_test_prec_type<GemmConfig<ck_tile::pk_int4_t>, TypeConfig, 128>(
|
||||||
|
a_layout, b_layout, argc, argv);
|
||||||
}
|
}
|
||||||
else if(data_type == "i4f32fp8")
|
else if(data_type == "i4f32fp8")
|
||||||
{
|
{
|
||||||
using TypeConfig =
|
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::pk_int4_t,
|
||||||
decltype(GemmQuantTypeConfig<ck_tile::pk_int4_t, ck_tile::fp8_t, float, float>{});
|
ck_tile::fp8_t,
|
||||||
return run_gemm_test_prec_type<TypeConfig, 128>(a_layout, b_layout, argc, argv);
|
ck_tile::half_t,
|
||||||
|
float>{});
|
||||||
|
return run_gemm_test_prec_type<GemmConfig<ck_tile::pk_int4_t>, TypeConfig, 128>(
|
||||||
|
a_layout, b_layout, argc, argv);
|
||||||
}
|
}
|
||||||
else if(data_type == "i4f32bf8")
|
else if(data_type == "i4f32bf8")
|
||||||
{
|
{
|
||||||
using TypeConfig =
|
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::pk_int4_t,
|
||||||
decltype(GemmQuantTypeConfig<ck_tile::pk_int4_t, ck_tile::bf8_t, float, float>{});
|
ck_tile::bf8_t,
|
||||||
return run_gemm_test_prec_type<TypeConfig, 128>(a_layout, b_layout, argc, argv);
|
ck_tile::half_t,
|
||||||
|
float>{});
|
||||||
|
return run_gemm_test_prec_type<GemmConfig<ck_tile::pk_int4_t>, TypeConfig, 128>(
|
||||||
|
a_layout, b_layout, argc, argv);
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
@@ -564,7 +583,7 @@ int run_gemm_combinations(std::string const& data_type)
|
|||||||
// Call the function with the current configuration
|
// Call the function with the current configuration
|
||||||
try
|
try
|
||||||
{
|
{
|
||||||
is_success = run_gemm_test(ARG_COUNT, argv) && is_success;
|
is_success = run_gemm_test<GemmConfigDecode>(ARG_COUNT, argv) && is_success;
|
||||||
}
|
}
|
||||||
catch(const ArgumentsNotSupportedException& e)
|
catch(const ArgumentsNotSupportedException& e)
|
||||||
{
|
{
|
||||||
|
|||||||
Reference in New Issue
Block a user