This commit is contained in:
Ding, Yi
2026-03-11 23:03:20 -04:00
commit e6cd3f1e3f
6330 changed files with 1132789 additions and 0 deletions

View File

@@ -0,0 +1,504 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_impl.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_smfmac_impl.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_attribute_smfmac.hpp"
namespace ck_tile {
// fp32
using WarpGemmMfmaF32F32F32M16N16K4 = WarpGemmImpl<
WarpGemmAttributeMfma<WarpGemmAttributeMfmaImplF32F32F32M16N16K4<WGAttrCtlEnum::Default_>>>;
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
using WarpGemmMfmaF32F32F32M16N16K16 = WarpGemmImpl<WarpGemmAttributeMfmaIterateK<
WarpGemmAttributeMfmaImplF32F32F32M16N16K4<WGAttrCtlEnum::Default_>,
4,
AttrNumAccess>>;
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
using WarpGemmMfmaF32F32F32M16N16K8 = WarpGemmImpl<WarpGemmAttributeMfmaIterateK<
WarpGemmAttributeMfmaImplF32F32F32M16N16K4<WGAttrCtlEnum::Default_>,
2,
AttrNumAccess>>;
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
using WarpGemmMfmaF32F32F32M32N32K8 = WarpGemmImpl<WarpGemmAttributeMfmaIterateK<
WarpGemmAttributeMfmaImplF32F32F32M32N32K2<WGAttrCtlEnum::Default_>,
4,
AttrNumAccess>>;
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
using WarpGemmMfmaF32F32F32M32N32K4 = WarpGemmImpl<WarpGemmAttributeMfmaIterateK<
WarpGemmAttributeMfmaImplF32F32F32M32N32K2<WGAttrCtlEnum::Default_>,
2,
AttrNumAccess>>;
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
using WarpGemmMfmaF32F32F32M16N16K16TransposedCDistribution =
WarpGemmImpl<WarpGemmAttributeMfmaIterateKAndTransposedCDistribution<
WarpGemmAttributeMfmaImplF32F32F32M16N16K4<WGAttrCtlEnum::Default_>,
4,
AttrNumAccess>>;
// fp16
using WarpGemmMfmaF16F16F32M32N32K8 = WarpGemmImpl<
WarpGemmAttributeMfma<WarpGemmAttributeMfmaImplF16F16F32M32N32K8<WGAttrCtlEnum::Default_>>>;
using WarpGemmMfmaF16F16F32M16N16K16 = WarpGemmImpl<
WarpGemmAttributeMfma<WarpGemmAttributeMfmaImplF16F16F32M16N16K16<WGAttrCtlEnum::Default_>>>;
#if defined(__gfx950__)
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
using WarpGemmMfmaF16F16F32M32N32K16 = WarpGemmImpl<
WarpGemmAttributeMfma<WarpGemmAttributeMfmaImplF16F16F32M32N32K16<WGAttrCtlEnum::Default_>,
AttrNumAccess>>;
#else
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
using WarpGemmMfmaF16F16F32M32N32K16 = WarpGemmImpl<WarpGemmAttributeMfmaIterateK<
WarpGemmAttributeMfmaImplF16F16F32M32N32K8<WGAttrCtlEnum::Default_>,
2,
AttrNumAccess>>;
#endif
#if defined(__gfx950__)
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
using WarpGemmMfmaF16F16F32M16N16K32 = WarpGemmImpl<
WarpGemmAttributeMfma<WarpGemmAttributeMfmaImplF16F16F32M16N16K32<WGAttrCtlEnum::Default_>,
AttrNumAccess>>;
#else
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
using WarpGemmMfmaF16F16F32M16N16K32 = WarpGemmImpl<WarpGemmAttributeMfmaIterateK<
WarpGemmAttributeMfmaImplF16F16F32M16N16K16<WGAttrCtlEnum::Default_>,
2,
AttrNumAccess>>;
#endif
using WarpGemmMfmaF16F16F32M32N32K8SwizzleA = WarpGemmImpl<WarpGemmAttributeMfmaIterateK_SwizzleA<
WarpGemmAttributeMfmaImplF16F16F32M32N32K8<WGAttrCtlEnum::Default_>,
1>>;
using WarpGemmMfmaF16F16F32M32N32K16SwizzleA = WarpGemmImpl<WarpGemmAttributeMfmaIterateK_SwizzleA<
WarpGemmAttributeMfmaImplF16F16F32M32N32K8<WGAttrCtlEnum::Default_>,
2>>;
using WarpGemmMfmaF16F16F32M32N32K8TransposedCDistribution =
WarpGemmImpl<WarpGemmAttributeMfmaTransposedCDistribution<
WarpGemmAttributeMfmaImplF16F16F32M32N32K8<WGAttrCtlEnum::Default_>>>;
using WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution =
WarpGemmImpl<WarpGemmAttributeMfmaTransposedCDistribution<
WarpGemmAttributeMfmaImplF16F16F32M16N16K16<WGAttrCtlEnum::Default_>>>;
#if defined(__gfx950__)
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
using WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution =
WarpGemmImpl<WarpGemmAttributeMfmaTransposedCDistribution<
WarpGemmAttributeMfmaImplF16F16F32M32N32K16<WGAttrCtlEnum::Default_>,
AttrNumAccess>>;
#else
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
using WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution =
WarpGemmImpl<WarpGemmAttributeMfmaIterateKAndTransposedCDistribution<
WarpGemmAttributeMfmaImplF16F16F32M32N32K8<WGAttrCtlEnum::Default_>,
2,
AttrNumAccess>>;
#endif
#if defined(__gfx950__)
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
using WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution =
WarpGemmImpl<WarpGemmAttributeMfmaTransposedCDistribution<
WarpGemmAttributeMfmaImplF16F16F32M16N16K32<WGAttrCtlEnum::Default_>,
AttrNumAccess>>;
#else
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
using WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution =
WarpGemmImpl<WarpGemmAttributeMfmaIterateKAndTransposedCDistribution<
WarpGemmAttributeMfmaImplF16F16F32M16N16K16<WGAttrCtlEnum::Default_>,
2,
AttrNumAccess>>;
#endif
#if defined(__gfx950__)
using WarpGemmMfmaF16F16F32M16N16K32SwizzleBTransposedCDistribution =
WarpGemmImpl<WarpGemmAttributeMfmaTransposedCDistribution_SwizzleB<
WarpGemmAttributeMfmaImplF16F16F32M16N16K32<WGAttrCtlEnum::Default_>,
1>>;
using WarpGemmMfmaBf16Bf16F32M16N16K32SwizzleBTransposedCDistribution =
WarpGemmImpl<WarpGemmAttributeMfmaTransposedCDistribution_SwizzleB<
WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K32<WGAttrCtlEnum::Default_>,
1>>;
#endif
using WarpGemmMfmaF16F16F32M32N32K8SwizzleBTransposedCDistribution =
WarpGemmImpl<WarpGemmAttributeMfmaTransposedCDistribution_SwizzleB<
WarpGemmAttributeMfmaImplF16F16F32M32N32K8<WGAttrCtlEnum::Default_>>>;
#if defined(__gfx950__)
using WarpGemmMfmaF16F16F32M32N32K16SwizzleBTransposedCDistribution =
WarpGemmImpl<WarpGemmAttributeMfmaTransposedCDistribution_SwizzleB<
WarpGemmAttributeMfmaImplF16F16F32M32N32K16<WGAttrCtlEnum::Default_>>>;
#else
using WarpGemmMfmaF16F16F32M32N32K16SwizzleBTransposedCDistribution =
WarpGemmImpl<WarpGemmAttributeMfmaIterateKAndTransposedCDistribution_SwizzleB<
WarpGemmAttributeMfmaImplF16F16F32M32N32K8<WGAttrCtlEnum::Default_>,
2>>;
#endif
using WarpGemmMfmaF16F16F32M4N64K16 = WarpGemmImpl<WarpGemmAttributeMfmaIterateK<
WarpGemmAttributeMfmaImplF16F16F32M4N64K4<WGAttrCtlEnum::Default_>,
4>>;
using WarpGemmMfmaF16F16F32M64N4K16 = WarpGemmImpl<WarpGemmAttributeMfmaIterateK<
WarpGemmAttributeMfmaImplF16F16F32M64N4K4<WGAttrCtlEnum::Default_>,
4>>;
// fp16 2:4 structured sparsity
using WarpGemmSmfmacF16F16F32M32N32K16 = WarpGemmSmfmacImpl<WarpGemmAttributeSmfmac<
WarpGemmAttributeSmfmacImplF16F16F32M32N32K16<WGAttrCtlEnum::Default_>>>;
using WarpGemmSmfmacF16F16F32M16N16K32 = WarpGemmSmfmacImpl<WarpGemmAttributeSmfmac<
WarpGemmAttributeSmfmacImplF16F16F32M16N16K32<WGAttrCtlEnum::Default_>>>;
// bf16
using WarpGemmMfmaBf16Bf16F32M32N32K8 = WarpGemmImpl<
WarpGemmAttributeMfma<WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8<WGAttrCtlEnum::Default_>>>;
using WarpGemmMfmaBf16Bf16F32M16N16K16 = WarpGemmImpl<
WarpGemmAttributeMfma<WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16<WGAttrCtlEnum::Default_>>>;
#if defined(__gfx950__)
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
using WarpGemmMfmaBf16Bf16F32M32N32K16 = WarpGemmImpl<
WarpGemmAttributeMfma<WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K16<WGAttrCtlEnum::Default_>,
AttrNumAccess>>;
#else
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
using WarpGemmMfmaBf16Bf16F32M32N32K16 = WarpGemmImpl<WarpGemmAttributeMfmaIterateK<
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8<WGAttrCtlEnum::Default_>,
2,
AttrNumAccess>>;
#endif
#if defined(__gfx950__)
template <WGAttrNumAccessEnum AttrNumAccessA = WGAttrNumAccessEnum::Single,
WGAttrNumAccessEnum AttrNumAccessB = AttrNumAccessA>
using WarpGemmMfmaBf16Bf16F32M16N16K32 = WarpGemmImpl<
WarpGemmAttributeMfma<WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K32<WGAttrCtlEnum::Default_>,
AttrNumAccessA,
AttrNumAccessB>>;
template <WGAttrNumAccessEnum AttrNumAccessA = WGAttrNumAccessEnum::Single,
WGAttrNumAccessEnum AttrNumAccessB = AttrNumAccessA>
using WarpGemmMfmaBf16Bf16F32M16N16K64 = WarpGemmImpl<WarpGemmAttributeMfmaIterateK<
WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K32<WGAttrCtlEnum::Default_>,
2,
AttrNumAccessA,
AttrNumAccessB>>;
#else
template <WGAttrNumAccessEnum AttrNumAccessA = WGAttrNumAccessEnum::Single,
WGAttrNumAccessEnum AttrNumAccessB = AttrNumAccessA>
using WarpGemmMfmaBf16Bf16F32M16N16K32 = WarpGemmImpl<WarpGemmAttributeMfmaIterateK<
WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16<WGAttrCtlEnum::Default_>,
2,
AttrNumAccessA>>;
template <WGAttrNumAccessEnum AttrNumAccessA = WGAttrNumAccessEnum::Single,
WGAttrNumAccessEnum AttrNumAccessB = AttrNumAccessA>
using WarpGemmMfmaBf16Bf16F32M16N16K64 = WarpGemmImpl<WarpGemmAttributeMfmaIterateK<
WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16<WGAttrCtlEnum::Default_>,
4,
AttrNumAccessA,
AttrNumAccessB>>;
#endif
using WarpGemmMfmaBf16Bf16F32M32N32K8SwizzleA = WarpGemmImpl<WarpGemmAttributeMfmaIterateK_SwizzleA<
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8<WGAttrCtlEnum::Default_>,
1>>;
using WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleA =
WarpGemmImpl<WarpGemmAttributeMfmaIterateK_SwizzleA<
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8<WGAttrCtlEnum::Default_>,
2>>;
using WarpGemmMfmaBf16Bf16F32M32N32K8TransposedCDistribution =
WarpGemmImpl<WarpGemmAttributeMfmaTransposedCDistribution<
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8<WGAttrCtlEnum::Default_>>>;
using WarpGemmMfmaBf16Bf16F32M16N16K16TransposedCDistribution =
WarpGemmImpl<WarpGemmAttributeMfmaTransposedCDistribution<
WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16<WGAttrCtlEnum::Default_>>>;
#if defined(__gfx950__)
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
using WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution =
WarpGemmImpl<WarpGemmAttributeMfmaTransposedCDistribution<
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K16<WGAttrCtlEnum::Default_>,
AttrNumAccess>>;
#else
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
using WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution =
WarpGemmImpl<WarpGemmAttributeMfmaIterateKAndTransposedCDistribution<
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8<WGAttrCtlEnum::Default_>,
2,
AttrNumAccess>>;
#endif
#if defined(__gfx950__)
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
using WarpGemmMfmaBf16Bf16F32M16N16K32TransposedCDistribution =
WarpGemmImpl<WarpGemmAttributeMfmaTransposedCDistribution<
WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K32<WGAttrCtlEnum::Default_>,
AttrNumAccess>>;
#else
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
using WarpGemmMfmaBf16Bf16F32M16N16K32TransposedCDistribution =
WarpGemmImpl<WarpGemmAttributeMfmaIterateKAndTransposedCDistribution<
WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16<WGAttrCtlEnum::Default_>,
2,
AttrNumAccess>>;
#endif
using WarpGemmMfmaBf16Bf16F32M32N32K8SwizzleBTransposedCDistribution =
WarpGemmImpl<WarpGemmAttributeMfmaTransposedCDistribution_SwizzleB<
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8<WGAttrCtlEnum::Default_>>>;
#if defined(__gfx950__)
using WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleBTransposedCDistribution =
WarpGemmImpl<WarpGemmAttributeMfmaTransposedCDistribution_SwizzleB<
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K16<WGAttrCtlEnum::Default_>>>;
#else
using WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleBTransposedCDistribution =
WarpGemmImpl<WarpGemmAttributeMfmaIterateKAndTransposedCDistribution_SwizzleB<
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8<WGAttrCtlEnum::Default_>,
2>>;
#endif
using WarpGemmMfmaBf16Bf16F32M4N64K16 = WarpGemmImpl<WarpGemmAttributeMfmaIterateK<
WarpGemmAttributeMfmaImplBf16Bf16F32M4N64K4<WGAttrCtlEnum::Default_>,
4>>;
using WarpGemmMfmaBf16Bf16F32M64N4K16 = WarpGemmImpl<WarpGemmAttributeMfmaIterateK<
WarpGemmAttributeMfmaImplBf16Bf16F32M64N4K4<WGAttrCtlEnum::Default_>,
4>>;
// fp8
#if defined(__gfx950__)
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
using WarpGemmMfma_f32_32x32x32_fp8_fp8 = WarpGemmImpl<WarpGemmAttributeMfmaIterateK<
WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_fp8<WGAttrCtlEnum::Default_>,
2,
AttrNumAccess>>;
#else
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
using WarpGemmMfma_f32_32x32x32_fp8_fp8 = WarpGemmImpl<WarpGemmAttributeMfmaIterateK<
WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_fp8<WGAttrCtlEnum::Default_>,
2>>;
#endif
#if defined(__gfx950__)
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
using WarpGemmMfma_f32_16x16x64_fp8_fp8 = WarpGemmImpl<WarpGemmAttributeMfmaIterateK<
WarpGemmAttributeMfmaImpl_f32_16x16x32_fp8_fp8<WGAttrCtlEnum::Default_>,
2,
AttrNumAccess>>;
#else
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
using WarpGemmMfma_f32_16x16x64_fp8_fp8 = WarpGemmImpl<WarpGemmAttributeMfmaIterateK<
WarpGemmAttributeMfmaImpl_f32_16x16x32_fp8_fp8<WGAttrCtlEnum::Default_>,
2>>;
#endif
#if defined(__gfx950__)
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
using WarpGemmMfma_f32_32x32x32_bf8_bf8 = WarpGemmImpl<WarpGemmAttributeMfmaIterateK<
WarpGemmAttributeMfmaImpl_f32_32x32x16_bf8_bf8<WGAttrCtlEnum::Default_>,
2,
AttrNumAccess>>;
#else
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
using WarpGemmMfma_f32_32x32x32_bf8_bf8 = WarpGemmImpl<WarpGemmAttributeMfmaIterateK<
WarpGemmAttributeMfmaImpl_f32_32x32x16_bf8_bf8<WGAttrCtlEnum::Default_>,
2>>;
#endif
using WarpGemmMfma_f32_32x32x16_fp8_fp8 = WarpGemmImpl<
WarpGemmAttributeMfma<WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_fp8<WGAttrCtlEnum::Default_>>>;
using WarpGemmMfma_f32_32x32x16_fp8_bf8 = WarpGemmImpl<
WarpGemmAttributeMfma<WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_bf8<WGAttrCtlEnum::Default_>>>;
using WarpGemmMfma_f32_16x16x32_fp8_bf8 = WarpGemmImpl<
WarpGemmAttributeMfma<WarpGemmAttributeMfmaImpl_f32_16x16x32_fp8_bf8<WGAttrCtlEnum::Default_>>>;
using WarpGemmMfma_f32_32x32x16_bf8_fp8 = WarpGemmImpl<
WarpGemmAttributeMfma<WarpGemmAttributeMfmaImpl_f32_32x32x16_bf8_fp8<WGAttrCtlEnum::Default_>>>;
using WarpGemmMfma_f32_32x32x16_bf8_bf8 = WarpGemmImpl<
WarpGemmAttributeMfma<WarpGemmAttributeMfmaImpl_f32_32x32x16_bf8_bf8<WGAttrCtlEnum::Default_>>>;
using WarpGemmMfma_f32_32x32x32_fp8_bf8 = WarpGemmImpl<WarpGemmAttributeMfmaIterateK<
WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_bf8<WGAttrCtlEnum::Default_>,
2>>;
using WarpGemmMfma_f32_16x16x32_fp8_fp8 = WarpGemmImpl<
WarpGemmAttributeMfma<WarpGemmAttributeMfmaImpl_f32_16x16x32_fp8_fp8<WGAttrCtlEnum::Default_>>>;
using WarpGemmMfma_f32_16x16x32_fp8_fp8_CTransposed =
WarpGemmImpl<WarpGemmAttributeMfmaTransposedCDistribution<
WarpGemmAttributeMfmaImpl_f32_16x16x32_fp8_fp8<WGAttrCtlEnum::Default_>>>;
using WarpGemmMfma_f32_16x16x32_bf8_bf8 = WarpGemmImpl<
WarpGemmAttributeMfma<WarpGemmAttributeMfmaImpl_f32_16x16x32_bf8_bf8<WGAttrCtlEnum::Default_>>>;
using WarpGemmMfma_f32_16x16x32_bf8_bf8_CTransposed =
WarpGemmImpl<WarpGemmAttributeMfmaTransposedCDistribution<
WarpGemmAttributeMfmaImpl_f32_16x16x32_bf8_bf8<WGAttrCtlEnum::Default_>>>;
using WarpGemmMfma_f32_16x16x64_bf8_bf8 = WarpGemmImpl<WarpGemmAttributeMfmaIterateK<
WarpGemmAttributeMfmaImpl_f32_16x16x32_bf8_bf8<WGAttrCtlEnum::Default_>,
2>>;
using WarpGemmMfma_f32_16x16x64_fp8_fp8_CTransposed =
WarpGemmImpl<WarpGemmAttributeMfmaIterateKAndTransposedCDistribution<
WarpGemmAttributeMfmaImpl_f32_16x16x32_fp8_fp8<WGAttrCtlEnum::Default_>,
2>>;
using WarpGemmMfma_f32_16x16x64_bf8_bf8_CTransposed =
WarpGemmImpl<WarpGemmAttributeMfmaIterateKAndTransposedCDistribution<
WarpGemmAttributeMfmaImpl_f32_16x16x32_bf8_bf8<WGAttrCtlEnum::Default_>,
2>>;
template <typename A, typename B, WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
using WarpGemmMfma_f32_16x16x128_f8f6f4 = WarpGemmImpl<
WarpGemmAttributeMfma<WarpGemmAttributeMfmaImpl_f32_16x16x128_f8f6f4<A, B>, AttrNumAccess>>;
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
using WarpGemmMfma_f32_16x16x128_fp8_fp8_CTransposed =
WarpGemmImpl<WarpGemmAttributeMfmaTransposedCDistribution<
WarpGemmAttributeMfmaImpl_f32_16x16x128_f8f6f4<fp8_t, fp8_t>,
AttrNumAccess>>;
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
using WarpGemmMfma_f32_16x16x128_fp8_bf8_CTransposed =
WarpGemmImpl<WarpGemmAttributeMfmaTransposedCDistribution<
WarpGemmAttributeMfmaImpl_f32_16x16x128_f8f6f4<fp8_t, bf8_t>,
AttrNumAccess>>;
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
using WarpGemmMfma_f32_16x16x128_bf8_fp8_CTransposed =
WarpGemmImpl<WarpGemmAttributeMfmaTransposedCDistribution<
WarpGemmAttributeMfmaImpl_f32_16x16x128_f8f6f4<bf8_t, fp8_t>,
AttrNumAccess>>;
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
using WarpGemmMfma_f32_16x16x128_bf8_bf8_CTransposed =
WarpGemmImpl<WarpGemmAttributeMfmaTransposedCDistribution<
WarpGemmAttributeMfmaImpl_f32_16x16x128_f8f6f4<bf8_t, bf8_t>,
AttrNumAccess>>;
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
using WarpGemmMfma_f32_16x16x128_fp4_fp4_CTransposed =
WarpGemmImpl<WarpGemmAttributeMfmaTransposedCDistribution<
WarpGemmAttributeMfmaImpl_f32_16x16x128_f8f6f4<pk_fp4_t, pk_fp4_t>,
AttrNumAccess>>;
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
using WarpGemmMfma_f32_32x32x64_fp8_fp8 = WarpGemmImpl<
WarpGemmAttributeMfma<WarpGemmAttributeMfmaImpl_f32_32x32x64_fp8_fp8<WGAttrCtlEnum::Default_>,
AttrNumAccess>>;
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
using WarpGemmMfma_f32_32x32x64_fp8_bf8 = WarpGemmImpl<
WarpGemmAttributeMfma<WarpGemmAttributeMfmaImpl_f32_32x32x64_fp8_bf8<WGAttrCtlEnum::Default_>,
AttrNumAccess>>;
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
using WarpGemmMfma_f32_32x32x64_bf8_fp8 = WarpGemmImpl<
WarpGemmAttributeMfma<WarpGemmAttributeMfmaImpl_f32_32x32x64_bf8_fp8<WGAttrCtlEnum::Default_>,
AttrNumAccess>>;
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
using WarpGemmMfma_f32_32x32x64_bf8_bf8 = WarpGemmImpl<
WarpGemmAttributeMfma<WarpGemmAttributeMfmaImpl_f32_32x32x64_bf8_bf8<WGAttrCtlEnum::Default_>,
AttrNumAccess>>;
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
using WarpGemmMfma_f32_32x32x64_fp8_fp8_CTransposed =
WarpGemmImpl<WarpGemmAttributeMfmaTransposedCDistribution<
WarpGemmAttributeMfmaImpl_f32_32x32x64_fp8_fp8<WGAttrCtlEnum::Default_>,
AttrNumAccess>>;
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
using WarpGemmMfma_f32_32x32x64_fp8_bf8_CTransposed =
WarpGemmImpl<WarpGemmAttributeMfmaTransposedCDistribution<
WarpGemmAttributeMfmaImpl_f32_32x32x64_fp8_bf8<WGAttrCtlEnum::Default_>,
AttrNumAccess>>;
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
using WarpGemmMfma_f32_32x32x64_bf8_fp8_CTransposed =
WarpGemmImpl<WarpGemmAttributeMfmaTransposedCDistribution<
WarpGemmAttributeMfmaImpl_f32_32x32x64_bf8_fp8<WGAttrCtlEnum::Default_>,
AttrNumAccess>>;
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
using WarpGemmMfma_f32_32x32x64_bf8_bf8_CTransposed =
WarpGemmImpl<WarpGemmAttributeMfmaTransposedCDistribution<
WarpGemmAttributeMfmaImpl_f32_32x32x64_bf8_bf8<WGAttrCtlEnum::Default_>,
AttrNumAccess>>;
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
using WarpGemmMfma_f32_32x32x64_fp4_fp4_CTransposed =
WarpGemmImpl<WarpGemmAttributeMfmaTransposedCDistribution<
WarpGemmAttributeMfmaImpl_f32_32x32x64_f8f6f4<pk_fp4_t, pk_fp4_t>,
AttrNumAccess>>;
using WarpGemmMfma_f32_32x32x16_fp8_fp8_CTransposed =
WarpGemmImpl<WarpGemmAttributeMfmaTransposedCDistribution<
WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_fp8<WGAttrCtlEnum::Default_>>>;
using WarpGemmMfma_f32_32x32x16_fp8_bf8_CTransposed =
WarpGemmImpl<WarpGemmAttributeMfmaTransposedCDistribution<
WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_bf8<WGAttrCtlEnum::Default_>>>;
using WarpGemmMfma_f32_32x32x16_bf8_fp8_CTransposed =
WarpGemmImpl<WarpGemmAttributeMfmaTransposedCDistribution<
WarpGemmAttributeMfmaImpl_f32_32x32x16_bf8_fp8<WGAttrCtlEnum::Default_>>>;
using WarpGemmMfma_f32_32x32x16_bf8_bf8_CTransposed =
WarpGemmImpl<WarpGemmAttributeMfmaTransposedCDistribution<
WarpGemmAttributeMfmaImpl_f32_32x32x16_bf8_bf8<WGAttrCtlEnum::Default_>>>;
template <index_t swizzle_factor = 2>
using WarpGemmMfmaFp8Fp8F32M32N32K32SwizzleBTransposedCDistribution =
WarpGemmImpl<WarpGemmAttributeMfmaIterateKAndTransposedCDistribution_SwizzleB<
WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_fp8<WGAttrCtlEnum::Default_>,
2,
swizzle_factor>>;
// int8
using WarpGemmMfma_i32_32x32x16_i8_i8 = WarpGemmImpl<
WarpGemmAttributeMfma<WarpGemmAttributeMfmaImpl_i32_32x32x16_i8<WGAttrCtlEnum::Default_>>>;
using WarpGemmMfma_i32_32x32x16_i8_i8_CTransposed =
WarpGemmImpl<WarpGemmAttributeMfmaTransposedCDistribution<
WarpGemmAttributeMfmaImpl_i32_32x32x16_i8<WGAttrCtlEnum::Default_>>>;
using WarpGemmMfma_i32_16x16x32_i8_i8 = WarpGemmImpl<
WarpGemmAttributeMfma<WarpGemmAttributeMfmaImpl_i32_16x16x32_i8<WGAttrCtlEnum::Default_>>>;
using WarpGemmMfma_i32_16x16x32_i8_i8_CTransposed =
WarpGemmImpl<WarpGemmAttributeMfmaTransposedCDistribution<
WarpGemmAttributeMfmaImpl_i32_16x16x32_i8<WGAttrCtlEnum::Default_>>>;
} // namespace ck_tile

