Merge branch 'develop' into streamhpc/grouped-conv-fwd-wmma

This commit is contained in:
Kiefer van Teutem
2025-12-16 09:47:31 +01:00
committed by GitHub
43 changed files with 1233 additions and 701 deletions

View File

@@ -766,6 +766,9 @@ if(CK_EXPERIMENTAL_BUILDER)
${PROJECT_SOURCE_DIR}/experimental/builder/include/ck_tile/builder
DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/ck_tile
)
set(CK_TILE_SRC_FOLDER ${CMAKE_SOURCE_DIR}/include/ck_tile/)
rocm_install(DIRECTORY ${CK_TILE_SRC_FOLDER} DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/ck_tile)
endif()
set(CPACK_RESOURCE_FILE_LICENSE "${CMAKE_CURRENT_SOURCE_DIR}/LICENSE")

View File

@@ -233,7 +233,20 @@ int run_contraction_bilinear_example(int argc, char* argv[])
}
}
return ck::utils::check_err(e_ms_ns_device_result, e_ms_ns_host_result) ? 0 : 1;
if(ck::is_gfx11_supported())
{
return ck::utils::check_err(e_ms_ns_device_result,
e_ms_ns_host_result,
"Error: Incorrect results!",
1e-4,
1e-4)
? 0
: 1;
}
else
{
return ck::utils::check_err(e_ms_ns_device_result, e_ms_ns_host_result) ? 0 : 1;
}
}
return 0;

View File

@@ -216,7 +216,20 @@ int run_contraction_scale_example(int argc, char* argv[])
}
}
return ck::utils::check_err(e_ms_ns_device_result, e_ms_ns_host_result) ? 0 : 1;
if(ck::is_gfx11_supported())
{
return ck::utils::check_err(e_ms_ns_device_result,
e_ms_ns_host_result,
"Error: Incorrect results!",
1e-4,
1e-4)
? 0
: 1;
}
else
{
return ck::utils::check_err(e_ms_ns_device_result, e_ms_ns_host_result) ? 0 : 1;
}
}
return 0;

View File

@@ -12,40 +12,6 @@
#include "ck_tile/ops/gemm.hpp"
#include "ck_tile/utility/json_dump.hpp"
template <typename PrecType, ck_tile::index_t M_Warp_Tile>
constexpr ck_tile::index_t get_k_warp_tile()
{
#if defined(CK_GFX950_SUPPORT)
constexpr bool is_8bit_float =
std::is_same_v<PrecType, ck_tile::fp8_t> || std::is_same_v<PrecType, ck_tile::bf8_t>;
if constexpr(M_Warp_Tile == 32)
return is_8bit_float ? 64 : 16;
else
return is_8bit_float ? 128 : 32;
#else
if constexpr(M_Warp_Tile == 32)
return 16;
else
return 32;
#endif
}
template <typename PrecType, ck_tile::index_t M_Warp_Tile>
constexpr ck_tile::index_t get_k_warp_tile_flatmm()
{
#if defined(CK_GFX950_SUPPORT)
if constexpr(M_Warp_Tile == 32)
return sizeof(PrecType) == 2 ? 16 : 64;
else
return sizeof(PrecType) == 2 ? 32 : 128;
#else
if constexpr(M_Warp_Tile == 32)
return sizeof(PrecType) == 2 ? 16 : 32;
else
return sizeof(PrecType) == 2 ? 32 : 64;
#endif
}
struct GemmConfigBase
{
static constexpr bool kPadM = false;
@@ -122,7 +88,8 @@ struct GemmConfigComputeV3 : public GemmConfigBase
static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
static constexpr ck_tile::index_t K_Warp_Tile =
ck_tile::get_k_warp_tile<PrecType, M_Warp_Tile>();
static constexpr bool DoubleSmemBuffer = false;
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3;
@@ -141,7 +108,8 @@ struct GemmConfigComputeV3_1 : public GemmConfigBase
static constexpr ck_tile::index_t M_Warp_Tile = 32;
static constexpr ck_tile::index_t N_Warp_Tile = 32;
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
static constexpr ck_tile::index_t K_Warp_Tile =
ck_tile::get_k_warp_tile<PrecType, M_Warp_Tile>();
static constexpr bool DoubleSmemBuffer = false;
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3;
@@ -160,7 +128,8 @@ struct GemmConfigComputeV3_2 : public GemmConfigBase
static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
static constexpr ck_tile::index_t K_Warp_Tile =
ck_tile::get_k_warp_tile<PrecType, M_Warp_Tile>();
static constexpr bool DoubleSmemBuffer = false;
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3;
@@ -204,7 +173,8 @@ struct GemmConfigComputeV4 : public GemmConfigBase
static constexpr ck_tile::index_t M_Warp_Tile = 32;
static constexpr ck_tile::index_t N_Warp_Tile = 32;
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
static constexpr ck_tile::index_t K_Warp_Tile =
ck_tile::get_k_warp_tile<PrecType, M_Warp_Tile>();
static constexpr bool DoubleSmemBuffer = true;
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V4;
@@ -223,7 +193,8 @@ struct GemmConfigComputeV4_1 : public GemmConfigBase
static constexpr ck_tile::index_t M_Warp_Tile = 32;
static constexpr ck_tile::index_t N_Warp_Tile = 32;
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
static constexpr ck_tile::index_t K_Warp_Tile =
ck_tile::get_k_warp_tile<PrecType, M_Warp_Tile>();
static constexpr bool DoubleSmemBuffer = true;
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V4;
@@ -242,7 +213,8 @@ struct GemmConfigComputeV5 : public GemmConfigBase
static constexpr ck_tile::index_t M_Warp_Tile = 32;
static constexpr ck_tile::index_t N_Warp_Tile = 32;
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
static constexpr ck_tile::index_t K_Warp_Tile =
ck_tile::get_k_warp_tile<PrecType, M_Warp_Tile>();
static constexpr bool DoubleSmemBuffer = false;
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V5;
@@ -282,7 +254,8 @@ struct GemmConfigPreshuffleDecode : public GemmConfigBase
static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile_flatmm<PrecType, M_Warp_Tile>();
static constexpr ck_tile::index_t K_Warp_Tile =
ck_tile::get_k_warp_tile<PrecType, M_Warp_Tile, true>();
static constexpr int kBlockPerCu = 1;
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
@@ -306,7 +279,8 @@ struct GemmConfigPreshufflePrefill : public GemmConfigBase
static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile_flatmm<PrecType, M_Warp_Tile>();
static constexpr ck_tile::index_t K_Warp_Tile =
ck_tile::get_k_warp_tile<PrecType, M_Warp_Tile, true>();
static constexpr int kBlockPerCu = 2;
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;

View File

@@ -11,40 +11,6 @@
#include "ck_tile/ops/gemm.hpp"
#include "ck_tile/utility/json_dump.hpp"
template <typename PrecType, ck_tile::index_t M_Warp_Tile>
constexpr ck_tile::index_t get_k_warp_tile()
{
#if defined(CK_GFX950_SUPPORT)
constexpr bool is_8bit_float =
std::is_same_v<PrecType, ck_tile::fp8_t> || std::is_same_v<PrecType, ck_tile::bf8_t>;
if constexpr(M_Warp_Tile == 32)
return is_8bit_float ? 64 : 16;
else
return is_8bit_float ? 128 : 32;
#else
if constexpr(M_Warp_Tile == 32)
return 16;
else
return 32;
#endif
}
template <typename PrecType, ck_tile::index_t M_Warp_Tile>
constexpr ck_tile::index_t get_k_warp_tile_flatmm()
{
#if defined(CK_GFX950_SUPPORT)
if constexpr(M_Warp_Tile == 32)
return sizeof(PrecType) == 2 ? 16 : 64;
else
return sizeof(PrecType) == 2 ? 32 : 128;
#else
if constexpr(M_Warp_Tile == 32)
return sizeof(PrecType) == 2 ? 16 : 32;
else
return sizeof(PrecType) == 2 ? 32 : 64;
#endif
}
template <typename DataType>
struct GemmTypeConfig;
@@ -111,7 +77,8 @@ struct GemmConfigComputeV3_2 : public GemmConfigBase
static constexpr ck_tile::index_t M_Warp_Tile = 32;
static constexpr ck_tile::index_t N_Warp_Tile = 32;
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
static constexpr ck_tile::index_t K_Warp_Tile =
ck_tile::get_k_warp_tile<PrecType, M_Warp_Tile>();
static constexpr bool DoubleSmemBuffer = false;
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3;
@@ -134,7 +101,8 @@ struct GemmConfigComputeV4 : public GemmConfigBase
static constexpr ck_tile::index_t M_Warp_Tile = 32;
static constexpr ck_tile::index_t N_Warp_Tile = 32;
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
static constexpr ck_tile::index_t K_Warp_Tile =
ck_tile::get_k_warp_tile<PrecType, M_Warp_Tile>();
static constexpr bool DoubleSmemBuffer = true;
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V4;
@@ -157,7 +125,8 @@ struct GemmConfigComputeV4_V2 : public GemmConfigBase
static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
static constexpr ck_tile::index_t K_Warp_Tile =
ck_tile::get_k_warp_tile<PrecType, M_Warp_Tile>();
static constexpr bool DoubleSmemBuffer = true;
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V4;
@@ -178,7 +147,8 @@ struct GemmConfigPreshuffleDecode : public GemmConfigBase
static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile_flatmm<PrecType, M_Warp_Tile>();
static constexpr ck_tile::index_t K_Warp_Tile =
ck_tile::get_k_warp_tile<PrecType, M_Warp_Tile, true>();
static constexpr bool kPadK = true;
@@ -203,7 +173,8 @@ struct GemmConfigPreshufflePrefill : public GemmConfigBase
static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile_flatmm<PrecType, M_Warp_Tile>();
static constexpr ck_tile::index_t K_Warp_Tile =
ck_tile::get_k_warp_tile<PrecType, M_Warp_Tile, true>();
static constexpr int kBlockPerCu = 2;
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;

View File

@@ -11,24 +11,6 @@
#include "ck_tile/ops/gemm.hpp"
#include "ck_tile/utility/json_dump.hpp"
template <typename PrecType, ck_tile::index_t M_Warp_Tile>
constexpr ck_tile::index_t get_k_warp_tile()
{
#if defined(CK_GFX950_SUPPORT)
constexpr bool is_8bit_float =
std::is_same_v<PrecType, ck_tile::fp8_t> || std::is_same_v<PrecType, ck_tile::bf8_t>;
if constexpr(M_Warp_Tile == 32)
return is_8bit_float ? 64 : 16;
else
return is_8bit_float ? 128 : 32;
#else
if constexpr(M_Warp_Tile == 32)
return 16;
else
return 32;
#endif
}
struct GemmConfigBase
{
static constexpr bool kPadM = false;

View File

@@ -10,40 +10,6 @@
#include "ck_tile/ops/gemm.hpp"
#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp"
template <typename PrecType, ck_tile::index_t M_Warp_Tile>
constexpr ck_tile::index_t get_k_warp_tile()
{
#if defined(CK_GFX950_SUPPORT)
constexpr bool is_8bit_float =
std::is_same_v<PrecType, ck_tile::fp8_t> || std::is_same_v<PrecType, ck_tile::bf8_t>;
if constexpr(M_Warp_Tile == 32)
return is_8bit_float ? 64 : 16;
else
return is_8bit_float ? 128 : 32;
#else
if constexpr(M_Warp_Tile == 32)
return 16;
else
return 32;
#endif
}
template <typename PrecType, ck_tile::index_t M_Warp_Tile>
constexpr ck_tile::index_t get_k_from_preshuffled_warp_tile()
{
#if defined(CK_GFX950_SUPPORT)
if constexpr(M_Warp_Tile == 32)
return sizeof(PrecType) == 2 ? 16 : 64;
else
return sizeof(PrecType) == 2 ? 32 : 128;
#else
if constexpr(M_Warp_Tile == 32)
return sizeof(PrecType) == 2 ? 16 : 32;
else
return sizeof(PrecType) == 2 ? 32 : 64;
#endif
}
template <typename DataType>
struct GemmTypeConfig;
@@ -100,7 +66,8 @@ struct GemmConfigComputeV3_2 : public GemmConfigBase<Persistent>
static constexpr ck_tile::index_t M_Warp_Tile = 32;
static constexpr ck_tile::index_t N_Warp_Tile = 32;
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
static constexpr ck_tile::index_t K_Warp_Tile =
ck_tile::get_k_warp_tile<PrecType, M_Warp_Tile>();
};
template <typename PrecType, bool Persistent>
@@ -117,7 +84,7 @@ struct GemmConfigPreshuffleB_Bquant_prefill : public GemmConfigBase<Persistent>
static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile =
get_k_from_preshuffled_warp_tile<PrecType, M_Warp_Tile>();
ck_tile::get_k_warp_tile<PrecType, M_Warp_Tile, true>();
static constexpr bool PreshuffleB = true;
static constexpr bool DoubleSmemBuffer = true;

View File

@@ -24,39 +24,6 @@ inline size_t hash_multiple_strings(const std::vector<std::string>& inputs)
return combined_hash;
}
template <typename PrecType, ck_tile::index_t M_Warp_Tile>
constexpr ck_tile::index_t get_k_warp_tile()
{
#if defined(CK_GFX950_SUPPORT)
constexpr bool is_8bit_float =
std::is_same_v<PrecType, ck_tile::fp8_t> || std::is_same_v<PrecType, ck_tile::bf8_t>;
if constexpr(M_Warp_Tile == 32)
return is_8bit_float ? 64 : 16;
else
return is_8bit_float ? 128 : 32;
#else
if constexpr(M_Warp_Tile == 32)
return 16;
else
return 32;
#endif
}
template <typename PrecType, ck_tile::index_t M_Warp_Tile>
constexpr ck_tile::index_t get_k_from_preshuffled_warp_tile()
{
#if defined(CK_GFX950_SUPPORT)
if constexpr(M_Warp_Tile == 32)
return sizeof(PrecType) == 2 ? 16 : 64;
else
return sizeof(PrecType) == 2 ? 32 : 128;
#else
if constexpr(M_Warp_Tile == 32)
return sizeof(PrecType) == 2 ? 16 : 32;
else
return sizeof(PrecType) == 2 ? 32 : 64;
#endif
}
template <typename Layout>
static constexpr inline auto is_row_major(Layout layout_)
{
@@ -124,7 +91,8 @@ struct GemmConfigQuantDecode : public GemmConfigBase
static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
static constexpr ck_tile::index_t K_Warp_Tile =
ck_tile::get_k_warp_tile<PrecType, M_Warp_Tile>();
};
template <typename PrecType>
@@ -140,7 +108,8 @@ struct GemmConfigRowColQuant : public GemmConfigBase
static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
static constexpr ck_tile::index_t K_Warp_Tile =
ck_tile::get_k_warp_tile<PrecType, M_Warp_Tile>();
};
template <typename PrecType>
@@ -157,7 +126,7 @@ struct GemmConfigPreshuffleQuantDecode : public GemmConfigBase
static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile =
get_k_from_preshuffled_warp_tile<PrecType, M_Warp_Tile>();
ck_tile::get_k_warp_tile<PrecType, M_Warp_Tile, true>();
static constexpr bool PreshuffleQuant = true;
};
@@ -176,7 +145,7 @@ struct GemmConfigPreshuffleB_BQuant_Decode : public GemmConfigBase
static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile =
get_k_from_preshuffled_warp_tile<PrecType, M_Warp_Tile>();
ck_tile::get_k_warp_tile<PrecType, M_Warp_Tile, true>();
static constexpr bool PreshuffleB = true;
static constexpr bool DoubleSmemBuffer = true;
@@ -206,7 +175,7 @@ struct GemmConfigPreshuffleB_BQuant_Prefill : public GemmConfigBase
static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile =
get_k_from_preshuffled_warp_tile<PrecType, M_Warp_Tile>();
ck_tile::get_k_warp_tile<PrecType, M_Warp_Tile, true>();
static constexpr bool PreshuffleB = true;
static constexpr bool DoubleSmemBuffer = true;
@@ -236,7 +205,8 @@ struct GemmConfigQuantPrefill : public GemmConfigBase
static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
static constexpr ck_tile::index_t K_Warp_Tile =
ck_tile::get_k_warp_tile<PrecType, M_Warp_Tile>();
};
template <typename PrecType>

