mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-05 22:22:27 +00:00
This commit is contained in:
504
include/ck_tile/ops/gemm/warp/warp_gemm.hpp
Normal file
504
include/ck_tile/ops/gemm/warp/warp_gemm.hpp
Normal file
@@ -0,0 +1,504 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm_impl.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp"
|
||||
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm_smfmac_impl.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm_attribute_smfmac.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// fp32
|
||||
|
||||
using WarpGemmMfmaF32F32F32M16N16K4 = WarpGemmImpl<
|
||||
WarpGemmAttributeMfma<WarpGemmAttributeMfmaImplF32F32F32M16N16K4<WGAttrCtlEnum::Default_>>>;
|
||||
|
||||
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
|
||||
using WarpGemmMfmaF32F32F32M16N16K16 = WarpGemmImpl<WarpGemmAttributeMfmaIterateK<
|
||||
WarpGemmAttributeMfmaImplF32F32F32M16N16K4<WGAttrCtlEnum::Default_>,
|
||||
4,
|
||||
AttrNumAccess>>;
|
||||
|
||||
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
|
||||
using WarpGemmMfmaF32F32F32M16N16K8 = WarpGemmImpl<WarpGemmAttributeMfmaIterateK<
|
||||
WarpGemmAttributeMfmaImplF32F32F32M16N16K4<WGAttrCtlEnum::Default_>,
|
||||
2,
|
||||
AttrNumAccess>>;
|
||||
|
||||
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
|
||||
using WarpGemmMfmaF32F32F32M32N32K8 = WarpGemmImpl<WarpGemmAttributeMfmaIterateK<
|
||||
WarpGemmAttributeMfmaImplF32F32F32M32N32K2<WGAttrCtlEnum::Default_>,
|
||||
4,
|
||||
AttrNumAccess>>;
|
||||
|
||||
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
|
||||
using WarpGemmMfmaF32F32F32M32N32K4 = WarpGemmImpl<WarpGemmAttributeMfmaIterateK<
|
||||
WarpGemmAttributeMfmaImplF32F32F32M32N32K2<WGAttrCtlEnum::Default_>,
|
||||
2,
|
||||
AttrNumAccess>>;
|
||||
|
||||
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
|
||||
using WarpGemmMfmaF32F32F32M16N16K16TransposedCDistribution =
|
||||
WarpGemmImpl<WarpGemmAttributeMfmaIterateKAndTransposedCDistribution<
|
||||
WarpGemmAttributeMfmaImplF32F32F32M16N16K4<WGAttrCtlEnum::Default_>,
|
||||
4,
|
||||
AttrNumAccess>>;
|
||||
|
||||
// fp16
|
||||
|
||||
using WarpGemmMfmaF16F16F32M32N32K8 = WarpGemmImpl<
|
||||
WarpGemmAttributeMfma<WarpGemmAttributeMfmaImplF16F16F32M32N32K8<WGAttrCtlEnum::Default_>>>;
|
||||
|
||||
using WarpGemmMfmaF16F16F32M16N16K16 = WarpGemmImpl<
|
||||
WarpGemmAttributeMfma<WarpGemmAttributeMfmaImplF16F16F32M16N16K16<WGAttrCtlEnum::Default_>>>;
|
||||
|
||||
#if defined(__gfx950__)
|
||||
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
|
||||
using WarpGemmMfmaF16F16F32M32N32K16 = WarpGemmImpl<
|
||||
WarpGemmAttributeMfma<WarpGemmAttributeMfmaImplF16F16F32M32N32K16<WGAttrCtlEnum::Default_>,
|
||||
AttrNumAccess>>;
|
||||
#else
|
||||
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
|
||||
using WarpGemmMfmaF16F16F32M32N32K16 = WarpGemmImpl<WarpGemmAttributeMfmaIterateK<
|
||||
WarpGemmAttributeMfmaImplF16F16F32M32N32K8<WGAttrCtlEnum::Default_>,
|
||||
2,
|
||||
AttrNumAccess>>;
|
||||
#endif
|
||||
|
||||
#if defined(__gfx950__)
|
||||
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
|
||||
using WarpGemmMfmaF16F16F32M16N16K32 = WarpGemmImpl<
|
||||
WarpGemmAttributeMfma<WarpGemmAttributeMfmaImplF16F16F32M16N16K32<WGAttrCtlEnum::Default_>,
|
||||
AttrNumAccess>>;
|
||||
#else
|
||||
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
|
||||
using WarpGemmMfmaF16F16F32M16N16K32 = WarpGemmImpl<WarpGemmAttributeMfmaIterateK<
|
||||
WarpGemmAttributeMfmaImplF16F16F32M16N16K16<WGAttrCtlEnum::Default_>,
|
||||
2,
|
||||
AttrNumAccess>>;
|
||||
#endif
|
||||
|
||||
using WarpGemmMfmaF16F16F32M32N32K8SwizzleA = WarpGemmImpl<WarpGemmAttributeMfmaIterateK_SwizzleA<
|
||||
WarpGemmAttributeMfmaImplF16F16F32M32N32K8<WGAttrCtlEnum::Default_>,
|
||||
1>>;
|
||||
|
||||
using WarpGemmMfmaF16F16F32M32N32K16SwizzleA = WarpGemmImpl<WarpGemmAttributeMfmaIterateK_SwizzleA<
|
||||
WarpGemmAttributeMfmaImplF16F16F32M32N32K8<WGAttrCtlEnum::Default_>,
|
||||
2>>;
|
||||
|
||||
using WarpGemmMfmaF16F16F32M32N32K8TransposedCDistribution =
|
||||
WarpGemmImpl<WarpGemmAttributeMfmaTransposedCDistribution<
|
||||
WarpGemmAttributeMfmaImplF16F16F32M32N32K8<WGAttrCtlEnum::Default_>>>;
|
||||
|
||||
using WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution =
|
||||
WarpGemmImpl<WarpGemmAttributeMfmaTransposedCDistribution<
|
||||
WarpGemmAttributeMfmaImplF16F16F32M16N16K16<WGAttrCtlEnum::Default_>>>;
|
||||
|
||||
#if defined(__gfx950__)
|
||||
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
|
||||
using WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution =
|
||||
WarpGemmImpl<WarpGemmAttributeMfmaTransposedCDistribution<
|
||||
WarpGemmAttributeMfmaImplF16F16F32M32N32K16<WGAttrCtlEnum::Default_>,
|
||||
AttrNumAccess>>;
|
||||
#else
|
||||
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
|
||||
using WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution =
|
||||
WarpGemmImpl<WarpGemmAttributeMfmaIterateKAndTransposedCDistribution<
|
||||
WarpGemmAttributeMfmaImplF16F16F32M32N32K8<WGAttrCtlEnum::Default_>,
|
||||
2,
|
||||
AttrNumAccess>>;
|
||||
#endif
|
||||
|
||||
#if defined(__gfx950__)
|
||||
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
|
||||
using WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution =
|
||||
WarpGemmImpl<WarpGemmAttributeMfmaTransposedCDistribution<
|
||||
WarpGemmAttributeMfmaImplF16F16F32M16N16K32<WGAttrCtlEnum::Default_>,
|
||||
AttrNumAccess>>;
|
||||
#else
|
||||
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
|
||||
using WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution =
|
||||
WarpGemmImpl<WarpGemmAttributeMfmaIterateKAndTransposedCDistribution<
|
||||
WarpGemmAttributeMfmaImplF16F16F32M16N16K16<WGAttrCtlEnum::Default_>,
|
||||
2,
|
||||
AttrNumAccess>>;
|
||||
#endif
|
||||
|
||||
#if defined(__gfx950__)
|
||||
using WarpGemmMfmaF16F16F32M16N16K32SwizzleBTransposedCDistribution =
|
||||
WarpGemmImpl<WarpGemmAttributeMfmaTransposedCDistribution_SwizzleB<
|
||||
WarpGemmAttributeMfmaImplF16F16F32M16N16K32<WGAttrCtlEnum::Default_>,
|
||||
1>>;
|
||||
|
||||
using WarpGemmMfmaBf16Bf16F32M16N16K32SwizzleBTransposedCDistribution =
|
||||
WarpGemmImpl<WarpGemmAttributeMfmaTransposedCDistribution_SwizzleB<
|
||||
WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K32<WGAttrCtlEnum::Default_>,
|
||||
1>>;
|
||||
#endif
|
||||
|
||||
using WarpGemmMfmaF16F16F32M32N32K8SwizzleBTransposedCDistribution =
|
||||
WarpGemmImpl<WarpGemmAttributeMfmaTransposedCDistribution_SwizzleB<
|
||||
WarpGemmAttributeMfmaImplF16F16F32M32N32K8<WGAttrCtlEnum::Default_>>>;
|
||||
|
||||
#if defined(__gfx950__)
|
||||
using WarpGemmMfmaF16F16F32M32N32K16SwizzleBTransposedCDistribution =
|
||||
WarpGemmImpl<WarpGemmAttributeMfmaTransposedCDistribution_SwizzleB<
|
||||
WarpGemmAttributeMfmaImplF16F16F32M32N32K16<WGAttrCtlEnum::Default_>>>;
|
||||
#else
|
||||
using WarpGemmMfmaF16F16F32M32N32K16SwizzleBTransposedCDistribution =
|
||||
WarpGemmImpl<WarpGemmAttributeMfmaIterateKAndTransposedCDistribution_SwizzleB<
|
||||
WarpGemmAttributeMfmaImplF16F16F32M32N32K8<WGAttrCtlEnum::Default_>,
|
||||
2>>;
|
||||
#endif
|
||||
|
||||
using WarpGemmMfmaF16F16F32M4N64K16 = WarpGemmImpl<WarpGemmAttributeMfmaIterateK<
|
||||
WarpGemmAttributeMfmaImplF16F16F32M4N64K4<WGAttrCtlEnum::Default_>,
|
||||
4>>;
|
||||
|
||||
using WarpGemmMfmaF16F16F32M64N4K16 = WarpGemmImpl<WarpGemmAttributeMfmaIterateK<
|
||||
WarpGemmAttributeMfmaImplF16F16F32M64N4K4<WGAttrCtlEnum::Default_>,
|
||||
4>>;
|
||||
|
||||
// fp16 2:4 structured sparsity
|
||||
using WarpGemmSmfmacF16F16F32M32N32K16 = WarpGemmSmfmacImpl<WarpGemmAttributeSmfmac<
|
||||
WarpGemmAttributeSmfmacImplF16F16F32M32N32K16<WGAttrCtlEnum::Default_>>>;
|
||||
|
||||
using WarpGemmSmfmacF16F16F32M16N16K32 = WarpGemmSmfmacImpl<WarpGemmAttributeSmfmac<
|
||||
WarpGemmAttributeSmfmacImplF16F16F32M16N16K32<WGAttrCtlEnum::Default_>>>;
|
||||
|
||||
// bf16
|
||||
using WarpGemmMfmaBf16Bf16F32M32N32K8 = WarpGemmImpl<
|
||||
WarpGemmAttributeMfma<WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8<WGAttrCtlEnum::Default_>>>;
|
||||
|
||||
using WarpGemmMfmaBf16Bf16F32M16N16K16 = WarpGemmImpl<
|
||||
WarpGemmAttributeMfma<WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16<WGAttrCtlEnum::Default_>>>;
|
||||
|
||||
#if defined(__gfx950__)
|
||||
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
|
||||
using WarpGemmMfmaBf16Bf16F32M32N32K16 = WarpGemmImpl<
|
||||
WarpGemmAttributeMfma<WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K16<WGAttrCtlEnum::Default_>,
|
||||
AttrNumAccess>>;
|
||||
#else
|
||||
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
|
||||
using WarpGemmMfmaBf16Bf16F32M32N32K16 = WarpGemmImpl<WarpGemmAttributeMfmaIterateK<
|
||||
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8<WGAttrCtlEnum::Default_>,
|
||||
2,
|
||||
AttrNumAccess>>;
|
||||
#endif
|
||||
|
||||
#if defined(__gfx950__)
|
||||
template <WGAttrNumAccessEnum AttrNumAccessA = WGAttrNumAccessEnum::Single,
|
||||
WGAttrNumAccessEnum AttrNumAccessB = AttrNumAccessA>
|
||||
using WarpGemmMfmaBf16Bf16F32M16N16K32 = WarpGemmImpl<
|
||||
WarpGemmAttributeMfma<WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K32<WGAttrCtlEnum::Default_>,
|
||||
AttrNumAccessA,
|
||||
AttrNumAccessB>>;
|
||||
|
||||
template <WGAttrNumAccessEnum AttrNumAccessA = WGAttrNumAccessEnum::Single,
|
||||
WGAttrNumAccessEnum AttrNumAccessB = AttrNumAccessA>
|
||||
using WarpGemmMfmaBf16Bf16F32M16N16K64 = WarpGemmImpl<WarpGemmAttributeMfmaIterateK<
|
||||
WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K32<WGAttrCtlEnum::Default_>,
|
||||
2,
|
||||
AttrNumAccessA,
|
||||
AttrNumAccessB>>;
|
||||
#else
|
||||
template <WGAttrNumAccessEnum AttrNumAccessA = WGAttrNumAccessEnum::Single,
|
||||
WGAttrNumAccessEnum AttrNumAccessB = AttrNumAccessA>
|
||||
using WarpGemmMfmaBf16Bf16F32M16N16K32 = WarpGemmImpl<WarpGemmAttributeMfmaIterateK<
|
||||
WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16<WGAttrCtlEnum::Default_>,
|
||||
2,
|
||||
AttrNumAccessA>>;
|
||||
|
||||
template <WGAttrNumAccessEnum AttrNumAccessA = WGAttrNumAccessEnum::Single,
|
||||
WGAttrNumAccessEnum AttrNumAccessB = AttrNumAccessA>
|
||||
using WarpGemmMfmaBf16Bf16F32M16N16K64 = WarpGemmImpl<WarpGemmAttributeMfmaIterateK<
|
||||
WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16<WGAttrCtlEnum::Default_>,
|
||||
4,
|
||||
AttrNumAccessA,
|
||||
AttrNumAccessB>>;
|
||||
#endif
|
||||
|
||||
using WarpGemmMfmaBf16Bf16F32M32N32K8SwizzleA = WarpGemmImpl<WarpGemmAttributeMfmaIterateK_SwizzleA<
|
||||
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8<WGAttrCtlEnum::Default_>,
|
||||
1>>;
|
||||
|
||||
using WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleA =
|
||||
WarpGemmImpl<WarpGemmAttributeMfmaIterateK_SwizzleA<
|
||||
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8<WGAttrCtlEnum::Default_>,
|
||||
2>>;
|
||||
|
||||
using WarpGemmMfmaBf16Bf16F32M32N32K8TransposedCDistribution =
|
||||
WarpGemmImpl<WarpGemmAttributeMfmaTransposedCDistribution<
|
||||
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8<WGAttrCtlEnum::Default_>>>;
|
||||
|
||||
using WarpGemmMfmaBf16Bf16F32M16N16K16TransposedCDistribution =
|
||||
WarpGemmImpl<WarpGemmAttributeMfmaTransposedCDistribution<
|
||||
WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16<WGAttrCtlEnum::Default_>>>;
|
||||
|
||||
#if defined(__gfx950__)
|
||||
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
|
||||
using WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution =
|
||||
WarpGemmImpl<WarpGemmAttributeMfmaTransposedCDistribution<
|
||||
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K16<WGAttrCtlEnum::Default_>,
|
||||
AttrNumAccess>>;
|
||||
#else
|
||||
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
|
||||
using WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution =
|
||||
WarpGemmImpl<WarpGemmAttributeMfmaIterateKAndTransposedCDistribution<
|
||||
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8<WGAttrCtlEnum::Default_>,
|
||||
2,
|
||||
AttrNumAccess>>;
|
||||
#endif
|
||||
|
||||
#if defined(__gfx950__)
|
||||
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
|
||||
using WarpGemmMfmaBf16Bf16F32M16N16K32TransposedCDistribution =
|
||||
WarpGemmImpl<WarpGemmAttributeMfmaTransposedCDistribution<
|
||||
WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K32<WGAttrCtlEnum::Default_>,
|
||||
AttrNumAccess>>;
|
||||
#else
|
||||
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
|
||||
using WarpGemmMfmaBf16Bf16F32M16N16K32TransposedCDistribution =
|
||||
WarpGemmImpl<WarpGemmAttributeMfmaIterateKAndTransposedCDistribution<
|
||||
WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16<WGAttrCtlEnum::Default_>,
|
||||
2,
|
||||
AttrNumAccess>>;
|
||||
#endif
|
||||
|
||||
using WarpGemmMfmaBf16Bf16F32M32N32K8SwizzleBTransposedCDistribution =
|
||||
WarpGemmImpl<WarpGemmAttributeMfmaTransposedCDistribution_SwizzleB<
|
||||
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8<WGAttrCtlEnum::Default_>>>;
|
||||
|
||||
#if defined(__gfx950__)
|
||||
using WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleBTransposedCDistribution =
|
||||
WarpGemmImpl<WarpGemmAttributeMfmaTransposedCDistribution_SwizzleB<
|
||||
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K16<WGAttrCtlEnum::Default_>>>;
|
||||
#else
|
||||
using WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleBTransposedCDistribution =
|
||||
WarpGemmImpl<WarpGemmAttributeMfmaIterateKAndTransposedCDistribution_SwizzleB<
|
||||
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8<WGAttrCtlEnum::Default_>,
|
||||
2>>;
|
||||
#endif
|
||||
|
||||
using WarpGemmMfmaBf16Bf16F32M4N64K16 = WarpGemmImpl<WarpGemmAttributeMfmaIterateK<
|
||||
WarpGemmAttributeMfmaImplBf16Bf16F32M4N64K4<WGAttrCtlEnum::Default_>,
|
||||
4>>;
|
||||
|
||||
using WarpGemmMfmaBf16Bf16F32M64N4K16 = WarpGemmImpl<WarpGemmAttributeMfmaIterateK<
|
||||
WarpGemmAttributeMfmaImplBf16Bf16F32M64N4K4<WGAttrCtlEnum::Default_>,
|
||||
4>>;
|
||||
|
||||
// fp8
|
||||
#if defined(__gfx950__)
|
||||
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
|
||||
using WarpGemmMfma_f32_32x32x32_fp8_fp8 = WarpGemmImpl<WarpGemmAttributeMfmaIterateK<
|
||||
WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_fp8<WGAttrCtlEnum::Default_>,
|
||||
2,
|
||||
AttrNumAccess>>;
|
||||
#else
|
||||
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
|
||||
using WarpGemmMfma_f32_32x32x32_fp8_fp8 = WarpGemmImpl<WarpGemmAttributeMfmaIterateK<
|
||||
WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_fp8<WGAttrCtlEnum::Default_>,
|
||||
2>>;
|
||||
#endif
|
||||
|
||||
#if defined(__gfx950__)
|
||||
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
|
||||
using WarpGemmMfma_f32_16x16x64_fp8_fp8 = WarpGemmImpl<WarpGemmAttributeMfmaIterateK<
|
||||
WarpGemmAttributeMfmaImpl_f32_16x16x32_fp8_fp8<WGAttrCtlEnum::Default_>,
|
||||
2,
|
||||
AttrNumAccess>>;
|
||||
#else
|
||||
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
|
||||
using WarpGemmMfma_f32_16x16x64_fp8_fp8 = WarpGemmImpl<WarpGemmAttributeMfmaIterateK<
|
||||
WarpGemmAttributeMfmaImpl_f32_16x16x32_fp8_fp8<WGAttrCtlEnum::Default_>,
|
||||
2>>;
|
||||
#endif
|
||||
|
||||
#if defined(__gfx950__)
|
||||
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
|
||||
using WarpGemmMfma_f32_32x32x32_bf8_bf8 = WarpGemmImpl<WarpGemmAttributeMfmaIterateK<
|
||||
WarpGemmAttributeMfmaImpl_f32_32x32x16_bf8_bf8<WGAttrCtlEnum::Default_>,
|
||||
2,
|
||||
AttrNumAccess>>;
|
||||
#else
|
||||
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
|
||||
using WarpGemmMfma_f32_32x32x32_bf8_bf8 = WarpGemmImpl<WarpGemmAttributeMfmaIterateK<
|
||||
WarpGemmAttributeMfmaImpl_f32_32x32x16_bf8_bf8<WGAttrCtlEnum::Default_>,
|
||||
2>>;
|
||||
#endif
|
||||
|
||||
using WarpGemmMfma_f32_32x32x16_fp8_fp8 = WarpGemmImpl<
|
||||
WarpGemmAttributeMfma<WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_fp8<WGAttrCtlEnum::Default_>>>;
|
||||
|
||||
using WarpGemmMfma_f32_32x32x16_fp8_bf8 = WarpGemmImpl<
|
||||
WarpGemmAttributeMfma<WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_bf8<WGAttrCtlEnum::Default_>>>;
|
||||
|
||||
using WarpGemmMfma_f32_16x16x32_fp8_bf8 = WarpGemmImpl<
|
||||
WarpGemmAttributeMfma<WarpGemmAttributeMfmaImpl_f32_16x16x32_fp8_bf8<WGAttrCtlEnum::Default_>>>;
|
||||
|
||||
using WarpGemmMfma_f32_32x32x16_bf8_fp8 = WarpGemmImpl<
|
||||
WarpGemmAttributeMfma<WarpGemmAttributeMfmaImpl_f32_32x32x16_bf8_fp8<WGAttrCtlEnum::Default_>>>;
|
||||
|
||||
using WarpGemmMfma_f32_32x32x16_bf8_bf8 = WarpGemmImpl<
|
||||
WarpGemmAttributeMfma<WarpGemmAttributeMfmaImpl_f32_32x32x16_bf8_bf8<WGAttrCtlEnum::Default_>>>;
|
||||
|
||||
using WarpGemmMfma_f32_32x32x32_fp8_bf8 = WarpGemmImpl<WarpGemmAttributeMfmaIterateK<
|
||||
WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_bf8<WGAttrCtlEnum::Default_>,
|
||||
2>>;
|
||||
|
||||
using WarpGemmMfma_f32_16x16x32_fp8_fp8 = WarpGemmImpl<
|
||||
WarpGemmAttributeMfma<WarpGemmAttributeMfmaImpl_f32_16x16x32_fp8_fp8<WGAttrCtlEnum::Default_>>>;
|
||||
|
||||
using WarpGemmMfma_f32_16x16x32_fp8_fp8_CTransposed =
|
||||
WarpGemmImpl<WarpGemmAttributeMfmaTransposedCDistribution<
|
||||
WarpGemmAttributeMfmaImpl_f32_16x16x32_fp8_fp8<WGAttrCtlEnum::Default_>>>;
|
||||
|
||||
using WarpGemmMfma_f32_16x16x32_bf8_bf8 = WarpGemmImpl<
|
||||
WarpGemmAttributeMfma<WarpGemmAttributeMfmaImpl_f32_16x16x32_bf8_bf8<WGAttrCtlEnum::Default_>>>;
|
||||
|
||||
using WarpGemmMfma_f32_16x16x32_bf8_bf8_CTransposed =
|
||||
WarpGemmImpl<WarpGemmAttributeMfmaTransposedCDistribution<
|
||||
WarpGemmAttributeMfmaImpl_f32_16x16x32_bf8_bf8<WGAttrCtlEnum::Default_>>>;
|
||||
|
||||
using WarpGemmMfma_f32_16x16x64_bf8_bf8 = WarpGemmImpl<WarpGemmAttributeMfmaIterateK<
|
||||
WarpGemmAttributeMfmaImpl_f32_16x16x32_bf8_bf8<WGAttrCtlEnum::Default_>,
|
||||
2>>;
|
||||
|
||||
using WarpGemmMfma_f32_16x16x64_fp8_fp8_CTransposed =
|
||||
WarpGemmImpl<WarpGemmAttributeMfmaIterateKAndTransposedCDistribution<
|
||||
WarpGemmAttributeMfmaImpl_f32_16x16x32_fp8_fp8<WGAttrCtlEnum::Default_>,
|
||||
2>>;
|
||||
|
||||
using WarpGemmMfma_f32_16x16x64_bf8_bf8_CTransposed =
|
||||
WarpGemmImpl<WarpGemmAttributeMfmaIterateKAndTransposedCDistribution<
|
||||
WarpGemmAttributeMfmaImpl_f32_16x16x32_bf8_bf8<WGAttrCtlEnum::Default_>,
|
||||
2>>;
|
||||
|
||||
template <typename A, typename B, WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
|
||||
using WarpGemmMfma_f32_16x16x128_f8f6f4 = WarpGemmImpl<
|
||||
WarpGemmAttributeMfma<WarpGemmAttributeMfmaImpl_f32_16x16x128_f8f6f4<A, B>, AttrNumAccess>>;
|
||||
|
||||
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
|
||||
using WarpGemmMfma_f32_16x16x128_fp8_fp8_CTransposed =
|
||||
WarpGemmImpl<WarpGemmAttributeMfmaTransposedCDistribution<
|
||||
WarpGemmAttributeMfmaImpl_f32_16x16x128_f8f6f4<fp8_t, fp8_t>,
|
||||
AttrNumAccess>>;
|
||||
|
||||
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
|
||||
using WarpGemmMfma_f32_16x16x128_fp8_bf8_CTransposed =
|
||||
WarpGemmImpl<WarpGemmAttributeMfmaTransposedCDistribution<
|
||||
WarpGemmAttributeMfmaImpl_f32_16x16x128_f8f6f4<fp8_t, bf8_t>,
|
||||
AttrNumAccess>>;
|
||||
|
||||
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
|
||||
using WarpGemmMfma_f32_16x16x128_bf8_fp8_CTransposed =
|
||||
WarpGemmImpl<WarpGemmAttributeMfmaTransposedCDistribution<
|
||||
WarpGemmAttributeMfmaImpl_f32_16x16x128_f8f6f4<bf8_t, fp8_t>,
|
||||
AttrNumAccess>>;
|
||||
|
||||
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
|
||||
using WarpGemmMfma_f32_16x16x128_bf8_bf8_CTransposed =
|
||||
WarpGemmImpl<WarpGemmAttributeMfmaTransposedCDistribution<
|
||||
WarpGemmAttributeMfmaImpl_f32_16x16x128_f8f6f4<bf8_t, bf8_t>,
|
||||
AttrNumAccess>>;
|
||||
|
||||
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
|
||||
using WarpGemmMfma_f32_16x16x128_fp4_fp4_CTransposed =
|
||||
WarpGemmImpl<WarpGemmAttributeMfmaTransposedCDistribution<
|
||||
WarpGemmAttributeMfmaImpl_f32_16x16x128_f8f6f4<pk_fp4_t, pk_fp4_t>,
|
||||
AttrNumAccess>>;
|
||||
|
||||
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
|
||||
using WarpGemmMfma_f32_32x32x64_fp8_fp8 = WarpGemmImpl<
|
||||
WarpGemmAttributeMfma<WarpGemmAttributeMfmaImpl_f32_32x32x64_fp8_fp8<WGAttrCtlEnum::Default_>,
|
||||
AttrNumAccess>>;
|
||||
|
||||
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
|
||||
using WarpGemmMfma_f32_32x32x64_fp8_bf8 = WarpGemmImpl<
|
||||
WarpGemmAttributeMfma<WarpGemmAttributeMfmaImpl_f32_32x32x64_fp8_bf8<WGAttrCtlEnum::Default_>,
|
||||
AttrNumAccess>>;
|
||||
|
||||
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
|
||||
using WarpGemmMfma_f32_32x32x64_bf8_fp8 = WarpGemmImpl<
|
||||
WarpGemmAttributeMfma<WarpGemmAttributeMfmaImpl_f32_32x32x64_bf8_fp8<WGAttrCtlEnum::Default_>,
|
||||
AttrNumAccess>>;
|
||||
|
||||
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
|
||||
using WarpGemmMfma_f32_32x32x64_bf8_bf8 = WarpGemmImpl<
|
||||
WarpGemmAttributeMfma<WarpGemmAttributeMfmaImpl_f32_32x32x64_bf8_bf8<WGAttrCtlEnum::Default_>,
|
||||
AttrNumAccess>>;
|
||||
|
||||
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
|
||||
using WarpGemmMfma_f32_32x32x64_fp8_fp8_CTransposed =
|
||||
WarpGemmImpl<WarpGemmAttributeMfmaTransposedCDistribution<
|
||||
WarpGemmAttributeMfmaImpl_f32_32x32x64_fp8_fp8<WGAttrCtlEnum::Default_>,
|
||||
AttrNumAccess>>;
|
||||
|
||||
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
|
||||
using WarpGemmMfma_f32_32x32x64_fp8_bf8_CTransposed =
|
||||
WarpGemmImpl<WarpGemmAttributeMfmaTransposedCDistribution<
|
||||
WarpGemmAttributeMfmaImpl_f32_32x32x64_fp8_bf8<WGAttrCtlEnum::Default_>,
|
||||
AttrNumAccess>>;
|
||||
|
||||
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
|
||||
using WarpGemmMfma_f32_32x32x64_bf8_fp8_CTransposed =
|
||||
WarpGemmImpl<WarpGemmAttributeMfmaTransposedCDistribution<
|
||||
WarpGemmAttributeMfmaImpl_f32_32x32x64_bf8_fp8<WGAttrCtlEnum::Default_>,
|
||||
AttrNumAccess>>;
|
||||
|
||||
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
|
||||
using WarpGemmMfma_f32_32x32x64_bf8_bf8_CTransposed =
|
||||
WarpGemmImpl<WarpGemmAttributeMfmaTransposedCDistribution<
|
||||
WarpGemmAttributeMfmaImpl_f32_32x32x64_bf8_bf8<WGAttrCtlEnum::Default_>,
|
||||
AttrNumAccess>>;
|
||||
|
||||
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
|
||||
using WarpGemmMfma_f32_32x32x64_fp4_fp4_CTransposed =
|
||||
WarpGemmImpl<WarpGemmAttributeMfmaTransposedCDistribution<
|
||||
WarpGemmAttributeMfmaImpl_f32_32x32x64_f8f6f4<pk_fp4_t, pk_fp4_t>,
|
||||
AttrNumAccess>>;
|
||||
|
||||
using WarpGemmMfma_f32_32x32x16_fp8_fp8_CTransposed =
|
||||
WarpGemmImpl<WarpGemmAttributeMfmaTransposedCDistribution<
|
||||
WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_fp8<WGAttrCtlEnum::Default_>>>;
|
||||
|
||||
using WarpGemmMfma_f32_32x32x16_fp8_bf8_CTransposed =
|
||||
WarpGemmImpl<WarpGemmAttributeMfmaTransposedCDistribution<
|
||||
WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_bf8<WGAttrCtlEnum::Default_>>>;
|
||||
|
||||
using WarpGemmMfma_f32_32x32x16_bf8_fp8_CTransposed =
|
||||
WarpGemmImpl<WarpGemmAttributeMfmaTransposedCDistribution<
|
||||
WarpGemmAttributeMfmaImpl_f32_32x32x16_bf8_fp8<WGAttrCtlEnum::Default_>>>;
|
||||
|
||||
using WarpGemmMfma_f32_32x32x16_bf8_bf8_CTransposed =
|
||||
WarpGemmImpl<WarpGemmAttributeMfmaTransposedCDistribution<
|
||||
WarpGemmAttributeMfmaImpl_f32_32x32x16_bf8_bf8<WGAttrCtlEnum::Default_>>>;
|
||||
|
||||
template <index_t swizzle_factor = 2>
|
||||
using WarpGemmMfmaFp8Fp8F32M32N32K32SwizzleBTransposedCDistribution =
|
||||
WarpGemmImpl<WarpGemmAttributeMfmaIterateKAndTransposedCDistribution_SwizzleB<
|
||||
WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_fp8<WGAttrCtlEnum::Default_>,
|
||||
2,
|
||||
swizzle_factor>>;
|
||||
|
||||
// int8
|
||||
using WarpGemmMfma_i32_32x32x16_i8_i8 = WarpGemmImpl<
|
||||
WarpGemmAttributeMfma<WarpGemmAttributeMfmaImpl_i32_32x32x16_i8<WGAttrCtlEnum::Default_>>>;
|
||||
|
||||
using WarpGemmMfma_i32_32x32x16_i8_i8_CTransposed =
|
||||
WarpGemmImpl<WarpGemmAttributeMfmaTransposedCDistribution<
|
||||
WarpGemmAttributeMfmaImpl_i32_32x32x16_i8<WGAttrCtlEnum::Default_>>>;
|
||||
|
||||
using WarpGemmMfma_i32_16x16x32_i8_i8 = WarpGemmImpl<
|
||||
WarpGemmAttributeMfma<WarpGemmAttributeMfmaImpl_i32_16x16x32_i8<WGAttrCtlEnum::Default_>>>;
|
||||
|
||||
using WarpGemmMfma_i32_16x16x32_i8_i8_CTransposed =
|
||||
WarpGemmImpl<WarpGemmAttributeMfmaTransposedCDistribution<
|
||||
WarpGemmAttributeMfmaImpl_i32_16x16x32_i8<WGAttrCtlEnum::Default_>>>;
|
||||
|
||||
} // namespace ck_tile
|
||||
928
include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp
Normal file
928
include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp
Normal file
@@ -0,0 +1,928 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// Number of groups of consecutive elements to fill in a ABKLane
|
||||
enum class WGAttrNumAccessEnum
|
||||
{
|
||||
Single = 1,
|
||||
Double = 2,
|
||||
Quad = 4,
|
||||
Invalid = -1
|
||||
};
|
||||
|
||||
template <WGAttrNumAccessEnum AttrNumAccess>
|
||||
struct get_wgattr_num_access
|
||||
{
|
||||
private:
|
||||
static constexpr index_t getAccesses()
|
||||
{
|
||||
if constexpr(AttrNumAccess == WGAttrNumAccessEnum::Single)
|
||||
{
|
||||
return 1;
|
||||
}
|
||||
else if constexpr(AttrNumAccess == WGAttrNumAccessEnum::Double)
|
||||
{
|
||||
return 2;
|
||||
}
|
||||
else if constexpr(AttrNumAccess == WGAttrNumAccessEnum::Quad)
|
||||
{
|
||||
return 4;
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(false, "unsupported AttrNumAccess");
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
public:
|
||||
static constexpr auto value = getAccesses();
|
||||
};
|
||||
|
||||
template <typename WarpGemmAttributeMfmaImpl_,
|
||||
WGAttrNumAccessEnum AttrNumAccessA_ = WGAttrNumAccessEnum::Single,
|
||||
WGAttrNumAccessEnum AttrNumAccessB_ = AttrNumAccessA_>
|
||||
struct WarpGemmAttributeMfma
|
||||
{
|
||||
using Impl = remove_cvref_t<WarpGemmAttributeMfmaImpl_>;
|
||||
static constexpr auto AttrNumAccessA = AttrNumAccessA_;
|
||||
static constexpr auto AttrNumAccessAV = get_wgattr_num_access<AttrNumAccessA>::value;
|
||||
static constexpr auto AttrNumAccessB = AttrNumAccessB_;
|
||||
static constexpr auto AttrNumAccessBV = get_wgattr_num_access<AttrNumAccessB>::value;
|
||||
|
||||
static constexpr bool UsePackNumAccess = AttrNumAccessA != AttrNumAccessB;
|
||||
|
||||
using ADataType = typename Impl::ADataType;
|
||||
using BDataType = typename Impl::BDataType;
|
||||
using CDataType = typename Impl::CDataType;
|
||||
|
||||
using AVecType = typename Impl::AVecType;
|
||||
using BVecType = typename Impl::BVecType;
|
||||
using CVecType = typename Impl::CVecType;
|
||||
|
||||
static constexpr index_t kM = Impl::kM;
|
||||
static constexpr index_t kN = Impl::kN;
|
||||
static constexpr index_t kK = Impl::kK;
|
||||
static constexpr index_t kKPerThread = Impl::kABKPerLane;
|
||||
static constexpr index_t kCMLane = Impl::kCMLane;
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return 1; }
|
||||
|
||||
static_assert(Impl::kAMBlock == 1 && Impl::kBNBlock == 1,
|
||||
"Multi-block WarpGemmAttributeMfmaImpl is not supported");
|
||||
|
||||
template <index_t kMNLane, index_t AttrNumAccessV_>
|
||||
static constexpr auto get_warp_dstr_encoding()
|
||||
{
|
||||
static_assert(kKPerThread % AttrNumAccessV_ == 0,
|
||||
"kKPerThread must be divisible by NumAccess");
|
||||
if constexpr(AttrNumAccessV_ == 1)
|
||||
{
|
||||
return tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<kMNLane>, sequence<Impl::kABKLane, Impl::kABKPerLane>>,
|
||||
tuple<sequence<2, 1>>,
|
||||
tuple<sequence<0, 0>>,
|
||||
sequence<2>,
|
||||
sequence<1>>{};
|
||||
}
|
||||
else
|
||||
{
|
||||
// AttrNumAccess splits the kABKPerLane
|
||||
// We can split them but still have them contiguous (packed) or have them interleaved.
|
||||
// The reason to split the dimension but still have it packed is to match load transpose
|
||||
// encoding when A and B use different AttrNumAccess (they have different types in LDS)
|
||||
// Example
|
||||
// A: 16bit, B: 8bit
|
||||
// Load transpose B: lane0 -> K=0..7 (only 1 instruction)
|
||||
// Load transpose A: lane0 -> K=0..3 first instruction, K=4..7 second instruction
|
||||
// In this way the data in register are consistent between A and B
|
||||
if constexpr(UsePackNumAccess)
|
||||
{
|
||||
return tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<kMNLane>,
|
||||
sequence<Impl::kABKLane,
|
||||
AttrNumAccessV_,
|
||||
Impl::kABKPerLane / AttrNumAccessV_>>,
|
||||
tuple<sequence<2, 1>>,
|
||||
tuple<sequence<0, 0>>,
|
||||
sequence<2, 2>,
|
||||
sequence<1, 2>>{};
|
||||
}
|
||||
else
|
||||
{
|
||||
return tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<kMNLane>,
|
||||
sequence<AttrNumAccessV_,
|
||||
Impl::kABKLane,
|
||||
Impl::kABKPerLane / AttrNumAccessV_>>,
|
||||
tuple<sequence<2, 1>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
sequence<2, 2>,
|
||||
sequence<0, 2>>{};
|
||||
}
|
||||
}
|
||||
}
|
||||
using AWarpDstrEncoding = decltype(get_warp_dstr_encoding<Impl::kAMLane, AttrNumAccessAV>());
|
||||
using BWarpDstrEncoding = decltype(get_warp_dstr_encoding<Impl::kBNLane, AttrNumAccessBV>());
|
||||
|
||||
using CWarpDstrEncoding = tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<Impl::kCM0PerLane, Impl::kCMLane, Impl::kCM1PerLane>,
|
||||
sequence<Impl::kCNLane>>,
|
||||
tuple<sequence<1, 2>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
sequence<1, 1>,
|
||||
sequence<0, 2>>;
|
||||
|
||||
// c_vec += a_vec * b_vec
|
||||
template <bool post_nop_ = false>
|
||||
CK_TILE_DEVICE void operator()(CVecType& c_vec,
|
||||
const AVecType& a_vec,
|
||||
const BVecType& b_vec,
|
||||
bool_constant<post_nop_> = {}) const
|
||||
{
|
||||
Impl{}(c_vec, a_vec, b_vec, bool_constant<post_nop_>{});
|
||||
}
|
||||
|
||||
// c_vec += a_vec * b_vec
|
||||
template <index_t opselA, index_t opselB, bool post_nop_ = false>
|
||||
CK_TILE_DEVICE void operator()(CVecType& c_vec,
|
||||
const AVecType& a_vec,
|
||||
const int32_t& a_scale,
|
||||
const BVecType& b_vec,
|
||||
const int32_t& b_scale,
|
||||
bool_constant<post_nop_> = {}) const
|
||||
{
|
||||
Impl{}.template operator()<opselA, opselB>(
|
||||
c_vec, a_vec, a_scale, b_vec, b_scale, bool_constant<post_nop_>{});
|
||||
}
|
||||
|
||||
// c_vec = a_vec * b_vec
|
||||
CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
|
||||
{
|
||||
return Impl{}(a_vec, b_vec);
|
||||
}
|
||||
|
||||
// c_vec = a_vec * b_vec
|
||||
template <index_t opselA, index_t opselB>
|
||||
CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec,
|
||||
const int32_t& a_scale,
|
||||
const BVecType& b_vec,
|
||||
const int32_t& b_scale) const
|
||||
{
|
||||
return Impl{}.template operator()<opselA, opselB>(a_vec, a_scale, b_vec, b_scale);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename WarpGemmAttributeMfmaImpl_,
|
||||
index_t kKIter,
|
||||
WGAttrNumAccessEnum AttrNumAccessA_ = WGAttrNumAccessEnum::Single,
|
||||
WGAttrNumAccessEnum AttrNumAccessB_ = AttrNumAccessA_>
|
||||
struct WarpGemmAttributeMfmaIterateK
|
||||
{
|
||||
static_assert(kKIter > 0, "wrong!");
|
||||
|
||||
using Impl = remove_cvref_t<WarpGemmAttributeMfmaImpl_>;
|
||||
static constexpr auto AttrNumAccessA = AttrNumAccessA_;
|
||||
static constexpr auto AttrNumAccessAV = get_wgattr_num_access<AttrNumAccessA>::value;
|
||||
static constexpr auto AttrNumAccessB = AttrNumAccessB_;
|
||||
static constexpr auto AttrNumAccessBV = get_wgattr_num_access<AttrNumAccessB>::value;
|
||||
|
||||
static constexpr bool UsePackNumAccess = AttrNumAccessA != AttrNumAccessB;
|
||||
|
||||
using ADataType = typename Impl::ADataType;
|
||||
using BDataType = typename Impl::BDataType;
|
||||
using CDataType = typename Impl::CDataType;
|
||||
|
||||
using AVecType =
|
||||
ext_vector_t<ADataType, vector_traits<typename Impl::AVecType>::vector_size * kKIter>;
|
||||
using BVecType =
|
||||
ext_vector_t<BDataType, vector_traits<typename Impl::BVecType>::vector_size * kKIter>;
|
||||
using CVecType = typename Impl::CVecType;
|
||||
|
||||
static constexpr index_t kM = Impl::kM;
|
||||
static constexpr index_t kN = Impl::kN;
|
||||
static constexpr index_t kK = Impl::kK * kKIter;
|
||||
static constexpr index_t kKPerThread = Impl::kABKPerLane * kKIter;
|
||||
static constexpr index_t kCMLane = Impl::kCMLane;
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return kKIter; }
|
||||
|
||||
static_assert(Impl::kAMBlock == 1 || Impl::kBNBlock == 1,
|
||||
"Multi-block on both M & N directions is not supported");
|
||||
|
||||
template <index_t kMNLane, index_t kMNBlock, index_t kNMBlock, index_t AttrNumAccessV_>
|
||||
CK_TILE_DEVICE static constexpr auto get_warp_dstr_encoding()
|
||||
{
|
||||
if constexpr(kMNBlock == 1 && kNMBlock == 1)
|
||||
{
|
||||
static_assert(kKPerThread % AttrNumAccessV_ == 0,
|
||||
"kKPerThread must be divisible by NumAccess");
|
||||
if constexpr(AttrNumAccessV_ == 1)
|
||||
{
|
||||
return tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<kMNLane>, sequence<Impl::kABKLane, Impl::kABKPerLane * kKIter>>,
|
||||
tuple<sequence<2, 1>>,
|
||||
tuple<sequence<0, 0>>,
|
||||
sequence<2>,
|
||||
sequence<1>>{};
|
||||
}
|
||||
else
|
||||
{
|
||||
if constexpr(UsePackNumAccess)
|
||||
{
|
||||
return tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<kMNLane>,
|
||||
sequence<Impl::kABKLane,
|
||||
AttrNumAccessV_,
|
||||
Impl::kABKPerLane * kKIter / AttrNumAccessV_>>,
|
||||
tuple<sequence<2, 1>>,
|
||||
tuple<sequence<0, 0>>,
|
||||
sequence<2, 2>,
|
||||
sequence<1, 2>>{};
|
||||
}
|
||||
else
|
||||
{
|
||||
return tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<kMNLane>,
|
||||
sequence<AttrNumAccessV_,
|
||||
Impl::kABKLane,
|
||||
Impl::kABKPerLane * kKIter / AttrNumAccessV_>>,
|
||||
tuple<sequence<2, 1>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
sequence<2, 2>,
|
||||
sequence<0, 2>>{};
|
||||
}
|
||||
}
|
||||
}
|
||||
else if constexpr(kMNBlock == 1 && 1 < kNMBlock)
|
||||
{
|
||||
static_assert(AttrNumAccessV_ == 1,
|
||||
"Multiple access is not supported when using multi-block");
|
||||
// each M/N blocks share the same data
|
||||
return tile_distribution_encoding<
|
||||
sequence<kNMBlock>,
|
||||
tuple<sequence<kMNLane>, sequence<Impl::kABKLane, Impl::kABKPerLane * kKIter>>,
|
||||
tuple<sequence<0, 2, 1>>,
|
||||
tuple<sequence<0, 0, 0>>,
|
||||
sequence<2>,
|
||||
sequence<1>>{};
|
||||
}
|
||||
else if constexpr(1 < kMNBlock && kNMBlock == 1)
|
||||
{
|
||||
static_assert(AttrNumAccessV_ == 1,
|
||||
"Multiple access is not supported when using multi-block");
|
||||
// single block to multi-block thread mapping
|
||||
return tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<kMNBlock, kMNLane>,
|
||||
sequence<Impl::kABKLane, Impl::kABKPerLane * kKIter>>,
|
||||
tuple<sequence<1, 2, 1>>,
|
||||
tuple<sequence<0, 0, 1>>,
|
||||
sequence<2>,
|
||||
sequence<1>>{};
|
||||
}
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE static constexpr auto get_cwarp_dstr_encoding()
|
||||
{
|
||||
if constexpr(Impl::kAMBlock == 1 && Impl::kBNBlock == 1)
|
||||
{
|
||||
return tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<Impl::kCM0PerLane, Impl::kCMLane, Impl::kCM1PerLane>,
|
||||
sequence<Impl::kCNLane>>,
|
||||
tuple<sequence<1, 2>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
sequence<1, 1>,
|
||||
sequence<0, 2>>{};
|
||||
}
|
||||
else if constexpr(Impl::kAMBlock == 1 && 1 < Impl::kBNBlock)
|
||||
{
|
||||
return tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<Impl::kCM0PerLane, Impl::kCMLane, Impl::kCM1PerLane>,
|
||||
sequence<Impl::kBNBlock * Impl::kCNLane>>,
|
||||
tuple<sequence<1, 2>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
sequence<1, 1>,
|
||||
sequence<0, 2>>{};
|
||||
}
|
||||
else if constexpr(1 < Impl::kAMBlock && Impl::kBNBlock == 1)
|
||||
{
|
||||
return tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<
|
||||
sequence<Impl::kCM0PerLane, Impl::kAMBlock * Impl::kCMLane, Impl::kCM1PerLane>,
|
||||
sequence<Impl::kCNLane>>,
|
||||
tuple<sequence<1, 2>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
sequence<1, 1>,
|
||||
sequence<0, 2>>{};
|
||||
}
|
||||
}
|
||||
|
||||
using AWarpDstrEncoding = decltype(get_warp_dstr_encoding<Impl::kAMLane,
|
||||
Impl::kAMBlock,
|
||||
Impl::kBNBlock,
|
||||
AttrNumAccessAV>());
|
||||
using BWarpDstrEncoding = decltype(get_warp_dstr_encoding<Impl::kBNLane,
|
||||
Impl::kBNBlock,
|
||||
Impl::kAMBlock,
|
||||
AttrNumAccessBV>());
|
||||
using CWarpDstrEncoding = decltype(get_cwarp_dstr_encoding());
|
||||
|
||||
// c_vec += a_vec * b_vec
|
||||
template <bool post_nop_ = false>
|
||||
CK_TILE_DEVICE void operator()(CVecType& c_vec,
|
||||
const AVecType& a_vec,
|
||||
const BVecType& b_vec,
|
||||
bool_constant<post_nop_> = {}) const
|
||||
{
|
||||
static_for<0, kKIter, 1>{}([&](auto iKIter) { operator()(c_vec, a_vec, b_vec, iKIter); });
|
||||
}
|
||||
|
||||
template <index_t iKIter, bool post_nop_ = false>
|
||||
CK_TILE_DEVICE void operator()(CVecType& c_vec,
|
||||
const AVecType& a_vec,
|
||||
const BVecType& b_vec,
|
||||
number<iKIter>,
|
||||
bool_constant<post_nop_> = {}) const
|
||||
{
|
||||
using buf_a = thread_buffer<typename Impl::AVecType, kKIter>;
|
||||
using buf_b = thread_buffer<typename Impl::BVecType, kKIter>;
|
||||
|
||||
static_assert(iKIter < kKIter);
|
||||
|
||||
Impl{}(c_vec,
|
||||
reinterpret_cast<const buf_a&>(a_vec)
|
||||
.template get_as<typename Impl::AVecType>()[iKIter],
|
||||
reinterpret_cast<const buf_b&>(b_vec)
|
||||
.template get_as<typename Impl::BVecType>()[iKIter],
|
||||
bool_constant<post_nop_>{});
|
||||
}
|
||||
|
||||
// c_vec = a_vec * b_vec
|
||||
CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
|
||||
{
|
||||
constexpr auto I0 = number<0>{};
|
||||
using buf_a = thread_buffer<typename Impl::AVecType, kKIter>;
|
||||
using buf_b = thread_buffer<typename Impl::BVecType, kKIter>;
|
||||
|
||||
// c = a * b
|
||||
auto c_vec = Impl{}(
|
||||
reinterpret_cast<const buf_a&>(a_vec).template get_as<typename Impl::AVecType>()[I0],
|
||||
reinterpret_cast<const buf_b&>(b_vec).template get_as<typename Impl::BVecType>()[I0]);
|
||||
|
||||
// c += a * b
|
||||
static_for<1, kKIter, 1>{}([&](auto iKIter) { operator()(c_vec, a_vec, b_vec, iKIter); });
|
||||
|
||||
return c_vec;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename WarpGemmAttributeMfmaImpl_,
|
||||
WGAttrNumAccessEnum AttrNumAccess_ = WGAttrNumAccessEnum::Single>
|
||||
struct WarpGemmAttributeMfmaTransposedCDistribution
|
||||
{
|
||||
using Impl = remove_cvref_t<WarpGemmAttributeMfmaImpl_>;
|
||||
static constexpr auto AttrNumAccess = AttrNumAccess_;
|
||||
static constexpr auto AttrNumAccessV = static_cast<index_t>(AttrNumAccess);
|
||||
|
||||
using ADataType = typename Impl::BDataType;
|
||||
using BDataType = typename Impl::ADataType;
|
||||
using CDataType = typename Impl::CDataType;
|
||||
|
||||
using AVecType = typename Impl::BVecType;
|
||||
using BVecType = typename Impl::AVecType;
|
||||
using CVecType = typename Impl::CVecType;
|
||||
|
||||
static constexpr index_t kM = Impl::kN;
|
||||
static constexpr index_t kN = Impl::kM;
|
||||
static constexpr index_t kK = Impl::kK;
|
||||
static constexpr index_t kKPerThread = Impl::kABKPerLane;
|
||||
static constexpr index_t kCMLane = Impl::kCMLane;
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return 1; }
|
||||
|
||||
static_assert(Impl::kAMBlock == 1 && Impl::kBNBlock == 1,
|
||||
"Multi-block WarpGemmAttributeMfmaImpl is not supported");
|
||||
|
||||
using AWarpDstrEncoding =
|
||||
typename WarpGemmAttributeMfma<Impl, AttrNumAccess>::BWarpDstrEncoding;
|
||||
using BWarpDstrEncoding =
|
||||
typename WarpGemmAttributeMfma<Impl, AttrNumAccess>::AWarpDstrEncoding;
|
||||
|
||||
using CWarpDstrEncoding = tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<Impl::kCNLane>,
|
||||
sequence<Impl::kCM0PerLane, Impl::kCMLane, Impl::kCM1PerLane>>,
|
||||
tuple<sequence<2, 1>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
sequence<2, 2>,
|
||||
sequence<0, 2>>;
|
||||
|
||||
// c_vec += a_vec * b_vec
|
||||
template <bool post_nop_ = false>
|
||||
CK_TILE_DEVICE void operator()(CVecType& c_vec,
|
||||
const AVecType& a_vec,
|
||||
const BVecType& b_vec,
|
||||
bool_constant<post_nop_> = {}) const
|
||||
{
|
||||
// swap A and B
|
||||
Impl{}(c_vec, b_vec, a_vec, bool_constant<post_nop_>{});
|
||||
}
|
||||
|
||||
template <index_t opselA, index_t opselB, bool post_nop_ = false>
|
||||
CK_TILE_DEVICE void operator()(CVecType& c_vec,
|
||||
const AVecType& a_vec,
|
||||
const int32_t& a_scale,
|
||||
const BVecType& b_vec,
|
||||
const int32_t& b_scale,
|
||||
bool_constant<post_nop_> = {}) const
|
||||
{
|
||||
// swap A and B
|
||||
Impl{}.template operator()<opselB, opselA>(
|
||||
c_vec, b_vec, b_scale, a_vec, a_scale, bool_constant<post_nop_>{});
|
||||
}
|
||||
|
||||
// c_vec = a_vec * b_vec
|
||||
CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
|
||||
{
|
||||
// swap A and B
|
||||
return Impl{}(b_vec, a_vec);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename WarpGemmAttributeMfmaImpl_, index_t SFactor_ = 2>
|
||||
struct WarpGemmAttributeMfmaTransposedCDistribution_SwizzleB
|
||||
{
|
||||
using Impl = remove_cvref_t<WarpGemmAttributeMfmaImpl_>;
|
||||
|
||||
using ADataType = typename Impl::BDataType;
|
||||
using BDataType = typename Impl::ADataType;
|
||||
using CDataType = typename Impl::CDataType;
|
||||
|
||||
using AVecType = typename Impl::BVecType;
|
||||
using BVecType = typename Impl::AVecType;
|
||||
using CVecType = typename Impl::CVecType;
|
||||
|
||||
static constexpr index_t kM = Impl::kN;
|
||||
static constexpr index_t kN = Impl::kM;
|
||||
static constexpr index_t kK = Impl::kK;
|
||||
static constexpr index_t kKPerThread = Impl::kABKPerLane;
|
||||
static constexpr index_t SFactor = SFactor_; // group how many CM1 together
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return 1; }
|
||||
|
||||
static_assert(Impl::kAMBlock == 1 && Impl::kBNBlock == 1,
|
||||
"Multi-block WarpGemmAttributeMfmaImpl is not supported");
|
||||
|
||||
using AWarpDstrEncoding = tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<Impl::kBNLane>, sequence<Impl::kABKLane, Impl::kABKPerLane>>,
|
||||
tuple<sequence<2, 1>>,
|
||||
tuple<sequence<0, 0>>,
|
||||
sequence<2>,
|
||||
sequence<1>>;
|
||||
#if 0
|
||||
using BWarpDstrEncoding = tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<Impl::kAMLane / (Impl::kABKPerLane * Impl::kABKLane * 2),
|
||||
Impl::kABKLane,
|
||||
2,
|
||||
Impl::kABKPerLane>,
|
||||
sequence<Impl::kABKLane, Impl::kABKPerLane>>,
|
||||
tuple<sequence<2, 1, 1, 1, 1>>,
|
||||
tuple<sequence<0, 0, 2, 1, 3>>,
|
||||
sequence<2>,
|
||||
sequence<1>>;
|
||||
|
||||
using CWarpDstrEncoding = tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<Impl::kCNLane>,
|
||||
sequence<Impl::kCM0PerLane / 2, Impl::kCMLane, Impl::kCM1PerLane * 2>>,
|
||||
tuple<sequence<2, 1>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
sequence<2, 2>,
|
||||
sequence<0, 2>>;
|
||||
#else
|
||||
// TODO: more test not only 32x32
|
||||
using BWarpDstrEncoding = tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<Impl::kAMLane / (Impl::kCMLane * SFactor * Impl::kCM1PerLane),
|
||||
Impl::kCMLane,
|
||||
SFactor,
|
||||
Impl::kCM1PerLane>,
|
||||
sequence<Impl::kABKLane, Impl::kABKPerLane>>,
|
||||
tuple<sequence<2, 1, 1, 1, 1>>,
|
||||
tuple<sequence<0, 0, 2, 1, 3>>,
|
||||
sequence<2>,
|
||||
sequence<1>>;
|
||||
|
||||
using CWarpDstrEncoding = tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<Impl::kCNLane>,
|
||||
sequence<Impl::kCM0PerLane / SFactor, Impl::kCMLane, Impl::kCM1PerLane * SFactor>>,
|
||||
tuple<sequence<2, 1>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
sequence<2, 2>,
|
||||
sequence<0, 2>>;
|
||||
#endif
|
||||
template <bool post_nop_ = false>
|
||||
// c_vec += a_vec * b_vec
|
||||
CK_TILE_DEVICE void operator()(CVecType& c_vec,
|
||||
const AVecType& a_vec,
|
||||
const BVecType& b_vec,
|
||||
bool_constant<post_nop_> = {}) const
|
||||
{
|
||||
// swap A and B
|
||||
Impl{}(c_vec, b_vec, a_vec, bool_constant<post_nop_>{});
|
||||
}
|
||||
|
||||
template <index_t opselA, index_t opselB, bool post_nop_ = false>
|
||||
CK_TILE_DEVICE void operator()(CVecType& c_vec,
|
||||
const AVecType& a_vec,
|
||||
const int32_t& a_scale,
|
||||
const BVecType& b_vec,
|
||||
const int32_t& b_scale,
|
||||
bool_constant<post_nop_> = {}) const
|
||||
{
|
||||
// swap A and B
|
||||
Impl{}.template operator()<opselB, opselA>(
|
||||
c_vec, b_vec, b_scale, a_vec, a_scale, bool_constant<post_nop_>{});
|
||||
}
|
||||
|
||||
// c_vec = a_vec * b_vec
|
||||
CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
|
||||
{
|
||||
// swap A and B
|
||||
return Impl{}(b_vec, a_vec);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename WarpGemmAttributeMfmaImpl_,
|
||||
index_t kKIter,
|
||||
WGAttrNumAccessEnum AttrNumAccess_ = WGAttrNumAccessEnum::Single>
|
||||
struct WarpGemmAttributeMfmaIterateKAndTransposedCDistribution
|
||||
{
|
||||
using Impl = remove_cvref_t<WarpGemmAttributeMfmaImpl_>;
|
||||
static constexpr auto AttrNumAccess = AttrNumAccess_;
|
||||
|
||||
// swap A and B
|
||||
using ADataType = typename Impl::BDataType;
|
||||
using BDataType = typename Impl::ADataType;
|
||||
using CDataType = typename Impl::CDataType;
|
||||
|
||||
using AVecType =
|
||||
ext_vector_t<ADataType, vector_traits<typename Impl::AVecType>::vector_size * kKIter>;
|
||||
using BVecType =
|
||||
ext_vector_t<BDataType, vector_traits<typename Impl::BVecType>::vector_size * kKIter>;
|
||||
using CVecType = typename Impl::CVecType;
|
||||
|
||||
static constexpr index_t kM = Impl::kN;
|
||||
static constexpr index_t kN = Impl::kM;
|
||||
static constexpr index_t kK = Impl::kK * kKIter;
|
||||
static constexpr index_t kKPerThread = Impl::kABKPerLane * kKIter;
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return kKIter; }
|
||||
|
||||
static_assert(Impl::kAMBlock == 1 || Impl::kBNBlock == 1,
|
||||
"Multi-block on both M & N directions is not supported");
|
||||
|
||||
CK_TILE_DEVICE static constexpr auto get_cwarp_dstr_encoding()
|
||||
{
|
||||
if constexpr(Impl::kAMBlock == 1 && Impl::kBNBlock == 1)
|
||||
{
|
||||
return tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<Impl::kCNLane>,
|
||||
sequence<Impl::kCM0PerLane, Impl::kCMLane, Impl::kCM1PerLane>>,
|
||||
tuple<sequence<2, 1>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
sequence<2, 2>,
|
||||
sequence<0, 2>>{};
|
||||
}
|
||||
else if constexpr(Impl::kAMBlock == 1 && 1 < Impl::kBNBlock)
|
||||
{
|
||||
return tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<Impl::kBNBlock * Impl::kCNLane>,
|
||||
sequence<Impl::kCM0PerLane, Impl::kCMLane, Impl::kCM1PerLane>>,
|
||||
tuple<sequence<2, 1>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
sequence<2, 2>,
|
||||
sequence<0, 2>>{};
|
||||
}
|
||||
else if constexpr(1 < Impl::kAMBlock && Impl::kBNBlock == 1)
|
||||
{
|
||||
return tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<
|
||||
sequence<Impl::kCNLane>,
|
||||
sequence<Impl::kCM0PerLane, Impl::kAMBlock * Impl::kCMLane, Impl::kCM1PerLane>>,
|
||||
tuple<sequence<2, 1>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
sequence<2, 2>,
|
||||
sequence<0, 2>>{};
|
||||
}
|
||||
}
|
||||
|
||||
using AWarpDstrEncoding =
|
||||
typename WarpGemmAttributeMfmaIterateK<Impl, kKIter, AttrNumAccess>::BWarpDstrEncoding;
|
||||
using BWarpDstrEncoding =
|
||||
typename WarpGemmAttributeMfmaIterateK<Impl, kKIter, AttrNumAccess>::AWarpDstrEncoding;
|
||||
using CWarpDstrEncoding = decltype(get_cwarp_dstr_encoding());
|
||||
|
||||
// c_vec += a_vec * b_vec
|
||||
template <bool post_nop_ = false>
|
||||
CK_TILE_DEVICE void operator()(CVecType& c_vec,
|
||||
const AVecType& a_vec,
|
||||
const BVecType& b_vec,
|
||||
bool_constant<post_nop_> = {}) const
|
||||
{
|
||||
static_for<0, kKIter, 1>{}([&](auto iKIter) { operator()(c_vec, a_vec, b_vec, iKIter); });
|
||||
}
|
||||
|
||||
template <index_t iKIter, bool post_nop_ = false>
|
||||
// c_vec += a_vec * b_vec
|
||||
CK_TILE_DEVICE void operator()(CVecType& c_vec,
|
||||
const AVecType& a_vec,
|
||||
const BVecType& b_vec,
|
||||
number<iKIter>,
|
||||
bool_constant<post_nop_> = {}) const
|
||||
{
|
||||
using buf_a = thread_buffer<typename Impl::AVecType, kKIter>;
|
||||
using buf_b = thread_buffer<typename Impl::BVecType, kKIter>;
|
||||
|
||||
static_assert(iKIter < kKIter);
|
||||
// swap A and B, value and type
|
||||
Impl{}(c_vec,
|
||||
reinterpret_cast<const buf_b&>(b_vec)
|
||||
.template get_as<typename Impl::BVecType>()[iKIter],
|
||||
reinterpret_cast<const buf_a&>(a_vec)
|
||||
.template get_as<typename Impl::AVecType>()[iKIter],
|
||||
bool_constant<post_nop_>{});
|
||||
}
|
||||
|
||||
// c_vec = a_vec * b_vec
|
||||
CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
|
||||
{
|
||||
constexpr auto I0 = number<0>{};
|
||||
using buf_a = thread_buffer<typename Impl::AVecType, kKIter>;
|
||||
using buf_b = thread_buffer<typename Impl::BVecType, kKIter>;
|
||||
|
||||
// swap A and B, value and type
|
||||
auto c_vec = Impl{}(
|
||||
reinterpret_cast<const buf_b&>(b_vec).template get_as<typename Impl::BVecType>()[I0],
|
||||
reinterpret_cast<const buf_a&>(a_vec).template get_as<typename Impl::AVecType>()[I0]);
|
||||
|
||||
static_for<1, kKIter, 1>{}([&](auto iKIter) { operator()(c_vec, a_vec, b_vec, iKIter); });
|
||||
|
||||
return c_vec;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename WarpGemmAttributeMfmaImpl_, index_t kKIter, index_t SFactor_ = 2>
|
||||
struct WarpGemmAttributeMfmaIterateKAndTransposedCDistribution_SwizzleB
|
||||
{
|
||||
using Impl = remove_cvref_t<WarpGemmAttributeMfmaImpl_>;
|
||||
|
||||
// swap A and B
|
||||
using ADataType = typename Impl::BDataType;
|
||||
using BDataType = typename Impl::ADataType;
|
||||
using CDataType = typename Impl::CDataType;
|
||||
|
||||
using AVecType =
|
||||
ext_vector_t<ADataType, vector_traits<typename Impl::AVecType>::vector_size * kKIter>;
|
||||
using BVecType =
|
||||
ext_vector_t<BDataType, vector_traits<typename Impl::BVecType>::vector_size * kKIter>;
|
||||
using CVecType = typename Impl::CVecType;
|
||||
|
||||
static constexpr index_t kM = Impl::kN;
|
||||
static constexpr index_t kN = Impl::kM;
|
||||
static constexpr index_t kK = Impl::kK * kKIter;
|
||||
static constexpr index_t kKPerThread = Impl::kABKPerLane * kKIter;
|
||||
static constexpr index_t SFactor = SFactor_; // group how many CM1 together
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return kKIter; }
|
||||
|
||||
static_assert(Impl::kAMBlock == 1 && Impl::kBNBlock == 1,
|
||||
"Multi-block WarpGemmAttributeMfmaImpl is not supported");
|
||||
|
||||
using AWarpDstrEncoding = tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<Impl::kBNLane>, sequence<Impl::kABKLane, Impl::kABKPerLane * kKIter>>,
|
||||
tuple<sequence<2, 1>>,
|
||||
tuple<sequence<0, 0>>,
|
||||
sequence<2>,
|
||||
sequence<1>>;
|
||||
#if 0
|
||||
using BWarpDstrEncoding = tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<Impl::kAMLane / (Impl::kABKPerLane * Impl::kABKLane * 2),
|
||||
Impl::kABKLane,
|
||||
2,
|
||||
Impl::kABKPerLane>,
|
||||
sequence<Impl::kABKLane, Impl::kABKPerLane * kKIter>>,
|
||||
tuple<sequence<2, 1, 1, 1, 1>>,
|
||||
tuple<sequence<0, 0, 2, 1, 3>>,
|
||||
sequence<2>,
|
||||
sequence<1>>;
|
||||
|
||||
using CWarpDstrEncoding = tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<Impl::kCNLane>,
|
||||
sequence<Impl::kCM0PerLane / 2, Impl::kCMLane, Impl::kCM1PerLane * 2>>,
|
||||
tuple<sequence<2, 1>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
sequence<2, 2>,
|
||||
sequence<0, 2>>;
|
||||
#else
|
||||
// TODO: more test not only 32x32
|
||||
using BWarpDstrEncoding = tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<Impl::kAMLane / (Impl::kCMLane * SFactor * Impl::kCM1PerLane),
|
||||
Impl::kCMLane,
|
||||
SFactor,
|
||||
Impl::kCM1PerLane>,
|
||||
sequence<Impl::kABKLane, Impl::kABKPerLane * kKIter>>,
|
||||
tuple<sequence<2, 1, 1, 1, 1>>,
|
||||
tuple<sequence<0, 0, 2, 1, 3>>,
|
||||
sequence<2>,
|
||||
sequence<1>>;
|
||||
|
||||
using CWarpDstrEncoding = tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<Impl::kCNLane>,
|
||||
sequence<Impl::kCM0PerLane / SFactor, Impl::kCMLane, Impl::kCM1PerLane * SFactor>>,
|
||||
tuple<sequence<2, 1>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
sequence<2, 2>,
|
||||
sequence<0, 2>>;
|
||||
#endif
|
||||
// c_vec += a_vec * b_vec
|
||||
template <bool post_nop_ = false>
|
||||
CK_TILE_DEVICE void operator()(CVecType& c_vec,
|
||||
const AVecType& a_vec,
|
||||
const BVecType& b_vec,
|
||||
bool_constant<post_nop_> = {}) const
|
||||
{
|
||||
// swap A and B, value and type
|
||||
static_for<0, kKIter, 1>{}([&](auto iKIter) { operator()(c_vec, a_vec, b_vec, iKIter); });
|
||||
}
|
||||
|
||||
template <index_t iKIter, bool post_nop_ = false>
|
||||
CK_TILE_DEVICE void operator()(CVecType& c_vec,
|
||||
const AVecType& a_vec,
|
||||
const BVecType& b_vec,
|
||||
number<iKIter>,
|
||||
bool_constant<post_nop_> = {}) const
|
||||
{
|
||||
using buf_a = thread_buffer<typename Impl::AVecType, kKIter>;
|
||||
using buf_b = thread_buffer<typename Impl::BVecType, kKIter>;
|
||||
|
||||
static_assert(iKIter < kKIter);
|
||||
// swap A and B, value and type
|
||||
Impl{}(c_vec,
|
||||
reinterpret_cast<const buf_b&>(b_vec)
|
||||
.template get_as<typename Impl::BVecType>()[iKIter],
|
||||
reinterpret_cast<const buf_a&>(a_vec)
|
||||
.template get_as<typename Impl::AVecType>()[iKIter],
|
||||
bool_constant<post_nop_>{});
|
||||
}
|
||||
|
||||
// c_vec = a_vec * b_vec
|
||||
CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
|
||||
{
|
||||
using buf_a = thread_buffer<typename Impl::AVecType, kKIter>;
|
||||
using buf_b = thread_buffer<typename Impl::BVecType, kKIter>;
|
||||
constexpr auto I0 = number<0>{};
|
||||
|
||||
// swap A and B, value and type
|
||||
auto c_vec = Impl{}(
|
||||
reinterpret_cast<const buf_b&>(b_vec).template get_as<typename Impl::BVecType>()[I0],
|
||||
reinterpret_cast<const buf_a&>(a_vec).template get_as<typename Impl::AVecType>()[I0]);
|
||||
|
||||
static_for<1, kKIter, 1>{}([&](auto iKIter) { operator()(c_vec, a_vec, b_vec, iKIter); });
|
||||
|
||||
return c_vec;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename WarpGemmAttributeMfmaImpl_, index_t kKIter, index_t SFactor_ = 2>
|
||||
struct WarpGemmAttributeMfmaIterateK_SwizzleA
|
||||
{
|
||||
using Impl = remove_cvref_t<WarpGemmAttributeMfmaImpl_>;
|
||||
|
||||
using ADataType = typename Impl::ADataType;
|
||||
using BDataType = typename Impl::BDataType;
|
||||
using CDataType = typename Impl::CDataType;
|
||||
|
||||
using AVecType =
|
||||
ext_vector_t<ADataType, vector_traits<typename Impl::AVecType>::vector_size * kKIter>;
|
||||
using BVecType =
|
||||
ext_vector_t<BDataType, vector_traits<typename Impl::BVecType>::vector_size * kKIter>;
|
||||
using CVecType = typename Impl::CVecType;
|
||||
|
||||
static constexpr index_t kM = Impl::kM;
|
||||
static constexpr index_t kN = Impl::kN;
|
||||
static constexpr index_t kK = Impl::kK * kKIter;
|
||||
static constexpr index_t kKPerThread = Impl::kABKPerLane * kKIter;
|
||||
static constexpr index_t SFactor = SFactor_; // group how many CM1 together
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return kKIter; }
|
||||
|
||||
static_assert(Impl::kAMBlock == 1 && Impl::kBNBlock == 1,
|
||||
"Multi-block WarpGemmAttributeMfmaImpl is not supported");
|
||||
|
||||
using AWarpDstrEncoding = tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<Impl::kAMLane / (Impl::kCMLane * SFactor * Impl::kCM1PerLane),
|
||||
Impl::kCMLane,
|
||||
SFactor,
|
||||
Impl::kCM1PerLane>,
|
||||
sequence<Impl::kABKLane, Impl::kABKPerLane * kKIter>>,
|
||||
tuple<sequence<2, 1, 1, 1, 1>>,
|
||||
tuple<sequence<0, 0, 2, 1, 3>>,
|
||||
sequence<2>,
|
||||
sequence<1>>;
|
||||
|
||||
using BWarpDstrEncoding = tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<Impl::kBNLane>, sequence<Impl::kABKLane, Impl::kABKPerLane * kKIter>>,
|
||||
tuple<sequence<2, 1>>,
|
||||
tuple<sequence<0, 0>>,
|
||||
sequence<2>,
|
||||
sequence<1>>;
|
||||
|
||||
using CWarpDstrEncoding = tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<Impl::kCM0PerLane / SFactor, Impl::kCMLane, Impl::kCM1PerLane * SFactor>,
|
||||
sequence<Impl::kCNLane>>,
|
||||
tuple<sequence<1, 2>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
sequence<1, 1>,
|
||||
sequence<0, 2>>;
|
||||
|
||||
// c_vec += a_vec * b_vec
|
||||
template <bool post_nop_ = false>
|
||||
CK_TILE_DEVICE void operator()(CVecType& c_vec,
|
||||
const AVecType& a_vec,
|
||||
const BVecType& b_vec,
|
||||
bool_constant<post_nop_> = {}) const
|
||||
{
|
||||
static_for<0, kKIter, 1>{}([&](auto iKIter) { operator()(c_vec, a_vec, b_vec, iKIter); });
|
||||
}
|
||||
|
||||
template <index_t iKIter, bool post_nop_ = false>
|
||||
CK_TILE_DEVICE void operator()(CVecType& c_vec,
|
||||
const AVecType& a_vec,
|
||||
const BVecType& b_vec,
|
||||
number<iKIter>,
|
||||
bool_constant<post_nop_> = {}) const
|
||||
{
|
||||
using buf_a = thread_buffer<typename Impl::AVecType, kKIter>;
|
||||
using buf_b = thread_buffer<typename Impl::BVecType, kKIter>;
|
||||
|
||||
static_assert(iKIter < kKIter);
|
||||
|
||||
Impl{}(c_vec,
|
||||
reinterpret_cast<const buf_a&>(a_vec)
|
||||
.template get_as<typename Impl::AVecType>()[iKIter],
|
||||
reinterpret_cast<const buf_b&>(b_vec)
|
||||
.template get_as<typename Impl::BVecType>()[iKIter],
|
||||
bool_constant<post_nop_>{});
|
||||
}
|
||||
|
||||
// c_vec = a_vec * b_vec
|
||||
CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
|
||||
{
|
||||
constexpr auto I0 = number<0>{};
|
||||
using buf_a = thread_buffer<typename Impl::AVecType, kKIter>;
|
||||
using buf_b = thread_buffer<typename Impl::BVecType, kKIter>;
|
||||
|
||||
auto c_vec = Impl{}(
|
||||
reinterpret_cast<const buf_a&>(a_vec).template get_as<typename Impl::AVecType>()[I0],
|
||||
reinterpret_cast<const buf_b&>(b_vec).template get_as<typename Impl::BVecType>()[I0]);
|
||||
|
||||
static_for<1, kKIter, 1>{}([&](auto iKIter) { operator()(c_vec, a_vec, b_vec, iKIter); });
|
||||
|
||||
return c_vec;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
2064
include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp
Normal file
2064
include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp
Normal file
File diff suppressed because it is too large
Load Diff
85
include/ck_tile/ops/gemm/warp/warp_gemm_attribute_smfmac.hpp
Normal file
85
include/ck_tile/ops/gemm/warp/warp_gemm_attribute_smfmac.hpp
Normal file
@@ -0,0 +1,85 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm_attribute_smfmac_impl.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
/**
|
||||
* @brief Class describing structured sparsity mfma instructions.
|
||||
*
|
||||
* @paragraph Overview "Overview"
|
||||
* Currently only 2:4 structured sparsity is supported, which is based on requirement that in every
|
||||
* groups of four continuous elements there are at most two non-zero, which results in processing
|
||||
* only half of elements in smfmac instruction. Because of structured sparsity A vector in smfmac
|
||||
* instruction will be smaller than B vector by the factor of CompressionRatio. The indexes of
|
||||
* non-zero elements are stored in `index` which is an additional parameter to assembly instruction.
|
||||
* Every pair of two bit indexes are containing information about which two elements in current
|
||||
* group of 4 values are non-zero and should be used inside smfmac instruction. Structured sparsity
|
||||
* format is supported only for A matrix for now.
|
||||
*/
|
||||
template <typename WarpGemmAttributeSmfmacImpl_>
|
||||
struct WarpGemmAttributeSmfmac
|
||||
{
|
||||
using Impl = remove_cvref_t<WarpGemmAttributeSmfmacImpl_>;
|
||||
|
||||
using ADataType = typename Impl::ADataType;
|
||||
using BDataType = typename Impl::BDataType;
|
||||
using IdxDataType = typename Impl::IdxDataType;
|
||||
using CDataType = typename Impl::CDataType;
|
||||
|
||||
using AVecType = typename Impl::AVecType;
|
||||
using BVecType = typename Impl::BVecType;
|
||||
using CVecType = typename Impl::CVecType;
|
||||
|
||||
static constexpr index_t kM = Impl::kM;
|
||||
static constexpr index_t kN = Impl::kN;
|
||||
static constexpr index_t kK = Impl::kK;
|
||||
static constexpr index_t kKPerThread = Impl::kABKPerLane;
|
||||
static constexpr index_t kCompressionRatio = Impl::CompressionRatio;
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return 1; }
|
||||
|
||||
static_assert(Impl::kAMBlock == 1 && Impl::kBNBlock == 1,
|
||||
"Multi-block WarpGemmAttributeSmfmacImpl is not supported");
|
||||
|
||||
using AWarpDstrEncoding = tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<Impl::kAMLane>, sequence<Impl::kABKLane, Impl::kABKPerLane>>,
|
||||
tuple<sequence<2, 1>>,
|
||||
tuple<sequence<0, 0>>,
|
||||
sequence<2>,
|
||||
sequence<1>>;
|
||||
|
||||
using BWarpDstrEncoding = tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<Impl::kBNLane>, sequence<Impl::kABKLane, Impl::kABKPerLane>>,
|
||||
tuple<sequence<2, 1>>,
|
||||
tuple<sequence<0, 0>>,
|
||||
sequence<2>,
|
||||
sequence<1>>;
|
||||
|
||||
using CWarpDstrEncoding = tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<Impl::kCM0PerLane, Impl::kCMLane, Impl::kCM1PerLane>,
|
||||
sequence<Impl::kCNLane>>,
|
||||
tuple<sequence<1, 2>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
sequence<1, 1>,
|
||||
sequence<0, 2>>;
|
||||
|
||||
// c_vec += a_vec * b_vec[idx]
|
||||
template <bool post_nop_ = false>
|
||||
CK_TILE_DEVICE void operator()(CVecType& c_vec,
|
||||
const AVecType& a_vec,
|
||||
const BVecType& b_vec,
|
||||
const int32_t& idx,
|
||||
bool_constant<post_nop_> = {}) const
|
||||
{
|
||||
Impl{}(c_vec, a_vec, b_vec, idx, bool_constant<post_nop_>{});
|
||||
}
|
||||
};
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,114 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "warp_gemm_attribute_mfma_impl.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// fp16 2:4 structured sparsity
|
||||
|
||||
template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
|
||||
struct WarpGemmAttributeSmfmacImplF16F16F32M32N32K16
|
||||
{
|
||||
static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
|
||||
using ADataType = fp16_t;
|
||||
using BDataType = fp16_t;
|
||||
using IdxDataType = int32_t;
|
||||
using CDataType = float;
|
||||
|
||||
using AVecType = ext_vector_t<fp16_t, 4>;
|
||||
using BVecType = ext_vector_t<fp16_t, 8>;
|
||||
using CVecType = ext_vector_t<float, 16>;
|
||||
|
||||
static constexpr index_t kM = 32;
|
||||
static constexpr index_t kN = 32;
|
||||
static constexpr index_t kK = 16;
|
||||
|
||||
static constexpr index_t kAMBlock = 1;
|
||||
static constexpr index_t kBNBlock = 1;
|
||||
|
||||
static constexpr index_t kAMLane = 32;
|
||||
static constexpr index_t kBNLane = 32;
|
||||
static constexpr index_t kABKLane = 2;
|
||||
static constexpr index_t kABKPerLane = 8;
|
||||
|
||||
static constexpr index_t kCMLane = 2;
|
||||
static constexpr index_t kCNLane = 32;
|
||||
static constexpr index_t kCM0PerLane = 4;
|
||||
static constexpr index_t kCM1PerLane = 4;
|
||||
|
||||
static constexpr index_t CompressionRatio = 2;
|
||||
|
||||
// c_vec += a_vec * b_vec[idx]
|
||||
template <bool post_nop_ = false>
|
||||
CK_TILE_DEVICE void operator()(CVecType& c_vec,
|
||||
const AVecType& a_vec,
|
||||
const BVecType& b_vec,
|
||||
const int32_t& idx,
|
||||
bool_constant<post_nop_> = {}) const
|
||||
{
|
||||
#if defined(__gfx94_) or defined(__gfx95_)
|
||||
c_vec = __builtin_amdgcn_smfmac_f32_32x32x16_f16(a_vec, b_vec, c_vec, idx, 0, 0);
|
||||
#else
|
||||
ck_tile::ignore = c_vec;
|
||||
ck_tile::ignore = a_vec;
|
||||
ck_tile::ignore = b_vec;
|
||||
ck_tile::ignore = idx;
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
|
||||
struct WarpGemmAttributeSmfmacImplF16F16F32M16N16K32
|
||||
{
|
||||
static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
|
||||
using ADataType = fp16_t;
|
||||
using BDataType = fp16_t;
|
||||
using IdxDataType = int32_t;
|
||||
using CDataType = float;
|
||||
|
||||
using AVecType = ext_vector_t<fp16_t, 4>;
|
||||
using BVecType = ext_vector_t<fp16_t, 8>;
|
||||
using CVecType = ext_vector_t<float, 4>;
|
||||
|
||||
static constexpr index_t kM = 16;
|
||||
static constexpr index_t kN = 16;
|
||||
static constexpr index_t kK = 32;
|
||||
|
||||
static constexpr index_t kAMBlock = 1;
|
||||
static constexpr index_t kBNBlock = 1;
|
||||
|
||||
static constexpr index_t kAMLane = 16;
|
||||
static constexpr index_t kBNLane = 16;
|
||||
static constexpr index_t kABKLane = 4;
|
||||
static constexpr index_t kABKPerLane = 8;
|
||||
|
||||
static constexpr index_t kCMLane = 4;
|
||||
static constexpr index_t kCNLane = 16;
|
||||
static constexpr index_t kCM0PerLane = 1;
|
||||
static constexpr index_t kCM1PerLane = 4;
|
||||
|
||||
static constexpr index_t CompressionRatio = 2;
|
||||
|
||||
// c_vec += a_vec * b_vec[idx]
|
||||
template <bool post_nop_ = false>
|
||||
CK_TILE_DEVICE void operator()(CVecType& c_vec,
|
||||
const AVecType& a_vec,
|
||||
const BVecType& b_vec,
|
||||
const int32_t& idx,
|
||||
bool_constant<post_nop_> = {}) const
|
||||
{
|
||||
#if defined(__gfx94_) or defined(__gfx95_)
|
||||
c_vec = __builtin_amdgcn_smfmac_f32_16x16x32_f16(a_vec, b_vec, c_vec, idx, 0, 0);
|
||||
#else
|
||||
ck_tile::ignore = c_vec;
|
||||
ck_tile::ignore = a_vec;
|
||||
ck_tile::ignore = b_vec;
|
||||
ck_tile::ignore = idx;
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
176
include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma.hpp
Normal file
176
include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma.hpp
Normal file
@@ -0,0 +1,176 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host/device_prop.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma_impl.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// TODO: currently only support 16 bit input, which means only support tr16_b128; will use ADataType
|
||||
// to determine the layout in the future
|
||||
template <typename Impl>
|
||||
struct AWarpDstrEncodingTrait
|
||||
{
|
||||
using type = tile_distribution_encoding<
|
||||
sequence<Impl::kRepeat>,
|
||||
tuple<sequence<Impl::kAMLane>,
|
||||
sequence<Impl::kABK0PerLane, Impl::kABKLane, Impl::kABK1PerLane>>,
|
||||
tuple<typename Impl::kABPs2RHssMajor>,
|
||||
tuple<typename Impl::kABPs2RHssMinor>,
|
||||
typename Impl::kABYs2RHsMajor,
|
||||
typename Impl::kABYs2RHsMinor>;
|
||||
};
|
||||
|
||||
template <typename Impl>
|
||||
struct BWarpDstrEncodingTrait
|
||||
{
|
||||
using type = tile_distribution_encoding<
|
||||
sequence<Impl::kRepeat>,
|
||||
tuple<sequence<Impl::kBNLane>,
|
||||
sequence<Impl::kABK0PerLane, Impl::kABKLane, Impl::kABK1PerLane>>,
|
||||
tuple<typename Impl::kABPs2RHssMajor>,
|
||||
tuple<typename Impl::kABPs2RHssMinor>,
|
||||
typename Impl::kABYs2RHsMajor,
|
||||
typename Impl::kABYs2RHsMinor>;
|
||||
};
|
||||
|
||||
template <typename Impl>
|
||||
struct CWarpDstrEncodingTrait
|
||||
{
|
||||
using type = tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<Impl::kCM0PerLane, Impl::kCMLane, Impl::kCM1PerLane>,
|
||||
sequence<Impl::kCNLane>>,
|
||||
tuple<typename Impl::kCPs2RHssMajor>,
|
||||
tuple<typename Impl::kCPs2RHssMinor>,
|
||||
typename Impl::kCYs2RHsMajor,
|
||||
typename Impl::kCYs2RHsMinor>;
|
||||
};
|
||||
|
||||
template <typename Impl>
|
||||
struct CTransposedWarpDstrEncodingTrait
|
||||
{
|
||||
using type = tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<Impl::kCNLane>,
|
||||
sequence<Impl::kCM0PerLane, Impl::kCMLane, Impl::kCM1PerLane>>,
|
||||
tuple<typename Impl::kCTPs2RHssMajor>,
|
||||
tuple<typename Impl::kCTPs2RHssMinor>,
|
||||
typename Impl::kCTYs2RHsMajor,
|
||||
typename Impl::kCTYs2RHsMinor>;
|
||||
};
|
||||
|
||||
template <typename WarpGemmAttributeWmmaImpl_, bool kTransC = false>
|
||||
struct WarpGemmAttributeWmma
|
||||
{
|
||||
using Impl = remove_cvref_t<WarpGemmAttributeWmmaImpl_>;
|
||||
|
||||
// When kTransC is true and A/B types differ, we need an impl with swapped types
|
||||
using TransposedImpl =
|
||||
std::conditional_t<kTransC &&
|
||||
!std::is_same_v<typename Impl::ADataType, typename Impl::BDataType>,
|
||||
WarpGemmAttributeWmmaImpl<WmmaTraits<typename Impl::TraitsType::ArchType,
|
||||
typename Impl::BDataType,
|
||||
typename Impl::ADataType,
|
||||
typename Impl::CDataType,
|
||||
Impl::kM,
|
||||
Impl::kN,
|
||||
Impl::kK>>,
|
||||
Impl>;
|
||||
|
||||
using ADataType = typename Impl::ADataType;
|
||||
using BDataType = typename Impl::BDataType;
|
||||
using CDataType = typename Impl::CDataType;
|
||||
|
||||
using AVecType = typename Impl::AVecType;
|
||||
using BVecType = typename Impl::BVecType;
|
||||
using CVecType = typename Impl::CVecType;
|
||||
|
||||
static constexpr index_t kM = Impl::kM;
|
||||
static constexpr index_t kN = Impl::kN;
|
||||
static constexpr index_t kK = Impl::kK;
|
||||
static constexpr index_t kCMLane = Impl::kCMLane;
|
||||
static constexpr index_t kKPerThread = Impl::kABK0PerLane * Impl::kABK1PerLane;
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return 1; }
|
||||
|
||||
// 16 bit input, kAMLane = 16, kABK0PerLane = 4, kABKLane = 2, kABK1PerLane = 2
|
||||
// 8 bit input, kAMLane = 16, kABK0PerLane = 2, kABKLane = 2, kABK1PerLane = 4
|
||||
using AWarpDstrEncoding = typename AWarpDstrEncodingTrait<Impl>::type;
|
||||
using BWarpDstrEncoding = typename BWarpDstrEncodingTrait<Impl>::type;
|
||||
|
||||
// kCM0PerLane = 1, kCMLane = 2, kCM1PerLane = 2, kCNLane = 16
|
||||
using CWarpDstrEncoding =
|
||||
std::conditional_t<kTransC,
|
||||
typename CTransposedWarpDstrEncodingTrait<Impl>::type,
|
||||
typename CWarpDstrEncodingTrait<Impl>::type>;
|
||||
|
||||
// c_vec += a_vec * b_vec
|
||||
template <bool post_nop_ = false>
|
||||
CK_TILE_DEVICE void operator()(CVecType& c_vec,
|
||||
const AVecType& a_vec,
|
||||
const BVecType& b_vec,
|
||||
bool_constant<post_nop_> = {}) const
|
||||
{
|
||||
if constexpr(kTransC)
|
||||
{
|
||||
TransposedImpl{}(c_vec, b_vec, a_vec, bool_constant<post_nop_>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
Impl{}(c_vec, a_vec, b_vec, bool_constant<post_nop_>{});
|
||||
}
|
||||
}
|
||||
|
||||
// c_vec = a_vec * b_vec
|
||||
CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
|
||||
{
|
||||
if constexpr(kTransC)
|
||||
{
|
||||
return TransposedImpl{}(b_vec, a_vec);
|
||||
}
|
||||
else
|
||||
{
|
||||
return Impl{}(a_vec, b_vec);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename AccDataType,
|
||||
index_t M_Warp_Tile,
|
||||
index_t N_Warp_Tile,
|
||||
index_t K_Warp_Tile>
|
||||
CK_TILE_HOST bool check_wmma_supported()
|
||||
{
|
||||
if(is_gfx12_supported())
|
||||
{
|
||||
return has_wmma_traits_v<gfx12_t,
|
||||
ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
M_Warp_Tile,
|
||||
N_Warp_Tile,
|
||||
K_Warp_Tile>;
|
||||
}
|
||||
else if(is_gfx11_supported())
|
||||
{
|
||||
return has_wmma_traits_v<gfx11_t,
|
||||
ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
M_Warp_Tile,
|
||||
N_Warp_Tile,
|
||||
K_Warp_Tile>;
|
||||
}
|
||||
else
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
141
include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma_impl.hpp
Normal file
141
include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma_impl.hpp
Normal file
@@ -0,0 +1,141 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/core/config.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// Base traits for WMMA operations
|
||||
template <typename Arch,
|
||||
typename AType,
|
||||
typename BType,
|
||||
typename CType,
|
||||
index_t M,
|
||||
index_t N,
|
||||
index_t K>
|
||||
struct WmmaTraits;
|
||||
|
||||
// Generic WMMA implementation using traits
|
||||
template <typename Traits>
|
||||
struct WarpGemmAttributeWmmaImpl
|
||||
{
|
||||
using TraitsType = Traits;
|
||||
using ADataType = typename Traits::ADataType;
|
||||
using BDataType = typename Traits::BDataType;
|
||||
using CDataType = typename Traits::CDataType;
|
||||
|
||||
using AVecType = typename Traits::AVecType;
|
||||
using BVecType = typename Traits::BVecType;
|
||||
using CVecType = typename Traits::CVecType;
|
||||
|
||||
// Forward all static constants and type aliases
|
||||
static constexpr index_t kM = Traits::kM;
|
||||
static constexpr index_t kN = Traits::kN;
|
||||
static constexpr index_t kK = Traits::kK;
|
||||
|
||||
static constexpr index_t kAMBlock = Traits::kAMBlock;
|
||||
static constexpr index_t kBNBlock = Traits::kBNBlock;
|
||||
|
||||
static constexpr index_t kRepeat = Traits::kRepeat;
|
||||
static constexpr index_t kAMLane = Traits::kAMLane;
|
||||
static constexpr index_t kBNLane = Traits::kBNLane;
|
||||
static constexpr index_t kABK0PerLane = Traits::kABK0PerLane;
|
||||
static constexpr index_t kABKLane = Traits::kABKLane;
|
||||
static constexpr index_t kABK1PerLane = Traits::kABK1PerLane;
|
||||
|
||||
static constexpr index_t kCMLane = Traits::kCMLane;
|
||||
static constexpr index_t kCNLane = Traits::kCNLane;
|
||||
static constexpr index_t kCM0PerLane = Traits::kCM0PerLane;
|
||||
static constexpr index_t kCM1PerLane = Traits::kCM1PerLane;
|
||||
|
||||
using kABPs2RHssMajor = typename Traits::kABPs2RHssMajor;
|
||||
using kABPs2RHssMinor = typename Traits::kABPs2RHssMinor;
|
||||
using kABYs2RHsMajor = typename Traits::kABYs2RHsMajor;
|
||||
using kABYs2RHsMinor = typename Traits::kABYs2RHsMinor;
|
||||
|
||||
using kCPs2RHssMajor = typename Traits::kCPs2RHssMajor;
|
||||
using kCPs2RHssMinor = typename Traits::kCPs2RHssMinor;
|
||||
using kCYs2RHsMajor = typename Traits::kCYs2RHsMajor;
|
||||
using kCYs2RHsMinor = typename Traits::kCYs2RHsMinor;
|
||||
|
||||
using kCTPs2RHssMajor = typename Traits::kCTPs2RHssMajor;
|
||||
using kCTPs2RHssMinor = typename Traits::kCTPs2RHssMinor;
|
||||
using kCTYs2RHsMajor = typename Traits::kCTYs2RHsMajor;
|
||||
using kCTYs2RHsMinor = typename Traits::kCTYs2RHsMinor;
|
||||
|
||||
// c_vec += a_vec * b_vec
|
||||
template <bool clamp = false, bool post_nop_ = false>
|
||||
CK_TILE_DEVICE void operator()(CVecType& c_vec,
|
||||
const AVecType& a_vec,
|
||||
const BVecType& b_vec,
|
||||
bool_constant<post_nop_> = {}) const
|
||||
{
|
||||
c_vec = Traits::template wmma_intrinsic<clamp>(a_vec, b_vec, c_vec);
|
||||
}
|
||||
|
||||
// c_vec = a_vec * b_vec
|
||||
template <bool clamp = false>
|
||||
CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
|
||||
{
|
||||
return bit_cast<CVecType>(
|
||||
Traits::template wmma_intrinsic<clamp>(a_vec, b_vec, CVecType{0.f}));
|
||||
}
|
||||
};
|
||||
|
||||
using DeviceIp = remove_cvref_t<decltype(ck_tile::get_device_arch())>;
|
||||
using WarpGemmAttributeWmmaImpl_f32_16x16x16_f16_f16 =
|
||||
WarpGemmAttributeWmmaImpl<WmmaTraits<DeviceIp, fp16_t, fp16_t, float, 16, 16, 16>>;
|
||||
|
||||
using WarpGemmAttributeWmmaImpl_f32_16x16x16_bf16_bf16 =
|
||||
WarpGemmAttributeWmmaImpl<WmmaTraits<DeviceIp, bf16_t, bf16_t, float, 16, 16, 16>>;
|
||||
|
||||
using WarpGemmAttributeWmmaImpl_i32_16x16x16_i8_i8 =
|
||||
WarpGemmAttributeWmmaImpl<WmmaTraits<DeviceIp, int8_t, int8_t, int32_t, 16, 16, 16>>;
|
||||
|
||||
using WarpGemmAttributeWmmaImpl_f32_16x16x16_f8_f8 =
|
||||
WarpGemmAttributeWmmaImpl<WmmaTraits<gfx12_t, fp8_t, fp8_t, float, 16, 16, 16>>;
|
||||
|
||||
using WarpGemmAttributeWmmaImpl_f32_16x16x16_bf8_bf8 =
|
||||
WarpGemmAttributeWmmaImpl<WmmaTraits<gfx12_t, bf8_t, bf8_t, float, 16, 16, 16>>;
|
||||
|
||||
using WarpGemmAttributeWmmaImpl_f32_16x16x16_f8_bf8 =
|
||||
WarpGemmAttributeWmmaImpl<WmmaTraits<gfx12_t, fp8_t, bf8_t, float, 16, 16, 16>>;
|
||||
|
||||
using WarpGemmAttributeWmmaImpl_f32_16x16x16_bf8_f8 =
|
||||
WarpGemmAttributeWmmaImpl<WmmaTraits<gfx12_t, bf8_t, fp8_t, float, 16, 16, 16>>;
|
||||
|
||||
template <typename Arch,
|
||||
typename AType,
|
||||
typename BType,
|
||||
typename CType,
|
||||
index_t warp_m,
|
||||
index_t warp_n,
|
||||
index_t warp_k>
|
||||
struct has_wmma_traits
|
||||
{
|
||||
template <typename T>
|
||||
static auto
|
||||
test(int) -> decltype(std::declval<
|
||||
typename WmmaTraits<T, AType, BType, CType, warp_m, warp_n, warp_k>::
|
||||
ADataType>(),
|
||||
std::true_type{});
|
||||
|
||||
template <typename>
|
||||
static std::false_type test(...);
|
||||
|
||||
static constexpr bool value = decltype(test<Arch>(0))::value;
|
||||
};
|
||||
|
||||
template <typename Arch,
|
||||
typename AType,
|
||||
typename BType,
|
||||
typename CType,
|
||||
index_t warp_m,
|
||||
index_t warp_n,
|
||||
index_t warp_k>
|
||||
constexpr bool has_wmma_traits_v =
|
||||
has_wmma_traits<Arch, AType, BType, CType, warp_m, warp_n, warp_k>::value;
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,95 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "warp_gemm_attribute_wmma_impl_base_traits.hpp"
|
||||
namespace ck_tile {
|
||||
// fp16 specialization - GFX11
|
||||
template <>
|
||||
struct WmmaTraits<gfx11_t, fp16_t, fp16_t, float, 16, 16, 16>
|
||||
: WmmaTraitsBase<gfx11_t, fp16_t, fp16_t, float>
|
||||
{
|
||||
using ArchType = gfx11_t;
|
||||
|
||||
template <bool clamp = false>
|
||||
CK_TILE_DEVICE static CVecType
|
||||
wmma_intrinsic(const AVecType& a_vec, const BVecType& b_vec, const CVecType& c_vec)
|
||||
{
|
||||
#ifdef __gfx11__
|
||||
return __builtin_amdgcn_wmma_f32_16x16x16_f16_w32(a_vec, b_vec, c_vec);
|
||||
#else
|
||||
ck_tile::ignore = a_vec;
|
||||
ck_tile::ignore = b_vec;
|
||||
ck_tile::ignore = c_vec;
|
||||
return CVecType{0.f};
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
// bf16 specialization - GFX11
|
||||
template <>
|
||||
struct WmmaTraits<gfx11_t, bf16_t, bf16_t, float, 16, 16, 16>
|
||||
: WmmaTraitsBase<gfx11_t, bf16_t, bf16_t, float>
|
||||
{
|
||||
using ArchType = gfx11_t;
|
||||
|
||||
template <bool clamp = false>
|
||||
CK_TILE_DEVICE static CVecType
|
||||
wmma_intrinsic(const AVecType& a_vec, const BVecType& b_vec, const CVecType& c_vec)
|
||||
{
|
||||
#ifdef __gfx11__
|
||||
return __builtin_amdgcn_wmma_f32_16x16x16_bf16_w32(a_vec, b_vec, c_vec);
|
||||
#else
|
||||
ck_tile::ignore = a_vec;
|
||||
ck_tile::ignore = b_vec;
|
||||
ck_tile::ignore = c_vec;
|
||||
return CVecType{0.f};
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
// fp16 specialization - GFX12
|
||||
template <>
|
||||
struct WmmaTraits<gfx12_t, fp16_t, fp16_t, float, 16, 16, 16>
|
||||
: WmmaTraitsBase<gfx12_t, fp16_t, fp16_t, float>
|
||||
{
|
||||
using ArchType = gfx12_t;
|
||||
|
||||
template <bool clamp = false>
|
||||
CK_TILE_DEVICE static CVecType
|
||||
wmma_intrinsic(const AVecType& a_vec, const BVecType& b_vec, const CVecType& c_vec)
|
||||
{
|
||||
#ifdef __gfx12__
|
||||
return __builtin_amdgcn_wmma_f32_16x16x16_f16_w32_gfx12(a_vec, b_vec, c_vec);
|
||||
#else
|
||||
ck_tile::ignore = a_vec;
|
||||
ck_tile::ignore = b_vec;
|
||||
ck_tile::ignore = c_vec;
|
||||
return CVecType{0.f};
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
// bf16 specialization - GFX12
|
||||
template <>
|
||||
struct WmmaTraits<gfx12_t, bf16_t, bf16_t, float, 16, 16, 16>
|
||||
: WmmaTraitsBase<gfx12_t, bf16_t, bf16_t, float>
|
||||
{
|
||||
using ArchType = gfx12_t;
|
||||
|
||||
template <bool clamp = false>
|
||||
CK_TILE_DEVICE static CVecType
|
||||
wmma_intrinsic(const AVecType& a_vec, const BVecType& b_vec, const CVecType& c_vec)
|
||||
{
|
||||
#ifdef __gfx12__
|
||||
return __builtin_amdgcn_wmma_f32_16x16x16_bf16_w32_gfx12(a_vec, b_vec, c_vec);
|
||||
#else
|
||||
ck_tile::ignore = a_vec;
|
||||
ck_tile::ignore = b_vec;
|
||||
ck_tile::ignore = c_vec;
|
||||
return CVecType{0.f};
|
||||
#endif
|
||||
}
|
||||
};
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,148 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "warp_gemm_attribute_wmma_impl_base_traits.hpp"
|
||||
namespace ck_tile {
|
||||
// int8 specialization - GFX11
|
||||
template <>
|
||||
struct WmmaTraits<gfx11_t, int8_t, int8_t, int32_t, 16, 16, 16>
|
||||
: WmmaTraitsBase<gfx11_t, int8_t, int8_t, int32_t>
|
||||
{
|
||||
using ArchType = gfx11_t;
|
||||
|
||||
template <bool clamp = false>
|
||||
CK_TILE_DEVICE static CVecType
|
||||
wmma_intrinsic(const AVecType& a_vec, const BVecType& b_vec, const CVecType& c_vec)
|
||||
{
|
||||
#ifdef __gfx11__
|
||||
return __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32(true, // neg_a
|
||||
bit_cast<int32x4_t>(a_vec),
|
||||
true, // neg_b
|
||||
bit_cast<int32x4_t>(b_vec),
|
||||
bit_cast<int32x8_t>(c_vec),
|
||||
clamp);
|
||||
#else
|
||||
ck_tile::ignore = a_vec;
|
||||
ck_tile::ignore = b_vec;
|
||||
ck_tile::ignore = c_vec;
|
||||
return CVecType{0};
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
// int8 specialization - GFX12
|
||||
template <>
|
||||
struct WmmaTraits<gfx12_t, int8_t, int8_t, int32_t, 16, 16, 16>
|
||||
: WmmaTraitsBase<gfx12_t, int8_t, int8_t, int32_t>
|
||||
{
|
||||
using ArchType = gfx12_t;
|
||||
|
||||
template <bool clamp = false>
|
||||
CK_TILE_DEVICE static CVecType
|
||||
wmma_intrinsic(const AVecType& a_vec, const BVecType& b_vec, const CVecType& c_vec)
|
||||
{
|
||||
#ifdef __gfx12__
|
||||
return __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12(true, // neg_a
|
||||
bit_cast<int32x2_t>(a_vec),
|
||||
true, // neg_b
|
||||
bit_cast<int32x2_t>(b_vec),
|
||||
bit_cast<int32x8_t>(c_vec),
|
||||
clamp);
|
||||
#else
|
||||
ck_tile::ignore = a_vec;
|
||||
ck_tile::ignore = b_vec;
|
||||
ck_tile::ignore = c_vec;
|
||||
return CVecType{0};
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
// fp8/bf8 specialization - GFX12
|
||||
template <>
|
||||
struct WmmaTraits<gfx12_t, fp8_t, fp8_t, float, 16, 16, 16>
|
||||
: WmmaTraitsBase<gfx12_t, fp8_t, fp8_t, float>
|
||||
{
|
||||
using ArchType = gfx12_t;
|
||||
|
||||
template <bool clamp = false>
|
||||
CK_TILE_DEVICE static CVecType
|
||||
wmma_intrinsic(const AVecType& a_vec, const BVecType& b_vec, const CVecType& c_vec)
|
||||
{
|
||||
#ifdef __gfx12__
|
||||
return __builtin_amdgcn_wmma_f32_16x16x16_fp8_fp8_w32_gfx12(
|
||||
bit_cast<int32x2_t>(a_vec), bit_cast<int32x2_t>(b_vec), bit_cast<fp32x8_t>(c_vec));
|
||||
#else
|
||||
ck_tile::ignore = a_vec;
|
||||
ck_tile::ignore = b_vec;
|
||||
ck_tile::ignore = c_vec;
|
||||
return CVecType{0};
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct WmmaTraits<gfx12_t, bf8_t, bf8_t, float, 16, 16, 16>
|
||||
: WmmaTraitsBase<gfx12_t, bf8_t, bf8_t, float>
|
||||
{
|
||||
using ArchType = gfx12_t;
|
||||
|
||||
template <bool clamp = false>
|
||||
CK_TILE_DEVICE static CVecType
|
||||
wmma_intrinsic(const AVecType& a_vec, const BVecType& b_vec, const CVecType& c_vec)
|
||||
{
|
||||
#ifdef __gfx12__
|
||||
return __builtin_amdgcn_wmma_f32_16x16x16_bf8_bf8_w32_gfx12(
|
||||
bit_cast<int32x2_t>(a_vec), bit_cast<int32x2_t>(b_vec), bit_cast<fp32x8_t>(c_vec));
|
||||
#else
|
||||
ck_tile::ignore = a_vec;
|
||||
ck_tile::ignore = b_vec;
|
||||
ck_tile::ignore = c_vec;
|
||||
return CVecType{0};
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct WmmaTraits<gfx12_t, fp8_t, bf8_t, float, 16, 16, 16>
|
||||
: WmmaTraitsBase<gfx12_t, fp8_t, bf8_t, float>
|
||||
{
|
||||
using ArchType = gfx12_t;
|
||||
|
||||
template <bool clamp = false>
|
||||
CK_TILE_DEVICE static CVecType
|
||||
wmma_intrinsic(const AVecType& a_vec, const BVecType& b_vec, const CVecType& c_vec)
|
||||
{
|
||||
#ifdef __gfx12__
|
||||
return __builtin_amdgcn_wmma_f32_16x16x16_fp8_bf8_w32_gfx12(
|
||||
bit_cast<int32x2_t>(a_vec), bit_cast<int32x2_t>(b_vec), bit_cast<fp32x8_t>(c_vec));
|
||||
#else
|
||||
ck_tile::ignore = a_vec;
|
||||
ck_tile::ignore = b_vec;
|
||||
ck_tile::ignore = c_vec;
|
||||
return CVecType{0};
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct WmmaTraits<gfx12_t, bf8_t, fp8_t, float, 16, 16, 16>
|
||||
: WmmaTraitsBase<gfx12_t, bf8_t, fp8_t, float>
|
||||
{
|
||||
template <bool clamp = false>
|
||||
CK_TILE_DEVICE static CVecType
|
||||
wmma_intrinsic(const AVecType& a_vec, const BVecType& b_vec, const CVecType& c_vec)
|
||||
{
|
||||
#ifdef __gfx12__
|
||||
return __builtin_amdgcn_wmma_f32_16x16x16_bf8_fp8_w32_gfx12(
|
||||
bit_cast<int32x2_t>(a_vec), bit_cast<int32x2_t>(b_vec), bit_cast<fp32x8_t>(c_vec));
|
||||
#else
|
||||
ck_tile::ignore = a_vec;
|
||||
ck_tile::ignore = b_vec;
|
||||
ck_tile::ignore = c_vec;
|
||||
return CVecType{0};
|
||||
#endif
|
||||
}
|
||||
};
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,106 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
namespace ck_tile {
|
||||
template <typename Arch, typename ADType, typename BDType, typename CDType>
|
||||
struct WmmaTraitsBase;
|
||||
|
||||
// GFX11 specialization
|
||||
template <typename ADType, typename BDType, typename CDType>
|
||||
struct WmmaTraitsBase<gfx11_t, ADType, BDType, CDType>
|
||||
{
|
||||
using ArchType = gfx11_t;
|
||||
|
||||
using ADataType = ADType;
|
||||
using BDataType = BDType;
|
||||
using CDataType = CDType;
|
||||
|
||||
using AVecType = ext_vector_t<ADataType, 16>;
|
||||
using BVecType = ext_vector_t<BDataType, 16>;
|
||||
using CVecType = ext_vector_t<CDataType, 8>;
|
||||
|
||||
static constexpr index_t kM = 16;
|
||||
static constexpr index_t kN = 16;
|
||||
static constexpr index_t kK = 16;
|
||||
|
||||
static constexpr index_t kAMBlock = 1;
|
||||
static constexpr index_t kBNBlock = 1;
|
||||
|
||||
static constexpr index_t kRepeat = 2;
|
||||
static constexpr index_t kAMLane = 16;
|
||||
static constexpr index_t kBNLane = 16;
|
||||
static constexpr index_t kABK0PerLane = 1;
|
||||
static constexpr index_t kABKLane = 1;
|
||||
static constexpr index_t kABK1PerLane = 16;
|
||||
|
||||
static constexpr index_t kCMLane = 2;
|
||||
static constexpr index_t kCNLane = 16;
|
||||
static constexpr index_t kCM0PerLane = 8;
|
||||
static constexpr index_t kCM1PerLane = 1;
|
||||
|
||||
using kABPs2RHssMajor = sequence<0, 2, 1>;
|
||||
using kABPs2RHssMinor = sequence<0, 1, 0>;
|
||||
using kABYs2RHsMajor = sequence<2, 2>;
|
||||
using kABYs2RHsMinor = sequence<0, 2>;
|
||||
|
||||
using kCPs2RHssMajor = sequence<1, 2>;
|
||||
using kCPs2RHssMinor = sequence<1, 0>;
|
||||
using kCYs2RHsMajor = sequence<1, 1>;
|
||||
using kCYs2RHsMinor = sequence<0, 2>;
|
||||
|
||||
using kCTPs2RHssMajor = sequence<2, 1>;
|
||||
using kCTPs2RHssMinor = sequence<1, 0>;
|
||||
using kCTYs2RHsMajor = sequence<2, 2>;
|
||||
using kCTYs2RHsMinor = sequence<0, 2>;
|
||||
};
|
||||
|
||||
// GFX12 specialization
|
||||
template <typename ADType, typename BDType, typename CDType>
|
||||
struct WmmaTraitsBase<gfx12_t, ADType, BDType, CDType>
|
||||
{
|
||||
using ArchType = gfx12_t;
|
||||
|
||||
using ADataType = ADType;
|
||||
using BDataType = BDType;
|
||||
using CDataType = CDType;
|
||||
|
||||
using AVecType = ext_vector_t<ADataType, 8>;
|
||||
using BVecType = ext_vector_t<BDataType, 8>;
|
||||
using CVecType = ext_vector_t<CDataType, 8>;
|
||||
|
||||
static constexpr index_t kM = 16;
|
||||
static constexpr index_t kN = 16;
|
||||
static constexpr index_t kK = 16;
|
||||
|
||||
static constexpr index_t kAMBlock = 1;
|
||||
static constexpr index_t kBNBlock = 1;
|
||||
|
||||
static constexpr index_t kRepeat = 1;
|
||||
static constexpr index_t kAMLane = 16;
|
||||
static constexpr index_t kBNLane = 16;
|
||||
static constexpr index_t kABK0PerLane = 1;
|
||||
static constexpr index_t kABKLane = 2;
|
||||
static constexpr index_t kABK1PerLane = 8;
|
||||
|
||||
static constexpr index_t kCMLane = 2;
|
||||
static constexpr index_t kCNLane = 16;
|
||||
static constexpr index_t kCM0PerLane = 1;
|
||||
static constexpr index_t kCM1PerLane = 8;
|
||||
|
||||
using kABPs2RHssMajor = sequence<2, 1>;
|
||||
using kABPs2RHssMinor = sequence<1, 0>;
|
||||
using kABYs2RHsMajor = sequence<2, 2>;
|
||||
using kABYs2RHsMinor = sequence<0, 2>;
|
||||
|
||||
using kCPs2RHssMajor = sequence<1, 2>;
|
||||
using kCPs2RHssMinor = sequence<1, 0>;
|
||||
using kCYs2RHsMajor = sequence<1, 1>;
|
||||
using kCYs2RHsMinor = sequence<0, 2>;
|
||||
|
||||
using kCTPs2RHssMajor = sequence<2, 1>;
|
||||
using kCTPs2RHssMinor = sequence<1, 0>;
|
||||
using kCTYs2RHsMajor = sequence<2, 2>;
|
||||
using kCTYs2RHsMinor = sequence<0, 2>;
|
||||
};
|
||||
} // namespace ck_tile
|
||||
207
include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp
Normal file
207
include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp
Normal file
@@ -0,0 +1,207 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_wmma_gemm.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
namespace impl {
|
||||
namespace warp_gemm_dispatcher {
|
||||
|
||||
// C++20 using enum
|
||||
static inline constexpr auto ESingle = WGAttrNumAccessEnum::Single;
|
||||
static inline constexpr auto EDouble = WGAttrNumAccessEnum::Double;
|
||||
static inline constexpr auto EQuad = WGAttrNumAccessEnum::Quad;
|
||||
|
||||
template <typename AType,
|
||||
typename BType,
|
||||
typename AccType,
|
||||
index_t MPerWave,
|
||||
index_t NPerWave,
|
||||
index_t KPerWave,
|
||||
bool TransposeC,
|
||||
bool SwizzleA = false,
|
||||
bool UseStructuredSparsity = false,
|
||||
WGAttrNumAccessEnum AttrNumAccessA = ESingle,
|
||||
WGAttrNumAccessEnum AttrNumAccessB = AttrNumAccessA>
|
||||
struct Dispatcher;
|
||||
|
||||
// clang-format off
|
||||
// fp32
|
||||
// ADataType, BDataType, AccDataType, MPerWave, NPerWave, KPerWave, TransposeC, SwizzleA, UseStructuredSparsity
|
||||
template<> struct Dispatcher<float, float, float, 16, 16, 4, false> { using Type = WarpGemmMfmaF32F32F32M16N16K4; };
|
||||
template<> struct Dispatcher<float, float, float, 16, 16, 16, false> { using Type = WarpGemmMfmaF32F32F32M16N16K16<>; };
|
||||
template<> struct Dispatcher<float, float, float, 16, 16, 8, false> { using Type = WarpGemmMfmaF32F32F32M16N16K8<>; };
|
||||
template<> struct Dispatcher<float, float, float, 32, 32, 4, false> { using Type = WarpGemmMfmaF32F32F32M32N32K4<>; };
|
||||
template<> struct Dispatcher<float, float, float, 32, 32, 8, false> { using Type = WarpGemmMfmaF32F32F32M32N32K8<>; };
|
||||
template<> struct Dispatcher<float, float, float, 32, 32, 8, false, false, false, EDouble> { using Type = WarpGemmMfmaF32F32F32M32N32K8<EDouble>; };
|
||||
template<> struct Dispatcher<float, float, float, 16, 16, 16, true> { using Type = WarpGemmMfmaF32F32F32M16N16K16TransposedCDistribution<>; };
|
||||
// fp16
|
||||
// ADataType, BDataType, AccDataType, MPerWave, NPerWave, KPerWave, TransposeC, SwizzleA, UseStructuredSparsity
|
||||
template<> struct Dispatcher<half_t, half_t, float, 32, 32, 8, false> { using Type = WarpGemmMfmaF16F16F32M32N32K8; };
|
||||
template<> struct Dispatcher<half_t, half_t, float, 32, 32, 8, true> { using Type = WarpGemmMfmaF16F16F32M32N32K8TransposedCDistribution; };
|
||||
template<> struct Dispatcher<half_t, half_t, float, 32, 32, 16, false> { using Type = WarpGemmMfmaF16F16F32M32N32K16<>; };
|
||||
template<> struct Dispatcher<half_t, half_t, float, 32, 32, 16, true> { using Type = WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution<>; };
|
||||
template<> struct Dispatcher<half_t, half_t, float, 32, 32, 16, false, false, false, EDouble> { using Type = WarpGemmMfmaF16F16F32M32N32K16<EDouble>; };
|
||||
template<> struct Dispatcher<half_t, half_t, float, 32, 32, 16, true, false, false, EDouble> { using Type = WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution<EDouble>; };
|
||||
template<> struct Dispatcher<half_t, half_t, float, 16, 16, 32, false> { using Type = WarpGemmMfmaF16F16F32M16N16K32<>; };
|
||||
template<> struct Dispatcher<half_t, half_t, float, 16, 16, 32, true> { using Type = WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution<>; };
|
||||
template<> struct Dispatcher<half_t, half_t, float, 16, 16, 32, false, false, false, EDouble> { using Type = WarpGemmMfmaF16F16F32M16N16K32<EDouble>; };
|
||||
template<> struct Dispatcher<half_t, half_t, float, 16, 16, 32, true, false, false, EDouble> { using Type = WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution<EDouble>; };
|
||||
template<> struct Dispatcher<half_t, half_t, float, 4, 64, 16, false> { using Type = WarpGemmMfmaF16F16F32M4N64K16; };
|
||||
template<> struct Dispatcher<half_t, half_t, float, 64, 4, 16, false> { using Type = WarpGemmMfmaF16F16F32M64N4K16; };
|
||||
// WMMA cases
|
||||
#if defined(__gfx11__) || defined(__gfx12__)
|
||||
template<bool TransposeC> struct Dispatcher<half_t, half_t, float, 16, 16, 16, TransposeC, false> { using Type = WarpGemmWmma_f32_16x16x16_f16_f16<TransposeC>;};
|
||||
#else
|
||||
template<> struct Dispatcher<half_t, half_t, float, 16, 16, 16, false> { using Type = WarpGemmMfmaF16F16F32M16N16K16; };
|
||||
template<> struct Dispatcher<half_t, half_t, float, 16, 16, 16, true> { using Type = WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution; };
|
||||
#endif
|
||||
|
||||
template<> struct Dispatcher<half_t, half_t, float, 32, 32, 8, false, true> { using Type = WarpGemmMfmaF16F16F32M32N32K8SwizzleA; };
|
||||
template<> struct Dispatcher<half_t, half_t, float, 32, 32, 16, false, true> { using Type = WarpGemmMfmaF16F16F32M32N32K16SwizzleA; };
|
||||
template<> struct Dispatcher<half_t, half_t, float, 32, 32, 8, true, true> { using Type = WarpGemmMfmaF16F16F32M32N32K8SwizzleBTransposedCDistribution; };
|
||||
template<> struct Dispatcher<half_t, half_t, float, 32, 32, 16, true, true> { using Type = WarpGemmMfmaF16F16F32M32N32K16SwizzleBTransposedCDistribution; };
|
||||
|
||||
// fp16 2:4 structural sparsity
|
||||
// ADataType, BDataType, AccDataType, MPerWave, NPerWave, KPerWave, TransposeC, SwizzleA, UseStructuredSparsity
|
||||
template<> struct Dispatcher<half_t, half_t, float, 32, 32, 16, false, false, true> { using Type = WarpGemmSmfmacF16F16F32M32N32K16; };
|
||||
template<> struct Dispatcher<half_t, half_t, float, 16, 16, 32, false, false, true> { using Type = WarpGemmSmfmacF16F16F32M16N16K32; };
|
||||
|
||||
// bf16
|
||||
// ADataType, BDataType, AccDataType, MPerWave, NPerWave, KPerWave, TransposeC, SwizzleA, UseStructuredSparsity
|
||||
template<> struct Dispatcher<bf16_t, bf16_t, float, 32, 32, 8, false> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K8; };
|
||||
template<> struct Dispatcher<bf16_t, bf16_t, float, 32, 32, 8, true> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K8TransposedCDistribution; };
|
||||
template<> struct Dispatcher<bf16_t, bf16_t, float, 32, 32, 16, false> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K16<>; };
|
||||
template<> struct Dispatcher<bf16_t, bf16_t, float, 32, 32, 16, true> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution<>; };
|
||||
template<> struct Dispatcher<bf16_t, bf16_t, float, 32, 32, 16, false, false, false, EDouble> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K16<EDouble>; };
|
||||
template<> struct Dispatcher<bf16_t, bf16_t, float, 32, 32, 16, true, false, false, EDouble> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution<EDouble>; };
|
||||
template<> struct Dispatcher<bf16_t, bf16_t, float, 16, 16, 32, false> { using Type = WarpGemmMfmaBf16Bf16F32M16N16K32<>; };
|
||||
template<> struct Dispatcher<bf16_t, bf16_t, float, 16, 16, 32, false, false, false, EDouble, ESingle> { using Type = WarpGemmMfmaBf16Bf16F32M16N16K32<EDouble, ESingle>; };
|
||||
template<> struct Dispatcher<bf16_t, bf16_t, float, 16, 16, 64, false, false, false, EQuad, ESingle> { using Type = WarpGemmMfmaBf16Bf16F32M16N16K64<EQuad, ESingle>; };
|
||||
template<> struct Dispatcher<bf16_t, bf16_t, float, 16, 16, 64, false, false, false, EQuad> { using Type = WarpGemmMfmaBf16Bf16F32M16N16K64<EQuad>; };
|
||||
template<> struct Dispatcher<bf16_t, bf16_t, float, 16, 16, 64, false> { using Type = WarpGemmMfmaBf16Bf16F32M16N16K64<>; };
|
||||
template<> struct Dispatcher<bf16_t, bf16_t, float, 16, 16, 32, true> { using Type = WarpGemmMfmaBf16Bf16F32M16N16K32TransposedCDistribution<>; };
|
||||
template<> struct Dispatcher<bf16_t, bf16_t, float, 16, 16, 32, false, false, false, EDouble> { using Type = WarpGemmMfmaBf16Bf16F32M16N16K32<EDouble>; };
|
||||
template<> struct Dispatcher<bf16_t, bf16_t, float, 16, 16, 32, true, false, false, EDouble> { using Type = WarpGemmMfmaBf16Bf16F32M16N16K32TransposedCDistribution<EDouble>; };
|
||||
template<> struct Dispatcher<bf16_t, bf16_t, float, 4, 64, 16, false> { using Type = WarpGemmMfmaBf16Bf16F32M4N64K16; };
|
||||
template<> struct Dispatcher<bf16_t, bf16_t, float, 64, 4, 16, false> { using Type = WarpGemmMfmaBf16Bf16F32M64N4K16; };
|
||||
// WMMA cases
|
||||
#if defined(__gfx11__) || defined(__gfx12__)
|
||||
template<bool TransposeC> struct Dispatcher<bf16_t, bf16_t, float, 16, 16, 16, TransposeC, false> { using Type = WarpGemmWmma_f32_16x16x16_bf16_bf16<TransposeC>; };
|
||||
#else
|
||||
template<> struct Dispatcher<bf16_t, bf16_t, float, 16, 16, 16, false> { using Type = WarpGemmMfmaBf16Bf16F32M16N16K16; };
|
||||
template<> struct Dispatcher<bf16_t, bf16_t, float, 16, 16, 16, true> { using Type = WarpGemmMfmaBf16Bf16F32M16N16K16TransposedCDistribution; };
|
||||
#endif
|
||||
|
||||
template<> struct Dispatcher<bf16_t, bf16_t, float, 32, 32, 8, false, true> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K8SwizzleA; };
|
||||
template<> struct Dispatcher<bf16_t, bf16_t, float, 32, 32, 16, false, true> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleA; };
|
||||
template<> struct Dispatcher<bf16_t, bf16_t, float, 32, 32, 8, true, true> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K8SwizzleBTransposedCDistribution; };
|
||||
template<> struct Dispatcher<bf16_t, bf16_t, float, 32, 32, 16, true, true> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleBTransposedCDistribution; };
|
||||
|
||||
// fp8
|
||||
// ADataType, BDataType, AccDataType, MPerWave, NPerWave, KPerWave, TransposeC, SwizzleA, UseStructuredSparsity
|
||||
template<> struct Dispatcher<fp8_t, fp8_t, float, 32, 32, 16, false> { using Type = WarpGemmMfma_f32_32x32x16_fp8_fp8; };
|
||||
template<> struct Dispatcher<fp8_t, fp8_t, float, 16, 16, 32, false> { using Type = WarpGemmMfma_f32_16x16x32_fp8_fp8; };
|
||||
template<> struct Dispatcher<fp8_t, fp8_t, float, 32, 32, 16, true> { using Type = WarpGemmMfma_f32_32x32x16_fp8_fp8_CTransposed; };
|
||||
template<> struct Dispatcher<fp8_t, fp8_t, float, 16, 16, 32, true> { using Type = WarpGemmMfma_f32_16x16x32_fp8_fp8_CTransposed; };
|
||||
template<> struct Dispatcher<fp8_t, bf8_t, float, 32, 32, 16, false> { using Type = WarpGemmMfma_f32_32x32x16_fp8_bf8; };
|
||||
template<> struct Dispatcher<fp8_t, bf8_t, float, 32, 32, 16, true> { using Type = WarpGemmMfma_f32_32x32x16_fp8_bf8_CTransposed; };
|
||||
template<> struct Dispatcher<fp8_t, bf8_t, float, 16, 16, 32, false> { using Type = WarpGemmMfma_f32_16x16x32_fp8_bf8; };
|
||||
template<> struct Dispatcher<fp8_t, bf8_t, float, 32, 32, 32, false> { using Type = WarpGemmMfma_f32_32x32x32_fp8_bf8; };
|
||||
template<> struct Dispatcher<bf8_t, fp8_t, float, 32, 32, 16, false> { using Type = WarpGemmMfma_f32_32x32x16_bf8_fp8; };
|
||||
template<> struct Dispatcher<bf8_t, fp8_t, float, 32, 32, 16, true> { using Type = WarpGemmMfma_f32_32x32x16_bf8_fp8_CTransposed; };
|
||||
template<> struct Dispatcher<bf8_t, bf8_t, float, 32, 32, 16, false> { using Type = WarpGemmMfma_f32_32x32x16_bf8_bf8; };
|
||||
template<> struct Dispatcher<bf8_t, bf8_t, float, 16, 16, 32, false> { using Type = WarpGemmMfma_f32_16x16x32_bf8_bf8; };
|
||||
template<> struct Dispatcher<bf8_t, bf8_t, float, 16, 16, 32, true> { using Type = WarpGemmMfma_f32_16x16x32_bf8_bf8_CTransposed; };
|
||||
template<> struct Dispatcher<bf8_t, bf8_t, float, 16, 16, 64, false> { using Type = WarpGemmMfma_f32_16x16x64_bf8_bf8; };
|
||||
template<> struct Dispatcher<bf8_t, bf8_t, float, 16, 16, 64, true> { using Type = WarpGemmMfma_f32_16x16x64_bf8_bf8_CTransposed; };
|
||||
template<> struct Dispatcher<bf8_t, bf8_t, float, 32, 32, 16, true> { using Type = WarpGemmMfma_f32_32x32x16_bf8_bf8_CTransposed; };
|
||||
|
||||
// scale mfma based f8f6f4
|
||||
template<typename A, typename B, WGAttrNumAccessEnum I>
|
||||
struct Dispatcher<A, B, float, 16, 16, 128, false, false, false, I> { using Type = WarpGemmMfma_f32_16x16x128_f8f6f4<A, B, I>; };
|
||||
template<WGAttrNumAccessEnum I> struct Dispatcher<fp8_t, fp8_t, float, 16, 16, 128, true, false, false, I> { using Type = WarpGemmMfma_f32_16x16x128_fp8_fp8_CTransposed<I>; };
|
||||
template<WGAttrNumAccessEnum I> struct Dispatcher<fp8_t, bf8_t, float, 16, 16, 128, true, false, false, I> { using Type = WarpGemmMfma_f32_16x16x128_fp8_bf8_CTransposed<I>; };
|
||||
template<WGAttrNumAccessEnum I> struct Dispatcher<bf8_t, fp8_t, float, 16, 16, 128, true, false, false, I> { using Type = WarpGemmMfma_f32_16x16x128_bf8_fp8_CTransposed<I>; };
|
||||
template<WGAttrNumAccessEnum I> struct Dispatcher<bf8_t, bf8_t, float, 16, 16, 128, true, false, false, I> { using Type = WarpGemmMfma_f32_16x16x128_bf8_bf8_CTransposed<I>; };
|
||||
|
||||
template<WGAttrNumAccessEnum I> struct Dispatcher<pk_fp4_t, pk_fp4_t, float, 16, 16, 128, true, false, false, I> { using Type = WarpGemmMfma_f32_16x16x128_fp4_fp4_CTransposed<I>; };
|
||||
|
||||
template<> struct Dispatcher<fp8_t, fp8_t, float, 32, 32, 64, false> { using Type = WarpGemmMfma_f32_32x32x64_fp8_fp8<>; };
|
||||
template<> struct Dispatcher<fp8_t, bf8_t, float, 32, 32, 64, false> { using Type = WarpGemmMfma_f32_32x32x64_fp8_bf8<>; };
|
||||
template<> struct Dispatcher<bf8_t, fp8_t, float, 32, 32, 64, false> { using Type = WarpGemmMfma_f32_32x32x64_bf8_fp8<>; };
|
||||
template<> struct Dispatcher<bf8_t, bf8_t, float, 32, 32, 64, false> { using Type = WarpGemmMfma_f32_32x32x64_bf8_bf8<>; };
|
||||
template<> struct Dispatcher<fp8_t, fp8_t, float, 32, 32, 64, false, false, false, EDouble> { using Type = WarpGemmMfma_f32_32x32x64_fp8_fp8<EDouble>; };
|
||||
template<> struct Dispatcher<fp8_t, bf8_t, float, 32, 32, 64, false, false, false, EDouble> { using Type = WarpGemmMfma_f32_32x32x64_fp8_bf8<EDouble>; };
|
||||
template<> struct Dispatcher<bf8_t, fp8_t, float, 32, 32, 64, false, false, false, EDouble> { using Type = WarpGemmMfma_f32_32x32x64_bf8_fp8<EDouble>; };
|
||||
template<> struct Dispatcher<bf8_t, bf8_t, float, 32, 32, 64, false, false, false, EDouble> { using Type = WarpGemmMfma_f32_32x32x64_bf8_bf8<EDouble>; };
|
||||
template<> struct Dispatcher<fp8_t, fp8_t, float, 32, 32, 64, false, false, false, EQuad> { using Type = WarpGemmMfma_f32_32x32x64_fp8_fp8<EQuad>; };
|
||||
template<> struct Dispatcher<fp8_t, bf8_t, float, 32, 32, 64, false, false, false, EQuad> { using Type = WarpGemmMfma_f32_32x32x64_fp8_bf8<EQuad>; };
|
||||
template<> struct Dispatcher<bf8_t, fp8_t, float, 32, 32, 64, false, false, false, EQuad> { using Type = WarpGemmMfma_f32_32x32x64_bf8_fp8<EQuad>; };
|
||||
template<> struct Dispatcher<bf8_t, bf8_t, float, 32, 32, 64, false, false, false, EQuad> { using Type = WarpGemmMfma_f32_32x32x64_bf8_bf8<EQuad>; };
|
||||
|
||||
template<WGAttrNumAccessEnum I> struct Dispatcher<fp8_t, fp8_t, float, 32, 32, 64, true, false, false, I> { using Type = WarpGemmMfma_f32_32x32x64_fp8_fp8_CTransposed<I>; };
|
||||
template<WGAttrNumAccessEnum I> struct Dispatcher<fp8_t, bf8_t, float, 32, 32, 64, true, false, false, I> { using Type = WarpGemmMfma_f32_32x32x64_fp8_bf8_CTransposed<I>; };
|
||||
template<WGAttrNumAccessEnum I> struct Dispatcher<bf8_t, fp8_t, float, 32, 32, 64, true, false, false, I> { using Type = WarpGemmMfma_f32_32x32x64_bf8_fp8_CTransposed<I>; };
|
||||
template<WGAttrNumAccessEnum I> struct Dispatcher<bf8_t, bf8_t, float, 32, 32, 64, true, false, false, I> { using Type = WarpGemmMfma_f32_32x32x64_bf8_bf8_CTransposed<I>; };
|
||||
|
||||
template<WGAttrNumAccessEnum I> struct Dispatcher<pk_fp4_t, pk_fp4_t, float, 32, 32, 64, true, false, false, I> { using Type = WarpGemmMfma_f32_32x32x64_fp4_fp4_CTransposed<I>; };
|
||||
|
||||
template<> struct Dispatcher<fp8_t, fp8_t, float, 32, 32, 32, false> { using Type = WarpGemmMfma_f32_32x32x32_fp8_fp8<>; };
|
||||
template<> struct Dispatcher<fp8_t, fp8_t, float, 32, 32, 32, false, false, false, EDouble> { using Type = WarpGemmMfma_f32_32x32x32_fp8_fp8<EDouble>; };
|
||||
template<> struct Dispatcher<bf8_t, bf8_t, float, 32, 32, 32, false> { using Type = WarpGemmMfma_f32_32x32x32_bf8_bf8<>; };
|
||||
template<> struct Dispatcher<bf8_t, bf8_t, float, 32, 32, 32, false, false, false, EDouble> { using Type = WarpGemmMfma_f32_32x32x32_bf8_bf8<EDouble>; };
|
||||
|
||||
template<> struct Dispatcher<fp8_t, fp8_t, float, 16, 16, 64, true> { using Type = WarpGemmMfma_f32_16x16x64_fp8_fp8_CTransposed; };
|
||||
template<> struct Dispatcher<fp8_t, fp8_t, float, 16, 16, 64, false> { using Type = WarpGemmMfma_f32_16x16x64_fp8_fp8<>; };
|
||||
template<> struct Dispatcher<fp8_t, fp8_t, float, 16, 16, 64, false, false, false, EDouble> { using Type = WarpGemmMfma_f32_16x16x64_fp8_fp8<EDouble>; };
|
||||
|
||||
//WMMA cases
|
||||
template<bool TransposeC> struct Dispatcher<fp8_t, fp8_t, float, 16, 16, 16, TransposeC, false> { using Type = WarpGemmWmma_f32_16x16x16_f8_f8<TransposeC>; };
|
||||
template<bool TransposeC> struct Dispatcher<bf8_t, bf8_t, float, 16, 16, 16, TransposeC, false> { using Type = WarpGemmWmma_f32_16x16x16_bf8_bf8<TransposeC>; };
|
||||
template<bool TransposeC> struct Dispatcher<fp8_t, bf8_t, float, 16, 16, 16, TransposeC, false> { using Type = WarpGemmWmma_f32_16x16x16_f8_bf8<TransposeC>; };
|
||||
template<bool TransposeC> struct Dispatcher<bf8_t, fp8_t, float, 16, 16, 16, TransposeC, false> { using Type = WarpGemmWmma_f32_16x16x16_bf8_f8<TransposeC>; };
|
||||
|
||||
// int8
|
||||
// ADataType, BDataType, AccDataType, MPerWave, NPerWave, KPerWave, TransposeC, SwizzleA, UseStructuredSparsity
|
||||
template<> struct Dispatcher<int8_t, int8_t, int32_t, 32, 32, 16, false> { using Type = WarpGemmMfma_i32_32x32x16_i8_i8; };
|
||||
template<> struct Dispatcher<int8_t, int8_t, int32_t, 32, 32, 16, true> { using Type = WarpGemmMfma_i32_32x32x16_i8_i8_CTransposed; };
|
||||
template<> struct Dispatcher<int8_t, int8_t, int32_t, 16, 16, 32, false> { using Type = WarpGemmMfma_i32_16x16x32_i8_i8; };
|
||||
template<> struct Dispatcher<int8_t, int8_t, int32_t, 16, 16, 32, true> { using Type = WarpGemmMfma_i32_16x16x32_i8_i8_CTransposed; };
|
||||
// WMMA cases
|
||||
template<bool TransposeC> struct Dispatcher<int8_t, int8_t, int32_t, 16, 16, 16, TransposeC, false> { using Type = WarpGemmWmma_i32_16x16x16_i8_i8<TransposeC>;};
|
||||
|
||||
// clang-format on
|
||||
} // namespace warp_gemm_dispatcher
|
||||
} // namespace impl
|
||||
|
||||
template <typename AType,
|
||||
typename BType,
|
||||
typename AccType,
|
||||
index_t MPerWave,
|
||||
index_t NPerWave,
|
||||
index_t KPerWave,
|
||||
bool TransposeC,
|
||||
bool SwizzleA = false,
|
||||
bool UseStructuredSparsity = false,
|
||||
WGAttrNumAccessEnum AttrNumAccessA = WGAttrNumAccessEnum::Single,
|
||||
WGAttrNumAccessEnum AttrNumAccessB = AttrNumAccessA>
|
||||
using WarpGemmDispatcher = typename impl::warp_gemm_dispatcher::Dispatcher< //
|
||||
AType,
|
||||
BType,
|
||||
AccType,
|
||||
MPerWave,
|
||||
NPerWave,
|
||||
KPerWave,
|
||||
TransposeC,
|
||||
SwizzleA,
|
||||
UseStructuredSparsity,
|
||||
AttrNumAccessA,
|
||||
AttrNumAccessB>::Type;
|
||||
|
||||
} // namespace ck_tile
|
||||
183
include/ck_tile/ops/gemm/warp/warp_gemm_impl.hpp
Normal file
183
include/ck_tile/ops/gemm/warp/warp_gemm_impl.hpp
Normal file
@@ -0,0 +1,183 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename WarpGemmAttribute_>
|
||||
struct WarpGemmImpl
|
||||
{
|
||||
using WarpGemmAttribute = remove_cvref_t<WarpGemmAttribute_>;
|
||||
|
||||
static constexpr index_t kM = WarpGemmAttribute::kM;
|
||||
static constexpr index_t kN = WarpGemmAttribute::kN;
|
||||
static constexpr index_t kK = WarpGemmAttribute::kK;
|
||||
static constexpr index_t kCMLane = WarpGemmAttribute::kCMLane;
|
||||
/// @brief The number of elements in K dimension processed by single thread in wavefront.
|
||||
///
|
||||
/// @note Note that WarpGemm may run MFMA instruction multiple times (on different K).
|
||||
/// In such situation this value reflects this fact.
|
||||
static constexpr index_t kKPerThread = WarpGemmAttribute::kKPerThread;
|
||||
|
||||
using ADataType = typename WarpGemmAttribute::ADataType;
|
||||
using BDataType = typename WarpGemmAttribute::BDataType;
|
||||
using CDataType = typename WarpGemmAttribute::CDataType;
|
||||
|
||||
using AWarpDstrEncoding = typename WarpGemmAttribute::AWarpDstrEncoding;
|
||||
using BWarpDstrEncoding = typename WarpGemmAttribute::BWarpDstrEncoding;
|
||||
using CWarpDstrEncoding = typename WarpGemmAttribute::CWarpDstrEncoding;
|
||||
|
||||
using AWarpDstr = remove_cvref_t<decltype(make_static_tile_distribution(AWarpDstrEncoding{}))>;
|
||||
using BWarpDstr = remove_cvref_t<decltype(make_static_tile_distribution(BWarpDstrEncoding{}))>;
|
||||
using CWarpDstr = remove_cvref_t<decltype(make_static_tile_distribution(CWarpDstrEncoding{}))>;
|
||||
|
||||
using AWarpTensor = static_distributed_tensor<ADataType, AWarpDstr>;
|
||||
using BWarpTensor = static_distributed_tensor<BDataType, BWarpDstr>;
|
||||
using CWarpTensor = static_distributed_tensor<CDataType, CWarpDstr>;
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access()
|
||||
{
|
||||
return WarpGemmAttribute_::get_num_of_access();
|
||||
}
|
||||
|
||||
template <typename CTensor, typename ATensor, typename BTensor, bool post_nop_ = false>
|
||||
CK_TILE_DEVICE void
|
||||
operator()(CTensor& c, const ATensor& a, const BTensor& b, bool_constant<post_nop_> = {}) const
|
||||
{
|
||||
static_assert(detail::is_similiar_distributed_tensor_v<CTensor, CWarpTensor> &&
|
||||
detail::is_similiar_distributed_tensor_v<ATensor, AWarpTensor> &&
|
||||
detail::is_similiar_distributed_tensor_v<BTensor, BWarpTensor>);
|
||||
using AVec = ext_vector_t<ADataType, ATensor::get_thread_buffer_size()>;
|
||||
using BVec = ext_vector_t<BDataType, BTensor::get_thread_buffer_size()>;
|
||||
using CVec = ext_vector_t<CDataType, CTensor::get_thread_buffer_size()>;
|
||||
|
||||
constexpr auto I0 = number<0>{};
|
||||
|
||||
const auto a_vec = a.get_thread_buffer().template get_as<AVec>()[I0];
|
||||
const auto b_vec = b.get_thread_buffer().template get_as<BVec>()[I0];
|
||||
auto c_vec = c.get_thread_buffer().template get_as<CVec>()[I0];
|
||||
|
||||
// c_vec += a_vec * b_vec
|
||||
WarpGemmAttribute{}(c_vec, a_vec, b_vec, bool_constant<post_nop_>{});
|
||||
|
||||
c.get_thread_buffer().template set_as<CVec>(I0, c_vec);
|
||||
}
|
||||
|
||||
template <typename CTensor,
|
||||
typename ATensor,
|
||||
typename BTensor,
|
||||
index_t i_subk,
|
||||
bool post_nop_ = false>
|
||||
CK_TILE_DEVICE void operator()(CTensor& c,
|
||||
const ATensor& a,
|
||||
const BTensor& b,
|
||||
number<i_subk>,
|
||||
bool_constant<post_nop_> = {}) const
|
||||
{
|
||||
using AVec = ext_vector_t<ADataType, ATensor::get_thread_buffer_size()>;
|
||||
using BVec = ext_vector_t<BDataType, BTensor::get_thread_buffer_size()>;
|
||||
using CVec = ext_vector_t<CDataType, CTensor::get_thread_buffer_size()>;
|
||||
|
||||
constexpr auto I0 = number<0>{};
|
||||
|
||||
const auto a_vec = a.get_thread_buffer().template get_as<AVec>()[I0];
|
||||
const auto b_vec = b.get_thread_buffer().template get_as<BVec>()[I0];
|
||||
auto c_vec = c.get_thread_buffer().template get_as<CVec>()[I0];
|
||||
|
||||
// c_vec += a_vec * b_vec
|
||||
WarpGemmAttribute{}(c_vec, a_vec, b_vec, number<i_subk>{}, bool_constant<post_nop_>{});
|
||||
|
||||
c.get_thread_buffer().template set_as<CVec>(I0, c_vec);
|
||||
}
|
||||
|
||||
template <index_t opselA,
|
||||
index_t opselB,
|
||||
typename CTensor,
|
||||
typename ATensor,
|
||||
typename BTensor,
|
||||
bool post_nop_ = false>
|
||||
CK_TILE_DEVICE void operator()(CTensor& c,
|
||||
const ATensor& a,
|
||||
const BTensor& b,
|
||||
const int32_t& a_scale,
|
||||
const int32_t& b_scale,
|
||||
bool_constant<post_nop_> = {}) const
|
||||
{
|
||||
static_assert(detail::is_similiar_distributed_tensor_v<CTensor, CWarpTensor> &&
|
||||
detail::is_similiar_distributed_tensor_v<ATensor, AWarpTensor> &&
|
||||
detail::is_similiar_distributed_tensor_v<BTensor, BWarpTensor>);
|
||||
using AVec = ext_vector_t<ADataType, ATensor::get_thread_buffer_size()>;
|
||||
using BVec = ext_vector_t<BDataType, BTensor::get_thread_buffer_size()>;
|
||||
using CVec = ext_vector_t<CDataType, CTensor::get_thread_buffer_size()>;
|
||||
|
||||
constexpr auto I0 = number<0>{};
|
||||
|
||||
const auto a_vec = a.get_thread_buffer().template get_as<AVec>()[I0];
|
||||
const auto b_vec = b.get_thread_buffer().template get_as<BVec>()[I0];
|
||||
auto c_vec = c.get_thread_buffer().template get_as<CVec>()[I0];
|
||||
|
||||
// c_vec += a_vec * b_vec
|
||||
WarpGemmAttribute{}.template operator()<opselA, opselB>(
|
||||
c_vec, a_vec, a_scale, b_vec, b_scale, bool_constant<post_nop_>{});
|
||||
|
||||
c.get_thread_buffer().template set_as<CVec>(I0, c_vec);
|
||||
}
|
||||
|
||||
template <typename ATensor, typename BTensor>
|
||||
CK_TILE_DEVICE auto operator()(const ATensor& a, const BTensor& b) const
|
||||
{
|
||||
using CTensor = CWarpTensor;
|
||||
static_assert(detail::is_similiar_distributed_tensor_v<ATensor, AWarpTensor> &&
|
||||
detail::is_similiar_distributed_tensor_v<BTensor, BWarpTensor>);
|
||||
CTensor c;
|
||||
|
||||
using AVec = ext_vector_t<ADataType, ATensor::get_thread_buffer_size()>;
|
||||
using BVec = ext_vector_t<BDataType, BTensor::get_thread_buffer_size()>;
|
||||
using CVec = ext_vector_t<CDataType, CTensor::get_thread_buffer_size()>;
|
||||
|
||||
constexpr auto I0 = number<0>{};
|
||||
|
||||
const auto a_vec = a.get_thread_buffer().template get_as<AVec>()[I0];
|
||||
const auto b_vec = b.get_thread_buffer().template get_as<BVec>()[I0];
|
||||
|
||||
// c_vec = a_vec * b_vec
|
||||
auto c_vec = WarpGemmAttribute{}(a_vec, b_vec);
|
||||
|
||||
c.get_thread_buffer().template set_as<CVec>(I0, c_vec);
|
||||
|
||||
return c;
|
||||
}
|
||||
|
||||
template <index_t opselA, index_t opselB, typename ATensor, typename BTensor>
|
||||
CK_TILE_DEVICE auto operator()(const ATensor& a,
|
||||
const BTensor& b,
|
||||
const int32_t& a_scale,
|
||||
const int32_t& b_scale) const
|
||||
{
|
||||
using CTensor = CWarpTensor;
|
||||
static_assert(detail::is_similiar_distributed_tensor_v<ATensor, AWarpTensor> &&
|
||||
detail::is_similiar_distributed_tensor_v<BTensor, BWarpTensor>);
|
||||
CTensor c;
|
||||
|
||||
using AVec = ext_vector_t<ADataType, ATensor::get_thread_buffer_size()>;
|
||||
using BVec = ext_vector_t<BDataType, BTensor::get_thread_buffer_size()>;
|
||||
using CVec = ext_vector_t<CDataType, CTensor::get_thread_buffer_size()>;
|
||||
|
||||
constexpr auto I0 = number<0>{};
|
||||
|
||||
const auto a_vec = a.get_thread_buffer().template get_as<AVec>()[I0];
|
||||
const auto b_vec = b.get_thread_buffer().template get_as<BVec>()[I0];
|
||||
|
||||
// c_vec = a_vec * b_vec
|
||||
auto c_vec =
|
||||
WarpGemmAttribute{}.template operator()<opselA, opselB>(a_vec, a_scale, b_vec, b_scale);
|
||||
|
||||
c.get_thread_buffer().template set_as<CVec>(I0, c_vec);
|
||||
|
||||
return c;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
131
include/ck_tile/ops/gemm/warp/warp_gemm_smfmac_impl.hpp
Normal file
131
include/ck_tile/ops/gemm/warp/warp_gemm_smfmac_impl.hpp
Normal file
@@ -0,0 +1,131 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/core/numeric/integer.hpp"
|
||||
#include "ck_tile/core/tensor/static_distributed_tensor.hpp"
|
||||
namespace ck_tile {
|
||||
|
||||
/**
|
||||
* @brief Compress A vector for 2:4 structured sparsity instruction by moving all non-zero
|
||||
* elements into lower part of a_vec to half its effective size.
|
||||
* @param a_vec Vector to be compressed.
|
||||
* @tparam ADataType The data type of a_vec
|
||||
* @tparam CompressedSize The target compression size
|
||||
* @tparam AVec The vector type of a_vec (deduced)
|
||||
* @return Packed 32‑bit word containing **CompressedSize** 2‑bit fields.
|
||||
* Each field encodes the original position (0–3) of the corresponding
|
||||
* non‑zero element in the input. If fewer than CompressedSize
|
||||
* non‑zeros are found, remaining fields default to 2 (see below).
|
||||
*/
|
||||
template <typename ADataType, index_t CompressedSize, typename AVec>
|
||||
static CK_TILE_DEVICE int32_t compress_a_impl(AVec& a_vec)
|
||||
{
|
||||
// idx holds one 2‑bit index per output element (total CompressedSize entries).
|
||||
// It is initialized to the pattern 0b10 for every field. This matches
|
||||
// what the hardware expects when there are fewer than two non‑zero values
|
||||
// in a 4‑element group – the unused output is treated as coming from slot 2.
|
||||
// The loop below will clear and set each field as real non‑zeros are seen.
|
||||
int32_t idx = 0;
|
||||
static_for<0, CompressedSize, 1>{}([&](auto k) { idx |= (2 << (2 * k)); });
|
||||
|
||||
static_for<0, CompressedSize / 2, 1>{}([&](auto i) {
|
||||
ADataType nonzero_elems[2] = {a_vec[i * 4 + 2], a_vec[i * 4 + 3]};
|
||||
int32_t non_zero_pos = 0;
|
||||
|
||||
static_for<0, 3, 1>{}([&](auto j) {
|
||||
if(a_vec[i * 4 + j] != 0.0f)
|
||||
{
|
||||
nonzero_elems[non_zero_pos] = a_vec[i * 4 + j];
|
||||
// clear the two‑bit field for this output and insert j
|
||||
idx &= ~(0b11 << 2 * (i * 2 + non_zero_pos));
|
||||
idx |= j << 2 * (i * 2 + non_zero_pos);
|
||||
++non_zero_pos;
|
||||
}
|
||||
});
|
||||
a_vec[i * 2] = nonzero_elems[0];
|
||||
a_vec[i * 2 + 1] = nonzero_elems[1];
|
||||
});
|
||||
|
||||
return idx;
|
||||
}
|
||||
|
||||
template <typename WarpGemmAttribute_>
|
||||
struct WarpGemmSmfmacImpl
|
||||
{
|
||||
using WarpGemmAttribute = remove_cvref_t<WarpGemmAttribute_>;
|
||||
|
||||
static constexpr index_t kM = WarpGemmAttribute::kM;
|
||||
static constexpr index_t kN = WarpGemmAttribute::kN;
|
||||
static constexpr index_t kK = WarpGemmAttribute::kK;
|
||||
/// @brief The number of elements in K dimension processed by single thread in wavefront.
|
||||
///
|
||||
/// @note Note that WarpGemm may run MFMA instruction multiple times (on different K).
|
||||
/// In such situation this value reflects this fact.
|
||||
static constexpr index_t kKPerThread = WarpGemmAttribute::kKPerThread;
|
||||
|
||||
using ADataType = typename WarpGemmAttribute::ADataType;
|
||||
using BDataType = typename WarpGemmAttribute::BDataType;
|
||||
using CDataType = typename WarpGemmAttribute::CDataType;
|
||||
|
||||
using AWarpDstrEncoding = typename WarpGemmAttribute::AWarpDstrEncoding;
|
||||
using BWarpDstrEncoding = typename WarpGemmAttribute::BWarpDstrEncoding;
|
||||
using CWarpDstrEncoding = typename WarpGemmAttribute::CWarpDstrEncoding;
|
||||
|
||||
using AWarpDstr = remove_cvref_t<decltype(make_static_tile_distribution(AWarpDstrEncoding{}))>;
|
||||
using BWarpDstr = remove_cvref_t<decltype(make_static_tile_distribution(BWarpDstrEncoding{}))>;
|
||||
using CWarpDstr = remove_cvref_t<decltype(make_static_tile_distribution(CWarpDstrEncoding{}))>;
|
||||
|
||||
using AWarpTensor = static_distributed_tensor<ADataType, AWarpDstr>;
|
||||
using BWarpTensor = static_distributed_tensor<BDataType, BWarpDstr>;
|
||||
using CWarpTensor = static_distributed_tensor<CDataType, CWarpDstr>;
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access()
|
||||
{
|
||||
return WarpGemmAttribute_::get_num_of_access();
|
||||
}
|
||||
|
||||
template <index_t CompressedSize, typename AVec>
|
||||
CK_TILE_DEVICE int32_t compress_a_vec(AVec& a_vec)
|
||||
{
|
||||
return compress_a_impl<ADataType, CompressedSize>(a_vec);
|
||||
}
|
||||
|
||||
template <typename CTensor, typename ATensor, typename BTensor, bool post_nop_ = false>
|
||||
CK_TILE_DEVICE void
|
||||
operator()(CTensor& c, const ATensor& a, const BTensor& b, bool_constant<post_nop_> = {}) const
|
||||
{
|
||||
static_assert(detail::is_similiar_distributed_tensor_v<CTensor, CWarpTensor> &&
|
||||
detail::is_similiar_distributed_tensor_v<ATensor, AWarpTensor> &&
|
||||
detail::is_similiar_distributed_tensor_v<BTensor, BWarpTensor>);
|
||||
constexpr auto CompressionRatio = WarpGemmAttribute::kCompressionRatio;
|
||||
|
||||
using AVec = ext_vector_t<ADataType, ATensor::get_thread_buffer_size()>;
|
||||
static constexpr index_t CompressedSize =
|
||||
ATensor::get_thread_buffer_size() / CompressionRatio;
|
||||
using AVecCompressed = ext_vector_t<ADataType, CompressedSize>;
|
||||
using BVec = ext_vector_t<BDataType, BTensor::get_thread_buffer_size()>;
|
||||
using CVec = ext_vector_t<CDataType, CTensor::get_thread_buffer_size()>;
|
||||
|
||||
constexpr auto I0 = number<0>{};
|
||||
|
||||
auto a_vec = a.get_thread_buffer().template get_as<AVec>()[I0];
|
||||
const auto b_vec = b.get_thread_buffer().template get_as<BVec>()[I0];
|
||||
auto c_vec = c.get_thread_buffer().template get_as<CVec>()[I0];
|
||||
|
||||
const int32_t idx = compress_a_vec<CompressedSize>(a_vec);
|
||||
|
||||
static_assert(CompressedSize == 4);
|
||||
// @TODO can we simply set a_vec_pruned to a_vec[0:3]?
|
||||
const AVecCompressed a_vec_pruned = {a_vec[0], a_vec[1], a_vec[2], a_vec[3]};
|
||||
|
||||
// c_vec += a_vec * b_vec[idx]
|
||||
WarpGemmAttribute{}(c_vec, a_vec_pruned, b_vec, idx, bool_constant<post_nop_>{});
|
||||
|
||||
c.get_thread_buffer().template set_as<CVec>(I0, c_vec);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
42
include/ck_tile/ops/gemm/warp/warp_wmma_gemm.hpp
Normal file
42
include/ck_tile/ops/gemm/warp/warp_wmma_gemm.hpp
Normal file
@@ -0,0 +1,42 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm_impl.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma_impl_16bit_traits.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma_impl_8bit_traits.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <bool kTransC = false>
|
||||
using WarpGemmWmma_f32_16x16x16_f16_f16 =
|
||||
WarpGemmImpl<WarpGemmAttributeWmma<WarpGemmAttributeWmmaImpl_f32_16x16x16_f16_f16, kTransC>>;
|
||||
|
||||
template <bool kTransC = false>
|
||||
using WarpGemmWmma_f32_16x16x16_bf16_bf16 =
|
||||
WarpGemmImpl<WarpGemmAttributeWmma<WarpGemmAttributeWmmaImpl_f32_16x16x16_bf16_bf16, kTransC>>;
|
||||
|
||||
template <bool kTransC = false>
|
||||
using WarpGemmWmma_i32_16x16x16_i8_i8 =
|
||||
WarpGemmImpl<WarpGemmAttributeWmma<WarpGemmAttributeWmmaImpl_i32_16x16x16_i8_i8, kTransC>>;
|
||||
|
||||
template <bool kTransC = false>
|
||||
using WarpGemmWmma_f32_16x16x16_f8_f8 =
|
||||
WarpGemmImpl<WarpGemmAttributeWmma<WarpGemmAttributeWmmaImpl_f32_16x16x16_f8_f8, kTransC>>;
|
||||
|
||||
template <bool kTransC = false>
|
||||
using WarpGemmWmma_f32_16x16x16_bf8_bf8 =
|
||||
WarpGemmImpl<WarpGemmAttributeWmma<WarpGemmAttributeWmmaImpl_f32_16x16x16_bf8_bf8, kTransC>>;
|
||||
|
||||
template <bool kTransC = false>
|
||||
using WarpGemmWmma_f32_16x16x16_f8_bf8 =
|
||||
WarpGemmImpl<WarpGemmAttributeWmma<WarpGemmAttributeWmmaImpl_f32_16x16x16_f8_bf8, kTransC>>;
|
||||
|
||||
template <bool kTransC = false>
|
||||
using WarpGemmWmma_f32_16x16x16_bf8_f8 =
|
||||
WarpGemmImpl<WarpGemmAttributeWmma<WarpGemmAttributeWmmaImpl_f32_16x16x16_bf8_f8, kTransC>>;
|
||||
|
||||
} // namespace ck_tile
|
||||
45
include/ck_tile/ops/gemm/warp/warp_wmma_gemm_gfx11_utils.hpp
Normal file
45
include/ck_tile/ops/gemm/warp/warp_wmma_gemm_gfx11_utils.hpp
Normal file
@@ -0,0 +1,45 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// C distribution of gfx11 WMMA is not compatible with A distribution:
|
||||
// C: 2 lanes per row (lane and lane + 16), 8 values per lane are interleaved.
|
||||
// A: 1 lane per row, 16 values, lane and lane + 16 have the same values.
|
||||
// This function transforms one ditribution to another for GEMM-GEMM scenarios.
|
||||
template <typename OutTensor, typename InTensor>
|
||||
CK_TILE_DEVICE static constexpr void PermuteWarpGemmCToA(OutTensor& out, const InTensor& in)
|
||||
{
|
||||
#if defined(__gfx11__)
|
||||
static_assert(sizeof(typename OutTensor::DataType) == 2);
|
||||
static_assert(std::is_same_v<typename OutTensor::DataType, typename InTensor::DataType>);
|
||||
|
||||
constexpr index_t n_out = OutTensor::get_thread_buffer_size();
|
||||
static_assert(n_out == InTensor::get_thread_buffer_size() * 2);
|
||||
|
||||
// Perm byte selectors are swapped for the second row (16 lanes) because it needs to be done
|
||||
// once instead to swapping w and v everytime
|
||||
const uint32_t byte_selector0 = get_lane_id() < 16 ? 0x05'04'01'00 : 0x01'00'05'04;
|
||||
const uint32_t byte_selector1 = get_lane_id() < 16 ? 0x07'06'03'02 : 0x03'02'07'06;
|
||||
static_for<0, n_out, 1>{}([&](auto i) {
|
||||
const auto v = in.get_thread_buffer().template get_as<uint32_t>(i);
|
||||
// Swap rows (lane <-> lane ^ 16)
|
||||
const auto w = __builtin_amdgcn_permlanex16(0, v, 0x76543210, 0xfedcba98, false, true);
|
||||
// Interleave values of lane and lane ^ 16
|
||||
out.get_thread_buffer().template set_as<uint32_t>(
|
||||
number<i * 2 + 0>{}, __builtin_amdgcn_perm(w, v, byte_selector0));
|
||||
out.get_thread_buffer().template set_as<uint32_t>(
|
||||
number<i * 2 + 1>{}, __builtin_amdgcn_perm(w, v, byte_selector1));
|
||||
});
|
||||
#else
|
||||
static_assert(false, "PermuteWarpGemmCToA is only for gfx11");
|
||||
ignore = out;
|
||||
ignore = in;
|
||||
#endif
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
Reference in New Issue
Block a user