View File

@@ -0,0 +1,928 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp"
namespace ck_tile {
// Number of groups of consecutive elements to fill in a ABKLane
enum class WGAttrNumAccessEnum
{
Single = 1,
Double = 2,
Quad = 4,
Invalid = -1
};
template <WGAttrNumAccessEnum AttrNumAccess>
struct get_wgattr_num_access
{
private:
static constexpr index_t getAccesses()
{
if constexpr(AttrNumAccess == WGAttrNumAccessEnum::Single)
{
return 1;
}
else if constexpr(AttrNumAccess == WGAttrNumAccessEnum::Double)
{
return 2;
}
else if constexpr(AttrNumAccess == WGAttrNumAccessEnum::Quad)
{
return 4;
}
else
{
static_assert(false, "unsupported AttrNumAccess");
return 0;
}
}
public:
static constexpr auto value = getAccesses();
};
template <typename WarpGemmAttributeMfmaImpl_,
WGAttrNumAccessEnum AttrNumAccessA_ = WGAttrNumAccessEnum::Single,
WGAttrNumAccessEnum AttrNumAccessB_ = AttrNumAccessA_>
struct WarpGemmAttributeMfma
{
using Impl = remove_cvref_t<WarpGemmAttributeMfmaImpl_>;
static constexpr auto AttrNumAccessA = AttrNumAccessA_;
static constexpr auto AttrNumAccessAV = get_wgattr_num_access<AttrNumAccessA>::value;
static constexpr auto AttrNumAccessB = AttrNumAccessB_;
static constexpr auto AttrNumAccessBV = get_wgattr_num_access<AttrNumAccessB>::value;
static constexpr bool UsePackNumAccess = AttrNumAccessA != AttrNumAccessB;
using ADataType = typename Impl::ADataType;
using BDataType = typename Impl::BDataType;
using CDataType = typename Impl::CDataType;
using AVecType = typename Impl::AVecType;
using BVecType = typename Impl::BVecType;
using CVecType = typename Impl::CVecType;
static constexpr index_t kM = Impl::kM;
static constexpr index_t kN = Impl::kN;
static constexpr index_t kK = Impl::kK;
static constexpr index_t kKPerThread = Impl::kABKPerLane;
static constexpr index_t kCMLane = Impl::kCMLane;
CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return 1; }
static_assert(Impl::kAMBlock == 1 && Impl::kBNBlock == 1,
"Multi-block WarpGemmAttributeMfmaImpl is not supported");
template <index_t kMNLane, index_t AttrNumAccessV_>
static constexpr auto get_warp_dstr_encoding()
{
static_assert(kKPerThread % AttrNumAccessV_ == 0,
"kKPerThread must be divisible by NumAccess");
if constexpr(AttrNumAccessV_ == 1)
{
return tile_distribution_encoding<
sequence<>,
tuple<sequence<kMNLane>, sequence<Impl::kABKLane, Impl::kABKPerLane>>,
tuple<sequence<2, 1>>,
tuple<sequence<0, 0>>,
sequence<2>,
sequence<1>>{};
}
else
{
// AttrNumAccess splits the kABKPerLane
// We can split them but still have them contiguous (packed) or have them interleaved.
// The reason to split the dimension but still have it packed is to match load transpose
// encoding when A and B use different AttrNumAccess (they have different types in LDS)
// Example
// A: 16bit, B: 8bit
// Load transpose B: lane0 -> K=0..7 (only 1 instruction)
// Load transpose A: lane0 -> K=0..3 first instruction, K=4..7 second instruction
// In this way the data in register are consistent between A and B
if constexpr(UsePackNumAccess)
{
return tile_distribution_encoding<
sequence<>,
tuple<sequence<kMNLane>,
sequence<Impl::kABKLane,
AttrNumAccessV_,
Impl::kABKPerLane / AttrNumAccessV_>>,
tuple<sequence<2, 1>>,
tuple<sequence<0, 0>>,
sequence<2, 2>,
sequence<1, 2>>{};
}
else
{
return tile_distribution_encoding<
sequence<>,
tuple<sequence<kMNLane>,
sequence<AttrNumAccessV_,
Impl::kABKLane,
Impl::kABKPerLane / AttrNumAccessV_>>,
tuple<sequence<2, 1>>,
tuple<sequence<1, 0>>,
sequence<2, 2>,
sequence<0, 2>>{};
}
}
}
using AWarpDstrEncoding = decltype(get_warp_dstr_encoding<Impl::kAMLane, AttrNumAccessAV>());
using BWarpDstrEncoding = decltype(get_warp_dstr_encoding<Impl::kBNLane, AttrNumAccessBV>());
using CWarpDstrEncoding = tile_distribution_encoding<
sequence<>,
tuple<sequence<Impl::kCM0PerLane, Impl::kCMLane, Impl::kCM1PerLane>,
sequence<Impl::kCNLane>>,
tuple<sequence<1, 2>>,
tuple<sequence<1, 0>>,
sequence<1, 1>,
sequence<0, 2>>;
// c_vec += a_vec * b_vec
template <bool post_nop_ = false>
CK_TILE_DEVICE void operator()(CVecType& c_vec,
const AVecType& a_vec,
const BVecType& b_vec,
bool_constant<post_nop_> = {}) const
{
Impl{}(c_vec, a_vec, b_vec, bool_constant<post_nop_>{});
}
// c_vec += a_vec * b_vec
template <index_t opselA, index_t opselB, bool post_nop_ = false>
CK_TILE_DEVICE void operator()(CVecType& c_vec,
const AVecType& a_vec,
const int32_t& a_scale,
const BVecType& b_vec,
const int32_t& b_scale,
bool_constant<post_nop_> = {}) const
{
Impl{}.template operator()<opselA, opselB>(
c_vec, a_vec, a_scale, b_vec, b_scale, bool_constant<post_nop_>{});
}
// c_vec = a_vec * b_vec
CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
{
return Impl{}(a_vec, b_vec);
}
// c_vec = a_vec * b_vec
template <index_t opselA, index_t opselB>
CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec,
const int32_t& a_scale,
const BVecType& b_vec,
const int32_t& b_scale) const
{
return Impl{}.template operator()<opselA, opselB>(a_vec, a_scale, b_vec, b_scale);
}
};
template <typename WarpGemmAttributeMfmaImpl_,
index_t kKIter,
WGAttrNumAccessEnum AttrNumAccessA_ = WGAttrNumAccessEnum::Single,
WGAttrNumAccessEnum AttrNumAccessB_ = AttrNumAccessA_>
struct WarpGemmAttributeMfmaIterateK
{
static_assert(kKIter > 0, "wrong!");
using Impl = remove_cvref_t<WarpGemmAttributeMfmaImpl_>;
static constexpr auto AttrNumAccessA = AttrNumAccessA_;
static constexpr auto AttrNumAccessAV = get_wgattr_num_access<AttrNumAccessA>::value;
static constexpr auto AttrNumAccessB = AttrNumAccessB_;
static constexpr auto AttrNumAccessBV = get_wgattr_num_access<AttrNumAccessB>::value;
static constexpr bool UsePackNumAccess = AttrNumAccessA != AttrNumAccessB;
using ADataType = typename Impl::ADataType;
using BDataType = typename Impl::BDataType;
using CDataType = typename Impl::CDataType;
using AVecType =
ext_vector_t<ADataType, vector_traits<typename Impl::AVecType>::vector_size * kKIter>;
using BVecType =
ext_vector_t<BDataType, vector_traits<typename Impl::BVecType>::vector_size * kKIter>;
using CVecType = typename Impl::CVecType;
static constexpr index_t kM = Impl::kM;
static constexpr index_t kN = Impl::kN;
static constexpr index_t kK = Impl::kK * kKIter;
static constexpr index_t kKPerThread = Impl::kABKPerLane * kKIter;
static constexpr index_t kCMLane = Impl::kCMLane;
CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return kKIter; }
static_assert(Impl::kAMBlock == 1 || Impl::kBNBlock == 1,
"Multi-block on both M & N directions is not supported");
template <index_t kMNLane, index_t kMNBlock, index_t kNMBlock, index_t AttrNumAccessV_>
CK_TILE_DEVICE static constexpr auto get_warp_dstr_encoding()
{
if constexpr(kMNBlock == 1 && kNMBlock == 1)
{
static_assert(kKPerThread % AttrNumAccessV_ == 0,
"kKPerThread must be divisible by NumAccess");
if constexpr(AttrNumAccessV_ == 1)
{
return tile_distribution_encoding<
sequence<>,
tuple<sequence<kMNLane>, sequence<Impl::kABKLane, Impl::kABKPerLane * kKIter>>,
tuple<sequence<2, 1>>,
tuple<sequence<0, 0>>,
sequence<2>,
sequence<1>>{};
}
else
{
if constexpr(UsePackNumAccess)
{
return tile_distribution_encoding<
sequence<>,
tuple<sequence<kMNLane>,
sequence<Impl::kABKLane,
AttrNumAccessV_,
Impl::kABKPerLane * kKIter / AttrNumAccessV_>>,
tuple<sequence<2, 1>>,
tuple<sequence<0, 0>>,
sequence<2, 2>,
sequence<1, 2>>{};
}
else
{
return tile_distribution_encoding<
sequence<>,
tuple<sequence<kMNLane>,
sequence<AttrNumAccessV_,
Impl::kABKLane,
Impl::kABKPerLane * kKIter / AttrNumAccessV_>>,
tuple<sequence<2, 1>>,
tuple<sequence<1, 0>>,
sequence<2, 2>,
sequence<0, 2>>{};
}
}
}
else if constexpr(kMNBlock == 1 && 1 < kNMBlock)
{
static_assert(AttrNumAccessV_ == 1,
"Multiple access is not supported when using multi-block");
// each M/N blocks share the same data
return tile_distribution_encoding<
sequence<kNMBlock>,
tuple<sequence<kMNLane>, sequence<Impl::kABKLane, Impl::kABKPerLane * kKIter>>,
tuple<sequence<0, 2, 1>>,
tuple<sequence<0, 0, 0>>,
sequence<2>,
sequence<1>>{};
}
else if constexpr(1 < kMNBlock && kNMBlock == 1)
{
static_assert(AttrNumAccessV_ == 1,
"Multiple access is not supported when using multi-block");
// single block to multi-block thread mapping
return tile_distribution_encoding<
sequence<>,
tuple<sequence<kMNBlock, kMNLane>,
sequence<Impl::kABKLane, Impl::kABKPerLane * kKIter>>,
tuple<sequence<1, 2, 1>>,
tuple<sequence<0, 0, 1>>,
sequence<2>,
sequence<1>>{};
}
}
CK_TILE_DEVICE static constexpr auto get_cwarp_dstr_encoding()
{
if constexpr(Impl::kAMBlock == 1 && Impl::kBNBlock == 1)
{
return tile_distribution_encoding<
sequence<>,
tuple<sequence<Impl::kCM0PerLane, Impl::kCMLane, Impl::kCM1PerLane>,
sequence<Impl::kCNLane>>,
tuple<sequence<1, 2>>,
tuple<sequence<1, 0>>,
sequence<1, 1>,
sequence<0, 2>>{};
}
else if constexpr(Impl::kAMBlock == 1 && 1 < Impl::kBNBlock)
{
return tile_distribution_encoding<
sequence<>,
tuple<sequence<Impl::kCM0PerLane, Impl::kCMLane, Impl::kCM1PerLane>,
sequence<Impl::kBNBlock * Impl::kCNLane>>,
tuple<sequence<1, 2>>,
tuple<sequence<1, 0>>,
sequence<1, 1>,
sequence<0, 2>>{};
}
else if constexpr(1 < Impl::kAMBlock && Impl::kBNBlock == 1)
{
return tile_distribution_encoding<
sequence<>,
tuple<
sequence<Impl::kCM0PerLane, Impl::kAMBlock * Impl::kCMLane, Impl::kCM1PerLane>,
sequence<Impl::kCNLane>>,
tuple<sequence<1, 2>>,
tuple<sequence<1, 0>>,
sequence<1, 1>,
sequence<0, 2>>{};
}
}
using AWarpDstrEncoding = decltype(get_warp_dstr_encoding<Impl::kAMLane,
Impl::kAMBlock,
Impl::kBNBlock,
AttrNumAccessAV>());
using BWarpDstrEncoding = decltype(get_warp_dstr_encoding<Impl::kBNLane,
Impl::kBNBlock,
Impl::kAMBlock,
AttrNumAccessBV>());
using CWarpDstrEncoding = decltype(get_cwarp_dstr_encoding());
// c_vec += a_vec * b_vec
template <bool post_nop_ = false>
CK_TILE_DEVICE void operator()(CVecType& c_vec,
const AVecType& a_vec,
const BVecType& b_vec,
bool_constant<post_nop_> = {}) const
{
static_for<0, kKIter, 1>{}([&](auto iKIter) { operator()(c_vec, a_vec, b_vec, iKIter); });
}
template <index_t iKIter, bool post_nop_ = false>
CK_TILE_DEVICE void operator()(CVecType& c_vec,
const AVecType& a_vec,
const BVecType& b_vec,
number<iKIter>,
bool_constant<post_nop_> = {}) const
{
using buf_a = thread_buffer<typename Impl::AVecType, kKIter>;
using buf_b = thread_buffer<typename Impl::BVecType, kKIter>;
static_assert(iKIter < kKIter);
Impl{}(c_vec,
reinterpret_cast<const buf_a&>(a_vec)
.template get_as<typename Impl::AVecType>()[iKIter],
reinterpret_cast<const buf_b&>(b_vec)
.template get_as<typename Impl::BVecType>()[iKIter],
bool_constant<post_nop_>{});
}
// c_vec = a_vec * b_vec
CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
{
constexpr auto I0 = number<0>{};
using buf_a = thread_buffer<typename Impl::AVecType, kKIter>;
using buf_b = thread_buffer<typename Impl::BVecType, kKIter>;
// c = a * b
auto c_vec = Impl{}(
reinterpret_cast<const buf_a&>(a_vec).template get_as<typename Impl::AVecType>()[I0],
reinterpret_cast<const buf_b&>(b_vec).template get_as<typename Impl::BVecType>()[I0]);
// c += a * b
static_for<1, kKIter, 1>{}([&](auto iKIter) { operator()(c_vec, a_vec, b_vec, iKIter); });
return c_vec;
}
};
template <typename WarpGemmAttributeMfmaImpl_,
WGAttrNumAccessEnum AttrNumAccess_ = WGAttrNumAccessEnum::Single>
struct WarpGemmAttributeMfmaTransposedCDistribution
{
using Impl = remove_cvref_t<WarpGemmAttributeMfmaImpl_>;
static constexpr auto AttrNumAccess = AttrNumAccess_;
static constexpr auto AttrNumAccessV = static_cast<index_t>(AttrNumAccess);
using ADataType = typename Impl::BDataType;
using BDataType = typename Impl::ADataType;
using CDataType = typename Impl::CDataType;
using AVecType = typename Impl::BVecType;
using BVecType = typename Impl::AVecType;
using CVecType = typename Impl::CVecType;
static constexpr index_t kM = Impl::kN;
static constexpr index_t kN = Impl::kM;
static constexpr index_t kK = Impl::kK;
static constexpr index_t kKPerThread = Impl::kABKPerLane;
static constexpr index_t kCMLane = Impl::kCMLane;
CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return 1; }
static_assert(Impl::kAMBlock == 1 && Impl::kBNBlock == 1,
"Multi-block WarpGemmAttributeMfmaImpl is not supported");
using AWarpDstrEncoding =
typename WarpGemmAttributeMfma<Impl, AttrNumAccess>::BWarpDstrEncoding;
using BWarpDstrEncoding =
typename WarpGemmAttributeMfma<Impl, AttrNumAccess>::AWarpDstrEncoding;
using CWarpDstrEncoding = tile_distribution_encoding<
sequence<>,
tuple<sequence<Impl::kCNLane>,
sequence<Impl::kCM0PerLane, Impl::kCMLane, Impl::kCM1PerLane>>,
tuple<sequence<2, 1>>,
tuple<sequence<1, 0>>,
sequence<2, 2>,
sequence<0, 2>>;
// c_vec += a_vec * b_vec
template <bool post_nop_ = false>
CK_TILE_DEVICE void operator()(CVecType& c_vec,
const AVecType& a_vec,
const BVecType& b_vec,
bool_constant<post_nop_> = {}) const
{
// swap A and B
Impl{}(c_vec, b_vec, a_vec, bool_constant<post_nop_>{});
}
template <index_t opselA, index_t opselB, bool post_nop_ = false>
CK_TILE_DEVICE void operator()(CVecType& c_vec,
const AVecType& a_vec,
const int32_t& a_scale,
const BVecType& b_vec,
const int32_t& b_scale,
bool_constant<post_nop_> = {}) const
{
// swap A and B
Impl{}.template operator()<opselB, opselA>(
c_vec, b_vec, b_scale, a_vec, a_scale, bool_constant<post_nop_>{});
}
// c_vec = a_vec * b_vec
CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
{
// swap A and B
return Impl{}(b_vec, a_vec);
}
};
template <typename WarpGemmAttributeMfmaImpl_, index_t SFactor_ = 2>
struct WarpGemmAttributeMfmaTransposedCDistribution_SwizzleB
{
using Impl = remove_cvref_t<WarpGemmAttributeMfmaImpl_>;
using ADataType = typename Impl::BDataType;
using BDataType = typename Impl::ADataType;
using CDataType = typename Impl::CDataType;
using AVecType = typename Impl::BVecType;
using BVecType = typename Impl::AVecType;
using CVecType = typename Impl::CVecType;
static constexpr index_t kM = Impl::kN;
static constexpr index_t kN = Impl::kM;
static constexpr index_t kK = Impl::kK;
static constexpr index_t kKPerThread = Impl::kABKPerLane;
static constexpr index_t SFactor = SFactor_; // group how many CM1 together
CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return 1; }
static_assert(Impl::kAMBlock == 1 && Impl::kBNBlock == 1,
"Multi-block WarpGemmAttributeMfmaImpl is not supported");
using AWarpDstrEncoding = tile_distribution_encoding<
sequence<>,
tuple<sequence<Impl::kBNLane>, sequence<Impl::kABKLane, Impl::kABKPerLane>>,
tuple<sequence<2, 1>>,
tuple<sequence<0, 0>>,
sequence<2>,
sequence<1>>;
#if 0
using BWarpDstrEncoding = tile_distribution_encoding<
sequence<>,
tuple<sequence<Impl::kAMLane / (Impl::kABKPerLane * Impl::kABKLane * 2),
Impl::kABKLane,
2,
Impl::kABKPerLane>,
sequence<Impl::kABKLane, Impl::kABKPerLane>>,
tuple<sequence<2, 1, 1, 1, 1>>,
tuple<sequence<0, 0, 2, 1, 3>>,
sequence<2>,
sequence<1>>;
using CWarpDstrEncoding = tile_distribution_encoding<
sequence<>,
tuple<sequence<Impl::kCNLane>,
sequence<Impl::kCM0PerLane / 2, Impl::kCMLane, Impl::kCM1PerLane * 2>>,
tuple<sequence<2, 1>>,
tuple<sequence<1, 0>>,
sequence<2, 2>,
sequence<0, 2>>;
#else
// TODO: more test not only 32x32
using BWarpDstrEncoding = tile_distribution_encoding<
sequence<>,
tuple<sequence<Impl::kAMLane / (Impl::kCMLane * SFactor * Impl::kCM1PerLane),
Impl::kCMLane,
SFactor,
Impl::kCM1PerLane>,
sequence<Impl::kABKLane, Impl::kABKPerLane>>,
tuple<sequence<2, 1, 1, 1, 1>>,
tuple<sequence<0, 0, 2, 1, 3>>,
sequence<2>,
sequence<1>>;
using CWarpDstrEncoding = tile_distribution_encoding<
sequence<>,
tuple<sequence<Impl::kCNLane>,
sequence<Impl::kCM0PerLane / SFactor, Impl::kCMLane, Impl::kCM1PerLane * SFactor>>,
tuple<sequence<2, 1>>,
tuple<sequence<1, 0>>,
sequence<2, 2>,
sequence<0, 2>>;
#endif
template <bool post_nop_ = false>
// c_vec += a_vec * b_vec
CK_TILE_DEVICE void operator()(CVecType& c_vec,
const AVecType& a_vec,
const BVecType& b_vec,
bool_constant<post_nop_> = {}) const
{
// swap A and B
Impl{}(c_vec, b_vec, a_vec, bool_constant<post_nop_>{});
}
template <index_t opselA, index_t opselB, bool post_nop_ = false>
CK_TILE_DEVICE void operator()(CVecType& c_vec,
const AVecType& a_vec,
const int32_t& a_scale,
const BVecType& b_vec,
const int32_t& b_scale,
bool_constant<post_nop_> = {}) const
{
// swap A and B
Impl{}.template operator()<opselB, opselA>(
c_vec, b_vec, b_scale, a_vec, a_scale, bool_constant<post_nop_>{});
}
// c_vec = a_vec * b_vec
CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
{
// swap A and B
return Impl{}(b_vec, a_vec);
}
};
template <typename WarpGemmAttributeMfmaImpl_,
index_t kKIter,
WGAttrNumAccessEnum AttrNumAccess_ = WGAttrNumAccessEnum::Single>
struct WarpGemmAttributeMfmaIterateKAndTransposedCDistribution
{
using Impl = remove_cvref_t<WarpGemmAttributeMfmaImpl_>;
static constexpr auto AttrNumAccess = AttrNumAccess_;
// swap A and B
using ADataType = typename Impl::BDataType;
using BDataType = typename Impl::ADataType;
using CDataType = typename Impl::CDataType;
using AVecType =
ext_vector_t<ADataType, vector_traits<typename Impl::AVecType>::vector_size * kKIter>;
using BVecType =
ext_vector_t<BDataType, vector_traits<typename Impl::BVecType>::vector_size * kKIter>;
using CVecType = typename Impl::CVecType;
static constexpr index_t kM = Impl::kN;
static constexpr index_t kN = Impl::kM;
static constexpr index_t kK = Impl::kK * kKIter;
static constexpr index_t kKPerThread = Impl::kABKPerLane * kKIter;
CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return kKIter; }
static_assert(Impl::kAMBlock == 1 || Impl::kBNBlock == 1,
"Multi-block on both M & N directions is not supported");
CK_TILE_DEVICE static constexpr auto get_cwarp_dstr_encoding()
{
if constexpr(Impl::kAMBlock == 1 && Impl::kBNBlock == 1)
{
return tile_distribution_encoding<
sequence<>,
tuple<sequence<Impl::kCNLane>,
sequence<Impl::kCM0PerLane, Impl::kCMLane, Impl::kCM1PerLane>>,
tuple<sequence<2, 1>>,
tuple<sequence<1, 0>>,
sequence<2, 2>,
sequence<0, 2>>{};
}
else if constexpr(Impl::kAMBlock == 1 && 1 < Impl::kBNBlock)
{
return tile_distribution_encoding<
sequence<>,
tuple<sequence<Impl::kBNBlock * Impl::kCNLane>,
sequence<Impl::kCM0PerLane, Impl::kCMLane, Impl::kCM1PerLane>>,
tuple<sequence<2, 1>>,
tuple<sequence<1, 0>>,
sequence<2, 2>,
sequence<0, 2>>{};
}
else if constexpr(1 < Impl::kAMBlock && Impl::kBNBlock == 1)
{
return tile_distribution_encoding<
sequence<>,
tuple<
sequence<Impl::kCNLane>,
sequence<Impl::kCM0PerLane, Impl::kAMBlock * Impl::kCMLane, Impl::kCM1PerLane>>,
tuple<sequence<2, 1>>,
tuple<sequence<1, 0>>,
sequence<2, 2>,
sequence<0, 2>>{};
}
}
using AWarpDstrEncoding =
typename WarpGemmAttributeMfmaIterateK<Impl, kKIter, AttrNumAccess>::BWarpDstrEncoding;
using BWarpDstrEncoding =
typename WarpGemmAttributeMfmaIterateK<Impl, kKIter, AttrNumAccess>::AWarpDstrEncoding;
using CWarpDstrEncoding = decltype(get_cwarp_dstr_encoding());
// c_vec += a_vec * b_vec
template <bool post_nop_ = false>
CK_TILE_DEVICE void operator()(CVecType& c_vec,
const AVecType& a_vec,
const BVecType& b_vec,
bool_constant<post_nop_> = {}) const
{
static_for<0, kKIter, 1>{}([&](auto iKIter) { operator()(c_vec, a_vec, b_vec, iKIter); });
}
template <index_t iKIter, bool post_nop_ = false>
// c_vec += a_vec * b_vec
CK_TILE_DEVICE void operator()(CVecType& c_vec,
const AVecType& a_vec,
const BVecType& b_vec,
number<iKIter>,
bool_constant<post_nop_> = {}) const
{
using buf_a = thread_buffer<typename Impl::AVecType, kKIter>;
using buf_b = thread_buffer<typename Impl::BVecType, kKIter>;
static_assert(iKIter < kKIter);
// swap A and B, value and type
Impl{}(c_vec,
reinterpret_cast<const buf_b&>(b_vec)
.template get_as<typename Impl::BVecType>()[iKIter],
reinterpret_cast<const buf_a&>(a_vec)
.template get_as<typename Impl::AVecType>()[iKIter],
bool_constant<post_nop_>{});
}
// c_vec = a_vec * b_vec
CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
{
constexpr auto I0 = number<0>{};
using buf_a = thread_buffer<typename Impl::AVecType, kKIter>;
using buf_b = thread_buffer<typename Impl::BVecType, kKIter>;
// swap A and B, value and type
auto c_vec = Impl{}(
reinterpret_cast<const buf_b&>(b_vec).template get_as<typename Impl::BVecType>()[I0],
reinterpret_cast<const buf_a&>(a_vec).template get_as<typename Impl::AVecType>()[I0]);
static_for<1, kKIter, 1>{}([&](auto iKIter) { operator()(c_vec, a_vec, b_vec, iKIter); });
return c_vec;
}
};
template <typename WarpGemmAttributeMfmaImpl_, index_t kKIter, index_t SFactor_ = 2>
struct WarpGemmAttributeMfmaIterateKAndTransposedCDistribution_SwizzleB
{
using Impl = remove_cvref_t<WarpGemmAttributeMfmaImpl_>;
// swap A and B
using ADataType = typename Impl::BDataType;
using BDataType = typename Impl::ADataType;
using CDataType = typename Impl::CDataType;
using AVecType =
ext_vector_t<ADataType, vector_traits<typename Impl::AVecType>::vector_size * kKIter>;
using BVecType =
ext_vector_t<BDataType, vector_traits<typename Impl::BVecType>::vector_size * kKIter>;
using CVecType = typename Impl::CVecType;
static constexpr index_t kM = Impl::kN;
static constexpr index_t kN = Impl::kM;
static constexpr index_t kK = Impl::kK * kKIter;
static constexpr index_t kKPerThread = Impl::kABKPerLane * kKIter;
static constexpr index_t SFactor = SFactor_; // group how many CM1 together
CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return kKIter; }
static_assert(Impl::kAMBlock == 1 && Impl::kBNBlock == 1,
"Multi-block WarpGemmAttributeMfmaImpl is not supported");
using AWarpDstrEncoding = tile_distribution_encoding<
sequence<>,
tuple<sequence<Impl::kBNLane>, sequence<Impl::kABKLane, Impl::kABKPerLane * kKIter>>,
tuple<sequence<2, 1>>,
tuple<sequence<0, 0>>,
sequence<2>,
sequence<1>>;
#if 0
using BWarpDstrEncoding = tile_distribution_encoding<
sequence<>,
tuple<sequence<Impl::kAMLane / (Impl::kABKPerLane * Impl::kABKLane * 2),
Impl::kABKLane,
2,
Impl::kABKPerLane>,
sequence<Impl::kABKLane, Impl::kABKPerLane * kKIter>>,
tuple<sequence<2, 1, 1, 1, 1>>,
tuple<sequence<0, 0, 2, 1, 3>>,
sequence<2>,
sequence<1>>;
using CWarpDstrEncoding = tile_distribution_encoding<
sequence<>,
tuple<sequence<Impl::kCNLane>,
sequence<Impl::kCM0PerLane / 2, Impl::kCMLane, Impl::kCM1PerLane * 2>>,
tuple<sequence<2, 1>>,
tuple<sequence<1, 0>>,
sequence<2, 2>,
sequence<0, 2>>;
#else
// TODO: more test not only 32x32
using BWarpDstrEncoding = tile_distribution_encoding<
sequence<>,
tuple<sequence<Impl::kAMLane / (Impl::kCMLane * SFactor * Impl::kCM1PerLane),
Impl::kCMLane,
SFactor,
Impl::kCM1PerLane>,
sequence<Impl::kABKLane, Impl::kABKPerLane * kKIter>>,
tuple<sequence<2, 1, 1, 1, 1>>,
tuple<sequence<0, 0, 2, 1, 3>>,
sequence<2>,
sequence<1>>;
using CWarpDstrEncoding = tile_distribution_encoding<
sequence<>,
tuple<sequence<Impl::kCNLane>,
sequence<Impl::kCM0PerLane / SFactor, Impl::kCMLane, Impl::kCM1PerLane * SFactor>>,
tuple<sequence<2, 1>>,
tuple<sequence<1, 0>>,
sequence<2, 2>,
sequence<0, 2>>;
#endif
// c_vec += a_vec * b_vec
template <bool post_nop_ = false>
CK_TILE_DEVICE void operator()(CVecType& c_vec,
const AVecType& a_vec,
const BVecType& b_vec,
bool_constant<post_nop_> = {}) const
{
// swap A and B, value and type
static_for<0, kKIter, 1>{}([&](auto iKIter) { operator()(c_vec, a_vec, b_vec, iKIter); });
}
template <index_t iKIter, bool post_nop_ = false>
CK_TILE_DEVICE void operator()(CVecType& c_vec,
const AVecType& a_vec,
const BVecType& b_vec,
number<iKIter>,
bool_constant<post_nop_> = {}) const
{
using buf_a = thread_buffer<typename Impl::AVecType, kKIter>;
using buf_b = thread_buffer<typename Impl::BVecType, kKIter>;
static_assert(iKIter < kKIter);
// swap A and B, value and type
Impl{}(c_vec,
reinterpret_cast<const buf_b&>(b_vec)
.template get_as<typename Impl::BVecType>()[iKIter],
reinterpret_cast<const buf_a&>(a_vec)
.template get_as<typename Impl::AVecType>()[iKIter],
bool_constant<post_nop_>{});
}
// c_vec = a_vec * b_vec
CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
{
using buf_a = thread_buffer<typename Impl::AVecType, kKIter>;
using buf_b = thread_buffer<typename Impl::BVecType, kKIter>;
constexpr auto I0 = number<0>{};
// swap A and B, value and type
auto c_vec = Impl{}(
reinterpret_cast<const buf_b&>(b_vec).template get_as<typename Impl::BVecType>()[I0],
reinterpret_cast<const buf_a&>(a_vec).template get_as<typename Impl::AVecType>()[I0]);
static_for<1, kKIter, 1>{}([&](auto iKIter) { operator()(c_vec, a_vec, b_vec, iKIter); });
return c_vec;
}
};
template <typename WarpGemmAttributeMfmaImpl_, index_t kKIter, index_t SFactor_ = 2>
struct WarpGemmAttributeMfmaIterateK_SwizzleA
{
using Impl = remove_cvref_t<WarpGemmAttributeMfmaImpl_>;
using ADataType = typename Impl::ADataType;
using BDataType = typename Impl::BDataType;
using CDataType = typename Impl::CDataType;
using AVecType =
ext_vector_t<ADataType, vector_traits<typename Impl::AVecType>::vector_size * kKIter>;
using BVecType =
ext_vector_t<BDataType, vector_traits<typename Impl::BVecType>::vector_size * kKIter>;
using CVecType = typename Impl::CVecType;
static constexpr index_t kM = Impl::kM;
static constexpr index_t kN = Impl::kN;
static constexpr index_t kK = Impl::kK * kKIter;
static constexpr index_t kKPerThread = Impl::kABKPerLane * kKIter;
static constexpr index_t SFactor = SFactor_; // group how many CM1 together
CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return kKIter; }
static_assert(Impl::kAMBlock == 1 && Impl::kBNBlock == 1,
"Multi-block WarpGemmAttributeMfmaImpl is not supported");
using AWarpDstrEncoding = tile_distribution_encoding<
sequence<>,
tuple<sequence<Impl::kAMLane / (Impl::kCMLane * SFactor * Impl::kCM1PerLane),
Impl::kCMLane,
SFactor,
Impl::kCM1PerLane>,
sequence<Impl::kABKLane, Impl::kABKPerLane * kKIter>>,
tuple<sequence<2, 1, 1, 1, 1>>,
tuple<sequence<0, 0, 2, 1, 3>>,
sequence<2>,
sequence<1>>;
using BWarpDstrEncoding = tile_distribution_encoding<
sequence<>,
tuple<sequence<Impl::kBNLane>, sequence<Impl::kABKLane, Impl::kABKPerLane * kKIter>>,
tuple<sequence<2, 1>>,
tuple<sequence<0, 0>>,
sequence<2>,
sequence<1>>;
using CWarpDstrEncoding = tile_distribution_encoding<
sequence<>,
tuple<sequence<Impl::kCM0PerLane / SFactor, Impl::kCMLane, Impl::kCM1PerLane * SFactor>,
sequence<Impl::kCNLane>>,
tuple<sequence<1, 2>>,
tuple<sequence<1, 0>>,
sequence<1, 1>,
sequence<0, 2>>;
// c_vec += a_vec * b_vec
template <bool post_nop_ = false>
CK_TILE_DEVICE void operator()(CVecType& c_vec,
const AVecType& a_vec,
const BVecType& b_vec,
bool_constant<post_nop_> = {}) const
{
static_for<0, kKIter, 1>{}([&](auto iKIter) { operator()(c_vec, a_vec, b_vec, iKIter); });
}
template <index_t iKIter, bool post_nop_ = false>
CK_TILE_DEVICE void operator()(CVecType& c_vec,
const AVecType& a_vec,
const BVecType& b_vec,
number<iKIter>,
bool_constant<post_nop_> = {}) const
{
using buf_a = thread_buffer<typename Impl::AVecType, kKIter>;
using buf_b = thread_buffer<typename Impl::BVecType, kKIter>;
static_assert(iKIter < kKIter);
Impl{}(c_vec,
reinterpret_cast<const buf_a&>(a_vec)
.template get_as<typename Impl::AVecType>()[iKIter],
reinterpret_cast<const buf_b&>(b_vec)
.template get_as<typename Impl::BVecType>()[iKIter],
bool_constant<post_nop_>{});
}
// c_vec = a_vec * b_vec
CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
{
constexpr auto I0 = number<0>{};
using buf_a = thread_buffer<typename Impl::AVecType, kKIter>;
using buf_b = thread_buffer<typename Impl::BVecType, kKIter>;
auto c_vec = Impl{}(
reinterpret_cast<const buf_a&>(a_vec).template get_as<typename Impl::AVecType>()[I0],
reinterpret_cast<const buf_b&>(b_vec).template get_as<typename Impl::BVecType>()[I0]);
static_for<1, kKIter, 1>{}([&](auto iKIter) { operator()(c_vec, a_vec, b_vec, iKIter); });
return c_vec;
}
};
} // namespace ck_tile

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,85 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_attribute_smfmac_impl.hpp"
namespace ck_tile {
/**
* @brief Class describing structured sparsity mfma instructions.
*
* @paragraph Overview "Overview"
* Currently only 2:4 structured sparsity is supported, which is based on requirement that in every
* groups of four continuous elements there are at most two non-zero, which results in processing
* only half of elements in smfmac instruction. Because of structured sparsity A vector in smfmac
* instruction will be smaller than B vector by the factor of CompressionRatio. The indexes of
* non-zero elements are stored in `index` which is an additional parameter to assembly instruction.
* Every pair of two bit indexes are containing information about which two elements in current
* group of 4 values are non-zero and should be used inside smfmac instruction. Structured sparsity
* format is supported only for A matrix for now.
*/
template <typename WarpGemmAttributeSmfmacImpl_>
struct WarpGemmAttributeSmfmac
{
using Impl = remove_cvref_t<WarpGemmAttributeSmfmacImpl_>;
using ADataType = typename Impl::ADataType;
using BDataType = typename Impl::BDataType;
using IdxDataType = typename Impl::IdxDataType;
using CDataType = typename Impl::CDataType;
using AVecType = typename Impl::AVecType;
using BVecType = typename Impl::BVecType;
using CVecType = typename Impl::CVecType;
static constexpr index_t kM = Impl::kM;
static constexpr index_t kN = Impl::kN;
static constexpr index_t kK = Impl::kK;
static constexpr index_t kKPerThread = Impl::kABKPerLane;
static constexpr index_t kCompressionRatio = Impl::CompressionRatio;
CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return 1; }
static_assert(Impl::kAMBlock == 1 && Impl::kBNBlock == 1,
"Multi-block WarpGemmAttributeSmfmacImpl is not supported");
using AWarpDstrEncoding = tile_distribution_encoding<
sequence<>,
tuple<sequence<Impl::kAMLane>, sequence<Impl::kABKLane, Impl::kABKPerLane>>,
tuple<sequence<2, 1>>,
tuple<sequence<0, 0>>,
sequence<2>,
sequence<1>>;
using BWarpDstrEncoding = tile_distribution_encoding<
sequence<>,
tuple<sequence<Impl::kBNLane>, sequence<Impl::kABKLane, Impl::kABKPerLane>>,
tuple<sequence<2, 1>>,
tuple<sequence<0, 0>>,
sequence<2>,
sequence<1>>;
using CWarpDstrEncoding = tile_distribution_encoding<
sequence<>,
tuple<sequence<Impl::kCM0PerLane, Impl::kCMLane, Impl::kCM1PerLane>,
sequence<Impl::kCNLane>>,
tuple<sequence<1, 2>>,
tuple<sequence<1, 0>>,
sequence<1, 1>,
sequence<0, 2>>;
// c_vec += a_vec * b_vec[idx]
template <bool post_nop_ = false>
CK_TILE_DEVICE void operator()(CVecType& c_vec,
const AVecType& a_vec,
const BVecType& b_vec,
const int32_t& idx,
bool_constant<post_nop_> = {}) const
{
Impl{}(c_vec, a_vec, b_vec, idx, bool_constant<post_nop_>{});
}
};
} // namespace ck_tile

