Merge commit '1b1c46e508c1fd40a03f54114b6b78629032fb4f' into develop

This commit is contained in:
assistant-librarian[bot]
2025-11-11 17:12:49 +00:00
parent 0b000816a4
commit db12c41b56
65 changed files with 845 additions and 455 deletions

View File

@@ -13,7 +13,7 @@ using CDataType = ck::bhalf_t;
using ComputeTypeA = ck::f8_t;
using ComputeTypeB = ck::f8_t;
using ALayout = Row;
using ALayout = Col;
using BLayout = Col;
using CLayout = Row;
@@ -30,13 +30,13 @@ using DeviceGemmV2Instance = ck::tensor_operation::device::DeviceGemm_Wmma_CShuf
PassThrough, PassThrough, PassThrough, GemmDefault,
128,
128, 64, 64,
8, 8,
16, 16, // AK1, BK1
16, 16,
4, 2,
S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>,
1, 4, 16, 0,
S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>,
2, 8, 8, 0,
S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>,
2, 8, 8, 0,
2, 16, 16, 0,
1, 1, S<1, 32, 1, 4>, 8,
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1,
ComputeTypeA, ComputeTypeB>;

View File

@@ -5,7 +5,7 @@ endif()
list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -mllvm -enable-noalias-to-md-conversion=0)
if(GPU_TARGETS MATCHES "gfx94" OR GPU_TARGETS MATCHES "gfx95")
if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx12")
add_executable(tile_example_gemm_quant_basic EXCLUDE_FROM_ALL gemm_quant_basic.cpp)
target_compile_options(tile_example_gemm_quant_basic PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
else()

View File

@@ -419,6 +419,10 @@ int dispatch_group_size_ct(int m, int n, int k, F&& f)
int main(int argc, char* argv[])
{
#if CK_TILE_USE_WMMA
return !run_gemm_example<GemmConfigBQuantPrefill_Wmma>(argc, argv);
#else
// Use non-preshuffled GemmConfig for 2D block scale support
return !run_gemm_example<GemmConfigBQuantPrefill>(argc, argv);
#endif
}

View File

@@ -216,6 +216,14 @@ struct GemmConfigBQuantPrefill : public GemmConfigBase
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
};
template <typename PrecType>
struct GemmConfigBQuantPrefill_Wmma : public GemmConfigBQuantPrefill<PrecType>
{
static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile = 16;
};
template <typename ADataType_,
typename BDataType_ = ADataType_,
typename CDataType_ = ADataType_,