Merge commit '7c44a763fa9719ba1b18d3b6a37b6138c78d97fd' into develop

This commit is contained in:
assistant-librarian[bot]
2025-08-01 17:13:48 +00:00
parent da195f27ce
commit 29f8d7250c
3 changed files with 30 additions and 12 deletions

View File

@@ -265,17 +265,25 @@ using WarpGemmMfma_f32_16x16x64_bf8_bf8 = WarpGemmImpl<WarpGemmAtrributeMfmaIter
WarpGemmAttributeMfmaImpl_f32_16x16x32_bf8_bf8<WGAttrCtlEnum::Default_>,
2>>;
using WarpGemmMfma_f32_16x16x128_fp8_fp8 = WarpGemmImpl<WarpGemmAtrributeMfma<
WarpGemmAttributeMfmaImpl_f32_16x16x128_fp8_fp8<WGAttrCtlEnum::Default_>>>;
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
using WarpGemmMfma_f32_16x16x128_fp8_fp8 = WarpGemmImpl<
WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImpl_f32_16x16x128_fp8_fp8<WGAttrCtlEnum::Default_>,
AttrNumAccess>>;
using WarpGemmMfma_f32_16x16x128_fp8_bf8 = WarpGemmImpl<WarpGemmAtrributeMfma<
WarpGemmAttributeMfmaImpl_f32_16x16x128_fp8_bf8<WGAttrCtlEnum::Default_>>>;
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
using WarpGemmMfma_f32_16x16x128_fp8_bf8 = WarpGemmImpl<
WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImpl_f32_16x16x128_fp8_bf8<WGAttrCtlEnum::Default_>,
AttrNumAccess>>;
using WarpGemmMfma_f32_16x16x128_bf8_fp8 = WarpGemmImpl<WarpGemmAtrributeMfma<
WarpGemmAttributeMfmaImpl_f32_16x16x128_bf8_fp8<WGAttrCtlEnum::Default_>>>;
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
using WarpGemmMfma_f32_16x16x128_bf8_fp8 = WarpGemmImpl<
WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImpl_f32_16x16x128_bf8_fp8<WGAttrCtlEnum::Default_>,
AttrNumAccess>>;
using WarpGemmMfma_f32_16x16x128_bf8_bf8 = WarpGemmImpl<WarpGemmAtrributeMfma<
WarpGemmAttributeMfmaImpl_f32_16x16x128_bf8_bf8<WGAttrCtlEnum::Default_>>>;
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
using WarpGemmMfma_f32_16x16x128_bf8_bf8 = WarpGemmImpl<
WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImpl_f32_16x16x128_bf8_bf8<WGAttrCtlEnum::Default_>,
AttrNumAccess>>;
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
using WarpGemmMfma_f32_32x32x64_fp8_fp8 = WarpGemmImpl<

View File

@@ -92,10 +92,10 @@ template<> struct WarpGemmMfmaDispatcher<ck_tile::bf8_t, ck_tile::bf8_t, float,
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf8_t, ck_tile::bf8_t, float, 16, 16, 64, false> { using Type = WarpGemmMfma_f32_16x16x64_bf8_bf8; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf8_t, ck_tile::bf8_t, float, 32, 32, 16, true> { using Type = WarpGemmMfma_f32_32x32x16_bf8_bf8_CTransposed; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::fp8_t, ck_tile::fp8_t, float, 16, 16, 128, false> { using Type = WarpGemmMfma_f32_16x16x128_fp8_fp8; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::fp8_t, ck_tile::bf8_t, float, 16, 16, 128, false> { using Type = WarpGemmMfma_f32_16x16x128_fp8_bf8; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf8_t, ck_tile::fp8_t, float, 16, 16, 128, false> { using Type = WarpGemmMfma_f32_16x16x128_bf8_fp8; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf8_t, ck_tile::bf8_t, float, 16, 16, 128, false> { using Type = WarpGemmMfma_f32_16x16x128_bf8_bf8; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::fp8_t, ck_tile::fp8_t, float, 16, 16, 128, false> { using Type = WarpGemmMfma_f32_16x16x128_fp8_fp8<>; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::fp8_t, ck_tile::bf8_t, float, 16, 16, 128, false> { using Type = WarpGemmMfma_f32_16x16x128_fp8_bf8<>; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf8_t, ck_tile::fp8_t, float, 16, 16, 128, false> { using Type = WarpGemmMfma_f32_16x16x128_bf8_fp8<>; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf8_t, ck_tile::bf8_t, float, 16, 16, 128, false> { using Type = WarpGemmMfma_f32_16x16x128_bf8_bf8<>; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::fp8_t, ck_tile::fp8_t, float, 32, 32, 64, false> { using Type = WarpGemmMfma_f32_32x32x64_fp8_fp8<>; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::fp8_t, ck_tile::bf8_t, float, 32, 32, 64, false> { using Type = WarpGemmMfma_f32_32x32x64_fp8_bf8<>; };
@@ -110,6 +110,14 @@ template<> struct WarpGemmMfmaDispatcher<ck_tile::bf8_t, ck_tile::fp8_t, float,
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf8_t, ck_tile::bf8_t, float, 32, 32, 64, false, false, false, WGAttrNumAccessEnum::Quad> {
using Type = WarpGemmMfma_f32_32x32x64_bf8_bf8<WGAttrNumAccessEnum::Quad>; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::fp8_t, ck_tile::fp8_t, float, 16, 16, 128, false, false, false, WGAttrNumAccessEnum::Quad> {
using Type = WarpGemmMfma_f32_16x16x128_fp8_fp8<WGAttrNumAccessEnum::Quad>; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::fp8_t, ck_tile::bf8_t, float, 16, 16, 128, false, false, false, WGAttrNumAccessEnum::Quad> {
using Type = WarpGemmMfma_f32_16x16x128_fp8_bf8<WGAttrNumAccessEnum::Quad>; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf8_t, ck_tile::fp8_t, float, 16, 16, 128, false, false, false, WGAttrNumAccessEnum::Quad> {
using Type = WarpGemmMfma_f32_16x16x128_bf8_fp8<WGAttrNumAccessEnum::Quad>; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf8_t, ck_tile::bf8_t, float, 16, 16, 128, false, false, false, WGAttrNumAccessEnum::Quad> {
using Type = WarpGemmMfma_f32_16x16x128_bf8_bf8<WGAttrNumAccessEnum::Quad>; };
// int8
// ADataType, BDataType, AccDataType, MPerWave, NPerWave, KPerWave, TransposeC, SwizzleA, UseStructuredSparsity
template<> struct WarpGemmMfmaDispatcher<ck_tile::int8_t, ck_tile::int8_t, ck_tile::int32_t, 32, 32, 16, false> { using Type = WarpGemmMfma_i32_32x32x16_i8_i8; };

View File

@@ -375,6 +375,8 @@ int run_gemm_combinations(std::string const& data_type)
{
is_success =
run_gemm_test<GemmConfigComputeV3>(ARG_COUNT, argv) && is_success;
is_success =
run_gemm_test<GemmConfigComputeV3_2>(ARG_COUNT, argv) && is_success;
}
catch(const ArgumentsNotSupportedException& e)
{