View File

@@ -0,0 +1,114 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "warp_gemm_attribute_mfma_impl.hpp"
namespace ck_tile {
// fp16 2:4 structured sparsity
template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
struct WarpGemmAttributeSmfmacImplF16F16F32M32N32K16
{
static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
using ADataType = fp16_t;
using BDataType = fp16_t;
using IdxDataType = int32_t;
using CDataType = float;
using AVecType = ext_vector_t<fp16_t, 4>;
using BVecType = ext_vector_t<fp16_t, 8>;
using CVecType = ext_vector_t<float, 16>;
static constexpr index_t kM = 32;
static constexpr index_t kN = 32;
static constexpr index_t kK = 16;
static constexpr index_t kAMBlock = 1;
static constexpr index_t kBNBlock = 1;
static constexpr index_t kAMLane = 32;
static constexpr index_t kBNLane = 32;
static constexpr index_t kABKLane = 2;
static constexpr index_t kABKPerLane = 8;
static constexpr index_t kCMLane = 2;
static constexpr index_t kCNLane = 32;
static constexpr index_t kCM0PerLane = 4;
static constexpr index_t kCM1PerLane = 4;
static constexpr index_t CompressionRatio = 2;
// c_vec += a_vec * b_vec[idx]
template <bool post_nop_ = false>
CK_TILE_DEVICE void operator()(CVecType& c_vec,
const AVecType& a_vec,
const BVecType& b_vec,
const int32_t& idx,
bool_constant<post_nop_> = {}) const
{
#if defined(__gfx94_) or defined(__gfx95_)
c_vec = __builtin_amdgcn_smfmac_f32_32x32x16_f16(a_vec, b_vec, c_vec, idx, 0, 0);
#else
ck_tile::ignore = c_vec;
ck_tile::ignore = a_vec;
ck_tile::ignore = b_vec;
ck_tile::ignore = idx;
#endif
}
};
template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
struct WarpGemmAttributeSmfmacImplF16F16F32M16N16K32
{
static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
using ADataType = fp16_t;
using BDataType = fp16_t;
using IdxDataType = int32_t;
using CDataType = float;
using AVecType = ext_vector_t<fp16_t, 4>;
using BVecType = ext_vector_t<fp16_t, 8>;
using CVecType = ext_vector_t<float, 4>;
static constexpr index_t kM = 16;
static constexpr index_t kN = 16;
static constexpr index_t kK = 32;
static constexpr index_t kAMBlock = 1;
static constexpr index_t kBNBlock = 1;
static constexpr index_t kAMLane = 16;
static constexpr index_t kBNLane = 16;
static constexpr index_t kABKLane = 4;
static constexpr index_t kABKPerLane = 8;
static constexpr index_t kCMLane = 4;
static constexpr index_t kCNLane = 16;
static constexpr index_t kCM0PerLane = 1;
static constexpr index_t kCM1PerLane = 4;
static constexpr index_t CompressionRatio = 2;
// c_vec += a_vec * b_vec[idx]
template <bool post_nop_ = false>
CK_TILE_DEVICE void operator()(CVecType& c_vec,
const AVecType& a_vec,
const BVecType& b_vec,
const int32_t& idx,
bool_constant<post_nop_> = {}) const
{
#if defined(__gfx94_) or defined(__gfx95_)
c_vec = __builtin_amdgcn_smfmac_f32_16x16x32_f16(a_vec, b_vec, c_vec, idx, 0, 0);
#else
ck_tile::ignore = c_vec;
ck_tile::ignore = a_vec;
ck_tile::ignore = b_vec;
ck_tile::ignore = idx;
#endif
}
};
} // namespace ck_tile

