[CK_TILE] CK_TILE GEMM WMMA Support for GFX11/GFX12 (#2466)

* WMMA GEMM F16 Implementation

Signed-off-by: root <tianyuwu@amd.com>

* Self-review

Signed-off-by: root <tianyuwu@amd.com>

* ASIC check minor tweak

Signed-off-by: root <tianyuwu@amd.com>

* add missing include file

* Set GPU_TARGETS to gfx11/12 generic

Signed-off-by: root <tianyuwu@amd.com>

* INT8 GFX12

Signed-off-by: root <tianyuwu@amd.com>

* add int8x16 branch

* Fix CI script

Signed-off-by: root <tianyuwu@amd.com>

* Fix typo

Signed-off-by: root <tianyuwu@amd.com>

* Add CK_Tile WMMA example

Signed-off-by: Tianyuan Wu <tianyuwu@amd.com>

* Fix CI

Signed-off-by: Tianyuan Wu <tianyuwu@amd.com>

* fix clang format

* Set M/N_Warp Back to Constant

Signed-off-by: Tianyuan Wu <tianyuwu@amd.com>

* Use GemmConfigComputeV3 by default

Signed-off-by: TianyuanWu <Tianyuan.Wu@amd.com>

* Enable CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT for gfx12

Signed-off-by: TianyuanWu <Tianyuan.Wu@amd.com>

* Remove CK_Tile wmma gemm examples from the CI list

Signed-off-by: TianyuanWu <Tianyuan.Wu@amd.com>

* Add atomic add fallback method for gfx11

Signed-off-by: TianyuanWu <Tianyuan.Wu@amd.com>

* Fix typo

Signed-off-by: TianyuanWu <Tianyuan.Wu@amd.com>

* Omit copyright year

Signed-off-by: TianyuanWu <Tianyuan.Wu@amd.com>

* Support non-square cases

Signed-off-by: TianyuanWu <Tianyuan.Wu@amd.com>

* Fix CI

Signed-off-by: TianyuanWu <Tianyuan.Wu@amd.com>

* Add get_device_ip()

Signed-off-by: TianyuanWu <Tianyuan.Wu@amd.com>

* Revert "Add atomic add fallback method for gfx11"

This reverts commit 4f664969c01b37976c8518c19833d9f1574cd746.

Signed-off-by: Tianyuan Wu <Tianyuan.Wu@amd.com>

* Revert "Enable CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT for gfx12"

This reverts commit 949129a3858a825b2a2c4d3ec01663df18a165a5.

* Revise method name and typos

Signed-off-by: Tianyuan Wu <Tianyuan.Wu@amd.com>

* clang-format

Signed-off-by: TianyuanWu <Tianyuan.Wu@amd.com>

* Try fix CI

Signed-off-by: TianyuanWu <Tianyuan.Wu@amd.com>

* Revert "Try fix CI"

This reverts commit 084c683227e64ab6a8137db00c8165fb05bdc902.

* clang-format

Signed-off-by: TianyuanWu <Tianyuan.Wu@amd.com>

* Fix typo caused by merge

Signed-off-by: Tianyuan Wu <Tianyuan.Wu@amd.com>

* Fix typo caused by merging

Signed-off-by: Tianyuan Wu <Tianyuan.Wu@amd.com>

---------

Signed-off-by: root <tianyuwu@amd.com>
Signed-off-by: Tianyuan Wu <tianyuwu@amd.com>
Signed-off-by: TianyuanWu <Tianyuan.Wu@amd.com>
Signed-off-by: Tianyuan Wu <Tianyuan.Wu@amd.com>
Co-authored-by: joye <joye@amd.com>
Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com>
Co-authored-by: illsilin_amdeng <Illia.Silin@amd.com>

[ROCm/composable_kernel commit: 68134b60e4]
This commit is contained in:
Tianyuan Wu
2025-08-16 07:22:27 +08:00
committed by GitHub
parent 42d775e488
commit ec7ee5b7b7
54 changed files with 1388 additions and 403 deletions

View File

@@ -327,6 +327,7 @@ endif()
if(USE_OPT_GFX11)
add_compile_options(-mcumode)
add_compile_options(-mno-wavefrontsize64)
add_compile_definitions(CK_TILE_WAVE32_ENABLED)
message(STATUS "CK compiled with USE_OPT_GFX11 set to ${USE_OPT_GFX11}")
endif()
@@ -336,6 +337,12 @@ if(ENABLE_ASM_DUMP)
message("CK compiled with ENABLE_ASM_DUMP set to ${ENABLE_ASM_DUMP}")
endif()
if(USE_OPT_GFX12 AND (SUPPORTED_GPU_TARGETS MATCHES "gfx12"))
add_compile_options(-mno-wavefrontsize64)
add_compile_definitions(CK_TILE_WAVE32_ENABLED)
message(STATUS "CK compiled with USE_OPT_GFX12 set to ${USE_OPT_GFX12}")
endif()
## Threads
set(THREADS_PREFER_PTHREAD_FLAG ON)
find_package(Threads REQUIRED)

4
Jenkinsfile vendored
View File

@@ -1474,7 +1474,7 @@ pipeline {
}
agent{ label rocmnode("gfx1101") }
environment{
setup_args = """ -DCMAKE_INSTALL_PREFIX=../install -DGPU_TARGETS="gfx11-generic" -DCMAKE_CXX_FLAGS=" -O3 " """
setup_args = """ -DCMAKE_INSTALL_PREFIX=../install -DGPU_TARGETS="gfx11-generic" -DUSE_OPT_GFX11=ON -DCMAKE_CXX_FLAGS=" -O3 " """
execute_args = """ cd ../client_example && rm -rf build && mkdir build && cd build && \
cmake -DCMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" \
-DGPU_TARGETS="gfx11-generic" \
@@ -1495,7 +1495,7 @@ pipeline {
}
agent{ label rocmnode("gfx1201") }
environment{
setup_args = """ -DCMAKE_INSTALL_PREFIX=../install -DGPU_TARGETS="gfx12-generic" -DCMAKE_CXX_FLAGS=" -O3 " """
setup_args = """ -DCMAKE_INSTALL_PREFIX=../install -DGPU_TARGETS="gfx12-generic" -DUSE_OPT_GFX12=ON -DCMAKE_CXX_FLAGS=" -O3 " """
execute_args = """ cd ../client_example && rm -rf build && mkdir build && cd build && \
cmake -DCMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" \
-DGPU_TARGETS="gfx12-generic" \

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <numeric>

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <numeric>

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <numeric>

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <numeric>

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2023-2025, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <numeric>

View File

@@ -172,6 +172,27 @@ struct GemmConfigComputeV3_2 : public GemmConfigBase
static constexpr int kBlockPerCu = 2;
};
template <typename PrecType>
struct GemmConfigComputeV3_WMMA : public GemmConfigBase
{
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 128;
static constexpr ck_tile::index_t K_Tile = 64 / sizeof(PrecType);
static constexpr ck_tile::index_t M_Warp = 4;
static constexpr ck_tile::index_t N_Warp = 2;
static constexpr ck_tile::index_t K_Warp = 1;
static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile = 16;
static constexpr bool DoubleSmemBuffer = false;
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3;
static constexpr int kBlockPerCu = 2;
};
template <typename PrecType>
struct GemmConfigComputeV4 : public GemmConfigBase
{

View File

@@ -346,5 +346,4 @@ int main(int argc, char* argv[])
// Return a non-zero code to indicate failure
return EXIT_FAILURE;
}
return EXIT_SUCCESS;
}

0
example/ck_tile/17_grouped_gemm/grouped_gemm.cpp Executable file → Normal file
View File

View File

@@ -9,6 +9,7 @@
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/numeric/integral_constant.hpp"
#include "ck_tile/core/utility/ignore.hpp"
#define CK_TILE_S_CNT_MAX 0b1100'1111'0111'1111
#define CK_TILE_VMCNT(cnt) \
@@ -59,7 +60,7 @@ enum struct memory_operation_enum : std::uint16_t
CK_TILE_HOST_DEVICE constexpr index_t get_warp_size()
{
#if defined(__GFX9__) || !defined(__HIP_DEVICE_COMPILE__)
#if defined(__GFX9__) || (!defined(__HIP_DEVICE_COMPILE__) && !defined(CK_TILE_WAVE32_ENABLED))
return 64;
#else
return 32;
@@ -230,4 +231,20 @@ CK_TILE_HOST_DEVICE constexpr const char* address_space_to_string(address_space_
}
}
// Architecture tags
struct gfx11_t
{
};
struct gfx12_t
{
};
CK_TILE_DEVICE static constexpr auto get_device_arch()
{
#if defined(__gfx11__)
return gfx11_t{};
#else // if defined(__gfx12__)
return gfx12_t{};
#endif
}
} // namespace ck_tile

View File

@@ -6,6 +6,10 @@
#include "ck_tile/core/numeric/type_convert.hpp"
#include "ck_tile/core/container/thread_buffer.hpp"
#define HAS_GLOBAL_ATOMIC_PK_ADD_BUILTIN \
__has_builtin(__builtin_amdgcn_global_atomic_fadd_v2f16) && \
__has_builtin(__builtin_amdgcn_global_atomic_fadd_v2bf16)
namespace ck_tile {
template <typename T, typename ComputeType>
@@ -32,6 +36,14 @@ CK_TILE_HOST_DEVICE bf16x4_t add_bf16x4_t(const bf16x4_t& a, const bf16x4_t& b)
return rtn;
}
CK_TILE_HOST_DEVICE fp16x2_t add_f16x2_t(const fp16x2_t& a, const fp16x2_t& b)
{
fp16x2_t rtn;
rtn[0] = add<fp16_t, float>(a[0], b[0]);
rtn[1] = add<fp16_t, float>(a[1], b[1]);
return rtn;
}
CK_TILE_HOST_DEVICE fp8x4_t add_fp8x4_t(const fp8x4_t& a, const fp8x4_t& b)
{
fp8x4_t rtn;
@@ -304,6 +316,44 @@ CK_TILE_DEVICE void atomic_add<bf8x8_t>(bf8x8_t* p_dst, bf8x8_t const& x)
} while(cur_v.u64 != old_v);
}
//
// Atomic add for fp16x2_t
//
template <>
CK_TILE_DEVICE void atomic_add<fp16x2_t>(fp16x2_t* p_dst, fp16x2_t const& x)
{
#if HAS_GLOBAL_ATOMIC_PK_ADD_BUILTIN
__builtin_amdgcn_global_atomic_fadd_v2f16(c_style_pointer_cast<fp16x2_t*>(p_dst), x);
#else
union U32F162_ADDR
{
uint32_t* u32_a;
fp16x2_t* f162_a;
};
union U32F162
{
uint32_t u32;
fp16x2_t f162;
};
U32F162_ADDR dword_addr;
U32F162 cur_v;
U32F162 new_;
uint32_t old_v, new_v;
dword_addr.f162_a = p_dst;
cur_v.u32 = *dword_addr.u32_a;
do
{
old_v = cur_v.u32;
new_.f162 = add_f16x2_t(cur_v.f162, x);
new_v = new_.u32;
cur_v.u32 = atomicCAS(dword_addr.u32_a, old_v, new_v);
} while(cur_v.u32 != old_v);
#endif
}
template <typename T, index_t N>
CK_TILE_DEVICE void atomic_add_g(T* p_dst, const thread_buffer<T, N>& x)
{
@@ -311,6 +361,7 @@ CK_TILE_DEVICE void atomic_add_g(T* p_dst, const thread_buffer<T, N>& x)
(std::is_same<T, uint32_t>::value && (N == 1)) ||
(std::is_same<T, float>::value && (N == 1 || N == 2)) ||
(std::is_same<T, double>::value && (N == 1 || N == 2)) ||
(std::is_same<T, fp16_t>::value && (N == 2 || N == 4 || N == 8)) ||
(std::is_same<T, bf16_t>::value && (N == 2 || N == 4 || N == 8)) ||
(std::is_same<T, fp8_t>::value && (N == 4 || N == 8 || N == 16)) ||
(std::is_same<T, bf8_t>::value && (N == 4 || N == 8 || N == 16)),
@@ -406,6 +457,13 @@ CK_TILE_DEVICE void atomic_add_g(T* p_dst, const thread_buffer<T, N>& x)
atomic_add(c_style_pointer_cast<bf8x8_t*>(p_dst) + 1, x.template get_as<bf8x8_t>()[I1]);
}
}
else if constexpr(std::is_same<T, fp16_t>::value)
{
static_for<0, N / 2, 1>{}([&](auto i) {
atomic_add(c_style_pointer_cast<fp16x2_t*>(p_dst) + i,
x.template get_as<fp16x2_t>()[i]);
});
}
}
template <typename T, index_t N>

View File

@@ -152,7 +152,7 @@
// buffer atomic add: floating point
#ifndef __HIP_DEVICE_COMPILE__ // for host code
#define CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT 1
#elif defined(__gfx9__) // for GPU code
#elif defined(__gfx9__) || defined(__gfx12__) // for GPU code
#define CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT 1
#else // for GPU code
#define CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT 0
@@ -274,6 +274,12 @@
#define CK_TILE_WA_ISSUE_2028 0
#endif
#ifndef CK_TILE_WAVE32_ENABLED
#if defined(__gfx11__) || defined(__gfx12__)
#define CK_TILE_WAVE32_ENABLED
#endif
#endif
// Y pointed to R, we don't see a valuable use case.
// Will enforce encoding to check Y not pointed to R if set to zero
#ifndef CK_TILE_ENC_SUPPORT_Y_TO_R

View File

