mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-30 11:47:48 +00:00
Merge branch 'develop' of github.com:ROCm/composable_kernel into ck_moe_bs_splitk_pr
This commit is contained in:
5
.gitignore
vendored
5
.gitignore
vendored
@@ -83,6 +83,11 @@ __pycache__/
|
||||
|
||||
.cache/
|
||||
|
||||
# Generated test data
|
||||
test_data/*
|
||||
!test_data/*.py
|
||||
!test_data/*.sh
|
||||
|
||||
# Exceptions to build* patterns above
|
||||
# The experimental/builder directory should be tracked despite matching build*
|
||||
!experimental/builder
|
||||
|
||||
@@ -766,6 +766,9 @@ if(CK_EXPERIMENTAL_BUILDER)
|
||||
${PROJECT_SOURCE_DIR}/experimental/builder/include/ck_tile/builder
|
||||
DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/ck_tile
|
||||
)
|
||||
|
||||
set(CK_TILE_SRC_FOLDER ${CMAKE_SOURCE_DIR}/include/ck_tile/)
|
||||
rocm_install(DIRECTORY ${CK_TILE_SRC_FOLDER} DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/ck_tile)
|
||||
endif()
|
||||
|
||||
set(CPACK_RESOURCE_FILE_LICENSE "${CMAKE_CURRENT_SOURCE_DIR}/LICENSE")
|
||||
|
||||
10
Jenkinsfile
vendored
10
Jenkinsfile
vendored
@@ -1476,15 +1476,19 @@ pipeline {
|
||||
setup_args = "NO_CK_BUILD"
|
||||
execute_args = """ cd ../build && \
|
||||
../script/cmake-ck-dev.sh ../ gfx90a && \
|
||||
make -j64 test_grouped_convnd_fwd_dataset_xdl && \
|
||||
make -j64 test_grouped_convnd_fwd_dataset_xdl \
|
||||
test_grouped_convnd_bwd_data_dataset_xdl \
|
||||
test_grouped_convnd_bwd_weight_dataset_xdl && \
|
||||
cd ../test_data && \
|
||||
# Dataset generation modes:
|
||||
# - small: ~60 test cases (minimal, quick testing - 3 models, 2 batch sizes, 2 image sizes)
|
||||
# - half: ~300 test cases (moderate coverage - 16 models, 3 batch sizes, 5 image sizes), ~ 17 hours testing time
|
||||
# - full: ~600 test cases (comprehensive - 16 models, 5 batch sizes, 9 image sizes), ~ 40 hours testing time
|
||||
./generate_test_dataset.sh half && \
|
||||
./generate_test_dataset.sh small && \
|
||||
cd ../build && \
|
||||
./bin/test_grouped_convnd_fwd_dataset_xdl"""
|
||||
./bin/test_grouped_convnd_fwd_dataset_xdl && \
|
||||
./bin/test_grouped_convnd_bwd_data_dataset_xdl && \
|
||||
./bin/test_grouped_convnd_bwd_weight_dataset_xdl"""
|
||||
}
|
||||
steps{
|
||||
buildHipClangJobAndReboot(setup_args:setup_args, build_type: 'Release', execute_cmd: execute_args)
|
||||
|
||||
@@ -233,7 +233,20 @@ int run_contraction_bilinear_example(int argc, char* argv[])
|
||||
}
|
||||
}
|
||||
|
||||
return ck::utils::check_err(e_ms_ns_device_result, e_ms_ns_host_result) ? 0 : 1;
|
||||
if(ck::is_gfx11_supported())
|
||||
{
|
||||
return ck::utils::check_err(e_ms_ns_device_result,
|
||||
e_ms_ns_host_result,
|
||||
"Error: Incorrect results!",
|
||||
1e-4,
|
||||
1e-4)
|
||||
? 0
|
||||
: 1;
|
||||
}
|
||||
else
|
||||
{
|
||||
return ck::utils::check_err(e_ms_ns_device_result, e_ms_ns_host_result) ? 0 : 1;
|
||||
}
|
||||
}
|
||||
|
||||
return 0;
|
||||
|
||||
@@ -216,7 +216,20 @@ int run_contraction_scale_example(int argc, char* argv[])
|
||||
}
|
||||
}
|
||||
|
||||
return ck::utils::check_err(e_ms_ns_device_result, e_ms_ns_host_result) ? 0 : 1;
|
||||
if(ck::is_gfx11_supported())
|
||||
{
|
||||
return ck::utils::check_err(e_ms_ns_device_result,
|
||||
e_ms_ns_host_result,
|
||||
"Error: Incorrect results!",
|
||||
1e-4,
|
||||
1e-4)
|
||||
? 0
|
||||
: 1;
|
||||
}
|
||||
else
|
||||
{
|
||||
return ck::utils::check_err(e_ms_ns_device_result, e_ms_ns_host_result) ? 0 : 1;
|
||||
}
|
||||
}
|
||||
|
||||
return 0;
|
||||
|
||||
@@ -12,40 +12,6 @@
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
#include "ck_tile/utility/json_dump.hpp"
|
||||
|
||||
template <typename PrecType, ck_tile::index_t M_Warp_Tile>
|
||||
constexpr ck_tile::index_t get_k_warp_tile()
|
||||
{
|
||||
#if defined(CK_GFX950_SUPPORT)
|
||||
constexpr bool is_8bit_float =
|
||||
std::is_same_v<PrecType, ck_tile::fp8_t> || std::is_same_v<PrecType, ck_tile::bf8_t>;
|
||||
if constexpr(M_Warp_Tile == 32)
|
||||
return is_8bit_float ? 64 : 16;
|
||||
else
|
||||
return is_8bit_float ? 128 : 32;
|
||||
#else
|
||||
if constexpr(M_Warp_Tile == 32)
|
||||
return 16;
|
||||
else
|
||||
return 32;
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename PrecType, ck_tile::index_t M_Warp_Tile>
|
||||
constexpr ck_tile::index_t get_k_warp_tile_flatmm()
|
||||
{
|
||||
#if defined(CK_GFX950_SUPPORT)
|
||||
if constexpr(M_Warp_Tile == 32)
|
||||
return sizeof(PrecType) == 2 ? 16 : 64;
|
||||
else
|
||||
return sizeof(PrecType) == 2 ? 32 : 128;
|
||||
#else
|
||||
if constexpr(M_Warp_Tile == 32)
|
||||
return sizeof(PrecType) == 2 ? 16 : 32;
|
||||
else
|
||||
return sizeof(PrecType) == 2 ? 32 : 64;
|
||||
#endif
|
||||
}
|
||||
|
||||
struct GemmConfigBase
|
||||
{
|
||||
static constexpr bool kPadM = false;
|
||||
@@ -122,7 +88,8 @@ struct GemmConfigComputeV3 : public GemmConfigBase
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
|
||||
static constexpr ck_tile::index_t K_Warp_Tile =
|
||||
ck_tile::get_k_warp_tile<PrecType, M_Warp_Tile>();
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3;
|
||||
@@ -141,7 +108,8 @@ struct GemmConfigComputeV3_1 : public GemmConfigBase
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
|
||||
static constexpr ck_tile::index_t K_Warp_Tile =
|
||||
ck_tile::get_k_warp_tile<PrecType, M_Warp_Tile>();
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3;
|
||||
@@ -160,7 +128,8 @@ struct GemmConfigComputeV3_2 : public GemmConfigBase
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
|
||||
static constexpr ck_tile::index_t K_Warp_Tile =
|
||||
ck_tile::get_k_warp_tile<PrecType, M_Warp_Tile>();
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3;
|
||||
@@ -204,7 +173,8 @@ struct GemmConfigComputeV4 : public GemmConfigBase
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
|
||||
static constexpr ck_tile::index_t K_Warp_Tile =
|
||||
ck_tile::get_k_warp_tile<PrecType, M_Warp_Tile>();
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = true;
|
||||
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V4;
|
||||
@@ -223,7 +193,8 @@ struct GemmConfigComputeV4_1 : public GemmConfigBase
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
|
||||
static constexpr ck_tile::index_t K_Warp_Tile =
|
||||
ck_tile::get_k_warp_tile<PrecType, M_Warp_Tile>();
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = true;
|
||||
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V4;
|
||||
@@ -242,7 +213,8 @@ struct GemmConfigComputeV5 : public GemmConfigBase
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
|
||||
static constexpr ck_tile::index_t K_Warp_Tile =
|
||||
ck_tile::get_k_warp_tile<PrecType, M_Warp_Tile>();
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V5;
|
||||
@@ -282,7 +254,8 @@ struct GemmConfigPreshuffleDecode : public GemmConfigBase
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile_flatmm<PrecType, M_Warp_Tile>();
|
||||
static constexpr ck_tile::index_t K_Warp_Tile =
|
||||
ck_tile::get_k_warp_tile<PrecType, M_Warp_Tile, true>();
|
||||
|
||||
static constexpr int kBlockPerCu = 1;
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
|
||||
@@ -306,7 +279,8 @@ struct GemmConfigPreshufflePrefill : public GemmConfigBase
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile_flatmm<PrecType, M_Warp_Tile>();
|
||||
static constexpr ck_tile::index_t K_Warp_Tile =
|
||||
ck_tile::get_k_warp_tile<PrecType, M_Warp_Tile, true>();
|
||||
|
||||
static constexpr int kBlockPerCu = 2;
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
|
||||
|
||||
@@ -63,25 +63,30 @@ struct UniversalInvoker
|
||||
const auto Run = [&](const auto memory_operation_) {
|
||||
constexpr auto memory_operation = memory_operation_.value;
|
||||
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
CDEElementWise,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
GemmConfig::M_Warp,
|
||||
GemmConfig::N_Warp,
|
||||
GemmConfig::M_Warp_Tile,
|
||||
GemmConfig::N_Warp_Tile,
|
||||
GemmConfig::K_Warp_Tile,
|
||||
UniversalGemmProblem::TransposeC,
|
||||
memory_operation,
|
||||
GemmConfig::NumWaveGroups>>;
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<ck_tile::CShuffleEpilogueProblem<
|
||||
ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
CDEElementWise,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
GemmConfig::M_Warp,
|
||||
GemmConfig::N_Warp,
|
||||
GemmConfig::M_Warp_Tile,
|
||||
GemmConfig::N_Warp_Tile,
|
||||
GemmConfig::K_Warp_Tile,
|
||||
UniversalGemmProblem::TransposeC,
|
||||
memory_operation,
|
||||
GemmConfig::NumWaveGroups,
|
||||
false, /*FixedVectorSize_*/
|
||||
1, /*VectorSizeC_*/
|
||||
false, /*TiledMMAPermuteN_*/
|
||||
1, /*BlockedXDLN_PerWarp_*/
|
||||
GemmConfig::DoubleSmemBuffer /*DoubleSmemBuffer*/>>;
|
||||
|
||||
using Kernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
|
||||
auto kargs = Kernel::MakeKernelArgs(args);
|
||||
|
||||
@@ -11,40 +11,6 @@
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
#include "ck_tile/utility/json_dump.hpp"
|
||||
|
||||
template <typename PrecType, ck_tile::index_t M_Warp_Tile>
|
||||
constexpr ck_tile::index_t get_k_warp_tile()
|
||||
{
|
||||
#if defined(CK_GFX950_SUPPORT)
|
||||
constexpr bool is_8bit_float =
|
||||
std::is_same_v<PrecType, ck_tile::fp8_t> || std::is_same_v<PrecType, ck_tile::bf8_t>;
|
||||
if constexpr(M_Warp_Tile == 32)
|
||||
return is_8bit_float ? 64 : 16;
|
||||
else
|
||||
return is_8bit_float ? 128 : 32;
|
||||
#else
|
||||
if constexpr(M_Warp_Tile == 32)
|
||||
return 16;
|
||||
else
|
||||
return 32;
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename PrecType, ck_tile::index_t M_Warp_Tile>
|
||||
constexpr ck_tile::index_t get_k_warp_tile_flatmm()
|
||||
{
|
||||
#if defined(CK_GFX950_SUPPORT)
|
||||
if constexpr(M_Warp_Tile == 32)
|
||||
return sizeof(PrecType) == 2 ? 16 : 64;
|
||||
else
|
||||
return sizeof(PrecType) == 2 ? 32 : 128;
|
||||
#else
|
||||
if constexpr(M_Warp_Tile == 32)
|
||||
return sizeof(PrecType) == 2 ? 16 : 32;
|
||||
else
|
||||
return sizeof(PrecType) == 2 ? 32 : 64;
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename DataType>
|
||||
struct GemmTypeConfig;
|
||||
|
||||
@@ -111,7 +77,8 @@ struct GemmConfigComputeV3_2 : public GemmConfigBase
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
|
||||
static constexpr ck_tile::index_t K_Warp_Tile =
|
||||
ck_tile::get_k_warp_tile<PrecType, M_Warp_Tile>();
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3;
|
||||
@@ -134,7 +101,8 @@ struct GemmConfigComputeV4 : public GemmConfigBase
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
|
||||
static constexpr ck_tile::index_t K_Warp_Tile =
|
||||
ck_tile::get_k_warp_tile<PrecType, M_Warp_Tile>();
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = true;
|
||||
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V4;
|
||||
@@ -157,7 +125,8 @@ struct GemmConfigComputeV4_V2 : public GemmConfigBase
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
|
||||
static constexpr ck_tile::index_t K_Warp_Tile =
|
||||
ck_tile::get_k_warp_tile<PrecType, M_Warp_Tile>();
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = true;
|
||||
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V4;
|
||||
@@ -178,7 +147,8 @@ struct GemmConfigPreshuffleDecode : public GemmConfigBase
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile_flatmm<PrecType, M_Warp_Tile>();
|
||||
static constexpr ck_tile::index_t K_Warp_Tile =
|
||||
ck_tile::get_k_warp_tile<PrecType, M_Warp_Tile, true>();
|
||||
|
||||
static constexpr bool kPadK = true;
|
||||
|
||||
@@ -203,7 +173,8 @@ struct GemmConfigPreshufflePrefill : public GemmConfigBase
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile_flatmm<PrecType, M_Warp_Tile>();
|
||||
static constexpr ck_tile::index_t K_Warp_Tile =
|
||||
ck_tile::get_k_warp_tile<PrecType, M_Warp_Tile, true>();
|
||||
|
||||
static constexpr int kBlockPerCu = 2;
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
|
||||
|
||||
@@ -11,24 +11,6 @@
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
#include "ck_tile/utility/json_dump.hpp"
|
||||
|
||||
template <typename PrecType, ck_tile::index_t M_Warp_Tile>
|
||||
constexpr ck_tile::index_t get_k_warp_tile()
|
||||
{
|
||||
#if defined(CK_GFX950_SUPPORT)
|
||||
constexpr bool is_8bit_float =
|
||||
std::is_same_v<PrecType, ck_tile::fp8_t> || std::is_same_v<PrecType, ck_tile::bf8_t>;
|
||||
if constexpr(M_Warp_Tile == 32)
|
||||
return is_8bit_float ? 64 : 16;
|
||||
else
|
||||
return is_8bit_float ? 128 : 32;
|
||||
#else
|
||||
if constexpr(M_Warp_Tile == 32)
|
||||
return 16;
|
||||
else
|
||||
return 32;
|
||||
#endif
|
||||
}
|
||||
|
||||
struct GemmConfigBase
|
||||
{
|
||||
static constexpr bool kPadM = false;
|
||||
|
||||
@@ -10,40 +10,6 @@
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp"
|
||||
|
||||
template <typename PrecType, ck_tile::index_t M_Warp_Tile>
|
||||
constexpr ck_tile::index_t get_k_warp_tile()
|
||||
{
|
||||
#if defined(CK_GFX950_SUPPORT)
|
||||
constexpr bool is_8bit_float =
|
||||
std::is_same_v<PrecType, ck_tile::fp8_t> || std::is_same_v<PrecType, ck_tile::bf8_t>;
|
||||
if constexpr(M_Warp_Tile == 32)
|
||||
return is_8bit_float ? 64 : 16;
|
||||
else
|
||||
return is_8bit_float ? 128 : 32;
|
||||
#else
|
||||
if constexpr(M_Warp_Tile == 32)
|
||||
return 16;
|
||||
else
|
||||
return 32;
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename PrecType, ck_tile::index_t M_Warp_Tile>
|
||||
constexpr ck_tile::index_t get_k_from_preshuffled_warp_tile()
|
||||
{
|
||||
#if defined(CK_GFX950_SUPPORT)
|
||||
if constexpr(M_Warp_Tile == 32)
|
||||
return sizeof(PrecType) == 2 ? 16 : 64;
|
||||
else
|
||||
return sizeof(PrecType) == 2 ? 32 : 128;
|
||||
#else
|
||||
if constexpr(M_Warp_Tile == 32)
|
||||
return sizeof(PrecType) == 2 ? 16 : 32;
|
||||
else
|
||||
return sizeof(PrecType) == 2 ? 32 : 64;
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename DataType>
|
||||
struct GemmTypeConfig;
|
||||
|
||||
@@ -100,7 +66,8 @@ struct GemmConfigComputeV3_2 : public GemmConfigBase<Persistent>
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
|
||||
static constexpr ck_tile::index_t K_Warp_Tile =
|
||||
ck_tile::get_k_warp_tile<PrecType, M_Warp_Tile>();
|
||||
};
|
||||
|
||||
template <typename PrecType, bool Persistent>
|
||||
@@ -117,7 +84,7 @@ struct GemmConfigPreshuffleB_Bquant_prefill : public GemmConfigBase<Persistent>
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile =
|
||||
get_k_from_preshuffled_warp_tile<PrecType, M_Warp_Tile>();
|
||||
ck_tile::get_k_warp_tile<PrecType, M_Warp_Tile, true>();
|
||||
|
||||
static constexpr bool PreshuffleB = true;
|
||||
static constexpr bool DoubleSmemBuffer = true;
|
||||
|
||||
@@ -24,39 +24,6 @@ inline size_t hash_multiple_strings(const std::vector<std::string>& inputs)
|
||||
return combined_hash;
|
||||
}
|
||||
|
||||
template <typename PrecType, ck_tile::index_t M_Warp_Tile>
|
||||
constexpr ck_tile::index_t get_k_warp_tile()
|
||||
{
|
||||
#if defined(CK_GFX950_SUPPORT)
|
||||
constexpr bool is_8bit_float =
|
||||
std::is_same_v<PrecType, ck_tile::fp8_t> || std::is_same_v<PrecType, ck_tile::bf8_t>;
|
||||
if constexpr(M_Warp_Tile == 32)
|
||||
return is_8bit_float ? 64 : 16;
|
||||
else
|
||||
return is_8bit_float ? 128 : 32;
|
||||
#else
|
||||
if constexpr(M_Warp_Tile == 32)
|
||||
return 16;
|
||||
else
|
||||
return 32;
|
||||
#endif
|
||||
}
|
||||
template <typename PrecType, ck_tile::index_t M_Warp_Tile>
|
||||
constexpr ck_tile::index_t get_k_from_preshuffled_warp_tile()
|
||||
{
|
||||
#if defined(CK_GFX950_SUPPORT)
|
||||
if constexpr(M_Warp_Tile == 32)
|
||||
return sizeof(PrecType) == 2 ? 16 : 64;
|
||||
else
|
||||
return sizeof(PrecType) == 2 ? 32 : 128;
|
||||
#else
|
||||
if constexpr(M_Warp_Tile == 32)
|
||||
return sizeof(PrecType) == 2 ? 16 : 32;
|
||||
else
|
||||
return sizeof(PrecType) == 2 ? 32 : 64;
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename Layout>
|
||||
static constexpr inline auto is_row_major(Layout layout_)
|
||||
{
|
||||
@@ -124,7 +91,8 @@ struct GemmConfigQuantDecode : public GemmConfigBase
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
|
||||
static constexpr ck_tile::index_t K_Warp_Tile =
|
||||
ck_tile::get_k_warp_tile<PrecType, M_Warp_Tile>();
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
@@ -140,7 +108,8 @@ struct GemmConfigRowColQuant : public GemmConfigBase
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
|
||||
static constexpr ck_tile::index_t K_Warp_Tile =
|
||||
ck_tile::get_k_warp_tile<PrecType, M_Warp_Tile>();
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
@@ -157,7 +126,7 @@ struct GemmConfigPreshuffleQuantDecode : public GemmConfigBase
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile =
|
||||
get_k_from_preshuffled_warp_tile<PrecType, M_Warp_Tile>();
|
||||
ck_tile::get_k_warp_tile<PrecType, M_Warp_Tile, true>();
|
||||
|
||||
static constexpr bool PreshuffleQuant = true;
|
||||
};
|
||||
@@ -176,7 +145,7 @@ struct GemmConfigPreshuffleB_BQuant_Decode : public GemmConfigBase
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile =
|
||||
get_k_from_preshuffled_warp_tile<PrecType, M_Warp_Tile>();
|
||||
ck_tile::get_k_warp_tile<PrecType, M_Warp_Tile, true>();
|
||||
|
||||
static constexpr bool PreshuffleB = true;
|
||||
static constexpr bool DoubleSmemBuffer = true;
|
||||
@@ -206,7 +175,7 @@ struct GemmConfigPreshuffleB_BQuant_Prefill : public GemmConfigBase
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile =
|
||||
get_k_from_preshuffled_warp_tile<PrecType, M_Warp_Tile>();
|
||||
ck_tile::get_k_warp_tile<PrecType, M_Warp_Tile, true>();
|
||||
|
||||
static constexpr bool PreshuffleB = true;
|
||||
static constexpr bool DoubleSmemBuffer = true;
|
||||
@@ -236,7 +205,8 @@ struct GemmConfigQuantPrefill : public GemmConfigBase
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
|
||||
static constexpr ck_tile::index_t K_Warp_Tile =
|
||||
ck_tile::get_k_warp_tile<PrecType, M_Warp_Tile>();
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
|
||||
@@ -84,63 +84,46 @@ namespace ck_tile::builder::factory {
|
||||
|
||||
// CK Tile kernel
|
||||
template <typename T>
|
||||
consteval bool IsTileAlgorithm()
|
||||
{
|
||||
return ConvAlgorithmDescriptor<T> && SpecifiesTileThreadBlock<T> && SpecifiesTileTransfer<T> &&
|
||||
SpecifiesTileConvSpecialization<T> && SpecifiesTileBlockGemm<T> &&
|
||||
SpecifiesTileOptimizations<T>;
|
||||
}
|
||||
concept IsTileAlgorithm = ConvAlgorithmDescriptor<T> && SpecifiesTileThreadBlock<T> &&
|
||||
SpecifiesTileTransfer<T> && SpecifiesTileConvSpecialization<T> &&
|
||||
SpecifiesTileBlockGemm<T> && SpecifiesTileOptimizations<T>;
|
||||
|
||||
// XDL-based kernel with V3 pipeline structure (newer block GEMM pipeline)
|
||||
template <typename T>
|
||||
consteval bool IsXdlV3Algorithm()
|
||||
{
|
||||
return ConvAlgorithmDescriptor<T> && SpecifiesThreadBlock<T> && SpecifiesGridwiseXdlGemm<T> &&
|
||||
SpecifiesBlockTransfer<T> && SpecifiesLdsTransfer<T> &&
|
||||
SpecifiesThreadClusterAccessOrder<T> && SpecifiesSourceAccessOrder<T> &&
|
||||
SpecifiesFwdConvSpecialization<T> && SpecifiesGemmSpecialization<T> &&
|
||||
SpecifiesBlockGemm<T>;
|
||||
}
|
||||
concept IsXdlV3Algorithm =
|
||||
ConvAlgorithmDescriptor<T> && SpecifiesThreadBlock<T> && SpecifiesGridwiseXdlGemm<T> &&
|
||||
SpecifiesBlockTransfer<T> && SpecifiesLdsTransfer<T> && SpecifiesThreadClusterAccessOrder<T> &&
|
||||
SpecifiesSourceAccessOrder<T> && SpecifiesFwdConvSpecialization<T> &&
|
||||
SpecifiesGemmSpecialization<T> && SpecifiesBlockGemm<T>;
|
||||
|
||||
// Standard XDL-based kernel (uses XDLops hardware instructions for matrix multiply)
|
||||
template <typename T>
|
||||
consteval bool IsXdlAlgorithm()
|
||||
{
|
||||
return ConvAlgorithmDescriptor<T> && SpecifiesThreadBlock<T> && SpecifiesGridwiseXdlGemm<T> &&
|
||||
SpecifiesBlockTransfer<T> && SpecifiesLdsTransfer<T> &&
|
||||
SpecifiesThreadClusterAccessOrder<T> && SpecifiesSourceAccessOrder<T> &&
|
||||
SpecifiesFwdConvSpecialization<T> && SpecifiesGemmSpecialization<T> &&
|
||||
SpecifiesNumPrefetchStages<T> && SpecifiesNumGroupsToMerge<T> &&
|
||||
SpecifiesLoopScheduler<T>;
|
||||
}
|
||||
concept IsXdlAlgorithm =
|
||||
ConvAlgorithmDescriptor<T> && SpecifiesThreadBlock<T> && SpecifiesGridwiseXdlGemm<T> &&
|
||||
SpecifiesBlockTransfer<T> && SpecifiesLdsTransfer<T> && SpecifiesThreadClusterAccessOrder<T> &&
|
||||
SpecifiesSourceAccessOrder<T> && SpecifiesFwdConvSpecialization<T> &&
|
||||
SpecifiesGemmSpecialization<T> && SpecifiesNumPrefetchStages<T> &&
|
||||
SpecifiesNumGroupsToMerge<T> && SpecifiesLoopScheduler<T>;
|
||||
|
||||
// WMMA-based kernel (uses Wavefront Matrix-Matrix Accumulate instructions)
|
||||
template <typename T>
|
||||
consteval bool IsWmmaAlgorithm()
|
||||
{
|
||||
return ConvAlgorithmDescriptor<T> && SpecifiesThreadBlock<T> && SpecifiesGridwiseWmmaGemm<T> &&
|
||||
SpecifiesBlockTransfer<T> && SpecifiesLdsTransfer<T> &&
|
||||
SpecifiesThreadClusterAccessOrder<T> && SpecifiesSourceAccessOrder<T> &&
|
||||
SpecifiesFwdConvSpecialization<T> && SpecifiesGemmSpecialization<T> &&
|
||||
SpecifiesNumPrefetchStages<T> && SpecifiesLoopScheduler<T>;
|
||||
}
|
||||
concept IsWmmaAlgorithm =
|
||||
ConvAlgorithmDescriptor<T> && SpecifiesThreadBlock<T> && SpecifiesGridwiseWmmaGemm<T> &&
|
||||
SpecifiesBlockTransfer<T> && SpecifiesLdsTransfer<T> && SpecifiesThreadClusterAccessOrder<T> &&
|
||||
SpecifiesSourceAccessOrder<T> && SpecifiesFwdConvSpecialization<T> &&
|
||||
SpecifiesGemmSpecialization<T> && SpecifiesNumPrefetchStages<T> && SpecifiesLoopScheduler<T>;
|
||||
|
||||
// Specialized DL kernel for specific NHWC/KYXC/NHWK data layouts
|
||||
template <typename T>
|
||||
consteval bool IsDlAlgorithm()
|
||||
{
|
||||
return ConvAlgorithmDescriptor<T> && SpecifiesThreadBlock<T> &&
|
||||
SpecifiesFwdConvSpecialization<T> && SpecifiesGemmSpecialization<T> &&
|
||||
SpecifiesDlThreadConfig<T> && SpecifiesDlThreadCluster<T> &&
|
||||
SpecifiesDlBlockTransfer<T> && SpecifiesDlEpilogue<T>;
|
||||
}
|
||||
concept IsDlAlgorithm =
|
||||
ConvAlgorithmDescriptor<T> && SpecifiesThreadBlock<T> && SpecifiesFwdConvSpecialization<T> &&
|
||||
SpecifiesGemmSpecialization<T> && SpecifiesDlThreadConfig<T> && SpecifiesDlThreadCluster<T> &&
|
||||
SpecifiesDlBlockTransfer<T> && SpecifiesDlEpilogue<T>;
|
||||
|
||||
// XDL-based kernel with large tensor support
|
||||
template <typename T>
|
||||
consteval bool IsLargeTensorAlgorithm()
|
||||
{
|
||||
return IsXdlAlgorithm<decltype(T::base_algorithm)>() && SpecifiesLargeTensorSupport<T>;
|
||||
}
|
||||
concept IsLargeTensorAlgorithm =
|
||||
IsXdlAlgorithm<decltype(T::base_algorithm)> && SpecifiesLargeTensorSupport<T>;
|
||||
|
||||
template <ConvSignatureDescriptor auto SIGNATURE,
|
||||
ConvAlgorithmDescriptor auto ALGORITHM,
|
||||
@@ -150,29 +133,29 @@ constexpr auto make_conv_instance()
|
||||
using AlgoType = std::remove_const_t<decltype(ALGORITHM)>;
|
||||
|
||||
// CK Tile supports common factory for each direction
|
||||
if constexpr(IsTileAlgorithm<AlgoType>())
|
||||
if constexpr(IsTileAlgorithm<AlgoType>)
|
||||
{
|
||||
return typename ConvTileFactory<SIGNATURE, ALGORITHM, VERSION>::Instance{};
|
||||
}
|
||||
else if constexpr(ConvDirectionIsForward<SIGNATURE>)
|
||||
{
|
||||
if constexpr(IsXdlV3Algorithm<AlgoType>())
|
||||
if constexpr(IsXdlV3Algorithm<AlgoType>)
|
||||
{
|
||||
return typename ConvFwdXdlV3Factory<SIGNATURE, ALGORITHM, VERSION>::Instance{};
|
||||
}
|
||||
else if constexpr(IsXdlAlgorithm<AlgoType>())
|
||||
else if constexpr(IsXdlAlgorithm<AlgoType>)
|
||||
{
|
||||
return typename ConvFwdXdlFactory<SIGNATURE, ALGORITHM, VERSION>::Instance{};
|
||||
}
|
||||
else if constexpr(IsWmmaAlgorithm<AlgoType>())
|
||||
else if constexpr(IsWmmaAlgorithm<AlgoType>)
|
||||
{
|
||||
return typename ConvFwdWmmaFactory<SIGNATURE, ALGORITHM, VERSION>::Instance{};
|
||||
}
|
||||
else if constexpr(IsDlAlgorithm<AlgoType>())
|
||||
else if constexpr(IsDlAlgorithm<AlgoType>)
|
||||
{
|
||||
return typename ConvFwdDlFactory<SIGNATURE, ALGORITHM, VERSION>::Instance{};
|
||||
}
|
||||
else if constexpr(IsLargeTensorAlgorithm<AlgoType>())
|
||||
else if constexpr(IsLargeTensorAlgorithm<AlgoType>)
|
||||
{
|
||||
return typename ConvFwdLargeTensorFactory<SIGNATURE, ALGORITHM, VERSION>::Instance{};
|
||||
}
|
||||
|
||||
@@ -31,13 +31,20 @@ constexpr size_t data_type_sizeof(DataType data_type)
|
||||
{
|
||||
case DataType::UNDEFINED_DATA_TYPE: return 0;
|
||||
case DataType::FP32: return 4;
|
||||
case DataType::FP32_FP32: return 8;
|
||||
case DataType::FP16: return 2;
|
||||
case DataType::FP16_FP16: return 4;
|
||||
case DataType::BF16: return 2;
|
||||
case DataType::BF16_BF16: return 4;
|
||||
case DataType::FP8: return 1;
|
||||
case DataType::BF8: return 1;
|
||||
case DataType::FP64: return 8;
|
||||
case DataType::INT32: return 4;
|
||||
case DataType::I8: return 1;
|
||||
case DataType::I8_I8: return 2;
|
||||
case DataType::U8: return 1;
|
||||
}
|
||||
return 0; // Default case to ensure all control paths return a value
|
||||
}
|
||||
|
||||
} // namespace ck_tile::builder::test
|
||||
|
||||
@@ -72,7 +72,12 @@ inline bool is_xdl_supported()
|
||||
is_gfx12_supported() || is_gfx11_supported();
|
||||
}
|
||||
|
||||
template <typename ADataType, typename BDataType, index_t MPerXDL, index_t NPerXDL>
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
index_t MPerXDL64,
|
||||
index_t NPerXDL64,
|
||||
index_t MPerXDL32 = MPerXDL64,
|
||||
index_t NPerXDL32 = NPerXDL64>
|
||||
inline bool is_xdl_wmma_supported()
|
||||
{
|
||||
if(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" ||
|
||||
@@ -82,7 +87,7 @@ inline bool is_xdl_wmma_supported()
|
||||
}
|
||||
else if(is_gfx12_supported() || is_gfx11_supported())
|
||||
{
|
||||
if constexpr((MPerXDL != 16) || (NPerXDL != 16))
|
||||
if constexpr((MPerXDL32 != 16) || (NPerXDL32 != 16))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -17,6 +17,7 @@
|
||||
#endif
|
||||
#endif
|
||||
#include "ck/utility/get_id.hpp"
|
||||
#include "ck/utility/sequence.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
@@ -96,6 +97,57 @@ static constexpr auto GetNXdlPerWave2()
|
||||
IsWave64>(); \
|
||||
}
|
||||
|
||||
template <index_t BlockSize_,
|
||||
index_t MPerBlock_,
|
||||
index_t NPerBlock_,
|
||||
index_t MPerXDL_,
|
||||
index_t NPerXDL_,
|
||||
index_t MXdlPerWave_,
|
||||
index_t CShuffleMXdlPerWavePerShuffle_,
|
||||
index_t CShuffleNXdlPerWavePerShuffle_,
|
||||
bool IsWave64>
|
||||
static constexpr auto GetWarpTileConfig()
|
||||
{
|
||||
constexpr auto MXdlPerWave64 = MXdlPerWave_;
|
||||
constexpr auto MXdlPerWave32 = MXdlPerWave_ * MPerXDL_ / 16;
|
||||
constexpr auto CShuffleMXdlPerWavePerShuffle32 = CShuffleMXdlPerWavePerShuffle_ * MPerXDL_ / 16;
|
||||
|
||||
constexpr auto NXdlPerWave =
|
||||
IsWave64
|
||||
? GetNXdlPerWave2<BlockSize_,
|
||||
MPerBlock_,
|
||||
NPerBlock_,
|
||||
MPerXDL_,
|
||||
NPerXDL_,
|
||||
MXdlPerWave_,
|
||||
true>()
|
||||
: GetNXdlPerWave2<BlockSize_, MPerBlock_, NPerBlock_, 16, 16, MXdlPerWave32, false>();
|
||||
|
||||
if constexpr(IsWave64 == false && NXdlPerWave != 0)
|
||||
{
|
||||
constexpr auto CShuffleNXdlPerWavePerShuffle32 =
|
||||
NXdlPerWave >= CShuffleNXdlPerWavePerShuffle_ * NPerXDL_ / 16
|
||||
? CShuffleNXdlPerWavePerShuffle_ * NPerXDL_ / 16
|
||||
: CShuffleNXdlPerWavePerShuffle_;
|
||||
static_assert(CShuffleNXdlPerWavePerShuffle32 > 0);
|
||||
return Sequence<16,
|
||||
16,
|
||||
MXdlPerWave32,
|
||||
NXdlPerWave,
|
||||
CShuffleMXdlPerWavePerShuffle32,
|
||||
CShuffleNXdlPerWavePerShuffle32>{};
|
||||
}
|
||||
else
|
||||
{
|
||||
return Sequence<MPerXDL_,
|
||||
NPerXDL_,
|
||||
MXdlPerWave64,
|
||||
NXdlPerWave,
|
||||
CShuffleMXdlPerWavePerShuffle_,
|
||||
CShuffleNXdlPerWavePerShuffle_>{};
|
||||
}
|
||||
}
|
||||
|
||||
#define INVOKER_RUN_IMPL \
|
||||
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) \
|
||||
{ \
|
||||
|
||||
@@ -166,11 +166,27 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
|
||||
{
|
||||
using DeviceOp = DeviceContractionMultipleD_Xdl_CShuffle;
|
||||
|
||||
GET_NXDL_PER_WAVE_IMPL
|
||||
static constexpr auto NXdlPerWave64 = GetNXdlPerWave<true>();
|
||||
static constexpr auto NXdlPerWave32 = GetNXdlPerWave<false>();
|
||||
|
||||
static constexpr index_t NumDTensor = DsDataType::Size();
|
||||
static constexpr auto WarpTileConfig64 = GetWarpTileConfig<BlockSize,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MXdlPerWave,
|
||||
CShuffleMXdlPerWavePerShuffle,
|
||||
CShuffleNXdlPerWavePerShuffle,
|
||||
true>();
|
||||
static constexpr auto WarpTileConfig32 = GetWarpTileConfig<BlockSize,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MXdlPerWave,
|
||||
CShuffleMXdlPerWavePerShuffle,
|
||||
CShuffleNXdlPerWavePerShuffle,
|
||||
false>();
|
||||
static constexpr auto NXdlPerWave64 = WarpTileConfig64.At(3);
|
||||
static constexpr auto NXdlPerWave32 = WarpTileConfig32.At(3);
|
||||
static constexpr index_t NumDTensor = DsDataType::Size();
|
||||
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
@@ -321,7 +337,7 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
|
||||
using EGridDesc_M_N = decltype(MakeEGridDescriptor_M_N({}, {}));
|
||||
|
||||
// GridwiseGemm
|
||||
template <index_t NXdlPerWave_>
|
||||
template <typename WarpTileConfig>
|
||||
using GridwiseGemmBase = GridwiseGemmMultipleD_xdl_cshuffle<
|
||||
ADataType, // TODO: distinguish A/B datatype
|
||||
BDataType,
|
||||
@@ -340,10 +356,10 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
|
||||
KPerBlock,
|
||||
AK1,
|
||||
BK1,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MXdlPerWave,
|
||||
NXdlPerWave_,
|
||||
WarpTileConfig::At(0),
|
||||
WarpTileConfig::At(1),
|
||||
WarpTileConfig::At(2),
|
||||
WarpTileConfig::At(3),
|
||||
ABlockTransferThreadClusterLengths_AK0_M_AK1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
ABlockTransferSrcAccessOrder,
|
||||
@@ -360,13 +376,13 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
|
||||
BBlockTransferDstScalarPerVector_BK1,
|
||||
false,
|
||||
BBlockLdsExtraN,
|
||||
CShuffleMXdlPerWavePerShuffle,
|
||||
CShuffleNXdlPerWavePerShuffle,
|
||||
WarpTileConfig::At(4),
|
||||
WarpTileConfig::At(5),
|
||||
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
CDEBlockTransferScalarPerVector_NPerBlock,
|
||||
LoopSched>;
|
||||
using GridwiseGemm64 = GridwiseGemmBase<math::max(NXdlPerWave64, 1)>;
|
||||
using GridwiseGemm32 = GridwiseGemmBase<NXdlPerWave32>;
|
||||
using GridwiseGemm64 = GridwiseGemmBase<decltype(WarpTileConfig64)>;
|
||||
using GridwiseGemm32 = GridwiseGemmBase<decltype(WarpTileConfig32)>;
|
||||
|
||||
// desc for blockwise copy
|
||||
using AGridDesc_AK0_M_AK1 =
|
||||
@@ -588,7 +604,12 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
|
||||
|
||||
static bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
if(!ck::is_xdl_wmma_supported<ComputeDataType, ComputeDataType, MPerXDL, NPerXDL>())
|
||||
if(!ck::is_xdl_wmma_supported<ComputeDataType,
|
||||
ComputeDataType,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
WarpTileConfig32.At(0),
|
||||
WarpTileConfig32.At(1)>())
|
||||
{
|
||||
return false;
|
||||
}
|
||||
@@ -783,6 +804,8 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
|
||||
<< MPerBlock << ", "
|
||||
<< NPerBlock << ", "
|
||||
<< KPerBlock << ", "
|
||||
<< MPerXDL << ", "
|
||||
<< NPerXDL << ", "
|
||||
<< AK1 << ", "
|
||||
<< BK1 << ", "
|
||||
<< ABlockTransferSrcVectorDim << ", "
|
||||
|
||||
@@ -620,7 +620,44 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
|
||||
bool isWave64 = get_warp_size() == 64;
|
||||
for(std::size_t i = 0; i < arg.gemm_kernel_args_.size(); ++i)
|
||||
{
|
||||
const auto& a = arg.gemm_kernel_args_[i].karg_;
|
||||
const auto& a = arg.gemm_kernel_args_[i].karg_;
|
||||
|
||||
// Validate stride requirements for SplitK (k_batch > 1)
|
||||
// TODO: Enable splitK
|
||||
if(a.k_batch > 1)
|
||||
{
|
||||
if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
if(a.StrideC != a.N)
|
||||
{
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "[" << __func__ << "] group id: " << i
|
||||
<< " SplitK (k_batch=" << a.k_batch
|
||||
<< ") requires contiguous output stride."
|
||||
<< " For RowMajor layout: StrideC must equal N."
|
||||
<< " Got StrideC=" << a.StrideC << ", N=" << a.N << std::endl;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::ColumnMajor>)
|
||||
{
|
||||
if(a.StrideC != a.M)
|
||||
{
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "[" << __func__ << "] group id: " << i
|
||||
<< " SplitK (k_batch=" << a.k_batch
|
||||
<< ") requires contiguous output stride."
|
||||
<< " For ColumnMajor layout: StrideC must equal M."
|
||||
<< " Got StrideC=" << a.StrideC << ", M=" << a.M << std::endl;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bool group_arg_valid = false;
|
||||
if(isWave64)
|
||||
{
|
||||
|
||||
@@ -366,6 +366,26 @@ struct amdgcn_compiler_target_state
|
||||
#else
|
||||
static constexpr bool CK_TILE_ARCH_GFX1010 = false;
|
||||
#endif
|
||||
#if defined(__gfx1011__)
|
||||
static constexpr bool CK_TILE_ARCH_GFX1011 = true;
|
||||
#else
|
||||
static constexpr bool CK_TILE_ARCH_GFX1011 = false;
|
||||
#endif
|
||||
#if defined(__gfx1012__)
|
||||
static constexpr bool CK_TILE_ARCH_GFX1012 = true;
|
||||
#else
|
||||
static constexpr bool CK_TILE_ARCH_GFX1012 = false;
|
||||
#endif
|
||||
#if defined(__gfx1013__)
|
||||
static constexpr bool CK_TILE_ARCH_GFX1013 = true;
|
||||
#else
|
||||
static constexpr bool CK_TILE_ARCH_GFX1013 = false;
|
||||
#endif
|
||||
#if defined(__gfx10_1_generic__)
|
||||
static constexpr bool CK_TILE_ARCH_GFX10_1_GENERIC = true;
|
||||
#else
|
||||
static constexpr bool CK_TILE_ARCH_GFX10_1_GENERIC = false;
|
||||
#endif // __gfx10_1_generic__
|
||||
|
||||
#if defined(__gfx1030__)
|
||||
static constexpr bool CK_TILE_ARCH_GFX1030 = true;
|
||||
@@ -504,6 +524,10 @@ CK_TILE_HOST_DEVICE static constexpr uint32_t count_values_of(T search, Ts... se
|
||||
amdgcn_compiler_target_state::CK_TILE_ARCH_GFX942, \
|
||||
amdgcn_compiler_target_state::CK_TILE_ARCH_GFX950, \
|
||||
amdgcn_compiler_target_state::CK_TILE_ARCH_GFX1010, \
|
||||
amdgcn_compiler_target_state::CK_TILE_ARCH_GFX1011, \
|
||||
amdgcn_compiler_target_state::CK_TILE_ARCH_GFX1012, \
|
||||
amdgcn_compiler_target_state::CK_TILE_ARCH_GFX1013, \
|
||||
amdgcn_compiler_target_state::CK_TILE_ARCH_GFX10_1_GENERIC, \
|
||||
amdgcn_compiler_target_state::CK_TILE_ARCH_GFX1030, \
|
||||
amdgcn_compiler_target_state::CK_TILE_ARCH_GFX1031, \
|
||||
amdgcn_compiler_target_state::CK_TILE_ARCH_GFX1032, \
|
||||
|
||||
@@ -68,7 +68,7 @@ auto shuffle_bq(const ck_tile::HostTensor<T>* t, int block_bq_k)
|
||||
}
|
||||
|
||||
template <typename GemmConfig, typename T>
|
||||
auto shuffle_b(const ck_tile::HostTensor<T>& t)
|
||||
auto shuffle_b(const ck_tile::HostTensor<T>& t, const GemmConfig& gemmConfig)
|
||||
{
|
||||
assert(t.get_lengths().size() == 2);
|
||||
int n_ = t.get_lengths()[1];
|
||||
@@ -78,10 +78,10 @@ auto shuffle_b(const ck_tile::HostTensor<T>& t)
|
||||
{
|
||||
constexpr int divisor = 2;
|
||||
constexpr int kABK1PerLane = 8;
|
||||
constexpr int kABK0PerLane = GemmConfig::K_Warp_Tile / divisor / kABK1PerLane;
|
||||
ck_tile::HostTensor<T> t_view({n_ / GemmConfig::N_Warp_Tile,
|
||||
GemmConfig::N_Warp_Tile,
|
||||
k_ / GemmConfig::K_Warp_Tile,
|
||||
int kABK0PerLane = gemmConfig.K_Warp_Tile / divisor / kABK1PerLane;
|
||||
ck_tile::HostTensor<T> t_view({n_ / gemmConfig.N_Warp_Tile,
|
||||
gemmConfig.N_Warp_Tile,
|
||||
k_ / gemmConfig.K_Warp_Tile,
|
||||
kABK0PerLane,
|
||||
divisor,
|
||||
kABK1PerLane});
|
||||
@@ -98,18 +98,24 @@ auto shuffle_b(const ck_tile::HostTensor<T>& t)
|
||||
else
|
||||
{
|
||||
assert(is_wave32() == false);
|
||||
divisor = GemmConfig::N_Warp_Tile == 32 ? 2 : 4;
|
||||
divisor = gemmConfig.N_Warp_Tile == 32 ? 2 : 4;
|
||||
}
|
||||
ck_tile::HostTensor<T> t_view({n_ / GemmConfig::N_Warp_Tile,
|
||||
GemmConfig::N_Warp_Tile,
|
||||
k_ / GemmConfig::K_Warp_Tile,
|
||||
ck_tile::HostTensor<T> t_view({n_ / gemmConfig.N_Warp_Tile,
|
||||
gemmConfig.N_Warp_Tile,
|
||||
k_ / gemmConfig.K_Warp_Tile,
|
||||
divisor,
|
||||
GemmConfig::K_Warp_Tile / divisor});
|
||||
gemmConfig.K_Warp_Tile / divisor});
|
||||
std::copy(t.begin(), t.end(), t_view.begin());
|
||||
return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4});
|
||||
}
|
||||
}
|
||||
|
||||
template <typename GemmConfig, typename T>
|
||||
auto shuffle_b(const ck_tile::HostTensor<T>& t)
|
||||
{
|
||||
return shuffle_b(t, GemmConfig{});
|
||||
}
|
||||
|
||||
template <typename GemmConfig, typename T>
|
||||
auto bq_permuteN(const ck_tile::HostTensor<T>& t, index_t group_n)
|
||||
{
|
||||
@@ -129,22 +135,22 @@ auto bq_permuteN(const ck_tile::HostTensor<T>& t, index_t group_n)
|
||||
}
|
||||
|
||||
template <typename GemmConfig, typename T>
|
||||
auto shuffle_b_permuteN(const ck_tile::HostTensor<T>& t)
|
||||
auto shuffle_b_permuteN(const ck_tile::HostTensor<T>& t, const GemmConfig& gemmConfig)
|
||||
{
|
||||
assert(t.get_lengths().size() == 2);
|
||||
int n_ = t.get_lengths()[1];
|
||||
int k_ = t.get_lengths()[0];
|
||||
constexpr int NRepeat = GemmConfig::N_Tile / GemmConfig::N_Warp_Tile / GemmConfig::N_Warp;
|
||||
int n_ = t.get_lengths()[1];
|
||||
int k_ = t.get_lengths()[0];
|
||||
int NRepeat = gemmConfig.N_Tile / gemmConfig.N_Warp_Tile / gemmConfig.N_Warp;
|
||||
if(ck_tile::is_gfx12_supported())
|
||||
{
|
||||
constexpr int divisor = 2;
|
||||
constexpr int kABK1PerLane = 8;
|
||||
constexpr int kABK0PerLane = GemmConfig::K_Warp_Tile / divisor / kABK1PerLane;
|
||||
ck_tile::HostTensor<T> t_view({n_ / GemmConfig::N_Tile,
|
||||
GemmConfig::N_Warp,
|
||||
GemmConfig::N_Warp_Tile,
|
||||
int kABK0PerLane = gemmConfig.K_Warp_Tile / divisor / kABK1PerLane;
|
||||
ck_tile::HostTensor<T> t_view({n_ / gemmConfig.N_Tile,
|
||||
gemmConfig.N_Warp,
|
||||
gemmConfig.N_Warp_Tile,
|
||||
NRepeat,
|
||||
k_ / GemmConfig::K_Warp_Tile,
|
||||
k_ / gemmConfig.K_Warp_Tile,
|
||||
kABK0PerLane,
|
||||
divisor,
|
||||
kABK1PerLane});
|
||||
@@ -161,17 +167,23 @@ auto shuffle_b_permuteN(const ck_tile::HostTensor<T>& t)
|
||||
else
|
||||
{
|
||||
assert(is_wave32() == false);
|
||||
divisor = GemmConfig::N_Warp_Tile == 32 ? 2 : 4;
|
||||
divisor = gemmConfig.N_Warp_Tile == 32 ? 2 : 4;
|
||||
}
|
||||
ck_tile::HostTensor<T> t_view({n_ / GemmConfig::N_Tile,
|
||||
GemmConfig::N_Warp,
|
||||
GemmConfig::N_Warp_Tile,
|
||||
ck_tile::HostTensor<T> t_view({n_ / gemmConfig.N_Tile,
|
||||
gemmConfig.N_Warp,
|
||||
gemmConfig.N_Warp_Tile,
|
||||
NRepeat,
|
||||
k_ / GemmConfig::K_Warp_Tile,
|
||||
k_ / gemmConfig.K_Warp_Tile,
|
||||
divisor,
|
||||
GemmConfig::K_Warp_Tile / divisor});
|
||||
gemmConfig.K_Warp_Tile / divisor});
|
||||
std::copy(t.begin(), t.end(), t_view.begin());
|
||||
return ck_tile::reference_permute(t_view, {0, 3, 1, 4, 5, 2, 6});
|
||||
}
|
||||
}
|
||||
|
||||
template <typename GemmConfig, typename T>
|
||||
auto shuffle_b_permuteN(const ck_tile::HostTensor<T>& t)
|
||||
{
|
||||
return shuffle_b_permuteN(t, GemmConfig{});
|
||||
}
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -35,7 +35,8 @@ template <typename AsDataType_,
|
||||
bool FixedVectorSize_ = false,
|
||||
index_t VectorSizeC_ = 1,
|
||||
bool TiledMMAPermuteN_ = false,
|
||||
index_t BlockedXDLN_PerWarp_ = 1> // The number of continuous xdl_output per warp
|
||||
index_t BlockedXDLN_PerWarp_ = 1, // The number of continuous xdl_output per warp
|
||||
bool DoubleSmemBuffer_ = false>
|
||||
struct CShuffleEpilogueProblem
|
||||
{
|
||||
using AsDataType = remove_cvref_t<AsDataType_>;
|
||||
@@ -59,6 +60,7 @@ struct CShuffleEpilogueProblem
|
||||
static constexpr bool FixedVectorSize = FixedVectorSize_;
|
||||
static constexpr index_t VectorSizeC = VectorSizeC_;
|
||||
static constexpr index_t BlockedXDLN_PerWarp = BlockedXDLN_PerWarp_;
|
||||
static constexpr bool DoubleSmemBuffer = DoubleSmemBuffer_;
|
||||
static constexpr bool TiledMMAPermuteN = TiledMMAPermuteN_;
|
||||
static constexpr index_t kNumWaveGroups = kNumWaveGroups_;
|
||||
static constexpr index_t NumDTensor = DsDataType::size();
|
||||
@@ -118,6 +120,7 @@ struct CShuffleEpilogue
|
||||
static constexpr bool FixedVectorSize = Problem::FixedVectorSize;
|
||||
static constexpr bool TiledMMAPermuteN = Problem::TiledMMAPermuteN;
|
||||
static constexpr index_t BlockedXDLN_PerWarp = Problem::BlockedXDLN_PerWarp;
|
||||
static constexpr bool DoubleSmemBuffer = Problem::DoubleSmemBuffer;
|
||||
static constexpr index_t VectorSizeC = Problem::VectorSizeC;
|
||||
static constexpr index_t MPerIteration = MPerXdl * MWave;
|
||||
static constexpr index_t NPerIteration = NPerXdl * NWave;
|
||||
@@ -204,6 +207,26 @@ struct CShuffleEpilogue
|
||||
}
|
||||
return max_vector_size / sizeof(DiDataType);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Shuffle tile configuration parameters check and aligment
|
||||
*
|
||||
* @details Return tuple(1, 1) if shuffle_tile values are too large for SMEM.
|
||||
*/
|
||||
template <index_t m_shuffle_tile, index_t n_shuffle_tile>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto AlignShuffleTileWithSmem()
|
||||
{
|
||||
constexpr index_t m_val = MPerXdl * MWave * m_shuffle_tile;
|
||||
constexpr index_t n_val = NPerXdl * NWave * n_shuffle_tile;
|
||||
|
||||
constexpr auto shuffle_tile =
|
||||
m_val * n_val * sizeof(ODataType) > get_smem_capacity() || DoubleSmemBuffer
|
||||
? std::make_tuple(1, 1)
|
||||
: std::make_tuple(m_shuffle_tile, n_shuffle_tile);
|
||||
|
||||
return shuffle_tile;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Shuffle tile configuration parameters
|
||||
*
|
||||
@@ -214,20 +237,23 @@ struct CShuffleEpilogue
|
||||
*/
|
||||
static constexpr auto shuffle_tile_tuple = [] {
|
||||
constexpr index_t elem_per_thread = MPerXdl * NPerXdl / get_warp_size();
|
||||
if constexpr(elem_per_thread >= GetVectorSizeC())
|
||||
if constexpr(elem_per_thread <= GetVectorSizeC())
|
||||
{
|
||||
return std::make_tuple(1, 1);
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr index_t num_xdl_shuffles = GetVectorSizeC() / elem_per_thread;
|
||||
constexpr index_t num_xdl_shuffles = elem_per_thread / GetVectorSizeC();
|
||||
static_assert(elem_per_thread % GetVectorSizeC() == 0);
|
||||
if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
static_assert((kMPerBlock % (MPerXdl * MWave) == 0) &&
|
||||
(kMPerBlock % num_xdl_shuffles == 0),
|
||||
"kMPerBlock must be divisible by MPerXdl*MWave and "
|
||||
"num_xdl_shuffles for CShuffleEpilogue");
|
||||
return std::make_tuple(min(num_xdl_shuffles, kMPerBlock / (MPerXdl * MWave)), 1);
|
||||
return AlignShuffleTileWithSmem<min(num_xdl_shuffles,
|
||||
kMPerBlock / (MPerXdl * MWave)),
|
||||
1>();
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -235,7 +261,9 @@ struct CShuffleEpilogue
|
||||
(kNPerBlock % num_xdl_shuffles == 0),
|
||||
"kNPerBlock must be divisible by NPerXdl*NWave and "
|
||||
"num_xdl_shuffles for CShuffleEpilogue");
|
||||
return std::make_tuple(1, min(num_xdl_shuffles, kNPerBlock / (NPerXdl * NWave)));
|
||||
return AlignShuffleTileWithSmem<1,
|
||||
min(num_xdl_shuffles,
|
||||
kNPerBlock / (NPerXdl * NWave))>();
|
||||
}
|
||||
}
|
||||
}();
|
||||
|
||||
@@ -232,7 +232,7 @@ struct BatchedGemmKernel
|
||||
|
||||
if constexpr(GemmPipeline::DoubleSmemBuffer == true)
|
||||
{
|
||||
__shared__ char smem_ptr1[GetSmemSize()];
|
||||
__shared__ char smem_ptr1[GemmPipeline::GetSmemSize()];
|
||||
UniversalGemmKernel::RunGemm2LDS({a_ptr},
|
||||
{b_ptr},
|
||||
{/*ds_ptr*/},
|
||||
|
||||
@@ -310,7 +310,7 @@ struct GroupedGemmKernel
|
||||
if constexpr(GemmPipeline::DoubleSmemBuffer == true)
|
||||
{
|
||||
|
||||
__shared__ char smem_ptr_1[GetSmemSize()];
|
||||
__shared__ char smem_ptr_1[GemmPipeline::GetSmemSize()];
|
||||
RunGemmWithPipelineSelection2LDS(a_ptr,
|
||||
b_ptr,
|
||||
c_ptr,
|
||||
|
||||
@@ -1084,7 +1084,7 @@ struct UniversalGemmKernel
|
||||
|
||||
if constexpr(GemmPipeline::DoubleSmemBuffer == true)
|
||||
{
|
||||
__shared__ char smem_ptr_1[GetSmemSize()];
|
||||
__shared__ char smem_ptr_1[GemmPipeline::GetSmemSize()];
|
||||
if constexpr(!(EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add &&
|
||||
EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
|
||||
is_any_of<EDataType, fp16_t, bf16_t>::value))
|
||||
@@ -1169,7 +1169,7 @@ struct UniversalGemmKernel
|
||||
// Run the GEMM
|
||||
if constexpr(GemmPipeline::DoubleSmemBuffer == true)
|
||||
{
|
||||
__shared__ char smem_ptr_1[GetSmemSize()];
|
||||
__shared__ char smem_ptr_1[GemmPipeline::GetSmemSize()];
|
||||
if constexpr(!(EpiloguePipeline::MemoryOperation ==
|
||||
memory_operation_enum::atomic_add &&
|
||||
EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
|
||||
|
||||
@@ -43,4 +43,26 @@ struct TileGemmShape
|
||||
}
|
||||
};
|
||||
|
||||
template <typename PrecType, index_t M_Warp_Tile, bool IsFlatMM = false>
|
||||
constexpr index_t get_k_warp_tile()
|
||||
{
|
||||
#if CK_TILE_USE_WMMA
|
||||
return 16;
|
||||
#else
|
||||
#if defined(CK_GFX950_SUPPORT)
|
||||
constexpr bool is_8bit_float =
|
||||
std::is_same_v<PrecType, fp8_t> || std::is_same_v<PrecType, bf8_t>;
|
||||
if constexpr(M_Warp_Tile == 32)
|
||||
return is_8bit_float ? 64 : 16;
|
||||
else
|
||||
return is_8bit_float ? 128 : 32;
|
||||
#else
|
||||
if constexpr(M_Warp_Tile == 32)
|
||||
return (sizeof(PrecType) == 2 || IsFlatMM == false) ? 16 : 32;
|
||||
else
|
||||
return (sizeof(PrecType) == 2 || IsFlatMM == false) ? 32 : 64;
|
||||
#endif
|
||||
#endif
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -61,6 +61,7 @@ struct BQuantBlockUniversalGemmAsBsCr : public BlockGemmBQuantBase<Problem_>
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
using BQDataType = remove_cvref_t<typename Problem::BQDataType>;
|
||||
using BLayout = remove_cvref_t<typename Problem::BLayout>;
|
||||
using BQLayout = remove_cvref_t<typename Problem::BQLayout>;
|
||||
using ComputeDataType = remove_cvref_t<typename Problem::ComputeDataType>;
|
||||
using CDataType = remove_cvref_t<typename Problem::CDataType>;
|
||||
@@ -156,9 +157,11 @@ struct BQuantBlockUniversalGemmAsBsCr : public BlockGemmBQuantBase<Problem_>
|
||||
using CDataType = remove_cvref_t<typename Traits::CDataType>;
|
||||
|
||||
// BDataType gets converted from PkInt4 during loading
|
||||
using OverrideBDataType =
|
||||
std::conditional_t<std::is_same_v<BDataType, pk_int4_t>, ADataType, BDataType>;
|
||||
|
||||
using OverrideBDataType = std::conditional_t<
|
||||
std::is_same_v<BDataType, pk_int4_t> &&
|
||||
std::is_same_v<typename Traits::BLayout, tensor_layout::gemm::RowMajor>,
|
||||
ADataType,
|
||||
BDataType>;
|
||||
using Base = BlockGemmBQuantBase<Problem_>;
|
||||
|
||||
using WarpGemm = remove_cvref_t<typename Traits::WarpGemm>;
|
||||
|
||||
@@ -1324,7 +1324,7 @@ struct QuantGemmKernel
|
||||
assert(kargs.k_batch == 1);
|
||||
if constexpr(GemmPipeline::DoubleSmemBuffer == true)
|
||||
{
|
||||
__shared__ char smem_ptr_1[GetSmemSize()];
|
||||
__shared__ char smem_ptr_1[GemmPipeline::GetSmemSize()];
|
||||
|
||||
RunGemm2LDS(a_ptr,
|
||||
b_ptr,
|
||||
|
||||
@@ -325,7 +325,7 @@ struct QuantGroupedGemmKernel
|
||||
kQuantType == QuantType::BQuantGrouped)
|
||||
{
|
||||
|
||||
__shared__ char smem_ptr_1[GetSmemSize()];
|
||||
__shared__ char smem_ptr_1[GemmPipeline::GetSmemSize()];
|
||||
RunGemmWithPipelineSelection2LDS(a_ptr,
|
||||
b_ptr,
|
||||
aq_ptr,
|
||||
|
||||
@@ -33,9 +33,17 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Prob
|
||||
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
|
||||
using QuantGroupSize = remove_cvref_t<typename Problem::QuantGroupSize>;
|
||||
|
||||
using ALayout = remove_cvref_t<typename Problem::ALayout>;
|
||||
using BQLayout = remove_cvref_t<typename Problem::BQLayout>;
|
||||
using BLayout = remove_cvref_t<typename Problem::BLayout>;
|
||||
using CLayout = remove_cvref_t<typename Problem::CLayout>;
|
||||
|
||||
// BDataType gets converted from PkInt4 during loading
|
||||
using OverrideBDataType =
|
||||
std::conditional_t<std::is_same_v<BDataType, pk_int4_t>, ADataType, BDataType>;
|
||||
std::conditional_t<std::is_same_v<BDataType, pk_int4_t> &&
|
||||
std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>,
|
||||
ADataType,
|
||||
BDataType>;
|
||||
|
||||
static_assert(QuantGroupSize::kM == 1, "only N/K blocks for BQuant kernel!");
|
||||
using I0 = number<0>;
|
||||
@@ -50,11 +58,6 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Prob
|
||||
static constexpr index_t BQPackedSize =
|
||||
ck_tile::numeric_traits<remove_cvref_t<BQDataType>>::PackedSize;
|
||||
|
||||
using ALayout = remove_cvref_t<typename Problem::ALayout>;
|
||||
using BQLayout = remove_cvref_t<typename Problem::BQLayout>;
|
||||
using BLayout = remove_cvref_t<typename Problem::BLayout>;
|
||||
using CLayout = remove_cvref_t<typename Problem::CLayout>;
|
||||
|
||||
using BlockGemm = remove_cvref_t<decltype(Policy::template GetBlockGemm<Problem>())>;
|
||||
|
||||
static constexpr index_t BlockSize = Problem::kBlockSize;
|
||||
@@ -184,6 +187,23 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Prob
|
||||
load_int4_tile<SrcDataType, DestDataType, UnaryOpSize>(b_block_tile, b_dram_window);
|
||||
}
|
||||
|
||||
template <typename BBlockTile_, typename BDramWindow, typename BDramTileWindowStep>
|
||||
CK_TILE_DEVICE void
|
||||
BGlobalPrefetch(BBlockTile_& b_block_tile,
|
||||
BDramWindow& b_copy_dram_window,
|
||||
const BDramTileWindowStep& b_dram_tile_window_step) const
|
||||
{
|
||||
if constexpr(!std::is_same_v<BDataType, OverrideBDataType>)
|
||||
{
|
||||
LoadAndConvertBTile(b_block_tile, b_copy_dram_window);
|
||||
move_tile_window(b_copy_dram_window, b_dram_tile_window_step);
|
||||
}
|
||||
else
|
||||
{
|
||||
Base::GlobalPrefetch(b_block_tile, b_copy_dram_window, b_dram_tile_window_step);
|
||||
}
|
||||
}
|
||||
|
||||
template <bool HasHotLoop,
|
||||
TailNumber TailNum,
|
||||
typename ADramBlockWindowTmp,
|
||||
@@ -262,7 +282,7 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Prob
|
||||
using ABlockTile =
|
||||
decltype(make_static_distributed_tensor<ADataType>(ABlockTileDistr{}));
|
||||
using BBlockTile =
|
||||
decltype(make_static_distributed_tensor<ADataType>(BBlockTileDistr{}));
|
||||
decltype(make_static_distributed_tensor<OverrideBDataType>(BBlockTileDistr{}));
|
||||
using BQBlockTile =
|
||||
decltype(make_static_distributed_tensor<BQDataType>(BQBlockTileDistr{}));
|
||||
|
||||
@@ -289,8 +309,7 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Prob
|
||||
// DRAM prefetch (global read 0)
|
||||
Base::GlobalPrefetch(a_block_tile, a_copy_dram_window, a_dram_tile_window_step);
|
||||
// B tile gets converted to A datatype during loading
|
||||
LoadAndConvertBTile(b_block_tile, b_copy_dram_window);
|
||||
move_tile_window(b_copy_dram_window, b_dram_tile_window_step);
|
||||
BGlobalPrefetch(b_block_tile, b_copy_dram_window, b_dram_tile_window_step);
|
||||
Base::GlobalPrefetch(
|
||||
bq_block_tile[currIdx], bq_copy_dram_window, bq_dram_tile_window_step);
|
||||
|
||||
@@ -311,7 +330,7 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Prob
|
||||
if constexpr(is_b_row_major && !is_b_load_tr_v())
|
||||
{
|
||||
// B datatype is converted to A datatype during loading
|
||||
auto b_shuffle_tmp = make_static_distributed_tensor<ADataType>(
|
||||
auto b_shuffle_tmp = make_static_distributed_tensor<OverrideBDataType>(
|
||||
Policy::template MakeShuffledBRegTileDistribution<Problem>());
|
||||
transpose_tile2d(b_shuffle_tmp, b_block_tile);
|
||||
Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func);
|
||||
@@ -322,8 +341,8 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Prob
|
||||
}
|
||||
|
||||
Base::GlobalPrefetch(a_block_tile, a_copy_dram_window, a_dram_tile_window_step);
|
||||
LoadAndConvertBTile(b_block_tile, b_copy_dram_window);
|
||||
move_tile_window(b_copy_dram_window, b_dram_tile_window_step);
|
||||
// B tile gets converted to A datatype during loading
|
||||
BGlobalPrefetch(b_block_tile, b_copy_dram_window, b_dram_tile_window_step);
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
@@ -366,8 +385,8 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Prob
|
||||
}
|
||||
|
||||
Base::GlobalPrefetch(a_block_tile, a_copy_dram_window, a_dram_tile_window_step);
|
||||
LoadAndConvertBTile(b_block_tile, b_copy_dram_window);
|
||||
move_tile_window(b_copy_dram_window, b_dram_tile_window_step);
|
||||
// B tile gets converted to A datatype during loading
|
||||
BGlobalPrefetch(b_block_tile, b_copy_dram_window, b_dram_tile_window_step);
|
||||
Base::GlobalPrefetch(bq_block_tile[(currIdx + 1) % 2],
|
||||
bq_copy_dram_window,
|
||||
bq_dram_tile_window_step);
|
||||
|
||||
@@ -1048,7 +1048,7 @@ struct GroupedConvolutionBackwardDataKernel
|
||||
|
||||
if constexpr(GemmPipeline::DoubleSmemBuffer == true)
|
||||
{
|
||||
__shared__ char smem_ptr_1[GetSmemSize()];
|
||||
__shared__ char smem_ptr_1[GemmPipeline::GetSmemSize()];
|
||||
if constexpr(!(EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add &&
|
||||
GroupedConvTraitsType_::VectorSizeC % 2 != 0 &&
|
||||
is_any_of<OutDataType, fp16_t, bf16_t>::value))
|
||||
|
||||
@@ -1005,7 +1005,7 @@ struct GroupedConvolutionBackwardWeightKernel
|
||||
|
||||
if constexpr(GemmPipeline::DoubleSmemBuffer == true)
|
||||
{
|
||||
__shared__ char smem_ptr_1[GetSmemSize()];
|
||||
__shared__ char smem_ptr_1[GemmPipeline::GetSmemSize()];
|
||||
if constexpr(!(EpiloguePipeline::MemoryOperation ==
|
||||
memory_operation_enum::atomic_add &&
|
||||
GroupedConvTraitsType_::VectorSizeC % 2 != 0 &&
|
||||
|
||||
@@ -1184,7 +1184,7 @@ struct GroupedConvolutionForwardKernel
|
||||
|
||||
if constexpr(GemmPipeline::DoubleSmemBuffer == true)
|
||||
{
|
||||
__shared__ char smem_ptr_1[GetSmemSize()];
|
||||
__shared__ char smem_ptr_1[GemmPipeline::GetSmemSize()];
|
||||
if constexpr(!(EpiloguePipeline::MemoryOperation ==
|
||||
memory_operation_enum::atomic_add &&
|
||||
GroupedConvTraitsType_::VectorSizeC % 2 != 0 &&
|
||||
|
||||
@@ -202,7 +202,13 @@ class TestCkTileGemmPipeline : public ::testing::Test
|
||||
N_Warp_Tile,
|
||||
K_Warp_Tile,
|
||||
UniversalGemmProblem::TransposeC,
|
||||
memory_operation>>;
|
||||
memory_operation,
|
||||
1, /*kNumWaveGroups_*/
|
||||
false, /*FixedVectorSize_*/
|
||||
1, /*VectorSizeC_*/
|
||||
false, /*TiledMMAPermuteN_*/
|
||||
1, /*BlockedXDLN_PerWarp_*/
|
||||
DoubleSmemBuffer /*DoubleSmemBuffer*/>>;
|
||||
|
||||
using Kernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
|
||||
auto kargs = Kernel::MakeKernelArgs(args);
|
||||
|
||||
@@ -11,25 +11,93 @@ list(APPEND TEST_GEMM_COMPILE_OPTIONS -mllvm -enable-noalias-to-md-conversion=0)
|
||||
if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx12")
|
||||
# Typed Test Suite for GEMM Quantization - split into multiple files to reduce compile time
|
||||
|
||||
# AQuant tests
|
||||
add_gtest_executable(test_tile_gemm_quant_aquant
|
||||
test_gemm_quant_aquant.cpp
|
||||
# AQuant tests - split into 6 files
|
||||
add_gtest_executable(test_tile_gemm_quant_aquant_base_rcr
|
||||
test_gemm_quant_aquant_base_rcr.cpp
|
||||
)
|
||||
target_compile_options(test_tile_gemm_quant_aquant PRIVATE ${TEST_GEMM_COMPILE_OPTIONS})
|
||||
target_compile_options(test_tile_gemm_quant_aquant_base_rcr PRIVATE ${TEST_GEMM_COMPILE_OPTIONS})
|
||||
|
||||
# BQuant tests (without PreshuffleB)
|
||||
add_gtest_executable(test_tile_gemm_quant_bquant
|
||||
test_gemm_quant_bquant.cpp
|
||||
add_gtest_executable(test_tile_gemm_quant_aquant_base_rrr_crr
|
||||
test_gemm_quant_aquant_base_rrr_crr.cpp
|
||||
)
|
||||
target_compile_options(test_tile_gemm_quant_bquant PRIVATE ${TEST_GEMM_COMPILE_OPTIONS})
|
||||
target_compile_options(test_tile_gemm_quant_aquant_base_rrr_crr PRIVATE ${TEST_GEMM_COMPILE_OPTIONS})
|
||||
|
||||
# BQuant tests (with PreshuffleB)
|
||||
# disabling this test until it can be built within reasonable time!
|
||||
# currently taking ~50 minutes on gfx12!
|
||||
#add_gtest_executable(test_tile_gemm_quant_bquant_preshuffle
|
||||
# test_gemm_quant_bquant_preshuffle.cpp
|
||||
#)
|
||||
#target_compile_options(test_tile_gemm_quant_bquant_preshuffle PRIVATE ${TEST_GEMM_COMPILE_OPTIONS})
|
||||
add_gtest_executable(test_tile_gemm_quant_aquant_base_ccr
|
||||
test_gemm_quant_aquant_base_ccr.cpp
|
||||
)
|
||||
target_compile_options(test_tile_gemm_quant_aquant_base_ccr PRIVATE ${TEST_GEMM_COMPILE_OPTIONS})
|
||||
|
||||
add_gtest_executable(test_tile_gemm_quant_aquant_prefill
|
||||
test_gemm_quant_aquant_prefill.cpp
|
||||
)
|
||||
target_compile_options(test_tile_gemm_quant_aquant_prefill PRIVATE ${TEST_GEMM_COMPILE_OPTIONS})
|
||||
|
||||
add_gtest_executable(test_tile_gemm_quant_aquant_transpose_c
|
||||
test_gemm_quant_aquant_transpose_c.cpp
|
||||
)
|
||||
target_compile_options(test_tile_gemm_quant_aquant_transpose_c PRIVATE ${TEST_GEMM_COMPILE_OPTIONS})
|
||||
|
||||
add_gtest_executable(test_tile_gemm_quant_aquant_preshuffle
|
||||
test_gemm_quant_aquant_preshuffle.cpp
|
||||
)
|
||||
target_compile_options(test_tile_gemm_quant_aquant_preshuffle PRIVATE ${TEST_GEMM_COMPILE_OPTIONS})
|
||||
|
||||
# BQuant tests (without PreshuffleB) - split into 6 files
|
||||
add_gtest_executable(test_tile_gemm_quant_bquant_1d_128
|
||||
test_gemm_quant_bquant_1d_128.cpp
|
||||
)
|
||||
target_compile_options(test_tile_gemm_quant_bquant_1d_128 PRIVATE ${TEST_GEMM_COMPILE_OPTIONS})
|
||||
|
||||
add_gtest_executable(test_tile_gemm_quant_bquant_1d_64
|
||||
test_gemm_quant_bquant_1d_64.cpp
|
||||
)
|
||||
target_compile_options(test_tile_gemm_quant_bquant_1d_64 PRIVATE ${TEST_GEMM_COMPILE_OPTIONS})
|
||||
|
||||
add_gtest_executable(test_tile_gemm_quant_bquant_2d_small_n
|
||||
test_gemm_quant_bquant_2d_small_n.cpp
|
||||
)
|
||||
target_compile_options(test_tile_gemm_quant_bquant_2d_small_n PRIVATE ${TEST_GEMM_COMPILE_OPTIONS})
|
||||
|
||||
add_gtest_executable(test_tile_gemm_quant_bquant_2d_medium_n
|
||||
test_gemm_quant_bquant_2d_medium_n.cpp
|
||||
)
|
||||
target_compile_options(test_tile_gemm_quant_bquant_2d_medium_n PRIVATE ${TEST_GEMM_COMPILE_OPTIONS})
|
||||
|
||||
add_gtest_executable(test_tile_gemm_quant_bquant_2d_large_n
|
||||
test_gemm_quant_bquant_2d_large_n.cpp
|
||||
)
|
||||
target_compile_options(test_tile_gemm_quant_bquant_2d_large_n PRIVATE ${TEST_GEMM_COMPILE_OPTIONS})
|
||||
|
||||
add_gtest_executable(test_tile_gemm_quant_bquant_transpose
|
||||
test_gemm_quant_bquant_transpose.cpp
|
||||
)
|
||||
target_compile_options(test_tile_gemm_quant_bquant_transpose PRIVATE ${TEST_GEMM_COMPILE_OPTIONS})
|
||||
|
||||
# BQuant tests (with PreshuffleB) - split into 5 files
|
||||
add_gtest_executable(test_tile_gemm_quant_bquant_preshuffle_decode_1d
|
||||
test_gemm_quant_bquant_preshuffle_decode_1d.cpp
|
||||
)
|
||||
target_compile_options(test_tile_gemm_quant_bquant_preshuffle_decode_1d PRIVATE ${TEST_GEMM_COMPILE_OPTIONS})
|
||||
|
||||
add_gtest_executable(test_tile_gemm_quant_bquant_preshuffle_prefill_1d
|
||||
test_gemm_quant_bquant_preshuffle_prefill_1d.cpp
|
||||
)
|
||||
target_compile_options(test_tile_gemm_quant_bquant_preshuffle_prefill_1d PRIVATE ${TEST_GEMM_COMPILE_OPTIONS})
|
||||
|
||||
add_gtest_executable(test_tile_gemm_quant_bquant_preshuffle_tiled_permute
|
||||
test_gemm_quant_bquant_preshuffle_tiled_permute.cpp
|
||||
)
|
||||
target_compile_options(test_tile_gemm_quant_bquant_preshuffle_tiled_permute PRIVATE ${TEST_GEMM_COMPILE_OPTIONS})
|
||||
|
||||
add_gtest_executable(test_tile_gemm_quant_bquant_preshuffle_decode_2d
|
||||
test_gemm_quant_bquant_preshuffle_decode_2d.cpp
|
||||
)
|
||||
target_compile_options(test_tile_gemm_quant_bquant_preshuffle_decode_2d PRIVATE ${TEST_GEMM_COMPILE_OPTIONS})
|
||||
|
||||
add_gtest_executable(test_tile_gemm_quant_bquant_preshuffle_prefill_2d
|
||||
test_gemm_quant_bquant_preshuffle_prefill_2d.cpp
|
||||
)
|
||||
target_compile_options(test_tile_gemm_quant_bquant_preshuffle_prefill_2d PRIVATE ${TEST_GEMM_COMPILE_OPTIONS})
|
||||
|
||||
# RowColQuant tests
|
||||
add_gtest_executable(test_tile_gemm_quant_rowcol
|
||||
@@ -42,6 +110,34 @@ if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx12")
|
||||
test_gemm_quant_tensor.cpp
|
||||
)
|
||||
target_compile_options(test_tile_gemm_quant_tensor PRIVATE ${TEST_GEMM_COMPILE_OPTIONS})
|
||||
|
||||
# Umbrella target to build all gemm quant tests
|
||||
add_custom_target(test_tile_gemm_quant_all)
|
||||
add_dependencies(test_tile_gemm_quant_all
|
||||
# AQuant tests
|
||||
test_tile_gemm_quant_aquant_base_rcr
|
||||
test_tile_gemm_quant_aquant_base_rrr_crr
|
||||
test_tile_gemm_quant_aquant_base_ccr
|
||||
test_tile_gemm_quant_aquant_prefill
|
||||
test_tile_gemm_quant_aquant_transpose_c
|
||||
test_tile_gemm_quant_aquant_preshuffle
|
||||
# BQuant tests
|
||||
test_tile_gemm_quant_bquant_1d_128
|
||||
test_tile_gemm_quant_bquant_1d_64
|
||||
test_tile_gemm_quant_bquant_2d_small_n
|
||||
test_tile_gemm_quant_bquant_2d_medium_n
|
||||
test_tile_gemm_quant_bquant_2d_large_n
|
||||
test_tile_gemm_quant_bquant_transpose
|
||||
# BQuant preshuffle tests
|
||||
test_tile_gemm_quant_bquant_preshuffle_decode_1d
|
||||
test_tile_gemm_quant_bquant_preshuffle_prefill_1d
|
||||
test_tile_gemm_quant_bquant_preshuffle_tiled_permute
|
||||
test_tile_gemm_quant_bquant_preshuffle_decode_2d
|
||||
test_tile_gemm_quant_bquant_preshuffle_prefill_2d
|
||||
# Other quant tests
|
||||
test_tile_gemm_quant_rowcol
|
||||
test_tile_gemm_quant_tensor
|
||||
)
|
||||
else()
|
||||
message(DEBUG "Skipping ck_tile quant gemm tests for current target")
|
||||
endif()
|
||||
|
||||
@@ -1,95 +0,0 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include <memory>
|
||||
|
||||
#include "test_gemm_quant_fixtures.hpp"
|
||||
|
||||
// Type aliases for readability
|
||||
using RowMajor = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
using FP8 = ck_tile::fp8_t;
|
||||
using BF8 = ck_tile::bf8_t;
|
||||
using Half = ck_tile::half_t;
|
||||
using PkInt4 = ck_tile::pk_int4_t;
|
||||
using AQuantGrouped = std::integral_constant<ck_tile::QuantType, ck_tile::QuantType::AQuantGrouped>;
|
||||
using BQuantGrouped = std::integral_constant<ck_tile::QuantType, ck_tile::QuantType::BQuantGrouped>;
|
||||
using RowColQuant = std::integral_constant<ck_tile::QuantType, ck_tile::QuantType::RowColQuant>;
|
||||
using TensorQuant = std::integral_constant<ck_tile::QuantType, ck_tile::QuantType::TensorQuant>;
|
||||
using GroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
using GroupSize64 = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 64>>;
|
||||
|
||||
// 2d block sizes for BQuant
|
||||
using GroupSize2D8N = ck_tile::QuantGroupShape<ck_tile::sequence<1, 8, 128>>;
|
||||
using GroupSize2D16N = ck_tile::QuantGroupShape<ck_tile::sequence<1, 16, 128>>;
|
||||
using GroupSize2D32N = ck_tile::QuantGroupShape<ck_tile::sequence<1, 32, 128>>;
|
||||
using GroupSize2D64N = ck_tile::QuantGroupShape<ck_tile::sequence<1, 64, 128>>;
|
||||
using GroupSize2D128N = ck_tile::QuantGroupShape<ck_tile::sequence<1, 128, 128>>;
|
||||
|
||||
// Type combinations for AQuant tests
|
||||
// Tuple format: <ALayout, BLayout, CLayout, AQLayout, ADataType, BDataType, QDataType, CDataType,
|
||||
// QuantType, GemmConfig, QuantGroupSize>
|
||||
// clang-format off
|
||||
using AQuantTypes = ::testing::Types<
|
||||
// PreshuffleQuant = false && TransposeC = false (RCR layout with RowMajor AQ)
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, FP8, FP8, float, Half, AQuantGrouped, GemmConfigBase, GroupSize>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, BF8, BF8, float, Half, AQuantGrouped, GemmConfigBase, GroupSize>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, PkInt4, FP8, FP8, Half, AQuantGrouped, GemmConfigBase, GroupSize>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, PkInt4, BF8, BF8, Half, AQuantGrouped, GemmConfigBase, GroupSize>,
|
||||
|
||||
// RRR layout (RowMajor A, RowMajor B, RowMajor C with RowMajor AQ)
|
||||
std::tuple<RowMajor, RowMajor, RowMajor, RowMajor, FP8, FP8, float, Half, AQuantGrouped, GemmConfigBase, GroupSize>,
|
||||
std::tuple<RowMajor, RowMajor, RowMajor, RowMajor, BF8, BF8, float, Half, AQuantGrouped, GemmConfigBase, GroupSize>,
|
||||
std::tuple<RowMajor, RowMajor, RowMajor, RowMajor, PkInt4, FP8, FP8, Half, AQuantGrouped, GemmConfigBase, GroupSize>,
|
||||
std::tuple<RowMajor, RowMajor, RowMajor, RowMajor, PkInt4, BF8, BF8, Half, AQuantGrouped, GemmConfigBase, GroupSize>,
|
||||
|
||||
// CRR layout (ColumnMajor A, RowMajor B, RowMajor C with RowMajor AQ)
|
||||
std::tuple<ColumnMajor, RowMajor, RowMajor, RowMajor, FP8, FP8, float, Half, AQuantGrouped, GemmConfigBase, GroupSize>,
|
||||
std::tuple<ColumnMajor, RowMajor, RowMajor, RowMajor, BF8, BF8, float, Half, AQuantGrouped, GemmConfigBase, GroupSize>,
|
||||
std::tuple<ColumnMajor, RowMajor, RowMajor, RowMajor, PkInt4, FP8, FP8, Half, AQuantGrouped, GemmConfigBase, GroupSize>,
|
||||
std::tuple<ColumnMajor, RowMajor, RowMajor, RowMajor, PkInt4, BF8, BF8, Half, AQuantGrouped, GemmConfigBase, GroupSize>,
|
||||
|
||||
// CCR layout (ColumnMajor A, ColumnMajor B, RowMajor C with ColumnMajor AQ) - NEW layout support
|
||||
std::tuple<ColumnMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, FP8, float, Half, AQuantGrouped, GemmConfigBase, GroupSize>,
|
||||
std::tuple<ColumnMajor, ColumnMajor, RowMajor, ColumnMajor, BF8, BF8, float, Half, AQuantGrouped, GemmConfigBase, GroupSize>,
|
||||
std::tuple<ColumnMajor, ColumnMajor, RowMajor, ColumnMajor, PkInt4, FP8, FP8, Half, AQuantGrouped, GemmConfigBase, GroupSize>,
|
||||
std::tuple<ColumnMajor, ColumnMajor, RowMajor, ColumnMajor, PkInt4, BF8, BF8, Half, AQuantGrouped, GemmConfigBase, GroupSize>,
|
||||
|
||||
// RCR layout - with the Prefill BlockTile Config.
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, FP8, FP8, float, Half, AQuantGrouped, GemmConfigPrefill, GroupSize>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, BF8, BF8, float, Half, AQuantGrouped, GemmConfigPrefill, GroupSize>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, PkInt4, FP8, FP8, Half, AQuantGrouped, GemmConfigPrefill, GroupSize>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, PkInt4, BF8, BF8, Half, AQuantGrouped, GemmConfigPrefill, GroupSize>,
|
||||
|
||||
// PreshuffleQuant = false && TransposeC = true (with RowMajor AQ)
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, FP8, FP8, float, Half, AQuantGrouped, GemmConfigTransposeC, GroupSize>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, BF8, BF8, float, Half, AQuantGrouped, GemmConfigTransposeC, GroupSize>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, PkInt4, FP8, FP8, Half, AQuantGrouped, GemmConfigTransposeC, GroupSize>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, PkInt4, BF8, BF8, Half, AQuantGrouped, GemmConfigTransposeC, GroupSize>,
|
||||
|
||||
// PreshuffleQuant = true && TransposeC = false (with RowMajor AQ - PreshuffleQuant only supports RowMajor)
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, FP8, FP8, float, Half, AQuantGrouped, GemmConfigPreshuffleQuant, GroupSize>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, BF8, BF8, float, Half, AQuantGrouped, GemmConfigPreshuffleQuant, GroupSize>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, PkInt4, FP8, FP8, Half, AQuantGrouped, GemmConfigPreshuffleQuant, GroupSize>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, PkInt4, BF8, BF8, Half, AQuantGrouped, GemmConfigPreshuffleQuant, GroupSize>,
|
||||
|
||||
// PreshuffleQuant = true && TransposeC = true (with RowMajor AQ - PreshuffleQuant only supports RowMajor)
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, FP8, FP8, float, Half, AQuantGrouped, GemmConfigPreshuffleQuantTransposeC, GroupSize>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, BF8, BF8, float, Half, AQuantGrouped, GemmConfigPreshuffleQuantTransposeC, GroupSize>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, PkInt4, FP8, FP8, Half, AQuantGrouped, GemmConfigPreshuffleQuantTransposeC, GroupSize>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, PkInt4, BF8, BF8, Half, AQuantGrouped, GemmConfigPreshuffleQuantTransposeC, GroupSize>
|
||||
>;
|
||||
// clang-format on
|
||||
|
||||
// Test suite for AQuant
|
||||
TYPED_TEST_SUITE(TestCkTileGemmAQuant, AQuantTypes);
|
||||
|
||||
// AQuant tests
|
||||
TYPED_TEST(TestCkTileGemmAQuant, AQuantGroupedTest)
|
||||
{
|
||||
this->run_test_with_validation(1024, 1024, 1024);
|
||||
}
|
||||
@@ -0,0 +1,42 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include <memory>
|
||||
|
||||
#include "test_gemm_quant_fixtures.hpp"
|
||||
|
||||
// Type aliases for readability
|
||||
using RowMajor = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
using FP8 = ck_tile::fp8_t;
|
||||
using BF8 = ck_tile::bf8_t;
|
||||
using Half = ck_tile::half_t;
|
||||
using PkInt4 = ck_tile::pk_int4_t;
|
||||
using AQuantGrouped = std::integral_constant<ck_tile::QuantType, ck_tile::QuantType::AQuantGrouped>;
|
||||
using GroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
|
||||
// Type combinations for AQuant tests - CCR layout
|
||||
// Tuple format: <ALayout, BLayout, CLayout, AQLayout, ADataType, BDataType, QDataType, CDataType,
|
||||
// QuantType, GemmConfig, QuantGroupSize>
|
||||
// clang-format off
|
||||
using AQuantBaseCCRTypes = ::testing::Types<
|
||||
// CCR layout (ColumnMajor A, ColumnMajor B, RowMajor C with ColumnMajor AQ) - NEW layout support
|
||||
std::tuple<ColumnMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, FP8, float, Half, AQuantGrouped, GemmConfigBase, GroupSize>,
|
||||
std::tuple<ColumnMajor, ColumnMajor, RowMajor, ColumnMajor, BF8, BF8, float, Half, AQuantGrouped, GemmConfigBase, GroupSize>,
|
||||
std::tuple<ColumnMajor, ColumnMajor, RowMajor, ColumnMajor, PkInt4, FP8, FP8, Half, AQuantGrouped, GemmConfigBase, GroupSize>,
|
||||
std::tuple<ColumnMajor, ColumnMajor, RowMajor, ColumnMajor, PkInt4, BF8, BF8, Half, AQuantGrouped, GemmConfigBase, GroupSize>
|
||||
>;
|
||||
// clang-format on
|
||||
|
||||
// Test suite for AQuant Base CCR
|
||||
TYPED_TEST_SUITE(TestCkTileGemmAQuant, AQuantBaseCCRTypes);
|
||||
|
||||
// AQuant tests
|
||||
TYPED_TEST(TestCkTileGemmAQuant, AQuantGroupedTest)
|
||||
{
|
||||
this->run_test_with_validation(1024, 1024, 1024);
|
||||
}
|
||||
@@ -0,0 +1,42 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include <memory>
|
||||
|
||||
#include "test_gemm_quant_fixtures.hpp"
|
||||
|
||||
// Type aliases for readability
|
||||
using RowMajor = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
using FP8 = ck_tile::fp8_t;
|
||||
using BF8 = ck_tile::bf8_t;
|
||||
using Half = ck_tile::half_t;
|
||||
using PkInt4 = ck_tile::pk_int4_t;
|
||||
using AQuantGrouped = std::integral_constant<ck_tile::QuantType, ck_tile::QuantType::AQuantGrouped>;
|
||||
using GroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
|
||||
// Type combinations for AQuant tests - RCR layout base configuration
|
||||
// Tuple format: <ALayout, BLayout, CLayout, AQLayout, ADataType, BDataType, QDataType, CDataType,
|
||||
// QuantType, GemmConfig, QuantGroupSize>
|
||||
// clang-format off
|
||||
using AQuantBaseRCRTypes = ::testing::Types<
|
||||
// PreshuffleQuant = false && TransposeC = false (RCR layout with RowMajor AQ)
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, FP8, FP8, float, Half, AQuantGrouped, GemmConfigBase, GroupSize>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, BF8, BF8, float, Half, AQuantGrouped, GemmConfigBase, GroupSize>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, PkInt4, FP8, FP8, Half, AQuantGrouped, GemmConfigBase, GroupSize>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, PkInt4, BF8, BF8, Half, AQuantGrouped, GemmConfigBase, GroupSize>
|
||||
>;
|
||||
// clang-format on
|
||||
|
||||
// Test suite for AQuant Base RCR
|
||||
TYPED_TEST_SUITE(TestCkTileGemmAQuant, AQuantBaseRCRTypes);
|
||||
|
||||
// AQuant tests
|
||||
TYPED_TEST(TestCkTileGemmAQuant, AQuantGroupedTest)
|
||||
{
|
||||
this->run_test_with_validation(1024, 1024, 1024);
|
||||
}
|
||||
@@ -0,0 +1,46 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include <memory>
|
||||
|
||||
#include "test_gemm_quant_fixtures.hpp"
|
||||
|
||||
// Type aliases for readability
|
||||
using RowMajor = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
using FP8 = ck_tile::fp8_t;
|
||||
using BF8 = ck_tile::bf8_t;
|
||||
using Half = ck_tile::half_t;
|
||||
using PkInt4 = ck_tile::pk_int4_t;
|
||||
using AQuantGrouped = std::integral_constant<ck_tile::QuantType, ck_tile::QuantType::AQuantGrouped>;
|
||||
using GroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
|
||||
// Type combinations for AQuant tests - RRR and CRR layouts
|
||||
// Tuple format: <ALayout, BLayout, CLayout, AQLayout, ADataType, BDataType, QDataType, CDataType,
|
||||
// QuantType, GemmConfig, QuantGroupSize>
|
||||
// clang-format off
|
||||
using AQuantBaseRRRCRRTypes = ::testing::Types<
|
||||
// RRR layout (RowMajor A, RowMajor B, RowMajor C with RowMajor AQ)
|
||||
std::tuple<RowMajor, RowMajor, RowMajor, RowMajor, FP8, FP8, float, Half, AQuantGrouped, GemmConfigBase, GroupSize>,
|
||||
std::tuple<RowMajor, RowMajor, RowMajor, RowMajor, BF8, BF8, float, Half, AQuantGrouped, GemmConfigBase, GroupSize>,
|
||||
std::tuple<RowMajor, RowMajor, RowMajor, RowMajor, PkInt4, FP8, FP8, Half, AQuantGrouped, GemmConfigBase, GroupSize>,
|
||||
std::tuple<RowMajor, RowMajor, RowMajor, RowMajor, PkInt4, BF8, BF8, Half, AQuantGrouped, GemmConfigBase, GroupSize>,
|
||||
|
||||
// CRR layout (ColumnMajor A, RowMajor B, RowMajor C with RowMajor AQ)
|
||||
std::tuple<ColumnMajor, RowMajor, RowMajor, RowMajor, FP8, FP8, float, Half, AQuantGrouped, GemmConfigBase, GroupSize>,
|
||||
std::tuple<ColumnMajor, RowMajor, RowMajor, RowMajor, PkInt4, FP8, FP8, Half, AQuantGrouped, GemmConfigBase, GroupSize>
|
||||
>;
|
||||
// clang-format on
|
||||
|
||||
// Test suite for AQuant Base RRR/CRR
|
||||
TYPED_TEST_SUITE(TestCkTileGemmAQuant, AQuantBaseRRRCRRTypes);
|
||||
|
||||
// AQuant tests
|
||||
TYPED_TEST(TestCkTileGemmAQuant, AQuantGroupedTest)
|
||||
{
|
||||
this->run_test_with_validation(1024, 1024, 1024);
|
||||
}
|
||||
@@ -0,0 +1,41 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include <memory>
|
||||
|
||||
#include "test_gemm_quant_fixtures.hpp"
|
||||
|
||||
// Type aliases for readability
|
||||
using RowMajor = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
using FP8 = ck_tile::fp8_t;
|
||||
using BF8 = ck_tile::bf8_t;
|
||||
using Half = ck_tile::half_t;
|
||||
using PkInt4 = ck_tile::pk_int4_t;
|
||||
using AQuantGrouped = std::integral_constant<ck_tile::QuantType, ck_tile::QuantType::AQuantGrouped>;
|
||||
using GroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
|
||||
// Type combinations for AQuant tests - Prefill Configuration
|
||||
// Tuple format: <ALayout, BLayout, CLayout, AQLayout, ADataType, BDataType, QDataType, CDataType,
|
||||
// QuantType, GemmConfig, QuantGroupSize>
|
||||
// clang-format off
|
||||
using AQuantPrefillTypes = ::testing::Types<
|
||||
// RCR layout - with the Prefill BlockTile Config.
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, FP8, FP8, float, Half, AQuantGrouped, GemmConfigPrefill, GroupSize>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, PkInt4, FP8, FP8, Half, AQuantGrouped, GemmConfigPrefill, GroupSize>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, PkInt4, BF8, BF8, Half, AQuantGrouped, GemmConfigPrefill, GroupSize>
|
||||
>;
|
||||
// clang-format on
|
||||
|
||||
// Test suite for AQuant Prefill
|
||||
TYPED_TEST_SUITE(TestCkTileGemmAQuant, AQuantPrefillTypes);
|
||||
|
||||
// AQuant tests
|
||||
TYPED_TEST(TestCkTileGemmAQuant, AQuantGroupedTest)
|
||||
{
|
||||
this->run_test_with_validation(1024, 1024, 1024);
|
||||
}
|
||||
@@ -0,0 +1,48 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include <memory>
|
||||
|
||||
#include "test_gemm_quant_fixtures.hpp"
|
||||
|
||||
// Type aliases for readability
|
||||
using RowMajor = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
using FP8 = ck_tile::fp8_t;
|
||||
using BF8 = ck_tile::bf8_t;
|
||||
using Half = ck_tile::half_t;
|
||||
using PkInt4 = ck_tile::pk_int4_t;
|
||||
using AQuantGrouped = std::integral_constant<ck_tile::QuantType, ck_tile::QuantType::AQuantGrouped>;
|
||||
using GroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
|
||||
// Type combinations for AQuant tests - PreshuffleQuant Configurations
|
||||
// Tuple format: <ALayout, BLayout, CLayout, AQLayout, ADataType, BDataType, QDataType, CDataType,
|
||||
// QuantType, GemmConfig, QuantGroupSize>
|
||||
// clang-format off
|
||||
using AQuantPreshuffleTypes = ::testing::Types<
|
||||
// PreshuffleQuant = true && TransposeC = false (with RowMajor AQ - PreshuffleQuant only supports RowMajor)
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, FP8, FP8, float, Half, AQuantGrouped, GemmConfigPreshuffleQuant, GroupSize>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, BF8, BF8, float, Half, AQuantGrouped, GemmConfigPreshuffleQuant, GroupSize>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, PkInt4, FP8, FP8, Half, AQuantGrouped, GemmConfigPreshuffleQuant, GroupSize>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, PkInt4, BF8, BF8, Half, AQuantGrouped, GemmConfigPreshuffleQuant, GroupSize>,
|
||||
|
||||
// PreshuffleQuant = true && TransposeC = true (with RowMajor AQ - PreshuffleQuant only supports RowMajor)
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, FP8, FP8, float, Half, AQuantGrouped, GemmConfigPreshuffleQuantTransposeC, GroupSize>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, BF8, BF8, float, Half, AQuantGrouped, GemmConfigPreshuffleQuantTransposeC, GroupSize>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, PkInt4, FP8, FP8, Half, AQuantGrouped, GemmConfigPreshuffleQuantTransposeC, GroupSize>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, PkInt4, BF8, BF8, Half, AQuantGrouped, GemmConfigPreshuffleQuantTransposeC, GroupSize>
|
||||
>;
|
||||
// clang-format on
|
||||
|
||||
// Test suite for AQuant Preshuffle
|
||||
TYPED_TEST_SUITE(TestCkTileGemmAQuant, AQuantPreshuffleTypes);
|
||||
|
||||
// AQuant tests
|
||||
TYPED_TEST(TestCkTileGemmAQuant, AQuantGroupedTest)
|
||||
{
|
||||
this->run_test_with_validation(1024, 1024, 1024);
|
||||
}
|
||||
@@ -0,0 +1,40 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include <memory>
|
||||
|
||||
#include "test_gemm_quant_fixtures.hpp"
|
||||
|
||||
// Type aliases for readability
|
||||
using RowMajor = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
using FP8 = ck_tile::fp8_t;
|
||||
using BF8 = ck_tile::bf8_t;
|
||||
using Half = ck_tile::half_t;
|
||||
using PkInt4 = ck_tile::pk_int4_t;
|
||||
using AQuantGrouped = std::integral_constant<ck_tile::QuantType, ck_tile::QuantType::AQuantGrouped>;
|
||||
using GroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
|
||||
// Type combinations for AQuant tests - TransposeC Configuration
|
||||
// Tuple format: <ALayout, BLayout, CLayout, AQLayout, ADataType, BDataType, QDataType, CDataType,
|
||||
// QuantType, GemmConfig, QuantGroupSize>
|
||||
// clang-format off
|
||||
using AQuantTransposeCTypes = ::testing::Types<
|
||||
// PreshuffleQuant = false && TransposeC = true (with RowMajor AQ)
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, FP8, FP8, float, Half, AQuantGrouped, GemmConfigTransposeC, GroupSize>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, PkInt4, FP8, FP8, Half, AQuantGrouped, GemmConfigTransposeC, GroupSize>
|
||||
>;
|
||||
// clang-format on
|
||||
|
||||
// Test suite for AQuant TransposeC
|
||||
TYPED_TEST_SUITE(TestCkTileGemmAQuant, AQuantTransposeCTypes);
|
||||
|
||||
// AQuant tests
|
||||
TYPED_TEST(TestCkTileGemmAQuant, AQuantGroupedTest)
|
||||
{
|
||||
this->run_test_with_validation(1024, 1024, 1024);
|
||||
}
|
||||
@@ -1,99 +0,0 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include <memory>
|
||||
|
||||
#include "test_gemm_quant_fixtures.hpp"
|
||||
|
||||
// Type aliases for readability
|
||||
using RowMajor = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
using FP8 = ck_tile::fp8_t;
|
||||
using BF8 = ck_tile::bf8_t;
|
||||
using Half = ck_tile::half_t;
|
||||
using PkInt4 = ck_tile::pk_int4_t;
|
||||
using BF16 = ck_tile::bf16_t;
|
||||
using UInt8 = ck_tile::pk_fp4_raw_t;
|
||||
using BQuantGrouped = std::integral_constant<ck_tile::QuantType, ck_tile::QuantType::BQuantGrouped>;
|
||||
using GroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
using GroupSize64 = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 64>>;
|
||||
using GroupSize32 = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 32>>;
|
||||
|
||||
// 2d block sizes for BQuant
|
||||
using GroupSize2D8N = ck_tile::QuantGroupShape<ck_tile::sequence<1, 8, 128>>;
|
||||
using GroupSize2D16N = ck_tile::QuantGroupShape<ck_tile::sequence<1, 16, 128>>;
|
||||
using GroupSize2D32N = ck_tile::QuantGroupShape<ck_tile::sequence<1, 32, 128>>;
|
||||
using GroupSize2D64N = ck_tile::QuantGroupShape<ck_tile::sequence<1, 64, 128>>;
|
||||
using GroupSize2D128N = ck_tile::QuantGroupShape<ck_tile::sequence<1, 128, 128>>;
|
||||
|
||||
// Type combinations for BQuant tests (without PreshuffleB)
|
||||
// Tuple format: <ALayout, BLayout, CLayout, BQLayout, ADataType, BDataType, QDataType, CDataType,
|
||||
// QuantType, GemmConfig, QuantGroupSize>
|
||||
// clang-format off
|
||||
using BQuantTypes = ::testing::Types<
|
||||
// 1d cases with grouping only on k axis
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF8, BF8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigBase, GroupSize>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF8, PkInt4, BF8, Half, BQuantGrouped, GemmConfigBase, GroupSize>,
|
||||
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize64>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF8, BF8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize64>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigBase, GroupSize64>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF8, PkInt4, BF8, Half, BQuantGrouped, GemmConfigBase, GroupSize64>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF16, UInt8, UInt8, BF16, BQuantGrouped, GemmConfigMxFp4, GroupSize64>,
|
||||
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF16, UInt8, UInt8, BF16, BQuantGrouped, GemmConfigMxFp4, GroupSize32>,
|
||||
|
||||
// 2d cases with grouping also on the n axis
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize2D8N>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF8, BF8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize2D8N>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigBase, GroupSize2D8N>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF8, PkInt4, BF8, Half, BQuantGrouped, GemmConfigBase, GroupSize2D8N>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize2D16N>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF8, BF8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize2D16N>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigBase, GroupSize2D16N>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF8, PkInt4, BF8, Half, BQuantGrouped, GemmConfigBase, GroupSize2D16N>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize2D32N>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF8, BF8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize2D32N>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigBase, GroupSize2D32N>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF8, PkInt4, BF8, Half, BQuantGrouped, GemmConfigBase, GroupSize2D32N>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize2D64N>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF8, BF8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize2D64N>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigBase, GroupSize2D64N>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF8, PkInt4, BF8, Half, BQuantGrouped, GemmConfigBase, GroupSize2D64N>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize2D128N>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF8, BF8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize2D128N>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigBase, GroupSize2D128N>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF8, PkInt4, BF8, Half, BQuantGrouped, GemmConfigBase, GroupSize2D128N>,
|
||||
|
||||
// some cases with transpose layouts
|
||||
std::tuple< RowMajor, RowMajor, RowMajor, RowMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize64>,
|
||||
std::tuple<ColumnMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize64>,
|
||||
std::tuple<ColumnMajor, RowMajor, RowMajor, RowMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize64>,
|
||||
std::tuple< RowMajor, RowMajor, RowMajor, RowMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize2D64N>,
|
||||
std::tuple<ColumnMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize2D64N>,
|
||||
std::tuple<ColumnMajor, RowMajor, RowMajor, RowMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize2D64N>,
|
||||
|
||||
// pkint4 + transpose cases
|
||||
std::tuple< RowMajor, RowMajor, RowMajor, RowMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigBase, GroupSize64>,
|
||||
std::tuple<ColumnMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigBase, GroupSize64>,
|
||||
std::tuple<ColumnMajor, RowMajor, RowMajor, RowMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigBase, GroupSize64>,
|
||||
std::tuple< RowMajor, RowMajor, RowMajor, RowMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigBase, GroupSize2D64N>,
|
||||
std::tuple<ColumnMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigBase, GroupSize2D64N>,
|
||||
std::tuple<ColumnMajor, RowMajor, RowMajor, RowMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigBase, GroupSize2D64N>
|
||||
>;
|
||||
// clang-format on
|
||||
|
||||
// Test suite for BQuant (without PreshuffleB)
|
||||
TYPED_TEST_SUITE(TestCkTileGemmBQuant, BQuantTypes);
|
||||
|
||||
// BQuant tests
|
||||
TYPED_TEST(TestCkTileGemmBQuant, BQuantGroupedTest)
|
||||
{
|
||||
this->run_test_with_validation(1024, 1024, 1024);
|
||||
}
|
||||
@@ -0,0 +1,41 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include <memory>
|
||||
|
||||
#include "test_gemm_quant_fixtures.hpp"
|
||||
|
||||
// Type aliases for readability
|
||||
using RowMajor = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
using FP8 = ck_tile::fp8_t;
|
||||
using BF8 = ck_tile::bf8_t;
|
||||
using Half = ck_tile::half_t;
|
||||
using PkInt4 = ck_tile::pk_int4_t;
|
||||
using BQuantGrouped = std::integral_constant<ck_tile::QuantType, ck_tile::QuantType::BQuantGrouped>;
|
||||
using GroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
|
||||
// Type combinations for BQuant tests - 1D GroupSize 128
|
||||
// Tuple format: <ALayout, BLayout, CLayout, BQLayout, ADataType, BDataType, QDataType, CDataType,
|
||||
// QuantType, GemmConfig, QuantGroupSize>
|
||||
// clang-format off
|
||||
using BQuant1D128Types = ::testing::Types<
|
||||
// 1d cases with grouping only on k axis
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigBase, GroupSize>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF8, PkInt4, BF8, Half, BQuantGrouped, GemmConfigBase, GroupSize>
|
||||
>;
|
||||
// clang-format on
|
||||
|
||||
// Test suite for BQuant 1D 128
|
||||
TYPED_TEST_SUITE(TestCkTileGemmBQuant, BQuant1D128Types);
|
||||
|
||||
// BQuant tests
|
||||
TYPED_TEST(TestCkTileGemmBQuant, BQuantGroupedTest)
|
||||
{
|
||||
this->run_test_with_validation(1024, 1024, 1024);
|
||||
}
|
||||
@@ -0,0 +1,41 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include <memory>
|
||||
|
||||
#include "test_gemm_quant_fixtures.hpp"
|
||||
|
||||
// Type aliases for readability
|
||||
using RowMajor = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
using FP8 = ck_tile::fp8_t;
|
||||
using BF8 = ck_tile::bf8_t;
|
||||
using Half = ck_tile::half_t;
|
||||
using PkInt4 = ck_tile::pk_int4_t;
|
||||
using BQuantGrouped = std::integral_constant<ck_tile::QuantType, ck_tile::QuantType::BQuantGrouped>;
|
||||
using GroupSize64 = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 64>>;
|
||||
|
||||
// Type combinations for BQuant tests - 1D GroupSize 64
|
||||
// Tuple format: <ALayout, BLayout, CLayout, BQLayout, ADataType, BDataType, QDataType, CDataType,
|
||||
// QuantType, GemmConfig, QuantGroupSize>
|
||||
// clang-format off
|
||||
using BQuant1D64Types = ::testing::Types<
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize64>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF8, BF8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize64>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigBase, GroupSize64>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF8, PkInt4, BF8, Half, BQuantGrouped, GemmConfigBase, GroupSize64>
|
||||
>;
|
||||
// clang-format on
|
||||
|
||||
// Test suite for BQuant 1D 64
|
||||
TYPED_TEST_SUITE(TestCkTileGemmBQuant, BQuant1D64Types);
|
||||
|
||||
// BQuant tests
|
||||
TYPED_TEST(TestCkTileGemmBQuant, BQuantGroupedTest)
|
||||
{
|
||||
this->run_test_with_validation(1024, 1024, 1024);
|
||||
}
|
||||
@@ -0,0 +1,41 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include <memory>
|
||||
|
||||
#include "test_gemm_quant_fixtures.hpp"
|
||||
|
||||
// Type aliases for readability
|
||||
using RowMajor = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
using FP8 = ck_tile::fp8_t;
|
||||
using BF8 = ck_tile::bf8_t;
|
||||
using Half = ck_tile::half_t;
|
||||
using PkInt4 = ck_tile::pk_int4_t;
|
||||
using BQuantGrouped = std::integral_constant<ck_tile::QuantType, ck_tile::QuantType::BQuantGrouped>;
|
||||
using GroupSize2D128N = ck_tile::QuantGroupShape<ck_tile::sequence<1, 128, 128>>;
|
||||
|
||||
// Type combinations for BQuant tests - 2D Large N (128N)
|
||||
// Tuple format: <ALayout, BLayout, CLayout, BQLayout, ADataType, BDataType, QDataType, CDataType,
|
||||
// QuantType, GemmConfig, QuantGroupSize>
|
||||
// clang-format off
|
||||
using BQuant2DLargeNTypes = ::testing::Types<
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize2D128N>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF8, BF8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize2D128N>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigBase, GroupSize2D128N>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF8, PkInt4, BF8, Half, BQuantGrouped, GemmConfigBase, GroupSize2D128N>
|
||||
>;
|
||||
// clang-format on
|
||||
|
||||
// Test suite for BQuant 2D Large N
|
||||
TYPED_TEST_SUITE(TestCkTileGemmBQuant, BQuant2DLargeNTypes);
|
||||
|
||||
// BQuant tests
|
||||
TYPED_TEST(TestCkTileGemmBQuant, BQuantGroupedTest)
|
||||
{
|
||||
this->run_test_with_validation(1024, 1024, 1024);
|
||||
}
|
||||
@@ -0,0 +1,48 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include <memory>
|
||||
|
||||
#include "test_gemm_quant_fixtures.hpp"
|
||||
|
||||
// Type aliases for readability
|
||||
using RowMajor = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
using FP8 = ck_tile::fp8_t;
|
||||
using BF8 = ck_tile::bf8_t;
|
||||
using Half = ck_tile::half_t;
|
||||
using PkInt4 = ck_tile::pk_int4_t;
|
||||
using BQuantGrouped = std::integral_constant<ck_tile::QuantType, ck_tile::QuantType::BQuantGrouped>;
|
||||
|
||||
// 2d block sizes for BQuant
|
||||
using GroupSize2D32N = ck_tile::QuantGroupShape<ck_tile::sequence<1, 32, 128>>;
|
||||
using GroupSize2D64N = ck_tile::QuantGroupShape<ck_tile::sequence<1, 64, 128>>;
|
||||
|
||||
// Type combinations for BQuant tests - 2D Medium N (32N and 64N)
|
||||
// Tuple format: <ALayout, BLayout, CLayout, BQLayout, ADataType, BDataType, QDataType, CDataType,
|
||||
// QuantType, GemmConfig, QuantGroupSize>
|
||||
// clang-format off
|
||||
using BQuant2DMediumNTypes = ::testing::Types<
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize2D32N>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF8, BF8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize2D32N>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigBase, GroupSize2D32N>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF8, PkInt4, BF8, Half, BQuantGrouped, GemmConfigBase, GroupSize2D32N>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize2D64N>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF8, BF8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize2D64N>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigBase, GroupSize2D64N>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF8, PkInt4, BF8, Half, BQuantGrouped, GemmConfigBase, GroupSize2D64N>
|
||||
>;
|
||||
// clang-format on
|
||||
|
||||
// Test suite for BQuant 2D Medium N
|
||||
TYPED_TEST_SUITE(TestCkTileGemmBQuant, BQuant2DMediumNTypes);
|
||||
|
||||
// BQuant tests
|
||||
TYPED_TEST(TestCkTileGemmBQuant, BQuantGroupedTest)
|
||||
{
|
||||
this->run_test_with_validation(1024, 1024, 1024);
|
||||
}
|
||||
@@ -0,0 +1,49 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include <memory>
|
||||
|
||||
#include "test_gemm_quant_fixtures.hpp"
|
||||
|
||||
// Type aliases for readability
|
||||
using RowMajor = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
using FP8 = ck_tile::fp8_t;
|
||||
using BF8 = ck_tile::bf8_t;
|
||||
using Half = ck_tile::half_t;
|
||||
using PkInt4 = ck_tile::pk_int4_t;
|
||||
using BQuantGrouped = std::integral_constant<ck_tile::QuantType, ck_tile::QuantType::BQuantGrouped>;
|
||||
|
||||
// 2d block sizes for BQuant
|
||||
using GroupSize2D8N = ck_tile::QuantGroupShape<ck_tile::sequence<1, 8, 128>>;
|
||||
using GroupSize2D16N = ck_tile::QuantGroupShape<ck_tile::sequence<1, 16, 128>>;
|
||||
|
||||
// Type combinations for BQuant tests - 2D Small N (8N and 16N)
|
||||
// Tuple format: <ALayout, BLayout, CLayout, BQLayout, ADataType, BDataType, QDataType, CDataType,
|
||||
// QuantType, GemmConfig, QuantGroupSize>
|
||||
// clang-format off
|
||||
using BQuant2DSmallNTypes = ::testing::Types<
|
||||
// 2d cases with grouping also on the n axis
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize2D8N>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF8, BF8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize2D8N>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigBase, GroupSize2D8N>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF8, PkInt4, BF8, Half, BQuantGrouped, GemmConfigBase, GroupSize2D8N>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize2D16N>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF8, BF8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize2D16N>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigBase, GroupSize2D16N>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF8, PkInt4, BF8, Half, BQuantGrouped, GemmConfigBase, GroupSize2D16N>
|
||||
>;
|
||||
// clang-format on
|
||||
|
||||
// Test suite for BQuant 2D Small N
|
||||
TYPED_TEST_SUITE(TestCkTileGemmBQuant, BQuant2DSmallNTypes);
|
||||
|
||||
// BQuant tests
|
||||
TYPED_TEST(TestCkTileGemmBQuant, BQuantGroupedTest)
|
||||
{
|
||||
this->run_test_with_validation(1024, 1024, 1024);
|
||||
}
|
||||
@@ -1,93 +0,0 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include <memory>
|
||||
|
||||
#include "test_gemm_quant_fixtures.hpp"
|
||||
|
||||
// Type aliases for readability
|
||||
using RowMajor = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
using FP8 = ck_tile::fp8_t;
|
||||
using BF8 = ck_tile::bf8_t;
|
||||
using Half = ck_tile::half_t;
|
||||
using PkInt4 = ck_tile::pk_int4_t;
|
||||
using BQuantGrouped = std::integral_constant<ck_tile::QuantType, ck_tile::QuantType::BQuantGrouped>;
|
||||
using GroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
|
||||
// 2d block sizes for BQuant
|
||||
using GroupSize2D8N = ck_tile::QuantGroupShape<ck_tile::sequence<1, 8, 128>>;
|
||||
using GroupSize2D16N = ck_tile::QuantGroupShape<ck_tile::sequence<1, 16, 128>>;
|
||||
using GroupSize2D32N = ck_tile::QuantGroupShape<ck_tile::sequence<1, 32, 128>>;
|
||||
using GroupSize2D64N = ck_tile::QuantGroupShape<ck_tile::sequence<1, 64, 128>>;
|
||||
|
||||
// Type combinations for BQuant tests with PreshuffleB
|
||||
// Tuple format: <ALayout, BLayout, CLayout, BQLayout, ADataType, BDataType, QDataType, CDataType,
|
||||
// QuantType, GemmConfig, QuantGroupSize>
|
||||
// clang-format off
|
||||
using BPreshuffleBQuantTypes = ::testing::Types<
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigPreshuffleBDecode, GroupSize>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF8, BF8, float, Half, BQuantGrouped, GemmConfigPreshuffleBDecode, GroupSize>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigPreshuffleBDecode, GroupSize>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF8, PkInt4, BF8, Half, BQuantGrouped, GemmConfigPreshuffleBDecode, GroupSize>,
|
||||
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigPreshuffleBPrefill, GroupSize>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF8, BF8, float, Half, BQuantGrouped, GemmConfigPreshuffleBPrefill, GroupSize>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigPreshuffleBPrefill, GroupSize>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF8, PkInt4, BF8, Half, BQuantGrouped, GemmConfigPreshuffleBPrefill, GroupSize>,
|
||||
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigPreshuffleBPrefillTiledPermuteN, GroupSize>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF8, BF8, float, Half, BQuantGrouped, GemmConfigPreshuffleBPrefillTiledPermuteN, GroupSize>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigPreshuffleBPrefillTiledPermuteN, GroupSize>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF8, PkInt4, BF8, Half, BQuantGrouped, GemmConfigPreshuffleBPrefillTiledPermuteN, GroupSize>,
|
||||
|
||||
// //2d cases with preshuffle B
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigPreshuffleBDecode, GroupSize2D8N>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF8, BF8, float, Half, BQuantGrouped, GemmConfigPreshuffleBDecode, GroupSize2D8N>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigPreshuffleBDecode, GroupSize2D8N>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF8, PkInt4, BF8, Half, BQuantGrouped, GemmConfigPreshuffleBDecode, GroupSize2D8N>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigPreshuffleBDecode, GroupSize2D16N>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF8, BF8, float, Half, BQuantGrouped, GemmConfigPreshuffleBDecode, GroupSize2D16N>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigPreshuffleBDecode, GroupSize2D16N>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF8, PkInt4, BF8, Half, BQuantGrouped, GemmConfigPreshuffleBDecode, GroupSize2D16N>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigPreshuffleBDecode, GroupSize2D32N>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF8, BF8, float, Half, BQuantGrouped, GemmConfigPreshuffleBDecode, GroupSize2D32N>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigPreshuffleBDecode, GroupSize2D32N>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF8, PkInt4, BF8, Half, BQuantGrouped, GemmConfigPreshuffleBDecode, GroupSize2D32N>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigPreshuffleBDecode, GroupSize2D64N>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF8, BF8, float, Half, BQuantGrouped, GemmConfigPreshuffleBDecode, GroupSize2D64N>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigPreshuffleBDecode, GroupSize2D64N>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF8, PkInt4, BF8, Half, BQuantGrouped, GemmConfigPreshuffleBDecode, GroupSize2D64N>,
|
||||
|
||||
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigPreshuffleBPrefill, GroupSize2D8N>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF8, BF8, float, Half, BQuantGrouped, GemmConfigPreshuffleBPrefill, GroupSize2D8N>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigPreshuffleBPrefill, GroupSize2D8N>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF8, PkInt4, BF8, Half, BQuantGrouped, GemmConfigPreshuffleBPrefill, GroupSize2D8N>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigPreshuffleBPrefill, GroupSize2D16N>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF8, BF8, float, Half, BQuantGrouped, GemmConfigPreshuffleBPrefill, GroupSize2D16N>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigPreshuffleBPrefill, GroupSize2D16N>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF8, PkInt4, BF8, Half, BQuantGrouped, GemmConfigPreshuffleBPrefill, GroupSize2D16N>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigPreshuffleBPrefill, GroupSize2D32N>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF8, BF8, float, Half, BQuantGrouped, GemmConfigPreshuffleBPrefill, GroupSize2D32N>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigPreshuffleBPrefill, GroupSize2D32N>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF8, PkInt4, BF8, Half, BQuantGrouped, GemmConfigPreshuffleBPrefill, GroupSize2D32N>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigPreshuffleBPrefill, GroupSize2D64N>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF8, BF8, float, Half, BQuantGrouped, GemmConfigPreshuffleBPrefill, GroupSize2D64N>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigPreshuffleBPrefill, GroupSize2D64N>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF8, PkInt4, BF8, Half, BQuantGrouped, GemmConfigPreshuffleBPrefill, GroupSize2D64N>
|
||||
>;
|
||||
// clang-format on
|
||||
|
||||
// Test suite for BQuant with PreshuffleB
|
||||
TYPED_TEST_SUITE(TestCkTileGemmPreshuffleBBQuant, BPreshuffleBQuantTypes);
|
||||
|
||||
// BQuant PreshuffleB tests
|
||||
TYPED_TEST(TestCkTileGemmPreshuffleBBQuant, BQuantPreshuffleTest)
|
||||
{
|
||||
this->run_test_with_validation(1024, 1024, 1024);
|
||||
}
|
||||
@@ -0,0 +1,39 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include <memory>
|
||||
|
||||
#include "test_gemm_quant_fixtures.hpp"
|
||||
|
||||
// Type aliases for readability
|
||||
using RowMajor = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
using FP8 = ck_tile::fp8_t;
|
||||
using BF8 = ck_tile::bf8_t;
|
||||
using Half = ck_tile::half_t;
|
||||
using PkInt4 = ck_tile::pk_int4_t;
|
||||
using BQuantGrouped = std::integral_constant<ck_tile::QuantType, ck_tile::QuantType::BQuantGrouped>;
|
||||
using GroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
|
||||
// Type combinations for BQuant Preshuffle tests - Decode Config 1D
|
||||
// Tuple format: <ALayout, BLayout, CLayout, BQLayout, ADataType, BDataType, QDataType, CDataType,
|
||||
// QuantType, GemmConfig, QuantGroupSize>
|
||||
// clang-format off
|
||||
using BPreshuffleDecode1DTypes = ::testing::Types<
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigPreshuffleBDecode, GroupSize>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigPreshuffleBDecode, GroupSize>
|
||||
>;
|
||||
// clang-format on
|
||||
|
||||
// Test suite for BQuant Preshuffle Decode 1D
|
||||
TYPED_TEST_SUITE(TestCkTileGemmPreshuffleBBQuant, BPreshuffleDecode1DTypes);
|
||||
|
||||
// BQuant PreshuffleB tests
|
||||
TYPED_TEST(TestCkTileGemmPreshuffleBBQuant, BQuantPreshuffleTest)
|
||||
{
|
||||
this->run_test_with_validation(1024, 1024, 1024);
|
||||
}
|
||||
@@ -0,0 +1,51 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include <memory>
|
||||
|
||||
#include "test_gemm_quant_fixtures.hpp"
|
||||
|
||||
// Type aliases for readability
|
||||
using RowMajor = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
using FP8 = ck_tile::fp8_t;
|
||||
using BF8 = ck_tile::bf8_t;
|
||||
using Half = ck_tile::half_t;
|
||||
using PkInt4 = ck_tile::pk_int4_t;
|
||||
using BQuantGrouped = std::integral_constant<ck_tile::QuantType, ck_tile::QuantType::BQuantGrouped>;
|
||||
|
||||
// 2d block sizes for BQuant
|
||||
using GroupSize2D8N = ck_tile::QuantGroupShape<ck_tile::sequence<1, 8, 128>>;
|
||||
using GroupSize2D16N = ck_tile::QuantGroupShape<ck_tile::sequence<1, 16, 128>>;
|
||||
using GroupSize2D32N = ck_tile::QuantGroupShape<ck_tile::sequence<1, 32, 128>>;
|
||||
using GroupSize2D64N = ck_tile::QuantGroupShape<ck_tile::sequence<1, 64, 128>>;
|
||||
|
||||
// Type combinations for BQuant Preshuffle tests - Decode 2D
|
||||
// Tuple format: <ALayout, BLayout, CLayout, BQLayout, ADataType, BDataType, QDataType, CDataType,
|
||||
// QuantType, GemmConfig, QuantGroupSize>
|
||||
// clang-format off
|
||||
using BPreshuffleDecode2DTypes = ::testing::Types<
|
||||
// 2d cases with preshuffle B
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigPreshuffleBDecode, GroupSize2D8N>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigPreshuffleBDecode, GroupSize2D8N>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigPreshuffleBDecode, GroupSize2D16N>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigPreshuffleBDecode, GroupSize2D16N>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigPreshuffleBDecode, GroupSize2D32N>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigPreshuffleBDecode, GroupSize2D32N>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigPreshuffleBDecode, GroupSize2D64N>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigPreshuffleBDecode, GroupSize2D64N>
|
||||
>;
|
||||
// clang-format on
|
||||
|
||||
// Test suite for BQuant Preshuffle Decode 2D
|
||||
TYPED_TEST_SUITE(TestCkTileGemmPreshuffleBBQuant, BPreshuffleDecode2DTypes);
|
||||
|
||||
// BQuant PreshuffleB tests
|
||||
TYPED_TEST(TestCkTileGemmPreshuffleBBQuant, BQuantPreshuffleTest)
|
||||
{
|
||||
this->run_test_with_validation(1024, 1024, 1024);
|
||||
}
|
||||
@@ -0,0 +1,41 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include <memory>
|
||||
|
||||
#include "test_gemm_quant_fixtures.hpp"
|
||||
|
||||
// Type aliases for readability
|
||||
using RowMajor = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
using FP8 = ck_tile::fp8_t;
|
||||
using BF8 = ck_tile::bf8_t;
|
||||
using Half = ck_tile::half_t;
|
||||
using PkInt4 = ck_tile::pk_int4_t;
|
||||
using BQuantGrouped = std::integral_constant<ck_tile::QuantType, ck_tile::QuantType::BQuantGrouped>;
|
||||
using GroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
|
||||
// Type combinations for BQuant Preshuffle tests - Prefill Config 1D
|
||||
// Tuple format: <ALayout, BLayout, CLayout, BQLayout, ADataType, BDataType, QDataType, CDataType,
|
||||
// QuantType, GemmConfig, QuantGroupSize>
|
||||
// clang-format off
|
||||
using BPreshufflePrefill1DTypes = ::testing::Types<
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigPreshuffleBPrefill, GroupSize>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF8, BF8, float, Half, BQuantGrouped, GemmConfigPreshuffleBPrefill, GroupSize>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigPreshuffleBPrefill, GroupSize>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF8, PkInt4, BF8, Half, BQuantGrouped, GemmConfigPreshuffleBPrefill, GroupSize>
|
||||
>;
|
||||
// clang-format on
|
||||
|
||||
// Test suite for BQuant Preshuffle Prefill 1D
|
||||
TYPED_TEST_SUITE(TestCkTileGemmPreshuffleBBQuant, BPreshufflePrefill1DTypes);
|
||||
|
||||
// BQuant PreshuffleB tests
|
||||
TYPED_TEST(TestCkTileGemmPreshuffleBBQuant, BQuantPreshuffleTest)
|
||||
{
|
||||
this->run_test_with_validation(1024, 1024, 1024);
|
||||
}
|
||||
@@ -0,0 +1,58 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include <memory>
|
||||
|
||||
#include "test_gemm_quant_fixtures.hpp"
|
||||
|
||||
// Type aliases for readability
|
||||
using RowMajor = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
using FP8 = ck_tile::fp8_t;
|
||||
using BF8 = ck_tile::bf8_t;
|
||||
using Half = ck_tile::half_t;
|
||||
using PkInt4 = ck_tile::pk_int4_t;
|
||||
using BQuantGrouped = std::integral_constant<ck_tile::QuantType, ck_tile::QuantType::BQuantGrouped>;
|
||||
|
||||
// 2d block sizes for BQuant
|
||||
using GroupSize2D8N = ck_tile::QuantGroupShape<ck_tile::sequence<1, 8, 128>>;
|
||||
using GroupSize2D16N = ck_tile::QuantGroupShape<ck_tile::sequence<1, 16, 128>>;
|
||||
using GroupSize2D32N = ck_tile::QuantGroupShape<ck_tile::sequence<1, 32, 128>>;
|
||||
using GroupSize2D64N = ck_tile::QuantGroupShape<ck_tile::sequence<1, 64, 128>>;
|
||||
|
||||
// Type combinations for BQuant Preshuffle tests - Prefill 2D
|
||||
// Tuple format: <ALayout, BLayout, CLayout, BQLayout, ADataType, BDataType, QDataType, CDataType,
|
||||
// QuantType, GemmConfig, QuantGroupSize>
|
||||
// clang-format off
|
||||
using BPreshufflePrefill2DTypes = ::testing::Types<
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigPreshuffleBPrefill, GroupSize2D8N>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF8, BF8, float, Half, BQuantGrouped, GemmConfigPreshuffleBPrefill, GroupSize2D8N>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigPreshuffleBPrefill, GroupSize2D8N>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF8, PkInt4, BF8, Half, BQuantGrouped, GemmConfigPreshuffleBPrefill, GroupSize2D8N>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigPreshuffleBPrefill, GroupSize2D16N>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF8, BF8, float, Half, BQuantGrouped, GemmConfigPreshuffleBPrefill, GroupSize2D16N>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigPreshuffleBPrefill, GroupSize2D16N>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF8, PkInt4, BF8, Half, BQuantGrouped, GemmConfigPreshuffleBPrefill, GroupSize2D16N>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigPreshuffleBPrefill, GroupSize2D32N>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF8, BF8, float, Half, BQuantGrouped, GemmConfigPreshuffleBPrefill, GroupSize2D32N>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigPreshuffleBPrefill, GroupSize2D32N>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF8, PkInt4, BF8, Half, BQuantGrouped, GemmConfigPreshuffleBPrefill, GroupSize2D32N>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigPreshuffleBPrefill, GroupSize2D64N>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF8, BF8, float, Half, BQuantGrouped, GemmConfigPreshuffleBPrefill, GroupSize2D64N>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigPreshuffleBPrefill, GroupSize2D64N>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF8, PkInt4, BF8, Half, BQuantGrouped, GemmConfigPreshuffleBPrefill, GroupSize2D64N>
|
||||
>;
|
||||
// clang-format on
|
||||
|
||||
// Test suite for BQuant Preshuffle Prefill 2D
|
||||
TYPED_TEST_SUITE(TestCkTileGemmPreshuffleBBQuant, BPreshufflePrefill2DTypes);
|
||||
|
||||
// BQuant PreshuffleB tests
|
||||
TYPED_TEST(TestCkTileGemmPreshuffleBBQuant, BQuantPreshuffleTest)
|
||||
{
|
||||
this->run_test_with_validation(1024, 1024, 1024);
|
||||
}
|
||||
@@ -0,0 +1,40 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include <memory>
|
||||
|
||||
#include "test_gemm_quant_fixtures.hpp"
|
||||
|
||||
// Type aliases for readability
|
||||
using RowMajor = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
using FP8 = ck_tile::fp8_t;
|
||||
using BF8 = ck_tile::bf8_t;
|
||||
using Half = ck_tile::half_t;
|
||||
using PkInt4 = ck_tile::pk_int4_t;
|
||||
using BQuantGrouped = std::integral_constant<ck_tile::QuantType, ck_tile::QuantType::BQuantGrouped>;
|
||||
using GroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
|
||||
// Type combinations for BQuant Preshuffle tests - TiledPermuteN Config
|
||||
// Tuple format: <ALayout, BLayout, CLayout, BQLayout, ADataType, BDataType, QDataType, CDataType,
|
||||
// QuantType, GemmConfig, QuantGroupSize>
|
||||
// clang-format off
|
||||
using BPreshuffleTiledPermuteTypes = ::testing::Types<
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigPreshuffleBPrefillTiledPermuteN, GroupSize>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigPreshuffleBPrefillTiledPermuteN, GroupSize>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF8, PkInt4, BF8, Half, BQuantGrouped, GemmConfigPreshuffleBPrefillTiledPermuteN, GroupSize>
|
||||
>;
|
||||
// clang-format on
|
||||
|
||||
// Test suite for BQuant Preshuffle TiledPermuteN
|
||||
TYPED_TEST_SUITE(TestCkTileGemmPreshuffleBBQuant, BPreshuffleTiledPermuteTypes);
|
||||
|
||||
// BQuant PreshuffleB tests
|
||||
TYPED_TEST(TestCkTileGemmPreshuffleBBQuant, BQuantPreshuffleTest)
|
||||
{
|
||||
this->run_test_with_validation(1024, 1024, 1024);
|
||||
}
|
||||
@@ -0,0 +1,53 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include <memory>
|
||||
|
||||
#include "test_gemm_quant_fixtures.hpp"
|
||||
|
||||
// Type aliases for readability
|
||||
using RowMajor = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
using FP8 = ck_tile::fp8_t;
|
||||
using BF8 = ck_tile::bf8_t;
|
||||
using Half = ck_tile::half_t;
|
||||
using PkInt4 = ck_tile::pk_int4_t;
|
||||
using BQuantGrouped = std::integral_constant<ck_tile::QuantType, ck_tile::QuantType::BQuantGrouped>;
|
||||
using GroupSize64 = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 64>>;
|
||||
using GroupSize2D64N = ck_tile::QuantGroupShape<ck_tile::sequence<1, 64, 128>>;
|
||||
|
||||
// Type combinations for BQuant tests - Transpose Layouts
|
||||
// Tuple format: <ALayout, BLayout, CLayout, BQLayout, ADataType, BDataType, QDataType, CDataType,
|
||||
// QuantType, GemmConfig, QuantGroupSize>
|
||||
// clang-format off
|
||||
using BQuantTransposeTypes = ::testing::Types<
|
||||
// some cases with transpose layouts
|
||||
std::tuple< RowMajor, RowMajor, RowMajor, RowMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize64>,
|
||||
std::tuple<ColumnMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize64>,
|
||||
std::tuple<ColumnMajor, RowMajor, RowMajor, RowMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize64>,
|
||||
std::tuple< RowMajor, RowMajor, RowMajor, RowMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize2D64N>,
|
||||
std::tuple<ColumnMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize2D64N>,
|
||||
std::tuple<ColumnMajor, RowMajor, RowMajor, RowMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize2D64N>,
|
||||
|
||||
// pkint4 + transpose cases
|
||||
std::tuple< RowMajor, RowMajor, RowMajor, RowMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigBase, GroupSize64>,
|
||||
std::tuple<ColumnMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigBase, GroupSize64>,
|
||||
std::tuple<ColumnMajor, RowMajor, RowMajor, RowMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigBase, GroupSize64>,
|
||||
std::tuple< RowMajor, RowMajor, RowMajor, RowMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigBase, GroupSize2D64N>,
|
||||
std::tuple<ColumnMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigBase, GroupSize2D64N>,
|
||||
std::tuple<ColumnMajor, RowMajor, RowMajor, RowMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigBase, GroupSize2D64N>
|
||||
>;
|
||||
// clang-format on
|
||||
|
||||
// Test suite for BQuant Transpose
|
||||
TYPED_TEST_SUITE(TestCkTileGemmBQuant, BQuantTransposeTypes);
|
||||
|
||||
// BQuant tests
|
||||
TYPED_TEST(TestCkTileGemmBQuant, BQuantGroupedTest)
|
||||
{
|
||||
this->run_test_with_validation(1024, 1024, 1024);
|
||||
}
|
||||
@@ -11,26 +11,6 @@
|
||||
#include "ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp"
|
||||
#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp"
|
||||
|
||||
template <typename PrecType, ck_tile::index_t M_Warp_Tile>
|
||||
constexpr ck_tile::index_t get_k_warp_tile_flatmm()
|
||||
{
|
||||
#if CK_TILE_USE_WMMA
|
||||
return 16;
|
||||
#else
|
||||
#if defined(CK_GFX950_SUPPORT)
|
||||
if constexpr(M_Warp_Tile == 32)
|
||||
return sizeof(PrecType) == 2 ? 16 : 64;
|
||||
else
|
||||
return sizeof(PrecType) == 2 ? 32 : 128;
|
||||
#else
|
||||
if constexpr(M_Warp_Tile == 32)
|
||||
return sizeof(PrecType) == 2 ? 16 : 32;
|
||||
else
|
||||
return sizeof(PrecType) == 2 ? 32 : 64;
|
||||
#endif
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename Tuple>
|
||||
class TestCkTileGroupedGemmPreshuffle : public ::testing::Test
|
||||
{
|
||||
@@ -67,7 +47,8 @@ class TestCkTileGroupedGemmPreshuffle : public ::testing::Test
|
||||
|
||||
static const ck_tile::index_t M_Warp_Tile = 16;
|
||||
static const ck_tile::index_t N_Warp_Tile = 16;
|
||||
static const ck_tile::index_t K_Warp_Tile = get_k_warp_tile_flatmm<BDataType, M_Warp_Tile>();
|
||||
static const ck_tile::index_t K_Warp_Tile =
|
||||
ck_tile::get_k_warp_tile<BDataType, M_Warp_Tile, true>();
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = true; // preshuffle v2 uses ping-pong smem
|
||||
static constexpr bool TransposeC = false; // transpose c is not supported
|
||||
@@ -101,46 +82,6 @@ class TestCkTileGroupedGemmPreshuffle : public ::testing::Test
|
||||
return gemm_descs.size() * sizeof(ck_tile::GemmTransKernelArg<>);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
auto shuffle_b(const ck_tile::HostTensor<T>& t)
|
||||
{
|
||||
assert(t.get_lengths().size() == 2);
|
||||
int n_ = t.get_lengths()[1];
|
||||
int k_ = t.get_lengths()[0];
|
||||
|
||||
if(ck_tile::is_gfx12_supported())
|
||||
{
|
||||
constexpr int divisor = 2;
|
||||
constexpr int kABK1PerLane = 8;
|
||||
constexpr int kABK0PerLane = K_Warp_Tile / divisor / kABK1PerLane;
|
||||
ck_tile::HostTensor<T> t_view({n_ / N_Warp_Tile,
|
||||
N_Warp_Tile,
|
||||
k_ / K_Warp_Tile,
|
||||
kABK0PerLane,
|
||||
divisor,
|
||||
kABK1PerLane});
|
||||
std::copy(t.begin(), t.end(), t_view.begin());
|
||||
return ck_tile::reference_permute(t_view, {0, 2, 4, 1, 3, 5});
|
||||
}
|
||||
else
|
||||
{
|
||||
int divisor = 1;
|
||||
if(ck_tile::is_gfx11_supported())
|
||||
{
|
||||
divisor = 1;
|
||||
}
|
||||
else
|
||||
{
|
||||
assert(is_wave32() == false);
|
||||
divisor = N_Warp_Tile == 32 ? 2 : 4;
|
||||
}
|
||||
ck_tile::HostTensor<T> t_view(
|
||||
{n_ / N_Warp_Tile, N_Warp_Tile, k_ / K_Warp_Tile, divisor, K_Warp_Tile / divisor});
|
||||
std::copy(t.begin(), t.end(), t_view.begin());
|
||||
return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4});
|
||||
}
|
||||
}
|
||||
|
||||
template <typename ALayout, typename BLayout, typename CLayout>
|
||||
void invoke_grouped_gemm(const std::vector<grouped_gemm_kargs>& gemm_descs,
|
||||
const ck_tile::stream_config& s,
|
||||
@@ -340,6 +281,14 @@ class TestCkTileGroupedGemmPreshuffle : public ::testing::Test
|
||||
}
|
||||
}
|
||||
|
||||
struct BShuffleGemmConfig
|
||||
{
|
||||
static constexpr ck_tile::index_t N_Warp_Tile =
|
||||
TestCkTileGroupedGemmPreshuffle::N_Warp_Tile;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile =
|
||||
TestCkTileGroupedGemmPreshuffle::K_Warp_Tile;
|
||||
};
|
||||
|
||||
public:
|
||||
void Run(const std::vector<int>& Ms,
|
||||
const std::vector<int>& Ns,
|
||||
@@ -424,7 +373,7 @@ class TestCkTileGroupedGemmPreshuffle : public ::testing::Test
|
||||
ck_tile::FillUniformDistribution<BDataType>{-1.f, 1.f}(b_k_n_tensors[i]);
|
||||
|
||||
// Host-side preshuffle of B
|
||||
auto b_shuffle_host = shuffle_b(b_k_n_tensors[i]);
|
||||
auto b_shuffle_host = ck_tile::shuffle_b<BShuffleGemmConfig>(b_k_n_tensors[i]);
|
||||
|
||||
a_m_k_dev_buf.push_back(std::make_unique<ck_tile::DeviceMem>(
|
||||
a_m_k_tensors[i].get_element_space_size_in_bytes()));
|
||||
|
||||
@@ -6,18 +6,18 @@ if(CK_USE_OCP_FP8)
|
||||
list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8)
|
||||
endif()
|
||||
|
||||
if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx12")
|
||||
# Split into three separate test executables for faster parallel compilation
|
||||
add_gtest_executable(test_ck_tile_grouped_gemm_quant_rowcol test_grouped_gemm_quant_rowcol.cpp)
|
||||
target_compile_options(test_ck_tile_grouped_gemm_quant_rowcol PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
|
||||
# if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx12")
|
||||
# # Split into three separate test executables for faster parallel compilation
|
||||
# add_gtest_executable(test_ck_tile_grouped_gemm_quant_rowcol test_grouped_gemm_quant_rowcol.cpp)
|
||||
# target_compile_options(test_ck_tile_grouped_gemm_quant_rowcol PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
|
||||
|
||||
add_gtest_executable(test_ck_tile_grouped_gemm_quant_tensor test_grouped_gemm_quant_tensor.cpp)
|
||||
target_compile_options(test_ck_tile_grouped_gemm_quant_tensor PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
|
||||
# add_gtest_executable(test_ck_tile_grouped_gemm_quant_tensor test_grouped_gemm_quant_tensor.cpp)
|
||||
# target_compile_options(test_ck_tile_grouped_gemm_quant_tensor PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
|
||||
|
||||
add_gtest_executable(test_ck_tile_grouped_gemm_quant_aquant test_grouped_gemm_quant_aquant.cpp)
|
||||
target_compile_options(test_ck_tile_grouped_gemm_quant_aquant PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
|
||||
# add_gtest_executable(test_ck_tile_grouped_gemm_quant_aquant test_grouped_gemm_quant_aquant.cpp)
|
||||
# target_compile_options(test_ck_tile_grouped_gemm_quant_aquant PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
|
||||
|
||||
# add_gtest_executable(test_ck_tile_grouped_gemm_quant_bquant test_grouped_gemm_quant_bquant.cpp)
|
||||
# target_compile_options(test_ck_tile_grouped_gemm_quant_bquant PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
|
||||
endif()
|
||||
# endif()
|
||||
|
||||
|
||||
246
test/common/csv_test_loader.hpp
Normal file
246
test/common/csv_test_loader.hpp
Normal file
@@ -0,0 +1,246 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <fstream>
|
||||
#include <filesystem>
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "ck/library/utility/convolution_parameter.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace test {
|
||||
|
||||
namespace fs = std::filesystem;
|
||||
|
||||
// Helper function to find test_data directory relative to the test binary
|
||||
static std::string GetTestDataPath()
|
||||
{
|
||||
// Get the path to the current executable
|
||||
fs::path exe_path = fs::read_symlink("/proc/self/exe");
|
||||
|
||||
// Get the directory containing the executable
|
||||
fs::path current_dir = exe_path.parent_path();
|
||||
|
||||
// Search for test_data directory by going up the directory tree
|
||||
// This makes the code robust regardless of build directory depth
|
||||
while(current_dir != current_dir.root_path())
|
||||
{
|
||||
fs::path test_data_path = current_dir / "test_data";
|
||||
if(fs::exists(test_data_path) && fs::is_directory(test_data_path))
|
||||
{
|
||||
return test_data_path.string();
|
||||
}
|
||||
current_dir = current_dir.parent_path();
|
||||
}
|
||||
|
||||
// If not found, return empty string
|
||||
std::cerr << "ERROR: Could not find test_data directory relative to executable" << std::endl;
|
||||
return "";
|
||||
}
|
||||
|
||||
// CSV Reader Function for Loading Test Cases
|
||||
// Reads convolution parameters from CSV file and returns vector of ConvParam structures
|
||||
inline std::vector<ck::utils::conv::ConvParam> load_csv_test_cases(const std::string& filename)
|
||||
{
|
||||
std::vector<ck::utils::conv::ConvParam> conv_params; // Return vector
|
||||
std::ifstream file(filename); // Open CSV file
|
||||
|
||||
if(!file.is_open())
|
||||
{
|
||||
std::cerr << "ERROR: Cannot open CSV file: " << filename << std::endl;
|
||||
return conv_params; // Return empty vector on error
|
||||
}
|
||||
|
||||
std::string line;
|
||||
int line_number = 0;
|
||||
|
||||
// Read file line by line
|
||||
while(std::getline(file, line))
|
||||
{
|
||||
line_number++;
|
||||
// Skip comment lines (starting with #) and empty lines
|
||||
if(line.empty() || line[0] == '#')
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
// Skip header line (contains column names)
|
||||
if(line.find("NDim,Groups,BatchSize") != std::string::npos)
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
// Parse CSV line using stringstream
|
||||
std::stringstream ss(line);
|
||||
std::string cell;
|
||||
std::vector<std::string> row;
|
||||
|
||||
// Split line by commas
|
||||
while(std::getline(ss, cell, ','))
|
||||
{
|
||||
row.push_back(cell);
|
||||
}
|
||||
|
||||
// Validate row has correct number of columns
|
||||
if(row.size() < 19)
|
||||
{ // Need at least 19 columns for 2D (excluding TestName)
|
||||
std::cerr << "WARNING: Line " << line_number << " has insufficient columns ("
|
||||
<< row.size() << "), skipping" << std::endl;
|
||||
continue;
|
||||
}
|
||||
|
||||
try
|
||||
{
|
||||
// Parse CSV data into ConvParam structure
|
||||
// CSV Format:
|
||||
// NDim,Groups,BatchSize,OutChannels,InChannels,KernelH,KernelW,InputH,InputW,OutputH,OutputW,StrideH,StrideW,DilationH,DilationW,LeftPadH,LeftPadW,RightPadH,RightPadW,TestName
|
||||
int NDim = std::stoi(row[0]);
|
||||
int Groups = std::stoi(row[1]);
|
||||
int BatchSize = std::stoi(row[2]);
|
||||
int OutChannels = std::stoi(row[3]);
|
||||
int InChannels = std::stoi(row[4]);
|
||||
|
||||
if(NDim == 1)
|
||||
{
|
||||
// 1D Convolution: Need fewer columns for 1D parameters
|
||||
if(row.size() < 13)
|
||||
{
|
||||
std::cerr << "WARNING: 1D convolution on line " << line_number
|
||||
<< " needs 13+ columns, has " << row.size() << ", skipping"
|
||||
<< std::endl;
|
||||
continue;
|
||||
}
|
||||
// 1D Convolution: {NDim, Groups, BatchSize, OutChannels, InChannels,
|
||||
// {KernelW}, {InputW}, {StrideW}, {DilationW}, {LeftPadW}, {RightPadW}}
|
||||
ck::utils::conv::ConvParam param = {
|
||||
NDim, // NDim = 1
|
||||
Groups, // Groups
|
||||
BatchSize, // Batch size
|
||||
OutChannels, // Output channels
|
||||
InChannels, // Input channels
|
||||
{std::stoi(row[5])}, // Kernel: {W}
|
||||
{std::stoi(row[7])}, // Input: {W}
|
||||
{std::stoi(row[11])}, // Stride: {W}
|
||||
{std::stoi(row[13])}, // Dilation: {W}
|
||||
{std::stoi(row[15])}, // Left pad: {W}
|
||||
{std::stoi(row[17])} // Right pad: {W}
|
||||
};
|
||||
conv_params.push_back(param);
|
||||
}
|
||||
else if(NDim == 2)
|
||||
{
|
||||
// 2D Convolution: {NDim, Groups, BatchSize, OutChannels, InChannels,
|
||||
// {KernelH,KernelW}, {InputH,InputW}, {StrideH,StrideW}, {DilationH,DilationW},
|
||||
// {LeftPadH,LeftPadW}, {RightPadH,RightPadW}}
|
||||
ck::utils::conv::ConvParam param = {
|
||||
NDim, // NDim = 2
|
||||
Groups, // Groups
|
||||
BatchSize, // Batch size
|
||||
OutChannels, // Output channels
|
||||
InChannels, // Input channels
|
||||
{std::stoi(row[5]), std::stoi(row[6])}, // Kernel: {H, W}
|
||||
{std::stoi(row[7]), std::stoi(row[8])}, // Input: {H, W}
|
||||
{std::stoi(row[11]), std::stoi(row[12])}, // Stride: {H, W}
|
||||
{std::stoi(row[13]), std::stoi(row[14])}, // Dilation: {H, W}
|
||||
{std::stoi(row[15]), std::stoi(row[16])}, // Left pad: {H, W}
|
||||
{std::stoi(row[17]), std::stoi(row[18])} // Right pad: {H, W}
|
||||
};
|
||||
conv_params.push_back(param);
|
||||
}
|
||||
else if(NDim == 3)
|
||||
{
|
||||
// 3D Convolution: Need more columns for 3D parameters
|
||||
if(row.size() < 26)
|
||||
{
|
||||
std::cerr << "WARNING: 3D convolution on line " << line_number
|
||||
<< " needs 26+ columns, has " << row.size() << ", skipping"
|
||||
<< std::endl;
|
||||
continue;
|
||||
}
|
||||
// 3D Convolution: {NDim, Groups, BatchSize, OutChannels, InChannels,
|
||||
// {KernelD,KernelH,KernelW}, {InputD,InputH,InputW}, {OutputD,OutputH,OutputW},
|
||||
// {StrideD,StrideH,StrideW}, {DilationD,DilationH,DilationW},
|
||||
// {LeftPadD,LeftPadH,LeftPadW}, {RightPadD,RightPadH,RightPadW}}
|
||||
ck::utils::conv::ConvParam param = {
|
||||
NDim, // NDim = 3
|
||||
Groups, // Groups
|
||||
BatchSize, // Batch size
|
||||
OutChannels, // Output channels
|
||||
InChannels, // Input channels
|
||||
{std::stoi(row[5]), std::stoi(row[6]), std::stoi(row[7])}, // Kernel: {D, H, W}
|
||||
{std::stoi(row[8]), std::stoi(row[9]), std::stoi(row[10])}, // Input: {D, H, W}
|
||||
{std::stoi(row[14]),
|
||||
std::stoi(row[15]),
|
||||
std::stoi(row[16])}, // Stride: {D, H, W}
|
||||
{std::stoi(row[17]),
|
||||
std::stoi(row[18]),
|
||||
std::stoi(row[19])}, // Dilation: {D, H, W}
|
||||
{std::stoi(row[20]),
|
||||
std::stoi(row[21]),
|
||||
std::stoi(row[22])}, // Left pad: {D, H, W}
|
||||
{std::stoi(row[23]),
|
||||
std::stoi(row[24]),
|
||||
std::stoi(row[25])} // Right pad: {D, H, W}
|
||||
};
|
||||
conv_params.push_back(param);
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cerr << "WARNING: Unsupported NDim=" << NDim << " on line " << line_number
|
||||
<< ", skipping" << std::endl;
|
||||
}
|
||||
}
|
||||
catch(const std::exception& e)
|
||||
{
|
||||
std::cerr << "ERROR: Failed to parse line " << line_number << ": " << e.what()
|
||||
<< std::endl;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
file.close();
|
||||
std::cout << "Loaded " << conv_params.size() << " test cases from " << filename << std::endl;
|
||||
return conv_params;
|
||||
}
|
||||
|
||||
// Helper function to load CSV test cases and populate conv_params vector
|
||||
// Returns true if loading succeeded, false otherwise
|
||||
inline bool load_and_populate_test_cases(const std::vector<std::string>& csv_paths,
|
||||
std::vector<ck::utils::conv::ConvParam>& conv_params,
|
||||
const std::string& dimension_label)
|
||||
{
|
||||
for(const auto& csv_path : csv_paths)
|
||||
{
|
||||
auto csv_cases = load_csv_test_cases(csv_path);
|
||||
if(!csv_cases.empty())
|
||||
{
|
||||
// Successfully loaded CSV data - add all test cases to conv_params
|
||||
for(const auto& test_case : csv_cases)
|
||||
{
|
||||
conv_params.push_back(test_case);
|
||||
}
|
||||
std::cout << "Loaded " << csv_cases.size() << " " << dimension_label
|
||||
<< " test cases from " << csv_path << std::endl;
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
// Failed to load from any path
|
||||
std::cerr << "ERROR: Failed to load CSV test data from any of these locations:" << std::endl;
|
||||
for(const auto& path : csv_paths)
|
||||
{
|
||||
std::cerr << " - " << path << std::endl;
|
||||
}
|
||||
std::cerr << "\nPlease ensure CSV test data exists in one of these locations." << std::endl;
|
||||
std::cerr << "Run generate_test_dataset.sh in test_data/ to create test datasets." << std::endl;
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
} // namespace test
|
||||
} // namespace ck
|
||||
@@ -9,6 +9,10 @@ if(GPU_TARGETS MATCHES "gfx9|gfx11|gfx12")
|
||||
add_executable(test_grouped_convnd_bwd_data_xdl_large_cases test_grouped_convnd_bwd_data_xdl_large_cases.cpp)
|
||||
target_compile_options(test_grouped_convnd_bwd_data_xdl_large_cases PRIVATE -Wno-global-constructors -Wno-undef)
|
||||
target_link_libraries(test_grouped_convnd_bwd_data_xdl_large_cases PRIVATE gtest_main getopt::getopt utility device_grouped_conv2d_bwd_data_instance device_grouped_conv3d_bwd_data_instance)
|
||||
|
||||
add_executable(test_grouped_convnd_bwd_data_dataset_xdl test_grouped_convnd_bwd_data_dataset_xdl.cpp)
|
||||
target_compile_options(test_grouped_convnd_bwd_data_dataset_xdl PRIVATE -Wno-global-constructors -Wno-undef)
|
||||
target_link_libraries(test_grouped_convnd_bwd_data_dataset_xdl PRIVATE gtest_main getopt::getopt utility device_grouped_conv2d_bwd_data_instance device_grouped_conv3d_bwd_data_instance)
|
||||
endif()
|
||||
add_gtest_executable(test_grouped_convnd_bwd_data_wmma test_grouped_convnd_bwd_data_wmma.cpp)
|
||||
if(result EQUAL 0)
|
||||
|
||||
@@ -0,0 +1,317 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include <cstdlib> // Standard C library (exit codes, malloc)
|
||||
#include <iostream> // C++ I/O streams (cout, cerr)
|
||||
#include <initializer_list> // C++ initializer list support (unused here)
|
||||
#include <vector> // C++ vector container - stores test cases
|
||||
#include <string> // String operations
|
||||
#include <gtest/gtest.h> // Google Test framework - provides TEST_P, INSTANTIATE_TEST_SUITE_P
|
||||
|
||||
#include "profiler/profile_grouped_conv_bwd_data_impl.hpp" // The actual GPU profiler that does convolution work
|
||||
#include "../common/csv_test_loader.hpp" // Shared CSV test case loader
|
||||
|
||||
using namespace ck::tensor_layout::convolution; // Import tensor layout names (GNHWK, GKYXC, etc.)
|
||||
|
||||
// Load CSV data for 2D tests
|
||||
static std::vector<ck::utils::conv::ConvParam> Get2DTestCases()
|
||||
{
|
||||
static std::vector<ck::utils::conv::ConvParam> test_cases;
|
||||
if(test_cases.empty())
|
||||
{
|
||||
std::string test_data_dir = ck::test::GetTestDataPath();
|
||||
if(test_data_dir.empty())
|
||||
{
|
||||
std::cerr << "FATAL: test_data directory not found" << std::endl;
|
||||
return test_cases;
|
||||
}
|
||||
|
||||
std::vector<std::string> csv_paths = {test_data_dir + "/conv_test_set_2d_dataset.csv"};
|
||||
bool loaded = ck::test::load_and_populate_test_cases(csv_paths, test_cases, "2D");
|
||||
if(!loaded)
|
||||
{
|
||||
std::cerr << "FATAL: Failed to load 2D test cases from " << csv_paths[0] << std::endl;
|
||||
}
|
||||
}
|
||||
return test_cases;
|
||||
}
|
||||
|
||||
// Load CSV data for 3D tests
|
||||
static std::vector<ck::utils::conv::ConvParam> Get3DTestCases()
|
||||
{
|
||||
static std::vector<ck::utils::conv::ConvParam> test_cases;
|
||||
if(test_cases.empty())
|
||||
{
|
||||
std::string test_data_dir = ck::test::GetTestDataPath();
|
||||
if(test_data_dir.empty())
|
||||
{
|
||||
std::cerr << "FATAL: test_data directory not found" << std::endl;
|
||||
return test_cases;
|
||||
}
|
||||
|
||||
std::vector<std::string> csv_paths = {test_data_dir + "/conv_test_set_3d_dataset.csv"};
|
||||
bool loaded = ck::test::load_and_populate_test_cases(csv_paths, test_cases, "3D");
|
||||
if(!loaded)
|
||||
{
|
||||
std::cerr << "FATAL: Failed to load 3D test cases from " << csv_paths[0] << std::endl;
|
||||
}
|
||||
}
|
||||
return test_cases;
|
||||
}
|
||||
|
||||
// Helper template to run a single backward data convolution test with split_k
|
||||
template <ck::index_t NDimSpatial,
|
||||
typename OutLayout,
|
||||
typename WeiLayout,
|
||||
typename InLayout,
|
||||
typename DataType>
|
||||
bool RunConvBwdDataTest(const ck::utils::conv::ConvParam& param, ck::index_t split_k)
|
||||
{
|
||||
return ck::profiler::profile_grouped_conv_bwd_data_impl<NDimSpatial,
|
||||
OutLayout,
|
||||
WeiLayout,
|
||||
InLayout,
|
||||
DataType,
|
||||
DataType,
|
||||
DataType>(true, // do_verification
|
||||
1, // init_method
|
||||
false, // do_log
|
||||
false, // time_kernel
|
||||
param, // ConvParam
|
||||
split_k, // Split-K value
|
||||
-1); // instance_index
|
||||
}
|
||||
|
||||
// 2D Tests - GNHWK layout - Float - SplitK=1
|
||||
class TestGroupedConvndBwdData2dGNHWKFloatSplitK1
|
||||
: public ::testing::TestWithParam<ck::utils::conv::ConvParam>
|
||||
{
|
||||
};
|
||||
TEST_P(TestGroupedConvndBwdData2dGNHWKFloatSplitK1, ConvTest)
|
||||
{
|
||||
EXPECT_TRUE((RunConvBwdDataTest<2, GNHWK, GKYXC, GNHWC, float>(GetParam(), 1)));
|
||||
}
|
||||
INSTANTIATE_TEST_SUITE_P(Dataset,
|
||||
TestGroupedConvndBwdData2dGNHWKFloatSplitK1,
|
||||
::testing::ValuesIn(Get2DTestCases()));
|
||||
|
||||
// 2D Tests - GNHWK layout - Float - SplitK=2
|
||||
class TestGroupedConvndBwdData2dGNHWKFloatSplitK2
|
||||
: public ::testing::TestWithParam<ck::utils::conv::ConvParam>
|
||||
{
|
||||
};
|
||||
TEST_P(TestGroupedConvndBwdData2dGNHWKFloatSplitK2, ConvTest)
|
||||
{
|
||||
EXPECT_TRUE((RunConvBwdDataTest<2, GNHWK, GKYXC, GNHWC, float>(GetParam(), 2)));
|
||||
}
|
||||
INSTANTIATE_TEST_SUITE_P(Dataset,
|
||||
TestGroupedConvndBwdData2dGNHWKFloatSplitK2,
|
||||
::testing::ValuesIn(Get2DTestCases()));
|
||||
|
||||
// 2D Tests - GNHWK layout - Half - SplitK=1
|
||||
class TestGroupedConvndBwdData2dGNHWKHalfSplitK1
|
||||
: public ::testing::TestWithParam<ck::utils::conv::ConvParam>
|
||||
{
|
||||
};
|
||||
TEST_P(TestGroupedConvndBwdData2dGNHWKHalfSplitK1, ConvTest)
|
||||
{
|
||||
EXPECT_TRUE((RunConvBwdDataTest<2, GNHWK, GKYXC, GNHWC, ck::half_t>(GetParam(), 1)));
|
||||
}
|
||||
INSTANTIATE_TEST_SUITE_P(Dataset,
|
||||
TestGroupedConvndBwdData2dGNHWKHalfSplitK1,
|
||||
::testing::ValuesIn(Get2DTestCases()));
|
||||
|
||||
// 2D Tests - GNHWK layout - Half - SplitK=2
|
||||
class TestGroupedConvndBwdData2dGNHWKHalfSplitK2
|
||||
: public ::testing::TestWithParam<ck::utils::conv::ConvParam>
|
||||
{
|
||||
};
|
||||
TEST_P(TestGroupedConvndBwdData2dGNHWKHalfSplitK2, ConvTest)
|
||||
{
|
||||
EXPECT_TRUE((RunConvBwdDataTest<2, GNHWK, GKYXC, GNHWC, ck::half_t>(GetParam(), 2)));
|
||||
}
|
||||
INSTANTIATE_TEST_SUITE_P(Dataset,
|
||||
TestGroupedConvndBwdData2dGNHWKHalfSplitK2,
|
||||
::testing::ValuesIn(Get2DTestCases()));
|
||||
|
||||
// 2D Tests - GNHWK layout - BFloat16 - SplitK=1
|
||||
class TestGroupedConvndBwdData2dGNHWKBFloat16SplitK1
|
||||
: public ::testing::TestWithParam<ck::utils::conv::ConvParam>
|
||||
{
|
||||
};
|
||||
TEST_P(TestGroupedConvndBwdData2dGNHWKBFloat16SplitK1, ConvTest)
|
||||
{
|
||||
EXPECT_TRUE((RunConvBwdDataTest<2, GNHWK, GKYXC, GNHWC, ck::bhalf_t>(GetParam(), 1)));
|
||||
}
|
||||
INSTANTIATE_TEST_SUITE_P(Dataset,
|
||||
TestGroupedConvndBwdData2dGNHWKBFloat16SplitK1,
|
||||
::testing::ValuesIn(Get2DTestCases()));
|
||||
|
||||
// 2D Tests - GNHWK layout - BFloat16 - SplitK=2
|
||||
class TestGroupedConvndBwdData2dGNHWKBFloat16SplitK2
|
||||
: public ::testing::TestWithParam<ck::utils::conv::ConvParam>
|
||||
{
|
||||
};
|
||||
TEST_P(TestGroupedConvndBwdData2dGNHWKBFloat16SplitK2, ConvTest)
|
||||
{
|
||||
EXPECT_TRUE((RunConvBwdDataTest<2, GNHWK, GKYXC, GNHWC, ck::bhalf_t>(GetParam(), 2)));
|
||||
}
|
||||
INSTANTIATE_TEST_SUITE_P(Dataset,
|
||||
TestGroupedConvndBwdData2dGNHWKBFloat16SplitK2,
|
||||
::testing::ValuesIn(Get2DTestCases()));
|
||||
|
||||
// 2D Tests - NHWGK layout - Float - SplitK=1
|
||||
class TestGroupedConvndBwdData2dNHWGKFloatSplitK1
|
||||
: public ::testing::TestWithParam<ck::utils::conv::ConvParam>
|
||||
{
|
||||
};
|
||||
TEST_P(TestGroupedConvndBwdData2dNHWGKFloatSplitK1, ConvTest)
|
||||
{
|
||||
EXPECT_TRUE((RunConvBwdDataTest<2, NHWGK, GKYXC, NHWGC, float>(GetParam(), 1)));
|
||||
}
|
||||
INSTANTIATE_TEST_SUITE_P(Dataset,
|
||||
TestGroupedConvndBwdData2dNHWGKFloatSplitK1,
|
||||
::testing::ValuesIn(Get2DTestCases()));
|
||||
|
||||
// 2D Tests - NHWGK layout - Float - SplitK=2
|
||||
class TestGroupedConvndBwdData2dNHWGKFloatSplitK2
|
||||
: public ::testing::TestWithParam<ck::utils::conv::ConvParam>
|
||||
{
|
||||
};
|
||||
TEST_P(TestGroupedConvndBwdData2dNHWGKFloatSplitK2, ConvTest)
|
||||
{
|
||||
EXPECT_TRUE((RunConvBwdDataTest<2, NHWGK, GKYXC, NHWGC, float>(GetParam(), 2)));
|
||||
}
|
||||
INSTANTIATE_TEST_SUITE_P(Dataset,
|
||||
TestGroupedConvndBwdData2dNHWGKFloatSplitK2,
|
||||
::testing::ValuesIn(Get2DTestCases()));
|
||||
|
||||
// 2D Tests - NHWGK layout - Half - SplitK=1
|
||||
class TestGroupedConvndBwdData2dNHWGKHalfSplitK1
|
||||
: public ::testing::TestWithParam<ck::utils::conv::ConvParam>
|
||||
{
|
||||
};
|
||||
TEST_P(TestGroupedConvndBwdData2dNHWGKHalfSplitK1, ConvTest)
|
||||
{
|
||||
EXPECT_TRUE((RunConvBwdDataTest<2, NHWGK, GKYXC, NHWGC, ck::half_t>(GetParam(), 1)));
|
||||
}
|
||||
INSTANTIATE_TEST_SUITE_P(Dataset,
|
||||
TestGroupedConvndBwdData2dNHWGKHalfSplitK1,
|
||||
::testing::ValuesIn(Get2DTestCases()));
|
||||
|
||||
// 2D Tests - NHWGK layout - Half - SplitK=2
|
||||
class TestGroupedConvndBwdData2dNHWGKHalfSplitK2
|
||||
: public ::testing::TestWithParam<ck::utils::conv::ConvParam>
|
||||
{
|
||||
};
|
||||
TEST_P(TestGroupedConvndBwdData2dNHWGKHalfSplitK2, ConvTest)
|
||||
{
|
||||
EXPECT_TRUE((RunConvBwdDataTest<2, NHWGK, GKYXC, NHWGC, ck::half_t>(GetParam(), 2)));
|
||||
}
|
||||
INSTANTIATE_TEST_SUITE_P(Dataset,
|
||||
TestGroupedConvndBwdData2dNHWGKHalfSplitK2,
|
||||
::testing::ValuesIn(Get2DTestCases()));
|
||||
|
||||
// 2D Tests - NHWGK layout - BFloat16 - SplitK=1
|
||||
class TestGroupedConvndBwdData2dNHWGKBFloat16SplitK1
|
||||
: public ::testing::TestWithParam<ck::utils::conv::ConvParam>
|
||||
{
|
||||
};
|
||||
TEST_P(TestGroupedConvndBwdData2dNHWGKBFloat16SplitK1, ConvTest)
|
||||
{
|
||||
EXPECT_TRUE((RunConvBwdDataTest<2, NHWGK, GKYXC, NHWGC, ck::bhalf_t>(GetParam(), 1)));
|
||||
}
|
||||
INSTANTIATE_TEST_SUITE_P(Dataset,
|
||||
TestGroupedConvndBwdData2dNHWGKBFloat16SplitK1,
|
||||
::testing::ValuesIn(Get2DTestCases()));
|
||||
|
||||
// 2D Tests - NHWGK layout - BFloat16 - SplitK=2
|
||||
class TestGroupedConvndBwdData2dNHWGKBFloat16SplitK2
|
||||
: public ::testing::TestWithParam<ck::utils::conv::ConvParam>
|
||||
{
|
||||
};
|
||||
TEST_P(TestGroupedConvndBwdData2dNHWGKBFloat16SplitK2, ConvTest)
|
||||
{
|
||||
EXPECT_TRUE((RunConvBwdDataTest<2, NHWGK, GKYXC, NHWGC, ck::bhalf_t>(GetParam(), 2)));
|
||||
}
|
||||
INSTANTIATE_TEST_SUITE_P(Dataset,
|
||||
TestGroupedConvndBwdData2dNHWGKBFloat16SplitK2,
|
||||
::testing::ValuesIn(Get2DTestCases()));
|
||||
|
||||
// 3D Tests - NDHWGK layout - Float - SplitK=1
|
||||
class TestGroupedConvndBwdData3dNDHWGKFloatSplitK1
|
||||
: public ::testing::TestWithParam<ck::utils::conv::ConvParam>
|
||||
{
|
||||
};
|
||||
TEST_P(TestGroupedConvndBwdData3dNDHWGKFloatSplitK1, ConvTest)
|
||||
{
|
||||
EXPECT_TRUE((RunConvBwdDataTest<3, NDHWGK, GKZYXC, NDHWGC, float>(GetParam(), 1)));
|
||||
}
|
||||
INSTANTIATE_TEST_SUITE_P(Dataset,
|
||||
TestGroupedConvndBwdData3dNDHWGKFloatSplitK1,
|
||||
::testing::ValuesIn(Get3DTestCases()));
|
||||
|
||||
// 3D Tests - NDHWGK layout - Float - SplitK=2
|
||||
class TestGroupedConvndBwdData3dNDHWGKFloatSplitK2
|
||||
: public ::testing::TestWithParam<ck::utils::conv::ConvParam>
|
||||
{
|
||||
};
|
||||
TEST_P(TestGroupedConvndBwdData3dNDHWGKFloatSplitK2, ConvTest)
|
||||
{
|
||||
EXPECT_TRUE((RunConvBwdDataTest<3, NDHWGK, GKZYXC, NDHWGC, float>(GetParam(), 2)));
|
||||
}
|
||||
INSTANTIATE_TEST_SUITE_P(Dataset,
|
||||
TestGroupedConvndBwdData3dNDHWGKFloatSplitK2,
|
||||
::testing::ValuesIn(Get3DTestCases()));
|
||||
|
||||
// 3D Tests - NDHWGK layout - Half - SplitK=1
|
||||
class TestGroupedConvndBwdData3dNDHWGKHalfSplitK1
|
||||
: public ::testing::TestWithParam<ck::utils::conv::ConvParam>
|
||||
{
|
||||
};
|
||||
TEST_P(TestGroupedConvndBwdData3dNDHWGKHalfSplitK1, ConvTest)
|
||||
{
|
||||
EXPECT_TRUE((RunConvBwdDataTest<3, NDHWGK, GKZYXC, NDHWGC, ck::half_t>(GetParam(), 1)));
|
||||
}
|
||||
INSTANTIATE_TEST_SUITE_P(Dataset,
|
||||
TestGroupedConvndBwdData3dNDHWGKHalfSplitK1,
|
||||
::testing::ValuesIn(Get3DTestCases()));
|
||||
|
||||
// 3D Tests - NDHWGK layout - Half - SplitK=2
|
||||
class TestGroupedConvndBwdData3dNDHWGKHalfSplitK2
|
||||
: public ::testing::TestWithParam<ck::utils::conv::ConvParam>
|
||||
{
|
||||
};
|
||||
TEST_P(TestGroupedConvndBwdData3dNDHWGKHalfSplitK2, ConvTest)
|
||||
{
|
||||
EXPECT_TRUE((RunConvBwdDataTest<3, NDHWGK, GKZYXC, NDHWGC, ck::half_t>(GetParam(), 2)));
|
||||
}
|
||||
INSTANTIATE_TEST_SUITE_P(Dataset,
|
||||
TestGroupedConvndBwdData3dNDHWGKHalfSplitK2,
|
||||
::testing::ValuesIn(Get3DTestCases()));
|
||||
|
||||
// 3D Tests - NDHWGK layout - BFloat16 - SplitK=1
|
||||
class TestGroupedConvndBwdData3dNDHWGKBFloat16SplitK1
|
||||
: public ::testing::TestWithParam<ck::utils::conv::ConvParam>
|
||||
{
|
||||
};
|
||||
TEST_P(TestGroupedConvndBwdData3dNDHWGKBFloat16SplitK1, ConvTest)
|
||||
{
|
||||
EXPECT_TRUE((RunConvBwdDataTest<3, NDHWGK, GKZYXC, NDHWGC, ck::bhalf_t>(GetParam(), 1)));
|
||||
}
|
||||
INSTANTIATE_TEST_SUITE_P(Dataset,
|
||||
TestGroupedConvndBwdData3dNDHWGKBFloat16SplitK1,
|
||||
::testing::ValuesIn(Get3DTestCases()));
|
||||
|
||||
// 3D Tests - NDHWGK layout - BFloat16 - SplitK=2
|
||||
class TestGroupedConvndBwdData3dNDHWGKBFloat16SplitK2
|
||||
: public ::testing::TestWithParam<ck::utils::conv::ConvParam>
|
||||
{
|
||||
};
|
||||
TEST_P(TestGroupedConvndBwdData3dNDHWGKBFloat16SplitK2, ConvTest)
|
||||
{
|
||||
EXPECT_TRUE((RunConvBwdDataTest<3, NDHWGK, GKZYXC, NDHWGC, ck::bhalf_t>(GetParam(), 2)));
|
||||
}
|
||||
INSTANTIATE_TEST_SUITE_P(Dataset,
|
||||
TestGroupedConvndBwdData3dNDHWGKBFloat16SplitK2,
|
||||
::testing::ValuesIn(Get3DTestCases()));
|
||||
@@ -4,6 +4,10 @@
|
||||
if(GPU_TARGETS MATCHES "gfx9|gfx11|gfx12")
|
||||
add_gtest_executable(test_grouped_convnd_bwd_weight test_grouped_convnd_bwd_weight.cpp)
|
||||
target_link_libraries(test_grouped_convnd_bwd_weight PRIVATE utility device_grouped_conv1d_bwd_weight_instance device_grouped_conv2d_bwd_weight_instance device_grouped_conv3d_bwd_weight_instance device_grouped_convnd_bwd_weight_instance)
|
||||
|
||||
add_executable(test_grouped_convnd_bwd_weight_dataset_xdl test_grouped_convnd_bwd_weight_dataset_xdl.cpp)
|
||||
target_compile_options(test_grouped_convnd_bwd_weight_dataset_xdl PRIVATE -Wno-global-constructors -Wno-undef)
|
||||
target_link_libraries(test_grouped_convnd_bwd_weight_dataset_xdl PRIVATE gtest_main getopt::getopt utility device_grouped_conv1d_bwd_weight_instance device_grouped_conv2d_bwd_weight_instance device_grouped_conv3d_bwd_weight_instance device_grouped_convnd_bwd_weight_instance)
|
||||
elseif(DL_KERNELS)
|
||||
add_gtest_executable(test_grouped_convnd_bwd_weight test_grouped_convnd_bwd_weight.cpp)
|
||||
target_link_libraries(test_grouped_convnd_bwd_weight PRIVATE utility device_grouped_conv1d_bwd_weight_instance device_grouped_conv2d_bwd_weight_instance device_grouped_conv3d_bwd_weight_instance)
|
||||
|
||||
@@ -0,0 +1,258 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include <cstdlib> // Standard C library (exit codes, malloc)
|
||||
#include <iostream> // C++ I/O streams (cout, cerr)
|
||||
#include <initializer_list> // C++ initializer list support (unused here)
|
||||
#include <vector> // C++ vector container - stores test cases
|
||||
#include <string> // String operations
|
||||
#include <gtest/gtest.h> // Google Test framework - provides TEST_P, INSTANTIATE_TEST_SUITE_P
|
||||
|
||||
#include "ck/utility/common_header.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
|
||||
#include "profiler/profile_grouped_conv_bwd_weight_impl.hpp" // The actual GPU profiler that does convolution work
|
||||
#include "../common/csv_test_loader.hpp" // Shared CSV test case loader
|
||||
|
||||
using namespace ck::tensor_layout::convolution;
|
||||
|
||||
// Load CSV data for 2D tests
|
||||
static std::vector<ck::utils::conv::ConvParam> Get2DTestCases()
|
||||
{
|
||||
static std::vector<ck::utils::conv::ConvParam> test_cases;
|
||||
if(test_cases.empty())
|
||||
{
|
||||
std::string test_data_dir = ck::test::GetTestDataPath();
|
||||
if(test_data_dir.empty())
|
||||
{
|
||||
std::cerr << "FATAL: test_data directory not found" << std::endl;
|
||||
return test_cases;
|
||||
}
|
||||
|
||||
std::vector<std::string> csv_paths = {test_data_dir + "/conv_test_set_2d_dataset.csv"};
|
||||
bool loaded = ck::test::load_and_populate_test_cases(csv_paths, test_cases, "2D");
|
||||
if(!loaded)
|
||||
{
|
||||
std::cerr << "FATAL: Failed to load 2D test cases from " << csv_paths[0] << std::endl;
|
||||
}
|
||||
}
|
||||
return test_cases;
|
||||
}
|
||||
|
||||
// Load CSV data for 3D tests
|
||||
static std::vector<ck::utils::conv::ConvParam> Get3DTestCases()
|
||||
{
|
||||
static std::vector<ck::utils::conv::ConvParam> test_cases;
|
||||
if(test_cases.empty())
|
||||
{
|
||||
std::string test_data_dir = ck::test::GetTestDataPath();
|
||||
if(test_data_dir.empty())
|
||||
{
|
||||
std::cerr << "FATAL: test_data directory not found" << std::endl;
|
||||
return test_cases;
|
||||
}
|
||||
|
||||
std::vector<std::string> csv_paths = {test_data_dir + "/conv_test_set_3d_dataset.csv"};
|
||||
bool loaded = ck::test::load_and_populate_test_cases(csv_paths, test_cases, "3D");
|
||||
if(!loaded)
|
||||
{
|
||||
std::cerr << "FATAL: Failed to load 3D test cases from " << csv_paths[0] << std::endl;
|
||||
}
|
||||
}
|
||||
return test_cases;
|
||||
}
|
||||
|
||||
// Helper template to run a single backward weight convolution test
|
||||
template <ck::index_t NDimSpatial,
|
||||
typename InLayout,
|
||||
typename WeiLayout,
|
||||
typename OutLayout,
|
||||
typename InDataType,
|
||||
typename WeiDataType,
|
||||
typename OutDataType>
|
||||
bool RunConvBwdWeightTest(const ck::utils::conv::ConvParam& param, ck::index_t split_k)
|
||||
{
|
||||
return ck::profiler::profile_grouped_conv_bwd_weight_impl<NDimSpatial,
|
||||
InLayout,
|
||||
WeiLayout,
|
||||
OutLayout,
|
||||
InDataType,
|
||||
WeiDataType,
|
||||
OutDataType>(
|
||||
true, // do_verification
|
||||
1, // init_method
|
||||
false, // do_log
|
||||
false, // time_kernel
|
||||
param, // ConvParam
|
||||
std::to_string(split_k), // Split-K value as string
|
||||
-1); // instance_index
|
||||
}
|
||||
|
||||
// 2D Tests - NHWGK layout - Float - SplitK=1
|
||||
class TestGroupedConvndBwdWeight2dNHWGKFloatSplitK1
|
||||
: public ::testing::TestWithParam<ck::utils::conv::ConvParam>
|
||||
{
|
||||
};
|
||||
TEST_P(TestGroupedConvndBwdWeight2dNHWGKFloatSplitK1, ConvTest)
|
||||
{
|
||||
EXPECT_TRUE((RunConvBwdWeightTest<2, NHWGC, GKYXC, NHWGK, float, float, float>(GetParam(), 1)));
|
||||
}
|
||||
INSTANTIATE_TEST_SUITE_P(Dataset,
|
||||
TestGroupedConvndBwdWeight2dNHWGKFloatSplitK1,
|
||||
::testing::ValuesIn(Get2DTestCases()));
|
||||
|
||||
// 2D Tests - NHWGK layout - Float - SplitK=2
|
||||
class TestGroupedConvndBwdWeight2dNHWGKFloatSplitK2
|
||||
: public ::testing::TestWithParam<ck::utils::conv::ConvParam>
|
||||
{
|
||||
};
|
||||
TEST_P(TestGroupedConvndBwdWeight2dNHWGKFloatSplitK2, ConvTest)
|
||||
{
|
||||
EXPECT_TRUE((RunConvBwdWeightTest<2, NHWGC, GKYXC, NHWGK, float, float, float>(GetParam(), 2)));
|
||||
}
|
||||
INSTANTIATE_TEST_SUITE_P(Dataset,
|
||||
TestGroupedConvndBwdWeight2dNHWGKFloatSplitK2,
|
||||
::testing::ValuesIn(Get2DTestCases()));
|
||||
|
||||
// 2D Tests - NHWGK layout - Half - SplitK=1
|
||||
class TestGroupedConvndBwdWeight2dNHWGKHalfSplitK1
|
||||
: public ::testing::TestWithParam<ck::utils::conv::ConvParam>
|
||||
{
|
||||
};
|
||||
TEST_P(TestGroupedConvndBwdWeight2dNHWGKHalfSplitK1, ConvTest)
|
||||
{
|
||||
EXPECT_TRUE((RunConvBwdWeightTest<2, NHWGC, GKYXC, NHWGK, ck::half_t, ck::half_t, ck::half_t>(
|
||||
GetParam(), 1)));
|
||||
}
|
||||
INSTANTIATE_TEST_SUITE_P(Dataset,
|
||||
TestGroupedConvndBwdWeight2dNHWGKHalfSplitK1,
|
||||
::testing::ValuesIn(Get2DTestCases()));
|
||||
|
||||
// 2D Tests - NHWGK layout - Half - SplitK=2
|
||||
class TestGroupedConvndBwdWeight2dNHWGKHalfSplitK2
|
||||
: public ::testing::TestWithParam<ck::utils::conv::ConvParam>
|
||||
{
|
||||
};
|
||||
TEST_P(TestGroupedConvndBwdWeight2dNHWGKHalfSplitK2, ConvTest)
|
||||
{
|
||||
EXPECT_TRUE((RunConvBwdWeightTest<2, NHWGC, GKYXC, NHWGK, ck::half_t, ck::half_t, ck::half_t>(
|
||||
GetParam(), 2)));
|
||||
}
|
||||
INSTANTIATE_TEST_SUITE_P(Dataset,
|
||||
TestGroupedConvndBwdWeight2dNHWGKHalfSplitK2,
|
||||
::testing::ValuesIn(Get2DTestCases()));
|
||||
|
||||
// 2D Tests - NHWGK layout - BFloat16 - SplitK=1
|
||||
class TestGroupedConvndBwdWeight2dNHWGKBFloat16SplitK1
|
||||
: public ::testing::TestWithParam<ck::utils::conv::ConvParam>
|
||||
{
|
||||
};
|
||||
TEST_P(TestGroupedConvndBwdWeight2dNHWGKBFloat16SplitK1, ConvTest)
|
||||
{
|
||||
EXPECT_TRUE((RunConvBwdWeightTest<2, NHWGC, GKYXC, NHWGK, ck::bhalf_t, float, ck::bhalf_t>(
|
||||
GetParam(), 1)));
|
||||
}
|
||||
INSTANTIATE_TEST_SUITE_P(Dataset,
|
||||
TestGroupedConvndBwdWeight2dNHWGKBFloat16SplitK1,
|
||||
::testing::ValuesIn(Get2DTestCases()));
|
||||
|
||||
// 2D Tests - NHWGK layout - BFloat16 - SplitK=2
|
||||
class TestGroupedConvndBwdWeight2dNHWGKBFloat16SplitK2
|
||||
: public ::testing::TestWithParam<ck::utils::conv::ConvParam>
|
||||
{
|
||||
};
|
||||
TEST_P(TestGroupedConvndBwdWeight2dNHWGKBFloat16SplitK2, ConvTest)
|
||||
{
|
||||
EXPECT_TRUE((RunConvBwdWeightTest<2, NHWGC, GKYXC, NHWGK, ck::bhalf_t, float, ck::bhalf_t>(
|
||||
GetParam(), 2)));
|
||||
}
|
||||
INSTANTIATE_TEST_SUITE_P(Dataset,
|
||||
TestGroupedConvndBwdWeight2dNHWGKBFloat16SplitK2,
|
||||
::testing::ValuesIn(Get2DTestCases()));
|
||||
|
||||
// 3D Tests - NDHWGK layout - Float - SplitK=1
|
||||
class TestGroupedConvndBwdWeight3dNDHWGKFloatSplitK1
|
||||
: public ::testing::TestWithParam<ck::utils::conv::ConvParam>
|
||||
{
|
||||
};
|
||||
TEST_P(TestGroupedConvndBwdWeight3dNDHWGKFloatSplitK1, ConvTest)
|
||||
{
|
||||
EXPECT_TRUE(
|
||||
(RunConvBwdWeightTest<3, NDHWGC, GKZYXC, NDHWGK, float, float, float>(GetParam(), 1)));
|
||||
}
|
||||
INSTANTIATE_TEST_SUITE_P(Dataset,
|
||||
TestGroupedConvndBwdWeight3dNDHWGKFloatSplitK1,
|
||||
::testing::ValuesIn(Get3DTestCases()));
|
||||
|
||||
// 3D Tests - NDHWGK layout - Float - SplitK=2
|
||||
class TestGroupedConvndBwdWeight3dNDHWGKFloatSplitK2
|
||||
: public ::testing::TestWithParam<ck::utils::conv::ConvParam>
|
||||
{
|
||||
};
|
||||
TEST_P(TestGroupedConvndBwdWeight3dNDHWGKFloatSplitK2, ConvTest)
|
||||
{
|
||||
EXPECT_TRUE(
|
||||
(RunConvBwdWeightTest<3, NDHWGC, GKZYXC, NDHWGK, float, float, float>(GetParam(), 2)));
|
||||
}
|
||||
INSTANTIATE_TEST_SUITE_P(Dataset,
|
||||
TestGroupedConvndBwdWeight3dNDHWGKFloatSplitK2,
|
||||
::testing::ValuesIn(Get3DTestCases()));
|
||||
|
||||
// 3D Tests - NDHWGK layout - Half - SplitK=1
|
||||
class TestGroupedConvndBwdWeight3dNDHWGKHalfSplitK1
|
||||
: public ::testing::TestWithParam<ck::utils::conv::ConvParam>
|
||||
{
|
||||
};
|
||||
TEST_P(TestGroupedConvndBwdWeight3dNDHWGKHalfSplitK1, ConvTest)
|
||||
{
|
||||
EXPECT_TRUE(
|
||||
(RunConvBwdWeightTest<3, NDHWGC, GKZYXC, NDHWGK, ck::half_t, ck::half_t, ck::half_t>(
|
||||
GetParam(), 1)));
|
||||
}
|
||||
INSTANTIATE_TEST_SUITE_P(Dataset,
|
||||
TestGroupedConvndBwdWeight3dNDHWGKHalfSplitK1,
|
||||
::testing::ValuesIn(Get3DTestCases()));
|
||||
|
||||
// 3D Tests - NDHWGK layout - Half - SplitK=2
|
||||
class TestGroupedConvndBwdWeight3dNDHWGKHalfSplitK2
|
||||
: public ::testing::TestWithParam<ck::utils::conv::ConvParam>
|
||||
{
|
||||
};
|
||||
TEST_P(TestGroupedConvndBwdWeight3dNDHWGKHalfSplitK2, ConvTest)
|
||||
{
|
||||
EXPECT_TRUE(
|
||||
(RunConvBwdWeightTest<3, NDHWGC, GKZYXC, NDHWGK, ck::half_t, ck::half_t, ck::half_t>(
|
||||
GetParam(), 2)));
|
||||
}
|
||||
INSTANTIATE_TEST_SUITE_P(Dataset,
|
||||
TestGroupedConvndBwdWeight3dNDHWGKHalfSplitK2,
|
||||
::testing::ValuesIn(Get3DTestCases()));
|
||||
|
||||
// 3D Tests - NDHWGK layout - BFloat16 - SplitK=1
|
||||
class TestGroupedConvndBwdWeight3dNDHWGKBFloat16SplitK1
|
||||
: public ::testing::TestWithParam<ck::utils::conv::ConvParam>
|
||||
{
|
||||
};
|
||||
TEST_P(TestGroupedConvndBwdWeight3dNDHWGKBFloat16SplitK1, ConvTest)
|
||||
{
|
||||
EXPECT_TRUE((RunConvBwdWeightTest<3, NDHWGC, GKZYXC, NDHWGK, ck::bhalf_t, float, ck::bhalf_t>(
|
||||
GetParam(), 1)));
|
||||
}
|
||||
INSTANTIATE_TEST_SUITE_P(Dataset,
|
||||
TestGroupedConvndBwdWeight3dNDHWGKBFloat16SplitK1,
|
||||
::testing::ValuesIn(Get3DTestCases()));
|
||||
|
||||
// 3D Tests - NDHWGK layout - BFloat16 - SplitK=2
|
||||
class TestGroupedConvndBwdWeight3dNDHWGKBFloat16SplitK2
|
||||
: public ::testing::TestWithParam<ck::utils::conv::ConvParam>
|
||||
{
|
||||
};
|
||||
TEST_P(TestGroupedConvndBwdWeight3dNDHWGKBFloat16SplitK2, ConvTest)
|
||||
{
|
||||
EXPECT_TRUE((RunConvBwdWeightTest<3, NDHWGC, GKZYXC, NDHWGK, ck::bhalf_t, float, ck::bhalf_t>(
|
||||
GetParam(), 2)));
|
||||
}
|
||||
INSTANTIATE_TEST_SUITE_P(Dataset,
|
||||
TestGroupedConvndBwdWeight3dNDHWGKBFloat16SplitK2,
|
||||
::testing::ValuesIn(Get3DTestCases()));
|
||||
@@ -5,330 +5,165 @@
|
||||
#include <iostream> // C++ I/O streams (cout, cerr)
|
||||
#include <initializer_list> // C++ initializer list support (unused here)
|
||||
#include <vector> // C++ vector container - stores test cases
|
||||
#include <fstream> // File I/O for CSV reading
|
||||
#include <sstream> // String stream for CSV parsing
|
||||
#include <string> // String operations
|
||||
#include <gtest/gtest.h> // Google Test framework - provides TYPED_TEST, EXPECT_TRUE
|
||||
#include <gtest/gtest.h> // Google Test framework - provides TEST_P, INSTANTIATE_TEST_SUITE_P
|
||||
|
||||
#include "profiler/profile_grouped_conv_fwd_impl.hpp" // The actual GPU profiler that does convolution work
|
||||
|
||||
// CSV Reader Function for Loading Test Cases
|
||||
// Reads convolution parameters from CSV file and returns vector of ConvParam structures
|
||||
std::vector<ck::utils::conv::ConvParam> load_csv_test_cases(const std::string& filename)
|
||||
{
|
||||
std::vector<ck::utils::conv::ConvParam> conv_params; // Return vector
|
||||
std::ifstream file(filename); // Open CSV file
|
||||
|
||||
if(!file.is_open())
|
||||
{
|
||||
std::cerr << "ERROR: Cannot open CSV file: " << filename << std::endl;
|
||||
return conv_params; // Return empty vector on error
|
||||
}
|
||||
|
||||
std::string line;
|
||||
int line_number = 0;
|
||||
|
||||
// Read file line by line
|
||||
while(std::getline(file, line))
|
||||
{
|
||||
line_number++;
|
||||
// Skip comment lines (starting with #) and empty lines
|
||||
if(line.empty() || line[0] == '#')
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
// Skip header line (contains column names)
|
||||
if(line.find("NDim,Groups,BatchSize") != std::string::npos)
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
// Parse CSV line using stringstream
|
||||
std::stringstream ss(line);
|
||||
std::string cell;
|
||||
std::vector<std::string> row;
|
||||
|
||||
// Split line by commas
|
||||
while(std::getline(ss, cell, ','))
|
||||
{
|
||||
row.push_back(cell);
|
||||
}
|
||||
|
||||
// Validate row has correct number of columns
|
||||
if(row.size() < 19)
|
||||
{ // Need at least 19 columns for 2D (excluding TestName)
|
||||
std::cerr << "WARNING: Line " << line_number << " has insufficient columns ("
|
||||
<< row.size() << "), skipping" << std::endl;
|
||||
continue;
|
||||
}
|
||||
|
||||
try
|
||||
{
|
||||
// Parse CSV data into ConvParam structure
|
||||
// CSV Format:
|
||||
// NDim,Groups,BatchSize,OutChannels,InChannels,KernelH,KernelW,InputH,InputW,OutputH,OutputW,StrideH,StrideW,DilationH,DilationW,LeftPadH,LeftPadW,RightPadH,RightPadW,TestName
|
||||
int NDim = std::stoi(row[0]);
|
||||
int Groups = std::stoi(row[1]);
|
||||
int BatchSize = std::stoi(row[2]);
|
||||
int OutChannels = std::stoi(row[3]);
|
||||
int InChannels = std::stoi(row[4]);
|
||||
|
||||
if(NDim == 2)
|
||||
{
|
||||
// 2D Convolution: {NDim, Groups, BatchSize, OutChannels, InChannels,
|
||||
// {KernelH,KernelW}, {InputH,InputW}, {StrideH,StrideW}, {DilationH,DilationW},
|
||||
// {LeftPadH,LeftPadW}, {RightPadH,RightPadW}}
|
||||
ck::utils::conv::ConvParam param = {
|
||||
NDim, // NDim = 2
|
||||
Groups, // Groups
|
||||
BatchSize, // Batch size
|
||||
OutChannels, // Output channels
|
||||
InChannels, // Input channels
|
||||
{std::stoi(row[5]), std::stoi(row[6])}, // Kernel: {H, W}
|
||||
{std::stoi(row[7]), std::stoi(row[8])}, // Input: {H, W}
|
||||
{std::stoi(row[11]), std::stoi(row[12])}, // Stride: {H, W}
|
||||
{std::stoi(row[13]), std::stoi(row[14])}, // Dilation: {H, W}
|
||||
{std::stoi(row[15]), std::stoi(row[16])}, // Left pad: {H, W}
|
||||
{std::stoi(row[17]), std::stoi(row[18])} // Right pad: {H, W}
|
||||
};
|
||||
conv_params.push_back(param);
|
||||
}
|
||||
else if(NDim == 3)
|
||||
{
|
||||
// 3D Convolution: Need more columns for 3D parameters
|
||||
if(row.size() < 26)
|
||||
{
|
||||
std::cerr << "WARNING: 3D convolution on line " << line_number
|
||||
<< " needs 26+ columns, has " << row.size() << ", skipping"
|
||||
<< std::endl;
|
||||
continue;
|
||||
}
|
||||
// 3D Convolution: {NDim, Groups, BatchSize, OutChannels, InChannels,
|
||||
// {KernelD,KernelH,KernelW}, {InputD,InputH,InputW}, {OutputD,OutputH,OutputW},
|
||||
// {StrideD,StrideH,StrideW}, {DilationD,DilationH,DilationW},
|
||||
// {LeftPadD,LeftPadH,LeftPadW}, {RightPadD,RightPadH,RightPadW}}
|
||||
ck::utils::conv::ConvParam param = {
|
||||
NDim, // NDim = 3
|
||||
Groups, // Groups
|
||||
BatchSize, // Batch size
|
||||
OutChannels, // Output channels
|
||||
InChannels, // Input channels
|
||||
{std::stoi(row[5]), std::stoi(row[6]), std::stoi(row[7])}, // Kernel: {D, H, W}
|
||||
{std::stoi(row[8]), std::stoi(row[9]), std::stoi(row[10])}, // Input: {D, H, W}
|
||||
{std::stoi(row[14]),
|
||||
std::stoi(row[15]),
|
||||
std::stoi(row[16])}, // Stride: {D, H, W}
|
||||
{std::stoi(row[17]),
|
||||
std::stoi(row[18]),
|
||||
std::stoi(row[19])}, // Dilation: {D, H, W}
|
||||
{std::stoi(row[20]),
|
||||
std::stoi(row[21]),
|
||||
std::stoi(row[22])}, // Left pad: {D, H, W}
|
||||
{std::stoi(row[23]),
|
||||
std::stoi(row[24]),
|
||||
std::stoi(row[25])} // Right pad: {D, H, W}
|
||||
};
|
||||
conv_params.push_back(param);
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cerr << "WARNING: Unsupported NDim=" << NDim << " on line " << line_number
|
||||
<< ", skipping" << std::endl;
|
||||
}
|
||||
}
|
||||
catch(const std::exception& e)
|
||||
{
|
||||
std::cerr << "ERROR: Failed to parse line " << line_number << ": " << e.what()
|
||||
<< std::endl;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
file.close();
|
||||
std::cout << "Loaded " << conv_params.size() << " test cases from " << filename << std::endl;
|
||||
return conv_params;
|
||||
}
|
||||
|
||||
// Template class that works with different data types and tensor layouts
|
||||
template <typename Tuple>
|
||||
class TestGroupedConvndFwd : public ::testing::Test // Inherit from Google Test base class
|
||||
{
|
||||
protected:
|
||||
using DataType =
|
||||
std::tuple_element_t<0, Tuple>; // Extract data type from tuple (fp32, fp16, bf16, int8)
|
||||
using InLayout =
|
||||
std::tuple_element_t<1, Tuple>; // Extract input tensor layout (NHWGC, NDHWGC, etc.)
|
||||
using WeiLayout =
|
||||
std::tuple_element_t<2, Tuple>; // Extract weight tensor layout (GKYXC, GKZYXC, etc.)
|
||||
using OutLayout =
|
||||
std::tuple_element_t<3, Tuple>; // Extract output tensor layout (NHWGK, NDHWGK, etc.)
|
||||
using IndexType = ck::long_index_t; // 64-bit integer type for tensor dimensions
|
||||
|
||||
// THE KEY CONTAINER: This stores all test case parameters
|
||||
// Each test will push_back() ConvParam structures here
|
||||
std::vector<ck::utils::conv::ConvParam> conv_params;
|
||||
|
||||
// Template function to run tests for N-dimensional spatial convolution (2D or 3D)
|
||||
template <ck::index_t NDimSpatial>
|
||||
void Run()
|
||||
{
|
||||
EXPECT_FALSE(conv_params.empty()); // Google Test assertion: ensure we have test cases
|
||||
bool pass = true; // Track overall pass/fail across all test cases
|
||||
|
||||
// MAIN LOOP: Execute every test case that was added to conv_params
|
||||
for(auto& param : conv_params)
|
||||
{
|
||||
// CALL THE ACTUAL GPU PROFILER - This is where convolution happens!
|
||||
pass = pass &&
|
||||
ck::profiler::profile_grouped_conv_fwd_impl<NDimSpatial,
|
||||
InLayout, // Input tensor layout
|
||||
WeiLayout, // Weight tensor layout
|
||||
OutLayout, // Output tensor layout
|
||||
DataType, // Input data type
|
||||
DataType, // Weight data type
|
||||
DataType, // Output data type
|
||||
DataType, // Accumulation type
|
||||
DataType, // Bias type
|
||||
IndexType>( // Index type (int64)
|
||||
true, // do_verification: Compare GPU result with CPU reference
|
||||
1, // init_method: How to initialize random test data (1 = uniform -5 to 5)
|
||||
false, // do_log: Don't print detailed tensor values
|
||||
false, // time_kernel: Don't do performance timing (just correctness)
|
||||
param); // ConvParam: {NDim, Groups, Batch, OutChannels, InChannels,
|
||||
// KernelSize, InputSize, ...}
|
||||
}
|
||||
EXPECT_TRUE(pass); // Google Test assertion: ALL test cases must pass
|
||||
}
|
||||
};
|
||||
#include "../common/csv_test_loader.hpp" // Shared CSV test case loader
|
||||
|
||||
using namespace ck::tensor_layout::convolution; // Import tensor layout names (NHWGC, GKYXC, etc.)
|
||||
|
||||
// GOOGLE TEST TYPE COMBINATIONS: Define what data types and layouts to test
|
||||
// This creates 4 separate test instances for 2D convolution:
|
||||
using KernelTypes2d =
|
||||
::testing::Types<std::tuple<float, NHWGC, GKYXC, NHWGK>, // fp32 test
|
||||
std::tuple<ck::half_t, NHWGC, GKYXC, NHWGK>, // fp16 test
|
||||
std::tuple<ck::bhalf_t, NHWGC, GKYXC, NHWGK>, // bfloat16 test
|
||||
std::tuple<int8_t, NHWGC, GKYXC, NHWGK>>; // int8 test
|
||||
|
||||
// This creates 3 separate test instances for 3D convolution (no int8 support for 3D):
|
||||
using KernelTypes3d =
|
||||
::testing::Types<std::tuple<float, NDHWGC, GKZYXC, NDHWGK>, // fp32 3D test
|
||||
std::tuple<ck::half_t, NDHWGC, GKZYXC, NDHWGK>, // fp16 3D test
|
||||
std::tuple<ck::bhalf_t, NDHWGC, GKZYXC, NDHWGK>>; // bfloat16 3D test
|
||||
|
||||
// Create specialized test classes that inherit from the base template class
|
||||
template <typename Tuple>
|
||||
class TestGroupedConvndFwd2d : public TestGroupedConvndFwd<Tuple> // 2D convolution test class
|
||||
// Load CSV data for 2D tests
|
||||
static std::vector<ck::utils::conv::ConvParam> Get2DTestCases()
|
||||
{
|
||||
};
|
||||
|
||||
template <typename Tuple>
|
||||
class TestGroupedConvndFwd3d : public TestGroupedConvndFwd<Tuple> // 3D convolution test class
|
||||
{
|
||||
};
|
||||
|
||||
// GOOGLE TEST MAGIC: Create test suites
|
||||
// This tells Google Test to create 4 test instances for 2D (fp32, fp16, bf16, int8)
|
||||
TYPED_TEST_SUITE(TestGroupedConvndFwd2d, KernelTypes2d);
|
||||
// This tells Google Test to create 3 test instances for 3D (fp32, fp16, bf16)
|
||||
TYPED_TEST_SUITE(TestGroupedConvndFwd3d, KernelTypes3d);
|
||||
|
||||
// THE ACTUAL 2D TEST - This runs 4 times (once for each data type: fp32, fp16, bf16, int8)
|
||||
TYPED_TEST(TestGroupedConvndFwd2d, Test2D)
|
||||
{
|
||||
// LOAD TEST CASES FROM CSV FILE instead of hardcoded cases
|
||||
// Try different locations for the CSV file (build directory vs source directory)
|
||||
std::vector<std::string> csv_paths = {
|
||||
"../test_data/conv_test_set_2d_dataset.csv", // From build directory to source
|
||||
};
|
||||
|
||||
bool loaded = false;
|
||||
for(const auto& csv_path : csv_paths)
|
||||
static std::vector<ck::utils::conv::ConvParam> test_cases;
|
||||
if(test_cases.empty())
|
||||
{
|
||||
auto csv_cases = load_csv_test_cases(csv_path);
|
||||
if(!csv_cases.empty())
|
||||
std::string test_data_dir = ck::test::GetTestDataPath();
|
||||
if(test_data_dir.empty())
|
||||
{
|
||||
// Successfully loaded CSV data - add all test cases to conv_params
|
||||
for(const auto& test_case : csv_cases)
|
||||
{
|
||||
this->conv_params.push_back(test_case);
|
||||
}
|
||||
std::cout << "Loaded " << csv_cases.size() << " 2D test cases from " << csv_path
|
||||
<< std::endl;
|
||||
loaded = true;
|
||||
break;
|
||||
std::cerr << "FATAL: test_data directory not found" << std::endl;
|
||||
return test_cases;
|
||||
}
|
||||
|
||||
std::vector<std::string> csv_paths = {test_data_dir + "/conv_test_set_2d_dataset.csv"};
|
||||
bool loaded = ck::test::load_and_populate_test_cases(csv_paths, test_cases, "2D");
|
||||
if(!loaded)
|
||||
{
|
||||
std::cerr << "FATAL: Failed to load 2D test cases from " << csv_paths[0] << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
// FAIL if CSV loading fails - no fallback!
|
||||
if(!loaded)
|
||||
{
|
||||
std::cerr << "ERROR: Failed to load CSV test data from any of these locations:"
|
||||
<< std::endl;
|
||||
for(const auto& path : csv_paths)
|
||||
{
|
||||
std::cerr << " - " << path << std::endl;
|
||||
}
|
||||
std::cerr << "\nPlease ensure CSV test data exists in one of these locations." << std::endl;
|
||||
std::cerr << "Run generate_test_dataset.sh in test_data/ to create test datasets."
|
||||
<< std::endl;
|
||||
|
||||
// Force test failure - no test cases means test should fail
|
||||
EXPECT_TRUE(loaded) << "CSV test data loading failed";
|
||||
}
|
||||
|
||||
// Execute all test cases with 2D convolution
|
||||
// This calls Run<2>() which loops through conv_params and calls GPU profiler for each
|
||||
this->template Run<2>();
|
||||
return test_cases;
|
||||
}
|
||||
|
||||
// THE ACTUAL 3D TEST - This runs 3 times (once for each data type: fp32, fp16, bf16)
|
||||
TYPED_TEST(TestGroupedConvndFwd3d, Test3D)
|
||||
// Load CSV data for 3D tests
|
||||
static std::vector<ck::utils::conv::ConvParam> Get3DTestCases()
|
||||
{
|
||||
// LOAD TEST CASES FROM CSV FILE instead of hardcoded cases
|
||||
// Try different locations for the CSV file (build directory vs source directory)
|
||||
std::vector<std::string> csv_paths = {
|
||||
"../test_data/conv_test_set_3d_dataset.csv", // From build directory to source
|
||||
};
|
||||
|
||||
bool loaded = false;
|
||||
for(const auto& csv_path : csv_paths)
|
||||
static std::vector<ck::utils::conv::ConvParam> test_cases;
|
||||
if(test_cases.empty())
|
||||
{
|
||||
auto csv_cases = load_csv_test_cases(csv_path);
|
||||
if(!csv_cases.empty())
|
||||
std::string test_data_dir = ck::test::GetTestDataPath();
|
||||
if(test_data_dir.empty())
|
||||
{
|
||||
// Successfully loaded CSV data - add all test cases to conv_params
|
||||
for(const auto& test_case : csv_cases)
|
||||
{
|
||||
this->conv_params.push_back(test_case);
|
||||
}
|
||||
std::cout << "Loaded " << csv_cases.size() << " 3D test cases from " << csv_path
|
||||
<< std::endl;
|
||||
loaded = true;
|
||||
break;
|
||||
std::cerr << "FATAL: test_data directory not found" << std::endl;
|
||||
return test_cases;
|
||||
}
|
||||
|
||||
std::vector<std::string> csv_paths = {test_data_dir + "/conv_test_set_3d_dataset.csv"};
|
||||
bool loaded = ck::test::load_and_populate_test_cases(csv_paths, test_cases, "3D");
|
||||
if(!loaded)
|
||||
{
|
||||
std::cerr << "FATAL: Failed to load 3D test cases from " << csv_paths[0] << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
// FAIL if CSV loading fails - no fallback!
|
||||
if(!loaded)
|
||||
{
|
||||
std::cerr << "ERROR: Failed to load CSV test data from any of these locations:"
|
||||
<< std::endl;
|
||||
for(const auto& path : csv_paths)
|
||||
{
|
||||
std::cerr << " - " << path << std::endl;
|
||||
}
|
||||
std::cerr << "\nPlease ensure CSV test data exists in one of these locations." << std::endl;
|
||||
std::cerr << "Run generate_test_dataset.sh in test_data/ to create test datasets."
|
||||
<< std::endl;
|
||||
|
||||
// Force test failure - no test cases means test should fail
|
||||
EXPECT_TRUE(loaded) << "CSV test data loading failed";
|
||||
}
|
||||
|
||||
// Execute all test cases with 3D convolution
|
||||
// This calls Run<3>() which loops through conv_params and calls GPU profiler for each
|
||||
this->template Run<3>();
|
||||
return test_cases;
|
||||
}
|
||||
|
||||
// Helper template to run a single convolution test
|
||||
template <ck::index_t NDimSpatial,
|
||||
typename InLayout,
|
||||
typename WeiLayout,
|
||||
typename OutLayout,
|
||||
typename DataType>
|
||||
bool RunConvTest(const ck::utils::conv::ConvParam& param)
|
||||
{
|
||||
using IndexType = ck::long_index_t;
|
||||
return ck::profiler::profile_grouped_conv_fwd_impl<NDimSpatial,
|
||||
InLayout,
|
||||
WeiLayout,
|
||||
OutLayout,
|
||||
DataType,
|
||||
DataType,
|
||||
DataType,
|
||||
DataType,
|
||||
DataType,
|
||||
IndexType>(true, // do_verification
|
||||
1, // init_method
|
||||
false, // do_log
|
||||
false, // time_kernel
|
||||
param);
|
||||
}
|
||||
|
||||
// 2D Tests - Float
|
||||
class TestGroupedConvndFwd2dFloat : public ::testing::TestWithParam<ck::utils::conv::ConvParam>
|
||||
{
|
||||
};
|
||||
TEST_P(TestGroupedConvndFwd2dFloat, ConvTest)
|
||||
{
|
||||
EXPECT_TRUE((RunConvTest<2, NHWGC, GKYXC, NHWGK, float>(GetParam())));
|
||||
}
|
||||
INSTANTIATE_TEST_SUITE_P(Dataset,
|
||||
TestGroupedConvndFwd2dFloat,
|
||||
::testing::ValuesIn(Get2DTestCases()));
|
||||
|
||||
// 2D Tests - Half
|
||||
class TestGroupedConvndFwd2dHalf : public ::testing::TestWithParam<ck::utils::conv::ConvParam>
|
||||
{
|
||||
};
|
||||
TEST_P(TestGroupedConvndFwd2dHalf, ConvTest)
|
||||
{
|
||||
EXPECT_TRUE((RunConvTest<2, NHWGC, GKYXC, NHWGK, ck::half_t>(GetParam())));
|
||||
}
|
||||
INSTANTIATE_TEST_SUITE_P(Dataset,
|
||||
TestGroupedConvndFwd2dHalf,
|
||||
::testing::ValuesIn(Get2DTestCases()));
|
||||
|
||||
// 2D Tests - BFloat16
|
||||
class TestGroupedConvndFwd2dBFloat16 : public ::testing::TestWithParam<ck::utils::conv::ConvParam>
|
||||
{
|
||||
};
|
||||
TEST_P(TestGroupedConvndFwd2dBFloat16, ConvTest)
|
||||
{
|
||||
EXPECT_TRUE((RunConvTest<2, NHWGC, GKYXC, NHWGK, ck::bhalf_t>(GetParam())));
|
||||
}
|
||||
INSTANTIATE_TEST_SUITE_P(Dataset,
|
||||
TestGroupedConvndFwd2dBFloat16,
|
||||
::testing::ValuesIn(Get2DTestCases()));
|
||||
|
||||
// 2D Tests - Int8
|
||||
class TestGroupedConvndFwd2dInt8 : public ::testing::TestWithParam<ck::utils::conv::ConvParam>
|
||||
{
|
||||
};
|
||||
TEST_P(TestGroupedConvndFwd2dInt8, ConvTest)
|
||||
{
|
||||
EXPECT_TRUE((RunConvTest<2, NHWGC, GKYXC, NHWGK, int8_t>(GetParam())));
|
||||
}
|
||||
INSTANTIATE_TEST_SUITE_P(Dataset,
|
||||
TestGroupedConvndFwd2dInt8,
|
||||
::testing::ValuesIn(Get2DTestCases()));
|
||||
|
||||
// 3D Tests - Float
|
||||
class TestGroupedConvndFwd3dFloat : public ::testing::TestWithParam<ck::utils::conv::ConvParam>
|
||||
{
|
||||
};
|
||||
TEST_P(TestGroupedConvndFwd3dFloat, ConvTest)
|
||||
{
|
||||
EXPECT_TRUE((RunConvTest<3, NDHWGC, GKZYXC, NDHWGK, float>(GetParam())));
|
||||
}
|
||||
INSTANTIATE_TEST_SUITE_P(Dataset,
|
||||
TestGroupedConvndFwd3dFloat,
|
||||
::testing::ValuesIn(Get3DTestCases()));
|
||||
|
||||
// 3D Tests - Half
|
||||
class TestGroupedConvndFwd3dHalf : public ::testing::TestWithParam<ck::utils::conv::ConvParam>
|
||||
{
|
||||
};
|
||||
TEST_P(TestGroupedConvndFwd3dHalf, ConvTest)
|
||||
{
|
||||
EXPECT_TRUE((RunConvTest<3, NDHWGC, GKZYXC, NDHWGK, ck::half_t>(GetParam())));
|
||||
}
|
||||
INSTANTIATE_TEST_SUITE_P(Dataset,
|
||||
TestGroupedConvndFwd3dHalf,
|
||||
::testing::ValuesIn(Get3DTestCases()));
|
||||
|
||||
// 3D Tests - BFloat16
|
||||
class TestGroupedConvndFwd3dBFloat16 : public ::testing::TestWithParam<ck::utils::conv::ConvParam>
|
||||
{
|
||||
};
|
||||
TEST_P(TestGroupedConvndFwd3dBFloat16, ConvTest)
|
||||
{
|
||||
EXPECT_TRUE((RunConvTest<3, NDHWGC, GKZYXC, NDHWGK, ck::bhalf_t>(GetParam())));
|
||||
}
|
||||
INSTANTIATE_TEST_SUITE_P(Dataset,
|
||||
TestGroupedConvndFwd3dBFloat16,
|
||||
::testing::ValuesIn(Get3DTestCases()));
|
||||
|
||||
@@ -8,6 +8,20 @@
|
||||
set -e # Exit on error
|
||||
set +x # Disable command echo (even if called with bash -x)
|
||||
|
||||
# Trap to kill all background jobs on script exit/interruption
|
||||
cleanup() {
|
||||
echo ""
|
||||
echo "Cleaning up background processes..."
|
||||
# Kill all jobs in the current process group
|
||||
jobs -p | xargs -r kill 2>/dev/null || true
|
||||
wait 2>/dev/null || true
|
||||
echo "Cleanup complete."
|
||||
exit 1
|
||||
}
|
||||
|
||||
# Set up trap for common termination signals
|
||||
trap cleanup SIGINT SIGTERM EXIT
|
||||
|
||||
echo "=========================================="
|
||||
echo "CK Convolution Test Dataset Generator"
|
||||
echo "=========================================="
|
||||
@@ -18,7 +32,7 @@ if ! python3 -c "import torch" 2>/dev/null; then
|
||||
echo "PyTorch not found. Creating virtual environment..."
|
||||
|
||||
# Create a virtual environment in the current directory
|
||||
VENV_DIR="./pytorch_venv"
|
||||
VENV_DIR="./.venv"
|
||||
if [ ! -d "$VENV_DIR" ]; then
|
||||
python3 -m venv $VENV_DIR || {
|
||||
echo "ERROR: Failed to create virtual environment."
|
||||
@@ -66,11 +80,71 @@ if ! $PYTHON_CMD -c "import torch; import sys; sys.exit(0 if torch.cuda.is_avail
|
||||
echo "Continuing anyway to generate placeholder data..."
|
||||
fi
|
||||
|
||||
# Parse command line arguments
|
||||
CONFIG_MODE="full" # Default configuration mode: 'small', 'half' or 'full'
|
||||
MAX_PARALLEL_JOBS=1 # Default number of parallel jobs
|
||||
NUM_GPUS=1 # Number of GPUs to use (0 means no GPU assignment)
|
||||
|
||||
# Process arguments
|
||||
while [[ $# -gt 0 ]]; do
|
||||
case $1 in
|
||||
-j)
|
||||
MAX_PARALLEL_JOBS="$2"
|
||||
shift 2
|
||||
;;
|
||||
-j*)
|
||||
MAX_PARALLEL_JOBS="${1#-j}"
|
||||
shift
|
||||
;;
|
||||
--gpus)
|
||||
NUM_GPUS="$2"
|
||||
shift 2
|
||||
;;
|
||||
small|half|full)
|
||||
CONFIG_MODE="$1"
|
||||
shift
|
||||
;;
|
||||
*)
|
||||
echo "Usage: $0 [small|half|full] [-j <num_jobs>] [--gpus <num_gpus>]"
|
||||
echo " Configuration modes: small, half, full (default: full)"
|
||||
echo " -j <num_jobs>: Number of parallel jobs (default: 1)"
|
||||
echo " --gpus <num_gpus>: Number of GPUs to use (e.g., 8 for GPUs 0-7)"
|
||||
exit 1
|
||||
;;
|
||||
esac
|
||||
done
|
||||
|
||||
# Setup GPU array if GPUs are requested
|
||||
if [ $NUM_GPUS -gt 0 ]; then
|
||||
# Auto-detect available GPUs
|
||||
AVAILABLE_GPUS_COUNT=$(rocm-smi --showid 2>/dev/null | grep -oP 'GPU\[\K[0-9]+' | wc -l)
|
||||
if [ "$AVAILABLE_GPUS_COUNT" -gt 0 ]; then
|
||||
MAX_AVAILABLE=$AVAILABLE_GPUS_COUNT
|
||||
else
|
||||
MAX_AVAILABLE=0
|
||||
fi
|
||||
|
||||
# Validate requested GPU count
|
||||
if [ $NUM_GPUS -gt $MAX_AVAILABLE ]; then
|
||||
echo "WARNING: Requested $NUM_GPUS GPUs but only $MAX_AVAILABLE available. Using $MAX_AVAILABLE GPUs."
|
||||
NUM_GPUS=$MAX_AVAILABLE
|
||||
fi
|
||||
|
||||
# Build GPU array (0 to NUM_GPUS-1)
|
||||
GPU_ARRAY=()
|
||||
for ((i=0; i<NUM_GPUS; i++)); do
|
||||
GPU_ARRAY+=($i)
|
||||
done
|
||||
|
||||
echo "Using $NUM_GPUS GPU(s): ${GPU_ARRAY[*]}"
|
||||
else
|
||||
echo "No GPU assignment specified, using default GPU behavior"
|
||||
GPU_ARRAY=()
|
||||
fi
|
||||
|
||||
# Configuration
|
||||
OUTPUT_DIR="generated_datasets"
|
||||
TIMESTAMP=$(date +"%Y%m%d_%H%M%S")
|
||||
# Get configuration mode from command line argument (default: full)
|
||||
CONFIG_MODE="${1:-full}" # Configuration mode: 'small', 'half' or 'full'
|
||||
|
||||
# Colors
|
||||
RED='\033[0;31m'
|
||||
@@ -128,7 +202,8 @@ rocm-smi --showdriverversion || true
|
||||
echo ""
|
||||
echo "Step 2: Running 2D/3D models and capturing MIOpen commands"
|
||||
echo "-----------------------------------------"
|
||||
|
||||
echo "Using up to $MAX_PARALLEL_JOBS parallel jobs"
|
||||
echo ""
|
||||
|
||||
# Process 2D models from CSV configuration file
|
||||
echo "Processing 2D models from $OUTPUT_DIR/model_configs_2d.csv..."
|
||||
@@ -140,6 +215,11 @@ CURRENT_CONFIG=0
|
||||
echo "Total configurations to process: $TOTAL_CONFIGS"
|
||||
echo ""
|
||||
|
||||
# Array to track background job PIDs
|
||||
declare -a job_pids=()
|
||||
# Counter for round-robin GPU assignment
|
||||
GPU_COUNTER=0
|
||||
|
||||
# Read 2D configurations from CSV (skip comments and header)
|
||||
while IFS=',' read -r config_name model batch_size channels height width precision; do
|
||||
# Skip comments and empty lines
|
||||
@@ -150,21 +230,57 @@ while IFS=',' read -r config_name model batch_size channels height width precisi
|
||||
# Increment counter
|
||||
CURRENT_CONFIG=$((CURRENT_CONFIG + 1))
|
||||
|
||||
|
||||
# Build configuration command
|
||||
CONFIG="--model $model --batch-size $batch_size --channels $channels --height $height --width $width --precision $precision"
|
||||
CONFIG_NAME="$config_name"
|
||||
|
||||
echo -e "${GREEN}[${CURRENT_CONFIG}/${TOTAL_CONFIGS}]${NC} ${CYAN}2D${NC} ${YELLOW}$CONFIG_NAME${NC}"
|
||||
# Assign GPU in round-robin fashion if GPUs are specified
|
||||
if [ $NUM_GPUS -gt 0 ]; then
|
||||
GPU_ID=${GPU_ARRAY[$((GPU_COUNTER % NUM_GPUS))]}
|
||||
GPU_COUNTER=$((GPU_COUNTER + 1))
|
||||
echo -e "${GREEN}[${CURRENT_CONFIG}/${TOTAL_CONFIGS}]${NC} ${CYAN}2D${NC} ${YELLOW}$CONFIG_NAME${NC} ${PURPLE}[GPU ${GPU_ID}]${NC} - Starting in background"
|
||||
else
|
||||
GPU_ID=""
|
||||
echo -e "${GREEN}[${CURRENT_CONFIG}/${TOTAL_CONFIGS}]${NC} ${CYAN}2D${NC} ${YELLOW}$CONFIG_NAME${NC} - Starting in background"
|
||||
fi
|
||||
|
||||
# Actual run with logging (suppress stdout, only capture stderr with MIOpen commands)
|
||||
MIOPEN_ENABLE_LOGGING_CMD=1 $PYTHON_CMD run_model_with_miopen.py \
|
||||
--model $model --batch-size $batch_size --channels $channels --height $height --width $width --precision $precision \
|
||||
> /dev/null 2>> $OUTPUT_DIR/${model}_miopen_log_2d.txt || true
|
||||
|
||||
# Run in background
|
||||
(
|
||||
# Set HIP_VISIBLE_DEVICES if GPU was assigned
|
||||
if [ -n "$GPU_ID" ]; then
|
||||
export HIP_VISIBLE_DEVICES=$GPU_ID
|
||||
fi
|
||||
|
||||
MIOPEN_ENABLE_LOGGING_CMD=1 $PYTHON_CMD run_model_with_miopen.py \
|
||||
--model $model --batch-size $batch_size --channels $channels --height $height --width $width --precision $precision \
|
||||
> /dev/null 2>> $OUTPUT_DIR/${model}_miopen_log_2d.txt || true
|
||||
echo -e "${GREEN}[DONE]${NC} ${CYAN}2D${NC} ${YELLOW}$CONFIG_NAME${NC}"
|
||||
) &
|
||||
|
||||
job_pids+=($!)
|
||||
|
||||
# Limit number of parallel jobs
|
||||
if [ ${#job_pids[@]} -ge $MAX_PARALLEL_JOBS ]; then
|
||||
# Wait for any job to complete
|
||||
wait -n
|
||||
# Remove completed jobs from array
|
||||
for i in "${!job_pids[@]}"; do
|
||||
if ! kill -0 "${job_pids[$i]}" 2>/dev/null; then
|
||||
unset 'job_pids[$i]'
|
||||
fi
|
||||
done
|
||||
job_pids=("${job_pids[@]}") # Re-index array
|
||||
fi
|
||||
|
||||
done < $OUTPUT_DIR/model_configs_2d.csv
|
||||
|
||||
# Wait for all remaining 2D jobs to complete
|
||||
echo "Waiting for remaining 2D jobs to complete..."
|
||||
wait
|
||||
|
||||
echo "All 2D models processed!"
|
||||
echo ""
|
||||
|
||||
# Process 3D models from CSV configuration file
|
||||
echo "Processing 3D models from $OUTPUT_DIR/model_configs_3d.csv..."
|
||||
|
||||
@@ -175,6 +291,10 @@ CURRENT_3D_CONFIG=0
|
||||
echo "Total 3D configurations to process: $TOTAL_3D_CONFIGS"
|
||||
echo ""
|
||||
|
||||
# Reset job tracking array
|
||||
declare -a job_pids=()
|
||||
# GPU counter continues from 2D models for round-robin assignment
|
||||
|
||||
# Read 3D configurations from CSV (skip comments and header)
|
||||
while IFS=',' read -r config_name model batch_size channels temporal_size height width precision; do
|
||||
# Skip comments and empty lines
|
||||
@@ -185,21 +305,59 @@ while IFS=',' read -r config_name model batch_size channels temporal_size height
|
||||
# Increment counter
|
||||
CURRENT_3D_CONFIG=$((CURRENT_3D_CONFIG + 1))
|
||||
|
||||
|
||||
# Build configuration command for 3D models
|
||||
CONFIG="--model $model --batch-size $batch_size --channels $channels --temporal-size $temporal_size --height $height --width $width --precision $precision"
|
||||
CONFIG_NAME="$config_name"
|
||||
|
||||
echo -e "${GREEN}[${CURRENT_3D_CONFIG}/${TOTAL_3D_CONFIGS}]${NC} ${CYAN}3D${NC} ${YELLOW}$CONFIG_NAME${NC}"
|
||||
# Assign GPU in round-robin fashion if GPUs are specified
|
||||
if [ $NUM_GPUS -gt 0 ]; then
|
||||
GPU_ID=${GPU_ARRAY[$((GPU_COUNTER % NUM_GPUS))]}
|
||||
GPU_COUNTER=$((GPU_COUNTER + 1))
|
||||
echo -e "${GREEN}[${CURRENT_3D_CONFIG}/${TOTAL_3D_CONFIGS}]${NC} ${CYAN}3D${NC} ${YELLOW}$CONFIG_NAME${NC} ${PURPLE}[GPU ${GPU_ID}]${NC} - Starting in background"
|
||||
else
|
||||
GPU_ID=""
|
||||
echo -e "${GREEN}[${CURRENT_3D_CONFIG}/${TOTAL_3D_CONFIGS}]${NC} ${CYAN}3D${NC} ${YELLOW}$CONFIG_NAME${NC} - Starting in background"
|
||||
fi
|
||||
|
||||
# Run in background
|
||||
(
|
||||
# Set HIP_VISIBLE_DEVICES if GPU was assigned
|
||||
if [ -n "$GPU_ID" ]; then
|
||||
export HIP_VISIBLE_DEVICES=$GPU_ID
|
||||
fi
|
||||
|
||||
MIOPEN_ENABLE_LOGGING_CMD=1 $PYTHON_CMD run_model_with_miopen.py \
|
||||
--model $model --batch-size $batch_size --channels $channels --temporal-size $temporal_size --height $height --width $width --precision $precision \
|
||||
> /dev/null 2>> $OUTPUT_DIR/${model}_miopen_log_3d.txt || true
|
||||
echo -e "${GREEN}[DONE]${NC} ${CYAN}3D${NC} ${YELLOW}$CONFIG_NAME${NC}"
|
||||
) &
|
||||
|
||||
# Actual run with logging (suppress stdout, only capture stderr with MIOpen commands)
|
||||
MIOPEN_ENABLE_LOGGING_CMD=1 $PYTHON_CMD run_model_with_miopen.py \
|
||||
--model $model --batch-size $batch_size --channels $channels --temporal-size $temporal_size --height $height --width $width --precision $precision \
|
||||
> /dev/null 2>> $OUTPUT_DIR/${model}_miopen_log_3d.txt || true
|
||||
job_pids+=($!)
|
||||
|
||||
# Limit number of parallel jobs
|
||||
if [ ${#job_pids[@]} -ge $MAX_PARALLEL_JOBS ]; then
|
||||
# Wait for any job to complete
|
||||
wait -n
|
||||
# Remove completed jobs from array
|
||||
for i in "${!job_pids[@]}"; do
|
||||
if ! kill -0 "${job_pids[$i]}" 2>/dev/null; then
|
||||
unset 'job_pids[$i]'
|
||||
fi
|
||||
done
|
||||
job_pids=("${job_pids[@]}") # Re-index array
|
||||
fi
|
||||
|
||||
done < $OUTPUT_DIR/model_configs_3d.csv
|
||||
|
||||
# Wait for all remaining 3D jobs to complete
|
||||
echo "Waiting for remaining 3D jobs to complete..."
|
||||
wait
|
||||
|
||||
echo "All 3D models processed!"
|
||||
echo ""
|
||||
|
||||
# Disable trap on successful completion
|
||||
trap - SIGINT SIGTERM EXIT
|
||||
|
||||
echo ""
|
||||
echo "Step 3: Converting MIOpen commands to CSV test cases"
|
||||
@@ -311,7 +469,7 @@ if [ $COUNT_3D -gt 0 ]; then
|
||||
fi
|
||||
echo " - Intermediate files in: $OUTPUT_DIR/"
|
||||
echo ""
|
||||
echo "To use these datasets:"
|
||||
echo " 1. Build the test: cd ../script && make -j64 test_grouped_convnd_fwd_dataset_xdl"
|
||||
echo " 2. Run the test: ./bin/test_grouped_convnd_fwd_dataset_xdl"
|
||||
echo "To use these datasets for direction (bwd_data, bwd_weight, or fwd):"
|
||||
echo " 1. Build the test: cd ../script && make -j64 test_grouped_convnd_<direction>_dataset_xdl"
|
||||
echo " 2. Run the test: ./bin/test_grouped_convnd_<direction>_dataset_xdl"
|
||||
echo ""
|
||||
|
||||
1187
test_data/gtest_parallel.py
Normal file
1187
test_data/gtest_parallel.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -81,42 +81,3 @@ inline KernelTraits extract_traits_from_name(const std::string& kernel_name)
|
||||
|
||||
return traits;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
auto shuffle_b(const ck_tile::HostTensor<T>& t,
|
||||
ck_tile::index_t N_Warp_Tile,
|
||||
ck_tile::index_t K_Warp_Tile)
|
||||
{
|
||||
assert(t.get_lengths().size() == 2);
|
||||
int n_ = t.get_lengths()[1];
|
||||
int k_ = t.get_lengths()[0];
|
||||
int divisor = N_Warp_Tile == 32 ? 2 : 4;
|
||||
ck_tile::HostTensor<T> t_view(
|
||||
{n_ / N_Warp_Tile, N_Warp_Tile, k_ / K_Warp_Tile, divisor, K_Warp_Tile / divisor});
|
||||
std::copy(t.begin(), t.end(), t_view.begin());
|
||||
return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4});
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
auto shuffle_b_permuteN(const ck_tile::HostTensor<T>& t,
|
||||
ck_tile::index_t N_Warp_Tile,
|
||||
ck_tile::index_t K_Warp_Tile,
|
||||
ck_tile::index_t N_Tile,
|
||||
ck_tile::index_t N_Warp)
|
||||
{
|
||||
assert(t.get_lengths().size() == 2);
|
||||
|
||||
int n_ = t.get_lengths()[1];
|
||||
int k_ = t.get_lengths()[0];
|
||||
int divisor = N_Warp_Tile == 32 ? 2 : 4;
|
||||
int NRepeat = N_Tile / N_Warp_Tile / N_Warp;
|
||||
ck_tile::HostTensor<T> t_view({n_ / N_Tile,
|
||||
N_Warp,
|
||||
N_Warp_Tile,
|
||||
NRepeat,
|
||||
k_ / K_Warp_Tile,
|
||||
divisor,
|
||||
K_Warp_Tile / divisor});
|
||||
std::copy(t.begin(), t.end(), t_view.begin());
|
||||
return ck_tile::reference_permute(t_view, {0, 3, 1, 4, 5, 2, 6});
|
||||
}
|
||||
|
||||
@@ -111,21 +111,30 @@ class GemmProfiler
|
||||
c_m_n_dev_buf.SetZero();
|
||||
c_m_n_dev_result.SetZero();
|
||||
|
||||
struct GemmConfig
|
||||
{
|
||||
ck_tile::index_t N_Warp_Tile;
|
||||
ck_tile::index_t K_Warp_Tile;
|
||||
ck_tile::index_t N_Tile;
|
||||
ck_tile::index_t N_Warp;
|
||||
};
|
||||
|
||||
for(const auto& callable : callables)
|
||||
{
|
||||
ck_tile::index_t N_Warp_Tile = std::get<1>(config.warp_tile_dims);
|
||||
ck_tile::index_t K_Warp_Tile = std::get<2>(config.warp_tile_dims);
|
||||
ck_tile::index_t N_Tile = std::get<1>(config.tile_dims);
|
||||
ck_tile::index_t N_Warp = std::get<1>(config.warp_dims);
|
||||
GemmConfig gemmConfig = {};
|
||||
gemmConfig.N_Warp_Tile = std::get<1>(config.warp_tile_dims);
|
||||
gemmConfig.K_Warp_Tile = std::get<2>(config.warp_tile_dims);
|
||||
gemmConfig.N_Tile = std::get<1>(config.tile_dims);
|
||||
gemmConfig.N_Warp = std::get<1>(config.warp_dims);
|
||||
|
||||
ck_tile::HostTensor<BDataType> b_shuffle_host = [&]() {
|
||||
if(config.permuteN)
|
||||
{
|
||||
return shuffle_b_permuteN(b_k_n, N_Warp_Tile, K_Warp_Tile, N_Tile, N_Warp);
|
||||
return ck_tile::shuffle_b_permuteN(b_k_n, gemmConfig);
|
||||
}
|
||||
else
|
||||
{
|
||||
return shuffle_b(b_k_n, N_Warp_Tile, K_Warp_Tile);
|
||||
return ck_tile::shuffle_b(b_k_n, gemmConfig);
|
||||
}
|
||||
}();
|
||||
|
||||
|
||||
Reference in New Issue
Block a user