View File

@@ -0,0 +1,176 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/host/device_prop.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma_impl.hpp"
namespace ck_tile {
// TODO: currently only support 16 bit input, which means only support tr16_b128; will use ADataType
// to determine the layout in the future
template <typename Impl>
struct AWarpDstrEncodingTrait
{
using type = tile_distribution_encoding<
sequence<Impl::kRepeat>,
tuple<sequence<Impl::kAMLane>,
sequence<Impl::kABK0PerLane, Impl::kABKLane, Impl::kABK1PerLane>>,
tuple<typename Impl::kABPs2RHssMajor>,
tuple<typename Impl::kABPs2RHssMinor>,
typename Impl::kABYs2RHsMajor,
typename Impl::kABYs2RHsMinor>;
};
template <typename Impl>
struct BWarpDstrEncodingTrait
{
using type = tile_distribution_encoding<
sequence<Impl::kRepeat>,
tuple<sequence<Impl::kBNLane>,
sequence<Impl::kABK0PerLane, Impl::kABKLane, Impl::kABK1PerLane>>,
tuple<typename Impl::kABPs2RHssMajor>,
tuple<typename Impl::kABPs2RHssMinor>,
typename Impl::kABYs2RHsMajor,
typename Impl::kABYs2RHsMinor>;
};
template <typename Impl>
struct CWarpDstrEncodingTrait
{
using type = tile_distribution_encoding<
sequence<>,
tuple<sequence<Impl::kCM0PerLane, Impl::kCMLane, Impl::kCM1PerLane>,
sequence<Impl::kCNLane>>,
tuple<typename Impl::kCPs2RHssMajor>,
tuple<typename Impl::kCPs2RHssMinor>,
typename Impl::kCYs2RHsMajor,
typename Impl::kCYs2RHsMinor>;
};
template <typename Impl>
struct CTransposedWarpDstrEncodingTrait
{
using type = tile_distribution_encoding<
sequence<>,
tuple<sequence<Impl::kCNLane>,
sequence<Impl::kCM0PerLane, Impl::kCMLane, Impl::kCM1PerLane>>,
tuple<typename Impl::kCTPs2RHssMajor>,
tuple<typename Impl::kCTPs2RHssMinor>,
typename Impl::kCTYs2RHsMajor,
typename Impl::kCTYs2RHsMinor>;
};
template <typename WarpGemmAttributeWmmaImpl_, bool kTransC = false>
struct WarpGemmAttributeWmma
{
using Impl = remove_cvref_t<WarpGemmAttributeWmmaImpl_>;
// When kTransC is true and A/B types differ, we need an impl with swapped types
using TransposedImpl =
std::conditional_t<kTransC &&
!std::is_same_v<typename Impl::ADataType, typename Impl::BDataType>,
WarpGemmAttributeWmmaImpl<WmmaTraits<typename Impl::TraitsType::ArchType,
typename Impl::BDataType,
typename Impl::ADataType,
typename Impl::CDataType,
Impl::kM,
Impl::kN,
Impl::kK>>,
Impl>;
using ADataType = typename Impl::ADataType;
using BDataType = typename Impl::BDataType;
using CDataType = typename Impl::CDataType;
using AVecType = typename Impl::AVecType;
using BVecType = typename Impl::BVecType;
using CVecType = typename Impl::CVecType;
static constexpr index_t kM = Impl::kM;
static constexpr index_t kN = Impl::kN;
static constexpr index_t kK = Impl::kK;
static constexpr index_t kCMLane = Impl::kCMLane;
static constexpr index_t kKPerThread = Impl::kABK0PerLane * Impl::kABK1PerLane;
CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return 1; }
// 16 bit input, kAMLane = 16, kABK0PerLane = 4, kABKLane = 2, kABK1PerLane = 2
// 8 bit input, kAMLane = 16, kABK0PerLane = 2, kABKLane = 2, kABK1PerLane = 4
using AWarpDstrEncoding = typename AWarpDstrEncodingTrait<Impl>::type;
using BWarpDstrEncoding = typename BWarpDstrEncodingTrait<Impl>::type;
// kCM0PerLane = 1, kCMLane = 2, kCM1PerLane = 2, kCNLane = 16
using CWarpDstrEncoding =
std::conditional_t<kTransC,
typename CTransposedWarpDstrEncodingTrait<Impl>::type,
typename CWarpDstrEncodingTrait<Impl>::type>;
// c_vec += a_vec * b_vec
template <bool post_nop_ = false>
CK_TILE_DEVICE void operator()(CVecType& c_vec,
const AVecType& a_vec,
const BVecType& b_vec,
bool_constant<post_nop_> = {}) const
{
if constexpr(kTransC)
{
TransposedImpl{}(c_vec, b_vec, a_vec, bool_constant<post_nop_>{});
}
else
{
Impl{}(c_vec, a_vec, b_vec, bool_constant<post_nop_>{});
}
}
// c_vec = a_vec * b_vec
CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
{
if constexpr(kTransC)
{
return TransposedImpl{}(b_vec, a_vec);
}
else
{
return Impl{}(a_vec, b_vec);
}
}
};
template <typename ADataType,
typename BDataType,
typename AccDataType,
index_t M_Warp_Tile,
index_t N_Warp_Tile,
index_t K_Warp_Tile>
CK_TILE_HOST bool check_wmma_supported()
{
if(is_gfx12_supported())
{
return has_wmma_traits_v<gfx12_t,
ADataType,
BDataType,
AccDataType,
M_Warp_Tile,
N_Warp_Tile,
K_Warp_Tile>;
}
else if(is_gfx11_supported())
{
return has_wmma_traits_v<gfx11_t,
ADataType,
BDataType,
AccDataType,
M_Warp_Tile,
N_Warp_Tile,
K_Warp_Tile>;
}
else
{
return false;
}
}
} // namespace ck_tile