@@ -52,6 +52,19 @@ inline std::string get_device_name()
}
}
inline bool is_gfx11_supported()
{
return get_device_name() == "gfx1100" || get_device_name() == "gfx1101" ||
get_device_name() == "gfx1102" || get_device_name() == "gfx1103" ||
get_device_name() == "gfx1150" || get_device_name() == "gfx1151" ||
get_device_name() == "gfx1152";
}
inline bool is_gfx12_supported()
{
return get_device_name() == "gfx1200" || get_device_name() == "gfx1201";
}
inline bool is_load_tr_supported()
{
// Check if load transpose is supported.

View File

@@ -203,13 +203,13 @@ struct CShuffleEpilogue
static constexpr index_t MPerIterationShuffle = std::get<0>(MNPerIterationShuffle);
static constexpr index_t NPerIterationShuffle = std::get<1>(MNPerIterationShuffle);
using WG = WarpGemmMfmaDispatcher<ATypeToUse,
BTypeToUse,
AccDataType,
MPerXdl,
NPerXdl,
KPerXdl,
isCTransposed>;
using WG = WarpGemmDispatcher<ATypeToUse,
BTypeToUse,
AccDataType,
MPerXdl,
NPerXdl,
KPerXdl,
isCTransposed>;
using CWarpDstr = typename WG::CWarpDstr;
using CWarpTensor = typename WG::CWarpTensor;

View File

@@ -130,13 +130,13 @@ struct DefaultGemm2DEpilogue : public Default2DEpilogue<Problem_, Policy_>
static constexpr index_t kKPerXdl = Problem::kKPerXdl;
static constexpr index_t isCTransposed = Problem::isCTransposed;
using WG = WarpGemmMfmaDispatcher<ADataType,
BTypeToUse,
AccDataType,
kMPerXdl,
kNPerXdl,
kKPerXdl,
isCTransposed>;
using WG = WarpGemmDispatcher<ADataType,
BTypeToUse,
AccDataType,
kMPerXdl,
kNPerXdl,
kKPerXdl,
isCTransposed>;
using CWarpDstr = typename WG::CWarpDstr;

View File

@@ -430,13 +430,13 @@ struct UniversalFlatmmPipelineAgBgCrPolicy
// using AccDataType = float;
using BlockWarps = typename Problem::BlockGemmShape::BlockWarps;
using WarpTile = typename Problem::BlockGemmShape::WarpTile;
using WarpGemm = WarpGemmMfmaDispatcher<typename Problem::ADataType,
typename Problem::BDataType,
typename Problem::CDataType,
WarpTile::at(I0),
WarpTile::at(I1),
WarpTile::at(I2),
Problem::TransposeC>;
using WarpGemm = WarpGemmDispatcher<typename Problem::ADataType,
typename Problem::BDataType,
typename Problem::CDataType,
WarpTile::at(I0),
WarpTile::at(I1),
WarpTile::at(I2),
Problem::TransposeC>;
using BlockFlatmmPolicy = BlockFlatmmASmemBSmemCRegV1CustomPolicy<
typename Problem::ADataType,

View File

@@ -43,7 +43,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
typename Problem::BlockFmhaShape::Gemm0BlockWarps,
typename Problem::BlockFmhaShape::Gemm0WarpTile>>;
using WarpGemm = WarpGemmMfmaDispatcher<
using WarpGemm = WarpGemmDispatcher<
typename Problem::QDataType,
typename Problem::KDataType,
typename Problem::AccDataType,
@@ -78,18 +78,18 @@ struct BlockFmhaBwdPipelineDefaultPolicy
typename Problem::BlockFmhaShape::Gemm1WarpTile>>;
using WarpGemm =
WarpGemmMfmaDispatcher<typename Problem::GemmDataType,
typename Problem::OGradDataType,
typename Problem::AccDataType,
Problem::BlockFmhaShape::Gemm1WarpTile::at(number<0>{}),
Problem::BlockFmhaShape::Gemm1WarpTile::at(number<1>{}),
Problem::BlockFmhaShape::Gemm1WarpTile::at(number<2>{}),
true,
false, // SwizzleAccess
false, // UseStructuredSparsity
(Problem::BlockFmhaShape::Gemm1WarpTile::at(number<2>{}) == 32)
? WGAttrNumAccessEnum ::Double
: WGAttrNumAccessEnum ::Single>;
WarpGemmDispatcher<typename Problem::GemmDataType,
typename Problem::OGradDataType,
typename Problem::AccDataType,
Problem::BlockFmhaShape::Gemm1WarpTile::at(number<0>{}),
Problem::BlockFmhaShape::Gemm1WarpTile::at(number<1>{}),
Problem::BlockFmhaShape::Gemm1WarpTile::at(number<2>{}),
true,
false, // SwizzleAccess
false, // UseStructuredSparsity
(Problem::BlockFmhaShape::Gemm1WarpTile::at(number<2>{}) == 32)
? WGAttrNumAccessEnum ::Double
: WGAttrNumAccessEnum ::Single>;
using BlockGemmPolicy =
BlockGemmARegBRegCRegV1CustomPolicy<typename Problem::GemmDataType,
@@ -115,7 +115,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
typename Problem::BlockFmhaShape::Gemm2BlockWarps,
typename Problem::BlockFmhaShape::Gemm2WarpTile>>;
using WarpGemm = WarpGemmMfmaDispatcher<
using WarpGemm = WarpGemmDispatcher<
typename Problem::OGradDataType,
typename Problem::VDataType,
typename Problem::AccDataType,
@@ -150,18 +150,18 @@ struct BlockFmhaBwdPipelineDefaultPolicy
typename Problem::BlockFmhaShape::Gemm3WarpTile>>;
using WarpGemm =
WarpGemmMfmaDispatcher<typename Problem::GemmDataType,
typename Problem::QDataType,
typename Problem::AccDataType,
Problem::BlockFmhaShape::Gemm3WarpTile::at(number<0>{}),
Problem::BlockFmhaShape::Gemm3WarpTile::at(number<1>{}),
Problem::BlockFmhaShape::Gemm3WarpTile::at(number<2>{}),
true,
false, // SwizzleAccess
false, // UseStructuredSparsity
(Problem::BlockFmhaShape::Gemm1WarpTile::at(number<2>{}) == 32)
? WGAttrNumAccessEnum ::Double
: WGAttrNumAccessEnum ::Single>;
WarpGemmDispatcher<typename Problem::GemmDataType,
typename Problem::QDataType,
typename Problem::AccDataType,
Problem::BlockFmhaShape::Gemm3WarpTile::at(number<0>{}),
Problem::BlockFmhaShape::Gemm3WarpTile::at(number<1>{}),
Problem::BlockFmhaShape::Gemm3WarpTile::at(number<2>{}),
true,
false, // SwizzleAccess
false, // UseStructuredSparsity
(Problem::BlockFmhaShape::Gemm1WarpTile::at(number<2>{}) == 32)
? WGAttrNumAccessEnum ::Double
: WGAttrNumAccessEnum ::Single>;
using BlockGemmPolicy =
BlockGemmARegBRegCRegV1CustomPolicy<typename Problem::GemmDataType,
@@ -187,14 +187,13 @@ struct BlockFmhaBwdPipelineDefaultPolicy
typename Problem::BlockFmhaShape::Gemm4BlockWarps,
typename Problem::BlockFmhaShape::Gemm4WarpTile>>;
using WarpGemm =
WarpGemmMfmaDispatcher<typename Problem::GemmDataType,
typename Problem::KDataType,
typename Problem::AccDataType,
Problem::BlockFmhaShape::Gemm4WarpTile::at(number<0>{}),
Problem::BlockFmhaShape::Gemm4WarpTile::at(number<1>{}),
Problem::BlockFmhaShape::Gemm4WarpTile::at(number<2>{}),
false>;
using WarpGemm = WarpGemmDispatcher<typename Problem::GemmDataType,
typename Problem::KDataType,
typename Problem::AccDataType,
Problem::BlockFmhaShape::Gemm4WarpTile::at(number<0>{}),
Problem::BlockFmhaShape::Gemm4WarpTile::at(number<1>{}),
Problem::BlockFmhaShape::Gemm4WarpTile::at(number<2>{}),
false>;
using BlockGemmPolicy =
BlockGemmARegBRegCRegV1CustomPolicy<typename Problem::GemmDataType,

View File

@@ -25,7 +25,7 @@ struct BlockFmhaBwdPipelineTrLoadDefaultPolicy
typename Problem::BlockFmhaShape::Gemm0WarpTile>>;
constexpr auto SwizzleA = false;
using WarpGemm = WarpGemmMfmaDispatcher< //
using WarpGemm = WarpGemmDispatcher< //
typename Problem::QDataType,
typename Problem::KDataType,
typename Problem::AccDataType,
@@ -66,7 +66,7 @@ struct BlockFmhaBwdPipelineTrLoadDefaultPolicy
typename Problem::BlockFmhaShape::Gemm2WarpTile>>;
constexpr auto SwizzleA = false;
using WarpGemm = WarpGemmMfmaDispatcher< //
using WarpGemm = WarpGemmDispatcher< //
typename Problem::OGradDataType,
typename Problem::VDataType,
typename Problem::AccDataType,
@@ -106,7 +106,7 @@ struct BlockFmhaBwdPipelineTrLoadDefaultPolicy
typename BlockFmhaShape::Gemm4BlockWarps,
typename BlockFmhaShape::Gemm4WarpTile>>;
using WarpGemm = WarpGemmMfmaDispatcher< //
using WarpGemm = WarpGemmDispatcher< //
typename Problem::GemmDataType,
typename Problem::KDataType,
typename Problem::AccDataType,

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -512,14 +512,13 @@ struct BlockFmhaPipelineQRKSVSAsyncTrloadDefaultPolicy
typename Problem::BlockFmhaShape::Gemm0BlockWarps,
typename Problem::BlockFmhaShape::Gemm0WarpTile>>;
using WarpGemm =
WarpGemmMfmaDispatcher<typename Problem::QDataType,
typename Problem::KDataType,
typename Problem::SaccDataType,
Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{}),
Problem::BlockFmhaShape::Gemm0WarpTile::at(number<1>{}),
Problem::BlockFmhaShape::Gemm0WarpTile::at(number<2>{}),
true>;
using WarpGemm = WarpGemmDispatcher<typename Problem::QDataType,
typename Problem::KDataType,
typename Problem::SaccDataType,
Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{}),
Problem::BlockFmhaShape::Gemm0WarpTile::at(number<1>{}),
Problem::BlockFmhaShape::Gemm0WarpTile::at(number<2>{}),
true>;
using BlockGemmPolicy =
BlockGemmARegBRegCRegV2CustomPolicy<typename Problem::QDataType,
@@ -546,22 +545,22 @@ struct BlockFmhaPipelineQRKSVSAsyncTrloadDefaultPolicy
typename Problem::BlockFmhaShape::Gemm1BlockWarps,
typename Problem::BlockFmhaShape::Gemm1WarpTile>>;
using WarpGemm = WarpGemmMfmaDispatcher<
typename Problem::PDataType,
typename Problem::VDataType,
typename Problem::OaccDataType,
Problem::BlockFmhaShape::Gemm1WarpTile::at(number<0>{}),
Problem::BlockFmhaShape::Gemm1WarpTile::at(number<1>{}),
Problem::BlockFmhaShape::Gemm1WarpTile::at(number<2>{}),
true,
false,
false,
((Problem::BlockFmhaShape::Gemm1WarpTile::at(number<1>{}) == 16 &&
Problem::BlockFmhaShape::Gemm1WarpTile::at(number<2>{}) == 32) ||
(Problem::BlockFmhaShape::Gemm1WarpTile::at(number<1>{}) == 32 &&
Problem::BlockFmhaShape::Gemm1WarpTile::at(number<2>{}) == 16))
? WGAttrNumAccessEnum::Double
: WGAttrNumAccessEnum::Single>;
using WarpGemm =
WarpGemmDispatcher<typename Problem::PDataType,
typename Problem::VDataType,
typename Problem::OaccDataType,
Problem::BlockFmhaShape::Gemm1WarpTile::at(number<0>{}),
Problem::BlockFmhaShape::Gemm1WarpTile::at(number<1>{}),
Problem::BlockFmhaShape::Gemm1WarpTile::at(number<2>{}),
true,
false,
false,
((Problem::BlockFmhaShape::Gemm1WarpTile::at(number<1>{}) == 16 &&
Problem::BlockFmhaShape::Gemm1WarpTile::at(number<2>{}) == 32) ||
(Problem::BlockFmhaShape::Gemm1WarpTile::at(number<1>{}) == 32 &&
Problem::BlockFmhaShape::Gemm1WarpTile::at(number<2>{}) == 16))
? WGAttrNumAccessEnum::Double
: WGAttrNumAccessEnum::Single>;
using BlockGemmPolicy =
BlockGemmARegBRegCRegV2CustomPolicy<typename Problem::PDataType,

View File

@@ -956,20 +956,19 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
{
return WarpGemmMfmaFp8Fp8F32M32N32K16SwizzleBTransposedCDistribution<>{};
// return
// WarpGemmImpl<WarpGemmAtrributeMfmaTransposedCDistribution_SwizzleB<
// WarpGemmImpl<WarpGemmAttributeMfmaTransposedCDistribution_SwizzleB<
// WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base<typename
// Problem::PDataType, typename Problem::VDataType>>>{};
}
else
{
return WarpGemmMfmaDispatcher<
typename Problem::PDataType,
typename Problem::VDataType,
typename Problem::OaccDataType,
Problem::BlockFmhaShape::Gemm1WarpTile::at(number<0>{}),
Problem::BlockFmhaShape::Gemm1WarpTile::at(number<1>{}),
Problem::BlockFmhaShape::Gemm1WarpTile::at(number<2>{}),
true>{};
return WarpGemmDispatcher<typename Problem::PDataType,
typename Problem::VDataType,
typename Problem::OaccDataType,
Problem::BlockFmhaShape::Gemm1WarpTile::at(number<0>{}),
Problem::BlockFmhaShape::Gemm1WarpTile::at(number<1>{}),
Problem::BlockFmhaShape::Gemm1WarpTile::at(number<2>{}),
true>{};
}
}();

View File

@@ -568,7 +568,7 @@ struct FusedMoeGemmPipelineFlatmmPolicy
std::is_same_v<typename Problem::GDataType, ck_tile::bf16_t> &&
S_::Warp_M0 == 32 && S_::Warp_N0 == 32 && S_::Warp_K0 == 16)
{
return WarpGemmImpl<WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB<
return WarpGemmImpl<WarpGemmAttributeMfmaIterateKAndTransposedCDistribution_SwizzleB<
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8<wg_ctrl>,
2>>{};
}
@@ -576,7 +576,7 @@ struct FusedMoeGemmPipelineFlatmmPolicy
std::is_same_v<typename Problem::GDataType, ck_tile::int8_t> &&
S_::Warp_M0 == 32 && S_::Warp_N0 == 32 && S_::Warp_K0 == 32)
{
return WarpGemmImpl<WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB<
return WarpGemmImpl<WarpGemmAttributeMfmaIterateKAndTransposedCDistribution_SwizzleB<
WarpGemmAttributeMfmaImpl_i32_32x32x16_i8<wg_ctrl>,
2>>{};
}
@@ -695,7 +695,7 @@ struct FusedMoeGemmPipelineFlatmmPolicy
std::is_same_v<typename Problem::DDataType, ck_tile::bf16_t> &&
S_::Warp_M0 == 32 && S_::Warp_N0 == 32 && S_::Warp_K0 == 16)
{
return WarpGemmImpl<WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB<
return WarpGemmImpl<WarpGemmAttributeMfmaIterateKAndTransposedCDistribution_SwizzleB<
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8<wg_ctrl>,
2>>{};
}
@@ -703,7 +703,7 @@ struct FusedMoeGemmPipelineFlatmmPolicy
std::is_same_v<typename Problem::DDataType, ck_tile::int8_t> &&
S_::Warp_M0 == 32 && S_::Warp_N0 == 32 && S_::Warp_K0 == 32)
{
return WarpGemmImpl<WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB<
return WarpGemmImpl<WarpGemmAttributeMfmaIterateKAndTransposedCDistribution_SwizzleB<
WarpGemmAttributeMfmaImpl_i32_32x32x16_i8<wg_ctrl>,
2>>{};
}

View File

@@ -58,9 +58,15 @@
#include "ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_attribute_smfmac.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_attribute_smfmac_impl.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma_impl.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma_impl_16bit_traits.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma_impl_8bit_traits.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma_impl_base_traits.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_impl.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_smfmac_impl.hpp"
#include "ck_tile/ops/gemm/warp/warp_wmma_gemm.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/common/utils.hpp"

View File

@@ -54,16 +54,16 @@ struct BlockGemmASmemBSmemCRegV1DefaultPolicy
return make_tuple(WarpGemmMfmaF16F16F32M32N32K16<>{}, 2, 2);
}
#else
using WG = WarpGemmMfmaDispatcher<ck_tile::half_t,
ck_tile::half_t,
float,
32,
32,
16,
true,
false,
false,
wg_attr_num_access>;
using WG = WarpGemmDispatcher<ck_tile::half_t,
ck_tile::half_t,
float,
32,
32,
16,
true,
false,
false,
wg_attr_num_access>;
return make_tuple(WG{}, 4, 1);
#endif
}
@@ -71,16 +71,16 @@ struct BlockGemmASmemBSmemCRegV1DefaultPolicy
std::is_same_v<typename Problem::BDataType, bf16_t> &&
std::is_same_v<typename Problem::CDataType, float>)
{
using WG = WarpGemmMfmaDispatcher<ck_tile::bf16_t,
ck_tile::bf16_t,
float,
32,
32,
16,
true,
false,
false,
wg_attr_num_access>;
using WG = WarpGemmDispatcher<ck_tile::bf16_t,
ck_tile::bf16_t,
float,
32,
32,
16,
true,
false,
false,
wg_attr_num_access>;
return make_tuple(WG{}, 4, 1);
}
else

View File

@@ -182,7 +182,7 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
constexpr index_t NPerXDL = BlockGemm::WarpGemm::kN;
constexpr index_t KPerXDL = BlockGemm::WarpGemm::WarpGemmAttribute::Impl::kK;
constexpr index_t WaveSize = 64;
constexpr index_t WaveSize = get_warp_size();
constexpr index_t WaveNumM = BlockGemmShape::BlockWarps::at(I0{});
constexpr index_t WaveNumN = BlockGemmShape::BlockWarps::at(I1{});
@@ -242,7 +242,7 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
constexpr index_t NPerXDL = BlockGemm::WarpGemm::kN;
constexpr index_t KPerXDL = BlockGemm::WarpGemm::WarpGemmAttribute::Impl::kK;
constexpr index_t WaveSize = 64;
constexpr index_t WaveSize = get_warp_size();
constexpr index_t WaveNumM = BlockGemmShape::BlockWarps::at(I0{});
constexpr index_t WaveNumN = BlockGemmShape::BlockWarps::at(I1{});

View File

@@ -182,7 +182,7 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4<Problem>
constexpr index_t NPerXDL = BlockGemmShape::WarpTile::at(I1{});
constexpr index_t KPerXDL = BlockGemmShape::WarpTile::at(I2{});
constexpr index_t WaveSize = 64;
constexpr index_t WaveSize = get_warp_size();
constexpr index_t WaveNumM = BlockGemmShape::BlockWarps::at(I0{});
constexpr index_t WaveNumN = BlockGemmShape::BlockWarps::at(I1{});

View File

@@ -32,16 +32,17 @@ struct GemmPipelineAgBgCrCompV4DefaultPolicy
? WGAttrNumAccessEnum::Double
: WGAttrNumAccessEnum::Single;
using WarpGemm = WarpGemmMfmaDispatcher<typename Problem::ADataType,
typename Problem::BDataType,
typename Problem::CDataType, // AccDataType
WarpTile::at(I0),
WarpTile::at(I1),
WarpTile::at(I2),
Problem::TransposeC,
false,
false,
wg_attr_num_access>;
using WarpGemm = WarpGemmDispatcher<typename Problem::ADataType,
typename Problem::BDataType,
typename Problem::CDataType, // AccDataType
WarpTile::at(I0),
WarpTile::at(I1),
WarpTile::at(I2),
Problem::TransposeC,
false,
false,
wg_attr_num_access>;
using BlockGemmPolicy = BlockGemmARegBRegCRegV1CustomPolicy<typename Problem::ADataType,
typename Problem::BDataType,
typename Problem::CDataType,

View File