View File

@@ -84,63 +84,46 @@ namespace ck_tile::builder::factory {
// CK Tile kernel
template <typename T>
consteval bool IsTileAlgorithm()
{
return ConvAlgorithmDescriptor<T> && SpecifiesTileThreadBlock<T> && SpecifiesTileTransfer<T> &&
SpecifiesTileConvSpecialization<T> && SpecifiesTileBlockGemm<T> &&
SpecifiesTileOptimizations<T>;
}
concept IsTileAlgorithm = ConvAlgorithmDescriptor<T> && SpecifiesTileThreadBlock<T> &&
SpecifiesTileTransfer<T> && SpecifiesTileConvSpecialization<T> &&
SpecifiesTileBlockGemm<T> && SpecifiesTileOptimizations<T>;
// XDL-based kernel with V3 pipeline structure (newer block GEMM pipeline)
template <typename T>
consteval bool IsXdlV3Algorithm()
{
return ConvAlgorithmDescriptor<T> && SpecifiesThreadBlock<T> && SpecifiesGridwiseXdlGemm<T> &&
SpecifiesBlockTransfer<T> && SpecifiesLdsTransfer<T> &&
SpecifiesThreadClusterAccessOrder<T> && SpecifiesSourceAccessOrder<T> &&
SpecifiesFwdConvSpecialization<T> && SpecifiesGemmSpecialization<T> &&
SpecifiesBlockGemm<T>;
}
concept IsXdlV3Algorithm =
ConvAlgorithmDescriptor<T> && SpecifiesThreadBlock<T> && SpecifiesGridwiseXdlGemm<T> &&
SpecifiesBlockTransfer<T> && SpecifiesLdsTransfer<T> && SpecifiesThreadClusterAccessOrder<T> &&
SpecifiesSourceAccessOrder<T> && SpecifiesFwdConvSpecialization<T> &&
SpecifiesGemmSpecialization<T> && SpecifiesBlockGemm<T>;
// Standard XDL-based kernel (uses XDLops hardware instructions for matrix multiply)
template <typename T>
consteval bool IsXdlAlgorithm()
{
return ConvAlgorithmDescriptor<T> && SpecifiesThreadBlock<T> && SpecifiesGridwiseXdlGemm<T> &&
SpecifiesBlockTransfer<T> && SpecifiesLdsTransfer<T> &&
SpecifiesThreadClusterAccessOrder<T> && SpecifiesSourceAccessOrder<T> &&
SpecifiesFwdConvSpecialization<T> && SpecifiesGemmSpecialization<T> &&
SpecifiesNumPrefetchStages<T> && SpecifiesNumGroupsToMerge<T> &&
SpecifiesLoopScheduler<T>;
}
concept IsXdlAlgorithm =
ConvAlgorithmDescriptor<T> && SpecifiesThreadBlock<T> && SpecifiesGridwiseXdlGemm<T> &&
SpecifiesBlockTransfer<T> && SpecifiesLdsTransfer<T> && SpecifiesThreadClusterAccessOrder<T> &&
SpecifiesSourceAccessOrder<T> && SpecifiesFwdConvSpecialization<T> &&
SpecifiesGemmSpecialization<T> && SpecifiesNumPrefetchStages<T> &&
SpecifiesNumGroupsToMerge<T> && SpecifiesLoopScheduler<T>;
// WMMA-based kernel (uses Wavefront Matrix-Matrix Accumulate instructions)
template <typename T>
consteval bool IsWmmaAlgorithm()
{
return ConvAlgorithmDescriptor<T> && SpecifiesThreadBlock<T> && SpecifiesGridwiseWmmaGemm<T> &&
SpecifiesBlockTransfer<T> && SpecifiesLdsTransfer<T> &&
SpecifiesThreadClusterAccessOrder<T> && SpecifiesSourceAccessOrder<T> &&
SpecifiesFwdConvSpecialization<T> && SpecifiesGemmSpecialization<T> &&
SpecifiesNumPrefetchStages<T> && SpecifiesLoopScheduler<T>;
}
concept IsWmmaAlgorithm =
ConvAlgorithmDescriptor<T> && SpecifiesThreadBlock<T> && SpecifiesGridwiseWmmaGemm<T> &&
SpecifiesBlockTransfer<T> && SpecifiesLdsTransfer<T> && SpecifiesThreadClusterAccessOrder<T> &&
SpecifiesSourceAccessOrder<T> && SpecifiesFwdConvSpecialization<T> &&
SpecifiesGemmSpecialization<T> && SpecifiesNumPrefetchStages<T> && SpecifiesLoopScheduler<T>;
// Specialized DL kernel for specific NHWC/KYXC/NHWK data layouts
template <typename T>
consteval bool IsDlAlgorithm()
{
return ConvAlgorithmDescriptor<T> && SpecifiesThreadBlock<T> &&
SpecifiesFwdConvSpecialization<T> && SpecifiesGemmSpecialization<T> &&
SpecifiesDlThreadConfig<T> && SpecifiesDlThreadCluster<T> &&
SpecifiesDlBlockTransfer<T> && SpecifiesDlEpilogue<T>;
}
concept IsDlAlgorithm =
ConvAlgorithmDescriptor<T> && SpecifiesThreadBlock<T> && SpecifiesFwdConvSpecialization<T> &&
SpecifiesGemmSpecialization<T> && SpecifiesDlThreadConfig<T> && SpecifiesDlThreadCluster<T> &&
SpecifiesDlBlockTransfer<T> && SpecifiesDlEpilogue<T>;
// XDL-based kernel with large tensor support
template <typename T>
consteval bool IsLargeTensorAlgorithm()
{
return IsXdlAlgorithm<decltype(T::base_algorithm)>() && SpecifiesLargeTensorSupport<T>;
}
concept IsLargeTensorAlgorithm =
IsXdlAlgorithm<decltype(T::base_algorithm)> && SpecifiesLargeTensorSupport<T>;
template <ConvSignatureDescriptor auto SIGNATURE,
ConvAlgorithmDescriptor auto ALGORITHM,
@@ -150,29 +133,29 @@ constexpr auto make_conv_instance()
using AlgoType = std::remove_const_t<decltype(ALGORITHM)>;
// CK Tile supports common factory for each direction
if constexpr(IsTileAlgorithm<AlgoType>())
if constexpr(IsTileAlgorithm<AlgoType>)
{
return typename ConvTileFactory<SIGNATURE, ALGORITHM, VERSION>::Instance{};
}
else if constexpr(ConvDirectionIsForward<SIGNATURE>)
{
if constexpr(IsXdlV3Algorithm<AlgoType>())
if constexpr(IsXdlV3Algorithm<AlgoType>)
{
return typename ConvFwdXdlV3Factory<SIGNATURE, ALGORITHM, VERSION>::Instance{};
}
else if constexpr(IsXdlAlgorithm<AlgoType>())
else if constexpr(IsXdlAlgorithm<AlgoType>)
{
return typename ConvFwdXdlFactory<SIGNATURE, ALGORITHM, VERSION>::Instance{};
}
else if constexpr(IsWmmaAlgorithm<AlgoType>())
else if constexpr(IsWmmaAlgorithm<AlgoType>)
{
return typename ConvFwdWmmaFactory<SIGNATURE, ALGORITHM, VERSION>::Instance{};
}
else if constexpr(IsDlAlgorithm<AlgoType>())
else if constexpr(IsDlAlgorithm<AlgoType>)
{
return typename ConvFwdDlFactory<SIGNATURE, ALGORITHM, VERSION>::Instance{};
}
else if constexpr(IsLargeTensorAlgorithm<AlgoType>())
else if constexpr(IsLargeTensorAlgorithm<AlgoType>)
{
return typename ConvFwdLargeTensorFactory<SIGNATURE, ALGORITHM, VERSION>::Instance{};
}

View File

@@ -31,13 +31,20 @@ constexpr size_t data_type_sizeof(DataType data_type)
{
case DataType::UNDEFINED_DATA_TYPE: return 0;
case DataType::FP32: return 4;
case DataType::FP32_FP32: return 8;
case DataType::FP16: return 2;
case DataType::FP16_FP16: return 4;
case DataType::BF16: return 2;
case DataType::BF16_BF16: return 4;
case DataType::FP8: return 1;
case DataType::BF8: return 1;
case DataType::FP64: return 8;
case DataType::INT32: return 4;
case DataType::I8: return 1;
case DataType::I8_I8: return 2;
case DataType::U8: return 1;
}
return 0; // Default case to ensure all control paths return a value
}
} // namespace ck_tile::builder::test

View File