View File

@@ -0,0 +1,141 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/core/config.hpp"
namespace ck_tile {
// Base traits for WMMA operations
template <typename Arch,
typename AType,
typename BType,
typename CType,
index_t M,
index_t N,
index_t K>
struct WmmaTraits;
// Generic WMMA implementation using traits
template <typename Traits>
struct WarpGemmAttributeWmmaImpl
{
using TraitsType = Traits;
using ADataType = typename Traits::ADataType;
using BDataType = typename Traits::BDataType;
using CDataType = typename Traits::CDataType;
using AVecType = typename Traits::AVecType;
using BVecType = typename Traits::BVecType;
using CVecType = typename Traits::CVecType;
// Forward all static constants and type aliases
static constexpr index_t kM = Traits::kM;
static constexpr index_t kN = Traits::kN;
static constexpr index_t kK = Traits::kK;
static constexpr index_t kAMBlock = Traits::kAMBlock;
static constexpr index_t kBNBlock = Traits::kBNBlock;
static constexpr index_t kRepeat = Traits::kRepeat;
static constexpr index_t kAMLane = Traits::kAMLane;
static constexpr index_t kBNLane = Traits::kBNLane;
static constexpr index_t kABK0PerLane = Traits::kABK0PerLane;
static constexpr index_t kABKLane = Traits::kABKLane;
static constexpr index_t kABK1PerLane = Traits::kABK1PerLane;
static constexpr index_t kCMLane = Traits::kCMLane;
static constexpr index_t kCNLane = Traits::kCNLane;
static constexpr index_t kCM0PerLane = Traits::kCM0PerLane;
static constexpr index_t kCM1PerLane = Traits::kCM1PerLane;
using kABPs2RHssMajor = typename Traits::kABPs2RHssMajor;
using kABPs2RHssMinor = typename Traits::kABPs2RHssMinor;
using kABYs2RHsMajor = typename Traits::kABYs2RHsMajor;
using kABYs2RHsMinor = typename Traits::kABYs2RHsMinor;
using kCPs2RHssMajor = typename Traits::kCPs2RHssMajor;
using kCPs2RHssMinor = typename Traits::kCPs2RHssMinor;
using kCYs2RHsMajor = typename Traits::kCYs2RHsMajor;
using kCYs2RHsMinor = typename Traits::kCYs2RHsMinor;
using kCTPs2RHssMajor = typename Traits::kCTPs2RHssMajor;
using kCTPs2RHssMinor = typename Traits::kCTPs2RHssMinor;
using kCTYs2RHsMajor = typename Traits::kCTYs2RHsMajor;
using kCTYs2RHsMinor = typename Traits::kCTYs2RHsMinor;
// c_vec += a_vec * b_vec
template <bool clamp = false, bool post_nop_ = false>
CK_TILE_DEVICE void operator()(CVecType& c_vec,
const AVecType& a_vec,
const BVecType& b_vec,
bool_constant<post_nop_> = {}) const
{
c_vec = Traits::template wmma_intrinsic<clamp>(a_vec, b_vec, c_vec);
}
// c_vec = a_vec * b_vec
template <bool clamp = false>
CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
{
return bit_cast<CVecType>(
Traits::template wmma_intrinsic<clamp>(a_vec, b_vec, CVecType{0.f}));
}
};
using DeviceIp = remove_cvref_t<decltype(ck_tile::get_device_arch())>;
using WarpGemmAttributeWmmaImpl_f32_16x16x16_f16_f16 =
WarpGemmAttributeWmmaImpl<WmmaTraits<DeviceIp, fp16_t, fp16_t, float, 16, 16, 16>>;
using WarpGemmAttributeWmmaImpl_f32_16x16x16_bf16_bf16 =
WarpGemmAttributeWmmaImpl<WmmaTraits<DeviceIp, bf16_t, bf16_t, float, 16, 16, 16>>;
using WarpGemmAttributeWmmaImpl_i32_16x16x16_i8_i8 =
WarpGemmAttributeWmmaImpl<WmmaTraits<DeviceIp, int8_t, int8_t, int32_t, 16, 16, 16>>;
using WarpGemmAttributeWmmaImpl_f32_16x16x16_f8_f8 =
WarpGemmAttributeWmmaImpl<WmmaTraits<gfx12_t, fp8_t, fp8_t, float, 16, 16, 16>>;
using WarpGemmAttributeWmmaImpl_f32_16x16x16_bf8_bf8 =
WarpGemmAttributeWmmaImpl<WmmaTraits<gfx12_t, bf8_t, bf8_t, float, 16, 16, 16>>;
using WarpGemmAttributeWmmaImpl_f32_16x16x16_f8_bf8 =
WarpGemmAttributeWmmaImpl<WmmaTraits<gfx12_t, fp8_t, bf8_t, float, 16, 16, 16>>;
using WarpGemmAttributeWmmaImpl_f32_16x16x16_bf8_f8 =
WarpGemmAttributeWmmaImpl<WmmaTraits<gfx12_t, bf8_t, fp8_t, float, 16, 16, 16>>;
template <typename Arch,
typename AType,
typename BType,
typename CType,
index_t warp_m,
index_t warp_n,
index_t warp_k>
struct has_wmma_traits
{
template <typename T>
static auto
test(int) -> decltype(std::declval<
typename WmmaTraits<T, AType, BType, CType, warp_m, warp_n, warp_k>::
ADataType>(),
std::true_type{});
template <typename>
static std::false_type test(...);
static constexpr bool value = decltype(test<Arch>(0))::value;
};
template <typename Arch,
typename AType,
typename BType,
typename CType,
index_t warp_m,
index_t warp_n,
index_t warp_k>
constexpr bool has_wmma_traits_v =
has_wmma_traits<Arch, AType, BType, CType, warp_m, warp_n, warp_k>::value;
} // namespace ck_tile

View File

@@ -0,0 +1,95 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "warp_gemm_attribute_wmma_impl_base_traits.hpp"
namespace ck_tile {
// fp16 specialization - GFX11
template <>
struct WmmaTraits<gfx11_t, fp16_t, fp16_t, float, 16, 16, 16>
: WmmaTraitsBase<gfx11_t, fp16_t, fp16_t, float>
{
using ArchType = gfx11_t;
template <bool clamp = false>
CK_TILE_DEVICE static CVecType
wmma_intrinsic(const AVecType& a_vec, const BVecType& b_vec, const CVecType& c_vec)
{
#ifdef __gfx11__
return __builtin_amdgcn_wmma_f32_16x16x16_f16_w32(a_vec, b_vec, c_vec);
#else
ck_tile::ignore = a_vec;
ck_tile::ignore = b_vec;
ck_tile::ignore = c_vec;
return CVecType{0.f};
#endif
}
};
// bf16 specialization - GFX11
template <>
struct WmmaTraits<gfx11_t, bf16_t, bf16_t, float, 16, 16, 16>
: WmmaTraitsBase<gfx11_t, bf16_t, bf16_t, float>
{
using ArchType = gfx11_t;
template <bool clamp = false>
CK_TILE_DEVICE static CVecType
wmma_intrinsic(const AVecType& a_vec, const BVecType& b_vec, const CVecType& c_vec)
{
#ifdef __gfx11__
return __builtin_amdgcn_wmma_f32_16x16x16_bf16_w32(a_vec, b_vec, c_vec);
#else
ck_tile::ignore = a_vec;
ck_tile::ignore = b_vec;
ck_tile::ignore = c_vec;
return CVecType{0.f};
#endif
}
};
// fp16 specialization - GFX12
template <>
struct WmmaTraits<gfx12_t, fp16_t, fp16_t, float, 16, 16, 16>
: WmmaTraitsBase<gfx12_t, fp16_t, fp16_t, float>
{
using ArchType = gfx12_t;
template <bool clamp = false>
CK_TILE_DEVICE static CVecType
wmma_intrinsic(const AVecType& a_vec, const BVecType& b_vec, const CVecType& c_vec)
{
#ifdef __gfx12__
return __builtin_amdgcn_wmma_f32_16x16x16_f16_w32_gfx12(a_vec, b_vec, c_vec);
#else
ck_tile::ignore = a_vec;
ck_tile::ignore = b_vec;
ck_tile::ignore = c_vec;
return CVecType{0.f};
#endif
}
};
// bf16 specialization - GFX12
template <>
struct WmmaTraits<gfx12_t, bf16_t, bf16_t, float, 16, 16, 16>
: WmmaTraitsBase<gfx12_t, bf16_t, bf16_t, float>
{
using ArchType = gfx12_t;
template <bool clamp = false>
CK_TILE_DEVICE static CVecType
wmma_intrinsic(const AVecType& a_vec, const BVecType& b_vec, const CVecType& c_vec)
{
#ifdef __gfx12__
return __builtin_amdgcn_wmma_f32_16x16x16_bf16_w32_gfx12(a_vec, b_vec, c_vec);
#else
ck_tile::ignore = a_vec;
ck_tile::ignore = b_vec;
ck_tile::ignore = c_vec;
return CVecType{0.f};
#endif
}
};
} // namespace ck_tile