@@ -21,15 +21,16 @@ struct GemmPipelineAgBgCrCompV5DefaultPolicy
CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm()
{
// using AccDataType = float;
using BlockWarps = typename Problem::BlockGemmShape::BlockWarps;
using WarpTile = typename Problem::BlockGemmShape::WarpTile;
using WarpGemm = WarpGemmMfmaDispatcher<typename Problem::ADataType,
typename Problem::BDataType,
typename Problem::CDataType, // AccDataType
WarpTile::at(I0),
WarpTile::at(I1),
WarpTile::at(I2),
Problem::TransposeC>;
using BlockWarps = typename Problem::BlockGemmShape::BlockWarps;
using WarpTile = typename Problem::BlockGemmShape::WarpTile;
using WarpGemm = WarpGemmDispatcher<typename Problem::ADataType,
typename Problem::BDataType,
typename Problem::CDataType, // AccDataType
WarpTile::at(I0),
WarpTile::at(I1),
WarpTile::at(I2),
Problem::TransposeC>;
using BlockGemmPolicy = BlockGemmARegBRegCRegV1CustomPolicy<typename Problem::ADataType,
typename Problem::BDataType,
typename Problem::CDataType,

View File

@@ -390,16 +390,17 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm()
{
using AccDataType = float;
using BlockWarps = typename Problem::BlockGemmShape::BlockWarps;
using WarpTile = typename Problem::BlockGemmShape::WarpTile;
using WarpGemm = WarpGemmMfmaDispatcher<typename Problem::ComputeDataType,
typename Problem::ComputeDataType,
AccDataType,
WarpTile::at(I0),
WarpTile::at(I1),
WarpTile::at(I2),
Problem::TransposeC>;
using AccDataType = float;
using BlockWarps = typename Problem::BlockGemmShape::BlockWarps;
using WarpTile = typename Problem::BlockGemmShape::WarpTile;
using WarpGemm = WarpGemmDispatcher<typename Problem::ComputeDataType,
typename Problem::ComputeDataType,
AccDataType,
WarpTile::at(I0),
WarpTile::at(I1),
WarpTile::at(I2),
Problem::TransposeC>;
using BlockGemmPolicy = BlockGemmASmemBSmemCRegV1CustomPolicy<typename Problem::ADataType,
typename Problem::BDataType,
typename Problem::CDataType,

View File

@@ -635,16 +635,17 @@ struct UniversalGemmPipelineAgBgCrPolicy
: vector_size * 4 == thread_elements ? WGAttrNumAccessEnum::Quad
: WGAttrNumAccessEnum::Invalid;
using WarpGemm = WarpGemmMfmaDispatcher<typename Problem::ComputeDataType,
typename Problem::ComputeDataType,
typename Problem::CDataType,
WarpTile::at(I0),
WarpTile::at(I1),
WarpTile::at(I2),
Problem::TransposeC,
false,
Problem::UseStructuredSparsity,
wg_attr_num_access>;
using WarpGemm = WarpGemmDispatcher<typename Problem::ComputeDataType,
typename Problem::ComputeDataType,
typename Problem::CDataType,
WarpTile::at(I0),
WarpTile::at(I1),
WarpTile::at(I2),
Problem::TransposeC,
false,
Problem::UseStructuredSparsity,
wg_attr_num_access>;
using BlockGemmPolicy = BlockGemmASmemBSmemCRegV1CustomPolicy<typename Problem::ADataType,
typename Problem::BDataType,
typename Problem::CDataType,

View File

@@ -280,13 +280,13 @@ struct UniversalWeightPreshufflePipelineAgBgCrPolicy
{
using BlockWarps = typename Problem::BlockGemmShape::BlockWarps;
using WarpTile = typename Problem::BlockGemmShape::WarpTile;
using WarpGemm = WarpGemmMfmaDispatcher<typename Problem::ADataType,
typename Problem::BDataType,
typename Problem::CDataType,
WarpTile::at(I0),
WarpTile::at(I1),
WarpTile::at(I2),
Problem::TransposeC>;
using WarpGemm = WarpGemmDispatcher<typename Problem::ADataType,
typename Problem::BDataType,
typename Problem::CDataType,
WarpTile::at(I0),
WarpTile::at(I1),
WarpTile::at(I2),
Problem::TransposeC>;
using BlockWeightPreshufflePolicy =
BlockWeightPreshuffleASmemBSmemCRegV1CustomPolicy<typename Problem::ADataType,

View File

@@ -15,19 +15,19 @@ namespace ck_tile {
// fp16
using WarpGemmMfmaF16F16F32M32N32K8 = WarpGemmImpl<
WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImplF16F16F32M32N32K8<WGAttrCtlEnum::Default_>>>;
WarpGemmAttributeMfma<WarpGemmAttributeMfmaImplF16F16F32M32N32K8<WGAttrCtlEnum::Default_>>>;
using WarpGemmMfmaF16F16F32M16N16K16 = WarpGemmImpl<
WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImplF16F16F32M16N16K16<WGAttrCtlEnum::Default_>>>;
WarpGemmAttributeMfma<WarpGemmAttributeMfmaImplF16F16F32M16N16K16<WGAttrCtlEnum::Default_>>>;
#if defined(__gfx950__)
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
using WarpGemmMfmaF16F16F32M32N32K16 = WarpGemmImpl<
WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImplF16F16F32M32N32K16<WGAttrCtlEnum::Default_>,
WarpGemmAttributeMfma<WarpGemmAttributeMfmaImplF16F16F32M32N32K16<WGAttrCtlEnum::Default_>,
AttrNumAccess>>;
#else
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
using WarpGemmMfmaF16F16F32M32N32K16 = WarpGemmImpl<WarpGemmAtrributeMfmaIterateK<
using WarpGemmMfmaF16F16F32M32N32K16 = WarpGemmImpl<WarpGemmAttributeMfmaIterateK<
WarpGemmAttributeMfmaImplF16F16F32M32N32K8<WGAttrCtlEnum::Default_>,
2,
AttrNumAccess>>;
@@ -36,42 +36,42 @@ using WarpGemmMfmaF16F16F32M32N32K16 = WarpGemmImpl<WarpGemmAtrributeMfmaIterate
#if defined(__gfx950__)
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
using WarpGemmMfmaF16F16F32M16N16K32 = WarpGemmImpl<
WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImplF16F16F32M16N16K32<WGAttrCtlEnum::Default_>,
WarpGemmAttributeMfma<WarpGemmAttributeMfmaImplF16F16F32M16N16K32<WGAttrCtlEnum::Default_>,
AttrNumAccess>>;
#else
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
using WarpGemmMfmaF16F16F32M16N16K32 = WarpGemmImpl<WarpGemmAtrributeMfmaIterateK<
using WarpGemmMfmaF16F16F32M16N16K32 = WarpGemmImpl<WarpGemmAttributeMfmaIterateK<
WarpGemmAttributeMfmaImplF16F16F32M16N16K16<WGAttrCtlEnum::Default_>,
2,
AttrNumAccess>>;
#endif
using WarpGemmMfmaF16F16F32M32N32K8SwizzleA = WarpGemmImpl<WarpGemmAtrributeMfmaIterateK_SwizzleA<
using WarpGemmMfmaF16F16F32M32N32K8SwizzleA = WarpGemmImpl<WarpGemmAttributeMfmaIterateK_SwizzleA<
WarpGemmAttributeMfmaImplF16F16F32M32N32K8<WGAttrCtlEnum::Default_>,
1>>;
using WarpGemmMfmaF16F16F32M32N32K16SwizzleA = WarpGemmImpl<WarpGemmAtrributeMfmaIterateK_SwizzleA<
using WarpGemmMfmaF16F16F32M32N32K16SwizzleA = WarpGemmImpl<WarpGemmAttributeMfmaIterateK_SwizzleA<
WarpGemmAttributeMfmaImplF16F16F32M32N32K8<WGAttrCtlEnum::Default_>,
2>>;
using WarpGemmMfmaF16F16F32M32N32K8TransposedCDistribution =
WarpGemmImpl<WarpGemmAtrributeMfmaTransposedCDistribution<
WarpGemmImpl<WarpGemmAttributeMfmaTransposedCDistribution<
WarpGemmAttributeMfmaImplF16F16F32M32N32K8<WGAttrCtlEnum::Default_>>>;
using WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution =
WarpGemmImpl<WarpGemmAtrributeMfmaTransposedCDistribution<
WarpGemmImpl<WarpGemmAttributeMfmaTransposedCDistribution<
WarpGemmAttributeMfmaImplF16F16F32M16N16K16<WGAttrCtlEnum::Default_>>>;
#if defined(__gfx950__)
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
using WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution =
WarpGemmImpl<WarpGemmAtrributeMfmaTransposedCDistribution<
WarpGemmImpl<WarpGemmAttributeMfmaTransposedCDistribution<
WarpGemmAttributeMfmaImplF16F16F32M32N32K16<WGAttrCtlEnum::Default_>,
AttrNumAccess>>;
#else
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
using WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution =
WarpGemmImpl<WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution<
WarpGemmImpl<WarpGemmAttributeMfmaIterateKAndTransposedCDistribution<
WarpGemmAttributeMfmaImplF16F16F32M32N32K8<WGAttrCtlEnum::Default_>,
2,
AttrNumAccess>>;
@@ -80,13 +80,13 @@ using WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution =
#if defined(__gfx950__)
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
using WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution =
WarpGemmImpl<WarpGemmAtrributeMfmaTransposedCDistribution<
WarpGemmImpl<WarpGemmAttributeMfmaTransposedCDistribution<
WarpGemmAttributeMfmaImplF16F16F32M16N16K32<WGAttrCtlEnum::Default_>,
AttrNumAccess>>;
#else
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
using WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution =
WarpGemmImpl<WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution<
WarpGemmImpl<WarpGemmAttributeMfmaIterateKAndTransposedCDistribution<
WarpGemmAttributeMfmaImplF16F16F32M16N16K16<WGAttrCtlEnum::Default_>,
2,
AttrNumAccess>>;
@@ -94,36 +94,36 @@ using WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution =
#if defined(__gfx950__)
using WarpGemmMfmaF16F16F32M16N16K32SwizzleBTransposedCDistribution =
WarpGemmImpl<WarpGemmAtrributeMfmaTransposedCDistribution_SwizzleB<
WarpGemmImpl<WarpGemmAttributeMfmaTransposedCDistribution_SwizzleB<
WarpGemmAttributeMfmaImplF16F16F32M16N16K32<WGAttrCtlEnum::Default_>,
1>>;
using WarpGemmMfmaBf16Bf16F32M16N16K32SwizzleBTransposedCDistribution =
WarpGemmImpl<WarpGemmAtrributeMfmaTransposedCDistribution_SwizzleB<
WarpGemmImpl<WarpGemmAttributeMfmaTransposedCDistribution_SwizzleB<
WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K32<WGAttrCtlEnum::Default_>,
1>>;
#endif
using WarpGemmMfmaF16F16F32M32N32K8SwizzleBTransposedCDistribution =
WarpGemmImpl<WarpGemmAtrributeMfmaTransposedCDistribution_SwizzleB<
WarpGemmImpl<WarpGemmAttributeMfmaTransposedCDistribution_SwizzleB<
WarpGemmAttributeMfmaImplF16F16F32M32N32K8<WGAttrCtlEnum::Default_>>>;
#if defined(__gfx950__)
using WarpGemmMfmaF16F16F32M32N32K16SwizzleBTransposedCDistribution =
WarpGemmImpl<WarpGemmAtrributeMfmaTransposedCDistribution_SwizzleB<
WarpGemmImpl<WarpGemmAttributeMfmaTransposedCDistribution_SwizzleB<
WarpGemmAttributeMfmaImplF16F16F32M32N32K16<WGAttrCtlEnum::Default_>>>;
#else
using WarpGemmMfmaF16F16F32M32N32K16SwizzleBTransposedCDistribution =
WarpGemmImpl<WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB<
WarpGemmImpl<WarpGemmAttributeMfmaIterateKAndTransposedCDistribution_SwizzleB<
WarpGemmAttributeMfmaImplF16F16F32M32N32K8<WGAttrCtlEnum::Default_>,
2>>;
#endif
using WarpGemmMfmaF16F16F32M4N64K16 = WarpGemmImpl<WarpGemmAtrributeMfmaIterateK<
using WarpGemmMfmaF16F16F32M4N64K16 = WarpGemmImpl<WarpGemmAttributeMfmaIterateK<
WarpGemmAttributeMfmaImplF16F16F32M4N64K4<WGAttrCtlEnum::Default_>,
4>>;
using WarpGemmMfmaF16F16F32M64N4K16 = WarpGemmImpl<WarpGemmAtrributeMfmaIterateK<
using WarpGemmMfmaF16F16F32M64N4K16 = WarpGemmImpl<WarpGemmAttributeMfmaIterateK<
WarpGemmAttributeMfmaImplF16F16F32M64N4K4<WGAttrCtlEnum::Default_>,
4>>;
@@ -136,19 +136,19 @@ using WarpGemmSmfmacF16F16F32M16N16K32 = WarpGemmSmfmacImpl<WarpGemmAttributeSmf
// bf16
using WarpGemmMfmaBf16Bf16F32M32N32K8 = WarpGemmImpl<
WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8<WGAttrCtlEnum::Default_>>>;
WarpGemmAttributeMfma<WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8<WGAttrCtlEnum::Default_>>>;
using WarpGemmMfmaBf16Bf16F32M16N16K16 = WarpGemmImpl<
WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16<WGAttrCtlEnum::Default_>>>;
WarpGemmAttributeMfma<WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16<WGAttrCtlEnum::Default_>>>;
#if defined(__gfx950__)
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
using WarpGemmMfmaBf16Bf16F32M32N32K16 = WarpGemmImpl<
WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K16<WGAttrCtlEnum::Default_>,
WarpGemmAttributeMfma<WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K16<WGAttrCtlEnum::Default_>,
AttrNumAccess>>;
#else
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
using WarpGemmMfmaBf16Bf16F32M32N32K16 = WarpGemmImpl<WarpGemmAtrributeMfmaIterateK<
using WarpGemmMfmaBf16Bf16F32M32N32K16 = WarpGemmImpl<WarpGemmAttributeMfmaIterateK<
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8<WGAttrCtlEnum::Default_>,
2,
AttrNumAccess>>;
@@ -157,43 +157,43 @@ using WarpGemmMfmaBf16Bf16F32M32N32K16 = WarpGemmImpl<WarpGemmAtrributeMfmaItera
#if defined(__gfx950__)
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
using WarpGemmMfmaBf16Bf16F32M16N16K32 = WarpGemmImpl<
WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K32<WGAttrCtlEnum::Default_>,
WarpGemmAttributeMfma<WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K32<WGAttrCtlEnum::Default_>,
AttrNumAccess>>;
#else
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
using WarpGemmMfmaBf16Bf16F32M16N16K32 = WarpGemmImpl<WarpGemmAtrributeMfmaIterateK<
using WarpGemmMfmaBf16Bf16F32M16N16K32 = WarpGemmImpl<WarpGemmAttributeMfmaIterateK<
WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16<WGAttrCtlEnum::Default_>,
2,
AttrNumAccess>>;
#endif
using WarpGemmMfmaBf16Bf16F32M32N32K8SwizzleA = WarpGemmImpl<WarpGemmAtrributeMfmaIterateK_SwizzleA<
using WarpGemmMfmaBf16Bf16F32M32N32K8SwizzleA = WarpGemmImpl<WarpGemmAttributeMfmaIterateK_SwizzleA<
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8<WGAttrCtlEnum::Default_>,
1>>;
using WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleA =
WarpGemmImpl<WarpGemmAtrributeMfmaIterateK_SwizzleA<
WarpGemmImpl<WarpGemmAttributeMfmaIterateK_SwizzleA<
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8<WGAttrCtlEnum::Default_>,
2>>;
using WarpGemmMfmaBf16Bf16F32M32N32K8TransposedCDistribution =
WarpGemmImpl<WarpGemmAtrributeMfmaTransposedCDistribution<
WarpGemmImpl<WarpGemmAttributeMfmaTransposedCDistribution<
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8<WGAttrCtlEnum::Default_>>>;
using WarpGemmMfmaBf16Bf16F32M16N16K16TransposedCDistribution =
WarpGemmImpl<WarpGemmAtrributeMfmaTransposedCDistribution<
WarpGemmImpl<WarpGemmAttributeMfmaTransposedCDistribution<
WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16<WGAttrCtlEnum::Default_>>>;
#if defined(__gfx950__)
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
using WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution =
WarpGemmImpl<WarpGemmAtrributeMfmaTransposedCDistribution<
WarpGemmImpl<WarpGemmAttributeMfmaTransposedCDistribution<
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K16<WGAttrCtlEnum::Default_>,
AttrNumAccess>>;
#else
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
using WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution =
WarpGemmImpl<WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution<
WarpGemmImpl<WarpGemmAttributeMfmaIterateKAndTransposedCDistribution<
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8<WGAttrCtlEnum::Default_>,
2,
AttrNumAccess>>;
@@ -202,153 +202,153 @@ using WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution =
#if defined(__gfx950__)
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
using WarpGemmMfmaBf16Bf16F32M16N16K32TransposedCDistribution =
WarpGemmImpl<WarpGemmAtrributeMfmaTransposedCDistribution<
WarpGemmImpl<WarpGemmAttributeMfmaTransposedCDistribution<
WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K32<WGAttrCtlEnum::Default_>,
AttrNumAccess>>;
#else
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
using WarpGemmMfmaBf16Bf16F32M16N16K32TransposedCDistribution =
WarpGemmImpl<WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution<
WarpGemmImpl<WarpGemmAttributeMfmaIterateKAndTransposedCDistribution<
WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16<WGAttrCtlEnum::Default_>,
2,
AttrNumAccess>>;
#endif
using WarpGemmMfmaBf16Bf16F32M32N32K8SwizzleBTransposedCDistribution =
WarpGemmImpl<WarpGemmAtrributeMfmaTransposedCDistribution_SwizzleB<
WarpGemmImpl<WarpGemmAttributeMfmaTransposedCDistribution_SwizzleB<
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8<WGAttrCtlEnum::Default_>>>;
#if defined(__gfx950__)
using WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleBTransposedCDistribution =
WarpGemmImpl<WarpGemmAtrributeMfmaTransposedCDistribution_SwizzleB<
WarpGemmImpl<WarpGemmAttributeMfmaTransposedCDistribution_SwizzleB<
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K16<WGAttrCtlEnum::Default_>>>;
#else
using WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleBTransposedCDistribution =
WarpGemmImpl<WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB<
WarpGemmImpl<WarpGemmAttributeMfmaIterateKAndTransposedCDistribution_SwizzleB<
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8<WGAttrCtlEnum::Default_>,
2>>;
#endif
using WarpGemmMfmaBf16Bf16F32M4N64K16 = WarpGemmImpl<WarpGemmAtrributeMfmaIterateK<
using WarpGemmMfmaBf16Bf16F32M4N64K16 = WarpGemmImpl<WarpGemmAttributeMfmaIterateK<
WarpGemmAttributeMfmaImplBf16Bf16F32M4N64K4<WGAttrCtlEnum::Default_>,
4>>;
using WarpGemmMfmaBf16Bf16F32M64N4K16 = WarpGemmImpl<WarpGemmAtrributeMfmaIterateK<
using WarpGemmMfmaBf16Bf16F32M64N4K16 = WarpGemmImpl<WarpGemmAttributeMfmaIterateK<
WarpGemmAttributeMfmaImplBf16Bf16F32M64N4K4<WGAttrCtlEnum::Default_>,
4>>;
// fp8
using WarpGemmMfma_f32_32x32x16_fp8_fp8 = WarpGemmImpl<
WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_fp8<WGAttrCtlEnum::Default_>>>;
WarpGemmAttributeMfma<WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_fp8<WGAttrCtlEnum::Default_>>>;
using WarpGemmMfma_f32_32x32x16_fp8_bf8 = WarpGemmImpl<
WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_bf8<WGAttrCtlEnum::Default_>>>;
WarpGemmAttributeMfma<WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_bf8<WGAttrCtlEnum::Default_>>>;
using WarpGemmMfma_f32_32x32x16_bf8_fp8 = WarpGemmImpl<
WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImpl_f32_32x32x16_bf8_fp8<WGAttrCtlEnum::Default_>>>;
WarpGemmAttributeMfma<WarpGemmAttributeMfmaImpl_f32_32x32x16_bf8_fp8<WGAttrCtlEnum::Default_>>>;
using WarpGemmMfma_f32_32x32x16_bf8_bf8 = WarpGemmImpl<
WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImpl_f32_32x32x16_bf8_bf8<WGAttrCtlEnum::Default_>>>;
WarpGemmAttributeMfma<WarpGemmAttributeMfmaImpl_f32_32x32x16_bf8_bf8<WGAttrCtlEnum::Default_>>>;
using WarpGemmMfma_f32_32x32x32_fp8_fp8 = WarpGemmImpl<WarpGemmAtrributeMfmaIterateK<
using WarpGemmMfma_f32_32x32x32_fp8_fp8 = WarpGemmImpl<WarpGemmAttributeMfmaIterateK<
WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_fp8<WGAttrCtlEnum::Default_>,
2>>;
using WarpGemmMfma_f32_32x32x32_bf8_bf8 = WarpGemmImpl<WarpGemmAtrributeMfmaIterateK<
using WarpGemmMfma_f32_32x32x32_bf8_bf8 = WarpGemmImpl<WarpGemmAttributeMfmaIterateK<
WarpGemmAttributeMfmaImpl_f32_32x32x16_bf8_bf8<WGAttrCtlEnum::Default_>,
2>>;
using WarpGemmMfma_f32_16x16x32_fp8_fp8 = WarpGemmImpl<
WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImpl_f32_16x16x32_fp8_fp8<WGAttrCtlEnum::Default_>>>;
WarpGemmAttributeMfma<WarpGemmAttributeMfmaImpl_f32_16x16x32_fp8_fp8<WGAttrCtlEnum::Default_>>>;
using WarpGemmMfma_f32_16x16x32_bf8_bf8 = WarpGemmImpl<
WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImpl_f32_16x16x32_bf8_bf8<WGAttrCtlEnum::Default_>>>;
WarpGemmAttributeMfma<WarpGemmAttributeMfmaImpl_f32_16x16x32_bf8_bf8<WGAttrCtlEnum::Default_>>>;
using WarpGemmMfma_f32_16x16x64_fp8_fp8 = WarpGemmImpl<WarpGemmAtrributeMfmaIterateK<
using WarpGemmMfma_f32_16x16x64_fp8_fp8 = WarpGemmImpl<WarpGemmAttributeMfmaIterateK<
WarpGemmAttributeMfmaImpl_f32_16x16x32_fp8_fp8<WGAttrCtlEnum::Default_>,
2>>;
using WarpGemmMfma_f32_16x16x64_bf8_bf8 = WarpGemmImpl<WarpGemmAtrributeMfmaIterateK<
using WarpGemmMfma_f32_16x16x64_bf8_bf8 = WarpGemmImpl<WarpGemmAttributeMfmaIterateK<
WarpGemmAttributeMfmaImpl_f32_16x16x32_bf8_bf8<WGAttrCtlEnum::Default_>,
2>>;
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
using WarpGemmMfma_f32_16x16x128_fp8_fp8 = WarpGemmImpl<
WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImpl_f32_16x16x128_fp8_fp8<WGAttrCtlEnum::Default_>,
WarpGemmAttributeMfma<WarpGemmAttributeMfmaImpl_f32_16x16x128_fp8_fp8<WGAttrCtlEnum::Default_>,
AttrNumAccess>>;
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
using WarpGemmMfma_f32_16x16x128_fp8_bf8 = WarpGemmImpl<
WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImpl_f32_16x16x128_fp8_bf8<WGAttrCtlEnum::Default_>,
WarpGemmAttributeMfma<WarpGemmAttributeMfmaImpl_f32_16x16x128_fp8_bf8<WGAttrCtlEnum::Default_>,
AttrNumAccess>>;
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
using WarpGemmMfma_f32_16x16x128_bf8_fp8 = WarpGemmImpl<
WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImpl_f32_16x16x128_bf8_fp8<WGAttrCtlEnum::Default_>,
WarpGemmAttributeMfma<WarpGemmAttributeMfmaImpl_f32_16x16x128_bf8_fp8<WGAttrCtlEnum::Default_>,
AttrNumAccess>>;
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
using WarpGemmMfma_f32_16x16x128_bf8_bf8 = WarpGemmImpl<
WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImpl_f32_16x16x128_bf8_bf8<WGAttrCtlEnum::Default_>,
WarpGemmAttributeMfma<WarpGemmAttributeMfmaImpl_f32_16x16x128_bf8_bf8<WGAttrCtlEnum::Default_>,
AttrNumAccess>>;
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
using WarpGemmMfma_f32_32x32x64_fp8_fp8 = WarpGemmImpl<
WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImpl_f32_32x32x64_fp8_fp8<WGAttrCtlEnum::Default_>,
WarpGemmAttributeMfma<WarpGemmAttributeMfmaImpl_f32_32x32x64_fp8_fp8<WGAttrCtlEnum::Default_>,
AttrNumAccess>>;
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
using WarpGemmMfma_f32_32x32x64_fp8_bf8 = WarpGemmImpl<
WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImpl_f32_32x32x64_fp8_bf8<WGAttrCtlEnum::Default_>,
WarpGemmAttributeMfma<WarpGemmAttributeMfmaImpl_f32_32x32x64_fp8_bf8<WGAttrCtlEnum::Default_>,
AttrNumAccess>>;
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
using WarpGemmMfma_f32_32x32x64_bf8_fp8 = WarpGemmImpl<
WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImpl_f32_32x32x64_bf8_fp8<WGAttrCtlEnum::Default_>,
WarpGemmAttributeMfma<WarpGemmAttributeMfmaImpl_f32_32x32x64_bf8_fp8<WGAttrCtlEnum::Default_>,
AttrNumAccess>>;
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
using WarpGemmMfma_f32_32x32x64_bf8_bf8 = WarpGemmImpl<
WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImpl_f32_32x32x64_bf8_bf8<WGAttrCtlEnum::Default_>,
WarpGemmAttributeMfma<WarpGemmAttributeMfmaImpl_f32_32x32x64_bf8_bf8<WGAttrCtlEnum::Default_>,
AttrNumAccess>>;
using WarpGemmMfma_f32_32x32x16_fp8_fp8_CTransposed =
WarpGemmImpl<WarpGemmAtrributeMfmaTransposedCDistribution<
WarpGemmImpl<WarpGemmAttributeMfmaTransposedCDistribution<
WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_fp8<WGAttrCtlEnum::Default_>>>;
using WarpGemmMfma_f32_32x32x16_fp8_bf8_CTransposed =
WarpGemmImpl<WarpGemmAtrributeMfmaTransposedCDistribution<
WarpGemmImpl<WarpGemmAttributeMfmaTransposedCDistribution<
WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_bf8<WGAttrCtlEnum::Default_>>>;
using WarpGemmMfma_f32_32x32x16_bf8_fp8_CTransposed =
WarpGemmImpl<WarpGemmAtrributeMfmaTransposedCDistribution<
WarpGemmImpl<WarpGemmAttributeMfmaTransposedCDistribution<
WarpGemmAttributeMfmaImpl_f32_32x32x16_bf8_fp8<WGAttrCtlEnum::Default_>>>;
using WarpGemmMfma_f32_32x32x16_bf8_bf8_CTransposed =
WarpGemmImpl<WarpGemmAtrributeMfmaTransposedCDistribution<
WarpGemmImpl<WarpGemmAttributeMfmaTransposedCDistribution<
WarpGemmAttributeMfmaImpl_f32_32x32x16_bf8_bf8<WGAttrCtlEnum::Default_>>>;
template <index_t swizzle_factor = 2>
using WarpGemmMfmaFp8Fp8F32M32N32K16SwizzleBTransposedCDistribution =
WarpGemmImpl<WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB<
WarpGemmImpl<WarpGemmAttributeMfmaIterateKAndTransposedCDistribution_SwizzleB<
WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base<fp8_t, fp8_t, WGAttrCtlEnum::Default_>,
2,
swizzle_factor>>;
// int8
using WarpGemmMfma_i32_32x32x16_i8_i8 = WarpGemmImpl<
WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImpl_i32_32x32x16_i8<WGAttrCtlEnum::Default_>>>;
WarpGemmAttributeMfma<WarpGemmAttributeMfmaImpl_i32_32x32x16_i8<WGAttrCtlEnum::Default_>>>;
using WarpGemmMfma_i32_32x32x16_i8_i8_CTransposed =
WarpGemmImpl<WarpGemmAtrributeMfmaTransposedCDistribution<
WarpGemmImpl<WarpGemmAttributeMfmaTransposedCDistribution<
WarpGemmAttributeMfmaImpl_i32_32x32x16_i8<WGAttrCtlEnum::Default_>>>;
using WarpGemmMfma_i32_16x16x32_i8_i8 = WarpGemmImpl<
WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImpl_i32_16x16x32_i8<WGAttrCtlEnum::Default_>>>;
WarpGemmAttributeMfma<WarpGemmAttributeMfmaImpl_i32_16x16x32_i8<WGAttrCtlEnum::Default_>>>;
using WarpGemmMfma_i32_16x16x32_i8_i8_CTransposed =
WarpGemmImpl<WarpGemmAtrributeMfmaTransposedCDistribution<
WarpGemmImpl<WarpGemmAttributeMfmaTransposedCDistribution<
WarpGemmAttributeMfmaImpl_i32_16x16x32_i8<WGAttrCtlEnum::Default_>>>;
} // namespace ck_tile

View File

@@ -19,7 +19,7 @@ enum class WGAttrNumAccessEnum
template <typename WarpGemmAttributeMfmaImpl_,
WGAttrNumAccessEnum AttrNumAccess_ = WGAttrNumAccessEnum::Single>
struct WarpGemmAtrributeMfma
struct WarpGemmAttributeMfma
{
using Impl = remove_cvref_t<WarpGemmAttributeMfmaImpl_>;
static constexpr auto AttrNumAccess = AttrNumAccess_;
@@ -103,7 +103,7 @@ struct WarpGemmAtrributeMfma
template <typename WarpGemmAttributeMfmaImpl_,
index_t kKIter,
WGAttrNumAccessEnum AttrNumAccess_ = WGAttrNumAccessEnum::Single>
struct WarpGemmAtrributeMfmaIterateK
struct WarpGemmAttributeMfmaIterateK
{
static_assert(kKIter > 0, "wrong!");
@@ -367,7 +367,7 @@ struct WarpGemmAtrributeMfmaIterateK
template <typename WarpGemmAttributeMfmaImpl_,
WGAttrNumAccessEnum AttrNumAccess_ = WGAttrNumAccessEnum::Single>
struct WarpGemmAtrributeMfmaTransposedCDistribution
struct WarpGemmAttributeMfmaTransposedCDistribution
{
using Impl = remove_cvref_t<WarpGemmAttributeMfmaImpl_>;
static constexpr auto AttrNumAccess = AttrNumAccess_;
@@ -450,7 +450,7 @@ struct WarpGemmAtrributeMfmaTransposedCDistribution
};
template <typename WarpGemmAttributeMfmaImpl_, index_t SFactor_ = 2>
struct WarpGemmAtrributeMfmaTransposedCDistribution_SwizzleB
struct WarpGemmAttributeMfmaTransposedCDistribution_SwizzleB
{
using Impl = remove_cvref_t<WarpGemmAttributeMfmaImpl_>;
@@ -546,7 +546,7 @@ struct WarpGemmAtrributeMfmaTransposedCDistribution_SwizzleB
template <typename WarpGemmAttributeMfmaImpl_,
index_t kKIter,
WGAttrNumAccessEnum AttrNumAccess_ = WGAttrNumAccessEnum::Single>
struct WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution
struct WarpGemmAttributeMfmaIterateKAndTransposedCDistribution
{
using Impl = remove_cvref_t<WarpGemmAttributeMfmaImpl_>;
static constexpr auto AttrNumAccess = AttrNumAccess_;
@@ -574,13 +574,13 @@ struct WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution
CK_TILE_DEVICE static constexpr auto get_awarp_dstr_encoding()
{
return WarpGemmAtrributeMfmaIterateK<Impl, kKIter, AttrNumAccess>::
return WarpGemmAttributeMfmaIterateK<Impl, kKIter, AttrNumAccess>::
get_bwarp_dstr_encoding();
}
CK_TILE_DEVICE static constexpr auto get_bwarp_dstr_encoding()
{
return WarpGemmAtrributeMfmaIterateK<Impl, kKIter, AttrNumAccess>::
return WarpGemmAttributeMfmaIterateK<Impl, kKIter, AttrNumAccess>::
get_awarp_dstr_encoding();
}
@@ -696,7 +696,7 @@ struct WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution
};
template <typename WarpGemmAttributeMfmaImpl_, index_t kKIter, index_t SFactor_ = 2>
struct WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB
struct WarpGemmAttributeMfmaIterateKAndTransposedCDistribution_SwizzleB
{
using Impl = remove_cvref_t<WarpGemmAttributeMfmaImpl_>;
@@ -840,7 +840,7 @@ struct WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB
};
template <typename WarpGemmAttributeMfmaImpl_, index_t kKIter, index_t SFactor_ = 2>
struct WarpGemmAtrributeMfmaIterateK_SwizzleA
struct WarpGemmAttributeMfmaIterateK_SwizzleA
{
using Impl = remove_cvref_t<WarpGemmAttributeMfmaImpl_>;

View File

@@ -0,0 +1,147 @@
// SPDX-License-Identifier: MIT
// Copyright (c) Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/host/device_prop.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma_impl.hpp"
namespace ck_tile {
// TODO: currently only support 16 bit input, which means only support tr16_b128; will use ADataType
// to determine the layout in the future
template <typename Impl>
struct AWarpDstrEncodingTrait
{
using type = tile_distribution_encoding<
sequence<Impl::kRepeat>,
tuple<sequence<Impl::kAMLane>,
sequence<Impl::kABK0PerLane, Impl::kABKLane, Impl::kABK1PerLane>>,
tuple<typename Impl::kABPs2RHssMajor>,
tuple<typename Impl::kABPs2RHssMinor>,
typename Impl::kABYs2RHsMajor,
typename Impl::kABYs2RHsMinor>;
};
template <typename Impl>
struct BWarpDstrEncodingTrait
{
using type = tile_distribution_encoding<
sequence<Impl::kRepeat>,
tuple<sequence<Impl::kBNLane>,
sequence<Impl::kABK0PerLane, Impl::kABKLane, Impl::kABK1PerLane>>,
tuple<typename Impl::kABPs2RHssMajor>,
tuple<typename Impl::kABPs2RHssMinor>,
typename Impl::kABYs2RHsMajor,
typename Impl::kABYs2RHsMinor>;
};
template <typename Impl>
struct CWarpDstrEncodingTrait
{
using type = tile_distribution_encoding<
sequence<>,
tuple<sequence<Impl::kCM0PerLane, Impl::kCMLane, Impl::kCM1PerLane>,
sequence<Impl::kCNLane>>,
tuple<typename Impl::kCPs2RHssMajor>,
tuple<typename Impl::kCPs2RHssMinor>,
typename Impl::kCYs2RHsMajor,
typename Impl::kCYs2RHsMinor>;
};
template <typename WarpGemmAttributeWmmaImpl_, bool kTransC = false>
struct WarpGemmAttributeWmma
{
using Impl = remove_cvref_t<WarpGemmAttributeWmmaImpl_>;
using ADataType = typename Impl::ADataType;
using BDataType = typename Impl::BDataType;
using CDataType = typename Impl::CDataType;
using AVecType = typename Impl::AVecType;
using BVecType = typename Impl::BVecType;
using CVecType = typename Impl::CVecType;
static constexpr index_t kM = Impl::kM;
static constexpr index_t kN = Impl::kN;
static constexpr index_t kK = Impl::kK;
static constexpr index_t kKPerThread = Impl::kABK0PerLane * Impl::kABK1PerLane;
CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return 1; }
// 16 bit input, kAMLane = 16, kABK0PerLane = 4, kABKLane = 2, kABK1PerLane = 2
// 8 bit input, kAMLane = 16, kABK0PerLane = 2, kABKLane = 2, kABK1PerLane = 4
using AWarpDstrEncoding = typename AWarpDstrEncodingTrait<Impl>::type;
using BWarpDstrEncoding = typename BWarpDstrEncodingTrait<Impl>::type;
// kCM0PerLane = 4, kCMLane = 2, kCM1PerLane = 2, kCNLane = 16 for 16 bit input
// kCM0PerLane = 2, kCMLane = 2, kCM1PerLane = 4, kCNLane = 16 for 8 bit input
using CWarpDstrEncoding = typename CWarpDstrEncodingTrait<Impl>::type;
// c_vec += a_vec * b_vec
template <bool post_nop_ = false>
CK_TILE_DEVICE void operator()(CVecType& c_vec,
const AVecType& a_vec,
const BVecType& b_vec,
bool_constant<post_nop_> = {}) const
{
if constexpr(kTransC)
{
Impl{}(c_vec, b_vec, a_vec, bool_constant<post_nop_>{});
}
else
{
Impl{}(c_vec, a_vec, b_vec, bool_constant<post_nop_>{});
}
}
// c_vec = a_vec * b_vec
CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
{
if constexpr(kTransC)
{
return Impl{}(b_vec, a_vec);
}
else
{
return Impl{}(a_vec, b_vec);
}
}
};
template <typename ADataType,
typename BDataType,
typename AccDataType,
index_t M_Warp_Tile,
index_t N_Warp_Tile,
index_t K_Warp_Tile>
CK_TILE_HOST bool check_wmma_supported()
{
if(is_gfx12_supported())
{
return has_wmma_traits_v<gfx12_t,
ADataType,
BDataType,
AccDataType,
M_Warp_Tile,
N_Warp_Tile,
K_Warp_Tile>;
}
else if(is_gfx11_supported())
{
return has_wmma_traits_v<gfx11_t,
ADataType,
BDataType,
AccDataType,
M_Warp_Tile,
N_Warp_Tile,
K_Warp_Tile>;
}
else
{
return false;
}
}
} // namespace ck_tile

View File

@@ -0,0 +1,132 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/core/config.hpp"
namespace ck_tile {
// Base traits for WMMA operations
template <typename Arch,
typename AType,
typename BType,
typename CType,
index_t M,
index_t N,
index_t K>
struct WmmaTraits;
// Generic WMMA implementation using traits
template <typename Traits>
struct WarpGemmAttributeWmmaImpl
{
using ADataType = typename Traits::ADataType;
using BDataType = typename Traits::BDataType;
using CDataType = typename Traits::CDataType;
using AVecType = typename Traits::AVecType;
using BVecType = typename Traits::BVecType;
using CVecType = typename Traits::CVecType;
// Forward all static constants and type aliases
static constexpr index_t kM = Traits::kM;
static constexpr index_t kN = Traits::kN;
static constexpr index_t kK = Traits::kK;
static constexpr index_t kRepeat = Traits::kRepeat;
static constexpr index_t kAMLane = Traits::kAMLane;
static constexpr index_t kBNLane = Traits::kBNLane;
static constexpr index_t kABK0PerLane = Traits::kABK0PerLane;
static constexpr index_t kABKLane = Traits::kABKLane;
static constexpr index_t kABK1PerLane = Traits::kABK1PerLane;
static constexpr index_t kCMLane = Traits::kCMLane;
static constexpr index_t kCNLane = Traits::kCNLane;
static constexpr index_t kCM0PerLane = Traits::kCM0PerLane;
static constexpr index_t kCM1PerLane = Traits::kCM1PerLane;
using kABPs2RHssMajor = typename Traits::kABPs2RHssMajor;
using kABPs2RHssMinor = typename Traits::kABPs2RHssMinor;
using kABYs2RHsMajor = typename Traits::kABYs2RHsMajor;
using kABYs2RHsMinor = typename Traits::kABYs2RHsMinor;
using kCPs2RHssMajor = typename Traits::kCPs2RHssMajor;
using kCPs2RHssMinor = typename Traits::kCPs2RHssMinor;
using kCYs2RHsMajor = typename Traits::kCYs2RHsMajor;
using kCYs2RHsMinor = typename Traits::kCYs2RHsMinor;
// c_vec += a_vec * b_vec
template <bool clamp = false, bool post_nop_ = false>
CK_TILE_DEVICE void operator()(CVecType& c_vec,
const AVecType& a_vec,
const BVecType& b_vec,
bool_constant<post_nop_> = {}) const
{
c_vec = Traits::template wmma_intrinsic<clamp>(a_vec, b_vec, c_vec);
}
// c_vec = a_vec * b_vec
template <bool clamp = false>
CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
{
return bit_cast<CVecType>(
Traits::template wmma_intrinsic<clamp>(a_vec, b_vec, CVecType{0.f}));
}
};
using DeviceIp = remove_cvref_t<decltype(ck_tile::get_device_arch())>;
using WarpGemmAttributeWmmaImpl_f32_16x16x16_f16_f16 =
WarpGemmAttributeWmmaImpl<WmmaTraits<DeviceIp, fp16_t, fp16_t, float, 16, 16, 16>>;
using WarpGemmAttributeWmmaImpl_f32_16x16x16_bf16_bf16 =
WarpGemmAttributeWmmaImpl<WmmaTraits<DeviceIp, bf16_t, bf16_t, float, 16, 16, 16>>;
using WarpGemmAttributeWmmaImpl_i32_16x16x16_i8_i8 =
WarpGemmAttributeWmmaImpl<WmmaTraits<DeviceIp, int8_t, int8_t, int32_t, 16, 16, 16>>;
using WarpGemmAttributeWmmaImpl_f32_16x16x16_f8_f8 =
WarpGemmAttributeWmmaImpl<WmmaTraits<gfx12_t, fp8_t, fp8_t, float, 16, 16, 16>>;
using WarpGemmAttributeWmmaImpl_f32_16x16x16_bf8_bf8 =
WarpGemmAttributeWmmaImpl<WmmaTraits<gfx12_t, bf8_t, bf8_t, float, 16, 16, 16>>;
using WarpGemmAttributeWmmaImpl_f32_16x16x16_f8_bf8 =
WarpGemmAttributeWmmaImpl<WmmaTraits<gfx12_t, fp8_t, bf8_t, float, 16, 16, 16>>;
using WarpGemmAttributeWmmaImpl_f32_16x16x16_bf8_f8 =
WarpGemmAttributeWmmaImpl<WmmaTraits<gfx12_t, bf8_t, fp8_t, float, 16, 16, 16>>;
template <typename Arch,
typename AType,
typename BType,
typename CType,
index_t warp_m,
index_t warp_n,
index_t warp_k>
struct has_wmma_traits
{
template <typename T>
static auto
test(int) -> decltype(std::declval<
typename WmmaTraits<T, AType, BType, CType, warp_m, warp_n, warp_k>::
ADataType>(),
std::true_type{});
template <typename>
static std::false_type test(...);
static constexpr bool value = decltype(test<Arch>(0))::value;
};
template <typename Arch,
typename AType,
typename BType,
typename CType,
index_t warp_m,
index_t warp_n,
index_t warp_k>
constexpr bool has_wmma_traits_v =
has_wmma_traits<Arch, AType, BType, CType, warp_m, warp_n, warp_k>::value;
} // namespace ck_tile

View File

@@ -0,0 +1,87 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "warp_gemm_attribute_wmma_impl_base_traits.hpp"
namespace ck_tile {
// fp16 specialization - GFX11
template <>
struct WmmaTraits<gfx11_t, fp16_t, fp16_t, float, 16, 16, 16>
: WmmaTraitsBase<gfx11_t, fp16_t, fp16_t, float>
{
template <bool clamp = false>
CK_TILE_DEVICE static CVecType
wmma_intrinsic(const AVecType& a_vec, const BVecType& b_vec, const CVecType& c_vec)
{
#ifdef __gfx11__
return __builtin_amdgcn_wmma_f32_16x16x16_f16_w32(a_vec, b_vec, c_vec);
#else
ck_tile::ignore = a_vec;
ck_tile::ignore = b_vec;
ck_tile::ignore = c_vec;
return CVecType{0.f};
#endif
}
};
// bf16 specialization - GFX11
template <>
struct WmmaTraits<gfx11_t, bf16_t, bf16_t, float, 16, 16, 16>
: WmmaTraitsBase<gfx11_t, bf16_t, bf16_t, float>
{
template <bool clamp = false>
CK_TILE_DEVICE static CVecType
wmma_intrinsic(const AVecType& a_vec, const BVecType& b_vec, const CVecType& c_vec)
{
#ifdef __gfx11__
return __builtin_amdgcn_wmma_f32_16x16x16_bf16_w32(a_vec, b_vec, c_vec);
#else
ck_tile::ignore = a_vec;
ck_tile::ignore = b_vec;
ck_tile::ignore = c_vec;
return CVecType{0.f};
#endif
}
};
// fp16 specialization - GFX12
template <>
struct WmmaTraits<gfx12_t, fp16_t, fp16_t, float, 16, 16, 16>
: WmmaTraitsBase<gfx12_t, fp16_t, fp16_t, float>
{
template <bool clamp = false>
CK_TILE_DEVICE static CVecType
wmma_intrinsic(const AVecType& a_vec, const BVecType& b_vec, const CVecType& c_vec)
{
#ifdef __gfx12__
return __builtin_amdgcn_wmma_f32_16x16x16_f16_w32_gfx12(a_vec, b_vec, c_vec);
#else
ck_tile::ignore = a_vec;
ck_tile::ignore = b_vec;
ck_tile::ignore = c_vec;
return CVecType{0.f};
#endif
}
};
// bf16 specialization - GFX12
template <>
struct WmmaTraits<gfx12_t, bf16_t, bf16_t, float, 16, 16, 16>
: WmmaTraitsBase<gfx12_t, bf16_t, bf16_t, float>
{
template <bool clamp = false>
CK_TILE_DEVICE static CVecType
wmma_intrinsic(const AVecType& a_vec, const BVecType& b_vec, const CVecType& c_vec)
{
#ifdef __gfx12__
return __builtin_amdgcn_wmma_f32_16x16x16_bf16_w32_gfx12(a_vec, b_vec, c_vec);
#else
ck_tile::ignore = a_vec;
ck_tile::ignore = b_vec;
ck_tile::ignore = c_vec;
return CVecType{0.f};
#endif
}
};
} // namespace ck_tile

View File

@@ -0,0 +1,138 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "warp_gemm_attribute_wmma_impl_base_traits.hpp"
namespace ck_tile {
// int8 specialization - GFX11
template <>
struct WmmaTraits<gfx11_t, int8_t, int8_t, int32_t, 16, 16, 16>
: WmmaTraitsBase<gfx11_t, int8_t, int8_t, int32_t>
{
template <bool clamp = false>
CK_TILE_DEVICE static CVecType
wmma_intrinsic(const AVecType& a_vec, const BVecType& b_vec, const CVecType& c_vec)
{
#ifdef __gfx11__
return __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32(true, // neg_a
bit_cast<int32x4_t>(a_vec),
true, // neg_b
bit_cast<int32x4_t>(b_vec),
bit_cast<int32x8_t>(c_vec),
clamp);
#else
ck_tile::ignore = a_vec;
ck_tile::ignore = b_vec;
ck_tile::ignore = c_vec;
return CVecType{0};
#endif
}
};
// int8 specialization - GFX12
template <>
struct WmmaTraits<gfx12_t, int8_t, int8_t, int32_t, 16, 16, 16>
: WmmaTraitsBase<gfx12_t, int8_t, int8_t, int32_t>
{
template <bool clamp = false>
CK_TILE_DEVICE static CVecType
wmma_intrinsic(const AVecType& a_vec, const BVecType& b_vec, const CVecType& c_vec)
{
#ifdef __gfx12__
return __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12(true, // neg_a
bit_cast<int32x2_t>(a_vec),
true, // neg_b
bit_cast<int32x2_t>(b_vec),
bit_cast<int32x8_t>(c_vec),
clamp);
#else
ck_tile::ignore = a_vec;
ck_tile::ignore = b_vec;
ck_tile::ignore = c_vec;
return CVecType{0};
#endif
}
};
// fp8/bf8 specialization - GFX12
template <>
struct WmmaTraits<gfx12_t, fp8_t, fp8_t, float, 16, 16, 16>
: WmmaTraitsBase<gfx12_t, fp8_t, fp8_t, float>
{
template <bool clamp = false>
CK_TILE_DEVICE static CVecType
wmma_intrinsic(const AVecType& a_vec, const BVecType& b_vec, const CVecType& c_vec)
{
#ifdef __gfx12__
return __builtin_amdgcn_wmma_f32_16x16x16_fp8_fp8_w32_gfx12(
bit_cast<int32x2_t>(a_vec), bit_cast<int32x2_t>(b_vec), bit_cast<fp32x8_t>(c_vec));
#else
ck_tile::ignore = a_vec;
ck_tile::ignore = b_vec;
ck_tile::ignore = c_vec;
return CVecType{0};
#endif
}
};
template <>
struct WmmaTraits<gfx12_t, bf8_t, bf8_t, float, 16, 16, 16>
: WmmaTraitsBase<gfx12_t, bf8_t, bf8_t, float>
{
template <bool clamp = false>
CK_TILE_DEVICE static CVecType
wmma_intrinsic(const AVecType& a_vec, const BVecType& b_vec, const CVecType& c_vec)
{
#ifdef __gfx12__
return __builtin_amdgcn_wmma_f32_16x16x16_bf8_bf8_w32_gfx12(
bit_cast<int32x2_t>(a_vec), bit_cast<int32x2_t>(b_vec), bit_cast<fp32x8_t>(c_vec));
#else
ck_tile::ignore = a_vec;
ck_tile::ignore = b_vec;
ck_tile::ignore = c_vec;
return CVecType{0};
#endif
}
};
template <>
struct WmmaTraits<gfx12_t, fp8_t, bf8_t, float, 16, 16, 16>
: WmmaTraitsBase<gfx12_t, fp8_t, bf8_t, float>
{
template <bool clamp = false>
CK_TILE_DEVICE static CVecType
wmma_intrinsic(const AVecType& a_vec, const BVecType& b_vec, const CVecType& c_vec)
{
#ifdef __gfx12__
return __builtin_amdgcn_wmma_f32_16x16x16_fp8_bf8_w32_gfx12(
bit_cast<int32x2_t>(a_vec), bit_cast<int32x2_t>(b_vec), bit_cast<fp32x8_t>(c_vec));
#else
ck_tile::ignore = a_vec;
ck_tile::ignore = b_vec;
ck_tile::ignore = c_vec;
return CVecType{0};
#endif
}
};
template <>
struct WmmaTraits<gfx12_t, bf8_t, fp8_t, float, 16, 16, 16>
: WmmaTraitsBase<gfx12_t, bf8_t, fp8_t, float>
{
template <bool clamp = false>
CK_TILE_DEVICE static CVecType
wmma_intrinsic(const AVecType& a_vec, const BVecType& b_vec, const CVecType& c_vec)
{
#ifdef __gfx12__
return __builtin_amdgcn_wmma_f32_16x16x16_bf8_fp8_w32_gfx12(
bit_cast<int32x2_t>(a_vec), bit_cast<int32x2_t>(b_vec), bit_cast<fp32x8_t>(c_vec));
#else
ck_tile::ignore = a_vec;
ck_tile::ignore = b_vec;
ck_tile::ignore = c_vec;
return CVecType{0};
#endif
}
};
} // namespace ck_tile

View File

@@ -0,0 +1,86 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
namespace ck_tile {
template <typename Arch, typename ADType, typename BDType, typename CDType>
struct WmmaTraitsBase;
// GFX11 specialization
template <typename ADType, typename BDType, typename CDType>
struct WmmaTraitsBase<gfx11_t, ADType, BDType, CDType>
{
using ADataType = ADType;
using BDataType = BDType;
using CDataType = CDType;
using AVecType = ext_vector_t<ADataType, 16>;
using BVecType = ext_vector_t<BDataType, 16>;
using CVecType = ext_vector_t<CDataType, 8>;
static constexpr index_t kM = 16;
static constexpr index_t kN = 16;
static constexpr index_t kK = 16;
static constexpr index_t kRepeat = 2;
static constexpr index_t kAMLane = 16;
static constexpr index_t kBNLane = 16;
static constexpr index_t kABK0PerLane = 1;
static constexpr index_t kABKLane = 1;
static constexpr index_t kABK1PerLane = 16;
static constexpr index_t kCMLane = 2;
static constexpr index_t kCNLane = 16;
static constexpr index_t kCM0PerLane = 8;
static constexpr index_t kCM1PerLane = 1;
using kABPs2RHssMajor = sequence<0, 2, 1>;
using kABPs2RHssMinor = sequence<0, 1, 0>;
using kABYs2RHsMajor = sequence<2, 2>;
using kABYs2RHsMinor = sequence<0, 2>;
using kCPs2RHssMajor = sequence<1, 2>;
using kCPs2RHssMinor = sequence<1, 0>;
using kCYs2RHsMajor = sequence<1, 1>;
using kCYs2RHsMinor = sequence<0, 2>;
};
// GFX12 specialization
template <typename ADType, typename BDType, typename CDType>
struct WmmaTraitsBase<gfx12_t, ADType, BDType, CDType>
{
using ADataType = ADType;
using BDataType = BDType;
using CDataType = CDType;
using AVecType = ext_vector_t<ADataType, 8>;
using BVecType = ext_vector_t<BDataType, 8>;
using CVecType = ext_vector_t<CDataType, 8>;
static constexpr index_t kM = 16;
static constexpr index_t kN = 16;
static constexpr index_t kK = 16;
static constexpr index_t kRepeat = 1;
static constexpr index_t kAMLane = 16;
static constexpr index_t kBNLane = 16;
static constexpr index_t kABK0PerLane = 2;
static constexpr index_t kABKLane = 2;
static constexpr index_t kABK1PerLane = 4;
static constexpr index_t kCMLane = 2;
static constexpr index_t kCNLane = 16;
static constexpr index_t kCM0PerLane = 1;
static constexpr index_t kCM1PerLane = 8;
using kABPs2RHssMajor = sequence<2, 1>;
using kABPs2RHssMinor = sequence<1, 0>;
using kABYs2RHsMajor = sequence<2, 2>;
using kABYs2RHsMinor = sequence<0, 2>;
using kCPs2RHssMajor = sequence<1, 2>;
using kCPs2RHssMinor = sequence<1, 0>;
using kCYs2RHsMajor = sequence<1, 1>;
using kCYs2RHsMinor = sequence<0, 2>;
};
} // namespace ck_tile

View File

@@ -5,6 +5,7 @@
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm.hpp"
#include "ck_tile/ops/gemm/warp/warp_wmma_gemm.hpp"
namespace ck_tile {
@@ -19,115 +20,133 @@ template <typename AType,
bool SwizzleA = false,
bool UseStructuredSparsity = false,
WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
struct WarpGemmMfmaDispatcher;
struct WarpGemmDispatcher;
// clang-format off
// fp16
// ADataType, BDataType, AccDataType, MPerWave, NPerWave, KPerWave, TransposeC, SwizzleA, UseStructuredSparsity
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 32, 32, 8, false> { using Type = WarpGemmMfmaF16F16F32M32N32K8; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 32, 32, 8, true> { using Type = WarpGemmMfmaF16F16F32M32N32K8TransposedCDistribution; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 32, 32, 16, false> { using Type = WarpGemmMfmaF16F16F32M32N32K16<>; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 32, 32, 16, true> { using Type = WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution<>; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 32, 32, 16, false, false, false, WGAttrNumAccessEnum::Double> {
template<> struct WarpGemmDispatcher<ck_tile::half_t, ck_tile::half_t, float, 32, 32, 8, false> { using Type = WarpGemmMfmaF16F16F32M32N32K8; };
template<> struct WarpGemmDispatcher<ck_tile::half_t, ck_tile::half_t, float, 32, 32, 8, true> { using Type = WarpGemmMfmaF16F16F32M32N32K8TransposedCDistribution; };
template<> struct WarpGemmDispatcher<ck_tile::half_t, ck_tile::half_t, float, 32, 32, 16, false> { using Type = WarpGemmMfmaF16F16F32M32N32K16<>; };
template<> struct WarpGemmDispatcher<ck_tile::half_t, ck_tile::half_t, float, 32, 32, 16, true> { using Type = WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution<>; };
template<> struct WarpGemmDispatcher<ck_tile::half_t, ck_tile::half_t, float, 32, 32, 16, false, false, false, WGAttrNumAccessEnum::Double> {
using Type = WarpGemmMfmaF16F16F32M32N32K16<WGAttrNumAccessEnum::Double>; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 32, 32, 16, true, false, false, WGAttrNumAccessEnum::Double> {
template<> struct WarpGemmDispatcher<ck_tile::half_t, ck_tile::half_t, float, 32, 32, 16, true, false, false, WGAttrNumAccessEnum::Double> {
using Type = WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution<WGAttrNumAccessEnum::Double>; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 16, 16, 16, false> { using Type = WarpGemmMfmaF16F16F32M16N16K16; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 16, 16, 16, true> { using Type = WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 16, 16, 32, false> { using Type = WarpGemmMfmaF16F16F32M16N16K32<>; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 16, 16, 32, true> { using Type = WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution<>; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 16, 16, 32, false, false, false, WGAttrNumAccessEnum::Double> {
template<> struct WarpGemmDispatcher<ck_tile::half_t, ck_tile::half_t, float, 16, 16, 32, false> { using Type = WarpGemmMfmaF16F16F32M16N16K32<>; };
template<> struct WarpGemmDispatcher<ck_tile::half_t, ck_tile::half_t, float, 16, 16, 32, true> { using Type = WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution<>; };
template<> struct WarpGemmDispatcher<ck_tile::half_t, ck_tile::half_t, float, 16, 16, 32, false, false, false, WGAttrNumAccessEnum::Double> {
using Type = WarpGemmMfmaF16F16F32M16N16K32<WGAttrNumAccessEnum::Double>; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 16, 16, 32, true, false, false, WGAttrNumAccessEnum::Double> {
template<> struct WarpGemmDispatcher<ck_tile::half_t, ck_tile::half_t, float, 16, 16, 32, true, false, false, WGAttrNumAccessEnum::Double> {
using Type = WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution<WGAttrNumAccessEnum::Double>; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 4, 64, 16, false> { using Type = WarpGemmMfmaF16F16F32M4N64K16; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 64, 4, 16, false> { using Type = WarpGemmMfmaF16F16F32M64N4K16; };
template<> struct WarpGemmDispatcher<ck_tile::half_t, ck_tile::half_t, float, 4, 64, 16, false> { using Type = WarpGemmMfmaF16F16F32M4N64K16; };
template<> struct WarpGemmDispatcher<ck_tile::half_t, ck_tile::half_t, float, 64, 4, 16, false> { using Type = WarpGemmMfmaF16F16F32M64N4K16; };
// WMMA cases
#if defined(__gfx11__) || defined(__gfx12__)
template<bool TransposeC> struct WarpGemmDispatcher<ck_tile::half_t, ck_tile::half_t, float, 16, 16, 16, TransposeC, false> { using Type = WarpGemmWmma_f32_16x16x16_f16_f16<TransposeC>;};
#else
template<> struct WarpGemmDispatcher<ck_tile::half_t, ck_tile::half_t, float, 16, 16, 16, false> { using Type = WarpGemmMfmaF16F16F32M16N16K16; };
template<> struct WarpGemmDispatcher<ck_tile::half_t, ck_tile::half_t, float, 16, 16, 16, true> { using Type = WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution; };
#endif
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 32, 32, 8, false, true> { using Type = WarpGemmMfmaF16F16F32M32N32K8SwizzleA; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 32, 32, 16, false, true> { using Type = WarpGemmMfmaF16F16F32M32N32K16SwizzleA; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 32, 32, 8, true, true> { using Type = WarpGemmMfmaF16F16F32M32N32K8SwizzleBTransposedCDistribution; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 32, 32, 16, true, true> { using Type = WarpGemmMfmaF16F16F32M32N32K16SwizzleBTransposedCDistribution; };
template<> struct WarpGemmDispatcher<ck_tile::half_t, ck_tile::half_t, float, 32, 32, 8, false, true> { using Type = WarpGemmMfmaF16F16F32M32N32K8SwizzleA; };
template<> struct WarpGemmDispatcher<ck_tile::half_t, ck_tile::half_t, float, 32, 32, 16, false, true> { using Type = WarpGemmMfmaF16F16F32M32N32K16SwizzleA; };
template<> struct WarpGemmDispatcher<ck_tile::half_t, ck_tile::half_t, float, 32, 32, 8, true, true> { using Type = WarpGemmMfmaF16F16F32M32N32K8SwizzleBTransposedCDistribution; };
template<> struct WarpGemmDispatcher<ck_tile::half_t, ck_tile::half_t, float, 32, 32, 16, true, true> { using Type = WarpGemmMfmaF16F16F32M32N32K16SwizzleBTransposedCDistribution; };
// fp16 2:4 structural sparsity
// ADataType, BDataType, AccDataType, MPerWave, NPerWave, KPerWave, TransposeC, SwizzleA, UseStructuredSparsity
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 32, 32, 16, false, false, true> { using Type = WarpGemmSmfmacF16F16F32M32N32K16; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 16, 16, 32, false, false, true> { using Type = WarpGemmSmfmacF16F16F32M16N16K32; };
template<> struct WarpGemmDispatcher<ck_tile::half_t, ck_tile::half_t, float, 32, 32, 16, false, false, true> { using Type = WarpGemmSmfmacF16F16F32M32N32K16; };
template<> struct WarpGemmDispatcher<ck_tile::half_t, ck_tile::half_t, float, 16, 16, 32, false, false, true> { using Type = WarpGemmSmfmacF16F16F32M16N16K32; };
// bf16
// ADataType, BDataType, AccDataType, MPerWave, NPerWave, KPerWave, TransposeC, SwizzleA, UseStructuredSparsity
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 32, 32, 8, false> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K8; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 32, 32, 8, true> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K8TransposedCDistribution; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 32, 32, 16, false> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K16<>; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 32, 32, 16, true> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution<>; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 32, 32, 16, false, false, false, WGAttrNumAccessEnum::Double> {
template<> struct WarpGemmDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 32, 32, 8, false> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K8; };
template<> struct WarpGemmDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 32, 32, 8, true> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K8TransposedCDistribution; };
template<> struct WarpGemmDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 32, 32, 16, false> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K16<>; };
template<> struct WarpGemmDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 32, 32, 16, true> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution<>; };
template<> struct WarpGemmDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 32, 32, 16, false, false, false, WGAttrNumAccessEnum::Double> {
using Type = WarpGemmMfmaBf16Bf16F32M32N32K16<WGAttrNumAccessEnum::Double>; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 32, 32, 16, true, false, false, WGAttrNumAccessEnum::Double> {
template<> struct WarpGemmDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 32, 32, 16, true, false, false, WGAttrNumAccessEnum::Double> {
using Type = WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution<WGAttrNumAccessEnum::Double>; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 16, 16, 16, false> { using Type = WarpGemmMfmaBf16Bf16F32M16N16K16; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 16, 16, 16, true> { using Type = WarpGemmMfmaBf16Bf16F32M16N16K16TransposedCDistribution; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 16, 16, 32, false> { using Type = WarpGemmMfmaBf16Bf16F32M16N16K32<>; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 16, 16, 32, true> { using Type = WarpGemmMfmaBf16Bf16F32M16N16K32TransposedCDistribution<>; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 16, 16, 32, false, false, false, WGAttrNumAccessEnum::Double> {
template<> struct WarpGemmDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 16, 16, 32, false> { using Type = WarpGemmMfmaBf16Bf16F32M16N16K32<>; };
template<> struct WarpGemmDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 16, 16, 32, true> { using Type = WarpGemmMfmaBf16Bf16F32M16N16K32TransposedCDistribution<>; };
template<> struct WarpGemmDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 16, 16, 32, false, false, false, WGAttrNumAccessEnum::Double> {
using Type = WarpGemmMfmaBf16Bf16F32M16N16K32<WGAttrNumAccessEnum::Double>; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 16, 16, 32, true, false, false, WGAttrNumAccessEnum::Double> {
template<> struct WarpGemmDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 16, 16, 32, true, false, false, WGAttrNumAccessEnum::Double> {
using Type = WarpGemmMfmaBf16Bf16F32M16N16K32TransposedCDistribution<WGAttrNumAccessEnum::Double>; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 4, 64, 16, false> { using Type = WarpGemmMfmaBf16Bf16F32M4N64K16; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 64, 4, 16, false> { using Type = WarpGemmMfmaBf16Bf16F32M64N4K16; };
template<> struct WarpGemmDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 4, 64, 16, false> { using Type = WarpGemmMfmaBf16Bf16F32M4N64K16; };
template<> struct WarpGemmDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 64, 4, 16, false> { using Type = WarpGemmMfmaBf16Bf16F32M64N4K16; };
// WMMA cases
#if defined(__gfx11__) || defined(__gfx12__)
template<bool TransposeC> struct WarpGemmDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 16, 16, 16, TransposeC, false> { using Type = WarpGemmWmma_f32_16x16x16_bf16_bf16<TransposeC>; };
#else
template<> struct WarpGemmDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 16, 16, 16, false> { using Type = WarpGemmMfmaBf16Bf16F32M16N16K16; };
template<> struct WarpGemmDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 16, 16, 16, true> { using Type = WarpGemmMfmaBf16Bf16F32M16N16K16TransposedCDistribution; };
#endif
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 32, 32, 8, false, true> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K8SwizzleA; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 32, 32, 16, false, true> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleA; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 32, 32, 8, true, true> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K8SwizzleBTransposedCDistribution; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 32, 32, 16, true, true> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleBTransposedCDistribution; };
template<> struct WarpGemmDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 32, 32, 8, false, true> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K8SwizzleA; };
template<> struct WarpGemmDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 32, 32, 16, false, true> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleA; };
template<> struct WarpGemmDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 32, 32, 8, true, true> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K8SwizzleBTransposedCDistribution; };
template<> struct WarpGemmDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 32, 32, 16, true, true> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleBTransposedCDistribution; };
// fp8
// ADataType, BDataType, AccDataType, MPerWave, NPerWave, KPerWave, TransposeC, SwizzleA, UseStructuredSparsity
template<> struct WarpGemmMfmaDispatcher<ck_tile::fp8_t, ck_tile::fp8_t, float, 32, 32, 16, false> { using Type = WarpGemmMfma_f32_32x32x16_fp8_fp8; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::fp8_t, ck_tile::fp8_t, float, 32, 32, 32, false> { using Type = WarpGemmMfma_f32_32x32x32_fp8_fp8; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::fp8_t, ck_tile::fp8_t, float, 16, 16, 32, false> { using Type = WarpGemmMfma_f32_16x16x32_fp8_fp8; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::fp8_t, ck_tile::fp8_t, float, 16, 16, 64, false> { using Type = WarpGemmMfma_f32_16x16x64_fp8_fp8; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::fp8_t, ck_tile::fp8_t, float, 32, 32, 16, true> { using Type = WarpGemmMfma_f32_32x32x16_fp8_fp8_CTransposed; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::fp8_t, ck_tile::bf8_t, float, 32, 32, 16, false> { using Type = WarpGemmMfma_f32_32x32x16_fp8_bf8; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::fp8_t, ck_tile::bf8_t, float, 32, 32, 16, true> { using Type = WarpGemmMfma_f32_32x32x16_fp8_bf8_CTransposed; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf8_t, ck_tile::fp8_t, float, 32, 32, 16, false> { using Type = WarpGemmMfma_f32_32x32x16_bf8_fp8; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf8_t, ck_tile::fp8_t, float, 32, 32, 16, true> { using Type = WarpGemmMfma_f32_32x32x16_bf8_fp8_CTransposed; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf8_t, ck_tile::bf8_t, float, 32, 32, 16, false> { using Type = WarpGemmMfma_f32_32x32x16_bf8_bf8; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf8_t, ck_tile::bf8_t, float, 32, 32, 32, false> { using Type = WarpGemmMfma_f32_32x32x32_bf8_bf8; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf8_t, ck_tile::bf8_t, float, 16, 16, 32, false> { using Type = WarpGemmMfma_f32_16x16x32_bf8_bf8; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf8_t, ck_tile::bf8_t, float, 16, 16, 64, false> { using Type = WarpGemmMfma_f32_16x16x64_bf8_bf8; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf8_t, ck_tile::bf8_t, float, 32, 32, 16, true> { using Type = WarpGemmMfma_f32_32x32x16_bf8_bf8_CTransposed; };
template<> struct WarpGemmDispatcher<ck_tile::fp8_t, ck_tile::fp8_t, float, 32, 32, 16, false> { using Type = WarpGemmMfma_f32_32x32x16_fp8_fp8; };
template<> struct WarpGemmDispatcher<ck_tile::fp8_t, ck_tile::fp8_t, float, 32, 32, 32, false> { using Type = WarpGemmMfma_f32_32x32x32_fp8_fp8; };
template<> struct WarpGemmDispatcher<ck_tile::fp8_t, ck_tile::fp8_t, float, 16, 16, 32, false> { using Type = WarpGemmMfma_f32_16x16x32_fp8_fp8; };
template<> struct WarpGemmDispatcher<ck_tile::fp8_t, ck_tile::fp8_t, float, 16, 16, 64, false> { using Type = WarpGemmMfma_f32_16x16x64_fp8_fp8; };
template<> struct WarpGemmDispatcher<ck_tile::fp8_t, ck_tile::fp8_t, float, 32, 32, 16, true> { using Type = WarpGemmMfma_f32_32x32x16_fp8_fp8_CTransposed; };
template<> struct WarpGemmDispatcher<ck_tile::fp8_t, ck_tile::bf8_t, float, 32, 32, 16, false> { using Type = WarpGemmMfma_f32_32x32x16_fp8_bf8; };
template<> struct WarpGemmDispatcher<ck_tile::fp8_t, ck_tile::bf8_t, float, 32, 32, 16, true> { using Type = WarpGemmMfma_f32_32x32x16_fp8_bf8_CTransposed; };
template<> struct WarpGemmDispatcher<ck_tile::bf8_t, ck_tile::fp8_t, float, 32, 32, 16, false> { using Type = WarpGemmMfma_f32_32x32x16_bf8_fp8; };
template<> struct WarpGemmDispatcher<ck_tile::bf8_t, ck_tile::fp8_t, float, 32, 32, 16, true> { using Type = WarpGemmMfma_f32_32x32x16_bf8_fp8_CTransposed; };
template<> struct WarpGemmDispatcher<ck_tile::bf8_t, ck_tile::bf8_t, float, 32, 32, 16, false> { using Type = WarpGemmMfma_f32_32x32x16_bf8_bf8; };
template<> struct WarpGemmDispatcher<ck_tile::bf8_t, ck_tile::bf8_t, float, 32, 32, 32, false> { using Type = WarpGemmMfma_f32_32x32x32_bf8_bf8; };
template<> struct WarpGemmDispatcher<ck_tile::bf8_t, ck_tile::bf8_t, float, 16, 16, 32, false> { using Type = WarpGemmMfma_f32_16x16x32_bf8_bf8; };
template<> struct WarpGemmDispatcher<ck_tile::bf8_t, ck_tile::bf8_t, float, 16, 16, 64, false> { using Type = WarpGemmMfma_f32_16x16x64_bf8_bf8; };
template<> struct WarpGemmDispatcher<ck_tile::bf8_t, ck_tile::bf8_t, float, 32, 32, 16, true> { using Type = WarpGemmMfma_f32_32x32x16_bf8_bf8_CTransposed; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::fp8_t, ck_tile::fp8_t, float, 16, 16, 128, false> { using Type = WarpGemmMfma_f32_16x16x128_fp8_fp8<>; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::fp8_t, ck_tile::bf8_t, float, 16, 16, 128, false> { using Type = WarpGemmMfma_f32_16x16x128_fp8_bf8<>; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf8_t, ck_tile::fp8_t, float, 16, 16, 128, false> { using Type = WarpGemmMfma_f32_16x16x128_bf8_fp8<>; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf8_t, ck_tile::bf8_t, float, 16, 16, 128, false> { using Type = WarpGemmMfma_f32_16x16x128_bf8_bf8<>; };
template<> struct WarpGemmDispatcher<ck_tile::fp8_t, ck_tile::fp8_t, float, 16, 16, 128, false> { using Type = WarpGemmMfma_f32_16x16x128_fp8_fp8<>; };
template<> struct WarpGemmDispatcher<ck_tile::fp8_t, ck_tile::bf8_t, float, 16, 16, 128, false> { using Type = WarpGemmMfma_f32_16x16x128_fp8_bf8<>; };
template<> struct WarpGemmDispatcher<ck_tile::bf8_t, ck_tile::fp8_t, float, 16, 16, 128, false> { using Type = WarpGemmMfma_f32_16x16x128_bf8_fp8<>; };
template<> struct WarpGemmDispatcher<ck_tile::bf8_t, ck_tile::bf8_t, float, 16, 16, 128, false> { using Type = WarpGemmMfma_f32_16x16x128_bf8_bf8<>; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::fp8_t, ck_tile::fp8_t, float, 32, 32, 64, false> { using Type = WarpGemmMfma_f32_32x32x64_fp8_fp8<>; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::fp8_t, ck_tile::bf8_t, float, 32, 32, 64, false> { using Type = WarpGemmMfma_f32_32x32x64_fp8_bf8<>; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf8_t, ck_tile::fp8_t, float, 32, 32, 64, false> { using Type = WarpGemmMfma_f32_32x32x64_bf8_fp8<>; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf8_t, ck_tile::bf8_t, float, 32, 32, 64, false> { using Type = WarpGemmMfma_f32_32x32x64_bf8_bf8<>; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::fp8_t, ck_tile::fp8_t, float, 32, 32, 64, false, false, false, WGAttrNumAccessEnum::Quad> {
template<> struct WarpGemmDispatcher<ck_tile::fp8_t, ck_tile::fp8_t, float, 32, 32, 64, false> { using Type = WarpGemmMfma_f32_32x32x64_fp8_fp8<>; };
template<> struct WarpGemmDispatcher<ck_tile::fp8_t, ck_tile::bf8_t, float, 32, 32, 64, false> { using Type = WarpGemmMfma_f32_32x32x64_fp8_bf8<>; };
template<> struct WarpGemmDispatcher<ck_tile::bf8_t, ck_tile::fp8_t, float, 32, 32, 64, false> { using Type = WarpGemmMfma_f32_32x32x64_bf8_fp8<>; };
template<> struct WarpGemmDispatcher<ck_tile::bf8_t, ck_tile::bf8_t, float, 32, 32, 64, false> { using Type = WarpGemmMfma_f32_32x32x64_bf8_bf8<>; };
template<> struct WarpGemmDispatcher<ck_tile::fp8_t, ck_tile::fp8_t, float, 32, 32, 64, false, false, false, WGAttrNumAccessEnum::Quad> {
using Type = WarpGemmMfma_f32_32x32x64_fp8_fp8<WGAttrNumAccessEnum::Quad>; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::fp8_t, ck_tile::bf8_t, float, 32, 32, 64, false, false, false, WGAttrNumAccessEnum::Quad> {
template<> struct WarpGemmDispatcher<ck_tile::fp8_t, ck_tile::bf8_t, float, 32, 32, 64, false, false, false, WGAttrNumAccessEnum::Quad> {
using Type = WarpGemmMfma_f32_32x32x64_fp8_bf8<WGAttrNumAccessEnum::Quad>; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf8_t, ck_tile::fp8_t, float, 32, 32, 64, false, false, false, WGAttrNumAccessEnum::Quad> {
template<> struct WarpGemmDispatcher<ck_tile::bf8_t, ck_tile::fp8_t, float, 32, 32, 64, false, false, false, WGAttrNumAccessEnum::Quad> {
using Type = WarpGemmMfma_f32_32x32x64_bf8_fp8<WGAttrNumAccessEnum::Quad>; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf8_t, ck_tile::bf8_t, float, 32, 32, 64, false, false, false, WGAttrNumAccessEnum::Quad> {
template<> struct WarpGemmDispatcher<ck_tile::bf8_t, ck_tile::bf8_t, float, 32, 32, 64, false, false, false, WGAttrNumAccessEnum::Quad> {
using Type = WarpGemmMfma_f32_32x32x64_bf8_bf8<WGAttrNumAccessEnum::Quad>; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::fp8_t, ck_tile::fp8_t, float, 16, 16, 128, false, false, false, WGAttrNumAccessEnum::Quad> {
template<> struct WarpGemmDispatcher<ck_tile::fp8_t, ck_tile::fp8_t, float, 16, 16, 128, false, false, false, WGAttrNumAccessEnum::Quad> {
using Type = WarpGemmMfma_f32_16x16x128_fp8_fp8<WGAttrNumAccessEnum::Quad>; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::fp8_t, ck_tile::bf8_t, float, 16, 16, 128, false, false, false, WGAttrNumAccessEnum::Quad> {
template<> struct WarpGemmDispatcher<ck_tile::fp8_t, ck_tile::bf8_t, float, 16, 16, 128, false, false, false, WGAttrNumAccessEnum::Quad> {
using Type = WarpGemmMfma_f32_16x16x128_fp8_bf8<WGAttrNumAccessEnum::Quad>; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf8_t, ck_tile::fp8_t, float, 16, 16, 128, false, false, false, WGAttrNumAccessEnum::Quad> {
template<> struct WarpGemmDispatcher<ck_tile::bf8_t, ck_tile::fp8_t, float, 16, 16, 128, false, false, false, WGAttrNumAccessEnum::Quad> {
using Type = WarpGemmMfma_f32_16x16x128_bf8_fp8<WGAttrNumAccessEnum::Quad>; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf8_t, ck_tile::bf8_t, float, 16, 16, 128, false, false, false, WGAttrNumAccessEnum::Quad> {
template<> struct WarpGemmDispatcher<ck_tile::bf8_t, ck_tile::bf8_t, float, 16, 16, 128, false, false, false, WGAttrNumAccessEnum::Quad> {
using Type = WarpGemmMfma_f32_16x16x128_bf8_bf8<WGAttrNumAccessEnum::Quad>; };
//WMMA cases
template<bool TransposeC> struct WarpGemmDispatcher<ck_tile::fp8_t, ck_tile::fp8_t, float, 16, 16, 16, TransposeC, false> { using Type =WarpGemmWmma_f32_16x16x16_f8_f8<TransposeC>; };
template<bool TransposeC> struct WarpGemmDispatcher<ck_tile::bf8_t, ck_tile::bf8_t, float, 16, 16, 16, TransposeC, false> { using Type =WarpGemmWmma_f32_16x16x16_bf8_bf8<TransposeC>; };
template<bool TransposeC> struct WarpGemmDispatcher<ck_tile::fp8_t, ck_tile::bf8_t, float, 16, 16, 16, TransposeC, false> { using Type =WarpGemmWmma_f32_16x16x16_f8_bf8<TransposeC>; };
template<bool TransposeC> struct WarpGemmDispatcher<ck_tile::bf8_t, ck_tile::fp8_t, float, 16, 16, 16, TransposeC, false> { using Type =WarpGemmWmma_f32_16x16x16_bf8_f8<TransposeC>; };
// int8
// ADataType, BDataType, AccDataType, MPerWave, NPerWave, KPerWave, TransposeC, SwizzleA, UseStructuredSparsity
template<> struct WarpGemmMfmaDispatcher<ck_tile::int8_t, ck_tile::int8_t, ck_tile::int32_t, 32, 32, 16, false> { using Type = WarpGemmMfma_i32_32x32x16_i8_i8; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::int8_t, ck_tile::int8_t, ck_tile::int32_t, 32, 32, 16, true> { using Type = WarpGemmMfma_i32_32x32x16_i8_i8_CTransposed; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::int8_t, ck_tile::int8_t, ck_tile::int32_t, 16, 16, 32, false> { using Type = WarpGemmMfma_i32_16x16x32_i8_i8; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::int8_t, ck_tile::int8_t, ck_tile::int32_t, 16, 16, 32, true> { using Type = WarpGemmMfma_i32_16x16x32_i8_i8_CTransposed; };
template<> struct WarpGemmDispatcher<ck_tile::int8_t, ck_tile::int8_t, ck_tile::int32_t, 32, 32, 16, false> { using Type = WarpGemmMfma_i32_32x32x16_i8_i8; };
template<> struct WarpGemmDispatcher<ck_tile::int8_t, ck_tile::int8_t, ck_tile::int32_t, 32, 32, 16, true> { using Type = WarpGemmMfma_i32_32x32x16_i8_i8_CTransposed; };
template<> struct WarpGemmDispatcher<ck_tile::int8_t, ck_tile::int8_t, ck_tile::int32_t, 16, 16, 32, false> { using Type = WarpGemmMfma_i32_16x16x32_i8_i8; };
template<> struct WarpGemmDispatcher<ck_tile::int8_t, ck_tile::int8_t, ck_tile::int32_t, 16, 16, 32, true> { using Type = WarpGemmMfma_i32_16x16x32_i8_i8_CTransposed; };
// WMMA cases
template<bool TransposeC> struct WarpGemmDispatcher<ck_tile::int8_t, ck_tile::int8_t, int32_t, 16, 16, 16, TransposeC, false> { using Type = WarpGemmWmma_i32_16x16x16_i8_i8<TransposeC>;};
// clang-format on
} // namespace impl
@@ -142,15 +161,15 @@ template <typename AType,
bool SwizzleA = false,
bool UseStructuredSparsity = false,
WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
using WarpGemmMfmaDispatcher = typename impl::WarpGemmMfmaDispatcher<AType,
BType,
AccType,
MPerWave,
NPerWave,
KPerWave,
TransposeC,
SwizzleA,
UseStructuredSparsity,
AttrNumAccess>::Type;
using WarpGemmDispatcher = typename impl::WarpGemmDispatcher<AType,
BType,
AccType,
MPerWave,
NPerWave,
KPerWave,
TransposeC,
SwizzleA,
UseStructuredSparsity,
AttrNumAccess>::Type;
} // namespace ck_tile

View File

@@ -0,0 +1,37 @@
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_impl.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma.hpp"
namespace ck_tile {
template <bool kTransC = false>
using WarpGemmWmma_f32_16x16x16_f16_f16 =
WarpGemmImpl<WarpGemmAttributeWmma<WarpGemmAttributeWmmaImpl_f32_16x16x16_f16_f16, kTransC>>;
template <bool kTransC = false>
using WarpGemmWmma_f32_16x16x16_bf16_bf16 =
WarpGemmImpl<WarpGemmAttributeWmma<WarpGemmAttributeWmmaImpl_f32_16x16x16_bf16_bf16, kTransC>>;
template <bool kTransC = false>
using WarpGemmWmma_i32_16x16x16_i8_i8 =
WarpGemmImpl<WarpGemmAttributeWmma<WarpGemmAttributeWmmaImpl_i32_16x16x16_i8_i8, kTransC>>;
template <bool kTransC = false>
using WarpGemmWmma_f32_16x16x16_f8_f8 =
WarpGemmImpl<WarpGemmAttributeWmma<WarpGemmAttributeWmmaImpl_f32_16x16x16_f8_f8, kTransC>>;
template <bool kTransC = false>
using WarpGemmWmma_f32_16x16x16_bf8_bf8 =
WarpGemmImpl<WarpGemmAttributeWmma<WarpGemmAttributeWmmaImpl_f32_16x16x16_bf8_bf8, kTransC>>;
template <bool kTransC = false>
using WarpGemmWmma_f32_16x16x16_f8_bf8 =
WarpGemmImpl<WarpGemmAttributeWmma<WarpGemmAttributeWmmaImpl_f32_16x16x16_f8_bf8, kTransC>>;
template <bool kTransC = false>
using WarpGemmWmma_f32_16x16x16_bf8_f8 =
WarpGemmImpl<WarpGemmAttributeWmma<WarpGemmAttributeWmmaImpl_f32_16x16x16_bf8_f8, kTransC>>;
} // namespace ck_tile

View File

@@ -44,13 +44,13 @@ struct GemmAQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgC
constexpr index_t VecLoadSize = GetVectorSizeAQ<Problem>();
constexpr bool Preshuffle = Problem::Traits::Preshuffle;
using WarpTile = typename Problem::BlockGemmShape::WarpTile;
using WarpGemm = WarpGemmMfmaDispatcher<typename Problem::ComputeDataType,
typename Problem::ComputeDataType,
typename Problem::CDataType,
WarpTile::at(I0),
WarpTile::at(I1),
WarpTile::at(I2),
false>;
using WarpGemm = WarpGemmDispatcher<typename Problem::ComputeDataType,
typename Problem::ComputeDataType,
typename Problem::CDataType,
WarpTile::at(I0),
WarpTile::at(I1),
WarpTile::at(I2),
false>;
static_assert(std::is_same_v<AQLayout, tensor_layout::gemm::RowMajor>);
if constexpr(Preshuffle)
@@ -92,13 +92,13 @@ struct GemmAQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgC
static_assert(Problem::kQuantGroupSize % WarpTile::at(I2) == 0,
"KPerWarpGemm must be a multiple of kQuantGroupSize!");
using WarpGemm = WarpGemmMfmaDispatcher<typename Problem::ComputeDataType,
typename Problem::ComputeDataType,
typename Problem::CDataType,
WarpTile::at(I0),
WarpTile::at(I1),
WarpTile::at(I2),
false>;
using WarpGemm = WarpGemmDispatcher<typename Problem::ComputeDataType,
typename Problem::ComputeDataType,
typename Problem::CDataType,
WarpTile::at(I0),
WarpTile::at(I1),
WarpTile::at(I2),
false>;
static_assert(std::is_same_v<typename Problem::ComputeDataType, fp8_t> ||
std::is_same_v<typename Problem::ComputeDataType, bf8_t>);
static_assert(std::is_same_v<typename Problem::CDataType, float>);

View File

@@ -30,6 +30,14 @@ if(GPU_TARGETS MATCHES "gfx94" OR GPU_TARGETS MATCHES "gfx95")
target_compile_options(test_ck_tile_gemm_pipeline_basic_fp8 PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
add_test_executable(test_ck_tile_gemm_pipeline_basic_bf8 test_gemm_pipeline_basic_bf8.cpp)
target_compile_options(test_ck_tile_gemm_pipeline_basic_bf8 PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
elseif(GPU_TARGETS MATCHES "gfx11" OR GPU_TARGETS MATCHES "gfx12")
# On Radeon devices, build the WMMA version instead
add_gtest_executable(test_ck_tile_gemm_pipeline_mem_wmma test_gemm_pipeline_mem_wmma.cpp)
add_gtest_executable(test_ck_tile_gemm_pipeline_compv3_wmma test_gemm_pipeline_compv3_wmma.cpp)
add_gtest_executable(test_ck_tile_gemm_pipeline_compv4_wmma test_gemm_pipeline_compv4_wmma.cpp)
target_compile_options(test_ck_tile_gemm_pipeline_mem_wmma PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
target_compile_options(test_ck_tile_gemm_pipeline_compv3_wmma PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
target_compile_options(test_ck_tile_gemm_pipeline_compv4_wmma PRIVATE ${EXAMPLE_GEMM_COMPILE_COMPUTE_V4_OPTIONS})
else()
message(DEBUG "Skipping ck_tile_gemm tests for current target")
endif()
@@ -46,4 +54,7 @@ if(GPU_TARGETS MATCHES "gfx94" OR GPU_TARGETS MATCHES "gfx95" OR GPU_TARGETS MAT
target_compile_options(test_ck_tile_gemm_pipeline_basic_fp16 PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
add_test_executable(test_ck_tile_gemm_pipeline_basic_bf16 test_gemm_pipeline_basic_bf16.cpp)
target_compile_options(test_ck_tile_gemm_pipeline_basic_bf16 PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
elseif(GPU_TARGETS MATCHES "gfx11" OR GPU_TARGETS MATCHES "gfx12")
add_gtest_executable(test_ck_tile_gemm_pipeline_persistent_wmma test_gemm_pipeline_persistent_wmma.cpp)
target_compile_options(test_ck_tile_gemm_pipeline_persistent_wmma PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
endif()

View File

@@ -3,7 +3,8 @@
#include "gtest/gtest.h"
template <typename T>
class TestCkTileGemmPipelineCompV3 : public TestCkTileGemmPipeline<T>
class TestCkTileGemmPipelineCompV3
: public TestCkTileGemmPipeline<T, TestCkTileGemmPipelineCompV3<T>>
{
};

View File

@@ -0,0 +1,17 @@
#include "test_gemm_pipeline_kernel_types.hpp"
#include "test_gemm_pipeline_wmma_base.hpp"
#include "gtest/gtest.h"
template <typename T>
class TestCkTileGemmPipelineCompV3Wmma
: public TestCkTileGemmPipelineWmmaBase<T, TestCkTileGemmPipelineCompV3Wmma<T>>
{
};
#define TEST_SUITE_NAME TestCkTileGemmPipelineCompV3Wmma
TYPED_TEST_SUITE(TestCkTileGemmPipelineCompV3Wmma, KernelTypesCompV3Wmma);
#include "test_gemm_pipeline_ut_cases.inc"
#undef TEST_SUITE_NAME

View File

@@ -3,7 +3,8 @@
#include "gtest/gtest.h"
template <typename T>
class TestCkTileGemmPipelineCompV4 : public TestCkTileGemmPipeline<T>
class TestCkTileGemmPipelineCompV4
: public TestCkTileGemmPipeline<T, TestCkTileGemmPipelineCompV4<T>>
{
};

View File

@@ -0,0 +1,17 @@
#include "test_gemm_pipeline_kernel_types.hpp"
#include "test_gemm_pipeline_wmma_base.hpp"
#include "gtest/gtest.h"
template <typename T>
class TestCkTileGemmPipelineCompV4Wmma
: public TestCkTileGemmPipelineWmmaBase<T, TestCkTileGemmPipelineCompV4Wmma<T>>
{
};
#define TEST_SUITE_NAME TestCkTileGemmPipelineCompV4Wmma
TYPED_TEST_SUITE(TestCkTileGemmPipelineCompV4Wmma, KernelTypesCompV4Wmma);
#include "test_gemm_pipeline_ut_cases.inc"
#undef TEST_SUITE_NAME

View File

@@ -9,13 +9,16 @@
#include "ck_tile/host.hpp"
#include "test_gemm_pipeline_util.hpp"
using I8 = ck_tile::int8_t;
using I32 = ck_tile::int32_t;
using INT8 = ck_tile::int8_t;
using INT32 = ck_tile::int32_t;
using F16 = ck_tile::half_t;
using F32 = float;
using F8 = ck_tile::fp8_t;
using BF16 = ck_tile::bf16_t;
using BF8 = ck_tile::bf8_t;
using Row = ck_tile::tensor_layout::gemm::RowMajor;
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
using Intrawave = ck_tile::integral_constant<ck_tile::GemmPipelineScheduler,
@@ -30,52 +33,119 @@ using CompV4 = ck_tile::integral_constant<GemmPipelineType, GemmPipelineType::Co
using Persistent = std::true_type;
using NonPersistent = std::false_type;
using I16 = ck_tile::number<16>;
using I32 = ck_tile::number<32>;
using I64 = ck_tile::number<64>;
using I256 = ck_tile::number<256>;
// clang-format off
using KernelTypesMem = ::testing::Types<
std::tuple< Row, Row, Row, F16, F16, F32, F16, Intrawave, Mem>,
std::tuple< Row, Row, Row, F16, F16, F32, F16, Interwave, Mem>,
std::tuple< Row, Row, Row, F8, F8, F32, F16, Interwave, Mem>,
std::tuple< Row, Row, Row, F8, F8, F32, F16, Intrawave, Mem>,
std::tuple< Row, Col, Row, F16, F16, F32, F16, Intrawave, Mem>,
std::tuple< Row, Col, Row, F16, F16, F32, F16, Interwave, Mem>,
std::tuple< Row, Col, Row, F8, F8, F32, F16, Interwave, Mem>,
std::tuple< Row, Col, Row, F8, F8, F32, F16, Intrawave, Mem>,
std::tuple< Col, Row, Row, F16, F16, F32, F16, Intrawave, Mem>,
std::tuple< Col, Row, Row, F16, F16, F32, F16, Interwave, Mem>,
std::tuple< Col, Row, Row, F8, F8, F32, F16, Intrawave, Mem>,
std::tuple< Col, Row, Row, F8, F8, F32, F16, Interwave, Mem>,
std::tuple< Col, Col, Row, F16, F16, F32, F16, Intrawave, Mem>,
std::tuple< Col, Col, Row, F16, F16, F32, F16, Interwave, Mem>,
std::tuple< Col, Col, Row, F8, F8, F32, F16, Intrawave, Mem>,
std::tuple< Col, Col, Row, F8, F8, F32, F16, Interwave, Mem>
// ALayout, BLayout, CLayout, ADataType, BDataType, AccDataType, CDataType, M_BlockSize, N_BlockSize, K_BlockSize, M_TileSize, M_TileSize, K_TileSize, Scheduler, PipelineType
std::tuple< Row, Row, Row, F16, F16, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, Mem>,
std::tuple< Row, Row, Row, F16, F16, F32, F16, I256, I256, I64, I32, I32, I16, Interwave, Mem>,
std::tuple< Row, Row, Row, F8, F8, F32, F16, I256, I256, I64, I32, I32, I16, Interwave, Mem>,
std::tuple< Row, Row, Row, F8, F8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, Mem>,
std::tuple< Row, Col, Row, F16, F16, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, Mem>,
std::tuple< Row, Col, Row, F16, F16, F32, F16, I256, I256, I64, I32, I32, I16, Interwave, Mem>,
std::tuple< Row, Col, Row, F8, F8, F32, F16, I256, I256, I64, I32, I32, I16, Interwave, Mem>,
std::tuple< Row, Col, Row, F8, F8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, Mem>,
std::tuple< Col, Row, Row, F16, F16, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, Mem>,
std::tuple< Col, Row, Row, F16, F16, F32, F16, I256, I256, I64, I32, I32, I16, Interwave, Mem>,
std::tuple< Col, Row, Row, F8, F8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, Mem>,
std::tuple< Col, Row, Row, F8, F8, F32, F16, I256, I256, I64, I32, I32, I16, Interwave, Mem>,
std::tuple< Col, Col, Row, F16, F16, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, Mem>,
std::tuple< Col, Col, Row, F16, F16, F32, F16, I256, I256, I64, I32, I32, I16, Interwave, Mem>,
std::tuple< Col, Col, Row, F8, F8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, Mem>,
std::tuple< Col, Col, Row, F8, F8, F32, F16, I256, I256, I64, I32, I32, I16, Interwave, Mem>
>;
using KernelTypesMemWmma = ::testing::Types<
std::tuple< Row, Row, Row, F16, F16, F32, F16, I64, I64, I32, I16, I16, I16, Interwave, Mem>,
std::tuple< Row, Row, Row, F16, F16, F32, F16, I64, I64, I32, I16, I16, I16, Interwave, Mem>,
std::tuple< Row, Row, Row, BF16, BF16, F32, BF16, I64, I64, I32, I16, I16, I16, Interwave, Mem>,
std::tuple< Row, Row, Row, BF16, BF16, F32, BF16, I64, I64, I32, I16, I16, I16, Interwave, Mem>,
std::tuple< Row, Row, Row, INT8, INT8, INT32, INT32, I64, I64, I32, I16, I16, I16, Interwave, Mem>,
std::tuple< Row, Row, Row, INT8, INT8, INT32, INT32, I64, I64, I32, I16, I16, I16, Intrawave, Mem>,
std::tuple< Row, Row, Row, F8, F8, F32, F16, I64, I64, I32, I16, I16, I16, Interwave, Mem>,
std::tuple< Row, Row, Row, F8, F8, F32, F16, I64, I64, I32, I16, I16, I16, Intrawave, Mem>,
std::tuple< Row, Row, Row, BF8, BF8, F32, F16, I64, I64, I32, I16, I16, I16, Interwave, Mem>,
std::tuple< Row, Row, Row, BF8, BF8, F32, F16, I64, I64, I32, I16, I16, I16, Intrawave, Mem>,
std::tuple< Row, Col, Row, F16, F16, F32, F16, I64, I64, I32, I16, I16, I16, Intrawave, Mem>,
std::tuple< Row, Col, Row, F16, F16, F32, F16, I64, I64, I32, I16, I16, I16, Interwave, Mem>,
std::tuple< Row, Col, Row, F8, F8, F32, F16, I64, I64, I32, I16, I16, I16, Interwave, Mem>,
std::tuple< Row, Col, Row, F8, F8, F32, F16, I64, I64, I32, I16, I16, I16, Intrawave, Mem>,
std::tuple< Col, Row, Row, F16, F16, F32, F16, I64, I64, I32, I16, I16, I16, Intrawave, Mem>,
std::tuple< Col, Row, Row, F16, F16, F32, F16, I64, I64, I32, I16, I16, I16, Interwave, Mem>,
std::tuple< Col, Row, Row, F8, F8, F32, F16, I64, I64, I32, I16, I16, I16, Intrawave, Mem>,
std::tuple< Col, Row, Row, F8, F8, F32, F16, I64, I64, I32, I16, I16, I16, Interwave, Mem>,
std::tuple< Col, Col, Row, F16, F16, F32, F16, I64, I64, I32, I16, I16, I16, Intrawave, Mem>,
std::tuple< Col, Col, Row, F16, F16, F32, F16, I64, I64, I32, I16, I16, I16, Interwave, Mem>,
std::tuple< Col, Col, Row, F8, F8, F32, F16, I64, I64, I32, I16, I16, I16, Intrawave, Mem>,
std::tuple< Col, Col, Row, F8, F8, F32, F16, I64, I64, I32, I16, I16, I16, Interwave, Mem>
>;
using KernelTypesCompV3 = ::testing::Types<
std::tuple< Row, Row, Row, F16, F16, F32, F16, Intrawave, CompV3>,
std::tuple< Row, Row, Row, F8, F8, F32, F16, Intrawave, CompV3>,
std::tuple< Row, Col, Row, F16, F16, F32, F16, Intrawave, CompV3>,
std::tuple< Row, Col, Row, F8, F8, F32, F16, Intrawave, CompV3>,
std::tuple< Col, Row, Row, F16, F16, F32, F16, Intrawave, CompV3>,
std::tuple< Col, Row, Row, F8, F8, F32, F16, Intrawave, CompV3>,
std::tuple< Col, Col, Row, F16, F16, F32, F16, Intrawave, CompV3>,
std::tuple< Col, Col, Row, F8, F8, F32, F16, Intrawave, CompV3>,
std::tuple< Row, Row, Row, I8, I8, I32, I32, Intrawave, CompV3>,
std::tuple< Row, Col, Row, I8, I8, I32, I32, Intrawave, CompV3>,
std::tuple< Col, Row, Row, I8, I8, I32, I32, Intrawave, CompV3>,
std::tuple< Col, Col, Row, I8, I8, I32, I32, Intrawave, CompV3>
std::tuple< Row, Row, Row, F16, F16, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>,
std::tuple< Row, Row, Row, F8, F8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>,
std::tuple< Row, Col, Row, F16, F16, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>,
std::tuple< Row, Col, Row, F8, F8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>,
std::tuple< Col, Row, Row, F16, F16, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>,
std::tuple< Col, Row, Row, F8, F8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>,
std::tuple< Col, Col, Row, F16, F16, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>,
std::tuple< Col, Col, Row, F8, F8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>,
std::tuple< Row, Row, Row, INT8, INT8, INT32, INT32, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>,
std::tuple< Row, Col, Row, INT8, INT8, INT32, INT32, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>,
std::tuple< Col, Row, Row, INT8, INT8, INT32, INT32, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>,
std::tuple< Col, Col, Row, INT8, INT8, INT32, INT32, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>
>;
using KernelTypesCompV3Wmma = ::testing::Types<
std::tuple< Row, Row, Row, F16, F16, F32, F16, I64, I64, I32, I16, I16, I16, Intrawave, CompV3>,
std::tuple< Row, Row, Row, BF16, BF16, F32, F16, I64, I64, I32, I16, I16, I16, Intrawave, CompV3>,
std::tuple< Row, Row, Row, INT8, INT8, INT32, INT32, I64, I64, I32, I16, I16, I16, Intrawave, CompV3>,
std::tuple< Row, Row, Row, F8, F8, F32, F16, I64, I64, I32, I16, I16, I16, Intrawave, CompV3>,
std::tuple< Row, Row, Row, BF8, BF8, F32, F16, I64, I64, I32, I16, I16, I16, Intrawave, CompV3>,
std::tuple< Row, Col, Row, F16, F16, F32, F16, I64, I64, I32, I16, I16, I16, Intrawave, CompV3>,
std::tuple< Row, Col, Row, BF16, BF16, F32, F16, I64, I64, I32, I16, I16, I16, Intrawave, CompV3>,
std::tuple< Row, Col, Row, INT8, INT8, INT32, INT32, I64, I64, I32, I16, I16, I16, Intrawave, CompV3>,
std::tuple< Row, Col, Row, F8, F8, F32, F16, I64, I64, I32, I16, I16, I16, Intrawave, CompV3>,
std::tuple< Row, Col, Row, BF8, BF8, F32, F16, I64, I64, I32, I16, I16, I16, Intrawave, CompV3>,
std::tuple< Col, Row, Row, F16, F16, F32, F16, I64, I64, I32, I16, I16, I16, Intrawave, CompV3>,
std::tuple< Col, Row, Row, BF16, BF16, F32, F16, I64, I64, I32, I16, I16, I16, Intrawave, CompV3>,
std::tuple< Col, Row, Row, INT8, INT8, INT32, INT32, I64, I64, I32, I16, I16, I16, Intrawave, CompV3>,
std::tuple< Col, Row, Row, F8, F8, F32, F16, I64, I64, I32, I16, I16, I16, Intrawave, CompV3>,
std::tuple< Col, Row, Row, BF8, BF8, F32, F16, I64, I64, I32, I16, I16, I16, Intrawave, CompV3>,
std::tuple< Col, Col, Row, F16, F16, F32, F16, I64, I64, I32, I16, I16, I16, Intrawave, CompV3>,
std::tuple< Col, Col, Row, BF16, BF16, F32, F16, I64, I64, I32, I16, I16, I16, Intrawave, CompV3>,
std::tuple< Col, Col, Row, INT8, INT8, INT32, INT32, I64, I64, I32, I16, I16, I16, Intrawave, CompV3>,
std::tuple< Col, Col, Row, F8, F8, F32, F16, I64, I64, I32, I16, I16, I16, Intrawave, CompV3>,
std::tuple< Col, Col, Row, BF8, BF8, F32, F16, I64, I64, I32, I16, I16, I16, Intrawave, CompV3>
>;
using KernelTypesCompV4 = ::testing::Types<
std::tuple< Row, Row, Row, F16, F16, F32, F16, Intrawave, CompV4>,
std::tuple< Row, Col, Row, F16, F16, F32, F16, Intrawave, CompV4>,
std::tuple< Col, Row, Row, F16, F16, F32, F16, Intrawave, CompV4>,
std::tuple< Col, Col, Row, F16, F16, F32, F16, Intrawave, CompV4>
std::tuple< Row, Row, Row, F16, F16, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>,
std::tuple< Row, Col, Row, F16, F16, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>,
std::tuple< Col, Row, Row, F16, F16, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>,
std::tuple< Col, Col, Row, F16, F16, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>
>;
using KernelTypesCompV4Wmma = ::testing::Types<
std::tuple< Row, Row, Row, F16, F16, F32, F16, I64, I64, I32, I16, I16, I16, Intrawave, CompV4>,
std::tuple< Row, Col, Row, F16, F16, F32, F16, I64, I64, I32, I16, I16, I16, Intrawave, CompV4>,
std::tuple< Col, Row, Row, F16, F16, F32, F16, I64, I64, I32, I16, I16, I16, Intrawave, CompV4>,
std::tuple< Col, Col, Row, F16, F16, F32, F16, I64, I64, I32, I16, I16, I16, Intrawave, CompV4>
>;
using KernelTypesPersistent = ::testing::Types<
std::tuple< Row, Col, Row, F16, F16, F32, F16, Intrawave, CompV3, Persistent>,
std::tuple< Row, Col, Row, F16, F16, F32, F16, Intrawave, CompV3, NonPersistent>
// ALayout, BLayout, CLayout, ADataType, BDataType, AccDataType, CDataType, M_BlockSize, N_BlockSize, K_BlockSize, M_TileSize, M_TileSize, K_TileSize, Scheduler, PipelineType, Persistent
std::tuple< Row, Col, Row, F16, F16, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3, Persistent>,
std::tuple< Row, Col, Row, F16, F16, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3, NonPersistent>
>;
using KernelTypesPersistentWmma = ::testing::Types<
std::tuple< Row, Col, Row, F16, F16, F32, F16, I64, I64, I32, I16, I16, I16, Intrawave, CompV3, Persistent>,
std::tuple< Row, Col, Row, F16, F16, F32, F16, I64, I64, I32, I16, I16, I16, Intrawave, CompV3, NonPersistent>
>;
// clang-format on

View File

@@ -3,7 +3,7 @@
#include "gtest/gtest.h"
template <typename T>
class TestCkTileGemmPipelineMem : public TestCkTileGemmPipeline<T>
class TestCkTileGemmPipelineMem : public TestCkTileGemmPipeline<T, TestCkTileGemmPipelineMem<T>>
{
};

View File

@@ -0,0 +1,17 @@
#include "test_gemm_pipeline_kernel_types.hpp"
#include "test_gemm_pipeline_wmma_base.hpp"
#include "gtest/gtest.h"
template <typename T>
class TestCkTileGemmPipelineMemWmma
: public TestCkTileGemmPipelineWmmaBase<T, TestCkTileGemmPipelineMemWmma<T>>
{
};
#define TEST_SUITE_NAME TestCkTileGemmPipelineMemWmma
TYPED_TEST_SUITE(TestCkTileGemmPipelineMemWmma, KernelTypesMemWmma);
#include "test_gemm_pipeline_ut_cases.inc"
#undef TEST_SUITE_NAME

View File

@@ -3,7 +3,8 @@
#include "gtest/gtest.h"
template <typename T>
class TestCkTileGemmPipelinePersistent : public TestCkTileGemmPipeline<T>
class TestCkTileGemmPipelinePersistent
: public TestCkTileGemmPipeline<T, TestCkTileGemmPipelinePersistent<T>>
{
};

View File

@@ -0,0 +1,17 @@
#include "test_gemm_pipeline_kernel_types.hpp"
#include "test_gemm_pipeline_wmma_base.hpp"
#include "gtest/gtest.h"
template <typename T>
class TestCkTileGemmPipelinePersistentWmma
: public TestCkTileGemmPipelineWmmaBase<T, TestCkTileGemmPipelinePersistentWmma<T>>
{
};
#define TEST_SUITE_NAME TestCkTileGemmPipelinePersistentWmma
TYPED_TEST_SUITE(TEST_SUITE_NAME, KernelTypesPersistentWmma);
#include "test_gemm_pipeline_ut_cases.inc"
#undef TEST_SUITE_NAME

View File

@@ -69,7 +69,7 @@ struct GemmPipelineTypeSelector<GemmPipelineType::CompV4, Problem>
static constexpr auto GetName() { return "GemmPipelineAgBgCrCompV4"; }
};
template <typename Tuple>
template <typename Tuple, typename Derived>
class TestCkTileGemmPipeline : public ::testing::Test
{
protected:
@@ -80,32 +80,30 @@ class TestCkTileGemmPipeline : public ::testing::Test
using BDataType = std::tuple_element_t<4, Tuple>;
using AccDataType = std::tuple_element_t<5, Tuple>;
using CDataType = std::tuple_element_t<6, Tuple>;
static constexpr auto Scheduler = std::tuple_element_t<7, Tuple>::value;
static constexpr auto PipelineType = std::tuple_element_t<8, Tuple>::value;
static constexpr auto Scheduler = std::tuple_element_t<13, Tuple>::value;
static constexpr auto PipelineType = std::tuple_element_t<14, Tuple>::value;
static constexpr ck_tile::index_t M_Tile = std::tuple_element_t<7, Tuple>{};
static constexpr ck_tile::index_t N_Tile = std::tuple_element_t<8, Tuple>{};
static constexpr ck_tile::index_t K_Tile = std::tuple_element_t<9, Tuple>{};
static constexpr ck_tile::index_t M_Warp_Tile = std::tuple_element_t<10, Tuple>{};
static constexpr ck_tile::index_t N_Warp_Tile = std::tuple_element_t<11, Tuple>{};
static constexpr ck_tile::index_t K_Warp_Tile = std::tuple_element_t<12, Tuple>{};
using DsLayout = ck_tile::tuple<>;
using DsDataType = ck_tile::tuple<>;
static constexpr bool Persistent =
ck_tile::tuple_element_or_default_t<Tuple, 9, std::false_type>::value;
// TODO: expose tile size through test t-param ?
ck_tile::tuple_element_or_default_t<Tuple, 15, std::false_type>::value;
template <bool PadM, bool PadN, bool PadK, bool Preshuffle>
void invoke_gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s)
{
// TODO: This should be parameterized in tests
constexpr ck_tile::index_t M_Tile = 256;
constexpr ck_tile::index_t N_Tile = 256;
constexpr ck_tile::index_t K_Tile = (PipelineType == GemmPipelineType::CompV4) ? 32 : 64;
constexpr ck_tile::index_t M_Warp = 2;
constexpr ck_tile::index_t N_Warp = 2;
constexpr ck_tile::index_t K_Warp = 1;
constexpr ck_tile::index_t M_Warp_Tile = 32;
constexpr ck_tile::index_t N_Warp_Tile = 32;
constexpr ck_tile::index_t K_Warp_Tile = 16;
constexpr bool kPadM = PadM;
constexpr bool kPadN = PadN;
constexpr bool kPadK = PadK;
@@ -247,11 +245,48 @@ class TestCkTileGemmPipeline : public ::testing::Test
BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num);
}
template <typename ADataType,
typename BDataType,
typename AccDataType,
ck_tile::index_t M_Warp_Tile,
ck_tile::index_t N_Warp_Tile,
ck_tile::index_t K_Warp_Tile>
bool check_data_type()
{
return static_cast<Derived*>(this)
->template check_data_type_impl<ADataType,
BDataType,
AccDataType,
M_Warp_Tile,
N_Warp_Tile,
K_Warp_Tile>();
}
template <typename ADataType,
typename BDataType,
typename AccDataType,
ck_tile::index_t M_Warp_Tile,
ck_tile::index_t N_Warp_Tile,
ck_tile::index_t K_Warp_Tile>
bool check_data_type_impl()
{
return true;
}
public:
std::vector<int> k_batches_;
void SetUp() override
{
if(!check_data_type<ADataType,
BDataType,
AccDataType,
M_Warp_Tile,
N_Warp_Tile,
K_Warp_Tile>())
{
GTEST_SKIP() << "Unsupported data type combination for gemm pipeline test.";
}
if constexpr(PipelineType == GemmPipelineType::CompV4)
{
// Only do k_batch = 1 when pipeline is CompV4

View File

@@ -0,0 +1,24 @@
#pragma once
#include "test_gemm_pipeline_util.hpp"
template <typename Tuple, typename Derived>
class TestCkTileGemmPipelineWmmaBase : public TestCkTileGemmPipeline<Tuple, Derived>
{
public:
template <typename ADataType,
typename BDataType,
typename AccDataType,
ck_tile::index_t M_Warp_Tile,
ck_tile::index_t N_Warp_Tile,
ck_tile::index_t K_Warp_Tile>
bool check_data_type_impl()
{
return ck_tile::check_wmma_supported<ADataType,
BDataType,
AccDataType,
M_Warp_Tile,
N_Warp_Tile,
K_Warp_Tile>();
}
};

View File