@@ -72,7 +72,12 @@ inline bool is_xdl_supported()
is_gfx12_supported() || is_gfx11_supported();
}
template <typename ADataType, typename BDataType, index_t MPerXDL, index_t NPerXDL>
template <typename ADataType,
typename BDataType,
index_t MPerXDL64,
index_t NPerXDL64,
index_t MPerXDL32 = MPerXDL64,
index_t NPerXDL32 = NPerXDL64>
inline bool is_xdl_wmma_supported()
{
if(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" ||
@@ -82,7 +87,7 @@ inline bool is_xdl_wmma_supported()
}
else if(is_gfx12_supported() || is_gfx11_supported())
{
if constexpr((MPerXDL != 16) || (NPerXDL != 16))
if constexpr((MPerXDL32 != 16) || (NPerXDL32 != 16))
{
return false;
}

View File

@@ -17,6 +17,7 @@
#endif
#endif
#include "ck/utility/get_id.hpp"
#include "ck/utility/sequence.hpp"
namespace ck {
namespace tensor_operation {
@@ -96,6 +97,57 @@ static constexpr auto GetNXdlPerWave2()
IsWave64>(); \
}
template <index_t BlockSize_,
index_t MPerBlock_,
index_t NPerBlock_,
index_t MPerXDL_,
index_t NPerXDL_,
index_t MXdlPerWave_,
index_t CShuffleMXdlPerWavePerShuffle_,
index_t CShuffleNXdlPerWavePerShuffle_,
bool IsWave64>
static constexpr auto GetWarpTileConfig()
{
constexpr auto MXdlPerWave64 = MXdlPerWave_;
constexpr auto MXdlPerWave32 = MXdlPerWave_ * MPerXDL_ / 16;
constexpr auto CShuffleMXdlPerWavePerShuffle32 = CShuffleMXdlPerWavePerShuffle_ * MPerXDL_ / 16;
constexpr auto NXdlPerWave =
IsWave64
? GetNXdlPerWave2<BlockSize_,
MPerBlock_,
NPerBlock_,
MPerXDL_,
NPerXDL_,
MXdlPerWave_,
true>()
: GetNXdlPerWave2<BlockSize_, MPerBlock_, NPerBlock_, 16, 16, MXdlPerWave32, false>();
if constexpr(IsWave64 == false && NXdlPerWave != 0)
{
constexpr auto CShuffleNXdlPerWavePerShuffle32 =
NXdlPerWave >= CShuffleNXdlPerWavePerShuffle_ * NPerXDL_ / 16
? CShuffleNXdlPerWavePerShuffle_ * NPerXDL_ / 16
: CShuffleNXdlPerWavePerShuffle_;
static_assert(CShuffleNXdlPerWavePerShuffle32 > 0);
return Sequence<16,
16,
MXdlPerWave32,
NXdlPerWave,
CShuffleMXdlPerWavePerShuffle32,
CShuffleNXdlPerWavePerShuffle32>{};
}
else
{
return Sequence<MPerXDL_,
NPerXDL_,
MXdlPerWave64,
NXdlPerWave,
CShuffleMXdlPerWavePerShuffle_,
CShuffleNXdlPerWavePerShuffle_>{};
}
}
#define INVOKER_RUN_IMPL \
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) \
{ \

View File

@@ -166,11 +166,27 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
{
using DeviceOp = DeviceContractionMultipleD_Xdl_CShuffle;
GET_NXDL_PER_WAVE_IMPL
static constexpr auto NXdlPerWave64 = GetNXdlPerWave<true>();
static constexpr auto NXdlPerWave32 = GetNXdlPerWave<false>();
static constexpr index_t NumDTensor = DsDataType::Size();
static constexpr auto WarpTileConfig64 = GetWarpTileConfig<BlockSize,
MPerBlock,
NPerBlock,
MPerXDL,
NPerXDL,
MXdlPerWave,
CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
true>();
static constexpr auto WarpTileConfig32 = GetWarpTileConfig<BlockSize,
MPerBlock,
NPerBlock,
MPerXDL,
NPerXDL,
MXdlPerWave,
CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
false>();
static constexpr auto NXdlPerWave64 = WarpTileConfig64.At(3);
static constexpr auto NXdlPerWave32 = WarpTileConfig32.At(3);
static constexpr index_t NumDTensor = DsDataType::Size();
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
@@ -321,7 +337,7 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
using EGridDesc_M_N = decltype(MakeEGridDescriptor_M_N({}, {}));
// GridwiseGemm
template <index_t NXdlPerWave_>
template <typename WarpTileConfig>
using GridwiseGemmBase = GridwiseGemmMultipleD_xdl_cshuffle<
ADataType, // TODO: distinguish A/B datatype
BDataType,
@@ -340,10 +356,10 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
KPerBlock,
AK1,
BK1,
MPerXDL,
NPerXDL,
MXdlPerWave,
NXdlPerWave_,
WarpTileConfig::At(0),
WarpTileConfig::At(1),
WarpTileConfig::At(2),
WarpTileConfig::At(3),
ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
@@ -360,13 +376,13 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
BBlockTransferDstScalarPerVector_BK1,
false,
BBlockLdsExtraN,
CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
WarpTileConfig::At(4),
WarpTileConfig::At(5),
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CDEBlockTransferScalarPerVector_NPerBlock,
LoopSched>;
using GridwiseGemm64 = GridwiseGemmBase<math::max(NXdlPerWave64, 1)>;
using GridwiseGemm32 = GridwiseGemmBase<NXdlPerWave32>;
using GridwiseGemm64 = GridwiseGemmBase<decltype(WarpTileConfig64)>;
using GridwiseGemm32 = GridwiseGemmBase<decltype(WarpTileConfig32)>;
// desc for blockwise copy
using AGridDesc_AK0_M_AK1 =
@@ -588,7 +604,12 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
static bool IsSupportedArgument(const Argument& arg)
{
if(!ck::is_xdl_wmma_supported<ComputeDataType, ComputeDataType, MPerXDL, NPerXDL>())
if(!ck::is_xdl_wmma_supported<ComputeDataType,
ComputeDataType,
MPerXDL,
NPerXDL,
WarpTileConfig32.At(0),
WarpTileConfig32.At(1)>())
{
return false;
}
@@ -783,6 +804,8 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
<< MPerBlock << ", "
<< NPerBlock << ", "
<< KPerBlock << ", "
<< MPerXDL << ", "
<< NPerXDL << ", "
<< AK1 << ", "
<< BK1 << ", "
<< ABlockTransferSrcVectorDim << ", "

View File

@@ -366,6 +366,26 @@ struct amdgcn_compiler_target_state
#else
static constexpr bool CK_TILE_ARCH_GFX1010 = false;
#endif
#if defined(__gfx1011__)
static constexpr bool CK_TILE_ARCH_GFX1011 = true;
#else
static constexpr bool CK_TILE_ARCH_GFX1011 = false;
#endif
#if defined(__gfx1012__)
static constexpr bool CK_TILE_ARCH_GFX1012 = true;
#else
static constexpr bool CK_TILE_ARCH_GFX1012 = false;
#endif
#if defined(__gfx1013__)
static constexpr bool CK_TILE_ARCH_GFX1013 = true;
#else
static constexpr bool CK_TILE_ARCH_GFX1013 = false;
#endif
#if defined(__gfx10_1_generic__)
static constexpr bool CK_TILE_ARCH_GFX10_1_GENERIC = true;
#else
static constexpr bool CK_TILE_ARCH_GFX10_1_GENERIC = false;
#endif // __gfx10_1_generic__
#if defined(__gfx1030__)
static constexpr bool CK_TILE_ARCH_GFX1030 = true;
@@ -504,6 +524,10 @@ CK_TILE_HOST_DEVICE static constexpr uint32_t count_values_of(T search, Ts... se
amdgcn_compiler_target_state::CK_TILE_ARCH_GFX942, \
amdgcn_compiler_target_state::CK_TILE_ARCH_GFX950, \
amdgcn_compiler_target_state::CK_TILE_ARCH_GFX1010, \
amdgcn_compiler_target_state::CK_TILE_ARCH_GFX1011, \
amdgcn_compiler_target_state::CK_TILE_ARCH_GFX1012, \
amdgcn_compiler_target_state::CK_TILE_ARCH_GFX1013, \
amdgcn_compiler_target_state::CK_TILE_ARCH_GFX10_1_GENERIC, \
amdgcn_compiler_target_state::CK_TILE_ARCH_GFX1030, \
amdgcn_compiler_target_state::CK_TILE_ARCH_GFX1031, \
amdgcn_compiler_target_state::CK_TILE_ARCH_GFX1032, \

View File

@@ -68,7 +68,7 @@ auto shuffle_bq(const ck_tile::HostTensor<T>* t, int block_bq_k)
}
template <typename GemmConfig, typename T>
auto shuffle_b(const ck_tile::HostTensor<T>& t)
auto shuffle_b(const ck_tile::HostTensor<T>& t, const GemmConfig& gemmConfig)
{
assert(t.get_lengths().size() == 2);
int n_ = t.get_lengths()[1];
@@ -78,10 +78,10 @@ auto shuffle_b(const ck_tile::HostTensor<T>& t)
{
constexpr int divisor = 2;
constexpr int kABK1PerLane = 8;
constexpr int kABK0PerLane = GemmConfig::K_Warp_Tile / divisor / kABK1PerLane;
ck_tile::HostTensor<T> t_view({n_ / GemmConfig::N_Warp_Tile,
GemmConfig::N_Warp_Tile,
k_ / GemmConfig::K_Warp_Tile,
int kABK0PerLane = gemmConfig.K_Warp_Tile / divisor / kABK1PerLane;
ck_tile::HostTensor<T> t_view({n_ / gemmConfig.N_Warp_Tile,
gemmConfig.N_Warp_Tile,
k_ / gemmConfig.K_Warp_Tile,
kABK0PerLane,
divisor,
kABK1PerLane});
@@ -98,18 +98,24 @@ auto shuffle_b(const ck_tile::HostTensor<T>& t)
else
{
assert(is_wave32() == false);
divisor = GemmConfig::N_Warp_Tile == 32 ? 2 : 4;
divisor = gemmConfig.N_Warp_Tile == 32 ? 2 : 4;
}
ck_tile::HostTensor<T> t_view({n_ / GemmConfig::N_Warp_Tile,
GemmConfig::N_Warp_Tile,
k_ / GemmConfig::K_Warp_Tile,
ck_tile::HostTensor<T> t_view({n_ / gemmConfig.N_Warp_Tile,
gemmConfig.N_Warp_Tile,
k_ / gemmConfig.K_Warp_Tile,
divisor,
GemmConfig::K_Warp_Tile / divisor});
gemmConfig.K_Warp_Tile / divisor});
std::copy(t.begin(), t.end(), t_view.begin());
return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4});
}
}
template <typename GemmConfig, typename T>
auto shuffle_b(const ck_tile::HostTensor<T>& t)
{
return shuffle_b(t, GemmConfig{});
}
template <typename GemmConfig, typename T>
auto bq_permuteN(const ck_tile::HostTensor<T>& t, index_t group_n)
{
@@ -129,22 +135,22 @@ auto bq_permuteN(const ck_tile::HostTensor<T>& t, index_t group_n)
}
template <typename GemmConfig, typename T>
auto shuffle_b_permuteN(const ck_tile::HostTensor<T>& t)
auto shuffle_b_permuteN(const ck_tile::HostTensor<T>& t, const GemmConfig& gemmConfig)
{
assert(t.get_lengths().size() == 2);
int n_ = t.get_lengths()[1];
int k_ = t.get_lengths()[0];
constexpr int NRepeat = GemmConfig::N_Tile / GemmConfig::N_Warp_Tile / GemmConfig::N_Warp;
int n_ = t.get_lengths()[1];
int k_ = t.get_lengths()[0];
int NRepeat = gemmConfig.N_Tile / gemmConfig.N_Warp_Tile / gemmConfig.N_Warp;
if(ck_tile::is_gfx12_supported())
{
constexpr int divisor = 2;
constexpr int kABK1PerLane = 8;
constexpr int kABK0PerLane = GemmConfig::K_Warp_Tile / divisor / kABK1PerLane;
ck_tile::HostTensor<T> t_view({n_ / GemmConfig::N_Tile,
GemmConfig::N_Warp,
GemmConfig::N_Warp_Tile,
int kABK0PerLane = gemmConfig.K_Warp_Tile / divisor / kABK1PerLane;
ck_tile::HostTensor<T> t_view({n_ / gemmConfig.N_Tile,
gemmConfig.N_Warp,
gemmConfig.N_Warp_Tile,
NRepeat,
k_ / GemmConfig::K_Warp_Tile,
k_ / gemmConfig.K_Warp_Tile,
kABK0PerLane,
divisor,
kABK1PerLane});
@@ -161,17 +167,23 @@ auto shuffle_b_permuteN(const ck_tile::HostTensor<T>& t)
else
{
assert(is_wave32() == false);
divisor = GemmConfig::N_Warp_Tile == 32 ? 2 : 4;
divisor = gemmConfig.N_Warp_Tile == 32 ? 2 : 4;
}
ck_tile::HostTensor<T> t_view({n_ / GemmConfig::N_Tile,
GemmConfig::N_Warp,
GemmConfig::N_Warp_Tile,
ck_tile::HostTensor<T> t_view({n_ / gemmConfig.N_Tile,
gemmConfig.N_Warp,
gemmConfig.N_Warp_Tile,
NRepeat,
k_ / GemmConfig::K_Warp_Tile,
k_ / gemmConfig.K_Warp_Tile,
divisor,
GemmConfig::K_Warp_Tile / divisor});
gemmConfig.K_Warp_Tile / divisor});
std::copy(t.begin(), t.end(), t_view.begin());
return ck_tile::reference_permute(t_view, {0, 3, 1, 4, 5, 2, 6});
}
}
template <typename GemmConfig, typename T>
auto shuffle_b_permuteN(const ck_tile::HostTensor<T>& t)
{
return shuffle_b_permuteN(t, GemmConfig{});
}
} // namespace ck_tile

View File

@@ -43,4 +43,26 @@ struct TileGemmShape
}
};
template <typename PrecType, index_t M_Warp_Tile, bool IsFlatMM = false>
constexpr index_t get_k_warp_tile()
{
#if CK_TILE_USE_WMMA
return 16;
#else
#if defined(CK_GFX950_SUPPORT)
constexpr bool is_8bit_float =
std::is_same_v<PrecType, fp8_t> || std::is_same_v<PrecType, bf8_t>;
if constexpr(M_Warp_Tile == 32)
return is_8bit_float ? 64 : 16;
else
return is_8bit_float ? 128 : 32;
#else
if constexpr(M_Warp_Tile == 32)
return (sizeof(PrecType) == 2 || IsFlatMM == false) ? 16 : 32;
else
return (sizeof(PrecType) == 2 || IsFlatMM == false) ? 32 : 64;
#endif
#endif
}
} // namespace ck_tile

View File

@@ -61,6 +61,7 @@ struct BQuantBlockUniversalGemmAsBsCr : public BlockGemmBQuantBase<Problem_>
using ADataType = remove_cvref_t<typename Problem::ADataType>;
using BDataType = remove_cvref_t<typename Problem::BDataType>;
using BQDataType = remove_cvref_t<typename Problem::BQDataType>;
using BLayout = remove_cvref_t<typename Problem::BLayout>;
using BQLayout = remove_cvref_t<typename Problem::BQLayout>;
using ComputeDataType = remove_cvref_t<typename Problem::ComputeDataType>;
using CDataType = remove_cvref_t<typename Problem::CDataType>;
@@ -156,9 +157,11 @@ struct BQuantBlockUniversalGemmAsBsCr : public BlockGemmBQuantBase<Problem_>
using CDataType = remove_cvref_t<typename Traits::CDataType>;
// BDataType gets converted from PkInt4 during loading
using OverrideBDataType =
std::conditional_t<std::is_same_v<BDataType, pk_int4_t>, ADataType, BDataType>;
using OverrideBDataType = std::conditional_t<
std::is_same_v<BDataType, pk_int4_t> &&
std::is_same_v<typename Traits::BLayout, tensor_layout::gemm::RowMajor>,
ADataType,
BDataType>;
using Base = BlockGemmBQuantBase<Problem_>;
using WarpGemm = remove_cvref_t<typename Traits::WarpGemm>;

View File

@@ -33,9 +33,17 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Prob
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
using QuantGroupSize = remove_cvref_t<typename Problem::QuantGroupSize>;
using ALayout = remove_cvref_t<typename Problem::ALayout>;
using BQLayout = remove_cvref_t<typename Problem::BQLayout>;
using BLayout = remove_cvref_t<typename Problem::BLayout>;
using CLayout = remove_cvref_t<typename Problem::CLayout>;
// BDataType gets converted from PkInt4 during loading
using OverrideBDataType =
std::conditional_t<std::is_same_v<BDataType, pk_int4_t>, ADataType, BDataType>;
std::conditional_t<std::is_same_v<BDataType, pk_int4_t> &&
std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>,
ADataType,
BDataType>;
static_assert(QuantGroupSize::kM == 1, "only N/K blocks for BQuant kernel!");
using I0 = number<0>;
@@ -50,11 +58,6 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Prob
static constexpr index_t BQPackedSize =
ck_tile::numeric_traits<remove_cvref_t<BQDataType>>::PackedSize;
using ALayout = remove_cvref_t<typename Problem::ALayout>;
using BQLayout = remove_cvref_t<typename Problem::BQLayout>;
using BLayout = remove_cvref_t<typename Problem::BLayout>;
using CLayout = remove_cvref_t<typename Problem::CLayout>;
using BlockGemm = remove_cvref_t<decltype(Policy::template GetBlockGemm<Problem>())>;
static constexpr index_t BlockSize = Problem::kBlockSize;
@@ -184,6 +187,23 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Prob
load_int4_tile<SrcDataType, DestDataType, UnaryOpSize>(b_block_tile, b_dram_window);
}
template <typename BBlockTile_, typename BDramWindow, typename BDramTileWindowStep>
CK_TILE_DEVICE void
BGlobalPrefetch(BBlockTile_& b_block_tile,
BDramWindow& b_copy_dram_window,
const BDramTileWindowStep& b_dram_tile_window_step) const
{
if constexpr(!std::is_same_v<BDataType, OverrideBDataType>)
{
LoadAndConvertBTile(b_block_tile, b_copy_dram_window);
move_tile_window(b_copy_dram_window, b_dram_tile_window_step);
}
else
{
Base::GlobalPrefetch(b_block_tile, b_copy_dram_window, b_dram_tile_window_step);
}
}
template <bool HasHotLoop,
TailNumber TailNum,
typename ADramBlockWindowTmp,
@@ -262,7 +282,7 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Prob
using ABlockTile =
decltype(make_static_distributed_tensor<ADataType>(ABlockTileDistr{}));
using BBlockTile =
decltype(make_static_distributed_tensor<ADataType>(BBlockTileDistr{}));
decltype(make_static_distributed_tensor<OverrideBDataType>(BBlockTileDistr{}));
using BQBlockTile =
decltype(make_static_distributed_tensor<BQDataType>(BQBlockTileDistr{}));
@@ -289,8 +309,7 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Prob
// DRAM prefetch (global read 0)
Base::GlobalPrefetch(a_block_tile, a_copy_dram_window, a_dram_tile_window_step);
// B tile gets converted to A datatype during loading
LoadAndConvertBTile(b_block_tile, b_copy_dram_window);
move_tile_window(b_copy_dram_window, b_dram_tile_window_step);
BGlobalPrefetch(b_block_tile, b_copy_dram_window, b_dram_tile_window_step);
Base::GlobalPrefetch(
bq_block_tile[currIdx], bq_copy_dram_window, bq_dram_tile_window_step);
@@ -311,7 +330,7 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Prob
if constexpr(is_b_row_major && !is_b_load_tr_v())
{
// B datatype is converted to A datatype during loading
auto b_shuffle_tmp = make_static_distributed_tensor<ADataType>(
auto b_shuffle_tmp = make_static_distributed_tensor<OverrideBDataType>(
Policy::template MakeShuffledBRegTileDistribution<Problem>());
transpose_tile2d(b_shuffle_tmp, b_block_tile);
Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func);
@@ -322,8 +341,8 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Prob
}
Base::GlobalPrefetch(a_block_tile, a_copy_dram_window, a_dram_tile_window_step);
LoadAndConvertBTile(b_block_tile, b_copy_dram_window);
move_tile_window(b_copy_dram_window, b_dram_tile_window_step);
// B tile gets converted to A datatype during loading
BGlobalPrefetch(b_block_tile, b_copy_dram_window, b_dram_tile_window_step);
block_sync_lds();
@@ -366,8 +385,8 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Prob
}
Base::GlobalPrefetch(a_block_tile, a_copy_dram_window, a_dram_tile_window_step);
LoadAndConvertBTile(b_block_tile, b_copy_dram_window);
move_tile_window(b_copy_dram_window, b_dram_tile_window_step);
// B tile gets converted to A datatype during loading
BGlobalPrefetch(b_block_tile, b_copy_dram_window, b_dram_tile_window_step);
Base::GlobalPrefetch(bq_block_tile[(currIdx + 1) % 2],
bq_copy_dram_window,
bq_dram_tile_window_step);

View File

@@ -11,25 +11,93 @@ list(APPEND TEST_GEMM_COMPILE_OPTIONS -mllvm -enable-noalias-to-md-conversion=0)
if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx12")
# Typed Test Suite for GEMM Quantization - split into multiple files to reduce compile time
# AQuant tests
add_gtest_executable(test_tile_gemm_quant_aquant
test_gemm_quant_aquant.cpp
# AQuant tests - split into 6 files
add_gtest_executable(test_tile_gemm_quant_aquant_base_rcr
test_gemm_quant_aquant_base_rcr.cpp
)
target_compile_options(test_tile_gemm_quant_aquant PRIVATE ${TEST_GEMM_COMPILE_OPTIONS})
target_compile_options(test_tile_gemm_quant_aquant_base_rcr PRIVATE ${TEST_GEMM_COMPILE_OPTIONS})
# BQuant tests (without PreshuffleB)
add_gtest_executable(test_tile_gemm_quant_bquant
test_gemm_quant_bquant.cpp
add_gtest_executable(test_tile_gemm_quant_aquant_base_rrr_crr
test_gemm_quant_aquant_base_rrr_crr.cpp
)
target_compile_options(test_tile_gemm_quant_bquant PRIVATE ${TEST_GEMM_COMPILE_OPTIONS})
target_compile_options(test_tile_gemm_quant_aquant_base_rrr_crr PRIVATE ${TEST_GEMM_COMPILE_OPTIONS})
# BQuant tests (with PreshuffleB)
# disabling this test until it can be built within reasonable time!
# currently taking ~50 minutes on gfx12!
#add_gtest_executable(test_tile_gemm_quant_bquant_preshuffle
# test_gemm_quant_bquant_preshuffle.cpp
#)
#target_compile_options(test_tile_gemm_quant_bquant_preshuffle PRIVATE ${TEST_GEMM_COMPILE_OPTIONS})
add_gtest_executable(test_tile_gemm_quant_aquant_base_ccr
test_gemm_quant_aquant_base_ccr.cpp
)
target_compile_options(test_tile_gemm_quant_aquant_base_ccr PRIVATE ${TEST_GEMM_COMPILE_OPTIONS})
add_gtest_executable(test_tile_gemm_quant_aquant_prefill
test_gemm_quant_aquant_prefill.cpp
)
target_compile_options(test_tile_gemm_quant_aquant_prefill PRIVATE ${TEST_GEMM_COMPILE_OPTIONS})
add_gtest_executable(test_tile_gemm_quant_aquant_transpose_c
test_gemm_quant_aquant_transpose_c.cpp
)
target_compile_options(test_tile_gemm_quant_aquant_transpose_c PRIVATE ${TEST_GEMM_COMPILE_OPTIONS})
add_gtest_executable(test_tile_gemm_quant_aquant_preshuffle
test_gemm_quant_aquant_preshuffle.cpp
)
target_compile_options(test_tile_gemm_quant_aquant_preshuffle PRIVATE ${TEST_GEMM_COMPILE_OPTIONS})
# BQuant tests (without PreshuffleB) - split into 6 files
add_gtest_executable(test_tile_gemm_quant_bquant_1d_128
test_gemm_quant_bquant_1d_128.cpp
)
target_compile_options(test_tile_gemm_quant_bquant_1d_128 PRIVATE ${TEST_GEMM_COMPILE_OPTIONS})
add_gtest_executable(test_tile_gemm_quant_bquant_1d_64
test_gemm_quant_bquant_1d_64.cpp
)
target_compile_options(test_tile_gemm_quant_bquant_1d_64 PRIVATE ${TEST_GEMM_COMPILE_OPTIONS})
add_gtest_executable(test_tile_gemm_quant_bquant_2d_small_n
test_gemm_quant_bquant_2d_small_n.cpp
)
target_compile_options(test_tile_gemm_quant_bquant_2d_small_n PRIVATE ${TEST_GEMM_COMPILE_OPTIONS})
add_gtest_executable(test_tile_gemm_quant_bquant_2d_medium_n
test_gemm_quant_bquant_2d_medium_n.cpp
)
target_compile_options(test_tile_gemm_quant_bquant_2d_medium_n PRIVATE ${TEST_GEMM_COMPILE_OPTIONS})
add_gtest_executable(test_tile_gemm_quant_bquant_2d_large_n
test_gemm_quant_bquant_2d_large_n.cpp
)
target_compile_options(test_tile_gemm_quant_bquant_2d_large_n PRIVATE ${TEST_GEMM_COMPILE_OPTIONS})
add_gtest_executable(test_tile_gemm_quant_bquant_transpose
test_gemm_quant_bquant_transpose.cpp
)
target_compile_options(test_tile_gemm_quant_bquant_transpose PRIVATE ${TEST_GEMM_COMPILE_OPTIONS})
# BQuant tests (with PreshuffleB) - split into 5 files
add_gtest_executable(test_tile_gemm_quant_bquant_preshuffle_decode_1d
test_gemm_quant_bquant_preshuffle_decode_1d.cpp
)
target_compile_options(test_tile_gemm_quant_bquant_preshuffle_decode_1d PRIVATE ${TEST_GEMM_COMPILE_OPTIONS})
add_gtest_executable(test_tile_gemm_quant_bquant_preshuffle_prefill_1d
test_gemm_quant_bquant_preshuffle_prefill_1d.cpp
)
target_compile_options(test_tile_gemm_quant_bquant_preshuffle_prefill_1d PRIVATE ${TEST_GEMM_COMPILE_OPTIONS})
add_gtest_executable(test_tile_gemm_quant_bquant_preshuffle_tiled_permute
test_gemm_quant_bquant_preshuffle_tiled_permute.cpp
)
target_compile_options(test_tile_gemm_quant_bquant_preshuffle_tiled_permute PRIVATE ${TEST_GEMM_COMPILE_OPTIONS})
add_gtest_executable(test_tile_gemm_quant_bquant_preshuffle_decode_2d
test_gemm_quant_bquant_preshuffle_decode_2d.cpp
)
target_compile_options(test_tile_gemm_quant_bquant_preshuffle_decode_2d PRIVATE ${TEST_GEMM_COMPILE_OPTIONS})
add_gtest_executable(test_tile_gemm_quant_bquant_preshuffle_prefill_2d
test_gemm_quant_bquant_preshuffle_prefill_2d.cpp
)
target_compile_options(test_tile_gemm_quant_bquant_preshuffle_prefill_2d PRIVATE ${TEST_GEMM_COMPILE_OPTIONS})
# RowColQuant tests
add_gtest_executable(test_tile_gemm_quant_rowcol
@@ -42,6 +110,34 @@ if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx12")
test_gemm_quant_tensor.cpp
)
target_compile_options(test_tile_gemm_quant_tensor PRIVATE ${TEST_GEMM_COMPILE_OPTIONS})
# Umbrella target to build all gemm quant tests
add_custom_target(test_tile_gemm_quant_all)
add_dependencies(test_tile_gemm_quant_all
# AQuant tests
test_tile_gemm_quant_aquant_base_rcr
test_tile_gemm_quant_aquant_base_rrr_crr
test_tile_gemm_quant_aquant_base_ccr
test_tile_gemm_quant_aquant_prefill
test_tile_gemm_quant_aquant_transpose_c
test_tile_gemm_quant_aquant_preshuffle
# BQuant tests
test_tile_gemm_quant_bquant_1d_128
test_tile_gemm_quant_bquant_1d_64
test_tile_gemm_quant_bquant_2d_small_n
test_tile_gemm_quant_bquant_2d_medium_n
test_tile_gemm_quant_bquant_2d_large_n
test_tile_gemm_quant_bquant_transpose
# BQuant preshuffle tests
test_tile_gemm_quant_bquant_preshuffle_decode_1d
test_tile_gemm_quant_bquant_preshuffle_prefill_1d
test_tile_gemm_quant_bquant_preshuffle_tiled_permute
test_tile_gemm_quant_bquant_preshuffle_decode_2d
test_tile_gemm_quant_bquant_preshuffle_prefill_2d
# Other quant tests
test_tile_gemm_quant_rowcol
test_tile_gemm_quant_tensor
)
else()
message(DEBUG "Skipping ck_tile quant gemm tests for current target")
endif()

View File

@@ -1,95 +0,0 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "ck_tile/host.hpp"
#include "ck_tile/ops/gemm.hpp"
#include <gtest/gtest.h>
#include <memory>
#include "test_gemm_quant_fixtures.hpp"
// Type aliases for readability
using RowMajor = ck_tile::tensor_layout::gemm::RowMajor;
using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor;
using FP8 = ck_tile::fp8_t;
using BF8 = ck_tile::bf8_t;
using Half = ck_tile::half_t;
using PkInt4 = ck_tile::pk_int4_t;
using AQuantGrouped = std::integral_constant<ck_tile::QuantType, ck_tile::QuantType::AQuantGrouped>;
using BQuantGrouped = std::integral_constant<ck_tile::QuantType, ck_tile::QuantType::BQuantGrouped>;
using RowColQuant = std::integral_constant<ck_tile::QuantType, ck_tile::QuantType::RowColQuant>;
using TensorQuant = std::integral_constant<ck_tile::QuantType, ck_tile::QuantType::TensorQuant>;
using GroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
using GroupSize64 = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 64>>;
// 2d block sizes for BQuant
using GroupSize2D8N = ck_tile::QuantGroupShape<ck_tile::sequence<1, 8, 128>>;
using GroupSize2D16N = ck_tile::QuantGroupShape<ck_tile::sequence<1, 16, 128>>;
using GroupSize2D32N = ck_tile::QuantGroupShape<ck_tile::sequence<1, 32, 128>>;
using GroupSize2D64N = ck_tile::QuantGroupShape<ck_tile::sequence<1, 64, 128>>;
using GroupSize2D128N = ck_tile::QuantGroupShape<ck_tile::sequence<1, 128, 128>>;
// Type combinations for AQuant tests
// Tuple format: <ALayout, BLayout, CLayout, AQLayout, ADataType, BDataType, QDataType, CDataType,
// QuantType, GemmConfig, QuantGroupSize>
// clang-format off
using AQuantTypes = ::testing::Types<
// PreshuffleQuant = false && TransposeC = false (RCR layout with RowMajor AQ)
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, FP8, FP8, float, Half, AQuantGrouped, GemmConfigBase, GroupSize>,
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, BF8, BF8, float, Half, AQuantGrouped, GemmConfigBase, GroupSize>,
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, PkInt4, FP8, FP8, Half, AQuantGrouped, GemmConfigBase, GroupSize>,
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, PkInt4, BF8, BF8, Half, AQuantGrouped, GemmConfigBase, GroupSize>,
// RRR layout (RowMajor A, RowMajor B, RowMajor C with RowMajor AQ)
std::tuple<RowMajor, RowMajor, RowMajor, RowMajor, FP8, FP8, float, Half, AQuantGrouped, GemmConfigBase, GroupSize>,
std::tuple<RowMajor, RowMajor, RowMajor, RowMajor, BF8, BF8, float, Half, AQuantGrouped, GemmConfigBase, GroupSize>,
std::tuple<RowMajor, RowMajor, RowMajor, RowMajor, PkInt4, FP8, FP8, Half, AQuantGrouped, GemmConfigBase, GroupSize>,
std::tuple<RowMajor, RowMajor, RowMajor, RowMajor, PkInt4, BF8, BF8, Half, AQuantGrouped, GemmConfigBase, GroupSize>,
// CRR layout (ColumnMajor A, RowMajor B, RowMajor C with RowMajor AQ)
std::tuple<ColumnMajor, RowMajor, RowMajor, RowMajor, FP8, FP8, float, Half, AQuantGrouped, GemmConfigBase, GroupSize>,
std::tuple<ColumnMajor, RowMajor, RowMajor, RowMajor, BF8, BF8, float, Half, AQuantGrouped, GemmConfigBase, GroupSize>,
std::tuple<ColumnMajor, RowMajor, RowMajor, RowMajor, PkInt4, FP8, FP8, Half, AQuantGrouped, GemmConfigBase, GroupSize>,
std::tuple<ColumnMajor, RowMajor, RowMajor, RowMajor, PkInt4, BF8, BF8, Half, AQuantGrouped, GemmConfigBase, GroupSize>,
// CCR layout (ColumnMajor A, ColumnMajor B, RowMajor C with ColumnMajor AQ) - NEW layout support
std::tuple<ColumnMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, FP8, float, Half, AQuantGrouped, GemmConfigBase, GroupSize>,
std::tuple<ColumnMajor, ColumnMajor, RowMajor, ColumnMajor, BF8, BF8, float, Half, AQuantGrouped, GemmConfigBase, GroupSize>,
std::tuple<ColumnMajor, ColumnMajor, RowMajor, ColumnMajor, PkInt4, FP8, FP8, Half, AQuantGrouped, GemmConfigBase, GroupSize>,
std::tuple<ColumnMajor, ColumnMajor, RowMajor, ColumnMajor, PkInt4, BF8, BF8, Half, AQuantGrouped, GemmConfigBase, GroupSize>,
// RCR layout - with the Prefill BlockTile Config.
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, FP8, FP8, float, Half, AQuantGrouped, GemmConfigPrefill, GroupSize>,
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, BF8, BF8, float, Half, AQuantGrouped, GemmConfigPrefill, GroupSize>,
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, PkInt4, FP8, FP8, Half, AQuantGrouped, GemmConfigPrefill, GroupSize>,
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, PkInt4, BF8, BF8, Half, AQuantGrouped, GemmConfigPrefill, GroupSize>,
// PreshuffleQuant = false && TransposeC = true (with RowMajor AQ)
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, FP8, FP8, float, Half, AQuantGrouped, GemmConfigTransposeC, GroupSize>,
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, BF8, BF8, float, Half, AQuantGrouped, GemmConfigTransposeC, GroupSize>,
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, PkInt4, FP8, FP8, Half, AQuantGrouped, GemmConfigTransposeC, GroupSize>,
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, PkInt4, BF8, BF8, Half, AQuantGrouped, GemmConfigTransposeC, GroupSize>,
// PreshuffleQuant = true && TransposeC = false (with RowMajor AQ - PreshuffleQuant only supports RowMajor)
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, FP8, FP8, float, Half, AQuantGrouped, GemmConfigPreshuffleQuant, GroupSize>,
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, BF8, BF8, float, Half, AQuantGrouped, GemmConfigPreshuffleQuant, GroupSize>,
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, PkInt4, FP8, FP8, Half, AQuantGrouped, GemmConfigPreshuffleQuant, GroupSize>,
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, PkInt4, BF8, BF8, Half, AQuantGrouped, GemmConfigPreshuffleQuant, GroupSize>,
// PreshuffleQuant = true && TransposeC = true (with RowMajor AQ - PreshuffleQuant only supports RowMajor)
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, FP8, FP8, float, Half, AQuantGrouped, GemmConfigPreshuffleQuantTransposeC, GroupSize>,
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, BF8, BF8, float, Half, AQuantGrouped, GemmConfigPreshuffleQuantTransposeC, GroupSize>,
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, PkInt4, FP8, FP8, Half, AQuantGrouped, GemmConfigPreshuffleQuantTransposeC, GroupSize>,
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, PkInt4, BF8, BF8, Half, AQuantGrouped, GemmConfigPreshuffleQuantTransposeC, GroupSize>
>;
// clang-format on
// Test suite for AQuant
TYPED_TEST_SUITE(TestCkTileGemmAQuant, AQuantTypes);
// AQuant tests
TYPED_TEST(TestCkTileGemmAQuant, AQuantGroupedTest)
{
this->run_test_with_validation(1024, 1024, 1024);
}

View File

@@ -0,0 +1,42 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "ck_tile/host.hpp"
#include "ck_tile/ops/gemm.hpp"
#include <gtest/gtest.h>
#include <memory>
#include "test_gemm_quant_fixtures.hpp"
// Type aliases for readability
using RowMajor = ck_tile::tensor_layout::gemm::RowMajor;
using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor;
using FP8 = ck_tile::fp8_t;
using BF8 = ck_tile::bf8_t;
using Half = ck_tile::half_t;
using PkInt4 = ck_tile::pk_int4_t;
using AQuantGrouped = std::integral_constant<ck_tile::QuantType, ck_tile::QuantType::AQuantGrouped>;
using GroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
// Type combinations for AQuant tests - CCR layout
// Tuple format: <ALayout, BLayout, CLayout, AQLayout, ADataType, BDataType, QDataType, CDataType,
// QuantType, GemmConfig, QuantGroupSize>
// clang-format off
using AQuantBaseCCRTypes = ::testing::Types<
// CCR layout (ColumnMajor A, ColumnMajor B, RowMajor C with ColumnMajor AQ) - NEW layout support
std::tuple<ColumnMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, FP8, float, Half, AQuantGrouped, GemmConfigBase, GroupSize>,
std::tuple<ColumnMajor, ColumnMajor, RowMajor, ColumnMajor, BF8, BF8, float, Half, AQuantGrouped, GemmConfigBase, GroupSize>,
std::tuple<ColumnMajor, ColumnMajor, RowMajor, ColumnMajor, PkInt4, FP8, FP8, Half, AQuantGrouped, GemmConfigBase, GroupSize>,
std::tuple<ColumnMajor, ColumnMajor, RowMajor, ColumnMajor, PkInt4, BF8, BF8, Half, AQuantGrouped, GemmConfigBase, GroupSize>
>;
// clang-format on
// Test suite for AQuant Base CCR
TYPED_TEST_SUITE(TestCkTileGemmAQuant, AQuantBaseCCRTypes);
// AQuant tests
TYPED_TEST(TestCkTileGemmAQuant, AQuantGroupedTest)
{
this->run_test_with_validation(1024, 1024, 1024);
}

View File

@@ -0,0 +1,42 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "ck_tile/host.hpp"
#include "ck_tile/ops/gemm.hpp"
#include <gtest/gtest.h>
#include <memory>
#include "test_gemm_quant_fixtures.hpp"
// Type aliases for readability
using RowMajor = ck_tile::tensor_layout::gemm::RowMajor;
using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor;
using FP8 = ck_tile::fp8_t;
using BF8 = ck_tile::bf8_t;
using Half = ck_tile::half_t;
using PkInt4 = ck_tile::pk_int4_t;
using AQuantGrouped = std::integral_constant<ck_tile::QuantType, ck_tile::QuantType::AQuantGrouped>;
using GroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
// Type combinations for AQuant tests - RCR layout base configuration
// Tuple format: <ALayout, BLayout, CLayout, AQLayout, ADataType, BDataType, QDataType, CDataType,
// QuantType, GemmConfig, QuantGroupSize>
// clang-format off
using AQuantBaseRCRTypes = ::testing::Types<
// PreshuffleQuant = false && TransposeC = false (RCR layout with RowMajor AQ)
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, FP8, FP8, float, Half, AQuantGrouped, GemmConfigBase, GroupSize>,
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, BF8, BF8, float, Half, AQuantGrouped, GemmConfigBase, GroupSize>,
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, PkInt4, FP8, FP8, Half, AQuantGrouped, GemmConfigBase, GroupSize>,
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, PkInt4, BF8, BF8, Half, AQuantGrouped, GemmConfigBase, GroupSize>
>;
// clang-format on
// Test suite for AQuant Base RCR
TYPED_TEST_SUITE(TestCkTileGemmAQuant, AQuantBaseRCRTypes);
// AQuant tests
TYPED_TEST(TestCkTileGemmAQuant, AQuantGroupedTest)
{
this->run_test_with_validation(1024, 1024, 1024);
}

View File

@@ -0,0 +1,46 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "ck_tile/host.hpp"
#include "ck_tile/ops/gemm.hpp"
#include <gtest/gtest.h>
#include <memory>
#include "test_gemm_quant_fixtures.hpp"
// Type aliases for readability
using RowMajor = ck_tile::tensor_layout::gemm::RowMajor;
using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor;
using FP8 = ck_tile::fp8_t;
using BF8 = ck_tile::bf8_t;
using Half = ck_tile::half_t;
using PkInt4 = ck_tile::pk_int4_t;
using AQuantGrouped = std::integral_constant<ck_tile::QuantType, ck_tile::QuantType::AQuantGrouped>;
using GroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
// Type combinations for AQuant tests - RRR and CRR layouts
// Tuple format: <ALayout, BLayout, CLayout, AQLayout, ADataType, BDataType, QDataType, CDataType,
// QuantType, GemmConfig, QuantGroupSize>
// clang-format off
using AQuantBaseRRRCRRTypes = ::testing::Types<
// RRR layout (RowMajor A, RowMajor B, RowMajor C with RowMajor AQ)
std::tuple<RowMajor, RowMajor, RowMajor, RowMajor, FP8, FP8, float, Half, AQuantGrouped, GemmConfigBase, GroupSize>,
std::tuple<RowMajor, RowMajor, RowMajor, RowMajor, BF8, BF8, float, Half, AQuantGrouped, GemmConfigBase, GroupSize>,
std::tuple<RowMajor, RowMajor, RowMajor, RowMajor, PkInt4, FP8, FP8, Half, AQuantGrouped, GemmConfigBase, GroupSize>,
std::tuple<RowMajor, RowMajor, RowMajor, RowMajor, PkInt4, BF8, BF8, Half, AQuantGrouped, GemmConfigBase, GroupSize>,
// CRR layout (ColumnMajor A, RowMajor B, RowMajor C with RowMajor AQ)
std::tuple<ColumnMajor, RowMajor, RowMajor, RowMajor, FP8, FP8, float, Half, AQuantGrouped, GemmConfigBase, GroupSize>,
std::tuple<ColumnMajor, RowMajor, RowMajor, RowMajor, PkInt4, FP8, FP8, Half, AQuantGrouped, GemmConfigBase, GroupSize>
>;
// clang-format on
// Test suite for AQuant Base RRR/CRR
TYPED_TEST_SUITE(TestCkTileGemmAQuant, AQuantBaseRRRCRRTypes);
// AQuant tests
TYPED_TEST(TestCkTileGemmAQuant, AQuantGroupedTest)
{
this->run_test_with_validation(1024, 1024, 1024);
}

View File

@@ -0,0 +1,41 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "ck_tile/host.hpp"
#include "ck_tile/ops/gemm.hpp"
#include <gtest/gtest.h>
#include <memory>
#include "test_gemm_quant_fixtures.hpp"
// Type aliases for readability
using RowMajor = ck_tile::tensor_layout::gemm::RowMajor;
using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor;
using FP8 = ck_tile::fp8_t;
using BF8 = ck_tile::bf8_t;
using Half = ck_tile::half_t;
using PkInt4 = ck_tile::pk_int4_t;
using AQuantGrouped = std::integral_constant<ck_tile::QuantType, ck_tile::QuantType::AQuantGrouped>;
using GroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
// Type combinations for AQuant tests - Prefill Configuration
// Tuple format: <ALayout, BLayout, CLayout, AQLayout, ADataType, BDataType, QDataType, CDataType,
// QuantType, GemmConfig, QuantGroupSize>
// clang-format off
using AQuantPrefillTypes = ::testing::Types<
// RCR layout - with the Prefill BlockTile Config.
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, FP8, FP8, float, Half, AQuantGrouped, GemmConfigPrefill, GroupSize>,
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, PkInt4, FP8, FP8, Half, AQuantGrouped, GemmConfigPrefill, GroupSize>,
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, PkInt4, BF8, BF8, Half, AQuantGrouped, GemmConfigPrefill, GroupSize>
>;
// clang-format on
// Test suite for AQuant Prefill
TYPED_TEST_SUITE(TestCkTileGemmAQuant, AQuantPrefillTypes);
// AQuant tests
TYPED_TEST(TestCkTileGemmAQuant, AQuantGroupedTest)
{
this->run_test_with_validation(1024, 1024, 1024);
}

View File

@@ -0,0 +1,48 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "ck_tile/host.hpp"
#include "ck_tile/ops/gemm.hpp"
#include <gtest/gtest.h>
#include <memory>
#include "test_gemm_quant_fixtures.hpp"
// Type aliases for readability
using RowMajor = ck_tile::tensor_layout::gemm::RowMajor;
using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor;
using FP8 = ck_tile::fp8_t;
using BF8 = ck_tile::bf8_t;
using Half = ck_tile::half_t;
using PkInt4 = ck_tile::pk_int4_t;
using AQuantGrouped = std::integral_constant<ck_tile::QuantType, ck_tile::QuantType::AQuantGrouped>;
using GroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
// Type combinations for AQuant tests - PreshuffleQuant Configurations
// Tuple format: <ALayout, BLayout, CLayout, AQLayout, ADataType, BDataType, QDataType, CDataType,
// QuantType, GemmConfig, QuantGroupSize>
// clang-format off
using AQuantPreshuffleTypes = ::testing::Types<
// PreshuffleQuant = true && TransposeC = false (with RowMajor AQ - PreshuffleQuant only supports RowMajor)
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, FP8, FP8, float, Half, AQuantGrouped, GemmConfigPreshuffleQuant, GroupSize>,
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, BF8, BF8, float, Half, AQuantGrouped, GemmConfigPreshuffleQuant, GroupSize>,
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, PkInt4, FP8, FP8, Half, AQuantGrouped, GemmConfigPreshuffleQuant, GroupSize>,
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, PkInt4, BF8, BF8, Half, AQuantGrouped, GemmConfigPreshuffleQuant, GroupSize>,
// PreshuffleQuant = true && TransposeC = true (with RowMajor AQ - PreshuffleQuant only supports RowMajor)
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, FP8, FP8, float, Half, AQuantGrouped, GemmConfigPreshuffleQuantTransposeC, GroupSize>,
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, BF8, BF8, float, Half, AQuantGrouped, GemmConfigPreshuffleQuantTransposeC, GroupSize>,
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, PkInt4, FP8, FP8, Half, AQuantGrouped, GemmConfigPreshuffleQuantTransposeC, GroupSize>,
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, PkInt4, BF8, BF8, Half, AQuantGrouped, GemmConfigPreshuffleQuantTransposeC, GroupSize>
>;
// clang-format on
// Test suite for AQuant Preshuffle
TYPED_TEST_SUITE(TestCkTileGemmAQuant, AQuantPreshuffleTypes);
// AQuant tests
TYPED_TEST(TestCkTileGemmAQuant, AQuantGroupedTest)
{
this->run_test_with_validation(1024, 1024, 1024);
}

View File

@@ -0,0 +1,40 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "ck_tile/host.hpp"
#include "ck_tile/ops/gemm.hpp"
#include <gtest/gtest.h>
#include <memory>
#include "test_gemm_quant_fixtures.hpp"
// Type aliases for readability
using RowMajor = ck_tile::tensor_layout::gemm::RowMajor;
using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor;
using FP8 = ck_tile::fp8_t;
using BF8 = ck_tile::bf8_t;
using Half = ck_tile::half_t;
using PkInt4 = ck_tile::pk_int4_t;
using AQuantGrouped = std::integral_constant<ck_tile::QuantType, ck_tile::QuantType::AQuantGrouped>;
using GroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
// Type combinations for AQuant tests - TransposeC Configuration
// Tuple format: <ALayout, BLayout, CLayout, AQLayout, ADataType, BDataType, QDataType, CDataType,
// QuantType, GemmConfig, QuantGroupSize>
// clang-format off
using AQuantTransposeCTypes = ::testing::Types<
// PreshuffleQuant = false && TransposeC = true (with RowMajor AQ)
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, FP8, FP8, float, Half, AQuantGrouped, GemmConfigTransposeC, GroupSize>,
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, PkInt4, FP8, FP8, Half, AQuantGrouped, GemmConfigTransposeC, GroupSize>
>;
// clang-format on
// Test suite for AQuant TransposeC
TYPED_TEST_SUITE(TestCkTileGemmAQuant, AQuantTransposeCTypes);
// AQuant tests
TYPED_TEST(TestCkTileGemmAQuant, AQuantGroupedTest)
{
this->run_test_with_validation(1024, 1024, 1024);
}

View File

@@ -1,99 +0,0 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "ck_tile/host.hpp"
#include "ck_tile/ops/gemm.hpp"
#include <gtest/gtest.h>
#include <memory>
#include "test_gemm_quant_fixtures.hpp"
// Type aliases for readability
using RowMajor = ck_tile::tensor_layout::gemm::RowMajor;
using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor;
using FP8 = ck_tile::fp8_t;
using BF8 = ck_tile::bf8_t;
using Half = ck_tile::half_t;
using PkInt4 = ck_tile::pk_int4_t;
using BF16 = ck_tile::bf16_t;
using UInt8 = ck_tile::pk_fp4_raw_t;
using BQuantGrouped = std::integral_constant<ck_tile::QuantType, ck_tile::QuantType::BQuantGrouped>;
using GroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
using GroupSize64 = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 64>>;
using GroupSize32 = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 32>>;
// 2d block sizes for BQuant
using GroupSize2D8N = ck_tile::QuantGroupShape<ck_tile::sequence<1, 8, 128>>;
using GroupSize2D16N = ck_tile::QuantGroupShape<ck_tile::sequence<1, 16, 128>>;
using GroupSize2D32N = ck_tile::QuantGroupShape<ck_tile::sequence<1, 32, 128>>;
using GroupSize2D64N = ck_tile::QuantGroupShape<ck_tile::sequence<1, 64, 128>>;
using GroupSize2D128N = ck_tile::QuantGroupShape<ck_tile::sequence<1, 128, 128>>;
// Type combinations for BQuant tests (without PreshuffleB)
// Tuple format: <ALayout, BLayout, CLayout, BQLayout, ADataType, BDataType, QDataType, CDataType,
// QuantType, GemmConfig, QuantGroupSize>
// clang-format off
using BQuantTypes = ::testing::Types<
// 1d cases with grouping only on k axis
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize>,
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF8, BF8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize>,
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigBase, GroupSize>,
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF8, PkInt4, BF8, Half, BQuantGrouped, GemmConfigBase, GroupSize>,
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize64>,
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF8, BF8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize64>,
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigBase, GroupSize64>,
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF8, PkInt4, BF8, Half, BQuantGrouped, GemmConfigBase, GroupSize64>,
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF16, UInt8, UInt8, BF16, BQuantGrouped, GemmConfigMxFp4, GroupSize64>,
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF16, UInt8, UInt8, BF16, BQuantGrouped, GemmConfigMxFp4, GroupSize32>,
// 2d cases with grouping also on the n axis
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize2D8N>,
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF8, BF8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize2D8N>,
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigBase, GroupSize2D8N>,
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF8, PkInt4, BF8, Half, BQuantGrouped, GemmConfigBase, GroupSize2D8N>,
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize2D16N>,
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF8, BF8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize2D16N>,
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigBase, GroupSize2D16N>,
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF8, PkInt4, BF8, Half, BQuantGrouped, GemmConfigBase, GroupSize2D16N>,
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize2D32N>,
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF8, BF8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize2D32N>,
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigBase, GroupSize2D32N>,
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF8, PkInt4, BF8, Half, BQuantGrouped, GemmConfigBase, GroupSize2D32N>,
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize2D64N>,
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF8, BF8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize2D64N>,
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigBase, GroupSize2D64N>,
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF8, PkInt4, BF8, Half, BQuantGrouped, GemmConfigBase, GroupSize2D64N>,
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize2D128N>,
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF8, BF8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize2D128N>,
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigBase, GroupSize2D128N>,
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF8, PkInt4, BF8, Half, BQuantGrouped, GemmConfigBase, GroupSize2D128N>,
// some cases with transpose layouts
std::tuple< RowMajor, RowMajor, RowMajor, RowMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize64>,
std::tuple<ColumnMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize64>,
std::tuple<ColumnMajor, RowMajor, RowMajor, RowMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize64>,
std::tuple< RowMajor, RowMajor, RowMajor, RowMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize2D64N>,
std::tuple<ColumnMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize2D64N>,
std::tuple<ColumnMajor, RowMajor, RowMajor, RowMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize2D64N>,
// pkint4 + transpose cases
std::tuple< RowMajor, RowMajor, RowMajor, RowMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigBase, GroupSize64>,
std::tuple<ColumnMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigBase, GroupSize64>,
std::tuple<ColumnMajor, RowMajor, RowMajor, RowMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigBase, GroupSize64>,
std::tuple< RowMajor, RowMajor, RowMajor, RowMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigBase, GroupSize2D64N>,
std::tuple<ColumnMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigBase, GroupSize2D64N>,
std::tuple<ColumnMajor, RowMajor, RowMajor, RowMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigBase, GroupSize2D64N>
>;
// clang-format on
// Test suite for BQuant (without PreshuffleB)
TYPED_TEST_SUITE(TestCkTileGemmBQuant, BQuantTypes);
// BQuant tests
TYPED_TEST(TestCkTileGemmBQuant, BQuantGroupedTest)
{
this->run_test_with_validation(1024, 1024, 1024);
}

View File

@@ -0,0 +1,41 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "ck_tile/host.hpp"
#include "ck_tile/ops/gemm.hpp"
#include <gtest/gtest.h>
#include <memory>
#include "test_gemm_quant_fixtures.hpp"
// Type aliases for readability
using RowMajor = ck_tile::tensor_layout::gemm::RowMajor;
using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor;
using FP8 = ck_tile::fp8_t;
using BF8 = ck_tile::bf8_t;
using Half = ck_tile::half_t;
using PkInt4 = ck_tile::pk_int4_t;
using BQuantGrouped = std::integral_constant<ck_tile::QuantType, ck_tile::QuantType::BQuantGrouped>;
using GroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
// Type combinations for BQuant tests - 1D GroupSize 128
// Tuple format: <ALayout, BLayout, CLayout, BQLayout, ADataType, BDataType, QDataType, CDataType,
// QuantType, GemmConfig, QuantGroupSize>
// clang-format off
using BQuant1D128Types = ::testing::Types<
// 1d cases with grouping only on k axis
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize>,
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigBase, GroupSize>,
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF8, PkInt4, BF8, Half, BQuantGrouped, GemmConfigBase, GroupSize>
>;
// clang-format on
// Test suite for BQuant 1D 128
TYPED_TEST_SUITE(TestCkTileGemmBQuant, BQuant1D128Types);
// BQuant tests
TYPED_TEST(TestCkTileGemmBQuant, BQuantGroupedTest)
{
this->run_test_with_validation(1024, 1024, 1024);
}

View File

@@ -0,0 +1,41 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "ck_tile/host.hpp"
#include "ck_tile/ops/gemm.hpp"
#include <gtest/gtest.h>
#include <memory>
#include "test_gemm_quant_fixtures.hpp"
// Type aliases for readability
using RowMajor = ck_tile::tensor_layout::gemm::RowMajor;
using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor;
using FP8 = ck_tile::fp8_t;
using BF8 = ck_tile::bf8_t;
using Half = ck_tile::half_t;
using PkInt4 = ck_tile::pk_int4_t;
using BQuantGrouped = std::integral_constant<ck_tile::QuantType, ck_tile::QuantType::BQuantGrouped>;
using GroupSize64 = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 64>>;
// Type combinations for BQuant tests - 1D GroupSize 64
// Tuple format: <ALayout, BLayout, CLayout, BQLayout, ADataType, BDataType, QDataType, CDataType,
// QuantType, GemmConfig, QuantGroupSize>
// clang-format off
using BQuant1D64Types = ::testing::Types<
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize64>,
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF8, BF8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize64>,
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigBase, GroupSize64>,
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF8, PkInt4, BF8, Half, BQuantGrouped, GemmConfigBase, GroupSize64>
>;
// clang-format on
// Test suite for BQuant 1D 64
TYPED_TEST_SUITE(TestCkTileGemmBQuant, BQuant1D64Types);
// BQuant tests
TYPED_TEST(TestCkTileGemmBQuant, BQuantGroupedTest)
{
this->run_test_with_validation(1024, 1024, 1024);
}

View File

@@ -0,0 +1,41 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "ck_tile/host.hpp"
#include "ck_tile/ops/gemm.hpp"
#include <gtest/gtest.h>
#include <memory>
#include "test_gemm_quant_fixtures.hpp"
// Type aliases for readability
using RowMajor = ck_tile::tensor_layout::gemm::RowMajor;
using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor;
using FP8 = ck_tile::fp8_t;
using BF8 = ck_tile::bf8_t;
using Half = ck_tile::half_t;
using PkInt4 = ck_tile::pk_int4_t;
using BQuantGrouped = std::integral_constant<ck_tile::QuantType, ck_tile::QuantType::BQuantGrouped>;
using GroupSize2D128N = ck_tile::QuantGroupShape<ck_tile::sequence<1, 128, 128>>;
// Type combinations for BQuant tests - 2D Large N (128N)
// Tuple format: <ALayout, BLayout, CLayout, BQLayout, ADataType, BDataType, QDataType, CDataType,
// QuantType, GemmConfig, QuantGroupSize>
// clang-format off
using BQuant2DLargeNTypes = ::testing::Types<
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize2D128N>,
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF8, BF8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize2D128N>,
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigBase, GroupSize2D128N>,
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF8, PkInt4, BF8, Half, BQuantGrouped, GemmConfigBase, GroupSize2D128N>
>;
// clang-format on
// Test suite for BQuant 2D Large N
TYPED_TEST_SUITE(TestCkTileGemmBQuant, BQuant2DLargeNTypes);
// BQuant tests
TYPED_TEST(TestCkTileGemmBQuant, BQuantGroupedTest)
{
this->run_test_with_validation(1024, 1024, 1024);
}

View File

@@ -0,0 +1,48 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "ck_tile/host.hpp"
#include "ck_tile/ops/gemm.hpp"
#include <gtest/gtest.h>
#include <memory>
#include "test_gemm_quant_fixtures.hpp"
// Type aliases for readability
using RowMajor = ck_tile::tensor_layout::gemm::RowMajor;
using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor;
using FP8 = ck_tile::fp8_t;
using BF8 = ck_tile::bf8_t;
using Half = ck_tile::half_t;
using PkInt4 = ck_tile::pk_int4_t;
using BQuantGrouped = std::integral_constant<ck_tile::QuantType, ck_tile::QuantType::BQuantGrouped>;
// 2d block sizes for BQuant
using GroupSize2D32N = ck_tile::QuantGroupShape<ck_tile::sequence<1, 32, 128>>;
using GroupSize2D64N = ck_tile::QuantGroupShape<ck_tile::sequence<1, 64, 128>>;
// Type combinations for BQuant tests - 2D Medium N (32N and 64N)
// Tuple format: <ALayout, BLayout, CLayout, BQLayout, ADataType, BDataType, QDataType, CDataType,
// QuantType, GemmConfig, QuantGroupSize>
// clang-format off
using BQuant2DMediumNTypes = ::testing::Types<
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize2D32N>,
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF8, BF8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize2D32N>,
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigBase, GroupSize2D32N>,
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF8, PkInt4, BF8, Half, BQuantGrouped, GemmConfigBase, GroupSize2D32N>,
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize2D64N>,
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF8, BF8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize2D64N>,
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigBase, GroupSize2D64N>,
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF8, PkInt4, BF8, Half, BQuantGrouped, GemmConfigBase, GroupSize2D64N>
>;
// clang-format on
// Test suite for BQuant 2D Medium N
TYPED_TEST_SUITE(TestCkTileGemmBQuant, BQuant2DMediumNTypes);
// BQuant tests
TYPED_TEST(TestCkTileGemmBQuant, BQuantGroupedTest)
{
this->run_test_with_validation(1024, 1024, 1024);
}

View File

@@ -0,0 +1,49 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "ck_tile/host.hpp"
#include "ck_tile/ops/gemm.hpp"
#include <gtest/gtest.h>
#include <memory>
#include "test_gemm_quant_fixtures.hpp"
// Type aliases for readability
using RowMajor = ck_tile::tensor_layout::gemm::RowMajor;
using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor;
using FP8 = ck_tile::fp8_t;
using BF8 = ck_tile::bf8_t;
using Half = ck_tile::half_t;
using PkInt4 = ck_tile::pk_int4_t;
using BQuantGrouped = std::integral_constant<ck_tile::QuantType, ck_tile::QuantType::BQuantGrouped>;
// 2d block sizes for BQuant
using GroupSize2D8N = ck_tile::QuantGroupShape<ck_tile::sequence<1, 8, 128>>;
using GroupSize2D16N = ck_tile::QuantGroupShape<ck_tile::sequence<1, 16, 128>>;
// Type combinations for BQuant tests - 2D Small N (8N and 16N)
// Tuple format: <ALayout, BLayout, CLayout, BQLayout, ADataType, BDataType, QDataType, CDataType,
// QuantType, GemmConfig, QuantGroupSize>
// clang-format off
using BQuant2DSmallNTypes = ::testing::Types<
// 2d cases with grouping also on the n axis
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize2D8N>,
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF8, BF8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize2D8N>,
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigBase, GroupSize2D8N>,
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF8, PkInt4, BF8, Half, BQuantGrouped, GemmConfigBase, GroupSize2D8N>,
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize2D16N>,
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF8, BF8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize2D16N>,
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigBase, GroupSize2D16N>,
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF8, PkInt4, BF8, Half, BQuantGrouped, GemmConfigBase, GroupSize2D16N>
>;
// clang-format on
// Test suite for BQuant 2D Small N
TYPED_TEST_SUITE(TestCkTileGemmBQuant, BQuant2DSmallNTypes);
// BQuant tests
TYPED_TEST(TestCkTileGemmBQuant, BQuantGroupedTest)
{
this->run_test_with_validation(1024, 1024, 1024);
}

View File

@@ -1,93 +0,0 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "ck_tile/host.hpp"
#include "ck_tile/ops/gemm.hpp"
#include <gtest/gtest.h>
#include <memory>
#include "test_gemm_quant_fixtures.hpp"
// Type aliases for readability
using RowMajor = ck_tile::tensor_layout::gemm::RowMajor;
using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor;
using FP8 = ck_tile::fp8_t;
using BF8 = ck_tile::bf8_t;
using Half = ck_tile::half_t;
using PkInt4 = ck_tile::pk_int4_t;
using BQuantGrouped = std::integral_constant<ck_tile::QuantType, ck_tile::QuantType::BQuantGrouped>;
using GroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
// 2d block sizes for BQuant
using GroupSize2D8N = ck_tile::QuantGroupShape<ck_tile::sequence<1, 8, 128>>;
using GroupSize2D16N = ck_tile::QuantGroupShape<ck_tile::sequence<1, 16, 128>>;
using GroupSize2D32N = ck_tile::QuantGroupShape<ck_tile::sequence<1, 32, 128>>;
using GroupSize2D64N = ck_tile::QuantGroupShape<ck_tile::sequence<1, 64, 128>>;
// Type combinations for BQuant tests with PreshuffleB
// Tuple format: <ALayout, BLayout, CLayout, BQLayout, ADataType, BDataType, QDataType, CDataType,
// QuantType, GemmConfig, QuantGroupSize>
// clang-format off
using BPreshuffleBQuantTypes = ::testing::Types<
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigPreshuffleBDecode, GroupSize>,
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF8, BF8, float, Half, BQuantGrouped, GemmConfigPreshuffleBDecode, GroupSize>,
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigPreshuffleBDecode, GroupSize>,
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF8, PkInt4, BF8, Half, BQuantGrouped, GemmConfigPreshuffleBDecode, GroupSize>,
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigPreshuffleBPrefill, GroupSize>,
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF8, BF8, float, Half, BQuantGrouped, GemmConfigPreshuffleBPrefill, GroupSize>,
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigPreshuffleBPrefill, GroupSize>,
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF8, PkInt4, BF8, Half, BQuantGrouped, GemmConfigPreshuffleBPrefill, GroupSize>,
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigPreshuffleBPrefillTiledPermuteN, GroupSize>,
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF8, BF8, float, Half, BQuantGrouped, GemmConfigPreshuffleBPrefillTiledPermuteN, GroupSize>,
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigPreshuffleBPrefillTiledPermuteN, GroupSize>,
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF8, PkInt4, BF8, Half, BQuantGrouped, GemmConfigPreshuffleBPrefillTiledPermuteN, GroupSize>,
// //2d cases with preshuffle B
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigPreshuffleBDecode, GroupSize2D8N>,
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF8, BF8, float, Half, BQuantGrouped, GemmConfigPreshuffleBDecode, GroupSize2D8N>,
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigPreshuffleBDecode, GroupSize2D8N>,
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF8, PkInt4, BF8, Half, BQuantGrouped, GemmConfigPreshuffleBDecode, GroupSize2D8N>,
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigPreshuffleBDecode, GroupSize2D16N>,
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF8, BF8, float, Half, BQuantGrouped, GemmConfigPreshuffleBDecode, GroupSize2D16N>,
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigPreshuffleBDecode, GroupSize2D16N>,
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF8, PkInt4, BF8, Half, BQuantGrouped, GemmConfigPreshuffleBDecode, GroupSize2D16N>,
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigPreshuffleBDecode, GroupSize2D32N>,
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF8, BF8, float, Half, BQuantGrouped, GemmConfigPreshuffleBDecode, GroupSize2D32N>,
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigPreshuffleBDecode, GroupSize2D32N>,
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF8, PkInt4, BF8, Half, BQuantGrouped, GemmConfigPreshuffleBDecode, GroupSize2D32N>,
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigPreshuffleBDecode, GroupSize2D64N>,
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF8, BF8, float, Half, BQuantGrouped, GemmConfigPreshuffleBDecode, GroupSize2D64N>,
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigPreshuffleBDecode, GroupSize2D64N>,
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF8, PkInt4, BF8, Half, BQuantGrouped, GemmConfigPreshuffleBDecode, GroupSize2D64N>,
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigPreshuffleBPrefill, GroupSize2D8N>,
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF8, BF8, float, Half, BQuantGrouped, GemmConfigPreshuffleBPrefill, GroupSize2D8N>,
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigPreshuffleBPrefill, GroupSize2D8N>,
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF8, PkInt4, BF8, Half, BQuantGrouped, GemmConfigPreshuffleBPrefill, GroupSize2D8N>,
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigPreshuffleBPrefill, GroupSize2D16N>,
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF8, BF8, float, Half, BQuantGrouped, GemmConfigPreshuffleBPrefill, GroupSize2D16N>,
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigPreshuffleBPrefill, GroupSize2D16N>,
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF8, PkInt4, BF8, Half, BQuantGrouped, GemmConfigPreshuffleBPrefill, GroupSize2D16N>,
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigPreshuffleBPrefill, GroupSize2D32N>,
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF8, BF8, float, Half, BQuantGrouped, GemmConfigPreshuffleBPrefill, GroupSize2D32N>,
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigPreshuffleBPrefill, GroupSize2D32N>,
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF8, PkInt4, BF8, Half, BQuantGrouped, GemmConfigPreshuffleBPrefill, GroupSize2D32N>,
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigPreshuffleBPrefill, GroupSize2D64N>,
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF8, BF8, float, Half, BQuantGrouped, GemmConfigPreshuffleBPrefill, GroupSize2D64N>,
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigPreshuffleBPrefill, GroupSize2D64N>,
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF8, PkInt4, BF8, Half, BQuantGrouped, GemmConfigPreshuffleBPrefill, GroupSize2D64N>
>;
// clang-format on
// Test suite for BQuant with PreshuffleB
TYPED_TEST_SUITE(TestCkTileGemmPreshuffleBBQuant, BPreshuffleBQuantTypes);
// BQuant PreshuffleB tests
TYPED_TEST(TestCkTileGemmPreshuffleBBQuant, BQuantPreshuffleTest)
{
this->run_test_with_validation(1024, 1024, 1024);
}

View File

@@ -0,0 +1,39 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "ck_tile/host.hpp"
#include "ck_tile/ops/gemm.hpp"
#include <gtest/gtest.h>
#include <memory>
#include "test_gemm_quant_fixtures.hpp"
// Type aliases for readability
using RowMajor = ck_tile::tensor_layout::gemm::RowMajor;
using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor;
using FP8 = ck_tile::fp8_t;
using BF8 = ck_tile::bf8_t;
using Half = ck_tile::half_t;
using PkInt4 = ck_tile::pk_int4_t;
using BQuantGrouped = std::integral_constant<ck_tile::QuantType, ck_tile::QuantType::BQuantGrouped>;
using GroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
// Type combinations for BQuant Preshuffle tests - Decode Config 1D
// Tuple format: <ALayout, BLayout, CLayout, BQLayout, ADataType, BDataType, QDataType, CDataType,
// QuantType, GemmConfig, QuantGroupSize>
// clang-format off
using BPreshuffleDecode1DTypes = ::testing::Types<
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigPreshuffleBDecode, GroupSize>,
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigPreshuffleBDecode, GroupSize>
>;
// clang-format on
// Test suite for BQuant Preshuffle Decode 1D
TYPED_TEST_SUITE(TestCkTileGemmPreshuffleBBQuant, BPreshuffleDecode1DTypes);
// BQuant PreshuffleB tests
TYPED_TEST(TestCkTileGemmPreshuffleBBQuant, BQuantPreshuffleTest)
{
this->run_test_with_validation(1024, 1024, 1024);
}

View File

@@ -0,0 +1,51 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "ck_tile/host.hpp"
#include "ck_tile/ops/gemm.hpp"
#include <gtest/gtest.h>
#include <memory>
#include "test_gemm_quant_fixtures.hpp"
// Type aliases for readability
using RowMajor = ck_tile::tensor_layout::gemm::RowMajor;
using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor;
using FP8 = ck_tile::fp8_t;
using BF8 = ck_tile::bf8_t;
using Half = ck_tile::half_t;
using PkInt4 = ck_tile::pk_int4_t;
using BQuantGrouped = std::integral_constant<ck_tile::QuantType, ck_tile::QuantType::BQuantGrouped>;
// 2d block sizes for BQuant
using GroupSize2D8N = ck_tile::QuantGroupShape<ck_tile::sequence<1, 8, 128>>;
using GroupSize2D16N = ck_tile::QuantGroupShape<ck_tile::sequence<1, 16, 128>>;
using GroupSize2D32N = ck_tile::QuantGroupShape<ck_tile::sequence<1, 32, 128>>;
using GroupSize2D64N = ck_tile::QuantGroupShape<ck_tile::sequence<1, 64, 128>>;
// Type combinations for BQuant Preshuffle tests - Decode 2D
// Tuple format: <ALayout, BLayout, CLayout, BQLayout, ADataType, BDataType, QDataType, CDataType,
// QuantType, GemmConfig, QuantGroupSize>
// clang-format off
using BPreshuffleDecode2DTypes = ::testing::Types<
// 2d cases with preshuffle B
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigPreshuffleBDecode, GroupSize2D8N>,
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigPreshuffleBDecode, GroupSize2D8N>,
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigPreshuffleBDecode, GroupSize2D16N>,
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigPreshuffleBDecode, GroupSize2D16N>,
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigPreshuffleBDecode, GroupSize2D32N>,
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigPreshuffleBDecode, GroupSize2D32N>,
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigPreshuffleBDecode, GroupSize2D64N>,
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigPreshuffleBDecode, GroupSize2D64N>
>;
// clang-format on
// Test suite for BQuant Preshuffle Decode 2D
TYPED_TEST_SUITE(TestCkTileGemmPreshuffleBBQuant, BPreshuffleDecode2DTypes);
// BQuant PreshuffleB tests
TYPED_TEST(TestCkTileGemmPreshuffleBBQuant, BQuantPreshuffleTest)
{
this->run_test_with_validation(1024, 1024, 1024);
}

View File

@@ -0,0 +1,41 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "ck_tile/host.hpp"
#include "ck_tile/ops/gemm.hpp"
#include <gtest/gtest.h>
#include <memory>
#include "test_gemm_quant_fixtures.hpp"
// Type aliases for readability
using RowMajor = ck_tile::tensor_layout::gemm::RowMajor;
using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor;
using FP8 = ck_tile::fp8_t;
using BF8 = ck_tile::bf8_t;
using Half = ck_tile::half_t;
using PkInt4 = ck_tile::pk_int4_t;
using BQuantGrouped = std::integral_constant<ck_tile::QuantType, ck_tile::QuantType::BQuantGrouped>;
using GroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
// Type combinations for BQuant Preshuffle tests - Prefill Config 1D
// Tuple format: <ALayout, BLayout, CLayout, BQLayout, ADataType, BDataType, QDataType, CDataType,
// QuantType, GemmConfig, QuantGroupSize>
// clang-format off
using BPreshufflePrefill1DTypes = ::testing::Types<
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigPreshuffleBPrefill, GroupSize>,
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF8, BF8, float, Half, BQuantGrouped, GemmConfigPreshuffleBPrefill, GroupSize>,
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigPreshuffleBPrefill, GroupSize>,
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF8, PkInt4, BF8, Half, BQuantGrouped, GemmConfigPreshuffleBPrefill, GroupSize>
>;
// clang-format on
// Test suite for BQuant Preshuffle Prefill 1D
TYPED_TEST_SUITE(TestCkTileGemmPreshuffleBBQuant, BPreshufflePrefill1DTypes);
// BQuant PreshuffleB tests
TYPED_TEST(TestCkTileGemmPreshuffleBBQuant, BQuantPreshuffleTest)
{
this->run_test_with_validation(1024, 1024, 1024);
}

View File

@@ -0,0 +1,58 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "ck_tile/host.hpp"
#include "ck_tile/ops/gemm.hpp"
#include <gtest/gtest.h>
#include <memory>
#include "test_gemm_quant_fixtures.hpp"
// Type aliases for readability
using RowMajor = ck_tile::tensor_layout::gemm::RowMajor;
using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor;
using FP8 = ck_tile::fp8_t;
using BF8 = ck_tile::bf8_t;
using Half = ck_tile::half_t;
using PkInt4 = ck_tile::pk_int4_t;
using BQuantGrouped = std::integral_constant<ck_tile::QuantType, ck_tile::QuantType::BQuantGrouped>;
// 2d block sizes for BQuant
using GroupSize2D8N = ck_tile::QuantGroupShape<ck_tile::sequence<1, 8, 128>>;
using GroupSize2D16N = ck_tile::QuantGroupShape<ck_tile::sequence<1, 16, 128>>;
using GroupSize2D32N = ck_tile::QuantGroupShape<ck_tile::sequence<1, 32, 128>>;
using GroupSize2D64N = ck_tile::QuantGroupShape<ck_tile::sequence<1, 64, 128>>;
// Type combinations for BQuant Preshuffle tests - Prefill 2D
// Tuple format: <ALayout, BLayout, CLayout, BQLayout, ADataType, BDataType, QDataType, CDataType,
// QuantType, GemmConfig, QuantGroupSize>
// clang-format off
using BPreshufflePrefill2DTypes = ::testing::Types<
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigPreshuffleBPrefill, GroupSize2D8N>,
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF8, BF8, float, Half, BQuantGrouped, GemmConfigPreshuffleBPrefill, GroupSize2D8N>,
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigPreshuffleBPrefill, GroupSize2D8N>,
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF8, PkInt4, BF8, Half, BQuantGrouped, GemmConfigPreshuffleBPrefill, GroupSize2D8N>,
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigPreshuffleBPrefill, GroupSize2D16N>,
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF8, BF8, float, Half, BQuantGrouped, GemmConfigPreshuffleBPrefill, GroupSize2D16N>,
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigPreshuffleBPrefill, GroupSize2D16N>,
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF8, PkInt4, BF8, Half, BQuantGrouped, GemmConfigPreshuffleBPrefill, GroupSize2D16N>,
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigPreshuffleBPrefill, GroupSize2D32N>,
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF8, BF8, float, Half, BQuantGrouped, GemmConfigPreshuffleBPrefill, GroupSize2D32N>,
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigPreshuffleBPrefill, GroupSize2D32N>,
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF8, PkInt4, BF8, Half, BQuantGrouped, GemmConfigPreshuffleBPrefill, GroupSize2D32N>,
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigPreshuffleBPrefill, GroupSize2D64N>,
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF8, BF8, float, Half, BQuantGrouped, GemmConfigPreshuffleBPrefill, GroupSize2D64N>,
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigPreshuffleBPrefill, GroupSize2D64N>,
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF8, PkInt4, BF8, Half, BQuantGrouped, GemmConfigPreshuffleBPrefill, GroupSize2D64N>
>;
// clang-format on
// Test suite for BQuant Preshuffle Prefill 2D
TYPED_TEST_SUITE(TestCkTileGemmPreshuffleBBQuant, BPreshufflePrefill2DTypes);
// BQuant PreshuffleB tests
TYPED_TEST(TestCkTileGemmPreshuffleBBQuant, BQuantPreshuffleTest)
{
this->run_test_with_validation(1024, 1024, 1024);
}

View File

@@ -0,0 +1,40 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "ck_tile/host.hpp"
#include "ck_tile/ops/gemm.hpp"
#include <gtest/gtest.h>
#include <memory>
#include "test_gemm_quant_fixtures.hpp"
// Type aliases for readability
using RowMajor = ck_tile::tensor_layout::gemm::RowMajor;
using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor;
using FP8 = ck_tile::fp8_t;
using BF8 = ck_tile::bf8_t;
using Half = ck_tile::half_t;
using PkInt4 = ck_tile::pk_int4_t;
using BQuantGrouped = std::integral_constant<ck_tile::QuantType, ck_tile::QuantType::BQuantGrouped>;
using GroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
// Type combinations for BQuant Preshuffle tests - TiledPermuteN Config
// Tuple format: <ALayout, BLayout, CLayout, BQLayout, ADataType, BDataType, QDataType, CDataType,
// QuantType, GemmConfig, QuantGroupSize>
// clang-format off
using BPreshuffleTiledPermuteTypes = ::testing::Types<
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigPreshuffleBPrefillTiledPermuteN, GroupSize>,
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigPreshuffleBPrefillTiledPermuteN, GroupSize>,
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF8, PkInt4, BF8, Half, BQuantGrouped, GemmConfigPreshuffleBPrefillTiledPermuteN, GroupSize>
>;
// clang-format on
// Test suite for BQuant Preshuffle TiledPermuteN
TYPED_TEST_SUITE(TestCkTileGemmPreshuffleBBQuant, BPreshuffleTiledPermuteTypes);
// BQuant PreshuffleB tests
TYPED_TEST(TestCkTileGemmPreshuffleBBQuant, BQuantPreshuffleTest)
{
this->run_test_with_validation(1024, 1024, 1024);
}

View File

@@ -0,0 +1,53 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "ck_tile/host.hpp"
#include "ck_tile/ops/gemm.hpp"
#include <gtest/gtest.h>
#include <memory>
#include "test_gemm_quant_fixtures.hpp"
// Type aliases for readability
using RowMajor = ck_tile::tensor_layout::gemm::RowMajor;
using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor;
using FP8 = ck_tile::fp8_t;
using BF8 = ck_tile::bf8_t;
using Half = ck_tile::half_t;
using PkInt4 = ck_tile::pk_int4_t;
using BQuantGrouped = std::integral_constant<ck_tile::QuantType, ck_tile::QuantType::BQuantGrouped>;
using GroupSize64 = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 64>>;
using GroupSize2D64N = ck_tile::QuantGroupShape<ck_tile::sequence<1, 64, 128>>;
// Type combinations for BQuant tests - Transpose Layouts
// Tuple format: <ALayout, BLayout, CLayout, BQLayout, ADataType, BDataType, QDataType, CDataType,
// QuantType, GemmConfig, QuantGroupSize>
// clang-format off
using BQuantTransposeTypes = ::testing::Types<
// some cases with transpose layouts
std::tuple< RowMajor, RowMajor, RowMajor, RowMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize64>,
std::tuple<ColumnMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize64>,
std::tuple<ColumnMajor, RowMajor, RowMajor, RowMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize64>,
std::tuple< RowMajor, RowMajor, RowMajor, RowMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize2D64N>,
std::tuple<ColumnMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize2D64N>,
std::tuple<ColumnMajor, RowMajor, RowMajor, RowMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize2D64N>,
// pkint4 + transpose cases
std::tuple< RowMajor, RowMajor, RowMajor, RowMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigBase, GroupSize64>,
std::tuple<ColumnMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigBase, GroupSize64>,
std::tuple<ColumnMajor, RowMajor, RowMajor, RowMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigBase, GroupSize64>,
std::tuple< RowMajor, RowMajor, RowMajor, RowMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigBase, GroupSize2D64N>,
std::tuple<ColumnMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigBase, GroupSize2D64N>,
std::tuple<ColumnMajor, RowMajor, RowMajor, RowMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigBase, GroupSize2D64N>
>;
// clang-format on
// Test suite for BQuant Transpose
TYPED_TEST_SUITE(TestCkTileGemmBQuant, BQuantTransposeTypes);
// BQuant tests
TYPED_TEST(TestCkTileGemmBQuant, BQuantGroupedTest)
{
this->run_test_with_validation(1024, 1024, 1024);
}

View File

@@ -11,26 +11,6 @@
#include "ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp"
#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp"
template <typename PrecType, ck_tile::index_t M_Warp_Tile>
constexpr ck_tile::index_t get_k_warp_tile_flatmm()
{
#if CK_TILE_USE_WMMA
return 16;
#else
#if defined(CK_GFX950_SUPPORT)
if constexpr(M_Warp_Tile == 32)
return sizeof(PrecType) == 2 ? 16 : 64;
else
return sizeof(PrecType) == 2 ? 32 : 128;
#else
if constexpr(M_Warp_Tile == 32)
return sizeof(PrecType) == 2 ? 16 : 32;
else
return sizeof(PrecType) == 2 ? 32 : 64;
#endif
#endif
}
template <typename Tuple>
class TestCkTileGroupedGemmPreshuffle : public ::testing::Test
{
@@ -67,7 +47,8 @@ class TestCkTileGroupedGemmPreshuffle : public ::testing::Test
static const ck_tile::index_t M_Warp_Tile = 16;
static const ck_tile::index_t N_Warp_Tile = 16;
static const ck_tile::index_t K_Warp_Tile = get_k_warp_tile_flatmm<BDataType, M_Warp_Tile>();
static const ck_tile::index_t K_Warp_Tile =
ck_tile::get_k_warp_tile<BDataType, M_Warp_Tile, true>();
static constexpr bool DoubleSmemBuffer = true; // preshuffle v2 uses ping-pong smem
static constexpr bool TransposeC = false; // transpose c is not supported
@@ -101,46 +82,6 @@ class TestCkTileGroupedGemmPreshuffle : public ::testing::Test
return gemm_descs.size() * sizeof(ck_tile::GemmTransKernelArg<>);
}
template <typename T>
auto shuffle_b(const ck_tile::HostTensor<T>& t)
{
assert(t.get_lengths().size() == 2);
int n_ = t.get_lengths()[1];
int k_ = t.get_lengths()[0];
if(ck_tile::is_gfx12_supported())
{
constexpr int divisor = 2;
constexpr int kABK1PerLane = 8;
constexpr int kABK0PerLane = K_Warp_Tile / divisor / kABK1PerLane;
ck_tile::HostTensor<T> t_view({n_ / N_Warp_Tile,
N_Warp_Tile,
k_ / K_Warp_Tile,
kABK0PerLane,
divisor,
kABK1PerLane});
std::copy(t.begin(), t.end(), t_view.begin());
return ck_tile::reference_permute(t_view, {0, 2, 4, 1, 3, 5});
}
else
{
int divisor = 1;
if(ck_tile::is_gfx11_supported())
{
divisor = 1;
}
else
{
assert(is_wave32() == false);
divisor = N_Warp_Tile == 32 ? 2 : 4;
}
ck_tile::HostTensor<T> t_view(
{n_ / N_Warp_Tile, N_Warp_Tile, k_ / K_Warp_Tile, divisor, K_Warp_Tile / divisor});
std::copy(t.begin(), t.end(), t_view.begin());
return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4});
}
}
template <typename ALayout, typename BLayout, typename CLayout>
void invoke_grouped_gemm(const std::vector<grouped_gemm_kargs>& gemm_descs,
const ck_tile::stream_config& s,
@@ -340,6 +281,14 @@ class TestCkTileGroupedGemmPreshuffle : public ::testing::Test
}
}
struct BShuffleGemmConfig
{
static constexpr ck_tile::index_t N_Warp_Tile =
TestCkTileGroupedGemmPreshuffle::N_Warp_Tile;
static constexpr ck_tile::index_t K_Warp_Tile =
TestCkTileGroupedGemmPreshuffle::K_Warp_Tile;
};
public:
void Run(const std::vector<int>& Ms,
const std::vector<int>& Ns,
@@ -424,7 +373,7 @@ class TestCkTileGroupedGemmPreshuffle : public ::testing::Test
ck_tile::FillUniformDistribution<BDataType>{-1.f, 1.f}(b_k_n_tensors[i]);
// Host-side preshuffle of B
auto b_shuffle_host = shuffle_b(b_k_n_tensors[i]);
auto b_shuffle_host = ck_tile::shuffle_b<BShuffleGemmConfig>(b_k_n_tensors[i]);
a_m_k_dev_buf.push_back(std::make_unique<ck_tile::DeviceMem>(
a_m_k_tensors[i].get_element_space_size_in_bytes()));

View File

@@ -6,18 +6,18 @@ if(CK_USE_OCP_FP8)
list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8)
endif()
if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx12")
# Split into three separate test executables for faster parallel compilation
add_gtest_executable(test_ck_tile_grouped_gemm_quant_rowcol test_grouped_gemm_quant_rowcol.cpp)
target_compile_options(test_ck_tile_grouped_gemm_quant_rowcol PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
# if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx12")
# # Split into three separate test executables for faster parallel compilation
# add_gtest_executable(test_ck_tile_grouped_gemm_quant_rowcol test_grouped_gemm_quant_rowcol.cpp)
# target_compile_options(test_ck_tile_grouped_gemm_quant_rowcol PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
add_gtest_executable(test_ck_tile_grouped_gemm_quant_tensor test_grouped_gemm_quant_tensor.cpp)
target_compile_options(test_ck_tile_grouped_gemm_quant_tensor PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
# add_gtest_executable(test_ck_tile_grouped_gemm_quant_tensor test_grouped_gemm_quant_tensor.cpp)
# target_compile_options(test_ck_tile_grouped_gemm_quant_tensor PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
add_gtest_executable(test_ck_tile_grouped_gemm_quant_aquant test_grouped_gemm_quant_aquant.cpp)
target_compile_options(test_ck_tile_grouped_gemm_quant_aquant PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
# add_gtest_executable(test_ck_tile_grouped_gemm_quant_aquant test_grouped_gemm_quant_aquant.cpp)
# target_compile_options(test_ck_tile_grouped_gemm_quant_aquant PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
# add_gtest_executable(test_ck_tile_grouped_gemm_quant_bquant test_grouped_gemm_quant_bquant.cpp)
# target_compile_options(test_ck_tile_grouped_gemm_quant_bquant PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
endif()
# endif()

View File

@@ -81,42 +81,3 @@ inline KernelTraits extract_traits_from_name(const std::string& kernel_name)
return traits;
}
template <typename T>
auto shuffle_b(const ck_tile::HostTensor<T>& t,
ck_tile::index_t N_Warp_Tile,
ck_tile::index_t K_Warp_Tile)
{
assert(t.get_lengths().size() == 2);
int n_ = t.get_lengths()[1];
int k_ = t.get_lengths()[0];
int divisor = N_Warp_Tile == 32 ? 2 : 4;
ck_tile::HostTensor<T> t_view(
{n_ / N_Warp_Tile, N_Warp_Tile, k_ / K_Warp_Tile, divisor, K_Warp_Tile / divisor});
std::copy(t.begin(), t.end(), t_view.begin());
return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4});
}
template <typename T>
auto shuffle_b_permuteN(const ck_tile::HostTensor<T>& t,
ck_tile::index_t N_Warp_Tile,
ck_tile::index_t K_Warp_Tile,
ck_tile::index_t N_Tile,
ck_tile::index_t N_Warp)
{
assert(t.get_lengths().size() == 2);
int n_ = t.get_lengths()[1];
int k_ = t.get_lengths()[0];
int divisor = N_Warp_Tile == 32 ? 2 : 4;
int NRepeat = N_Tile / N_Warp_Tile / N_Warp;
ck_tile::HostTensor<T> t_view({n_ / N_Tile,
N_Warp,
N_Warp_Tile,
NRepeat,
k_ / K_Warp_Tile,
divisor,
K_Warp_Tile / divisor});
std::copy(t.begin(), t.end(), t_view.begin());
return ck_tile::reference_permute(t_view, {0, 3, 1, 4, 5, 2, 6});
}