View File

@@ -0,0 +1,148 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "warp_gemm_attribute_wmma_impl_base_traits.hpp"
namespace ck_tile {
// int8 specialization - GFX11
template <>
struct WmmaTraits<gfx11_t, int8_t, int8_t, int32_t, 16, 16, 16>
: WmmaTraitsBase<gfx11_t, int8_t, int8_t, int32_t>
{
using ArchType = gfx11_t;
template <bool clamp = false>
CK_TILE_DEVICE static CVecType
wmma_intrinsic(const AVecType& a_vec, const BVecType& b_vec, const CVecType& c_vec)
{
#ifdef __gfx11__
return __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32(true, // neg_a
bit_cast<int32x4_t>(a_vec),
true, // neg_b
bit_cast<int32x4_t>(b_vec),
bit_cast<int32x8_t>(c_vec),
clamp);
#else
ck_tile::ignore = a_vec;
ck_tile::ignore = b_vec;
ck_tile::ignore = c_vec;
return CVecType{0};
#endif
}
};
// int8 specialization - GFX12
template <>
struct WmmaTraits<gfx12_t, int8_t, int8_t, int32_t, 16, 16, 16>
: WmmaTraitsBase<gfx12_t, int8_t, int8_t, int32_t>
{
using ArchType = gfx12_t;
template <bool clamp = false>
CK_TILE_DEVICE static CVecType
wmma_intrinsic(const AVecType& a_vec, const BVecType& b_vec, const CVecType& c_vec)
{
#ifdef __gfx12__
return __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12(true, // neg_a
bit_cast<int32x2_t>(a_vec),
true, // neg_b
bit_cast<int32x2_t>(b_vec),
bit_cast<int32x8_t>(c_vec),
clamp);
#else
ck_tile::ignore = a_vec;
ck_tile::ignore = b_vec;
ck_tile::ignore = c_vec;
return CVecType{0};
#endif
}
};
// fp8/bf8 specialization - GFX12
template <>
struct WmmaTraits<gfx12_t, fp8_t, fp8_t, float, 16, 16, 16>
: WmmaTraitsBase<gfx12_t, fp8_t, fp8_t, float>
{
using ArchType = gfx12_t;
template <bool clamp = false>
CK_TILE_DEVICE static CVecType
wmma_intrinsic(const AVecType& a_vec, const BVecType& b_vec, const CVecType& c_vec)
{
#ifdef __gfx12__
return __builtin_amdgcn_wmma_f32_16x16x16_fp8_fp8_w32_gfx12(
bit_cast<int32x2_t>(a_vec), bit_cast<int32x2_t>(b_vec), bit_cast<fp32x8_t>(c_vec));
#else
ck_tile::ignore = a_vec;
ck_tile::ignore = b_vec;
ck_tile::ignore = c_vec;
return CVecType{0};
#endif
}
};
template <>
struct WmmaTraits<gfx12_t, bf8_t, bf8_t, float, 16, 16, 16>
: WmmaTraitsBase<gfx12_t, bf8_t, bf8_t, float>
{
using ArchType = gfx12_t;
template <bool clamp = false>
CK_TILE_DEVICE static CVecType
wmma_intrinsic(const AVecType& a_vec, const BVecType& b_vec, const CVecType& c_vec)
{
#ifdef __gfx12__
return __builtin_amdgcn_wmma_f32_16x16x16_bf8_bf8_w32_gfx12(
bit_cast<int32x2_t>(a_vec), bit_cast<int32x2_t>(b_vec), bit_cast<fp32x8_t>(c_vec));
#else
ck_tile::ignore = a_vec;
ck_tile::ignore = b_vec;
ck_tile::ignore = c_vec;
return CVecType{0};
#endif
}
};
template <>
struct WmmaTraits<gfx12_t, fp8_t, bf8_t, float, 16, 16, 16>
: WmmaTraitsBase<gfx12_t, fp8_t, bf8_t, float>
{
using ArchType = gfx12_t;
template <bool clamp = false>
CK_TILE_DEVICE static CVecType
wmma_intrinsic(const AVecType& a_vec, const BVecType& b_vec, const CVecType& c_vec)
{
#ifdef __gfx12__
return __builtin_amdgcn_wmma_f32_16x16x16_fp8_bf8_w32_gfx12(
bit_cast<int32x2_t>(a_vec), bit_cast<int32x2_t>(b_vec), bit_cast<fp32x8_t>(c_vec));
#else
ck_tile::ignore = a_vec;
ck_tile::ignore = b_vec;
ck_tile::ignore = c_vec;
return CVecType{0};
#endif
}
};
template <>
struct WmmaTraits<gfx12_t, bf8_t, fp8_t, float, 16, 16, 16>
: WmmaTraitsBase<gfx12_t, bf8_t, fp8_t, float>
{
template <bool clamp = false>
CK_TILE_DEVICE static CVecType
wmma_intrinsic(const AVecType& a_vec, const BVecType& b_vec, const CVecType& c_vec)
{
#ifdef __gfx12__
return __builtin_amdgcn_wmma_f32_16x16x16_bf8_fp8_w32_gfx12(
bit_cast<int32x2_t>(a_vec), bit_cast<int32x2_t>(b_vec), bit_cast<fp32x8_t>(c_vec));
#else
ck_tile::ignore = a_vec;
ck_tile::ignore = b_vec;
ck_tile::ignore = c_vec;
return CVecType{0};
#endif
}
};
} // namespace ck_tile

View File

@@ -0,0 +1,106 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
namespace ck_tile {
template <typename Arch, typename ADType, typename BDType, typename CDType>
struct WmmaTraitsBase;
// GFX11 specialization
template <typename ADType, typename BDType, typename CDType>
struct WmmaTraitsBase<gfx11_t, ADType, BDType, CDType>
{
using ArchType = gfx11_t;
using ADataType = ADType;
using BDataType = BDType;
using CDataType = CDType;
using AVecType = ext_vector_t<ADataType, 16>;
using BVecType = ext_vector_t<BDataType, 16>;
using CVecType = ext_vector_t<CDataType, 8>;
static constexpr index_t kM = 16;
static constexpr index_t kN = 16;
static constexpr index_t kK = 16;
static constexpr index_t kAMBlock = 1;
static constexpr index_t kBNBlock = 1;
static constexpr index_t kRepeat = 2;
static constexpr index_t kAMLane = 16;
static constexpr index_t kBNLane = 16;
static constexpr index_t kABK0PerLane = 1;
static constexpr index_t kABKLane = 1;
static constexpr index_t kABK1PerLane = 16;
static constexpr index_t kCMLane = 2;
static constexpr index_t kCNLane = 16;
static constexpr index_t kCM0PerLane = 8;
static constexpr index_t kCM1PerLane = 1;
using kABPs2RHssMajor = sequence<0, 2, 1>;
using kABPs2RHssMinor = sequence<0, 1, 0>;
using kABYs2RHsMajor = sequence<2, 2>;
using kABYs2RHsMinor = sequence<0, 2>;
using kCPs2RHssMajor = sequence<1, 2>;
using kCPs2RHssMinor = sequence<1, 0>;
using kCYs2RHsMajor = sequence<1, 1>;
using kCYs2RHsMinor = sequence<0, 2>;
using kCTPs2RHssMajor = sequence<2, 1>;
using kCTPs2RHssMinor = sequence<1, 0>;
using kCTYs2RHsMajor = sequence<2, 2>;
using kCTYs2RHsMinor = sequence<0, 2>;
};
// GFX12 specialization
template <typename ADType, typename BDType, typename CDType>
struct WmmaTraitsBase<gfx12_t, ADType, BDType, CDType>
{
using ArchType = gfx12_t;
using ADataType = ADType;
using BDataType = BDType;
using CDataType = CDType;
using AVecType = ext_vector_t<ADataType, 8>;
using BVecType = ext_vector_t<BDataType, 8>;
using CVecType = ext_vector_t<CDataType, 8>;
static constexpr index_t kM = 16;
static constexpr index_t kN = 16;
static constexpr index_t kK = 16;
static constexpr index_t kAMBlock = 1;
static constexpr index_t kBNBlock = 1;
static constexpr index_t kRepeat = 1;
static constexpr index_t kAMLane = 16;
static constexpr index_t kBNLane = 16;
static constexpr index_t kABK0PerLane = 1;
static constexpr index_t kABKLane = 2;
static constexpr index_t kABK1PerLane = 8;
static constexpr index_t kCMLane = 2;
static constexpr index_t kCNLane = 16;
static constexpr index_t kCM0PerLane = 1;
static constexpr index_t kCM1PerLane = 8;
using kABPs2RHssMajor = sequence<2, 1>;
using kABPs2RHssMinor = sequence<1, 0>;
using kABYs2RHsMajor = sequence<2, 2>;
using kABYs2RHsMinor = sequence<0, 2>;
using kCPs2RHssMajor = sequence<1, 2>;
using kCPs2RHssMinor = sequence<1, 0>;
using kCYs2RHsMajor = sequence<1, 1>;
using kCYs2RHsMinor = sequence<0, 2>;
using kCTPs2RHssMajor = sequence<2, 1>;
using kCTPs2RHssMinor = sequence<1, 0>;
using kCTYs2RHsMajor = sequence<2, 2>;
using kCTYs2RHsMinor = sequence<0, 2>;
};
} // namespace ck_tile

View File

