mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-04 05:31:24 +00:00
[CK_TILE] Use read_tr in universal gemm (#2436)
* Use read_tr in universal gemm * Enable all instances back * Revert example37 changes * Resolve comments * resolve comments 2 * Fix assertion msg * fix the gemm basic * change index_t to bool for preshuffle variable * Solve the comment --------- Co-authored-by: Thomas Ning <Thomas.Ning@amd.com> Co-authored-by: Po Yen Chen <PoYen.Chen@amd.com> Co-authored-by: Max Podkorytov <4273004+tenpercent@users.noreply.github.com> Co-authored-by: AviralGoelAMD <aviral.goel@amd.com>
This commit is contained in:
@@ -21,22 +21,29 @@ using WarpGemmMfmaF16F16F32M16N16K16 = WarpGemmImpl<
|
||||
WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImplF16F16F32M16N16K16<WGAttrCtlEnum::Default_>>>;
|
||||
|
||||
#if defined(__gfx950__)
|
||||
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
|
||||
using WarpGemmMfmaF16F16F32M32N32K16 = WarpGemmImpl<
|
||||
WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImplF16F16F32M32N32K16<WGAttrCtlEnum::Default_>>>;
|
||||
|
||||
WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImplF16F16F32M32N32K16<WGAttrCtlEnum::Default_>,
|
||||
AttrNumAccess>>;
|
||||
#else
|
||||
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
|
||||
using WarpGemmMfmaF16F16F32M32N32K16 = WarpGemmImpl<WarpGemmAtrributeMfmaIterateK<
|
||||
WarpGemmAttributeMfmaImplF16F16F32M32N32K8<WGAttrCtlEnum::Default_>,
|
||||
2>>;
|
||||
2,
|
||||
AttrNumAccess>>;
|
||||
#endif
|
||||
|
||||
#if defined(__gfx950__)
|
||||
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
|
||||
using WarpGemmMfmaF16F16F32M16N16K32 = WarpGemmImpl<
|
||||
WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImplF16F16F32M16N16K32<WGAttrCtlEnum::Default_>>>;
|
||||
WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImplF16F16F32M16N16K32<WGAttrCtlEnum::Default_>,
|
||||
AttrNumAccess>>;
|
||||
#else
|
||||
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
|
||||
using WarpGemmMfmaF16F16F32M16N16K32 = WarpGemmImpl<WarpGemmAtrributeMfmaIterateK<
|
||||
WarpGemmAttributeMfmaImplF16F16F32M16N16K16<WGAttrCtlEnum::Default_>,
|
||||
2>>;
|
||||
2,
|
||||
AttrNumAccess>>;
|
||||
#endif
|
||||
|
||||
using WarpGemmMfmaF16F16F32M32N32K8SwizzleA = WarpGemmImpl<WarpGemmAtrributeMfmaIterateK_SwizzleA<
|
||||
@@ -56,25 +63,33 @@ using WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution =
|
||||
WarpGemmAttributeMfmaImplF16F16F32M16N16K16<WGAttrCtlEnum::Default_>>>;
|
||||
|
||||
#if defined(__gfx950__)
|
||||
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
|
||||
using WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution =
|
||||
WarpGemmImpl<WarpGemmAtrributeMfmaTransposedCDistribution<
|
||||
WarpGemmAttributeMfmaImplF16F16F32M32N32K16<WGAttrCtlEnum::Default_>>>;
|
||||
WarpGemmAttributeMfmaImplF16F16F32M32N32K16<WGAttrCtlEnum::Default_>,
|
||||
AttrNumAccess>>;
|
||||
#else
|
||||
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
|
||||
using WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution =
|
||||
WarpGemmImpl<WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution<
|
||||
WarpGemmAttributeMfmaImplF16F16F32M32N32K8<WGAttrCtlEnum::Default_>,
|
||||
2>>;
|
||||
2,
|
||||
AttrNumAccess>>;
|
||||
#endif
|
||||
|
||||
#if defined(__gfx950__)
|
||||
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
|
||||
using WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution =
|
||||
WarpGemmImpl<WarpGemmAtrributeMfmaTransposedCDistribution<
|
||||
WarpGemmAttributeMfmaImplF16F16F32M16N16K32<WGAttrCtlEnum::Default_>>>;
|
||||
WarpGemmAttributeMfmaImplF16F16F32M16N16K32<WGAttrCtlEnum::Default_>,
|
||||
AttrNumAccess>>;
|
||||
#else
|
||||
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
|
||||
using WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution =
|
||||
WarpGemmImpl<WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution<
|
||||
WarpGemmAttributeMfmaImplF16F16F32M16N16K16<WGAttrCtlEnum::Default_>,
|
||||
2>>;
|
||||
2,
|
||||
AttrNumAccess>>;
|
||||
#endif
|
||||
|
||||
#if defined(__gfx950__)
|
||||
@@ -123,22 +138,29 @@ using WarpGemmMfmaBf16Bf16F32M16N16K16 = WarpGemmImpl<
|
||||
WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16<WGAttrCtlEnum::Default_>>>;
|
||||
|
||||
#if defined(__gfx950__)
|
||||
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
|
||||
using WarpGemmMfmaBf16Bf16F32M32N32K16 = WarpGemmImpl<
|
||||
WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K16<WGAttrCtlEnum::Default_>>>;
|
||||
|
||||
WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K16<WGAttrCtlEnum::Default_>,
|
||||
AttrNumAccess>>;
|
||||
#else
|
||||
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
|
||||
using WarpGemmMfmaBf16Bf16F32M32N32K16 = WarpGemmImpl<WarpGemmAtrributeMfmaIterateK<
|
||||
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8<WGAttrCtlEnum::Default_>,
|
||||
2>>;
|
||||
2,
|
||||
AttrNumAccess>>;
|
||||
#endif
|
||||
|
||||
#if defined(__gfx950__)
|
||||
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
|
||||
using WarpGemmMfmaBf16Bf16F32M16N16K32 = WarpGemmImpl<
|
||||
WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K32<WGAttrCtlEnum::Default_>>>;
|
||||
WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K32<WGAttrCtlEnum::Default_>,
|
||||
AttrNumAccess>>;
|
||||
#else
|
||||
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
|
||||
using WarpGemmMfmaBf16Bf16F32M16N16K32 = WarpGemmImpl<WarpGemmAtrributeMfmaIterateK<
|
||||
WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16<WGAttrCtlEnum::Default_>,
|
||||
2>>;
|
||||
2,
|
||||
AttrNumAccess>>;
|
||||
#endif
|
||||
|
||||
using WarpGemmMfmaBf16Bf16F32M32N32K8SwizzleA = WarpGemmImpl<WarpGemmAtrributeMfmaIterateK_SwizzleA<
|
||||
@@ -159,25 +181,33 @@ using WarpGemmMfmaBf16Bf16F32M16N16K16TransposedCDistribution =
|
||||
WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16<WGAttrCtlEnum::Default_>>>;
|
||||
|
||||
#if defined(__gfx950__)
|
||||
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
|
||||
using WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution =
|
||||
WarpGemmImpl<WarpGemmAtrributeMfmaTransposedCDistribution<
|
||||
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K16<WGAttrCtlEnum::Default_>>>;
|
||||
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K16<WGAttrCtlEnum::Default_>,
|
||||
AttrNumAccess>>;
|
||||
#else
|
||||
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
|
||||
using WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution =
|
||||
WarpGemmImpl<WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution<
|
||||
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8<WGAttrCtlEnum::Default_>,
|
||||
2>>;
|
||||
2,
|
||||
AttrNumAccess>>;
|
||||
#endif
|
||||
|
||||
#if defined(__gfx950__)
|
||||
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
|
||||
using WarpGemmMfmaBf16Bf16F32M16N16K32TransposedCDistribution =
|
||||
WarpGemmImpl<WarpGemmAtrributeMfmaTransposedCDistribution<
|
||||
WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K32<WGAttrCtlEnum::Default_>>>;
|
||||
WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K32<WGAttrCtlEnum::Default_>,
|
||||
AttrNumAccess>>;
|
||||
#else
|
||||
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
|
||||
using WarpGemmMfmaBf16Bf16F32M16N16K32TransposedCDistribution =
|
||||
WarpGemmImpl<WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution<
|
||||
WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16<WGAttrCtlEnum::Default_>,
|
||||
2>>;
|
||||
2,
|
||||
AttrNumAccess>>;
|
||||
#endif
|
||||
|
||||
#if defined(__gfx950__)
|
||||
@@ -247,17 +277,25 @@ using WarpGemmMfma_f32_16x16x128_bf8_fp8 = WarpGemmImpl<WarpGemmAtrributeMfma<
|
||||
using WarpGemmMfma_f32_16x16x128_bf8_bf8 = WarpGemmImpl<WarpGemmAtrributeMfma<
|
||||
WarpGemmAttributeMfmaImpl_f32_16x16x128_bf8_bf8<WGAttrCtlEnum::Default_>>>;
|
||||
|
||||
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
|
||||
using WarpGemmMfma_f32_32x32x64_fp8_fp8 = WarpGemmImpl<
|
||||
WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImpl_f32_32x32x64_fp8_fp8<WGAttrCtlEnum::Default_>>>;
|
||||
WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImpl_f32_32x32x64_fp8_fp8<WGAttrCtlEnum::Default_>,
|
||||
AttrNumAccess>>;
|
||||
|
||||
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
|
||||
using WarpGemmMfma_f32_32x32x64_fp8_bf8 = WarpGemmImpl<
|
||||
WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImpl_f32_32x32x64_fp8_bf8<WGAttrCtlEnum::Default_>>>;
|
||||
WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImpl_f32_32x32x64_fp8_bf8<WGAttrCtlEnum::Default_>,
|
||||
AttrNumAccess>>;
|
||||
|
||||
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
|
||||
using WarpGemmMfma_f32_32x32x64_bf8_fp8 = WarpGemmImpl<
|
||||
WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImpl_f32_32x32x64_bf8_fp8<WGAttrCtlEnum::Default_>>>;
|
||||
WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImpl_f32_32x32x64_bf8_fp8<WGAttrCtlEnum::Default_>,
|
||||
AttrNumAccess>>;
|
||||
|
||||
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
|
||||
using WarpGemmMfma_f32_32x32x64_bf8_bf8 = WarpGemmImpl<
|
||||
WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImpl_f32_32x32x64_bf8_bf8<WGAttrCtlEnum::Default_>>>;
|
||||
WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImpl_f32_32x32x64_bf8_bf8<WGAttrCtlEnum::Default_>,
|
||||
AttrNumAccess>>;
|
||||
|
||||
using WarpGemmMfma_f32_32x32x16_fp8_fp8_CTransposed =
|
||||
WarpGemmImpl<WarpGemmAtrributeMfmaTransposedCDistribution<
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -8,10 +8,22 @@
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename WarpGemmAttributeMfmaImpl_>
|
||||
// Number of groups of consecutive elements to fill in a ABKLane
|
||||
enum class WGAttrNumAccessEnum
|
||||
{
|
||||
Single = 1,
|
||||
Double = 2,
|
||||
Quad = 4,
|
||||
Invalid = -1
|
||||
};
|
||||
|
||||
template <typename WarpGemmAttributeMfmaImpl_,
|
||||
WGAttrNumAccessEnum AttrNumAccess_ = WGAttrNumAccessEnum::Single>
|
||||
struct WarpGemmAtrributeMfma
|
||||
{
|
||||
using Impl = remove_cvref_t<WarpGemmAttributeMfmaImpl_>;
|
||||
using Impl = remove_cvref_t<WarpGemmAttributeMfmaImpl_>;
|
||||
static constexpr auto AttrNumAccess = AttrNumAccess_;
|
||||
static constexpr auto AttrNumAccessV = static_cast<index_t>(AttrNumAccess);
|
||||
|
||||
using ADataType = typename Impl::ADataType;
|
||||
using BDataType = typename Impl::BDataType;
|
||||
@@ -31,21 +43,35 @@ struct WarpGemmAtrributeMfma
|
||||
static_assert(Impl::kAMBlock == 1 && Impl::kBNBlock == 1,
|
||||
"Multi-block WarpGemmAttributeMfmaImpl is not supported");
|
||||
|
||||
using AWarpDstrEncoding = tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<Impl::kAMLane>, sequence<Impl::kABKLane, Impl::kABKPerLane>>,
|
||||
tuple<sequence<2, 1>>,
|
||||
tuple<sequence<0, 0>>,
|
||||
sequence<2>,
|
||||
sequence<1>>;
|
||||
|
||||
using BWarpDstrEncoding = tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<Impl::kBNLane>, sequence<Impl::kABKLane, Impl::kABKPerLane>>,
|
||||
tuple<sequence<2, 1>>,
|
||||
tuple<sequence<0, 0>>,
|
||||
sequence<2>,
|
||||
sequence<1>>;
|
||||
template <index_t kMNLane>
|
||||
static constexpr auto get_warp_dstr_encoding()
|
||||
{
|
||||
if constexpr(AttrNumAccessV == 1)
|
||||
{
|
||||
return tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<kMNLane>, sequence<Impl::kABKLane, Impl::kABKPerLane>>,
|
||||
tuple<sequence<2, 1>>,
|
||||
tuple<sequence<0, 0>>,
|
||||
sequence<2>,
|
||||
sequence<1>>{};
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(kKPerThread % AttrNumAccessV == 0,
|
||||
"kKPerThread must be divisible by NumAccess");
|
||||
return tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<kMNLane>,
|
||||
sequence<AttrNumAccessV, Impl::kABKLane, Impl::kABKPerLane / AttrNumAccessV>>,
|
||||
tuple<sequence<2, 1>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
sequence<2, 2>,
|
||||
sequence<0, 2>>{};
|
||||
}
|
||||
}
|
||||
using AWarpDstrEncoding = decltype(get_warp_dstr_encoding<Impl::kAMLane>());
|
||||
using BWarpDstrEncoding = decltype(get_warp_dstr_encoding<Impl::kBNLane>());
|
||||
|
||||
using CWarpDstrEncoding = tile_distribution_encoding<
|
||||
sequence<>,
|
||||
@@ -73,12 +99,16 @@ struct WarpGemmAtrributeMfma
|
||||
}
|
||||
};
|
||||
|
||||
template <typename WarpGemmAttributeMfmaImpl_, index_t kKIter>
|
||||
template <typename WarpGemmAttributeMfmaImpl_,
|
||||
index_t kKIter,
|
||||
WGAttrNumAccessEnum AttrNumAccess_ = WGAttrNumAccessEnum::Single>
|
||||
struct WarpGemmAtrributeMfmaIterateK
|
||||
{
|
||||
static_assert(kKIter > 0, "wrong!");
|
||||
|
||||
using Impl = remove_cvref_t<WarpGemmAttributeMfmaImpl_>;
|
||||
using Impl = remove_cvref_t<WarpGemmAttributeMfmaImpl_>;
|
||||
static constexpr auto AttrNumAccess = AttrNumAccess_;
|
||||
static constexpr auto AttrNumAccessV = static_cast<index_t>(AttrNumAccess);
|
||||
|
||||
using ADataType = typename Impl::ADataType;
|
||||
using BDataType = typename Impl::BDataType;
|
||||
@@ -104,17 +134,37 @@ struct WarpGemmAtrributeMfmaIterateK
|
||||
{
|
||||
if constexpr(Impl::kAMBlock == 1 && Impl::kBNBlock == 1)
|
||||
{
|
||||
return tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<Impl::kAMLane>,
|
||||
sequence<Impl::kABKLane, Impl::kABKPerLane * kKIter>>,
|
||||
tuple<sequence<2, 1>>,
|
||||
tuple<sequence<0, 0>>,
|
||||
sequence<2>,
|
||||
sequence<1>>{};
|
||||
if constexpr(AttrNumAccessV == 1)
|
||||
{
|
||||
return tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<Impl::kAMLane>,
|
||||
sequence<Impl::kABKLane, Impl::kABKPerLane * kKIter>>,
|
||||
tuple<sequence<2, 1>>,
|
||||
tuple<sequence<0, 0>>,
|
||||
sequence<2>,
|
||||
sequence<1>>{};
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(kKPerThread % AttrNumAccessV == 0,
|
||||
"kKPerThread must be divisible by NumAccess");
|
||||
return tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<Impl::kAMLane>,
|
||||
sequence<AttrNumAccessV,
|
||||
Impl::kABKLane,
|
||||
Impl::kABKPerLane * kKIter / AttrNumAccessV>>,
|
||||
tuple<sequence<2, 1>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
sequence<2, 2>,
|
||||
sequence<0, 2>>{};
|
||||
}
|
||||
}
|
||||
else if constexpr(Impl::kAMBlock == 1 && 1 < Impl::kBNBlock)
|
||||
{
|
||||
static_assert(AttrNumAccessV == 1,
|
||||
"Multiple access is not supported when using multi-block");
|
||||
// each M blocks share the same data
|
||||
return tile_distribution_encoding<
|
||||
sequence<Impl::kBNBlock>,
|
||||
@@ -127,6 +177,8 @@ struct WarpGemmAtrributeMfmaIterateK
|
||||
}
|
||||
else if constexpr(1 < Impl::kAMBlock && Impl::kBNBlock == 1)
|
||||
{
|
||||
static_assert(AttrNumAccessV == 1,
|
||||
"Multiple access is not supported when using multi-block");
|
||||
// single block to multi-block thread mapping
|
||||
return tile_distribution_encoding<
|
||||
sequence<>,
|
||||
@@ -143,17 +195,38 @@ struct WarpGemmAtrributeMfmaIterateK
|
||||
{
|
||||
if constexpr(Impl::kAMBlock == 1 && Impl::kBNBlock == 1)
|
||||
{
|
||||
return tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<Impl::kBNLane>,
|
||||
sequence<Impl::kABKLane, Impl::kABKPerLane * kKIter>>,
|
||||
tuple<sequence<2, 1>>,
|
||||
tuple<sequence<0, 0>>,
|
||||
sequence<2>,
|
||||
sequence<1>>{};
|
||||
if constexpr(AttrNumAccessV == 1)
|
||||
{
|
||||
return tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<Impl::kBNLane>,
|
||||
sequence<Impl::kABKLane, Impl::kABKPerLane * kKIter>>,
|
||||
tuple<sequence<2, 1>>,
|
||||
tuple<sequence<0, 0>>,
|
||||
sequence<2>,
|
||||
sequence<1>>{};
|
||||
}
|
||||
else
|
||||
{
|
||||
|
||||
static_assert(kKPerThread % AttrNumAccessV == 0,
|
||||
"kKPerThread must be divisible by NumAccess");
|
||||
return tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<Impl::kBNLane>,
|
||||
sequence<AttrNumAccessV,
|
||||
Impl::kABKLane,
|
||||
Impl::kABKPerLane * kKIter / AttrNumAccessV>>,
|
||||
tuple<sequence<2, 1>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
sequence<2, 2>,
|
||||
sequence<0, 2>>{};
|
||||
}
|
||||
}
|
||||
else if constexpr(Impl::kAMBlock == 1 && 1 < Impl::kBNBlock)
|
||||
{
|
||||
static_assert(AttrNumAccessV == 1,
|
||||
"Multiple access is not supported when using multi-block");
|
||||
// single block to multi-block thread mapping
|
||||
return tile_distribution_encoding<
|
||||
sequence<>,
|
||||
@@ -166,6 +239,8 @@ struct WarpGemmAtrributeMfmaIterateK
|
||||
}
|
||||
else if constexpr(1 < Impl::kAMBlock && Impl::kBNBlock == 1)
|
||||
{
|
||||
static_assert(AttrNumAccessV == 1,
|
||||
"Multiple access is not supported when using multi-block");
|
||||
// each N blocks share the same data
|
||||
return tile_distribution_encoding<
|
||||
sequence<Impl::kAMBlock>,
|
||||
@@ -289,10 +364,13 @@ struct WarpGemmAtrributeMfmaIterateK
|
||||
}
|
||||
};
|
||||
|
||||
template <typename WarpGemmAttributeMfmaImpl_>
|
||||
template <typename WarpGemmAttributeMfmaImpl_,
|
||||
WGAttrNumAccessEnum AttrNumAccess_ = WGAttrNumAccessEnum::Single>
|
||||
struct WarpGemmAtrributeMfmaTransposedCDistribution
|
||||
{
|
||||
using Impl = remove_cvref_t<WarpGemmAttributeMfmaImpl_>;
|
||||
using Impl = remove_cvref_t<WarpGemmAttributeMfmaImpl_>;
|
||||
static constexpr auto AttrNumAccess = AttrNumAccess_;
|
||||
static constexpr auto AttrNumAccessV = static_cast<index_t>(AttrNumAccess);
|
||||
|
||||
using ADataType = typename Impl::BDataType;
|
||||
using BDataType = typename Impl::ADataType;
|
||||
@@ -312,21 +390,35 @@ struct WarpGemmAtrributeMfmaTransposedCDistribution
|
||||
static_assert(Impl::kAMBlock == 1 && Impl::kBNBlock == 1,
|
||||
"Multi-block WarpGemmAttributeMfmaImpl is not supported");
|
||||
|
||||
using AWarpDstrEncoding = tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<Impl::kBNLane>, sequence<Impl::kABKLane, Impl::kABKPerLane>>,
|
||||
tuple<sequence<2, 1>>,
|
||||
tuple<sequence<0, 0>>,
|
||||
sequence<2>,
|
||||
sequence<1>>;
|
||||
|
||||
using BWarpDstrEncoding = tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<Impl::kAMLane>, sequence<Impl::kABKLane, Impl::kABKPerLane>>,
|
||||
tuple<sequence<2, 1>>,
|
||||
tuple<sequence<0, 0>>,
|
||||
sequence<2>,
|
||||
sequence<1>>;
|
||||
template <index_t kMNLane>
|
||||
static constexpr auto get_warp_dstr_encoding()
|
||||
{
|
||||
if constexpr(AttrNumAccessV == 1)
|
||||
{
|
||||
return tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<kMNLane>, sequence<Impl::kABKLane, Impl::kABKPerLane>>,
|
||||
tuple<sequence<2, 1>>,
|
||||
tuple<sequence<0, 0>>,
|
||||
sequence<2>,
|
||||
sequence<1>>{};
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(kKPerThread % AttrNumAccessV == 0,
|
||||
"kKPerThread must be divisible by NumAccess");
|
||||
return tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<kMNLane>,
|
||||
sequence<AttrNumAccessV, Impl::kABKLane, Impl::kABKPerLane / AttrNumAccessV>>,
|
||||
tuple<sequence<2, 1>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
sequence<2, 2>,
|
||||
sequence<0, 2>>{};
|
||||
}
|
||||
}
|
||||
using AWarpDstrEncoding = decltype(get_warp_dstr_encoding<Impl::kBNLane>());
|
||||
using BWarpDstrEncoding = decltype(get_warp_dstr_encoding<Impl::kAMLane>());
|
||||
|
||||
using CWarpDstrEncoding = tile_distribution_encoding<
|
||||
sequence<>,
|
||||
@@ -450,10 +542,13 @@ struct WarpGemmAtrributeMfmaTransposedCDistribution_SwizzleB
|
||||
}
|
||||
};
|
||||
|
||||
template <typename WarpGemmAttributeMfmaImpl_, index_t kKIter>
|
||||
template <typename WarpGemmAttributeMfmaImpl_,
|
||||
index_t kKIter,
|
||||
WGAttrNumAccessEnum AttrNumAccess_ = WGAttrNumAccessEnum::Single>
|
||||
struct WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution
|
||||
{
|
||||
using Impl = remove_cvref_t<WarpGemmAttributeMfmaImpl_>;
|
||||
using Impl = remove_cvref_t<WarpGemmAttributeMfmaImpl_>;
|
||||
static constexpr auto AttrNumAccess = AttrNumAccess_;
|
||||
|
||||
// swap A and B
|
||||
using ADataType = typename Impl::BDataType;
|
||||
@@ -478,80 +573,14 @@ struct WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution
|
||||
|
||||
CK_TILE_DEVICE static constexpr auto get_awarp_dstr_encoding()
|
||||
{
|
||||
if constexpr(Impl::kAMBlock == 1 && Impl::kBNBlock == 1)
|
||||
{
|
||||
return tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<Impl::kBNLane>,
|
||||
sequence<Impl::kABKLane, Impl::kABKPerLane * kKIter>>,
|
||||
tuple<sequence<2, 1>>,
|
||||
tuple<sequence<0, 0>>,
|
||||
sequence<2>,
|
||||
sequence<1>>{};
|
||||
}
|
||||
else if constexpr(Impl::kAMBlock == 1 && 1 < Impl::kBNBlock)
|
||||
{
|
||||
// single block to multi-block thread mapping
|
||||
return tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<Impl::kBNBlock, Impl::kBNLane>,
|
||||
sequence<Impl::kABKLane, Impl::kABKPerLane * kKIter>>,
|
||||
tuple<sequence<1, 2, 1>>,
|
||||
tuple<sequence<0, 0, 1>>,
|
||||
sequence<2>,
|
||||
sequence<1>>{};
|
||||
}
|
||||
else if constexpr(1 < Impl::kAMBlock && Impl::kBNBlock == 1)
|
||||
{
|
||||
// each N blocks share the same data
|
||||
return tile_distribution_encoding<
|
||||
sequence<Impl::kAMBlock>,
|
||||
tuple<sequence<Impl::kBNLane>,
|
||||
sequence<Impl::kABKLane, Impl::kABKPerLane * kKIter>>,
|
||||
tuple<sequence<0, 2, 1>>,
|
||||
tuple<sequence<0, 0, 0>>,
|
||||
sequence<2>,
|
||||
sequence<1>>{};
|
||||
}
|
||||
return WarpGemmAtrributeMfmaIterateK<Impl, kKIter, AttrNumAccess>::
|
||||
get_bwarp_dstr_encoding();
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE static constexpr auto get_bwarp_dstr_encoding()
|
||||
{
|
||||
if constexpr(Impl::kAMBlock == 1 && Impl::kBNBlock == 1)
|
||||
{
|
||||
return tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<Impl::kAMLane>,
|
||||
sequence<Impl::kABKLane, Impl::kABKPerLane * kKIter>>,
|
||||
tuple<sequence<2, 1>>,
|
||||
tuple<sequence<0, 0>>,
|
||||
sequence<2>,
|
||||
sequence<1>>{};
|
||||
}
|
||||
else if constexpr(Impl::kAMBlock == 1 && 1 < Impl::kBNBlock)
|
||||
{
|
||||
// each M blocks share the same data
|
||||
return tile_distribution_encoding<
|
||||
sequence<Impl::kBNBlock>,
|
||||
tuple<sequence<Impl::kAMLane>,
|
||||
sequence<Impl::kABKLane, Impl::kABKPerLane * kKIter>>,
|
||||
tuple<sequence<0, 2, 1>>,
|
||||
tuple<sequence<0, 0, 0>>,
|
||||
sequence<2>,
|
||||
sequence<1>>{};
|
||||
}
|
||||
else if constexpr(1 < Impl::kAMBlock && Impl::kBNBlock == 1)
|
||||
{
|
||||
// single block to multi-block thread mapping
|
||||
return tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<Impl::kAMBlock, Impl::kAMLane>,
|
||||
sequence<Impl::kABKLane, Impl::kABKPerLane * kKIter>>,
|
||||
tuple<sequence<1, 2, 1>>,
|
||||
tuple<sequence<0, 0, 1>>,
|
||||
sequence<2>,
|
||||
sequence<1>>{};
|
||||
}
|
||||
return WarpGemmAtrributeMfmaIterateK<Impl, kKIter, AttrNumAccess>::
|
||||
get_awarp_dstr_encoding();
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE static constexpr auto get_cwarp_dstr_encoding()
|
||||
|
||||
@@ -16,8 +16,9 @@ template <typename AType,
|
||||
index_t NPerWave,
|
||||
index_t KPerWave,
|
||||
bool TransposeC,
|
||||
bool SwizzleA = false,
|
||||
bool UseStructuredSparsity = false>
|
||||
bool SwizzleA = false,
|
||||
bool UseStructuredSparsity = false,
|
||||
WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
|
||||
struct WarpGemmMfmaDispatcher;
|
||||
|
||||
// clang-format off
|
||||
@@ -25,12 +26,20 @@ struct WarpGemmMfmaDispatcher;
|
||||
// ADataType, BDataType, AccDataType, MPerWave, NPerWave, KPerWave, TransposeC, SwizzleA, UseStructuredSparsity
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 32, 32, 8, false> { using Type = WarpGemmMfmaF16F16F32M32N32K8; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 32, 32, 8, true> { using Type = WarpGemmMfmaF16F16F32M32N32K8TransposedCDistribution; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 32, 32, 16, false> { using Type = WarpGemmMfmaF16F16F32M32N32K16; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 32, 32, 16, true> { using Type = WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 32, 32, 16, false> { using Type = WarpGemmMfmaF16F16F32M32N32K16<>; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 32, 32, 16, true> { using Type = WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution<>; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 32, 32, 16, false, false, false, WGAttrNumAccessEnum::Double> {
|
||||
using Type = WarpGemmMfmaF16F16F32M32N32K16<WGAttrNumAccessEnum::Double>; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 32, 32, 16, true, false, false, WGAttrNumAccessEnum::Double> {
|
||||
using Type = WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution<WGAttrNumAccessEnum::Double>; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 16, 16, 16, false> { using Type = WarpGemmMfmaF16F16F32M16N16K16; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 16, 16, 16, true> { using Type = WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 16, 16, 32, false> { using Type = WarpGemmMfmaF16F16F32M16N16K32; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 16, 16, 32, true> { using Type = WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 16, 16, 32, false> { using Type = WarpGemmMfmaF16F16F32M16N16K32<>; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 16, 16, 32, true> { using Type = WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution<>; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 16, 16, 32, false, false, false, WGAttrNumAccessEnum::Double> {
|
||||
using Type = WarpGemmMfmaF16F16F32M16N16K32<WGAttrNumAccessEnum::Double>; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 16, 16, 32, true, false, false, WGAttrNumAccessEnum::Double> {
|
||||
using Type = WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution<WGAttrNumAccessEnum::Double>; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 4, 64, 16, false> { using Type = WarpGemmMfmaF16F16F32M4N64K16; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 64, 4, 16, false> { using Type = WarpGemmMfmaF16F16F32M64N4K16; };
|
||||
|
||||
@@ -46,12 +55,20 @@ template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float
|
||||
// ADataType, BDataType, AccDataType, MPerWave, NPerWave, KPerWave, TransposeC, SwizzleA, UseStructuredSparsity
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 32, 32, 8, false> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K8; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 32, 32, 8, true> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K8TransposedCDistribution; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 32, 32, 16, false> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K16; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 32, 32, 16, true> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 32, 32, 16, false> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K16<>; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 32, 32, 16, true> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution<>; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 32, 32, 16, false, false, false, WGAttrNumAccessEnum::Double> {
|
||||
using Type = WarpGemmMfmaBf16Bf16F32M32N32K16<WGAttrNumAccessEnum::Double>; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 32, 32, 16, true, false, false, WGAttrNumAccessEnum::Double> {
|
||||
using Type = WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution<WGAttrNumAccessEnum::Double>; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 16, 16, 16, false> { using Type = WarpGemmMfmaBf16Bf16F32M16N16K16; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 16, 16, 16, true> { using Type = WarpGemmMfmaBf16Bf16F32M16N16K16TransposedCDistribution; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 16, 16, 32, false> { using Type = WarpGemmMfmaBf16Bf16F32M16N16K32; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 16, 16, 32, true> { using Type = WarpGemmMfmaBf16Bf16F32M16N16K32TransposedCDistribution; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 16, 16, 32, false> { using Type = WarpGemmMfmaBf16Bf16F32M16N16K32<>; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 16, 16, 32, true> { using Type = WarpGemmMfmaBf16Bf16F32M16N16K32TransposedCDistribution<>; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 16, 16, 32, false, false, false, WGAttrNumAccessEnum::Double> {
|
||||
using Type = WarpGemmMfmaBf16Bf16F32M16N16K32<WGAttrNumAccessEnum::Double>; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 16, 16, 32, true, false, false, WGAttrNumAccessEnum::Double> {
|
||||
using Type = WarpGemmMfmaBf16Bf16F32M16N16K32TransposedCDistribution<WGAttrNumAccessEnum::Double>; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 4, 64, 16, false> { using Type = WarpGemmMfmaBf16Bf16F32M4N64K16; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 64, 4, 16, false> { using Type = WarpGemmMfmaBf16Bf16F32M64N4K16; };
|
||||
|
||||
@@ -80,10 +97,18 @@ template<> struct WarpGemmMfmaDispatcher<ck_tile::fp8_t, ck_tile::bf8_t, float,
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf8_t, ck_tile::fp8_t, float, 16, 16, 128, false> { using Type = WarpGemmMfma_f32_16x16x128_bf8_fp8; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf8_t, ck_tile::bf8_t, float, 16, 16, 128, false> { using Type = WarpGemmMfma_f32_16x16x128_bf8_bf8; };
|
||||
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::fp8_t, ck_tile::fp8_t, float, 32, 32, 64, false> { using Type = WarpGemmMfma_f32_32x32x64_fp8_fp8; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::fp8_t, ck_tile::bf8_t, float, 32, 32, 64, false> { using Type = WarpGemmMfma_f32_32x32x64_fp8_bf8; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf8_t, ck_tile::fp8_t, float, 32, 32, 64, false> { using Type = WarpGemmMfma_f32_32x32x64_bf8_fp8; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf8_t, ck_tile::bf8_t, float, 32, 32, 64, false> { using Type = WarpGemmMfma_f32_32x32x64_bf8_bf8; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::fp8_t, ck_tile::fp8_t, float, 32, 32, 64, false> { using Type = WarpGemmMfma_f32_32x32x64_fp8_fp8<>; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::fp8_t, ck_tile::bf8_t, float, 32, 32, 64, false> { using Type = WarpGemmMfma_f32_32x32x64_fp8_bf8<>; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf8_t, ck_tile::fp8_t, float, 32, 32, 64, false> { using Type = WarpGemmMfma_f32_32x32x64_bf8_fp8<>; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf8_t, ck_tile::bf8_t, float, 32, 32, 64, false> { using Type = WarpGemmMfma_f32_32x32x64_bf8_bf8<>; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::fp8_t, ck_tile::fp8_t, float, 32, 32, 64, false, false, false, WGAttrNumAccessEnum::Quad> {
|
||||
using Type = WarpGemmMfma_f32_32x32x64_fp8_fp8<WGAttrNumAccessEnum::Quad>; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::fp8_t, ck_tile::bf8_t, float, 32, 32, 64, false, false, false, WGAttrNumAccessEnum::Quad> {
|
||||
using Type = WarpGemmMfma_f32_32x32x64_fp8_bf8<WGAttrNumAccessEnum::Quad>; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf8_t, ck_tile::fp8_t, float, 32, 32, 64, false, false, false, WGAttrNumAccessEnum::Quad> {
|
||||
using Type = WarpGemmMfma_f32_32x32x64_bf8_fp8<WGAttrNumAccessEnum::Quad>; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf8_t, ck_tile::bf8_t, float, 32, 32, 64, false, false, false, WGAttrNumAccessEnum::Quad> {
|
||||
using Type = WarpGemmMfma_f32_32x32x64_bf8_bf8<WGAttrNumAccessEnum::Quad>; };
|
||||
|
||||
// int8
|
||||
// ADataType, BDataType, AccDataType, MPerWave, NPerWave, KPerWave, TransposeC, SwizzleA, UseStructuredSparsity
|
||||
@@ -102,8 +127,9 @@ template <typename AType,
|
||||
index_t NPerWave,
|
||||
index_t KPerWave,
|
||||
bool TransposeC,
|
||||
bool SwizzleA = false,
|
||||
bool UseStructuredSparsity = false>
|
||||
bool SwizzleA = false,
|
||||
bool UseStructuredSparsity = false,
|
||||
WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
|
||||
using WarpGemmMfmaDispatcher = typename impl::WarpGemmMfmaDispatcher<AType,
|
||||
BType,
|
||||
AccType,
|
||||
@@ -112,6 +138,7 @@ using WarpGemmMfmaDispatcher = typename impl::WarpGemmMfmaDispatcher<AType,
|
||||
KPerWave,
|
||||
TransposeC,
|
||||
SwizzleA,
|
||||
UseStructuredSparsity>::Type;
|
||||
UseStructuredSparsity,
|
||||
AttrNumAccess>::Type;
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
Reference in New Issue
Block a user