Merge commit '47e2ed838e3547bba1b48d3f559f20f46fd07b87' into develop

This commit is contained in:
assistant-librarian[bot]
2025-11-20 02:43:03 +00:00
parent ca48bf3b98
commit 809c1ead72
183 changed files with 987 additions and 863 deletions

View File

@@ -3,7 +3,7 @@ if(CK_USE_OCP_FP8)
list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8)
endif()
if(GPU_TARGETS MATCHES "gfx94|gfx95")
if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx12")
# Split into three separate test executables for faster parallel compilation
add_gtest_executable(test_ck_tile_grouped_gemm_quant_rowcol test_grouped_gemm_quant_rowcol.cpp)
target_compile_options(test_ck_tile_grouped_gemm_quant_rowcol PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})

View File

@@ -74,6 +74,17 @@ class TestCkTileGroupedGemmQuant : public ::testing::Test
M_Warp_Tile>();
};
struct GroupedGemKernelParam_Wmma : public GroupedGemKernelParam_Mfma
{
static const ck_tile::index_t M_Tile = 128;
static const ck_tile::index_t N_Tile = 128;
static const ck_tile::index_t K_Tile = 128;
static const ck_tile::index_t M_Warp_Tile = 16;
static const ck_tile::index_t N_Warp_Tile = 16;
static const ck_tile::index_t K_Warp_Tile = 16;
};
using grouped_gemm_kargs = ck_tile::QuantGroupedGemmHostArgs;
std::size_t get_workspace_size(const std::vector<grouped_gemm_kargs>& gemm_descs)
{
@@ -373,8 +384,13 @@ class TestCkTileGroupedGemmQuant : public ::testing::Test
if constexpr(PreshuffleB && QuantType == ck_tile::QuantType::BQuantGrouped)
{
#if CK_TILE_USE_WMMA
auto b_shuffle_host =
ck_tile::shuffle_b<GroupedGemKernelParam_Wmma>(b_k_n_tensors[i]);
#else
auto b_shuffle_host =
ck_tile::shuffle_b<GroupedGemKernelParam_Mfma>(b_k_n_tensors[i]);
#endif
b_k_n_dev_buf[i]->ToDevice(b_shuffle_host.data());
}
else
@@ -446,8 +462,13 @@ class TestCkTileGroupedGemmQuant : public ::testing::Test
kargs.size() * sizeof(ck_tile::QuantGemmTransKernelArg),
hipMemcpyHostToDevice,
stream.stream_id_));
#if CK_TILE_USE_WMMA
invoke_grouped_gemm_persistent<GroupedGemKernelParam_Wmma, ALayout, BLayout, CLayout>(
stream, group_count, kargs_ptr);
#else
invoke_grouped_gemm_persistent<GroupedGemKernelParam_Mfma, ALayout, BLayout, CLayout>(
stream, group_count, kargs_ptr);
#endif
}
else
{