View File

@@ -111,21 +111,30 @@ class GemmProfiler
c_m_n_dev_buf.SetZero();
c_m_n_dev_result.SetZero();
struct GemmConfig
{
ck_tile::index_t N_Warp_Tile;
ck_tile::index_t K_Warp_Tile;
ck_tile::index_t N_Tile;
ck_tile::index_t N_Warp;
};
for(const auto& callable : callables)
{
ck_tile::index_t N_Warp_Tile = std::get<1>(config.warp_tile_dims);
ck_tile::index_t K_Warp_Tile = std::get<2>(config.warp_tile_dims);
ck_tile::index_t N_Tile = std::get<1>(config.tile_dims);
ck_tile::index_t N_Warp = std::get<1>(config.warp_dims);
GemmConfig gemmConfig = {};
gemmConfig.N_Warp_Tile = std::get<1>(config.warp_tile_dims);
gemmConfig.K_Warp_Tile = std::get<2>(config.warp_tile_dims);
gemmConfig.N_Tile = std::get<1>(config.tile_dims);
gemmConfig.N_Warp = std::get<1>(config.warp_dims);
ck_tile::HostTensor<BDataType> b_shuffle_host = [&]() {
if(config.permuteN)
{
return shuffle_b_permuteN(b_k_n, N_Warp_Tile, K_Warp_Tile, N_Tile, N_Warp);
return ck_tile::shuffle_b_permuteN(b_k_n, gemmConfig);
}
else
{
return shuffle_b(b_k_n, N_Warp_Tile, K_Warp_Tile);
return ck_tile::shuffle_b(b_k_n, gemmConfig);
}
}();