mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-03-24 00:57:39 +00:00
[CK TILE] Fix bugs in AQuant preshuffle (#2700)
* [CK TILE] Fix bugs in AQuant preshuffle
- Make Aquant works with block Mx64x256. `M` could be 16, 32, 64
- Make Aquant works with warp 16x16x32 and 32x32x16.
* [CK TILE] Rename Preshuffle to PreshuffleQuant
The new name, PreshuffleQuant, explicitly states the function's purpose:
to preshuffle the quantization matrix.
* [CK TILE Block Scale] Use GemmConfig to save tile properties
- Remove specialization of GemmQuantTypeConfig
- Pass GemmConfig around which contains tile properties. Stop using hard
coded tile properties in `gemm_calc_aquant()`
* [CK TILE Block Scale] Rename GemmConfig used in block scale
- Remove unused GemmConfig
- Rename GemmConfig used in block scale
---------
Co-authored-by: ThomasNing <thomas.ning@amd.com>
This commit is contained in:
@@ -8,11 +8,10 @@
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "gemm_utils.hpp"
|
||||
|
||||
template <typename ADataType,
|
||||
template <typename GemmConfig,
|
||||
typename ADataType,
|
||||
typename AQDataType,
|
||||
typename BDataType,
|
||||
typename AccDataType,
|
||||
@@ -21,8 +20,7 @@ template <typename ADataType,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename CLayout,
|
||||
uint32_t QuantGroupSize,
|
||||
bool Preshuffle = false>
|
||||
uint32_t QuantGroupSize>
|
||||
float gemm_calc_aquant(const ck_tile::AQuantGemmHostArgs& args, const ck_tile::stream_config& s)
|
||||
{
|
||||
constexpr bool kPadM = false;
|
||||
@@ -33,17 +31,17 @@ float gemm_calc_aquant(const ck_tile::AQuantGemmHostArgs& args, const ck_tile::s
|
||||
|
||||
static_assert(std::is_same_v<CLayout, ck_tile::tensor_layout::gemm::RowMajor>);
|
||||
|
||||
constexpr ck_tile::index_t M_Tile = 16;
|
||||
constexpr ck_tile::index_t N_Tile = 64;
|
||||
constexpr ck_tile::index_t K_Tile = 256;
|
||||
constexpr ck_tile::index_t M_Tile = GemmConfig::M_Tile;
|
||||
constexpr ck_tile::index_t N_Tile = GemmConfig::N_Tile;
|
||||
constexpr ck_tile::index_t K_Tile = GemmConfig::K_Tile;
|
||||
|
||||
constexpr ck_tile::index_t M_Warp = 1;
|
||||
constexpr ck_tile::index_t N_Warp = 4;
|
||||
constexpr ck_tile::index_t K_Warp = 1;
|
||||
constexpr ck_tile::index_t M_Warp = GemmConfig::M_Warp;
|
||||
constexpr ck_tile::index_t N_Warp = GemmConfig::N_Warp;
|
||||
constexpr ck_tile::index_t K_Warp = GemmConfig::K_Warp;
|
||||
|
||||
constexpr ck_tile::index_t M_Warp_Tile = 16;
|
||||
constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
constexpr ck_tile::index_t K_Warp_Tile = 32;
|
||||
constexpr ck_tile::index_t M_Warp_Tile = GemmConfig::M_Warp_Tile;
|
||||
constexpr ck_tile::index_t N_Warp_Tile = GemmConfig::N_Warp_Tile;
|
||||
constexpr ck_tile::index_t K_Warp_Tile = GemmConfig::K_Warp_Tile;
|
||||
|
||||
using CodegenGemmShape =
|
||||
ck_tile::TileGemmShape<ck_tile::sequence<M_Tile, N_Tile, K_Tile>,
|
||||
@@ -52,8 +50,13 @@ float gemm_calc_aquant(const ck_tile::AQuantGemmHostArgs& args, const ck_tile::s
|
||||
|
||||
using TilePartitioner = ck_tile::GemmTile1DPartitioner<CodegenGemmShape>;
|
||||
|
||||
using CodegenGemmTraits =
|
||||
ck_tile::TileGemmAQuantTraits<kPadM, kPadN, kPadK, Preshuffle, ALayout, BLayout, CLayout>;
|
||||
using CodegenGemmTraits = ck_tile::TileGemmAQuantTraits<kPadM,
|
||||
kPadN,
|
||||
kPadK,
|
||||
GemmConfig::PreshuffleQuant,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout>;
|
||||
|
||||
using GemmPipelineProblem = ck_tile::GemmPipelineProblemBase<ADataType,
|
||||
BDataType,
|
||||
@@ -186,13 +189,14 @@ int run_gemm_example(int argc, char* argv[])
|
||||
if(data_type == "fp8")
|
||||
{
|
||||
using TypeConfig =
|
||||
decltype(GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t>{});
|
||||
decltype(GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, float>{});
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>, TypeConfig, 128>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
}
|
||||
else if(data_type == "bf8")
|
||||
{
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::bf8_t, ck_tile::bf8_t, float>{});
|
||||
using TypeConfig =
|
||||
decltype(GemmQuantTypeConfig<ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, float>{});
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>, TypeConfig, 128>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
}
|
||||
@@ -200,7 +204,7 @@ int run_gemm_example(int argc, char* argv[])
|
||||
{
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::pk_int4_t,
|
||||
ck_tile::fp8_t,
|
||||
float,
|
||||
ck_tile::half_t,
|
||||
ck_tile::fp8_t>{});
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::pk_int4_t>, TypeConfig, 128>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
@@ -209,29 +213,15 @@ int run_gemm_example(int argc, char* argv[])
|
||||
{
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::pk_int4_t,
|
||||
ck_tile::bf8_t,
|
||||
float,
|
||||
ck_tile::half_t,
|
||||
ck_tile::bf8_t>{});
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::pk_int4_t>, TypeConfig, 128>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
}
|
||||
else if(data_type == "i4f32fp8")
|
||||
{
|
||||
using TypeConfig =
|
||||
decltype(GemmQuantTypeConfig<ck_tile::pk_int4_t, ck_tile::fp8_t, float, float>{});
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::pk_int4_t>, TypeConfig, 128>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
}
|
||||
else if(data_type == "i4f32bf8")
|
||||
{
|
||||
using TypeConfig =
|
||||
decltype(GemmQuantTypeConfig<ck_tile::pk_int4_t, ck_tile::bf8_t, float, float>{});
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::pk_int4_t>, TypeConfig, 128>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported data type for this operation !!!");
|
||||
}
|
||||
}
|
||||
|
||||
int main(int argc, char* argv[]) { return !run_gemm_example<GemmConfigComputeV3>(argc, argv); }
|
||||
int main(int argc, char* argv[]) { return !run_gemm_example<GemmConfigDecode>(argc, argv); }
|
||||
|
||||
@@ -8,11 +8,10 @@
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "gemm_utils.hpp"
|
||||
|
||||
template <typename ADataType,
|
||||
template <typename GemmConfig,
|
||||
typename ADataType,
|
||||
typename AQDataType,
|
||||
typename BDataType,
|
||||
typename AccDataType,
|
||||
@@ -21,8 +20,7 @@ template <typename ADataType,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename CLayout,
|
||||
uint32_t QuantGroupSize,
|
||||
bool Preshuffle = false>
|
||||
uint32_t QuantGroupSize>
|
||||
float gemm_calc_aquant(const ck_tile::AQuantGemmHostArgs& args, const ck_tile::stream_config& s)
|
||||
{
|
||||
constexpr bool kPadM = false;
|
||||
@@ -33,17 +31,17 @@ float gemm_calc_aquant(const ck_tile::AQuantGemmHostArgs& args, const ck_tile::s
|
||||
|
||||
static_assert(std::is_same_v<CLayout, ck_tile::tensor_layout::gemm::RowMajor>);
|
||||
|
||||
constexpr ck_tile::index_t M_Tile = 16;
|
||||
constexpr ck_tile::index_t N_Tile = 64;
|
||||
constexpr ck_tile::index_t K_Tile = 256;
|
||||
constexpr ck_tile::index_t M_Tile = GemmConfig::M_Tile;
|
||||
constexpr ck_tile::index_t N_Tile = GemmConfig::N_Tile;
|
||||
constexpr ck_tile::index_t K_Tile = GemmConfig::K_Tile;
|
||||
|
||||
constexpr ck_tile::index_t M_Warp = 1;
|
||||
constexpr ck_tile::index_t N_Warp = 4;
|
||||
constexpr ck_tile::index_t K_Warp = 1;
|
||||
constexpr ck_tile::index_t M_Warp = GemmConfig::M_Warp;
|
||||
constexpr ck_tile::index_t N_Warp = GemmConfig::N_Warp;
|
||||
constexpr ck_tile::index_t K_Warp = GemmConfig::K_Warp;
|
||||
|
||||
constexpr ck_tile::index_t M_Warp_Tile = 16;
|
||||
constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
constexpr ck_tile::index_t K_Warp_Tile = 32;
|
||||
constexpr ck_tile::index_t M_Warp_Tile = GemmConfig::M_Warp_Tile;
|
||||
constexpr ck_tile::index_t N_Warp_Tile = GemmConfig::N_Warp_Tile;
|
||||
constexpr ck_tile::index_t K_Warp_Tile = GemmConfig::K_Warp_Tile;
|
||||
|
||||
using CodegenGemmShape =
|
||||
ck_tile::TileGemmShape<ck_tile::sequence<M_Tile, N_Tile, K_Tile>,
|
||||
@@ -52,8 +50,13 @@ float gemm_calc_aquant(const ck_tile::AQuantGemmHostArgs& args, const ck_tile::s
|
||||
|
||||
using TilePartitioner = ck_tile::GemmTile1DPartitioner<CodegenGemmShape>;
|
||||
|
||||
using CodegenGemmTraits =
|
||||
ck_tile::TileGemmAQuantTraits<kPadM, kPadN, kPadK, Preshuffle, ALayout, BLayout, CLayout>;
|
||||
using CodegenGemmTraits = ck_tile::TileGemmAQuantTraits<kPadM,
|
||||
kPadN,
|
||||
kPadK,
|
||||
GemmConfig::PreshuffleQuant,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout>;
|
||||
|
||||
using GemmPipelineProblem = ck_tile::GemmPipelineProblemBase<ADataType,
|
||||
BDataType,
|
||||
@@ -186,13 +189,14 @@ int run_gemm_example(int argc, char* argv[])
|
||||
if(data_type == "fp8")
|
||||
{
|
||||
using TypeConfig =
|
||||
decltype(GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t>{});
|
||||
decltype(GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, float>{});
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>, TypeConfig, 128>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
}
|
||||
else if(data_type == "bf8")
|
||||
{
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::bf8_t, ck_tile::bf8_t, float>{});
|
||||
using TypeConfig =
|
||||
decltype(GemmQuantTypeConfig<ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, float>{});
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>, TypeConfig, 128>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
}
|
||||
@@ -200,7 +204,7 @@ int run_gemm_example(int argc, char* argv[])
|
||||
{
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::pk_int4_t,
|
||||
ck_tile::fp8_t,
|
||||
float,
|
||||
ck_tile::half_t,
|
||||
ck_tile::fp8_t>{});
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::pk_int4_t>, TypeConfig, 128>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
@@ -209,29 +213,18 @@ int run_gemm_example(int argc, char* argv[])
|
||||
{
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::pk_int4_t,
|
||||
ck_tile::bf8_t,
|
||||
float,
|
||||
ck_tile::half_t,
|
||||
ck_tile::bf8_t>{});
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::pk_int4_t>, TypeConfig, 128>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
}
|
||||
else if(data_type == "i4f32fp8")
|
||||
{
|
||||
using TypeConfig =
|
||||
decltype(GemmQuantTypeConfig<ck_tile::pk_int4_t, ck_tile::fp8_t, float, float>{});
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::pk_int4_t>, TypeConfig, 128>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
}
|
||||
else if(data_type == "i4f32bf8")
|
||||
{
|
||||
using TypeConfig =
|
||||
decltype(GemmQuantTypeConfig<ck_tile::pk_int4_t, ck_tile::bf8_t, float, float>{});
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::pk_int4_t>, TypeConfig, 128>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported data type for this operation !!!");
|
||||
}
|
||||
}
|
||||
|
||||
int main(int argc, char* argv[]) { return !run_gemm_example<GemmConfigPreshufle_AQ>(argc, argv); }
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
return !run_gemm_example<GemmConfigPreshuffleQuant>(argc, argv);
|
||||
}
|
||||
|
||||
@@ -11,11 +11,9 @@
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
#include "ck_tile/ops/gemm_group_quant.hpp"
|
||||
|
||||
#define CK_TILE_PIPELINE_COMPUTE_V3 1
|
||||
#define CK_TILE_PIPELINE_MEMORY 2
|
||||
#define CK_TILE_PIPELINE_COMPUTE_V4 3
|
||||
#define CK_TILE_PIPELINE_COMPUTE_V5 4
|
||||
#define CK_TILE_PIPELINE_PRESHUFFLE 5
|
||||
#define CK_TILE_PIPELINE_PREFILL 1
|
||||
#define CK_TILE_PIPELINE_DECODE 2
|
||||
#define CK_TILE_PIPELINE_PRESHUFFLEQUANT 3
|
||||
|
||||
template <typename PrecType, ck_tile::index_t M_Warp_Tile>
|
||||
constexpr ck_tile::index_t get_k_warp_tile()
|
||||
@@ -87,196 +85,32 @@ struct GemmConfigBase
|
||||
static constexpr ck_tile::index_t TileParitionerGroupNum = 8;
|
||||
static constexpr ck_tile::index_t TileParitionerM01 = 4;
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
|
||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3;
|
||||
static constexpr ck_tile::index_t NumWaveGroups = 1;
|
||||
static constexpr bool Preshuffle = false;
|
||||
static constexpr bool PreshuffleQuant = false;
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
struct GemmConfigMemoryInterwave : public GemmConfigBase
|
||||
struct GemmConfigDecode : public GemmConfigBase
|
||||
{
|
||||
// Memory friendly for Interwave scheduler
|
||||
static constexpr ck_tile::index_t M_Tile = 128;
|
||||
static constexpr ck_tile::index_t N_Tile = 32;
|
||||
static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType);
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = 4;
|
||||
static constexpr ck_tile::index_t N_Warp = 1;
|
||||
static constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = sizeof(PrecType) == 2 ? 8 : 16;
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_MEMORY;
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Interwave;
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
struct GemmConfigMemoryIntrawave : public GemmConfigBase
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Tile = 128;
|
||||
static constexpr ck_tile::index_t N_Tile = 32;
|
||||
static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType);
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = 4;
|
||||
static constexpr ck_tile::index_t N_Warp = 1;
|
||||
static constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = sizeof(PrecType) == 2 ? 8 : 16;
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_MEMORY;
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
struct GemmConfigComputeV3 : public GemmConfigBase
|
||||
{
|
||||
// Compute V3 only support Intrawave scheduler
|
||||
static constexpr ck_tile::index_t M_Tile = 32;
|
||||
static constexpr ck_tile::index_t N_Tile = 128;
|
||||
static constexpr ck_tile::index_t M_Tile = 16;
|
||||
static constexpr ck_tile::index_t N_Tile = 64;
|
||||
static constexpr ck_tile::index_t K_Tile = 256 / sizeof(PrecType);
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = 1;
|
||||
static constexpr ck_tile::index_t N_Warp = 4;
|
||||
static constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3;
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
struct GemmConfigComputeV3_1 : public GemmConfigBase
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Tile = 256;
|
||||
static constexpr ck_tile::index_t N_Tile = 256;
|
||||
static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType);
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = 2;
|
||||
static constexpr ck_tile::index_t N_Warp = 2;
|
||||
static constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3;
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
struct GemmConfigComputeV3_2 : public GemmConfigBase
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Tile = 128;
|
||||
static constexpr ck_tile::index_t N_Tile = 128;
|
||||
static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType);
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = 2;
|
||||
static constexpr ck_tile::index_t N_Warp = 2;
|
||||
static constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3;
|
||||
|
||||
static constexpr int kBlockPerCu = 2;
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
struct GemmConfigComputeV4 : public GemmConfigBase
|
||||
{
|
||||
// Compute V4 only support Intrawave scheduler
|
||||
// Using the ping pong reader in the lds level
|
||||
static constexpr ck_tile::index_t M_Tile = 256;
|
||||
static constexpr ck_tile::index_t N_Tile = 256;
|
||||
static constexpr ck_tile::index_t K_Tile = 64 / sizeof(PrecType);
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = 2;
|
||||
static constexpr ck_tile::index_t N_Warp = 2;
|
||||
static constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = true;
|
||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V4;
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
struct GemmConfigComputeV4_1 : public GemmConfigBase
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Tile = 256;
|
||||
static constexpr ck_tile::index_t N_Tile = 256;
|
||||
static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType);
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = 2;
|
||||
static constexpr ck_tile::index_t N_Warp = 2;
|
||||
static constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = true;
|
||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V4;
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
struct GemmConfigComputeV5 : public GemmConfigBase
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Tile = 128;
|
||||
static constexpr ck_tile::index_t N_Tile = 128;
|
||||
static constexpr ck_tile::index_t K_Tile = 64 / sizeof(PrecType);
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = 1;
|
||||
static constexpr ck_tile::index_t N_Warp = 1;
|
||||
static constexpr ck_tile::index_t K_Warp = 2;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V5;
|
||||
static constexpr ck_tile::index_t NumWaNumWaveGroups = 2;
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
struct GemmConfigPreshufle_1 : public GemmConfigBase
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Tile = 128;
|
||||
static constexpr ck_tile::index_t N_Tile = 128;
|
||||
static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType);
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = 1;
|
||||
static constexpr ck_tile::index_t N_Warp = 4;
|
||||
static constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile =
|
||||
get_k_from_preshuffled_warp_tile<PrecType, M_Warp_Tile>();
|
||||
|
||||
static constexpr int kBlockPerCu = 2;
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
|
||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_PRESHUFFLE;
|
||||
static constexpr bool Preshuffle = true;
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_DECODE;
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
struct GemmConfigPreshufle_2 : public GemmConfigBase
|
||||
struct GemmConfigPrefill : public GemmConfigBase
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Tile = 128;
|
||||
static constexpr ck_tile::index_t N_Tile = 128;
|
||||
@@ -288,18 +122,15 @@ struct GemmConfigPreshufle_2 : public GemmConfigBase
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile =
|
||||
get_k_from_preshuffled_warp_tile<PrecType, M_Warp_Tile>();
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
|
||||
|
||||
static constexpr int kBlockPerCu = 2;
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
|
||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_PRESHUFFLE;
|
||||
static constexpr bool Preshuffle = true;
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_PREFILL;
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
struct GemmConfigPreshufle_AQ : public GemmConfigBase
|
||||
struct GemmConfigPreshuffleQuant : public GemmConfigBase
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Tile = 16;
|
||||
static constexpr ck_tile::index_t N_Tile = 64;
|
||||
@@ -314,9 +145,9 @@ struct GemmConfigPreshufle_AQ : public GemmConfigBase
|
||||
static constexpr ck_tile::index_t K_Warp_Tile =
|
||||
get_k_from_preshuffled_warp_tile<PrecType, M_Warp_Tile>();
|
||||
|
||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_PRESHUFFLE;
|
||||
static constexpr bool Preshuffle = true;
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
|
||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_PRESHUFFLEQUANT;
|
||||
static constexpr bool PreshuffleQuant = true;
|
||||
};
|
||||
|
||||
template <typename ADataType_,
|
||||
@@ -332,176 +163,6 @@ struct GemmQuantTypeConfig
|
||||
using CDataType = CDataType_;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct GemmQuantTypeConfig<ck_tile::half_t>
|
||||
{
|
||||
using ADataType = ck_tile::half_t;
|
||||
using QDataType = float;
|
||||
using BDataType = ck_tile::half_t;
|
||||
using AccDataType = float;
|
||||
using CDataType = ck_tile::half_t;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct GemmQuantTypeConfig<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t>
|
||||
{
|
||||
using ADataType = ck_tile::bf16_t;
|
||||
using QDataType = float;
|
||||
using BDataType = ck_tile::bf16_t;
|
||||
using AccDataType = float;
|
||||
using CDataType = ck_tile::bf16_t;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t>
|
||||
{
|
||||
using ADataType = ck_tile::fp8_t;
|
||||
using QDataType = float;
|
||||
using BDataType = ck_tile::fp8_t;
|
||||
using AccDataType = float;
|
||||
using CDataType = ck_tile::half_t;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct GemmQuantTypeConfig<ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t>
|
||||
{
|
||||
using ADataType = ck_tile::bf8_t;
|
||||
using QDataType = float;
|
||||
using BDataType = ck_tile::bf8_t;
|
||||
using AccDataType = float;
|
||||
using CDataType = ck_tile::half_t;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct GemmQuantTypeConfig<ck_tile::half_t, ck_tile::pk_int4_t, ck_tile::half_t>
|
||||
{
|
||||
using ADataType = ck_tile::half_t;
|
||||
using QDataType = float;
|
||||
using BDataType = ck_tile::pk_int4_t;
|
||||
using AccDataType = float;
|
||||
using CDataType = ck_tile::half_t;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, float>
|
||||
{
|
||||
using ADataType = ck_tile::fp8_t;
|
||||
using QDataType = float;
|
||||
using BDataType = ck_tile::fp8_t;
|
||||
using AccDataType = float;
|
||||
using CDataType = ck_tile::half_t;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct GemmQuantTypeConfig<ck_tile::bf8_t, ck_tile::bf8_t, float>
|
||||
{
|
||||
using ADataType = ck_tile::bf8_t;
|
||||
using QDataType = float;
|
||||
using BDataType = ck_tile::bf8_t;
|
||||
using AccDataType = float;
|
||||
using CDataType = ck_tile::half_t;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct GemmQuantTypeConfig<ck_tile::pk_int4_t, ck_tile::fp8_t, float, ck_tile::fp8_t>
|
||||
{
|
||||
using ADataType = ck_tile::pk_int4_t;
|
||||
using QDataType = ck_tile::fp8_t;
|
||||
using BDataType = ck_tile::fp8_t;
|
||||
using AccDataType = float;
|
||||
using CDataType = ck_tile::half_t;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, float, ck_tile::fp8_t>
|
||||
{
|
||||
using ADataType = ck_tile::fp8_t;
|
||||
using QDataType = ck_tile::fp8_t;
|
||||
using BDataType = ck_tile::fp8_t;
|
||||
using AccDataType = float;
|
||||
using CDataType = ck_tile::half_t;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct GemmQuantTypeConfig<ck_tile::bf8_t, ck_tile::bf8_t, float, ck_tile::bf8_t>
|
||||
{
|
||||
using ADataType = ck_tile::bf8_t;
|
||||
using QDataType = ck_tile::bf8_t;
|
||||
using BDataType = ck_tile::bf8_t;
|
||||
using AccDataType = float;
|
||||
using CDataType = ck_tile::half_t;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct GemmQuantTypeConfig<ck_tile::pk_int4_t, ck_tile::bf8_t, float, ck_tile::bf8_t>
|
||||
{
|
||||
using ADataType = ck_tile::pk_int4_t;
|
||||
using QDataType = ck_tile::bf8_t;
|
||||
using BDataType = ck_tile::bf8_t;
|
||||
using AccDataType = float;
|
||||
using CDataType = ck_tile::half_t;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct GemmQuantTypeConfig<ck_tile::pk_int4_t, ck_tile::fp8_t, float, float>
|
||||
{
|
||||
using ADataType = ck_tile::pk_int4_t;
|
||||
using QDataType = float;
|
||||
using BDataType = ck_tile::fp8_t;
|
||||
using AccDataType = float;
|
||||
using CDataType = ck_tile::half_t;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct GemmQuantTypeConfig<ck_tile::pk_int4_t, ck_tile::bf8_t, float, float>
|
||||
{
|
||||
using ADataType = ck_tile::pk_int4_t;
|
||||
using QDataType = float;
|
||||
using BDataType = ck_tile::bf8_t;
|
||||
using AccDataType = float;
|
||||
using CDataType = ck_tile::half_t;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::pk_int4_t, float, ck_tile::fp8_t>
|
||||
{
|
||||
using ADataType = ck_tile::fp8_t;
|
||||
using QDataType = ck_tile::fp8_t;
|
||||
using BDataType = ck_tile::pk_int4_t;
|
||||
using AccDataType = float;
|
||||
using CDataType = ck_tile::half_t;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct GemmQuantTypeConfig<ck_tile::bf8_t, ck_tile::pk_int4_t, float, ck_tile::bf8_t>
|
||||
{
|
||||
using ADataType = ck_tile::bf8_t;
|
||||
using QDataType = ck_tile::bf8_t;
|
||||
using BDataType = ck_tile::pk_int4_t;
|
||||
using AccDataType = float;
|
||||
using CDataType = ck_tile::half_t;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::pk_int4_t, float, float>
|
||||
{
|
||||
using ADataType = ck_tile::fp8_t;
|
||||
using QDataType = float;
|
||||
using BDataType = ck_tile::pk_int4_t;
|
||||
using AccDataType = float;
|
||||
using CDataType = ck_tile::half_t;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct GemmQuantTypeConfig<ck_tile::bf8_t, ck_tile::pk_int4_t, float, float>
|
||||
{
|
||||
using ADataType = ck_tile::bf8_t;
|
||||
using QDataType = float;
|
||||
using BDataType = ck_tile::pk_int4_t;
|
||||
using AccDataType = float;
|
||||
using CDataType = ck_tile::half_t;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct DataTypeTraits;
|
||||
|
||||
@@ -559,55 +220,6 @@ struct DataTypeTraits<ck_tile::int8_t>
|
||||
static constexpr const char* name = "int8";
|
||||
};
|
||||
|
||||
template <ck_tile::index_t PipelineId>
|
||||
struct PipelineTypeTraits;
|
||||
|
||||
template <>
|
||||
struct PipelineTypeTraits<CK_TILE_PIPELINE_MEMORY>
|
||||
{
|
||||
template <typename PipelineProblem>
|
||||
using GemmPipeline = ck_tile::GemmPipelineAgBgCrMem<PipelineProblem>;
|
||||
template <typename PipelineProblem>
|
||||
using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrMem<PipelineProblem>;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct PipelineTypeTraits<CK_TILE_PIPELINE_COMPUTE_V3>
|
||||
{
|
||||
template <typename PipelineProblem>
|
||||
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3<PipelineProblem>;
|
||||
template <typename PipelineProblem>
|
||||
using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV3<PipelineProblem>;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct PipelineTypeTraits<CK_TILE_PIPELINE_COMPUTE_V4>
|
||||
{
|
||||
template <typename PipelineProblem>
|
||||
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV4<PipelineProblem>;
|
||||
template <typename PipelineProblem>
|
||||
using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV4<PipelineProblem>;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct PipelineTypeTraits<CK_TILE_PIPELINE_COMPUTE_V5>
|
||||
{
|
||||
template <typename PipelineProblem>
|
||||
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV5<PipelineProblem>;
|
||||
template <typename PipelineProblem>
|
||||
using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV5<PipelineProblem>;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct PipelineTypeTraits<CK_TILE_PIPELINE_PRESHUFFLE>
|
||||
{
|
||||
template <typename PipelineProblem>
|
||||
using GemmPipeline = ck_tile::WeightPreshufflePipelineAGmemBGmemCRegV1<PipelineProblem>;
|
||||
template <typename PipelineProblem>
|
||||
using UniversalGemmPipeline =
|
||||
ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV1<PipelineProblem>;
|
||||
};
|
||||
|
||||
auto create_args(int argc, char* argv[])
|
||||
{
|
||||
ck_tile::ArgParser arg_parser;
|
||||
|
||||
@@ -31,7 +31,8 @@ auto shuffle_aq(const ck_tile::HostTensor<T>& t, int block_aq_k)
|
||||
return ck_tile::reference_permute(t_view, {1, 0, 2});
|
||||
}
|
||||
|
||||
template <typename ADataType,
|
||||
template <typename GemmConfig,
|
||||
typename ADataType,
|
||||
typename AQDataType,
|
||||
typename BDataType,
|
||||
typename AccDataType,
|
||||
@@ -40,8 +41,7 @@ template <typename ADataType,
|
||||
typename AQLayout,
|
||||
typename BLayout,
|
||||
typename CLayout,
|
||||
uint32_t QuantGroupSize,
|
||||
bool Preshuffle = false>
|
||||
uint32_t QuantGroupSize>
|
||||
float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
|
||||
ck_tile::DeviceMem& aq_m_aqk_dev_buf,
|
||||
ck_tile::DeviceMem& b_k_n_dev_buf,
|
||||
@@ -73,7 +73,8 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
|
||||
args.stride_C = stride_C;
|
||||
args.stride_AQ = stride_AQ;
|
||||
|
||||
float ave_time = gemm_calc_aquant<ADataType,
|
||||
float ave_time = gemm_calc_aquant<GemmConfig,
|
||||
ADataType,
|
||||
AQDataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
@@ -82,8 +83,7 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
QuantGroupSize,
|
||||
Preshuffle>(
|
||||
QuantGroupSize>(
|
||||
args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat});
|
||||
|
||||
std::size_t flop = std::size_t(2) * M * N * K;
|
||||
@@ -206,7 +206,7 @@ int run_gemm_example_with_layouts(int argc,
|
||||
ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem c_m_n_dev_buf(c_m_n_dev_result.get_element_space_size_in_bytes());
|
||||
|
||||
if constexpr(GemmConfig::Preshuffle)
|
||||
if constexpr(GemmConfig::PreshuffleQuant)
|
||||
{
|
||||
ck_tile::HostTensor<AQDataType> aq_shuffle_host =
|
||||
shuffle_aq(aq_m_aqk, GemmConfig::K_Tile / QuantGroupSize);
|
||||
@@ -222,7 +222,8 @@ int run_gemm_example_with_layouts(int argc,
|
||||
c_m_n_dev_buf.SetZero();
|
||||
c_m_n_dev_result.SetZero();
|
||||
|
||||
invoke_gemm<ADataType,
|
||||
invoke_gemm<GemmConfig,
|
||||
ADataType,
|
||||
AQDataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
@@ -231,22 +232,21 @@ int run_gemm_example_with_layouts(int argc,
|
||||
AQLayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
QuantGroupSize,
|
||||
GemmConfig::Preshuffle>(a_m_k_dev_buf,
|
||||
aq_m_aqk_dev_buf,
|
||||
b_k_n_dev_buf,
|
||||
c_m_n_dev_buf,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
AQK,
|
||||
stride_A,
|
||||
stride_AQ,
|
||||
stride_B,
|
||||
stride_C,
|
||||
kbatch,
|
||||
n_warmup,
|
||||
n_repeat);
|
||||
QuantGroupSize>(a_m_k_dev_buf,
|
||||
aq_m_aqk_dev_buf,
|
||||
b_k_n_dev_buf,
|
||||
c_m_n_dev_buf,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
AQK,
|
||||
stride_A,
|
||||
stride_AQ,
|
||||
stride_B,
|
||||
stride_C,
|
||||
kbatch,
|
||||
n_warmup,
|
||||
n_repeat);
|
||||
|
||||
c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data());
|
||||
bool pass = true;
|
||||
|
||||
@@ -157,7 +157,7 @@ struct AQuantBlockUniversalGemmAsBsCr : public BlockGemmQuantBase<Problem_>
|
||||
static constexpr index_t KPack = WarpGemm::kKPerThread;
|
||||
static constexpr index_t KPerThread = KIterPerWarp * WarpGemm::kKPerThread;
|
||||
|
||||
static constexpr bool Preshuffle = Problem::Traits::Preshuffle;
|
||||
static constexpr bool PreshuffleQuant = Problem::Traits::PreshuffleQuant;
|
||||
};
|
||||
|
||||
public:
|
||||
@@ -357,7 +357,7 @@ struct AQuantBlockUniversalGemmAsBsCr : public BlockGemmQuantBase<Problem_>
|
||||
}
|
||||
});
|
||||
|
||||
if constexpr(Traits::Preshuffle)
|
||||
if constexpr(Traits::PreshuffleQuant)
|
||||
{
|
||||
// A view is created on top of the preshuffled AQ, where each row of the
|
||||
// view is composed of a row from a warp tile within an AQ block tile.
|
||||
@@ -392,12 +392,27 @@ struct AQuantBlockUniversalGemmAsBsCr : public BlockGemmQuantBase<Problem_>
|
||||
|
||||
// Thread 0 can read AQ_tile[0, 0] from itself, AQ_tile[1, 0]
|
||||
// from thread 1, ..., and AQ_tile[3, 0] from thread 3.
|
||||
auto pull_from_lane =
|
||||
((threadIdx.x & (warp_size - 1)) / Traits::WarpGemm::kN *
|
||||
kTileRowsOfCPerThread +
|
||||
c_row) *
|
||||
Traits::QScalesPerBlockRow +
|
||||
kQScale;
|
||||
decltype(threadIdx.x) pull_from_lane = 0;
|
||||
if constexpr(WarpGemm::kM == 16)
|
||||
{
|
||||
pull_from_lane = (__lane_id() / Traits::WarpGemm::kN *
|
||||
kTileRowsOfCPerThread +
|
||||
c_row) *
|
||||
Traits::QScalesPerBlockRow +
|
||||
kQScale;
|
||||
}
|
||||
else if constexpr(WarpGemm::kM == 32)
|
||||
{
|
||||
pull_from_lane = (__lane_id() / Traits::WarpGemm::kN *
|
||||
kTileRowsOfCPerThread +
|
||||
((c_row >> 2) << 3) + (c_row & 0b11)) *
|
||||
Traits::QScalesPerBlockRow +
|
||||
kQScale;
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(false, "WarpGemm::kM is not 16 nor 32.");
|
||||
}
|
||||
auto& scale_reg = aq_block_tensor.get_thread_buffer()[mIter];
|
||||
|
||||
// cross lane ops
|
||||
|
||||
@@ -99,15 +99,15 @@ struct AQuantGemmKernelArgs
|
||||
template <typename TilePartitioner_, typename GemmPipeline_, typename EpiloguePipeline_>
|
||||
struct AQuantGemmKernel
|
||||
{
|
||||
using TilePartitioner = remove_cvref_t<TilePartitioner_>;
|
||||
using GemmPipeline = remove_cvref_t<GemmPipeline_>;
|
||||
using EpiloguePipeline = remove_cvref_t<EpiloguePipeline_>;
|
||||
using ALayout = remove_cvref_t<typename GemmPipeline::ALayout>;
|
||||
using AQLayout = remove_cvref_t<typename GemmPipeline::AQLayout>;
|
||||
using BLayout = remove_cvref_t<typename GemmPipeline::BLayout>;
|
||||
using CLayout = remove_cvref_t<typename GemmPipeline::CLayout>;
|
||||
static constexpr index_t kBlockSize = GemmPipeline::BlockSize;
|
||||
static constexpr bool Preshuffle = GemmPipeline::Preshuffle;
|
||||
using TilePartitioner = remove_cvref_t<TilePartitioner_>;
|
||||
using GemmPipeline = remove_cvref_t<GemmPipeline_>;
|
||||
using EpiloguePipeline = remove_cvref_t<EpiloguePipeline_>;
|
||||
using ALayout = remove_cvref_t<typename GemmPipeline::ALayout>;
|
||||
using AQLayout = remove_cvref_t<typename GemmPipeline::AQLayout>;
|
||||
using BLayout = remove_cvref_t<typename GemmPipeline::BLayout>;
|
||||
using CLayout = remove_cvref_t<typename GemmPipeline::CLayout>;
|
||||
static constexpr index_t KernelBlockSize = GemmPipeline::BlockSize;
|
||||
static constexpr bool PreshuffleQuant = GemmPipeline::PreshuffleQuant;
|
||||
|
||||
using ADataType = remove_cvref_t<typename GemmPipeline::ADataType>;
|
||||
using AQDataType = remove_cvref_t<typename GemmPipeline::AQDataType>;
|
||||
@@ -422,9 +422,9 @@ struct AQuantGemmKernel
|
||||
ck_tile::integer_least_multiple(wave_tile_size, get_warp_size());
|
||||
const auto aq_merge_pad1_desc = transform_tensor_descriptor(
|
||||
aq_pad1_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(wave_tile_count_x, aq_y)),
|
||||
make_tuple(make_merge_transform(make_tuple(aq_y, wave_tile_count_x)),
|
||||
make_pass_through_transform(pad_wave_size)),
|
||||
make_tuple(sequence<1, 0>{}, sequence<2>{}),
|
||||
make_tuple(sequence<0, 1>{}, sequence<2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
return make_tensor_view<address_space_enum::global>(aq_ptr, aq_merge_pad1_desc);
|
||||
@@ -432,7 +432,7 @@ struct AQuantGemmKernel
|
||||
|
||||
const auto& aq_tensor_view = [&]() {
|
||||
static_assert(std::is_same_v<AQLayout, tensor_layout::gemm::RowMajor>);
|
||||
if constexpr(Preshuffle)
|
||||
if constexpr(PreshuffleQuant)
|
||||
{
|
||||
return make_preshuffled_aq_tensor_view();
|
||||
}
|
||||
@@ -599,10 +599,8 @@ struct AQuantGemmKernel
|
||||
}
|
||||
|
||||
template <typename PadView>
|
||||
CK_TILE_DEVICE static auto MakeGemmTileWindows(const PadView& views,
|
||||
const AQuantGemmKernelArgs& kargs,
|
||||
const index_t i_m,
|
||||
const index_t i_n)
|
||||
CK_TILE_DEVICE static auto
|
||||
MakeGemmTileWindows(const PadView& views, const index_t i_m, const index_t i_n)
|
||||
{
|
||||
const auto& a_pad_view = views.at(I0);
|
||||
const auto& aq_pad_view = views.at(I1);
|
||||
@@ -628,24 +626,27 @@ struct AQuantGemmKernel
|
||||
|
||||
const auto& aq_block_window = [&]() {
|
||||
static_assert(std::is_same_v<AQLayout, tensor_layout::gemm::RowMajor>);
|
||||
if constexpr(Preshuffle)
|
||||
constexpr auto block_m = TilePartitioner::MPerBlock;
|
||||
constexpr auto block_k = TilePartitioner::KPerBlock;
|
||||
constexpr auto warp_m = TilePartitioner::BlockGemmShape::WarpTile::at(I0);
|
||||
constexpr auto aqk_per_block =
|
||||
TilePartitioner::KPerBlock / GemmPipeline::QuantGroupSize;
|
||||
if constexpr(PreshuffleQuant)
|
||||
{
|
||||
constexpr auto tile_window_width = get_warp_size();
|
||||
constexpr auto tile_window_height =
|
||||
TilePartitioner::MPerBlock / TilePartitioner::BlockGemmShape::WarpTile::at(I0);
|
||||
auto block_m_idx = i_m / TilePartitioner::MPerBlock;
|
||||
constexpr auto tile_window_width =
|
||||
ck_tile::integer_least_multiple(warp_m * aqk_per_block, get_warp_size());
|
||||
constexpr auto tile_window_height = block_m / warp_m;
|
||||
auto block_m_idx = i_m / block_m;
|
||||
return make_tile_window(
|
||||
aq_pad_view,
|
||||
make_tuple(number<tile_window_height>{}, number<tile_window_width>{}),
|
||||
{block_m_idx * kargs.K / TilePartitioner::BlockGemmShape::BlockTile::at(I2),
|
||||
0});
|
||||
{block_m_idx * tile_window_height, 0});
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_tile_window(
|
||||
aq_pad_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
number<TilePartitioner::KPerBlock / GemmPipeline::QuantGroupSize>{}),
|
||||
make_tuple(number<block_m>{}, number<block_k / GemmPipeline::QuantGroupSize>{}),
|
||||
{i_m, 0});
|
||||
}
|
||||
}();
|
||||
@@ -706,8 +707,7 @@ struct AQuantGemmKernel
|
||||
a_ptr, b_ptr, aq_ptr, c_ptr, kargs, splitk_batch_offset);
|
||||
|
||||
const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple);
|
||||
auto gemm_tile_windows =
|
||||
MakeGemmTileWindows(gemm_pad_views, kargs, block_idx_m, block_idx_n);
|
||||
auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
|
||||
|
||||
const index_t num_loop = __builtin_amdgcn_readfirstlane(
|
||||
TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k));
|
||||
@@ -718,7 +718,7 @@ struct AQuantGemmKernel
|
||||
const auto& b_block_window = gemm_tile_windows.at(I2);
|
||||
|
||||
const auto& c_block_tile = GemmPipeline{}.template operator()(
|
||||
a_block_window, b_block_window, aq_block_window, num_loop, smem_ptr_0);
|
||||
a_block_window, b_block_window, aq_block_window, kargs.M, num_loop, smem_ptr_0);
|
||||
|
||||
// Run Epilogue Pipeline
|
||||
auto& c_block_window = gemm_tile_windows.at(I3);
|
||||
|
||||
@@ -37,23 +37,23 @@ struct GemmAQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgC
|
||||
using AQLayout = remove_cvref_t<typename Problem::AQLayout>;
|
||||
using BlockGemmShape = typename Problem::BlockGemmShape;
|
||||
|
||||
constexpr index_t BlockSize = Problem::kBlockSize;
|
||||
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
constexpr index_t KPerBlockAQ = KPerBlock / Problem::kQuantGroupSize;
|
||||
constexpr index_t VecLoadSize = GetVectorSizeAQ<Problem>();
|
||||
constexpr bool Preshuffle = Problem::Traits::Preshuffle;
|
||||
using WarpTile = typename Problem::BlockGemmShape::WarpTile;
|
||||
using WarpGemm = WarpGemmDispatcher<typename Problem::ComputeDataType,
|
||||
typename Problem::ComputeDataType,
|
||||
typename Problem::CDataType,
|
||||
WarpTile::at(I0),
|
||||
WarpTile::at(I1),
|
||||
WarpTile::at(I2),
|
||||
false>;
|
||||
constexpr index_t BlockSize = Problem::kBlockSize;
|
||||
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
constexpr index_t KPerBlockAQ = KPerBlock / Problem::kQuantGroupSize;
|
||||
constexpr index_t VecLoadSize = GetVectorSizeAQ<Problem>();
|
||||
constexpr bool PreshuffleQuant = Problem::Traits::PreshuffleQuant;
|
||||
using WarpTile = typename Problem::BlockGemmShape::WarpTile;
|
||||
using WarpGemm = WarpGemmDispatcher<typename Problem::ComputeDataType,
|
||||
typename Problem::ComputeDataType,
|
||||
typename Problem::CDataType,
|
||||
WarpTile::at(I0),
|
||||
WarpTile::at(I1),
|
||||
WarpTile::at(I2),
|
||||
false>;
|
||||
|
||||
static_assert(std::is_same_v<AQLayout, tensor_layout::gemm::RowMajor>);
|
||||
if constexpr(Preshuffle)
|
||||
if constexpr(PreshuffleQuant)
|
||||
{
|
||||
using TileEncodingPattern =
|
||||
TileDistributionEncodingPatternAQ<BlockGemmShape,
|
||||
@@ -64,7 +64,7 @@ struct GemmAQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgC
|
||||
WarpGemm::kM * KPerBlockAQ, get_warp_size()),
|
||||
KPerBlockAQ,
|
||||
VecLoadSize,
|
||||
Preshuffle>;
|
||||
PreshuffleQuant>;
|
||||
|
||||
return TileEncodingPattern::Make2DStaticTileDistribution();
|
||||
}
|
||||
@@ -77,7 +77,7 @@ struct GemmAQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgC
|
||||
KPerBlockAQ,
|
||||
KPerBlockAQ,
|
||||
VecLoadSize,
|
||||
Preshuffle>;
|
||||
PreshuffleQuant>;
|
||||
|
||||
return TileEncodingPattern::Make2DStaticTileDistribution();
|
||||
}
|
||||
|
||||
@@ -7,6 +7,7 @@
|
||||
#include <sstream>
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/core/numeric/math.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
|
||||
#include "ck_tile/ops/gemm_group_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_base.hpp"
|
||||
#include "ck_tile/host/concat.hpp"
|
||||
@@ -133,7 +134,7 @@ struct AQuantGemmPipelineAgBgCrCompV3 : public BaseAQuantGemmPipelineAgBgCrCompV
|
||||
static constexpr bool kPadK = Problem::kPadK;
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = Problem::DoubleSmemBuffer;
|
||||
static constexpr bool Preshuffle = Problem::Traits::Preshuffle;
|
||||
static constexpr bool PreshuffleQuant = Problem::Traits::PreshuffleQuant;
|
||||
|
||||
static constexpr bool HasHotLoop = Problem::HasHotLoop;
|
||||
static constexpr auto TailNum = Problem::TailNum;
|
||||
@@ -235,6 +236,7 @@ struct AQuantGemmPipelineAgBgCrCompV3 : public BaseAQuantGemmPipelineAgBgCrCompV
|
||||
const BDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
const BElementFunction& b_element_func,
|
||||
const AQDramBlockWindowTmp& aq_dram_block_window_tmp,
|
||||
index_t m,
|
||||
index_t num_loop,
|
||||
void* p_smem) const
|
||||
{
|
||||
@@ -311,9 +313,11 @@ struct AQuantGemmPipelineAgBgCrCompV3 : public BaseAQuantGemmPipelineAgBgCrCompV
|
||||
is_b_row_major ? make_array(KPerBlock, 0) : make_array(0, KPerBlock);
|
||||
|
||||
// only row_major for AQ
|
||||
constexpr AQDramTileWindowStep aq_dram_tile_window_step =
|
||||
Preshuffle ? make_array(MPerBlock / BlockGemm::WarpGemm::kM, 0)
|
||||
: make_array(0, KPerBlockAQ);
|
||||
const AQDramTileWindowStep aq_dram_tile_window_step =
|
||||
PreshuffleQuant ? make_array(ck_tile::integer_least_multiple(m, MPerBlock) /
|
||||
BlockGemm::WarpGemm::kM,
|
||||
0)
|
||||
: make_array(0, KPerBlockAQ);
|
||||
|
||||
// DRAM prefetch (global read 0)
|
||||
Base::GlobalPrefetch(a_block_tile, a_copy_dram_window, a_dram_tile_window_step);
|
||||
@@ -458,6 +462,7 @@ struct AQuantGemmPipelineAgBgCrCompV3 : public BaseAQuantGemmPipelineAgBgCrCompV
|
||||
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const BDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
const AQDramBlockWindowTmp& aq_dram_block_window_tmp,
|
||||
index_t m,
|
||||
index_t num_loop,
|
||||
void* p_smem) const
|
||||
{
|
||||
@@ -467,6 +472,7 @@ struct AQuantGemmPipelineAgBgCrCompV3 : public BaseAQuantGemmPipelineAgBgCrCompV
|
||||
b_dram_block_window_tmp,
|
||||
[](const BDataType& b) { return b; },
|
||||
aq_dram_block_window_tmp,
|
||||
m,
|
||||
num_loop,
|
||||
p_smem);
|
||||
}
|
||||
|
||||
@@ -52,7 +52,7 @@ template <typename BlockGemmShape,
|
||||
index_t XPerTile,
|
||||
index_t KPerBlockAQ,
|
||||
index_t VecSize,
|
||||
bool Preshuffle>
|
||||
bool PreshuffleQuant>
|
||||
struct TileDistributionEncodingPatternAQ : public TileDistributionEncodingPattern
|
||||
{
|
||||
static_assert(XPerTile % VecSize == 0, "XPerTile must be a multiple of VecSize!");
|
||||
@@ -72,20 +72,20 @@ struct TileDistributionEncodingPatternAQ : public TileDistributionEncodingPatter
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto Make2DStaticTileDistribution()
|
||||
{
|
||||
if constexpr(Preshuffle)
|
||||
if constexpr(PreshuffleQuant)
|
||||
{
|
||||
// # of elements per thread
|
||||
constexpr index_t X2 = KPerBlockAQ;
|
||||
constexpr index_t X1 = warp_size / X2;
|
||||
static_assert(XPerTile >= warp_size && XPerTile % warp_size == 0);
|
||||
constexpr index_t X1 = warp_size;
|
||||
constexpr index_t X0 = XPerTile / warp_size;
|
||||
|
||||
constexpr index_t Y1 = MWarps;
|
||||
constexpr index_t Y0 = YPerTile / Y1;
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<NWarps>,
|
||||
tuple<sequence<Y0, Y1>, sequence<X0, X1, X2>>,
|
||||
tuple<sequence<1, 0>, sequence<2, 2>>,
|
||||
tuple<sequence<1, 0>, sequence<1, 2>>,
|
||||
tuple<sequence<Y0, Y1>, sequence<X0, X1>>,
|
||||
tuple<sequence<1, 0>, sequence<2>>,
|
||||
tuple<sequence<1, 0>, sequence<1>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{});
|
||||
}
|
||||
|
||||
@@ -10,7 +10,7 @@ namespace ck_tile {
|
||||
template <bool kPadM_,
|
||||
bool kPadN_,
|
||||
bool kPadK_,
|
||||
bool Preshuffle_,
|
||||
bool PreshuffleQuant_,
|
||||
typename ALayout_,
|
||||
typename BLayout_,
|
||||
typename CLayout_,
|
||||
@@ -30,7 +30,7 @@ struct TileGemmAQuantTraits
|
||||
|
||||
static constexpr bool UseStructuredSparsity = false;
|
||||
static constexpr index_t NumWaveGroups = 1;
|
||||
static constexpr bool Preshuffle = Preshuffle_;
|
||||
static constexpr bool PreshuffleQuant = PreshuffleQuant_;
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -11,11 +11,9 @@
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
#include "ck_tile/ops/gemm_group_quant.hpp"
|
||||
|
||||
#define CK_TILE_PIPELINE_COMPUTE_V3 1
|
||||
#define CK_TILE_PIPELINE_MEMORY 2
|
||||
#define CK_TILE_PIPELINE_COMPUTE_V4 3
|
||||
#define CK_TILE_PIPELINE_COMPUTE_V5 4
|
||||
#define CK_TILE_PIPELINE_PRESHUFFLE 5
|
||||
#define CK_TILE_PIPELINE_PREFILL 1
|
||||
#define CK_TILE_PIPELINE_DECODE 2
|
||||
#define CK_TILE_PIPELINE_PRESHUFFLEQUANT 3
|
||||
|
||||
template <typename PrecType, ck_tile::index_t M_Warp_Tile>
|
||||
constexpr ck_tile::index_t get_k_warp_tile()
|
||||
@@ -34,21 +32,6 @@ constexpr ck_tile::index_t get_k_warp_tile()
|
||||
return 32;
|
||||
#endif
|
||||
}
|
||||
template <typename PrecType, ck_tile::index_t M_Warp_Tile>
|
||||
constexpr ck_tile::index_t get_k_warp_tile_flatmm()
|
||||
{
|
||||
#if defined(__gfx950__)
|
||||
if constexpr(M_Warp_Tile == 32)
|
||||
return sizeof(PrecType) == 2 ? 16 : 64;
|
||||
else
|
||||
return sizeof(PrecType) == 2 ? 32 : 128;
|
||||
#else
|
||||
if constexpr(M_Warp_Tile == 32)
|
||||
return sizeof(PrecType) == 2 ? 16 : 32;
|
||||
else
|
||||
return sizeof(PrecType) == 2 ? 32 : 64;
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename ADataType, typename BDataType, typename AccDataType, typename CDataType>
|
||||
auto calculate_rtol_atol(const ck_tile::index_t K,
|
||||
@@ -93,195 +76,32 @@ struct GemmConfigBase
|
||||
static constexpr ck_tile::index_t TileParitionerGroupNum = 8;
|
||||
static constexpr ck_tile::index_t TileParitionerM01 = 4;
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
|
||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3;
|
||||
static constexpr ck_tile::index_t NumWaveGroups = 1;
|
||||
static constexpr bool Preshuffle = false;
|
||||
static constexpr bool PreshuffleQuant = false;
|
||||
static constexpr bool DoubleSmemBuffer = true;
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
struct GemmConfigMemoryInterwave : public GemmConfigBase
|
||||
struct GemmConfigDecode : public GemmConfigBase
|
||||
{
|
||||
// Memory friendly for Interwave scheduler
|
||||
static constexpr ck_tile::index_t M_Tile = 128;
|
||||
static constexpr ck_tile::index_t N_Tile = 32;
|
||||
static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType);
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = 4;
|
||||
static constexpr ck_tile::index_t N_Warp = 1;
|
||||
static constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = sizeof(PrecType) == 2 ? 8 : 16;
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_MEMORY;
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Interwave;
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
struct GemmConfigMemoryIntrawave : public GemmConfigBase
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Tile = 128;
|
||||
static constexpr ck_tile::index_t N_Tile = 32;
|
||||
static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType);
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = 4;
|
||||
static constexpr ck_tile::index_t N_Warp = 1;
|
||||
static constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = sizeof(PrecType) == 2 ? 8 : 16;
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_MEMORY;
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
struct GemmConfigComputeV3 : public GemmConfigBase
|
||||
{
|
||||
// Compute V3 only support Intrawave scheduler
|
||||
static constexpr ck_tile::index_t M_Tile = 32;
|
||||
static constexpr ck_tile::index_t N_Tile = 128;
|
||||
static constexpr ck_tile::index_t K_Tile = 256;
|
||||
static constexpr ck_tile::index_t M_Tile = 16;
|
||||
static constexpr ck_tile::index_t N_Tile = 64;
|
||||
static constexpr ck_tile::index_t K_Tile = 256 / sizeof(PrecType);
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = 1;
|
||||
static constexpr ck_tile::index_t N_Warp = 4;
|
||||
static constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3;
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
struct GemmConfigComputeV3_1 : public GemmConfigBase
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Tile = 256;
|
||||
static constexpr ck_tile::index_t N_Tile = 256;
|
||||
static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType);
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = 2;
|
||||
static constexpr ck_tile::index_t N_Warp = 2;
|
||||
static constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3;
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
struct GemmConfigComputeV3_2 : public GemmConfigBase
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Tile = 128;
|
||||
static constexpr ck_tile::index_t N_Tile = 128;
|
||||
static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType);
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = 2;
|
||||
static constexpr ck_tile::index_t N_Warp = 2;
|
||||
static constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3;
|
||||
|
||||
static constexpr int kBlockPerCu = 2;
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
struct GemmConfigComputeV4 : public GemmConfigBase
|
||||
{
|
||||
// Compute V4 only support Intrawave scheduler
|
||||
// Using the ping pong reader in the lds level
|
||||
static constexpr ck_tile::index_t M_Tile = 256;
|
||||
static constexpr ck_tile::index_t N_Tile = 256;
|
||||
static constexpr ck_tile::index_t K_Tile = 64 / sizeof(PrecType);
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = 2;
|
||||
static constexpr ck_tile::index_t N_Warp = 2;
|
||||
static constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = true;
|
||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V4;
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
struct GemmConfigComputeV4_1 : public GemmConfigBase
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Tile = 256;
|
||||
static constexpr ck_tile::index_t N_Tile = 256;
|
||||
static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType);
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = 2;
|
||||
static constexpr ck_tile::index_t N_Warp = 2;
|
||||
static constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = true;
|
||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V4;
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
struct GemmConfigComputeV5 : public GemmConfigBase
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Tile = 128;
|
||||
static constexpr ck_tile::index_t N_Tile = 128;
|
||||
static constexpr ck_tile::index_t K_Tile = 64 / sizeof(PrecType);
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = 1;
|
||||
static constexpr ck_tile::index_t N_Warp = 1;
|
||||
static constexpr ck_tile::index_t K_Warp = 2;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V5;
|
||||
static constexpr ck_tile::index_t NumWaNumWaveGroups = 2;
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
struct GemmConfigPreshufle_1 : public GemmConfigBase
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Tile = 128;
|
||||
static constexpr ck_tile::index_t N_Tile = 128;
|
||||
static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType);
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = 1;
|
||||
static constexpr ck_tile::index_t N_Warp = 4;
|
||||
static constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile_flatmm<PrecType, M_Warp_Tile>();
|
||||
|
||||
static constexpr int kBlockPerCu = 2;
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
|
||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_PRESHUFFLE;
|
||||
static constexpr bool Preshuffle = true;
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_DECODE;
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
struct GemmConfigPreshufle_2 : public GemmConfigBase
|
||||
struct GemmConfigPrefill : public GemmConfigBase
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Tile = 128;
|
||||
static constexpr ck_tile::index_t N_Tile = 128;
|
||||
@@ -293,71 +113,32 @@ struct GemmConfigPreshufle_2 : public GemmConfigBase
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile_flatmm<PrecType, M_Warp_Tile>();
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
|
||||
|
||||
static constexpr int kBlockPerCu = 2;
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
|
||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_PRESHUFFLE;
|
||||
static constexpr bool Preshuffle = true;
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_PREFILL;
|
||||
};
|
||||
|
||||
template <typename ADataType, typename BDataType = ADataType, typename CDataType = ADataType>
|
||||
struct GemmTypeConfig;
|
||||
|
||||
template <>
|
||||
struct GemmTypeConfig<ck_tile::half_t>
|
||||
template <typename PrecType>
|
||||
struct GemmConfigPreshuffleQuant : public GemmConfigBase
|
||||
{
|
||||
using ADataType = ck_tile::half_t;
|
||||
using BDataType = ck_tile::half_t;
|
||||
using AccDataType = float;
|
||||
using CDataType = ck_tile::half_t;
|
||||
// ToDo: Add more bias config to support different categories of GEMM.
|
||||
};
|
||||
static constexpr ck_tile::index_t M_Tile = 16;
|
||||
static constexpr ck_tile::index_t N_Tile = 64;
|
||||
static constexpr ck_tile::index_t K_Tile = 256 / sizeof(PrecType);
|
||||
|
||||
template <>
|
||||
struct GemmTypeConfig<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t>
|
||||
{
|
||||
using ADataType = ck_tile::bf16_t;
|
||||
using BDataType = ck_tile::bf16_t;
|
||||
using AccDataType = float;
|
||||
using CDataType = ck_tile::bf16_t;
|
||||
};
|
||||
static constexpr ck_tile::index_t M_Warp = 1;
|
||||
static constexpr ck_tile::index_t N_Warp = 4;
|
||||
static constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
template <>
|
||||
struct GemmTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t>
|
||||
{
|
||||
using ADataType = ck_tile::fp8_t;
|
||||
using BDataType = ck_tile::fp8_t;
|
||||
using AccDataType = float;
|
||||
using CDataType = ck_tile::half_t;
|
||||
};
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile =
|
||||
get_k_from_preshuffled_warp_tile<PrecType, M_Warp_Tile>();
|
||||
|
||||
template <>
|
||||
struct GemmTypeConfig<ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t>
|
||||
{
|
||||
using ADataType = ck_tile::bf8_t;
|
||||
using BDataType = ck_tile::bf8_t;
|
||||
using AccDataType = float;
|
||||
using CDataType = ck_tile::half_t;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct GemmTypeConfig<ck_tile::half_t, ck_tile::pk_int4_t, ck_tile::half_t>
|
||||
{
|
||||
using ADataType = ck_tile::half_t;
|
||||
using BDataType = ck_tile::pk_int4_t;
|
||||
using AccDataType = float;
|
||||
using CDataType = ck_tile::half_t;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct GemmTypeConfig<ck_tile::int8_t, ck_tile::int8_t, int32_t>
|
||||
{
|
||||
using ADataType = ck_tile::int8_t;
|
||||
using BDataType = ck_tile::int8_t;
|
||||
using AccDataType = int32_t;
|
||||
using CDataType = int32_t;
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
|
||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_PRESHUFFLEQUANT;
|
||||
static constexpr bool PreshuffleQuant = true;
|
||||
};
|
||||
|
||||
template <typename ADataType_,
|
||||
@@ -373,176 +154,6 @@ struct GemmQuantTypeConfig
|
||||
using CDataType = CDataType_;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct GemmQuantTypeConfig<ck_tile::half_t>
|
||||
{
|
||||
using ADataType = ck_tile::half_t;
|
||||
using QDataType = float;
|
||||
using BDataType = ck_tile::half_t;
|
||||
using AccDataType = float;
|
||||
using CDataType = ck_tile::half_t;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct GemmQuantTypeConfig<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t>
|
||||
{
|
||||
using ADataType = ck_tile::bf16_t;
|
||||
using QDataType = float;
|
||||
using BDataType = ck_tile::bf16_t;
|
||||
using AccDataType = float;
|
||||
using CDataType = ck_tile::bf16_t;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t>
|
||||
{
|
||||
using ADataType = ck_tile::fp8_t;
|
||||
using QDataType = float;
|
||||
using BDataType = ck_tile::fp8_t;
|
||||
using AccDataType = float;
|
||||
using CDataType = ck_tile::half_t;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct GemmQuantTypeConfig<ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t>
|
||||
{
|
||||
using ADataType = ck_tile::bf8_t;
|
||||
using QDataType = float;
|
||||
using BDataType = ck_tile::bf8_t;
|
||||
using AccDataType = float;
|
||||
using CDataType = ck_tile::half_t;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct GemmQuantTypeConfig<ck_tile::half_t, ck_tile::pk_int4_t, ck_tile::half_t>
|
||||
{
|
||||
using ADataType = ck_tile::half_t;
|
||||
using QDataType = float;
|
||||
using BDataType = ck_tile::pk_int4_t;
|
||||
using AccDataType = float;
|
||||
using CDataType = ck_tile::half_t;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, float>
|
||||
{
|
||||
using ADataType = ck_tile::fp8_t;
|
||||
using QDataType = float;
|
||||
using BDataType = ck_tile::fp8_t;
|
||||
using AccDataType = float;
|
||||
using CDataType = float;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct GemmQuantTypeConfig<ck_tile::bf8_t, ck_tile::bf8_t, float>
|
||||
{
|
||||
using ADataType = ck_tile::bf8_t;
|
||||
using QDataType = float;
|
||||
using BDataType = ck_tile::bf8_t;
|
||||
using AccDataType = float;
|
||||
using CDataType = float;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct GemmQuantTypeConfig<ck_tile::pk_int4_t, ck_tile::fp8_t, float, ck_tile::fp8_t>
|
||||
{
|
||||
using ADataType = ck_tile::pk_int4_t;
|
||||
using QDataType = ck_tile::fp8_t;
|
||||
using BDataType = ck_tile::fp8_t;
|
||||
using AccDataType = float;
|
||||
using CDataType = float;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, float, ck_tile::fp8_t>
|
||||
{
|
||||
using ADataType = ck_tile::fp8_t;
|
||||
using QDataType = ck_tile::fp8_t;
|
||||
using BDataType = ck_tile::fp8_t;
|
||||
using AccDataType = float;
|
||||
using CDataType = float;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct GemmQuantTypeConfig<ck_tile::bf8_t, ck_tile::bf8_t, float, ck_tile::bf8_t>
|
||||
{
|
||||
using ADataType = ck_tile::bf8_t;
|
||||
using QDataType = ck_tile::bf8_t;
|
||||
using BDataType = ck_tile::bf8_t;
|
||||
using AccDataType = float;
|
||||
using CDataType = float;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct GemmQuantTypeConfig<ck_tile::pk_int4_t, ck_tile::bf8_t, float, ck_tile::bf8_t>
|
||||
{
|
||||
using ADataType = ck_tile::pk_int4_t;
|
||||
using QDataType = ck_tile::bf8_t;
|
||||
using BDataType = ck_tile::bf8_t;
|
||||
using AccDataType = float;
|
||||
using CDataType = float;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct GemmQuantTypeConfig<ck_tile::pk_int4_t, ck_tile::fp8_t, float, float>
|
||||
{
|
||||
using ADataType = ck_tile::pk_int4_t;
|
||||
using QDataType = float;
|
||||
using BDataType = ck_tile::fp8_t;
|
||||
using AccDataType = float;
|
||||
using CDataType = float;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct GemmQuantTypeConfig<ck_tile::pk_int4_t, ck_tile::bf8_t, float, float>
|
||||
{
|
||||
using ADataType = ck_tile::pk_int4_t;
|
||||
using QDataType = float;
|
||||
using BDataType = ck_tile::bf8_t;
|
||||
using AccDataType = float;
|
||||
using CDataType = float;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::pk_int4_t, float, ck_tile::fp8_t>
|
||||
{
|
||||
using ADataType = ck_tile::fp8_t;
|
||||
using QDataType = ck_tile::fp8_t;
|
||||
using BDataType = ck_tile::pk_int4_t;
|
||||
using AccDataType = float;
|
||||
using CDataType = float;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct GemmQuantTypeConfig<ck_tile::bf8_t, ck_tile::pk_int4_t, float, ck_tile::bf8_t>
|
||||
{
|
||||
using ADataType = ck_tile::bf8_t;
|
||||
using QDataType = ck_tile::bf8_t;
|
||||
using BDataType = ck_tile::pk_int4_t;
|
||||
using AccDataType = float;
|
||||
using CDataType = float;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::pk_int4_t, float, float>
|
||||
{
|
||||
using ADataType = ck_tile::fp8_t;
|
||||
using QDataType = float;
|
||||
using BDataType = ck_tile::pk_int4_t;
|
||||
using AccDataType = float;
|
||||
using CDataType = float;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct GemmQuantTypeConfig<ck_tile::bf8_t, ck_tile::pk_int4_t, float, float>
|
||||
{
|
||||
using ADataType = ck_tile::bf8_t;
|
||||
using QDataType = float;
|
||||
using BDataType = ck_tile::pk_int4_t;
|
||||
using AccDataType = float;
|
||||
using CDataType = float;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct DataTypeTraits;
|
||||
|
||||
@@ -600,55 +211,6 @@ struct DataTypeTraits<ck_tile::int8_t>
|
||||
static constexpr const char* name = "int8";
|
||||
};
|
||||
|
||||
template <ck_tile::index_t PipelineId>
|
||||
struct PipelineTypeTraits;
|
||||
|
||||
template <>
|
||||
struct PipelineTypeTraits<CK_TILE_PIPELINE_MEMORY>
|
||||
{
|
||||
template <typename PipelineProblem>
|
||||
using GemmPipeline = ck_tile::GemmPipelineAgBgCrMem<PipelineProblem>;
|
||||
template <typename PipelineProblem>
|
||||
using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrMem<PipelineProblem>;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct PipelineTypeTraits<CK_TILE_PIPELINE_COMPUTE_V3>
|
||||
{
|
||||
template <typename PipelineProblem>
|
||||
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3<PipelineProblem>;
|
||||
template <typename PipelineProblem>
|
||||
using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV3<PipelineProblem>;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct PipelineTypeTraits<CK_TILE_PIPELINE_COMPUTE_V4>
|
||||
{
|
||||
template <typename PipelineProblem>
|
||||
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV4<PipelineProblem>;
|
||||
template <typename PipelineProblem>
|
||||
using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV4<PipelineProblem>;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct PipelineTypeTraits<CK_TILE_PIPELINE_COMPUTE_V5>
|
||||
{
|
||||
template <typename PipelineProblem>
|
||||
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV5<PipelineProblem>;
|
||||
template <typename PipelineProblem>
|
||||
using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV5<PipelineProblem>;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct PipelineTypeTraits<CK_TILE_PIPELINE_PRESHUFFLE>
|
||||
{
|
||||
template <typename PipelineProblem>
|
||||
using GemmPipeline = ck_tile::WeightPreshufflePipelineAGmemBGmemCRegV1<PipelineProblem>;
|
||||
template <typename PipelineProblem>
|
||||
using UniversalGemmPipeline =
|
||||
ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV1<PipelineProblem>;
|
||||
};
|
||||
|
||||
auto create_args(int argc, char* argv[])
|
||||
{
|
||||
ck_tile::ArgParser arg_parser;
|
||||
|
||||
@@ -15,7 +15,8 @@
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "test_gemm_aquant_utils.hpp"
|
||||
|
||||
template <typename ADataType,
|
||||
template <typename GemmConfig,
|
||||
typename ADataType,
|
||||
typename AQDataType,
|
||||
typename BDataType,
|
||||
typename AccDataType,
|
||||
@@ -24,8 +25,7 @@ template <typename ADataType,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename CLayout,
|
||||
uint32_t QuantGroupSize,
|
||||
bool Preshuffle = false>
|
||||
uint32_t QuantGroupSize>
|
||||
float gemm_calc_aquant(const ck_tile::AQuantGemmHostArgs& args, const ck_tile::stream_config& s)
|
||||
{
|
||||
constexpr bool kPadM = false;
|
||||
@@ -36,17 +36,17 @@ float gemm_calc_aquant(const ck_tile::AQuantGemmHostArgs& args, const ck_tile::s
|
||||
|
||||
static_assert(std::is_same_v<CLayout, ck_tile::tensor_layout::gemm::RowMajor>);
|
||||
|
||||
constexpr ck_tile::index_t M_Tile = 16;
|
||||
constexpr ck_tile::index_t N_Tile = 64;
|
||||
constexpr ck_tile::index_t K_Tile = 256;
|
||||
constexpr ck_tile::index_t M_Tile = GemmConfig::M_Tile;
|
||||
constexpr ck_tile::index_t N_Tile = GemmConfig::N_Tile;
|
||||
constexpr ck_tile::index_t K_Tile = GemmConfig::K_Tile;
|
||||
|
||||
constexpr ck_tile::index_t M_Warp = 1;
|
||||
constexpr ck_tile::index_t N_Warp = 4;
|
||||
constexpr ck_tile::index_t K_Warp = 1;
|
||||
constexpr ck_tile::index_t M_Warp = GemmConfig::M_Warp;
|
||||
constexpr ck_tile::index_t N_Warp = GemmConfig::N_Warp;
|
||||
constexpr ck_tile::index_t K_Warp = GemmConfig::K_Warp;
|
||||
|
||||
constexpr ck_tile::index_t M_Warp_Tile = 16;
|
||||
constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
constexpr ck_tile::index_t K_Warp_Tile = 32;
|
||||
constexpr ck_tile::index_t M_Warp_Tile = GemmConfig::M_Warp_Tile;
|
||||
constexpr ck_tile::index_t N_Warp_Tile = GemmConfig::N_Warp_Tile;
|
||||
constexpr ck_tile::index_t K_Warp_Tile = GemmConfig::K_Warp_Tile;
|
||||
|
||||
using CodegenGemmShape =
|
||||
ck_tile::TileGemmShape<ck_tile::sequence<M_Tile, N_Tile, K_Tile>,
|
||||
@@ -55,8 +55,13 @@ float gemm_calc_aquant(const ck_tile::AQuantGemmHostArgs& args, const ck_tile::s
|
||||
|
||||
using TilePartitioner = ck_tile::GemmTile1DPartitioner<CodegenGemmShape>;
|
||||
|
||||
using CodegenGemmTraits =
|
||||
ck_tile::TileGemmAQuantTraits<kPadM, kPadN, kPadK, Preshuffle, ALayout, BLayout, CLayout>;
|
||||
using CodegenGemmTraits = ck_tile::TileGemmAQuantTraits<kPadM,
|
||||
kPadN,
|
||||
kPadK,
|
||||
GemmConfig::PreshuffleQuant,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout>;
|
||||
|
||||
using GemmPipelineProblem = ck_tile::GemmPipelineProblemBase<ADataType,
|
||||
BDataType,
|
||||
@@ -152,7 +157,8 @@ static constexpr inline auto is_row_major(Layout layout_)
|
||||
ck_tile::tensor_layout::gemm::RowMajor>>{};
|
||||
}
|
||||
|
||||
template <typename ADataType,
|
||||
template <typename GemmConfig,
|
||||
typename ADataType,
|
||||
typename AQDataType,
|
||||
typename BDataType,
|
||||
typename AccDataType,
|
||||
@@ -161,8 +167,7 @@ template <typename ADataType,
|
||||
typename AQLayout,
|
||||
typename BLayout,
|
||||
typename CLayout,
|
||||
uint32_t QuantGroupSize,
|
||||
bool Preshuffle = false>
|
||||
uint32_t QuantGroupSize>
|
||||
float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
|
||||
ck_tile::DeviceMem& aq_m_aqk_dev_buf,
|
||||
ck_tile::DeviceMem& b_k_n_dev_buf,
|
||||
@@ -194,7 +199,8 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
|
||||
args.stride_C = stride_C;
|
||||
args.stride_AQ = stride_AQ;
|
||||
|
||||
float ave_time = gemm_calc_aquant<ADataType,
|
||||
float ave_time = gemm_calc_aquant<GemmConfig,
|
||||
ADataType,
|
||||
AQDataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
@@ -203,8 +209,7 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
QuantGroupSize,
|
||||
Preshuffle>(
|
||||
QuantGroupSize>(
|
||||
args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat});
|
||||
|
||||
std::size_t flop = std::size_t(2) * M * N * K;
|
||||
@@ -227,7 +232,8 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
template <typename TypeConfig,
|
||||
template <typename GemmConfig,
|
||||
typename TypeConfig,
|
||||
uint32_t QuantGroupSize,
|
||||
typename ALayout,
|
||||
typename AQLayout,
|
||||
@@ -332,7 +338,8 @@ bool run_gemm_test_with_layouts(int argc,
|
||||
c_m_n_dev_buf.SetZero();
|
||||
c_m_n_dev_result.SetZero();
|
||||
|
||||
invoke_gemm<ADataType,
|
||||
invoke_gemm<GemmConfig,
|
||||
ADataType,
|
||||
AQDataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
@@ -400,7 +407,7 @@ bool run_gemm_test_with_layouts(int argc,
|
||||
return pass;
|
||||
}
|
||||
|
||||
template <typename TypeConfig, uint32_t QuantGroupSize>
|
||||
template <typename GemmConfig, typename TypeConfig, uint32_t QuantGroupSize>
|
||||
bool run_gemm_test_prec_type(std::string a_layout, std::string b_layout, int argc, char* argv[])
|
||||
{
|
||||
using Row = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
@@ -412,7 +419,7 @@ bool run_gemm_test_prec_type(std::string a_layout, std::string b_layout, int arg
|
||||
{
|
||||
if(a_layout == "R" && b_layout == "C")
|
||||
{
|
||||
return run_gemm_test_with_layouts<TypeConfig, QuantGroupSize>(
|
||||
return run_gemm_test_with_layouts<GemmConfig, TypeConfig, QuantGroupSize>(
|
||||
argc, argv, Row{}, Row{}, Col{}, Row{});
|
||||
}
|
||||
else
|
||||
@@ -428,6 +435,7 @@ bool run_gemm_test_prec_type(std::string a_layout, std::string b_layout, int arg
|
||||
return true;
|
||||
}
|
||||
|
||||
template <template <typename PreType> typename GemmConfig>
|
||||
bool run_gemm_test(int argc, char* argv[])
|
||||
{
|
||||
auto [result, arg_parser] = create_args(argc, argv);
|
||||
@@ -441,41 +449,52 @@ bool run_gemm_test(int argc, char* argv[])
|
||||
if(data_type == "fp8")
|
||||
{
|
||||
using TypeConfig =
|
||||
decltype(GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t>{});
|
||||
return run_gemm_test_prec_type<TypeConfig, 128>(a_layout, b_layout, argc, argv);
|
||||
decltype(GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, float>{});
|
||||
return run_gemm_test_prec_type<GemmConfig<ck_tile::fp8_t>, TypeConfig, 128>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
}
|
||||
else if(data_type == "bf8")
|
||||
{
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::bf8_t, ck_tile::bf8_t, float>{});
|
||||
return run_gemm_test_prec_type<TypeConfig, 128>(a_layout, b_layout, argc, argv);
|
||||
using TypeConfig =
|
||||
decltype(GemmQuantTypeConfig<ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, float>{});
|
||||
return run_gemm_test_prec_type<GemmConfig<ck_tile::fp8_t>, TypeConfig, 128>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
}
|
||||
else if(data_type == "i4fp8")
|
||||
{
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::pk_int4_t,
|
||||
ck_tile::fp8_t,
|
||||
float,
|
||||
ck_tile::half_t,
|
||||
ck_tile::fp8_t>{});
|
||||
return run_gemm_test_prec_type<TypeConfig, 128>(a_layout, b_layout, argc, argv);
|
||||
return run_gemm_test_prec_type<GemmConfig<ck_tile::pk_int4_t>, TypeConfig, 128>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
}
|
||||
else if(data_type == "i4bf8")
|
||||
{
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::pk_int4_t,
|
||||
ck_tile::bf8_t,
|
||||
float,
|
||||
ck_tile::half_t,
|
||||
ck_tile::bf8_t>{});
|
||||
return run_gemm_test_prec_type<TypeConfig, 128>(a_layout, b_layout, argc, argv);
|
||||
return run_gemm_test_prec_type<GemmConfig<ck_tile::pk_int4_t>, TypeConfig, 128>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
}
|
||||
else if(data_type == "i4f32fp8")
|
||||
{
|
||||
using TypeConfig =
|
||||
decltype(GemmQuantTypeConfig<ck_tile::pk_int4_t, ck_tile::fp8_t, float, float>{});
|
||||
return run_gemm_test_prec_type<TypeConfig, 128>(a_layout, b_layout, argc, argv);
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::pk_int4_t,
|
||||
ck_tile::fp8_t,
|
||||
ck_tile::half_t,
|
||||
float>{});
|
||||
return run_gemm_test_prec_type<GemmConfig<ck_tile::pk_int4_t>, TypeConfig, 128>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
}
|
||||
else if(data_type == "i4f32bf8")
|
||||
{
|
||||
using TypeConfig =
|
||||
decltype(GemmQuantTypeConfig<ck_tile::pk_int4_t, ck_tile::bf8_t, float, float>{});
|
||||
return run_gemm_test_prec_type<TypeConfig, 128>(a_layout, b_layout, argc, argv);
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::pk_int4_t,
|
||||
ck_tile::bf8_t,
|
||||
ck_tile::half_t,
|
||||
float>{});
|
||||
return run_gemm_test_prec_type<GemmConfig<ck_tile::pk_int4_t>, TypeConfig, 128>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -564,7 +583,7 @@ int run_gemm_combinations(std::string const& data_type)
|
||||
// Call the function with the current configuration
|
||||
try
|
||||
{
|
||||
is_success = run_gemm_test(ARG_COUNT, argv) && is_success;
|
||||
is_success = run_gemm_test<GemmConfigDecode>(ARG_COUNT, argv) && is_success;
|
||||
}
|
||||
catch(const ArgumentsNotSupportedException& e)
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user