[CK_TILE] Fix gemm_quant (#3186)

[ROCm/composable_kernel commit: 1b1c46e508]
This commit is contained in:
linqunAMD
2025-11-12 00:23:57 +08:00
committed by GitHub
parent c1b5372db3
commit 13cf0bd17f
13 changed files with 135 additions and 49 deletions

View File

@@ -5,7 +5,7 @@ endif()
list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -mllvm -enable-noalias-to-md-conversion=0)
if(GPU_TARGETS MATCHES "gfx94" OR GPU_TARGETS MATCHES "gfx95")
if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx12")
add_executable(tile_example_gemm_quant_basic EXCLUDE_FROM_ALL gemm_quant_basic.cpp)
target_compile_options(tile_example_gemm_quant_basic PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
else()

View File

@@ -419,6 +419,10 @@ int dispatch_group_size_ct(int m, int n, int k, F&& f)
int main(int argc, char* argv[])
{
#if CK_TILE_USE_WMMA
return !run_gemm_example<GemmConfigBQuantPrefill_Wmma>(argc, argv);
#else
// Use non-preshuffled GemmConfig for 2D block scale support
return !run_gemm_example<GemmConfigBQuantPrefill>(argc, argv);
#endif
}

View File

@@ -216,6 +216,14 @@ struct GemmConfigBQuantPrefill : public GemmConfigBase
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
};
template <typename PrecType>
struct GemmConfigBQuantPrefill_Wmma : public GemmConfigBQuantPrefill<PrecType>
{
static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile = 16;
};
template <typename ADataType_,
typename BDataType_ = ADataType_,
typename CDataType_ = ADataType_,