@@ -0,0 +1,207 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm.hpp"
#include "ck_tile/ops/gemm/warp/warp_wmma_gemm.hpp"
namespace ck_tile {
namespace impl {
namespace warp_gemm_dispatcher {
// C++20 using enum
static inline constexpr auto ESingle = WGAttrNumAccessEnum::Single;
static inline constexpr auto EDouble = WGAttrNumAccessEnum::Double;
static inline constexpr auto EQuad = WGAttrNumAccessEnum::Quad;
template <typename AType,
typename BType,
typename AccType,
index_t MPerWave,
index_t NPerWave,
index_t KPerWave,
bool TransposeC,
bool SwizzleA = false,
bool UseStructuredSparsity = false,
WGAttrNumAccessEnum AttrNumAccessA = ESingle,
WGAttrNumAccessEnum AttrNumAccessB = AttrNumAccessA>
struct Dispatcher;
// clang-format off
// fp32
// ADataType, BDataType, AccDataType, MPerWave, NPerWave, KPerWave, TransposeC, SwizzleA, UseStructuredSparsity
template<> struct Dispatcher<float, float, float, 16, 16, 4, false> { using Type = WarpGemmMfmaF32F32F32M16N16K4; };
template<> struct Dispatcher<float, float, float, 16, 16, 16, false> { using Type = WarpGemmMfmaF32F32F32M16N16K16<>; };
template<> struct Dispatcher<float, float, float, 16, 16, 8, false> { using Type = WarpGemmMfmaF32F32F32M16N16K8<>; };
template<> struct Dispatcher<float, float, float, 32, 32, 4, false> { using Type = WarpGemmMfmaF32F32F32M32N32K4<>; };
template<> struct Dispatcher<float, float, float, 32, 32, 8, false> { using Type = WarpGemmMfmaF32F32F32M32N32K8<>; };
template<> struct Dispatcher<float, float, float, 32, 32, 8, false, false, false, EDouble> { using Type = WarpGemmMfmaF32F32F32M32N32K8<EDouble>; };
template<> struct Dispatcher<float, float, float, 16, 16, 16, true> { using Type = WarpGemmMfmaF32F32F32M16N16K16TransposedCDistribution<>; };
// fp16
// ADataType, BDataType, AccDataType, MPerWave, NPerWave, KPerWave, TransposeC, SwizzleA, UseStructuredSparsity
template<> struct Dispatcher<half_t, half_t, float, 32, 32, 8, false> { using Type = WarpGemmMfmaF16F16F32M32N32K8; };
template<> struct Dispatcher<half_t, half_t, float, 32, 32, 8, true> { using Type = WarpGemmMfmaF16F16F32M32N32K8TransposedCDistribution; };
template<> struct Dispatcher<half_t, half_t, float, 32, 32, 16, false> { using Type = WarpGemmMfmaF16F16F32M32N32K16<>; };
template<> struct Dispatcher<half_t, half_t, float, 32, 32, 16, true> { using Type = WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution<>; };
template<> struct Dispatcher<half_t, half_t, float, 32, 32, 16, false, false, false, EDouble> { using Type = WarpGemmMfmaF16F16F32M32N32K16<EDouble>; };
template<> struct Dispatcher<half_t, half_t, float, 32, 32, 16, true, false, false, EDouble> { using Type = WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution<EDouble>; };
template<> struct Dispatcher<half_t, half_t, float, 16, 16, 32, false> { using Type = WarpGemmMfmaF16F16F32M16N16K32<>; };
template<> struct Dispatcher<half_t, half_t, float, 16, 16, 32, true> { using Type = WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution<>; };
template<> struct Dispatcher<half_t, half_t, float, 16, 16, 32, false, false, false, EDouble> { using Type = WarpGemmMfmaF16F16F32M16N16K32<EDouble>; };
template<> struct Dispatcher<half_t, half_t, float, 16, 16, 32, true, false, false, EDouble> { using Type = WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution<EDouble>; };
template<> struct Dispatcher<half_t, half_t, float, 4, 64, 16, false> { using Type = WarpGemmMfmaF16F16F32M4N64K16; };
template<> struct Dispatcher<half_t, half_t, float, 64, 4, 16, false> { using Type = WarpGemmMfmaF16F16F32M64N4K16; };
// WMMA cases
#if defined(__gfx11__) || defined(__gfx12__)
template<bool TransposeC> struct Dispatcher<half_t, half_t, float, 16, 16, 16, TransposeC, false> { using Type = WarpGemmWmma_f32_16x16x16_f16_f16<TransposeC>;};
#else
template<> struct Dispatcher<half_t, half_t, float, 16, 16, 16, false> { using Type = WarpGemmMfmaF16F16F32M16N16K16; };
template<> struct Dispatcher<half_t, half_t, float, 16, 16, 16, true> { using Type = WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution; };
#endif
template<> struct Dispatcher<half_t, half_t, float, 32, 32, 8, false, true> { using Type = WarpGemmMfmaF16F16F32M32N32K8SwizzleA; };
template<> struct Dispatcher<half_t, half_t, float, 32, 32, 16, false, true> { using Type = WarpGemmMfmaF16F16F32M32N32K16SwizzleA; };
template<> struct Dispatcher<half_t, half_t, float, 32, 32, 8, true, true> { using Type = WarpGemmMfmaF16F16F32M32N32K8SwizzleBTransposedCDistribution; };
template<> struct Dispatcher<half_t, half_t, float, 32, 32, 16, true, true> { using Type = WarpGemmMfmaF16F16F32M32N32K16SwizzleBTransposedCDistribution; };
// fp16 2:4 structural sparsity
// ADataType, BDataType, AccDataType, MPerWave, NPerWave, KPerWave, TransposeC, SwizzleA, UseStructuredSparsity
template<> struct Dispatcher<half_t, half_t, float, 32, 32, 16, false, false, true> { using Type = WarpGemmSmfmacF16F16F32M32N32K16; };
template<> struct Dispatcher<half_t, half_t, float, 16, 16, 32, false, false, true> { using Type = WarpGemmSmfmacF16F16F32M16N16K32; };
// bf16
// ADataType, BDataType, AccDataType, MPerWave, NPerWave, KPerWave, TransposeC, SwizzleA, UseStructuredSparsity
template<> struct Dispatcher<bf16_t, bf16_t, float, 32, 32, 8, false> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K8; };
template<> struct Dispatcher<bf16_t, bf16_t, float, 32, 32, 8, true> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K8TransposedCDistribution; };
template<> struct Dispatcher<bf16_t, bf16_t, float, 32, 32, 16, false> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K16<>; };
template<> struct Dispatcher<bf16_t, bf16_t, float, 32, 32, 16, true> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution<>; };
template<> struct Dispatcher<bf16_t, bf16_t, float, 32, 32, 16, false, false, false, EDouble> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K16<EDouble>; };
template<> struct Dispatcher<bf16_t, bf16_t, float, 32, 32, 16, true, false, false, EDouble> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution<EDouble>; };
template<> struct Dispatcher<bf16_t, bf16_t, float, 16, 16, 32, false> { using Type = WarpGemmMfmaBf16Bf16F32M16N16K32<>; };
template<> struct Dispatcher<bf16_t, bf16_t, float, 16, 16, 32, false, false, false, EDouble, ESingle> { using Type = WarpGemmMfmaBf16Bf16F32M16N16K32<EDouble, ESingle>; };
template<> struct Dispatcher<bf16_t, bf16_t, float, 16, 16, 64, false, false, false, EQuad, ESingle> { using Type = WarpGemmMfmaBf16Bf16F32M16N16K64<EQuad, ESingle>; };
template<> struct Dispatcher<bf16_t, bf16_t, float, 16, 16, 64, false, false, false, EQuad> { using Type = WarpGemmMfmaBf16Bf16F32M16N16K64<EQuad>; };
template<> struct Dispatcher<bf16_t, bf16_t, float, 16, 16, 64, false> { using Type = WarpGemmMfmaBf16Bf16F32M16N16K64<>; };
template<> struct Dispatcher<bf16_t, bf16_t, float, 16, 16, 32, true> { using Type = WarpGemmMfmaBf16Bf16F32M16N16K32TransposedCDistribution<>; };
template<> struct Dispatcher<bf16_t, bf16_t, float, 16, 16, 32, false, false, false, EDouble> { using Type = WarpGemmMfmaBf16Bf16F32M16N16K32<EDouble>; };
template<> struct Dispatcher<bf16_t, bf16_t, float, 16, 16, 32, true, false, false, EDouble> { using Type = WarpGemmMfmaBf16Bf16F32M16N16K32TransposedCDistribution<EDouble>; };
template<> struct Dispatcher<bf16_t, bf16_t, float, 4, 64, 16, false> { using Type = WarpGemmMfmaBf16Bf16F32M4N64K16; };
template<> struct Dispatcher<bf16_t, bf16_t, float, 64, 4, 16, false> { using Type = WarpGemmMfmaBf16Bf16F32M64N4K16; };
// WMMA cases
#if defined(__gfx11__) || defined(__gfx12__)
template<bool TransposeC> struct Dispatcher<bf16_t, bf16_t, float, 16, 16, 16, TransposeC, false> { using Type = WarpGemmWmma_f32_16x16x16_bf16_bf16<TransposeC>; };
#else
template<> struct Dispatcher<bf16_t, bf16_t, float, 16, 16, 16, false> { using Type = WarpGemmMfmaBf16Bf16F32M16N16K16; };
template<> struct Dispatcher<bf16_t, bf16_t, float, 16, 16, 16, true> { using Type = WarpGemmMfmaBf16Bf16F32M16N16K16TransposedCDistribution; };
#endif
template<> struct Dispatcher<bf16_t, bf16_t, float, 32, 32, 8, false, true> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K8SwizzleA; };
template<> struct Dispatcher<bf16_t, bf16_t, float, 32, 32, 16, false, true> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleA; };
template<> struct Dispatcher<bf16_t, bf16_t, float, 32, 32, 8, true, true> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K8SwizzleBTransposedCDistribution; };
template<> struct Dispatcher<bf16_t, bf16_t, float, 32, 32, 16, true, true> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleBTransposedCDistribution; };
// fp8
// ADataType, BDataType, AccDataType, MPerWave, NPerWave, KPerWave, TransposeC, SwizzleA, UseStructuredSparsity
template<> struct Dispatcher<fp8_t, fp8_t, float, 32, 32, 16, false> { using Type = WarpGemmMfma_f32_32x32x16_fp8_fp8; };
template<> struct Dispatcher<fp8_t, fp8_t, float, 16, 16, 32, false> { using Type = WarpGemmMfma_f32_16x16x32_fp8_fp8; };
template<> struct Dispatcher<fp8_t, fp8_t, float, 32, 32, 16, true> { using Type = WarpGemmMfma_f32_32x32x16_fp8_fp8_CTransposed; };
template<> struct Dispatcher<fp8_t, fp8_t, float, 16, 16, 32, true> { using Type = WarpGemmMfma_f32_16x16x32_fp8_fp8_CTransposed; };
template<> struct Dispatcher<fp8_t, bf8_t, float, 32, 32, 16, false> { using Type = WarpGemmMfma_f32_32x32x16_fp8_bf8; };
template<> struct Dispatcher<fp8_t, bf8_t, float, 32, 32, 16, true> { using Type = WarpGemmMfma_f32_32x32x16_fp8_bf8_CTransposed; };
template<> struct Dispatcher<fp8_t, bf8_t, float, 16, 16, 32, false> { using Type = WarpGemmMfma_f32_16x16x32_fp8_bf8; };
template<> struct Dispatcher<fp8_t, bf8_t, float, 32, 32, 32, false> { using Type = WarpGemmMfma_f32_32x32x32_fp8_bf8; };
template<> struct Dispatcher<bf8_t, fp8_t, float, 32, 32, 16, false> { using Type = WarpGemmMfma_f32_32x32x16_bf8_fp8; };
template<> struct Dispatcher<bf8_t, fp8_t, float, 32, 32, 16, true> { using Type = WarpGemmMfma_f32_32x32x16_bf8_fp8_CTransposed; };
template<> struct Dispatcher<bf8_t, bf8_t, float, 32, 32, 16, false> { using Type = WarpGemmMfma_f32_32x32x16_bf8_bf8; };
template<> struct Dispatcher<bf8_t, bf8_t, float, 16, 16, 32, false> { using Type = WarpGemmMfma_f32_16x16x32_bf8_bf8; };
template<> struct Dispatcher<bf8_t, bf8_t, float, 16, 16, 32, true> { using Type = WarpGemmMfma_f32_16x16x32_bf8_bf8_CTransposed; };
template<> struct Dispatcher<bf8_t, bf8_t, float, 16, 16, 64, false> { using Type = WarpGemmMfma_f32_16x16x64_bf8_bf8; };
template<> struct Dispatcher<bf8_t, bf8_t, float, 16, 16, 64, true> { using Type = WarpGemmMfma_f32_16x16x64_bf8_bf8_CTransposed; };
template<> struct Dispatcher<bf8_t, bf8_t, float, 32, 32, 16, true> { using Type = WarpGemmMfma_f32_32x32x16_bf8_bf8_CTransposed; };
// scale mfma based f8f6f4
template<typename A, typename B, WGAttrNumAccessEnum I>
struct Dispatcher<A, B, float, 16, 16, 128, false, false, false, I> { using Type = WarpGemmMfma_f32_16x16x128_f8f6f4<A, B, I>; };
template<WGAttrNumAccessEnum I> struct Dispatcher<fp8_t, fp8_t, float, 16, 16, 128, true, false, false, I> { using Type = WarpGemmMfma_f32_16x16x128_fp8_fp8_CTransposed<I>; };
template<WGAttrNumAccessEnum I> struct Dispatcher<fp8_t, bf8_t, float, 16, 16, 128, true, false, false, I> { using Type = WarpGemmMfma_f32_16x16x128_fp8_bf8_CTransposed<I>; };
template<WGAttrNumAccessEnum I> struct Dispatcher<bf8_t, fp8_t, float, 16, 16, 128, true, false, false, I> { using Type = WarpGemmMfma_f32_16x16x128_bf8_fp8_CTransposed<I>; };
template<WGAttrNumAccessEnum I> struct Dispatcher<bf8_t, bf8_t, float, 16, 16, 128, true, false, false, I> { using Type = WarpGemmMfma_f32_16x16x128_bf8_bf8_CTransposed<I>; };
template<WGAttrNumAccessEnum I> struct Dispatcher<pk_fp4_t, pk_fp4_t, float, 16, 16, 128, true, false, false, I> { using Type = WarpGemmMfma_f32_16x16x128_fp4_fp4_CTransposed<I>; };
template<> struct Dispatcher<fp8_t, fp8_t, float, 32, 32, 64, false> { using Type = WarpGemmMfma_f32_32x32x64_fp8_fp8<>; };
template<> struct Dispatcher<fp8_t, bf8_t, float, 32, 32, 64, false> { using Type = WarpGemmMfma_f32_32x32x64_fp8_bf8<>; };
template<> struct Dispatcher<bf8_t, fp8_t, float, 32, 32, 64, false> { using Type = WarpGemmMfma_f32_32x32x64_bf8_fp8<>; };
template<> struct Dispatcher<bf8_t, bf8_t, float, 32, 32, 64, false> { using Type = WarpGemmMfma_f32_32x32x64_bf8_bf8<>; };
template<> struct Dispatcher<fp8_t, fp8_t, float, 32, 32, 64, false, false, false, EDouble> { using Type = WarpGemmMfma_f32_32x32x64_fp8_fp8<EDouble>; };
template<> struct Dispatcher<fp8_t, bf8_t, float, 32, 32, 64, false, false, false, EDouble> { using Type = WarpGemmMfma_f32_32x32x64_fp8_bf8<EDouble>; };
template<> struct Dispatcher<bf8_t, fp8_t, float, 32, 32, 64, false, false, false, EDouble> { using Type = WarpGemmMfma_f32_32x32x64_bf8_fp8<EDouble>; };
template<> struct Dispatcher<bf8_t, bf8_t, float, 32, 32, 64, false, false, false, EDouble> { using Type = WarpGemmMfma_f32_32x32x64_bf8_bf8<EDouble>; };
template<> struct Dispatcher<fp8_t, fp8_t, float, 32, 32, 64, false, false, false, EQuad> { using Type = WarpGemmMfma_f32_32x32x64_fp8_fp8<EQuad>; };
template<> struct Dispatcher<fp8_t, bf8_t, float, 32, 32, 64, false, false, false, EQuad> { using Type = WarpGemmMfma_f32_32x32x64_fp8_bf8<EQuad>; };
template<> struct Dispatcher<bf8_t, fp8_t, float, 32, 32, 64, false, false, false, EQuad> { using Type = WarpGemmMfma_f32_32x32x64_bf8_fp8<EQuad>; };
template<> struct Dispatcher<bf8_t, bf8_t, float, 32, 32, 64, false, false, false, EQuad> { using Type = WarpGemmMfma_f32_32x32x64_bf8_bf8<EQuad>; };
template<WGAttrNumAccessEnum I> struct Dispatcher<fp8_t, fp8_t, float, 32, 32, 64, true, false, false, I> { using Type = WarpGemmMfma_f32_32x32x64_fp8_fp8_CTransposed<I>; };
template<WGAttrNumAccessEnum I> struct Dispatcher<fp8_t, bf8_t, float, 32, 32, 64, true, false, false, I> { using Type = WarpGemmMfma_f32_32x32x64_fp8_bf8_CTransposed<I>; };
template<WGAttrNumAccessEnum I> struct Dispatcher<bf8_t, fp8_t, float, 32, 32, 64, true, false, false, I> { using Type = WarpGemmMfma_f32_32x32x64_bf8_fp8_CTransposed<I>; };
template<WGAttrNumAccessEnum I> struct Dispatcher<bf8_t, bf8_t, float, 32, 32, 64, true, false, false, I> { using Type = WarpGemmMfma_f32_32x32x64_bf8_bf8_CTransposed<I>; };
template<WGAttrNumAccessEnum I> struct Dispatcher<pk_fp4_t, pk_fp4_t, float, 32, 32, 64, true, false, false, I> { using Type = WarpGemmMfma_f32_32x32x64_fp4_fp4_CTransposed<I>; };
template<> struct Dispatcher<fp8_t, fp8_t, float, 32, 32, 32, false> { using Type = WarpGemmMfma_f32_32x32x32_fp8_fp8<>; };
template<> struct Dispatcher<fp8_t, fp8_t, float, 32, 32, 32, false, false, false, EDouble> { using Type = WarpGemmMfma_f32_32x32x32_fp8_fp8<EDouble>; };
template<> struct Dispatcher<bf8_t, bf8_t, float, 32, 32, 32, false> { using Type = WarpGemmMfma_f32_32x32x32_bf8_bf8<>; };
template<> struct Dispatcher<bf8_t, bf8_t, float, 32, 32, 32, false, false, false, EDouble> { using Type = WarpGemmMfma_f32_32x32x32_bf8_bf8<EDouble>; };
template<> struct Dispatcher<fp8_t, fp8_t, float, 16, 16, 64, true> { using Type = WarpGemmMfma_f32_16x16x64_fp8_fp8_CTransposed; };
template<> struct Dispatcher<fp8_t, fp8_t, float, 16, 16, 64, false> { using Type = WarpGemmMfma_f32_16x16x64_fp8_fp8<>; };
template<> struct Dispatcher<fp8_t, fp8_t, float, 16, 16, 64, false, false, false, EDouble> { using Type = WarpGemmMfma_f32_16x16x64_fp8_fp8<EDouble>; };
//WMMA cases
template<bool TransposeC> struct Dispatcher<fp8_t, fp8_t, float, 16, 16, 16, TransposeC, false> { using Type = WarpGemmWmma_f32_16x16x16_f8_f8<TransposeC>; };
template<bool TransposeC> struct Dispatcher<bf8_t, bf8_t, float, 16, 16, 16, TransposeC, false> { using Type = WarpGemmWmma_f32_16x16x16_bf8_bf8<TransposeC>; };
template<bool TransposeC> struct Dispatcher<fp8_t, bf8_t, float, 16, 16, 16, TransposeC, false> { using Type = WarpGemmWmma_f32_16x16x16_f8_bf8<TransposeC>; };
template<bool TransposeC> struct Dispatcher<bf8_t, fp8_t, float, 16, 16, 16, TransposeC, false> { using Type = WarpGemmWmma_f32_16x16x16_bf8_f8<TransposeC>; };
// int8
// ADataType, BDataType, AccDataType, MPerWave, NPerWave, KPerWave, TransposeC, SwizzleA, UseStructuredSparsity
template<> struct Dispatcher<int8_t, int8_t, int32_t, 32, 32, 16, false> { using Type = WarpGemmMfma_i32_32x32x16_i8_i8; };
template<> struct Dispatcher<int8_t, int8_t, int32_t, 32, 32, 16, true> { using Type = WarpGemmMfma_i32_32x32x16_i8_i8_CTransposed; };
template<> struct Dispatcher<int8_t, int8_t, int32_t, 16, 16, 32, false> { using Type = WarpGemmMfma_i32_16x16x32_i8_i8; };
template<> struct Dispatcher<int8_t, int8_t, int32_t, 16, 16, 32, true> { using Type = WarpGemmMfma_i32_16x16x32_i8_i8_CTransposed; };
// WMMA cases
template<bool TransposeC> struct Dispatcher<int8_t, int8_t, int32_t, 16, 16, 16, TransposeC, false> { using Type = WarpGemmWmma_i32_16x16x16_i8_i8<TransposeC>;};
// clang-format on
} // namespace warp_gemm_dispatcher
} // namespace impl
template <typename AType,
typename BType,
typename AccType,
index_t MPerWave,
index_t NPerWave,
index_t KPerWave,
bool TransposeC,
bool SwizzleA = false,
bool UseStructuredSparsity = false,
WGAttrNumAccessEnum AttrNumAccessA = WGAttrNumAccessEnum::Single,
WGAttrNumAccessEnum AttrNumAccessB = AttrNumAccessA>
using WarpGemmDispatcher = typename impl::warp_gemm_dispatcher::Dispatcher< //
AType,
BType,
AccType,
MPerWave,
NPerWave,
KPerWave,
TransposeC,
SwizzleA,
UseStructuredSparsity,
AttrNumAccessA,
AttrNumAccessB>::Type;
} // namespace ck_tile

View File

@@ -0,0 +1,183 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/core.hpp"
namespace ck_tile {
template <typename WarpGemmAttribute_>
struct WarpGemmImpl
{
using WarpGemmAttribute = remove_cvref_t<WarpGemmAttribute_>;
static constexpr index_t kM = WarpGemmAttribute::kM;
static constexpr index_t kN = WarpGemmAttribute::kN;
static constexpr index_t kK = WarpGemmAttribute::kK;
static constexpr index_t kCMLane = WarpGemmAttribute::kCMLane;
/// @brief The number of elements in K dimension processed by single thread in wavefront.
///
/// @note Note that WarpGemm may run MFMA instruction multiple times (on different K).
/// In such situation this value reflects this fact.
static constexpr index_t kKPerThread = WarpGemmAttribute::kKPerThread;
using ADataType = typename WarpGemmAttribute::ADataType;
using BDataType = typename WarpGemmAttribute::BDataType;
using CDataType = typename WarpGemmAttribute::CDataType;
using AWarpDstrEncoding = typename WarpGemmAttribute::AWarpDstrEncoding;
using BWarpDstrEncoding = typename WarpGemmAttribute::BWarpDstrEncoding;
using CWarpDstrEncoding = typename WarpGemmAttribute::CWarpDstrEncoding;
using AWarpDstr = remove_cvref_t<decltype(make_static_tile_distribution(AWarpDstrEncoding{}))>;
using BWarpDstr = remove_cvref_t<decltype(make_static_tile_distribution(BWarpDstrEncoding{}))>;
using CWarpDstr = remove_cvref_t<decltype(make_static_tile_distribution(CWarpDstrEncoding{}))>;
using AWarpTensor = static_distributed_tensor<ADataType, AWarpDstr>;
using BWarpTensor = static_distributed_tensor<BDataType, BWarpDstr>;
using CWarpTensor = static_distributed_tensor<CDataType, CWarpDstr>;
CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access()
{
return WarpGemmAttribute_::get_num_of_access();
}
template <typename CTensor, typename ATensor, typename BTensor, bool post_nop_ = false>
CK_TILE_DEVICE void
operator()(CTensor& c, const ATensor& a, const BTensor& b, bool_constant<post_nop_> = {}) const
{
static_assert(detail::is_similiar_distributed_tensor_v<CTensor, CWarpTensor> &&
detail::is_similiar_distributed_tensor_v<ATensor, AWarpTensor> &&
detail::is_similiar_distributed_tensor_v<BTensor, BWarpTensor>);
using AVec = ext_vector_t<ADataType, ATensor::get_thread_buffer_size()>;
using BVec = ext_vector_t<BDataType, BTensor::get_thread_buffer_size()>;
using CVec = ext_vector_t<CDataType, CTensor::get_thread_buffer_size()>;
constexpr auto I0 = number<0>{};
const auto a_vec = a.get_thread_buffer().template get_as<AVec>()[I0];
const auto b_vec = b.get_thread_buffer().template get_as<BVec>()[I0];
auto c_vec = c.get_thread_buffer().template get_as<CVec>()[I0];
// c_vec += a_vec * b_vec
WarpGemmAttribute{}(c_vec, a_vec, b_vec, bool_constant<post_nop_>{});
c.get_thread_buffer().template set_as<CVec>(I0, c_vec);
}
template <typename CTensor,
typename ATensor,
typename BTensor,
index_t i_subk,
bool post_nop_ = false>
CK_TILE_DEVICE void operator()(CTensor& c,
const ATensor& a,
const BTensor& b,
number<i_subk>,
bool_constant<post_nop_> = {}) const
{
using AVec = ext_vector_t<ADataType, ATensor::get_thread_buffer_size()>;
using BVec = ext_vector_t<BDataType, BTensor::get_thread_buffer_size()>;
using CVec = ext_vector_t<CDataType, CTensor::get_thread_buffer_size()>;
constexpr auto I0 = number<0>{};
const auto a_vec = a.get_thread_buffer().template get_as<AVec>()[I0];
const auto b_vec = b.get_thread_buffer().template get_as<BVec>()[I0];
auto c_vec = c.get_thread_buffer().template get_as<CVec>()[I0];
// c_vec += a_vec * b_vec
WarpGemmAttribute{}(c_vec, a_vec, b_vec, number<i_subk>{}, bool_constant<post_nop_>{});
c.get_thread_buffer().template set_as<CVec>(I0, c_vec);
}
template <index_t opselA,
index_t opselB,
typename CTensor,
typename ATensor,
typename BTensor,
bool post_nop_ = false>
CK_TILE_DEVICE void operator()(CTensor& c,
const ATensor& a,
const BTensor& b,
const int32_t& a_scale,
const int32_t& b_scale,
bool_constant<post_nop_> = {}) const
{
static_assert(detail::is_similiar_distributed_tensor_v<CTensor, CWarpTensor> &&
detail::is_similiar_distributed_tensor_v<ATensor, AWarpTensor> &&
detail::is_similiar_distributed_tensor_v<BTensor, BWarpTensor>);
using AVec = ext_vector_t<ADataType, ATensor::get_thread_buffer_size()>;
using BVec = ext_vector_t<BDataType, BTensor::get_thread_buffer_size()>;
using CVec = ext_vector_t<CDataType, CTensor::get_thread_buffer_size()>;
constexpr auto I0 = number<0>{};
const auto a_vec = a.get_thread_buffer().template get_as<AVec>()[I0];
const auto b_vec = b.get_thread_buffer().template get_as<BVec>()[I0];
auto c_vec = c.get_thread_buffer().template get_as<CVec>()[I0];
// c_vec += a_vec * b_vec
WarpGemmAttribute{}.template operator()<opselA, opselB>(
c_vec, a_vec, a_scale, b_vec, b_scale, bool_constant<post_nop_>{});
c.get_thread_buffer().template set_as<CVec>(I0, c_vec);
}
template <typename ATensor, typename BTensor>
CK_TILE_DEVICE auto operator()(const ATensor& a, const BTensor& b) const
{
using CTensor = CWarpTensor;
static_assert(detail::is_similiar_distributed_tensor_v<ATensor, AWarpTensor> &&
detail::is_similiar_distributed_tensor_v<BTensor, BWarpTensor>);
CTensor c;
using AVec = ext_vector_t<ADataType, ATensor::get_thread_buffer_size()>;
using BVec = ext_vector_t<BDataType, BTensor::get_thread_buffer_size()>;
using CVec = ext_vector_t<CDataType, CTensor::get_thread_buffer_size()>;
constexpr auto I0 = number<0>{};
const auto a_vec = a.get_thread_buffer().template get_as<AVec>()[I0];
const auto b_vec = b.get_thread_buffer().template get_as<BVec>()[I0];
// c_vec = a_vec * b_vec
auto c_vec = WarpGemmAttribute{}(a_vec, b_vec);
c.get_thread_buffer().template set_as<CVec>(I0, c_vec);
return c;
}
template <index_t opselA, index_t opselB, typename ATensor, typename BTensor>
CK_TILE_DEVICE auto operator()(const ATensor& a,
const BTensor& b,
const int32_t& a_scale,
const int32_t& b_scale) const
{
using CTensor = CWarpTensor;
static_assert(detail::is_similiar_distributed_tensor_v<ATensor, AWarpTensor> &&
detail::is_similiar_distributed_tensor_v<BTensor, BWarpTensor>);
CTensor c;
using AVec = ext_vector_t<ADataType, ATensor::get_thread_buffer_size()>;
using BVec = ext_vector_t<BDataType, BTensor::get_thread_buffer_size()>;
using CVec = ext_vector_t<CDataType, CTensor::get_thread_buffer_size()>;
constexpr auto I0 = number<0>{};
const auto a_vec = a.get_thread_buffer().template get_as<AVec>()[I0];
const auto b_vec = b.get_thread_buffer().template get_as<BVec>()[I0];
// c_vec = a_vec * b_vec
auto c_vec =
WarpGemmAttribute{}.template operator()<opselA, opselB>(a_vec, a_scale, b_vec, b_scale);
c.get_thread_buffer().template set_as<CVec>(I0, c_vec);
return c;
}
};
} // namespace ck_tile

View File

@@ -0,0 +1,131 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/tensor/static_distributed_tensor.hpp"
namespace ck_tile {
/**
* @brief Compress A vector for 2:4 structured sparsity instruction by moving all non-zero
* elements into lower part of a_vec to half its effective size.
* @param a_vec Vector to be compressed.
* @tparam ADataType The data type of a_vec
* @tparam CompressedSize The target compression size
* @tparam AVec The vector type of a_vec (deduced)
* @return Packed 32bit word containing **CompressedSize** 2bit fields.
* Each field encodes the original position (03) of the corresponding
* nonzero element in the input. If fewer than CompressedSize
* nonzeros are found, remaining fields default to 2 (see below).
*/
template <typename ADataType, index_t CompressedSize, typename AVec>
static CK_TILE_DEVICE int32_t compress_a_impl(AVec& a_vec)
{
// idx holds one 2bit index per output element (total CompressedSize entries).
// It is initialized to the pattern 0b10 for every field. This matches
// what the hardware expects when there are fewer than two nonzero values
// in a 4element group the unused output is treated as coming from slot 2.
// The loop below will clear and set each field as real nonzeros are seen.
int32_t idx = 0;
static_for<0, CompressedSize, 1>{}([&](auto k) { idx |= (2 << (2 * k)); });
static_for<0, CompressedSize / 2, 1>{}([&](auto i) {
ADataType nonzero_elems[2] = {a_vec[i * 4 + 2], a_vec[i * 4 + 3]};
int32_t non_zero_pos = 0;
static_for<0, 3, 1>{}([&](auto j) {
if(a_vec[i * 4 + j] != 0.0f)
{
nonzero_elems[non_zero_pos] = a_vec[i * 4 + j];
// clear the twobit field for this output and insert j
idx &= ~(0b11 << 2 * (i * 2 + non_zero_pos));
idx |= j << 2 * (i * 2 + non_zero_pos);
++non_zero_pos;
}
});
a_vec[i * 2] = nonzero_elems[0];
a_vec[i * 2 + 1] = nonzero_elems[1];
});
return idx;
}
template <typename WarpGemmAttribute_>
struct WarpGemmSmfmacImpl
{
using WarpGemmAttribute = remove_cvref_t<WarpGemmAttribute_>;
static constexpr index_t kM = WarpGemmAttribute::kM;
static constexpr index_t kN = WarpGemmAttribute::kN;
static constexpr index_t kK = WarpGemmAttribute::kK;
/// @brief The number of elements in K dimension processed by single thread in wavefront.
///
/// @note Note that WarpGemm may run MFMA instruction multiple times (on different K).
/// In such situation this value reflects this fact.
static constexpr index_t kKPerThread = WarpGemmAttribute::kKPerThread;
using ADataType = typename WarpGemmAttribute::ADataType;
using BDataType = typename WarpGemmAttribute::BDataType;
using CDataType = typename WarpGemmAttribute::CDataType;
using AWarpDstrEncoding = typename WarpGemmAttribute::AWarpDstrEncoding;
using BWarpDstrEncoding = typename WarpGemmAttribute::BWarpDstrEncoding;
using CWarpDstrEncoding = typename WarpGemmAttribute::CWarpDstrEncoding;
using AWarpDstr = remove_cvref_t<decltype(make_static_tile_distribution(AWarpDstrEncoding{}))>;
using BWarpDstr = remove_cvref_t<decltype(make_static_tile_distribution(BWarpDstrEncoding{}))>;
using CWarpDstr = remove_cvref_t<decltype(make_static_tile_distribution(CWarpDstrEncoding{}))>;
using AWarpTensor = static_distributed_tensor<ADataType, AWarpDstr>;
using BWarpTensor = static_distributed_tensor<BDataType, BWarpDstr>;
using CWarpTensor = static_distributed_tensor<CDataType, CWarpDstr>;
CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access()
{
return WarpGemmAttribute_::get_num_of_access();
}
template <index_t CompressedSize, typename AVec>
CK_TILE_DEVICE int32_t compress_a_vec(AVec& a_vec)
{
return compress_a_impl<ADataType, CompressedSize>(a_vec);
}
template <typename CTensor, typename ATensor, typename BTensor, bool post_nop_ = false>
CK_TILE_DEVICE void
operator()(CTensor& c, const ATensor& a, const BTensor& b, bool_constant<post_nop_> = {}) const
{
static_assert(detail::is_similiar_distributed_tensor_v<CTensor, CWarpTensor> &&
detail::is_similiar_distributed_tensor_v<ATensor, AWarpTensor> &&
detail::is_similiar_distributed_tensor_v<BTensor, BWarpTensor>);
constexpr auto CompressionRatio = WarpGemmAttribute::kCompressionRatio;
using AVec = ext_vector_t<ADataType, ATensor::get_thread_buffer_size()>;
static constexpr index_t CompressedSize =
ATensor::get_thread_buffer_size() / CompressionRatio;
using AVecCompressed = ext_vector_t<ADataType, CompressedSize>;
using BVec = ext_vector_t<BDataType, BTensor::get_thread_buffer_size()>;
using CVec = ext_vector_t<CDataType, CTensor::get_thread_buffer_size()>;
constexpr auto I0 = number<0>{};
auto a_vec = a.get_thread_buffer().template get_as<AVec>()[I0];
const auto b_vec = b.get_thread_buffer().template get_as<BVec>()[I0];
auto c_vec = c.get_thread_buffer().template get_as<CVec>()[I0];
const int32_t idx = compress_a_vec<CompressedSize>(a_vec);
static_assert(CompressedSize == 4);
// @TODO can we simply set a_vec_pruned to a_vec[0:3]?
const AVecCompressed a_vec_pruned = {a_vec[0], a_vec[1], a_vec[2], a_vec[3]};
// c_vec += a_vec * b_vec[idx]
WarpGemmAttribute{}(c_vec, a_vec_pruned, b_vec, idx, bool_constant<post_nop_>{});
c.get_thread_buffer().template set_as<CVec>(I0, c_vec);
}
};
} // namespace ck_tile

View File

@@ -0,0 +1,42 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_impl.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma_impl_16bit_traits.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma_impl_8bit_traits.hpp"
namespace ck_tile {
template <bool kTransC = false>
using WarpGemmWmma_f32_16x16x16_f16_f16 =
WarpGemmImpl<WarpGemmAttributeWmma<WarpGemmAttributeWmmaImpl_f32_16x16x16_f16_f16, kTransC>>;
template <bool kTransC = false>
using WarpGemmWmma_f32_16x16x16_bf16_bf16 =
WarpGemmImpl<WarpGemmAttributeWmma<WarpGemmAttributeWmmaImpl_f32_16x16x16_bf16_bf16, kTransC>>;
template <bool kTransC = false>
using WarpGemmWmma_i32_16x16x16_i8_i8 =
WarpGemmImpl<WarpGemmAttributeWmma<WarpGemmAttributeWmmaImpl_i32_16x16x16_i8_i8, kTransC>>;
template <bool kTransC = false>
using WarpGemmWmma_f32_16x16x16_f8_f8 =
WarpGemmImpl<WarpGemmAttributeWmma<WarpGemmAttributeWmmaImpl_f32_16x16x16_f8_f8, kTransC>>;
template <bool kTransC = false>
using WarpGemmWmma_f32_16x16x16_bf8_bf8 =
WarpGemmImpl<WarpGemmAttributeWmma<WarpGemmAttributeWmmaImpl_f32_16x16x16_bf8_bf8, kTransC>>;
template <bool kTransC = false>
using WarpGemmWmma_f32_16x16x16_f8_bf8 =
WarpGemmImpl<WarpGemmAttributeWmma<WarpGemmAttributeWmmaImpl_f32_16x16x16_f8_bf8, kTransC>>;
template <bool kTransC = false>
using WarpGemmWmma_f32_16x16x16_bf8_f8 =
WarpGemmImpl<WarpGemmAttributeWmma<WarpGemmAttributeWmmaImpl_f32_16x16x16_bf8_f8, kTransC>>;
} // namespace ck_tile

View File

@@ -0,0 +1,45 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/core.hpp"
namespace ck_tile {
// C distribution of gfx11 WMMA is not compatible with A distribution:
// C: 2 lanes per row (lane and lane + 16), 8 values per lane are interleaved.
// A: 1 lane per row, 16 values, lane and lane + 16 have the same values.
// This function transforms one ditribution to another for GEMM-GEMM scenarios.
template <typename OutTensor, typename InTensor>
CK_TILE_DEVICE static constexpr void PermuteWarpGemmCToA(OutTensor& out, const InTensor& in)
{
#if defined(__gfx11__)
static_assert(sizeof(typename OutTensor::DataType) == 2);
static_assert(std::is_same_v<typename OutTensor::DataType, typename InTensor::DataType>);
constexpr index_t n_out = OutTensor::get_thread_buffer_size();
static_assert(n_out == InTensor::get_thread_buffer_size() * 2);
// Perm byte selectors are swapped for the second row (16 lanes) because it needs to be done
// once instead to swapping w and v everytime
const uint32_t byte_selector0 = get_lane_id() < 16 ? 0x05'04'01'00 : 0x01'00'05'04;
const uint32_t byte_selector1 = get_lane_id() < 16 ? 0x07'06'03'02 : 0x03'02'07'06;
static_for<0, n_out, 1>{}([&](auto i) {
const auto v = in.get_thread_buffer().template get_as<uint32_t>(i);
// Swap rows (lane <-> lane ^ 16)
const auto w = __builtin_amdgcn_permlanex16(0, v, 0x76543210, 0xfedcba98, false, true);
// Interleave values of lane and lane ^ 16
out.get_thread_buffer().template set_as<uint32_t>(
number<i * 2 + 0>{}, __builtin_amdgcn_perm(w, v, byte_selector0));
out.get_thread_buffer().template set_as<uint32_t>(
number<i * 2 + 1>{}, __builtin_amdgcn_perm(w, v, byte_selector1));
});
#else
static_assert(false, "PermuteWarpGemmCToA is only for gfx11");
ignore = out;
ignore = in;
#endif
}
} // namespace ck_tile